diff --git a/.gitattributes b/.gitattributes index 018b81829131af87e6ecbbcceafc1fb0c1266a0c..f35660ca423c63db9376cac53ad08dc8afbddfee 100644 --- a/.gitattributes +++ b/.gitattributes @@ -44,3 +44,7 @@ phivenv/Lib/site-packages/numpy/_core/tests/__pycache__/test_numeric.cpython-39. phivenv/Lib/site-packages/numpy/_core/tests/__pycache__/test_multiarray.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text phivenv/Lib/site-packages/numpy/_core/tests/__pycache__/test_ufunc.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text phivenv/Lib/site-packages/numpy/_core/tests/__pycache__/test_regression.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/numpy/_core/__pycache__/fromnumeric.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/numpy/_core/tests/__pycache__/test_umath.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/numpy/_core/__pycache__/_add_newdocs.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/numpy.libs/msvcp140-23ebcc0b37c8e3d074511f362feac48b.dll filter=lfs diff=lfs merge=lfs -text diff --git a/phivenv/Lib/site-packages/numpy.libs/msvcp140-23ebcc0b37c8e3d074511f362feac48b.dll b/phivenv/Lib/site-packages/numpy.libs/msvcp140-23ebcc0b37c8e3d074511f362feac48b.dll new file mode 100644 index 0000000000000000000000000000000000000000..0ef7ec3bb7cc4a77a9ec8b88bb84182bd02d52cb --- /dev/null +++ b/phivenv/Lib/site-packages/numpy.libs/msvcp140-23ebcc0b37c8e3d074511f362feac48b.dll @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:32e6c8100bd62e7a91f50996c2a59692dc796b6f140a2dfa4de313ca43d4c748 +size 618728 diff --git a/phivenv/Lib/site-packages/numpy/_core/__pycache__/_add_newdocs.cpython-39.pyc b/phivenv/Lib/site-packages/numpy/_core/__pycache__/_add_newdocs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4ac6532accd31aae3631f77e7747fb9b7c6cd68 --- /dev/null +++ b/phivenv/Lib/site-packages/numpy/_core/__pycache__/_add_newdocs.cpython-39.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0bea25f2f50a79675487288df3227f81992a7a68a963be4383a462dd312764a8 +size 191484 diff --git a/phivenv/Lib/site-packages/numpy/_core/__pycache__/fromnumeric.cpython-39.pyc b/phivenv/Lib/site-packages/numpy/_core/__pycache__/fromnumeric.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15c0180497568f51f8587eb939b1529b76c86f36 --- /dev/null +++ b/phivenv/Lib/site-packages/numpy/_core/__pycache__/fromnumeric.cpython-39.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3f099307b1342cc2b4342a99de83f3fcdf70eb441a647f614c7459c3517ff948 +size 130357 diff --git a/phivenv/Lib/site-packages/numpy/_core/tests/__pycache__/test_umath.cpython-39.pyc b/phivenv/Lib/site-packages/numpy/_core/tests/__pycache__/test_umath.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ecdb8ab68b7343e6f5b699a2c6c132817f2f8fc --- /dev/null +++ b/phivenv/Lib/site-packages/numpy/_core/tests/__pycache__/test_umath.cpython-39.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:73937362f04894b8f9fd2a39f391dfea1b26c210c8b609f418ff42d4372f5b4c +size 163106 diff --git a/phivenv/Lib/site-packages/sympy-1.14.0.dist-info/INSTALLER b/phivenv/Lib/site-packages/sympy-1.14.0.dist-info/INSTALLER new file mode 100644 index 0000000000000000000000000000000000000000..a1b589e38a32041e49332e5e81c2d363dc418d68 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy-1.14.0.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/phivenv/Lib/site-packages/sympy-1.14.0.dist-info/METADATA b/phivenv/Lib/site-packages/sympy-1.14.0.dist-info/METADATA new file mode 100644 index 0000000000000000000000000000000000000000..5883009e816c08a2672bc24096295aa8b3252a51 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy-1.14.0.dist-info/METADATA @@ -0,0 +1,319 @@ +Metadata-Version: 2.4 +Name: sympy +Version: 1.14.0 +Summary: Computer algebra system (CAS) in Python +Home-page: https://sympy.org +Author: SymPy development team +Author-email: sympy@googlegroups.com +License: BSD +Project-URL: Source, https://github.com/sympy/sympy +Keywords: Math CAS +Classifier: License :: OSI Approved :: BSD License +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python +Classifier: Topic :: Scientific/Engineering +Classifier: Topic :: Scientific/Engineering :: Mathematics +Classifier: Topic :: Scientific/Engineering :: Physics +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: 3.12 +Classifier: Programming Language :: Python :: 3.13 +Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: Implementation :: CPython +Classifier: Programming Language :: Python :: Implementation :: PyPy +Requires-Python: >=3.9 +Description-Content-Type: text/markdown +License-File: LICENSE +License-File: AUTHORS +Requires-Dist: mpmath<1.4,>=1.1.0 +Provides-Extra: dev +Requires-Dist: pytest>=7.1.0; extra == "dev" +Requires-Dist: hypothesis>=6.70.0; extra == "dev" +Dynamic: author +Dynamic: author-email +Dynamic: classifier +Dynamic: description +Dynamic: description-content-type +Dynamic: home-page +Dynamic: keywords +Dynamic: license +Dynamic: license-file +Dynamic: project-url +Dynamic: provides-extra +Dynamic: requires-dist +Dynamic: requires-python +Dynamic: summary + +# SymPy + +[![pypi version](https://img.shields.io/pypi/v/sympy.svg)](https://pypi.python.org/pypi/sympy) +[![Join the chat at https://gitter.im/sympy/sympy](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/sympy/sympy?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) +[![Zenodo Badge](https://zenodo.org/badge/18918/sympy/sympy.svg)](https://zenodo.org/badge/latestdoi/18918/sympy/sympy) +[![Downloads](https://pepy.tech/badge/sympy/month)](https://pepy.tech/project/sympy) +[![GitHub Issues](https://img.shields.io/badge/issue_tracking-github-blue.svg)](https://github.com/sympy/sympy/issues) +[![Git Tutorial](https://img.shields.io/badge/PR-Welcome-%23FF8300.svg?)](https://git-scm.com/book/en/v2/GitHub-Contributing-to-a-Project) +[![Powered by NumFocus](https://img.shields.io/badge/powered%20by-NumFOCUS-orange.svg?style=flat&colorA=E1523D&colorB=007D8A)](https://numfocus.org) +[![Commits since last release](https://img.shields.io/github/commits-since/sympy/sympy/latest.svg?longCache=true&style=flat-square&logo=git&logoColor=fff)](https://github.com/sympy/sympy/releases) + +[![SymPy Banner](https://github.com/sympy/sympy/raw/master/banner.svg)](https://sympy.org/) + + +See the [AUTHORS](AUTHORS) file for the list of authors. + +And many more people helped on the SymPy mailing list, reported bugs, +helped organize SymPy's participation in the Google Summer of Code, the +Google Highly Open Participation Contest, Google Code-In, wrote and +blogged about SymPy... + +License: New BSD License (see the [LICENSE](LICENSE) file for details) covers all +files in the sympy repository unless stated otherwise. + +Our mailing list is at +. + +We have a community chat at [Gitter](https://gitter.im/sympy/sympy). Feel +free to ask us anything there. We have a very welcoming and helpful +community. + +## Download + +The recommended installation method is through Anaconda, + + +You can also get the latest version of SymPy from + + +To get the git version do + + $ git clone https://github.com/sympy/sympy.git + +For other options (tarballs, debs, etc.), see +. + +## Documentation and Usage + +For in-depth instructions on installation and building the +documentation, see the [SymPy Documentation Style Guide](https://docs.sympy.org/dev/documentation-style-guide.html). + +Everything is at: + + + +You can generate everything at the above site in your local copy of +SymPy by: + + $ cd doc + $ make html + +Then the docs will be in \_build/html. If +you don't want to read that, here is a short usage: + +From this directory, start Python and: + +``` python +>>> from sympy import Symbol, cos +>>> x = Symbol('x') +>>> e = 1/cos(x) +>>> print(e.series(x, 0, 10)) +1 + x**2/2 + 5*x**4/24 + 61*x**6/720 + 277*x**8/8064 + O(x**10) +``` + +SymPy also comes with a console that is a simple wrapper around the +classic python console (or IPython when available) that loads the SymPy +namespace and executes some common commands for you. + +To start it, issue: + + $ bin/isympy + +from this directory, if SymPy is not installed or simply: + + $ isympy + +if SymPy is installed. + +## Installation + +To install SymPy using PyPI, run the following command: + + $ pip install sympy + +To install SymPy using Anaconda, run the following command: + + $ conda install -c anaconda sympy + +To install SymPy from GitHub source, first clone SymPy using `git`: + + $ git clone https://github.com/sympy/sympy.git + +Then, in the `sympy` repository that you cloned, simply run: + + $ pip install . + +See for more information. + +## Contributing + +We welcome contributions from anyone, even if you are new to open +source. Please read our [Introduction to Contributing](https://docs.sympy.org/dev/contributing/introduction-to-contributing.html) +page and the [SymPy Documentation Style Guide](https://docs.sympy.org/dev/documentation-style-guide.html). If you +are new and looking for some way to contribute, a good place to start is +to look at the issues tagged [Easy to Fix](https://github.com/sympy/sympy/issues?q=is%3Aopen+is%3Aissue+label%3A%22Easy+to+Fix%22). + +Please note that all participants in this project are expected to follow +our Code of Conduct. By participating in this project you agree to abide +by its terms. See [CODE\_OF\_CONDUCT.md](CODE_OF_CONDUCT.md). + +## Tests + +To execute all tests, run: + + $./setup.py test + +in the current directory. + +For the more fine-grained running of tests or doctests, use `bin/test` +or respectively `bin/doctest`. The master branch is automatically tested +by GitHub Actions. + +To test pull requests, use +[sympy-bot](https://github.com/sympy/sympy-bot). + +## Regenerate Experimental LaTeX Parser/Lexer + +The parser and lexer were generated with the [ANTLR4](http://antlr4.org) +toolchain in `sympy/parsing/latex/_antlr` and checked into the repo. +Presently, most users should not need to regenerate these files, but +if you plan to work on this feature, you will need the `antlr4` +command-line tool (and you must ensure that it is in your `PATH`). +One way to get it is: + + $ conda install -c conda-forge antlr=4.11.1 + +Alternatively, follow the instructions on the ANTLR website and download +the `antlr-4.11.1-complete.jar`. Then export the `CLASSPATH` as instructed +and instead of creating `antlr4` as an alias, make it an executable file +with the following contents: +``` bash +#!/bin/bash +java -jar /usr/local/lib/antlr-4.11.1-complete.jar "$@" +``` + +After making changes to `sympy/parsing/latex/LaTeX.g4`, run: + + $ ./setup.py antlr + +## Clean + +To clean everything (thus getting the same tree as in the repository): + + $ git clean -Xdf + +which will clear everything ignored by `.gitignore`, and: + + $ git clean -df + +to clear all untracked files. You can revert the most recent changes in +git with: + + $ git reset --hard + +WARNING: The above commands will all clear changes you may have made, +and you will lose them forever. Be sure to check things with `git +status`, `git diff`, `git clean -Xn`, and `git clean -n` before doing any +of those. + +## Bugs + +Our issue tracker is at . Please +report any bugs that you find. Or, even better, fork the repository on +GitHub and create a pull request. We welcome all changes, big or small, +and we will help you make the pull request if you are new to git (just +ask on our mailing list or Gitter Channel). If you further have any queries, you can find answers +on Stack Overflow using the [sympy](https://stackoverflow.com/questions/tagged/sympy) tag. + +## Brief History + +SymPy was started by Ondřej Čertík in 2005, he wrote some code during +the summer, then he wrote some more code during summer 2006. In February +2007, Fabian Pedregosa joined the project and helped fix many things, +contributed documentation, and made it alive again. 5 students (Mateusz +Paprocki, Brian Jorgensen, Jason Gedge, Robert Schwarz, and Chris Wu) +improved SymPy incredibly during summer 2007 as part of the Google +Summer of Code. Pearu Peterson joined the development during the summer +2007 and he has made SymPy much more competitive by rewriting the core +from scratch, which has made it from 10x to 100x faster. Jurjen N.E. Bos +has contributed pretty-printing and other patches. Fredrik Johansson has +written mpmath and contributed a lot of patches. + +SymPy has participated in every Google Summer of Code since 2007. You +can see for +full details. Each year has improved SymPy by bounds. Most of SymPy's +development has come from Google Summer of Code students. + +In 2011, Ondřej Čertík stepped down as lead developer, with Aaron +Meurer, who also started as a Google Summer of Code student, taking his +place. Ondřej Čertík is still active in the community but is too busy +with work and family to play a lead development role. + +Since then, a lot more people have joined the development and some +people have also left. You can see the full list in doc/src/aboutus.rst, +or online at: + + + +The git history goes back to 2007 when development moved from svn to hg. +To see the history before that point, look at +. + +You can use git to see the biggest developers. The command: + + $ git shortlog -ns + +will show each developer, sorted by commits to the project. The command: + + $ git shortlog -ns --since="1 year" + +will show the top developers from the last year. + +## Citation + +To cite SymPy in publications use + +> Meurer A, Smith CP, Paprocki M, Čertík O, Kirpichev SB, Rocklin M, +> Kumar A, Ivanov S, Moore JK, Singh S, Rathnayake T, Vig S, Granger BE, +> Muller RP, Bonazzi F, Gupta H, Vats S, Johansson F, Pedregosa F, Curry +> MJ, Terrel AR, Roučka Š, Saboo A, Fernando I, Kulal S, Cimrman R, +> Scopatz A. (2017) SymPy: symbolic computing in Python. *PeerJ Computer +> Science* 3:e103 + +A BibTeX entry for LaTeX users is + +``` bibtex +@article{10.7717/peerj-cs.103, + title = {SymPy: symbolic computing in Python}, + author = {Meurer, Aaron and Smith, Christopher P. and Paprocki, Mateusz and \v{C}ert\'{i}k, Ond\v{r}ej and Kirpichev, Sergey B. and Rocklin, Matthew and Kumar, Amit and Ivanov, Sergiu and Moore, Jason K. and Singh, Sartaj and Rathnayake, Thilina and Vig, Sean and Granger, Brian E. and Muller, Richard P. and Bonazzi, Francesco and Gupta, Harsh and Vats, Shivam and Johansson, Fredrik and Pedregosa, Fabian and Curry, Matthew J. and Terrel, Andy R. and Rou\v{c}ka, \v{S}t\v{e}p\'{a}n and Saboo, Ashutosh and Fernando, Isuru and Kulal, Sumith and Cimrman, Robert and Scopatz, Anthony}, + year = 2017, + month = Jan, + keywords = {Python, Computer algebra system, Symbolics}, + abstract = { + SymPy is an open-source computer algebra system written in pure Python. It is built with a focus on extensibility and ease of use, through both interactive and programmatic applications. These characteristics have led SymPy to become a popular symbolic library for the scientific Python ecosystem. This paper presents the architecture of SymPy, a description of its features, and a discussion of select submodules. The supplementary material provides additional examples and further outlines details of the architecture and features of SymPy. + }, + volume = 3, + pages = {e103}, + journal = {PeerJ Computer Science}, + issn = {2376-5992}, + url = {https://doi.org/10.7717/peerj-cs.103}, + doi = {10.7717/peerj-cs.103} +} +``` + +SymPy is BSD licensed, so you are free to use it whatever you like, be +it academic, commercial, creating forks or derivatives, as long as you +copy the BSD statement if you redistribute it (see the LICENSE file for +details). That said, although not required by the SymPy license, if it +is convenient for you, please cite SymPy when using it in your work and +also consider contributing all your changes back, so that we can +incorporate it and all of us will benefit in the end. diff --git a/phivenv/Lib/site-packages/sympy-1.14.0.dist-info/RECORD b/phivenv/Lib/site-packages/sympy-1.14.0.dist-info/RECORD new file mode 100644 index 0000000000000000000000000000000000000000..d3303d783ac67872c99d7a3d785411b2a0210ea9 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy-1.14.0.dist-info/RECORD @@ -0,0 +1,3105 @@ +../../Scripts/isympy.exe,sha256=uwTqJwLvnr4ODgeJfROIfwoRbXWNZ89g6AIaLT48lwU,106336 +../../share/man/man1/isympy.1,sha256=9DZdSOIQLikrATHlbkdDZ04LBQigZDUE0_oCXBDvdBs,6659 +__pycache__/isympy.cpython-39.pyc,, +isympy.py,sha256=haEArczb_Ny-LyjOCwx2FNGal_S0Mo4d0deiWL2g_Gc,11220 +sympy-1.14.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +sympy-1.14.0.dist-info/METADATA,sha256=t1bC-_1b4FrFvbDrymH1VhjzD2M-2S2IpWh0KTE6dZU,12816 +sympy-1.14.0.dist-info/RECORD,, +sympy-1.14.0.dist-info/WHEEL,sha256=SmOxYU7pzNKBqASvQJ7DjX3XGUF92lrGhMb3R6_iiqI,91 +sympy-1.14.0.dist-info/entry_points.txt,sha256=Sp-vLJom4PRlhGfY6RpUre7SjYm33JNq9NCwCGeW-fQ,39 +sympy-1.14.0.dist-info/licenses/AUTHORS,sha256=grEDhSGwaWxjyQsx7UoKLJxMFoccZIsVewfKRLbxj40,55779 +sympy-1.14.0.dist-info/licenses/LICENSE,sha256=B6XpgZ9ye0mGrSgpx6KaYyDUJXX3IOsk1xt_71c6AoY,7885 +sympy-1.14.0.dist-info/top_level.txt,sha256=elXb5xfjLdjgSSoQFk4_2Qu3lp2CIaglF9MQtfIoH7o,13 +sympy/__init__.py,sha256=TpR2NIuhBf6rKNgvW89s26Lj6E3m4Fm7_noTcowKSrA,29422 +sympy/__pycache__/__init__.cpython-39.pyc,, +sympy/__pycache__/abc.cpython-39.pyc,, +sympy/__pycache__/conftest.cpython-39.pyc,, +sympy/__pycache__/galgebra.cpython-39.pyc,, +sympy/__pycache__/release.cpython-39.pyc,, +sympy/__pycache__/this.cpython-39.pyc,, +sympy/abc.py,sha256=-RYLRKQ6SZMJGXDIAOP9UWPDAyvYH0zPbp9ZFg0YKBY,3764 +sympy/algebras/__init__.py,sha256=7PRGOW30nlMOTeUPR7iy8l5xGoE2yCBEfRbjqDKWOgU,62 +sympy/algebras/__pycache__/__init__.cpython-39.pyc,, +sympy/algebras/__pycache__/quaternion.cpython-39.pyc,, +sympy/algebras/quaternion.py,sha256=i2Fkvjj14vA6GtXgcSkpDwVmIYc7T5zT1hw49sb__tM,47483 +sympy/algebras/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/algebras/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/algebras/tests/__pycache__/test_quaternion.cpython-39.pyc,, +sympy/algebras/tests/test_quaternion.py,sha256=un4I6IKE2v6YEfUsFBSY5N6qaZ0agn1puX0LElb1k4E,16809 +sympy/assumptions/__init__.py,sha256=PFS8djTqiNbGVMjg7PaPjEfwmjyZVfioXiRVzqqA3E0,550 +sympy/assumptions/__pycache__/__init__.cpython-39.pyc,, +sympy/assumptions/__pycache__/ask.cpython-39.pyc,, +sympy/assumptions/__pycache__/ask_generated.cpython-39.pyc,, +sympy/assumptions/__pycache__/assume.cpython-39.pyc,, +sympy/assumptions/__pycache__/cnf.cpython-39.pyc,, +sympy/assumptions/__pycache__/facts.cpython-39.pyc,, +sympy/assumptions/__pycache__/lra_satask.cpython-39.pyc,, +sympy/assumptions/__pycache__/refine.cpython-39.pyc,, +sympy/assumptions/__pycache__/satask.cpython-39.pyc,, +sympy/assumptions/__pycache__/sathandlers.cpython-39.pyc,, +sympy/assumptions/__pycache__/wrapper.cpython-39.pyc,, +sympy/assumptions/ask.py,sha256=B2W8CW-9ynvYqfBX9_4qCLhCSwE0g_gUGqq7lM50wUk,19376 +sympy/assumptions/ask_generated.py,sha256=whNIU5tj2UkEGHPAfk-_89ahvVERs4npOB0JYVbIQJc,23558 +sympy/assumptions/assume.py,sha256=_gcFc4h_YGs9-tshoD0gmLl_RtPivDQWMWhWWLX9seo,14606 +sympy/assumptions/cnf.py,sha256=mfthFXL8Jhpn7Vgj1IfG5acFzfBgaV02xztQ7KpoLmg,12509 +sympy/assumptions/facts.py,sha256=fimPoHEyusSUr0uI4kiDb4mzxHjEBglvLQW0DGdNFAs,8391 +sympy/assumptions/handlers/__init__.py,sha256=lvjAfPdz0MDjTxjuzbBSGBco2OmpZRiGixSG0oaiZi0,330 +sympy/assumptions/handlers/__pycache__/__init__.cpython-39.pyc,, +sympy/assumptions/handlers/__pycache__/calculus.cpython-39.pyc,, +sympy/assumptions/handlers/__pycache__/common.cpython-39.pyc,, +sympy/assumptions/handlers/__pycache__/matrices.cpython-39.pyc,, +sympy/assumptions/handlers/__pycache__/ntheory.cpython-39.pyc,, +sympy/assumptions/handlers/__pycache__/order.cpython-39.pyc,, +sympy/assumptions/handlers/__pycache__/sets.cpython-39.pyc,, +sympy/assumptions/handlers/calculus.py,sha256=WF3gysxZa2frvd1kHxdDN7xd0PwYemER8dWutfDN-w8,7799 +sympy/assumptions/handlers/common.py,sha256=xgUR3xHPNgGFkLGyMqLV_nUIfMsQJLP8p6boy8Rc6JI,4314 +sympy/assumptions/handlers/matrices.py,sha256=Gdauk2xk1hKPRr4i6RpvOMHtDnyVD34x1OyhL-Oh8Hc,22321 +sympy/assumptions/handlers/ntheory.py,sha256=lRhW2ceJPK0gNqR00NAwkEq1FBjtRrxz1u6WevjRvjc,7660 +sympy/assumptions/handlers/order.py,sha256=-NCLn8ab7Fo_UzZLZi7m3UwtYGN-XiU1qvJDhi5z7ZQ,12280 +sympy/assumptions/handlers/sets.py,sha256=9P8X6ucARJIJA3jG_gj736L71ZQi5hfLCKL--0_ifrw,25841 +sympy/assumptions/lra_satask.py,sha256=FlmiLERsj6J9w6vygwEEEn7pyGPnD0JkPEFEdoE7bfM,9563 +sympy/assumptions/predicates/__init__.py,sha256=q1C7iWpvdDymEUZNyzJvZLsLtgwSkYtCixME-fYyIDw,110 +sympy/assumptions/predicates/__pycache__/__init__.cpython-39.pyc,, +sympy/assumptions/predicates/__pycache__/calculus.cpython-39.pyc,, +sympy/assumptions/predicates/__pycache__/common.cpython-39.pyc,, +sympy/assumptions/predicates/__pycache__/matrices.cpython-39.pyc,, +sympy/assumptions/predicates/__pycache__/ntheory.cpython-39.pyc,, +sympy/assumptions/predicates/__pycache__/order.cpython-39.pyc,, +sympy/assumptions/predicates/__pycache__/sets.cpython-39.pyc,, +sympy/assumptions/predicates/calculus.py,sha256=vFnlYVYZVd6D9OwA7-3bDK_Q0jf2iCZCZiMlWenw0Vg,1889 +sympy/assumptions/predicates/common.py,sha256=zpByACpa_tF0nVNB0J_rJehnXkHtkxhchn1DvkVVS-s,2279 +sympy/assumptions/predicates/matrices.py,sha256=X3vbkEf3zwJLyanEjf6ijYXuRfFfSv-yatl1tJ25wDk,12142 +sympy/assumptions/predicates/ntheory.py,sha256=wvFNFSf0S4egbY7REw0V0ANC03CuiRU9PLmdi16VfHo,2546 +sympy/assumptions/predicates/order.py,sha256=ez1UZ824KDtimLssUASCZHD_KEQmo8Pv-qofVLhZUrk,9511 +sympy/assumptions/predicates/sets.py,sha256=-bTVXa-X1-yfXlIKzMBW_JxIqueS5PdEwEzChzIne38,9238 +sympy/assumptions/refine.py,sha256=FNv5neAYJh-MgvnHDZ8-tjC9RIKcIxarVj5WiE4EnYg,11946 +sympy/assumptions/relation/__init__.py,sha256=t2tZNEIK7w-xXshRQIRL8tIyiNe1W5fMhN7QNRPnQFo,261 +sympy/assumptions/relation/__pycache__/__init__.cpython-39.pyc,, +sympy/assumptions/relation/__pycache__/binrel.cpython-39.pyc,, +sympy/assumptions/relation/__pycache__/equality.cpython-39.pyc,, +sympy/assumptions/relation/binrel.py,sha256=3iwnSEE53-vRsPv-bOnjydgOkCpbB12FTFR_sQ3CwvE,6313 +sympy/assumptions/relation/equality.py,sha256=ZBnSFpctNeroYy1nvam0kzSJCNpsUR3SlBeqFcAMM0U,7148 +sympy/assumptions/satask.py,sha256=P3iprPjuOyhT5Fwr0hX61xTOcD98M_bzXSAV-pXYhN4,11745 +sympy/assumptions/sathandlers.py,sha256=jx8B0u_N73fMoVoLKIfmXMdtSLz7-ZIKhJrxYl84AJk,9418 +sympy/assumptions/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/assumptions/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/assumptions/tests/__pycache__/test_assumptions_2.cpython-39.pyc,, +sympy/assumptions/tests/__pycache__/test_context.cpython-39.pyc,, +sympy/assumptions/tests/__pycache__/test_matrices.cpython-39.pyc,, +sympy/assumptions/tests/__pycache__/test_query.cpython-39.pyc,, +sympy/assumptions/tests/__pycache__/test_refine.cpython-39.pyc,, +sympy/assumptions/tests/__pycache__/test_rel_queries.cpython-39.pyc,, +sympy/assumptions/tests/__pycache__/test_satask.cpython-39.pyc,, +sympy/assumptions/tests/__pycache__/test_sathandlers.cpython-39.pyc,, +sympy/assumptions/tests/__pycache__/test_wrapper.cpython-39.pyc,, +sympy/assumptions/tests/test_assumptions_2.py,sha256=oNgIDOoW-GpBbXxbtw05SWnE8I7sGislYmB3MDogwB4,1070 +sympy/assumptions/tests/test_context.py,sha256=I5gES7AY9_vz1-CEaCchy4MXABtX85ncNkvoRuLskG8,1153 +sympy/assumptions/tests/test_matrices.py,sha256=nzSofuawc18hNe9Nj0dN_lTeDwa2KbPjt4K2rvb3xmw,12258 +sympy/assumptions/tests/test_query.py,sha256=ENB3XoNNi7j0p1_-j9Xd9C5YYpRbu1dcAQe46foV8cU,104495 +sympy/assumptions/tests/test_refine.py,sha256=bHxYUnCOEIzA1yPU3B2xbU9JZfhDv6RkmPm8esetisQ,8834 +sympy/assumptions/tests/test_rel_queries.py,sha256=fRn_IMMb942GfEH8CNk6z50YCyDlA_GMXv2xOVFdSdk,6677 +sympy/assumptions/tests/test_satask.py,sha256=IIqqIxzkLfANpTNBKEsCGCp3Bm8zmDnYd23woqKh9EE,15741 +sympy/assumptions/tests/test_sathandlers.py,sha256=jMCZQb3G6pVQ5MHaSTWV_0eULHaCF8Mowu12Ll72rgs,1842 +sympy/assumptions/tests/test_wrapper.py,sha256=iE32j83rrerCz85HHt2hTolgJkqb44KddfEpI3H1Fb8,1159 +sympy/assumptions/wrapper.py,sha256=7GXR39zPCCfV-pcs8ph9KRRwZF3i_T5Lzv156vKFf_I,5434 +sympy/benchmarks/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/benchmarks/__pycache__/__init__.cpython-39.pyc,, +sympy/benchmarks/__pycache__/bench_discrete_log.cpython-39.pyc,, +sympy/benchmarks/__pycache__/bench_meijerint.cpython-39.pyc,, +sympy/benchmarks/__pycache__/bench_symbench.cpython-39.pyc,, +sympy/benchmarks/bench_discrete_log.py,sha256=CNchIJ5HFMPpNlVZh2vOU0GgQ3bse6hqyqDovpDHlKE,2473 +sympy/benchmarks/bench_meijerint.py,sha256=dSNdZhoc8a4h50wRtbOxLwpmgUiuMFpe6ytTLURcplY,11610 +sympy/benchmarks/bench_symbench.py,sha256=UMD3eYf_Poht0qxjdH2_axGwwON6cZo1Sp700Ci1M1M,2997 +sympy/calculus/__init__.py,sha256=IWDc6qPbEcWyTm9QM6V8vSAs-5OtGNijimykoWz3Clc,828 +sympy/calculus/__pycache__/__init__.cpython-39.pyc,, +sympy/calculus/__pycache__/accumulationbounds.cpython-39.pyc,, +sympy/calculus/__pycache__/euler.cpython-39.pyc,, +sympy/calculus/__pycache__/finite_diff.cpython-39.pyc,, +sympy/calculus/__pycache__/singularities.cpython-39.pyc,, +sympy/calculus/__pycache__/util.cpython-39.pyc,, +sympy/calculus/accumulationbounds.py,sha256=4nJD5JzBarlB5l6yPpVADX3iDamnHVwsor_nQJw9CAM,28682 +sympy/calculus/euler.py,sha256=0QrHD9TYKlSZuO8drnU3bUFJrSu8v5SncqtkRSWLjGM,3436 +sympy/calculus/finite_diff.py,sha256=X7qZJ5GmHlHKokUUMFoaQqrqX2jLRq4b7W2G5aWntzM,17053 +sympy/calculus/singularities.py,sha256=wBQ7WiJ1amuZStBJ-iMTiIHJexjzJHHwrc0tU2XVT10,12184 +sympy/calculus/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/calculus/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/calculus/tests/__pycache__/test_accumulationbounds.cpython-39.pyc,, +sympy/calculus/tests/__pycache__/test_euler.cpython-39.pyc,, +sympy/calculus/tests/__pycache__/test_finite_diff.cpython-39.pyc,, +sympy/calculus/tests/__pycache__/test_singularities.cpython-39.pyc,, +sympy/calculus/tests/__pycache__/test_util.cpython-39.pyc,, +sympy/calculus/tests/test_accumulationbounds.py,sha256=a_Ry2nKX5WbhSe1Bk2k0W6-VWOpVTg0FnA9u8rNSIV4,11195 +sympy/calculus/tests/test_euler.py,sha256=YWpts4pWSiYEwRsi5DLQ16JgC9109-9NKZIL_IO6_Aw,2683 +sympy/calculus/tests/test_finite_diff.py,sha256=9mIvXB8DaxJh2gzmhwi5eE3eBRYTFxyqYG5uIUpYg2w,7670 +sympy/calculus/tests/test_singularities.py,sha256=Zj4WPkT-KlXo7TF3Ir3ug3IkS_qhnFBS7VXAMnIuCso,5272 +sympy/calculus/tests/test_util.py,sha256=IJIhgudR9dym1VRAdu33G2tfSDoexNdWdz-Pgf-kh4o,18557 +sympy/calculus/util.py,sha256=qBDTlb5VP-xJPfh2Ndz37zNMVpseXW4d77hxBCLX5AA,28880 +sympy/categories/__init__.py,sha256=XiKBVC6pbDED-OVtNlSH-fGB8dB_jWLqwCEO7wBTAyA,984 +sympy/categories/__pycache__/__init__.cpython-39.pyc,, +sympy/categories/__pycache__/baseclasses.cpython-39.pyc,, +sympy/categories/__pycache__/diagram_drawing.cpython-39.pyc,, +sympy/categories/baseclasses.py,sha256=1Kn7PIegQCbF78s0rhf1Bx1mbxwfQcfQi6v-QqloSoE,31360 +sympy/categories/diagram_drawing.py,sha256=lfR29f6QX-rK-cCAom-2mRJ4xLyhi55ch91CyEPCXAQ,95173 +sympy/categories/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/categories/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/categories/tests/__pycache__/test_baseclasses.cpython-39.pyc,, +sympy/categories/tests/__pycache__/test_drawing.cpython-39.pyc,, +sympy/categories/tests/test_baseclasses.py,sha256=SwD6QsfSlrEdpD2dbkcN62CPVIRP5SadjCplLrMAoa8,5767 +sympy/categories/tests/test_drawing.py,sha256=IELPpadmnQyQ2x5a5qHC8ioq5kfT1UnAl4h1vO3gbqg,27848 +sympy/codegen/__init__.py,sha256=sQcJsyLyoRh9ccOPhv2eZ-wHjQrArByOON9ndj-MYgQ,974 +sympy/codegen/__pycache__/__init__.cpython-39.pyc,, +sympy/codegen/__pycache__/abstract_nodes.cpython-39.pyc,, +sympy/codegen/__pycache__/algorithms.cpython-39.pyc,, +sympy/codegen/__pycache__/approximations.cpython-39.pyc,, +sympy/codegen/__pycache__/ast.cpython-39.pyc,, +sympy/codegen/__pycache__/cfunctions.cpython-39.pyc,, +sympy/codegen/__pycache__/cnodes.cpython-39.pyc,, +sympy/codegen/__pycache__/cutils.cpython-39.pyc,, +sympy/codegen/__pycache__/cxxnodes.cpython-39.pyc,, +sympy/codegen/__pycache__/fnodes.cpython-39.pyc,, +sympy/codegen/__pycache__/futils.cpython-39.pyc,, +sympy/codegen/__pycache__/matrix_nodes.cpython-39.pyc,, +sympy/codegen/__pycache__/numpy_nodes.cpython-39.pyc,, +sympy/codegen/__pycache__/pynodes.cpython-39.pyc,, +sympy/codegen/__pycache__/pyutils.cpython-39.pyc,, +sympy/codegen/__pycache__/rewriting.cpython-39.pyc,, +sympy/codegen/__pycache__/scipy_nodes.cpython-39.pyc,, +sympy/codegen/abstract_nodes.py,sha256=TY4ecftqnym5viYInnb59zGPPFXdeSGQwi--xTz6Pvo,490 +sympy/codegen/algorithms.py,sha256=GEvnadOTZ3afVrbN5WE52OhCP8gzMII_AN4oHQxvEpo,6383 +sympy/codegen/approximations.py,sha256=r7zyShIWp40QngshFQwqKVKWcA5KmHmQcJarXVtOATQ,6448 +sympy/codegen/ast.py,sha256=R_JhC-uryx9NoJ1ISQU6fb3aV4aRL1SModmV_Dc2z2U,56788 +sympy/codegen/cfunctions.py,sha256=jrFWaRjrQnpHzu6-mwFtd6cRspK8Jo7B3RDo1AOoTLo,12331 +sympy/codegen/cnodes.py,sha256=lsqy-JeRvr9WCk2fwDiRqPhAMFk0snInF7WlAlk9-Zg,3409 +sympy/codegen/cutils.py,sha256=vlzMs8OkC5Bu4sIP-AF2mYf_tIo7Uo4r2DAI_LNhZzM,383 +sympy/codegen/cxxnodes.py,sha256=Om-EBfYduFF97tgXOF68rr8zYbngem9kBRm9SJiKLSM,342 +sympy/codegen/fnodes.py,sha256=mdOPkMVmjjNp22XfHscvx8uMtUxbUqTl66Wki7ORMa8,18936 +sympy/codegen/futils.py,sha256=k-mxMJKr_Q_afTy6NrKNl_N2XQLBmSdZAssO5hBonNY,1792 +sympy/codegen/matrix_nodes.py,sha256=0d3qXy2zaq3isyklE48lP7NP5LTF7SkLXMHMbweVGXU,2284 +sympy/codegen/numpy_nodes.py,sha256=H6FoBnarEuO7R1sbX8Qa9hK4jCju90R2qN66_-1N6Ug,4543 +sympy/codegen/pynodes.py,sha256=Neo1gFQ9kC31T-gH8TeeCaDDNaDe5deIP97MRZFgMHk,243 +sympy/codegen/pyutils.py,sha256=HfF6SP710Y7yExZcSesI0usVaDiWdEPEmMtyMD3JtOY,838 +sympy/codegen/rewriting.py,sha256=8JtiMFgv0sA7uGu2MUU7L3uzldGw5xG1ksuk4zh2ZDE,11585 +sympy/codegen/scipy_nodes.py,sha256=hYlxtGyTM0Z64Nazm1TeMZ3Y8dMsiD_HNhNvbU9eiQY,2508 +sympy/codegen/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/codegen/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/codegen/tests/__pycache__/test_abstract_nodes.cpython-39.pyc,, +sympy/codegen/tests/__pycache__/test_algorithms.cpython-39.pyc,, +sympy/codegen/tests/__pycache__/test_applications.cpython-39.pyc,, +sympy/codegen/tests/__pycache__/test_approximations.cpython-39.pyc,, +sympy/codegen/tests/__pycache__/test_ast.cpython-39.pyc,, +sympy/codegen/tests/__pycache__/test_cfunctions.cpython-39.pyc,, +sympy/codegen/tests/__pycache__/test_cnodes.cpython-39.pyc,, +sympy/codegen/tests/__pycache__/test_cxxnodes.cpython-39.pyc,, +sympy/codegen/tests/__pycache__/test_fnodes.cpython-39.pyc,, +sympy/codegen/tests/__pycache__/test_matrix_nodes.cpython-39.pyc,, +sympy/codegen/tests/__pycache__/test_numpy_nodes.cpython-39.pyc,, +sympy/codegen/tests/__pycache__/test_pynodes.cpython-39.pyc,, +sympy/codegen/tests/__pycache__/test_pyutils.cpython-39.pyc,, +sympy/codegen/tests/__pycache__/test_rewriting.cpython-39.pyc,, +sympy/codegen/tests/__pycache__/test_scipy_nodes.cpython-39.pyc,, +sympy/codegen/tests/test_abstract_nodes.py,sha256=a_GKf3FpeNN8zfMc-V8AaSrQtEI1oiLfJOco2VKiSKI,451 +sympy/codegen/tests/test_algorithms.py,sha256=lgl6w3d7WaT8LJrp716wWicxbgnR4hQk4alzBCdDrIw,6919 +sympy/codegen/tests/test_applications.py,sha256=RFWS7EoE3DytUrr4HzkPaXbOtnaStpzJalVKrY1rVRE,2277 +sympy/codegen/tests/test_approximations.py,sha256=SZpOUzahb_bJOceD0DLdmeiw-jN37OPmf5TRp1dyRgM,2035 +sympy/codegen/tests/test_ast.py,sha256=aAWk-yAVVNAmFMkyUlYBbVA8mPlTFqULOtmXMEi3LO8,21688 +sympy/codegen/tests/test_cfunctions.py,sha256=4jJM7dCLPbfEU3CB1ZUKGC1pRuRvjqZfDkFfYSkrzsY,5119 +sympy/codegen/tests/test_cnodes.py,sha256=FlI5XP39K3kC1QWKQ-QKkzNQw8TROjj5mKXJhK1UU2c,3039 +sympy/codegen/tests/test_cxxnodes.py,sha256=5OwN8D_ZtKN9z5uNeUwbUkyAGzNLrTgIKUlcRWmOSpE,366 +sympy/codegen/tests/test_fnodes.py,sha256=dKSw6w7i4Z8NHVDu_ug4V9Fa49CfnSX1FVahDIdNPgI,6650 +sympy/codegen/tests/test_matrix_nodes.py,sha256=CTMSwQkW0JHaMgHR9Lys8kDk5UgQidTl4VQhWI8gw7s,1896 +sympy/codegen/tests/test_numpy_nodes.py,sha256=tEeyF9vDmw-EyKnHV2dYM6MWGfA2I2fmf3kBsYxqad0,2199 +sympy/codegen/tests/test_pynodes.py,sha256=Gso18KKzSwA-1AHC55SgHPAfH1GrGUCGaN6QR7iuEO0,432 +sympy/codegen/tests/test_pyutils.py,sha256=yvqif7d6EpsnaBjP8XXjVo3wEENBxI6Vm01I1Wow-js,299 +sympy/codegen/tests/test_rewriting.py,sha256=ELPziNI3CsJ4VS7mUbk4QWyG_94FbgZCdBKieMN20Vc,15852 +sympy/codegen/tests/test_scipy_nodes.py,sha256=LBWpjTRfgWN5NLTchLZEp6m7IMtu7HbiKoztLc6KNGY,1495 +sympy/combinatorics/__init__.py,sha256=Dx9xakpHuTIgy4G8zVjAY6pTu8J9_K3d_jKPizRMdVo,1500 +sympy/combinatorics/__pycache__/__init__.cpython-39.pyc,, +sympy/combinatorics/__pycache__/coset_table.cpython-39.pyc,, +sympy/combinatorics/__pycache__/fp_groups.cpython-39.pyc,, +sympy/combinatorics/__pycache__/free_groups.cpython-39.pyc,, +sympy/combinatorics/__pycache__/galois.cpython-39.pyc,, +sympy/combinatorics/__pycache__/generators.cpython-39.pyc,, +sympy/combinatorics/__pycache__/graycode.cpython-39.pyc,, +sympy/combinatorics/__pycache__/group_constructs.cpython-39.pyc,, +sympy/combinatorics/__pycache__/group_numbers.cpython-39.pyc,, +sympy/combinatorics/__pycache__/homomorphisms.cpython-39.pyc,, +sympy/combinatorics/__pycache__/named_groups.cpython-39.pyc,, +sympy/combinatorics/__pycache__/partitions.cpython-39.pyc,, +sympy/combinatorics/__pycache__/pc_groups.cpython-39.pyc,, +sympy/combinatorics/__pycache__/perm_groups.cpython-39.pyc,, +sympy/combinatorics/__pycache__/permutations.cpython-39.pyc,, +sympy/combinatorics/__pycache__/polyhedron.cpython-39.pyc,, +sympy/combinatorics/__pycache__/prufer.cpython-39.pyc,, +sympy/combinatorics/__pycache__/rewritingsystem.cpython-39.pyc,, +sympy/combinatorics/__pycache__/rewritingsystem_fsm.cpython-39.pyc,, +sympy/combinatorics/__pycache__/schur_number.cpython-39.pyc,, +sympy/combinatorics/__pycache__/subsets.cpython-39.pyc,, +sympy/combinatorics/__pycache__/tensor_can.cpython-39.pyc,, +sympy/combinatorics/__pycache__/testutil.cpython-39.pyc,, +sympy/combinatorics/__pycache__/util.cpython-39.pyc,, +sympy/combinatorics/coset_table.py,sha256=soppBd8i6UV-sppdOzykfNE8evMeclU4l1TwfL5UlXI,43296 +sympy/combinatorics/fp_groups.py,sha256=0iLx7mtwxn1kjYFysgAzILAdtlS9e1VWPtWnszubhd8,47779 +sympy/combinatorics/free_groups.py,sha256=gt33RoIByKDZNaf661ay4pQMilPPMPY4hFNGIKi2ho0,39740 +sympy/combinatorics/galois.py,sha256=cpd2iHaSdim5-3-cqvk58YP-AuS6CvJeCdUWpiOdIZI,17867 +sympy/combinatorics/generators.py,sha256=UAdv0bvqtN6i9bfFgB873g-A-zssc0ZoiHckDOEI0M0,7479 +sympy/combinatorics/graycode.py,sha256=xbtr8AaFYb4SMmwUi7mf7913U87jH-XEYF_3pGZfj0o,11207 +sympy/combinatorics/group_constructs.py,sha256=IKx12_yWJqEQ7g-oBuAWd5VRLbCOWyL0LG4PQu43BS8,2021 +sympy/combinatorics/group_numbers.py,sha256=eOiK0RyVHewiGiPXHihfM7MIKlAlg7AhGnBL4f1D54M,9163 +sympy/combinatorics/homomorphisms.py,sha256=dCpmPM3V2ReRuYDdXDfdMTU3pt7zkjKBkYSF6X2MfE8,18844 +sympy/combinatorics/named_groups.py,sha256=zd_C9epKDrMG0drafGUcHuuJJkcMaDt1Nf2ik4NXNq8,8378 +sympy/combinatorics/partitions.py,sha256=ZXqVmVNjmauhMeiTWtCCqOP38b9MJg7UlBdZa-7aICQ,20841 +sympy/combinatorics/pc_groups.py,sha256=5AwEeYQCMdYcybhlEykvJZqEYtJQLgYbRxAnNQeajfE,21414 +sympy/combinatorics/perm_groups.py,sha256=sf0FfxmJ5uWCwlvzRD5m5HHduG5qqHzh3Y0n1Dax5Yo,184825 +sympy/combinatorics/permutations.py,sha256=nQoc4YgTeoT80eHkqetTBUxwAQQW40US2AnR2u4-E_k,87737 +sympy/combinatorics/polyhedron.py,sha256=-1y5GhorUK62_gJpn4tXTLye7BcG0hAup74waDQ8y2I,35928 +sympy/combinatorics/prufer.py,sha256=yN6d4w_ZVXNFhBoevA84gor4Xb5ttG529xbVgHxzKDo,12061 +sympy/combinatorics/rewritingsystem.py,sha256=cT1JrAuKj9rWI3IhaHekYYt0rdG56pwFLg32pcGC9aI,17095 +sympy/combinatorics/rewritingsystem_fsm.py,sha256=CKGhLqyvxY0mlmy8_Hb4WzkSdWYPUaU2yZYhz-0iZ5w,2433 +sympy/combinatorics/schur_number.py,sha256=YdsyA7n_z9tyfRTSRfIjEjtnGo5EuDGBMUS09AQ2MxU,4437 +sympy/combinatorics/subsets.py,sha256=oxuExuGyFnvunkmktl-vBYiLbiN66A2Q2MyzwWfy46A,16047 +sympy/combinatorics/tensor_can.py,sha256=JywytVypkN6aWdCpeEYletMW4vxyUvU29RfKq9-Lu0c,40738 +sympy/combinatorics/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/combinatorics/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/combinatorics/tests/__pycache__/test_coset_table.cpython-39.pyc,, +sympy/combinatorics/tests/__pycache__/test_fp_groups.cpython-39.pyc,, +sympy/combinatorics/tests/__pycache__/test_free_groups.cpython-39.pyc,, +sympy/combinatorics/tests/__pycache__/test_galois.cpython-39.pyc,, +sympy/combinatorics/tests/__pycache__/test_generators.cpython-39.pyc,, +sympy/combinatorics/tests/__pycache__/test_graycode.cpython-39.pyc,, +sympy/combinatorics/tests/__pycache__/test_group_constructs.cpython-39.pyc,, +sympy/combinatorics/tests/__pycache__/test_group_numbers.cpython-39.pyc,, +sympy/combinatorics/tests/__pycache__/test_homomorphisms.cpython-39.pyc,, +sympy/combinatorics/tests/__pycache__/test_named_groups.cpython-39.pyc,, +sympy/combinatorics/tests/__pycache__/test_partitions.cpython-39.pyc,, +sympy/combinatorics/tests/__pycache__/test_pc_groups.cpython-39.pyc,, +sympy/combinatorics/tests/__pycache__/test_perm_groups.cpython-39.pyc,, +sympy/combinatorics/tests/__pycache__/test_permutations.cpython-39.pyc,, +sympy/combinatorics/tests/__pycache__/test_polyhedron.cpython-39.pyc,, +sympy/combinatorics/tests/__pycache__/test_prufer.cpython-39.pyc,, +sympy/combinatorics/tests/__pycache__/test_rewriting.cpython-39.pyc,, +sympy/combinatorics/tests/__pycache__/test_schur_number.cpython-39.pyc,, +sympy/combinatorics/tests/__pycache__/test_subsets.cpython-39.pyc,, +sympy/combinatorics/tests/__pycache__/test_tensor_can.cpython-39.pyc,, +sympy/combinatorics/tests/__pycache__/test_testutil.cpython-39.pyc,, +sympy/combinatorics/tests/__pycache__/test_util.cpython-39.pyc,, +sympy/combinatorics/tests/test_coset_table.py,sha256=cEUF0OH6SNhN_kh069wMsq6h4eSVqbDLghrg2r9Ht48,28474 +sympy/combinatorics/tests/test_fp_groups.py,sha256=C4rWnSJyk7L6uZnI-kAYvtbfFXxba0o1XjpOTkWr97s,10203 +sympy/combinatorics/tests/test_free_groups.py,sha256=g47dnRsNXVG9lFcqppJyrGryGIUCw9dRASkcMa4QujA,6411 +sympy/combinatorics/tests/test_galois.py,sha256=w35JRx8lmlXCdzUBNdocgATPYWBOEZ6LH-tAxOPwCQ8,2763 +sympy/combinatorics/tests/test_generators.py,sha256=6YpOp0i5PRGtySPNZseQ8mjSXbwpfGfz0hDB4kfk40Q,3567 +sympy/combinatorics/tests/test_graycode.py,sha256=pI4e7Y615d5Bmmxui6fdEeyca6j6KSD0YmeychV6ORk,2800 +sympy/combinatorics/tests/test_group_constructs.py,sha256=jJLwMdhuUalKv4Aql9SzV2utK8Ex-IYdMecggr95pi8,450 +sympy/combinatorics/tests/test_group_numbers.py,sha256=spHwsvjzEDXsdAFxkrQYOnioTgtWoVF5K7k7FbgMvfg,4149 +sympy/combinatorics/tests/test_homomorphisms.py,sha256=UwBj5loCuZAiuvmqy5VAbwhCQTph8o6BzTaGrH0rzB4,3745 +sympy/combinatorics/tests/test_named_groups.py,sha256=tsuDVGv4iHGEZ0BVR87_ENhyAfZvFIl0M6Dv_HX1VoY,1931 +sympy/combinatorics/tests/test_partitions.py,sha256=oppszKJLLSpcEzHgespIveSmEC3fDZ0qkus1k7MBt4E,4097 +sympy/combinatorics/tests/test_pc_groups.py,sha256=wfkY_ilpG0XWrhaWMVK6r7yWMeXfM8WNTyti5oE9bdk,2728 +sympy/combinatorics/tests/test_perm_groups.py,sha256=70_AedMpr-YxiYDHHjw6fjvmYW7-YoznEGwUeGtJo0E,41195 +sympy/combinatorics/tests/test_permutations.py,sha256=E8J-WLCW2z9hTyTJcm1dsFgFXh8YAesIppYsXRu5pAs,20189 +sympy/combinatorics/tests/test_polyhedron.py,sha256=3SWkFQKeF-p1QWP4Iu9NIA1oTxAFo1BLRrrLerBFAhw,4180 +sympy/combinatorics/tests/test_prufer.py,sha256=OTJp0NxjiVswWkOuCIlnGFU2Gw4noRsrPpUJtp2XhEs,2649 +sympy/combinatorics/tests/test_rewriting.py,sha256=3COHq74k6knt2rqE7hfd4ZP_6whf0Kg14tYxFmTtYrI,1787 +sympy/combinatorics/tests/test_schur_number.py,sha256=wg13uTumFltWIGbVg_PEr6nhXIru19UWitsEZiakoRI,1727 +sympy/combinatorics/tests/test_subsets.py,sha256=6pyhLYV5HuXvx63r-gGVHr8LSrGRXcpDudhFn9fBqX8,2635 +sympy/combinatorics/tests/test_tensor_can.py,sha256=olH5D5wwTBOkZXjtqvLO6RKbvCG9KoMVK4__wDe95N4,24676 +sympy/combinatorics/tests/test_testutil.py,sha256=NYEI8X2qLWf7Gv1O7UthOHUSyWpbNCRrGmG_vdyEgF8,1733 +sympy/combinatorics/tests/test_util.py,sha256=sOYMWHxlbM0mqalqA7jNrYMm8DKcf_GwL5YBjs96_C4,4499 +sympy/combinatorics/testutil.py,sha256=GNnqy0mb6yPMa3zpGEzz2p6uxY7VtobPtwUialhfYEQ,11142 +sympy/combinatorics/util.py,sha256=lkOaITBImqB9yyLvN8DU0G-vraw27cSx2XaPdAPVBhg,16296 +sympy/concrete/__init__.py,sha256=2HDmg3VyLgM_ZPw3XsGpkOClGiQnyTlUNHSwVTtizA0,144 +sympy/concrete/__pycache__/__init__.cpython-39.pyc,, +sympy/concrete/__pycache__/delta.cpython-39.pyc,, +sympy/concrete/__pycache__/expr_with_intlimits.cpython-39.pyc,, +sympy/concrete/__pycache__/expr_with_limits.cpython-39.pyc,, +sympy/concrete/__pycache__/gosper.cpython-39.pyc,, +sympy/concrete/__pycache__/guess.cpython-39.pyc,, +sympy/concrete/__pycache__/products.cpython-39.pyc,, +sympy/concrete/__pycache__/summations.cpython-39.pyc,, +sympy/concrete/delta.py,sha256=38M0E24b6V-dn706WXnylM4DAvvyb6tFy7eRE4oR5Bo,9870 +sympy/concrete/expr_with_intlimits.py,sha256=vj4PjttB9xE5aUYu37R1A4_KtGgxcPa65jzjv8-krsc,11352 +sympy/concrete/expr_with_limits.py,sha256=dL5u-b_CzwghTWIhNsGc4md8jPWfhXOvna5Lig6XVr0,21834 +sympy/concrete/gosper.py,sha256=JUFbnOWSls8Xo0Zj198Xm4adopE-rAJjXv6M2q9x7us,5499 +sympy/concrete/guess.py,sha256=6EejdhNgqjMK7lVTXms8-_xGmmpI3vtliLW7-L2eaDM,17474 +sympy/concrete/products.py,sha256=I_aGMeKtSuFU07PGfZPHHTwS1WxDRNbnjedNC_m1k24,18456 +sympy/concrete/summations.py,sha256=D5n4x-fzqhgRwTIWpdN3Xun4I8u847tB7Dcex7lh7lI,55998 +sympy/concrete/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/concrete/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/concrete/tests/__pycache__/test_delta.cpython-39.pyc,, +sympy/concrete/tests/__pycache__/test_gosper.cpython-39.pyc,, +sympy/concrete/tests/__pycache__/test_guess.cpython-39.pyc,, +sympy/concrete/tests/__pycache__/test_products.cpython-39.pyc,, +sympy/concrete/tests/__pycache__/test_sums_products.cpython-39.pyc,, +sympy/concrete/tests/test_delta.py,sha256=uI7xjMx7JuVb3kkN7cLR6_pGsKS4Ulq22p-Z9oti5Jc,23869 +sympy/concrete/tests/test_gosper.py,sha256=ZHiZfYGCeCS9I-0oqN6sFbiYa-284GeFoGsNbhIWq4I,7987 +sympy/concrete/tests/test_guess.py,sha256=TPW6Hy11Po6VLZG_dx95x3sMBYl5kcQH8wjJ6TOtu-k,3370 +sympy/concrete/tests/test_products.py,sha256=caYc-xlEIrX9I_A-KPQdwp5oDprVJSbfcOaKg_qUnsM,14521 +sympy/concrete/tests/test_sums_products.py,sha256=rg0cSziPNgkvyehA0wCf-3jj8fSmNT7ECOOkiXWN64o,66390 +sympy/conftest.py,sha256=wvibr2FiWTlUcLTYAFEL8tSAKWPnv_Q6Cj4lARBCM0U,2775 +sympy/core/__init__.py,sha256=0L72TGngrIg2JknW3elaPSIDmkpPjWGNVfNk33wKXJ0,3123 +sympy/core/__pycache__/__init__.cpython-39.pyc,, +sympy/core/__pycache__/_print_helpers.cpython-39.pyc,, +sympy/core/__pycache__/add.cpython-39.pyc,, +sympy/core/__pycache__/alphabets.cpython-39.pyc,, +sympy/core/__pycache__/assumptions.cpython-39.pyc,, +sympy/core/__pycache__/assumptions_generated.cpython-39.pyc,, +sympy/core/__pycache__/backend.cpython-39.pyc,, +sympy/core/__pycache__/basic.cpython-39.pyc,, +sympy/core/__pycache__/cache.cpython-39.pyc,, +sympy/core/__pycache__/compatibility.cpython-39.pyc,, +sympy/core/__pycache__/containers.cpython-39.pyc,, +sympy/core/__pycache__/core.cpython-39.pyc,, +sympy/core/__pycache__/coreerrors.cpython-39.pyc,, +sympy/core/__pycache__/decorators.cpython-39.pyc,, +sympy/core/__pycache__/evalf.cpython-39.pyc,, +sympy/core/__pycache__/expr.cpython-39.pyc,, +sympy/core/__pycache__/exprtools.cpython-39.pyc,, +sympy/core/__pycache__/facts.cpython-39.pyc,, +sympy/core/__pycache__/function.cpython-39.pyc,, +sympy/core/__pycache__/intfunc.cpython-39.pyc,, +sympy/core/__pycache__/kind.cpython-39.pyc,, +sympy/core/__pycache__/logic.cpython-39.pyc,, +sympy/core/__pycache__/mod.cpython-39.pyc,, +sympy/core/__pycache__/mul.cpython-39.pyc,, +sympy/core/__pycache__/multidimensional.cpython-39.pyc,, +sympy/core/__pycache__/numbers.cpython-39.pyc,, +sympy/core/__pycache__/operations.cpython-39.pyc,, +sympy/core/__pycache__/parameters.cpython-39.pyc,, +sympy/core/__pycache__/power.cpython-39.pyc,, +sympy/core/__pycache__/random.cpython-39.pyc,, +sympy/core/__pycache__/relational.cpython-39.pyc,, +sympy/core/__pycache__/rules.cpython-39.pyc,, +sympy/core/__pycache__/singleton.cpython-39.pyc,, +sympy/core/__pycache__/sorting.cpython-39.pyc,, +sympy/core/__pycache__/symbol.cpython-39.pyc,, +sympy/core/__pycache__/sympify.cpython-39.pyc,, +sympy/core/__pycache__/trace.cpython-39.pyc,, +sympy/core/__pycache__/traversal.cpython-39.pyc,, +sympy/core/_print_helpers.py,sha256=GQo9dI_BvAJtYHVFFfmroNr0L8d71UeI-tU7SGJgctk,2388 +sympy/core/add.py,sha256=TddwsicZFR4yEtVTkZmkzFtn9ndAlhfad7vJiECr6o4,43407 +sympy/core/alphabets.py,sha256=vWBs2atOvfRK6Xfg6hc5IKiB7s_0sZIiVJpcCUJL0N4,266 +sympy/core/assumptions.py,sha256=8K9rhYtT-kqS7hmc9ltkYC-wujYT3pe0BaUJNrOr8fg,23595 +sympy/core/assumptions_generated.py,sha256=0TJKYIHSIFyQcVHZdIHZ19b7tqst_sY7iZwjKzcvZBM,42817 +sympy/core/backend.py,sha256=hF9cftjzB92QEcwimekj0gQ9CEMXTH-leH3l8QcguGo,5273 +sympy/core/basic.py,sha256=IzZvs7iIcMvV8_athdOBnIvBYJg38khuO2fAJtd7uTE,76789 +sympy/core/benchmarks/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/core/benchmarks/__pycache__/__init__.cpython-39.pyc,, +sympy/core/benchmarks/__pycache__/bench_arit.cpython-39.pyc,, +sympy/core/benchmarks/__pycache__/bench_assumptions.cpython-39.pyc,, +sympy/core/benchmarks/__pycache__/bench_basic.cpython-39.pyc,, +sympy/core/benchmarks/__pycache__/bench_expand.cpython-39.pyc,, +sympy/core/benchmarks/__pycache__/bench_numbers.cpython-39.pyc,, +sympy/core/benchmarks/__pycache__/bench_sympify.cpython-39.pyc,, +sympy/core/benchmarks/bench_arit.py,sha256=gfrnvKSXLCaUoFFxMgJhnLUp7rG9Pa_YT7OKgOrPP8E,412 +sympy/core/benchmarks/bench_assumptions.py,sha256=evfZzTgOUUvvvlK0DRdDZQRqxIlGLfJYzKu8QDMxSks,177 +sympy/core/benchmarks/bench_basic.py,sha256=YF0tTJ_AN_Wz11qidzM4bIhlwEhEqVc-IGVGrUx6SaA,210 +sympy/core/benchmarks/bench_expand.py,sha256=xgQYQMwqgXJtKajM4JVhuL-7AW8TLY-vdBpO6uyMDoQ,427 +sympy/core/benchmarks/bench_numbers.py,sha256=_1yHSbYmxriiTTzKtxFh7kJm-rXPL6roGs8MW1E5-sg,1135 +sympy/core/benchmarks/bench_sympify.py,sha256=G5iGInhhbkkxSY2pS08BNG945m9m4eZlNT1aJutGt5M,138 +sympy/core/cache.py,sha256=AyG7kganyV0jVx-aNBEUFogqRLHQqqFn8xU3ZSfJoaM,6172 +sympy/core/compatibility.py,sha256=XQH7ezmRi6l3R23qMHN2wfA-YMRWbh2YYjPY7LRo3lo,1145 +sympy/core/containers.py,sha256=IA8PWQeCl78Pn_1WczoYUzOhBoT9X2hM3_hJ232gyKc,11411 +sympy/core/core.py,sha256=lX3Af31Qyff-gwPiAH-vijjuOzBFSuQtN06nUqIcSTQ,547 +sympy/core/coreerrors.py,sha256=imD9TCL-UEM5hrHcqGQbbfLeyZCSJMx6vB1NBQITQIQ,642 +sympy/core/decorators.py,sha256=NO3oe7yIO4q-6Vvy2vTpIAS8GGI-JHjEOEGeAv97voY,8759 +sympy/core/evalf.py,sha256=X0BinrUYeCKyHPHJQnu3FGASTUB8iz5dweVvqeArBrU,62089 +sympy/core/expr.py,sha256=HR_xzh6273FBmdfNbogGzkWoqaI_k_CeIOnpmyzwCrQ,143770 +sympy/core/exprtools.py,sha256=Ft7C1K2kD4_EdpRrGxii_WDub4pK_rYnLTdeVhlBoeo,51544 +sympy/core/facts.py,sha256=54pFKhJwEzU8LkO7rL25TwGjIb5y5CvZleHEy_TpD68,19546 +sympy/core/function.py,sha256=y0vEdXaUD8thP3XasZVkkogqXOGB0cs2JnfHEa4DNjQ,117102 +sympy/core/intfunc.py,sha256=TigJI9bvsb4aN0aZJgCzxIHlv_eh0rElWeTurier3yI,14219 +sympy/core/kind.py,sha256=9kQvtDxm-SSRGi-155XsBl_rs-oN_7dw7fNNT3mDu2Q,11540 +sympy/core/logic.py,sha256=Qbj2WUh5lS0GjxvtPaNmcQXz7mbWXuWj5ZNFZ-N-YPg,10827 +sympy/core/mod.py,sha256=XDfpth2r4AWc04oLqtDWwFuYRLzI0l0MhLMwGSMu99g,8363 +sympy/core/mul.py,sha256=rjCiyOqxMgUTWR8sAcmilRfQLup9wv8Ij7s7JIfUkVM,79201 +sympy/core/multidimensional.py,sha256=NWX1okybO_nZCl9IhIOE8QYalY1WoC0zlzsvBg_E1eE,4233 +sympy/core/numbers.py,sha256=WEkS1JJ9ov-n4hXvxwrH1NPikNlyBCjLN5v-jOZxvsw,136741 +sympy/core/operations.py,sha256=_wjhMX-MD6cyUNKCc9K_QxyxpeG4WQehdypo0Fp76GM,25828 +sympy/core/parameters.py,sha256=EoT2S3W1dS2-HoV6WN7szBexXvn5_w43e2JFouKuvkU,3854 +sympy/core/power.py,sha256=FJSYRf0S3pk1JGkVy3kdWWQENhmpVme-Pf9SS2Rc6SE,73244 +sympy/core/random.py,sha256=miFdVpNKfutbkpYiIOzG9kVNUm5GTk-_nnmQqUhVDZs,6647 +sympy/core/relational.py,sha256=htbGT0uvzoYD07Q6_hdT0iDDsns0NDI2aml0W8mtAgM,51896 +sympy/core/rules.py,sha256=AJuZztmYKZ_yUITLZB6rhZjDy6ROBCtajcYqPa50sjc,1496 +sympy/core/singleton.py,sha256=2Ueja8GJ-0-cvBzlqPwL9hB9eOtjrJFXV6mkDZO1Bp4,7115 +sympy/core/sorting.py,sha256=PoL2-MtVeuYTu-DISGlbvqW2mt787BmWfzrwV1ibavE,10827 +sympy/core/symbol.py,sha256=0xwgxT7ZfhPjq87glw2CQq5MtIEO-DOrAoXu7bHwWFo,29820 +sympy/core/sympify.py,sha256=1vWheH469HqICHrADMY_k4T4eDQlGORNjEh7Re4WbmE,20546 +sympy/core/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/core/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_args.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_arit.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_assumptions.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_basic.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_cache.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_compatibility.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_complex.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_constructor_postprocessor.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_containers.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_count_ops.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_diff.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_equal.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_eval.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_evalf.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_expand.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_expr.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_exprtools.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_facts.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_function.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_kind.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_logic.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_match.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_multidimensional.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_noncommutative.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_numbers.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_operations.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_parameters.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_power.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_priority.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_random.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_relational.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_rules.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_singleton.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_sorting.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_subs.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_symbol.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_sympify.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_traversal.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_truediv.cpython-39.pyc,, +sympy/core/tests/__pycache__/test_var.cpython-39.pyc,, +sympy/core/tests/test_args.py,sha256=gsvYkFoqSWM0G8sjwBkwwNpT_9jc-io1SyDcchI3fA4,184252 +sympy/core/tests/test_arit.py,sha256=TH4b9Z_sMGWlYqiU5MqpY5LmhzYRJSTHWtTNIrYv7fk,78611 +sympy/core/tests/test_assumptions.py,sha256=MjJdF_ymVL6mtgQx-aSr_rsNNxaTi2pHFLjyaPCBq5Q,41573 +sympy/core/tests/test_basic.py,sha256=skFOWZRNqDujmQ8GAskWH7OYcsrO0miBdM92Gy01k0U,10230 +sympy/core/tests/test_cache.py,sha256=p6Ci75a_T-bBXE_5HVxRKla62uSay_0Vuf57gUuH6sI,2001 +sympy/core/tests/test_compatibility.py,sha256=7pvNUEGIcRrfWl3doqHlm3AdNkGlcChO69gos3Fk09A,240 +sympy/core/tests/test_complex.py,sha256=koNGFMt6UMmzahJADSja_eD24gr-GG5gGCtyDgCRtPI,21906 +sympy/core/tests/test_constructor_postprocessor.py,sha256=0d7vbVuKi3GCm3PKLtiNqv_Au7v6RYt1rzRdHiD08tM,2441 +sympy/core/tests/test_containers.py,sha256=ijteBC6cjqzejfxXZZIELZJfMrgTQa1n6AlxHGCEQJs,7432 +sympy/core/tests/test_count_ops.py,sha256=eIA2WvCuWKXVBJEGfWoJrn6WfUshX_NXttrrfyLbNnI,5665 +sympy/core/tests/test_diff.py,sha256=6j4Vk9UCNRv8Oyx_4iv1ePjocwBg7_-3ftrSJ8u0cPo,5421 +sympy/core/tests/test_equal.py,sha256=RoOJuu4kMe4Rkk7eNyVOJov5S1770YHiVAiziNIKd2o,1678 +sympy/core/tests/test_eval.py,sha256=o0kZn3oaMidVYdNjeZYtx4uUKBoE3A2tWn2NS4hu72Q,2366 +sympy/core/tests/test_evalf.py,sha256=aIf2xMuFomr828s4qv6Q16xLHWNmBnDfC4FUv9g619M,28574 +sympy/core/tests/test_expand.py,sha256=jbIjBEmdsPPArsxJ9206YzMS7mPUVZo7j-7alM795eU,13886 +sympy/core/tests/test_expr.py,sha256=249jMUV5GnkjmVVcgPSkwWvzrT00ppWeygkK_9RO6kc,78428 +sympy/core/tests/test_exprtools.py,sha256=L7fi319z1EeFag6pH8myqDQYQ32H193QLKMdqlxACsY,19021 +sympy/core/tests/test_facts.py,sha256=YEZMZ-116VFnFqJ48h9bQsF2flhiB65trnZvJsRSh_o,11579 +sympy/core/tests/test_function.py,sha256=_YRRGTDthS85Bq4zoLo8bHMHNESPLUZW-4_VQaClCLs,51450 +sympy/core/tests/test_kind.py,sha256=NLJbwCpugzlNbaSyUlbb6NHoT_9dHuoXj023EDQMrNI,2048 +sympy/core/tests/test_logic.py,sha256=_YKSIod6Q0oIz9lDs78UQQrv9LU-uKaztd7w8LWwuwY,5634 +sympy/core/tests/test_match.py,sha256=yjVNccCLutvov8Ru_p38Hl_HDS0ltdsWierSi54m5Ts,22714 +sympy/core/tests/test_multidimensional.py,sha256=Fr-lagme3lwLrBpdaWP7O7oPezhIatn5X8fYYs-8bN8,848 +sympy/core/tests/test_noncommutative.py,sha256=IkGPcvLO4ACVj5LMT2IUgyj68F1RBvMKbm01iqTOK04,4436 +sympy/core/tests/test_numbers.py,sha256=SvTVZUpMWfeGvaqyJUp4e9FYIO8lFJLcbRcmBNWO8m0,78001 +sympy/core/tests/test_operations.py,sha256=mRxftKlrxxrn3zS3UPwqkF6Nr15l5Cv6j3c2RJX46s4,2859 +sympy/core/tests/test_parameters.py,sha256=wO9D-LcMMEyf5u5-EmDwVeQ02YzYbYwtFFR_o-M4ybQ,3560 +sympy/core/tests/test_power.py,sha256=H7SFxpWVnF4vgpF95ad-Fo5xHhIvKhtZJuGBg0J2SAk,24862 +sympy/core/tests/test_priority.py,sha256=6QlWz52qWwOU0CknBb4KErs4R87jp4wH3hVSQNAgdVI,3252 +sympy/core/tests/test_random.py,sha256=H58NfH5BYeQ3RIscbDct6SZkHQVRJjichVUSuSrhvAU,1233 +sympy/core/tests/test_relational.py,sha256=7ne7GIG_i0uVT8FfpCJ4ixnuIyQoVy-PwJm3kYNWbKI,43525 +sympy/core/tests/test_rules.py,sha256=iwmMX7hxC_73CuX9BizeAci-cO4JDq-y1sicKBXEGA4,349 +sympy/core/tests/test_singleton.py,sha256=xLJJgXwmkbKhsot_qTs-o4dniMjHUh3_va0xsA5h-KA,3036 +sympy/core/tests/test_sorting.py,sha256=6BZKYqUedAR-jeHcIgsJelJHFWuougml2c1NNilxGZg,902 +sympy/core/tests/test_subs.py,sha256=7ITJFDplgWBRImkcHfjRdnHqaKgjTxWb4j4WoRysvR8,30106 +sympy/core/tests/test_symbol.py,sha256=zYhPWsdyQp7_NiLVthpoCB1RyP9pmJcNlTdTN2kMdfY,13043 +sympy/core/tests/test_sympify.py,sha256=FEOsRduYX3EqE9EhMILwE9wu6jDj_o2WU9XEZHsXFUo,28195 +sympy/core/tests/test_traversal.py,sha256=cmgvMW8G-LZ20ZXy-wg5Vz5ogI_oq2p2bJSwMy9IMF0,4311 +sympy/core/tests/test_truediv.py,sha256=RYfJX39-mNhekRE3sj5TGFZXKra4ML9vGvObsRYuD3k,854 +sympy/core/tests/test_var.py,sha256=hexP-0q2nN9h_dyhKLCuvqFXgLC9e_Hroni8Ldb16Ko,1594 +sympy/core/trace.py,sha256=9WC8p3OpBL6TdHmZWMDK9jaCG-16f4uZV2VptduVH98,348 +sympy/core/traversal.py,sha256=PP4HUp-2gtsoCtGZwqxpbqZ33vtoRX-yYTbIqOQiaR0,8860 +sympy/crypto/__init__.py,sha256=i8GcbScXhIPbMEe7uuMgXqh_cU2mZm2f6hspIgmW5uM,2158 +sympy/crypto/__pycache__/__init__.cpython-39.pyc,, +sympy/crypto/__pycache__/crypto.cpython-39.pyc,, +sympy/crypto/crypto.py,sha256=iAERU5Qf9GizEUMUyxTgTUBI0U9QpSkbRir2nvKIKUA,89685 +sympy/crypto/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/crypto/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/crypto/tests/__pycache__/test_crypto.cpython-39.pyc,, +sympy/crypto/tests/test_crypto.py,sha256=gPsRBNNHwc4QDBjgT5tvWTCuC5IA5d4aZRTblxXjngk,19758 +sympy/diffgeom/__init__.py,sha256=cWj4N7AfNgrYcGIBexX-UrWxfd1bP9DTNqUmLWUJ9nA,991 +sympy/diffgeom/__pycache__/__init__.cpython-39.pyc,, +sympy/diffgeom/__pycache__/diffgeom.cpython-39.pyc,, +sympy/diffgeom/__pycache__/rn.cpython-39.pyc,, +sympy/diffgeom/diffgeom.py,sha256=DgxgxP9B7NT-M1Hb1-1AwHYHBvXZk_Y6_MqZcGPsmOg,72274 +sympy/diffgeom/rn.py,sha256=kvgth6rNJWt94kzVospZwiH53C-s4VSiorktQNmMobQ,6264 +sympy/diffgeom/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/diffgeom/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/diffgeom/tests/__pycache__/test_class_structure.cpython-39.pyc,, +sympy/diffgeom/tests/__pycache__/test_diffgeom.cpython-39.pyc,, +sympy/diffgeom/tests/__pycache__/test_function_diffgeom_book.cpython-39.pyc,, +sympy/diffgeom/tests/__pycache__/test_hyperbolic_space.cpython-39.pyc,, +sympy/diffgeom/tests/test_class_structure.py,sha256=LbRyxhhp-NnnfJ2gTn1SdlgCBQn2rhyB7xApOgcd_rM,1048 +sympy/diffgeom/tests/test_diffgeom.py,sha256=3BepCr6ned-4C_3me4zScu06HXG9Qx_dBBxIpiXAvy4,14145 +sympy/diffgeom/tests/test_function_diffgeom_book.py,sha256=GwhUAiAPtUv5I9oghdElMhtc6KX184tySgLSpVHhT0Q,5254 +sympy/diffgeom/tests/test_hyperbolic_space.py,sha256=c4xQJ_bBS4xrMj3pfx1Ms3oC2_LwuJuNYXNZxs-cVG8,2598 +sympy/discrete/__init__.py,sha256=A_Seud0IRr2gPYlz6JMQZa3sBhRL3O7gVqhIvMRRvE0,772 +sympy/discrete/__pycache__/__init__.cpython-39.pyc,, +sympy/discrete/__pycache__/convolutions.cpython-39.pyc,, +sympy/discrete/__pycache__/recurrences.cpython-39.pyc,, +sympy/discrete/__pycache__/transforms.cpython-39.pyc,, +sympy/discrete/convolutions.py,sha256=9L2d2Rrn6jqGfV2lBxCV6LmcTNBZUuOIqP_fuMIzPzk,18341 +sympy/discrete/recurrences.py,sha256=FqU5QG4qNNLSVBqcpL7HtKa7rQOlmHMXDQRzHZ_P_s0,5124 +sympy/discrete/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/discrete/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/discrete/tests/__pycache__/test_convolutions.cpython-39.pyc,, +sympy/discrete/tests/__pycache__/test_recurrences.cpython-39.pyc,, +sympy/discrete/tests/__pycache__/test_transforms.cpython-39.pyc,, +sympy/discrete/tests/test_convolutions.py,sha256=K4S9bA1E2tg0VcFQ8SYtlxjL6HB666RxD5rriTuAK4k,17856 +sympy/discrete/tests/test_recurrences.py,sha256=s5ZEZQ262gcnBLpCjJVmeKlTKQByRTQBrc-N9p_4W8c,3019 +sympy/discrete/tests/test_transforms.py,sha256=vEORFaPvxmPSsw0f4Z2hLEN1wD0FdyQOYHDEY9aVm5A,5546 +sympy/discrete/transforms.py,sha256=lf-n6IN881uCfTUAxPNjdUaSguiRbYW0omuR96vKNlE,11681 +sympy/external/__init__.py,sha256=C6s4654Elc_X-D9UgI2cUQWiQyGDt9LG3IKUc8qqzuo,578 +sympy/external/__pycache__/__init__.cpython-39.pyc,, +sympy/external/__pycache__/gmpy.cpython-39.pyc,, +sympy/external/__pycache__/importtools.cpython-39.pyc,, +sympy/external/__pycache__/ntheory.cpython-39.pyc,, +sympy/external/__pycache__/pythonmpq.cpython-39.pyc,, +sympy/external/gmpy.py,sha256=89vXc1KD5TiQCTwlL-98r01KdqsOsJpzc8XN1R7TwM4,9768 +sympy/external/importtools.py,sha256=Q7tS2cdGZ9a4NI_1sgGuoVcSDv_rIk-Av0BpFTa6EzA,7671 +sympy/external/ntheory.py,sha256=jnm13KGII1p6X42T-u2jVoBxiOCXZ8XGxjzhCuWXthE,16555 +sympy/external/pythonmpq.py,sha256=DBQYdZsCUajqIjV5ECYpzpCSbe4RR9qr4D1x_ZTuZaY,11259 +sympy/external/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/external/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/external/tests/__pycache__/test_autowrap.cpython-39.pyc,, +sympy/external/tests/__pycache__/test_codegen.cpython-39.pyc,, +sympy/external/tests/__pycache__/test_gmpy.cpython-39.pyc,, +sympy/external/tests/__pycache__/test_importtools.cpython-39.pyc,, +sympy/external/tests/__pycache__/test_ntheory.cpython-39.pyc,, +sympy/external/tests/__pycache__/test_numpy.cpython-39.pyc,, +sympy/external/tests/__pycache__/test_pythonmpq.cpython-39.pyc,, +sympy/external/tests/__pycache__/test_scipy.cpython-39.pyc,, +sympy/external/tests/test_autowrap.py,sha256=nUWYeQd1QF-Ou4A63iWypsn6pJMi0uMLyba8aRCOhnM,9759 +sympy/external/tests/test_codegen.py,sha256=t991vYFfQC6loodk7Vp1oKTCklBn8SwY7qLshVVou4w,12598 +sympy/external/tests/test_gmpy.py,sha256=8ITpuWYeitCymWwuQLpOOVBmRb3CJsO5sbqiZcVDulE,397 +sympy/external/tests/test_importtools.py,sha256=KrfontKYv11UvpazQ0vS1qyhxIvgZrCOXh1JFeACjeo,1394 +sympy/external/tests/test_ntheory.py,sha256=BJWirDnX7Y7McBoXreMomGQx33YlAjiuBTYQQLhobLU,8881 +sympy/external/tests/test_numpy.py,sha256=tuEji5l6GqbNjv74T6a3e8LDzI2zKUaLzvfluNXOFE0,10291 +sympy/external/tests/test_pythonmpq.py,sha256=L_FdZmmk5N-VEivE_O_qZa98BZhT1WSxRfdmG817bA0,5797 +sympy/external/tests/test_scipy.py,sha256=CVaw7D0-6DORgg78Q6b35SNKn05PlKwWJuqXOuU-qdY,1172 +sympy/functions/__init__.py,sha256=-u5IzcQAPk9emytXfMK22EVXqpXUtSau3pQtevZYmFo,5565 +sympy/functions/__pycache__/__init__.cpython-39.pyc,, +sympy/functions/combinatorial/__init__.py,sha256=WqXI3qU_TTJ7nJA8m3Z-7ZAYKoApT8f9Xs0u2bTwy_c,53 +sympy/functions/combinatorial/__pycache__/__init__.cpython-39.pyc,, +sympy/functions/combinatorial/__pycache__/factorials.cpython-39.pyc,, +sympy/functions/combinatorial/__pycache__/numbers.cpython-39.pyc,, +sympy/functions/combinatorial/factorials.py,sha256=cYMN1Bz76XninzSxzdLYxTN7-Mt1mzfUcSWIU-qpYAw,39184 +sympy/functions/combinatorial/numbers.py,sha256=PcW4rvWQ66POd1FcL8SJav2iUmGv2Ip_Vbbp_3qYIAY,101088 +sympy/functions/combinatorial/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/functions/combinatorial/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/functions/combinatorial/tests/__pycache__/test_comb_factorials.cpython-39.pyc,, +sympy/functions/combinatorial/tests/__pycache__/test_comb_numbers.cpython-39.pyc,, +sympy/functions/combinatorial/tests/test_comb_factorials.py,sha256=B59sfZxTUmKtnG62LW9sPn7HnRxFBP6x_rnbKqbaaZ0,26317 +sympy/functions/combinatorial/tests/test_comb_numbers.py,sha256=GWeIeb9zJo3WJr9NGW7VD-PmCIt13ce-VufICHs1-A8,47593 +sympy/functions/elementary/__init__.py,sha256=Fj8p5qE-Rr1lqAyHI0aSgC3RYX56O-gWwo6wu-eUQYA,50 +sympy/functions/elementary/__pycache__/__init__.cpython-39.pyc,, +sympy/functions/elementary/__pycache__/_trigonometric_special.cpython-39.pyc,, +sympy/functions/elementary/__pycache__/complexes.cpython-39.pyc,, +sympy/functions/elementary/__pycache__/exponential.cpython-39.pyc,, +sympy/functions/elementary/__pycache__/hyperbolic.cpython-39.pyc,, +sympy/functions/elementary/__pycache__/integers.cpython-39.pyc,, +sympy/functions/elementary/__pycache__/miscellaneous.cpython-39.pyc,, +sympy/functions/elementary/__pycache__/piecewise.cpython-39.pyc,, +sympy/functions/elementary/__pycache__/trigonometric.cpython-39.pyc,, +sympy/functions/elementary/_trigonometric_special.py,sha256=FvrgSXisxjXjyBC4-NLLya6q2YyTMNMAUqqYzuYl34g,7271 +sympy/functions/elementary/benchmarks/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/functions/elementary/benchmarks/__pycache__/__init__.cpython-39.pyc,, +sympy/functions/elementary/benchmarks/__pycache__/bench_exp.cpython-39.pyc,, +sympy/functions/elementary/benchmarks/bench_exp.py,sha256=PFBYa9eMovH5XOFN5XTxWr1VDj1EBoKwn4mAtj-_DdM,185 +sympy/functions/elementary/complexes.py,sha256=HQm4ViBWN5OkCtJZQPdLdChTnmzbk4Cb2A--MUzTgUs,44168 +sympy/functions/elementary/exponential.py,sha256=GoDlUBPcmpRjYehRZOooF0bISSK_oxpJlkqaE_KNTeQ,42610 +sympy/functions/elementary/hyperbolic.py,sha256=zewRSHNzxyvSFD9sP9XBKX-evG6Rds4k9jAAcfmDbIk,69329 +sympy/functions/elementary/integers.py,sha256=Dc5Bq6ds85NJYWkfHJPJccUeVTxbhzVloTeMo5HGG98,22394 +sympy/functions/elementary/miscellaneous.py,sha256=cPpv71po3HCQGsWIHgcQEql6RNb4Y3RrL24ujgtdTIQ,27944 +sympy/functions/elementary/piecewise.py,sha256=Nv32I9lktMPHF3k0JpifbcgCgIz7CyQ8KwI_t1BaVsI,58285 +sympy/functions/elementary/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/functions/elementary/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/functions/elementary/tests/__pycache__/test_complexes.cpython-39.pyc,, +sympy/functions/elementary/tests/__pycache__/test_exponential.cpython-39.pyc,, +sympy/functions/elementary/tests/__pycache__/test_hyperbolic.cpython-39.pyc,, +sympy/functions/elementary/tests/__pycache__/test_integers.cpython-39.pyc,, +sympy/functions/elementary/tests/__pycache__/test_interface.cpython-39.pyc,, +sympy/functions/elementary/tests/__pycache__/test_miscellaneous.cpython-39.pyc,, +sympy/functions/elementary/tests/__pycache__/test_piecewise.cpython-39.pyc,, +sympy/functions/elementary/tests/__pycache__/test_trigonometric.cpython-39.pyc,, +sympy/functions/elementary/tests/test_complexes.py,sha256=BaPIrUSkoguP74x7uLSt9bA9kxmB58fJ4wgNtx5RnCg,34016 +sympy/functions/elementary/tests/test_exponential.py,sha256=ojjC0aafWaQegZVYV7bgxyvWAES_erpeYTb3R0m9V5I,29733 +sympy/functions/elementary/tests/test_hyperbolic.py,sha256=TBJZbiV-o9c2uVUqnkaRuHr5FdSGBOpqKK5_2bWg92s,56898 +sympy/functions/elementary/tests/test_integers.py,sha256=_QaECu3JQRQa9uPi10jXlu9AohvtLTKB4rSe920g0tY,23266 +sympy/functions/elementary/tests/test_interface.py,sha256=dW0aHvLR5UPyyXzBS-tgKrQIp-H6Vu8sKjlQrBQgtZ4,2513 +sympy/functions/elementary/tests/test_miscellaneous.py,sha256=eCL30UmsusBhjvqICQNmToa1aJTML8fXav1L1J6b7FU,17148 +sympy/functions/elementary/tests/test_piecewise.py,sha256=b1c0IMbXHDLjbi_ESs7fQhJiMaEro30GEu9XYreicYA,62876 +sympy/functions/elementary/tests/test_trigonometric.py,sha256=ebRorj-WThNp56qvZOAGkvgpvWFnIlB6kFZyV1gCii4,90340 +sympy/functions/elementary/trigonometric.py,sha256=b2mE2SpNj0LZM3AcipW0FrjturAi9ozijK3PhqyfCbM,115945 +sympy/functions/special/__init__.py,sha256=5pjIq_RVCMsuCe1b-FlwIty30KxoUowZYKLmpIT9KHQ,59 +sympy/functions/special/__pycache__/__init__.cpython-39.pyc,, +sympy/functions/special/__pycache__/bessel.cpython-39.pyc,, +sympy/functions/special/__pycache__/beta_functions.cpython-39.pyc,, +sympy/functions/special/__pycache__/bsplines.cpython-39.pyc,, +sympy/functions/special/__pycache__/delta_functions.cpython-39.pyc,, +sympy/functions/special/__pycache__/elliptic_integrals.cpython-39.pyc,, +sympy/functions/special/__pycache__/error_functions.cpython-39.pyc,, +sympy/functions/special/__pycache__/gamma_functions.cpython-39.pyc,, +sympy/functions/special/__pycache__/hyper.cpython-39.pyc,, +sympy/functions/special/__pycache__/mathieu_functions.cpython-39.pyc,, +sympy/functions/special/__pycache__/polynomials.cpython-39.pyc,, +sympy/functions/special/__pycache__/singularity_functions.cpython-39.pyc,, +sympy/functions/special/__pycache__/spherical_harmonics.cpython-39.pyc,, +sympy/functions/special/__pycache__/tensor_functions.cpython-39.pyc,, +sympy/functions/special/__pycache__/zeta_functions.cpython-39.pyc,, +sympy/functions/special/benchmarks/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/functions/special/benchmarks/__pycache__/__init__.cpython-39.pyc,, +sympy/functions/special/benchmarks/__pycache__/bench_special.cpython-39.pyc,, +sympy/functions/special/benchmarks/bench_special.py,sha256=wzAoKTccuEaG4xrEYTlYfIJuLi3kUTMTEJ9iA113Wog,164 +sympy/functions/special/bessel.py,sha256=r-jJETBLp8a4JOKV9dW5EWr_yvERL0HMn0su328bug8,68656 +sympy/functions/special/beta_functions.py,sha256=de-ypTWZEjvFkq9RgPq3ghGOvFOfPmb9JfUDAPKQoXk,12777 +sympy/functions/special/bsplines.py,sha256=q7iZSeEMlRrx5gpXOzlk4uSTPIollduNZuGwMHB6fRM,10060 +sympy/functions/special/delta_functions.py,sha256=jopb1BX803pZtDAtLmKfFqRVjo3mySekyC80Lu-sJkM,19887 +sympy/functions/special/elliptic_integrals.py,sha256=YpCU-xyTvrmBxdO7Rj2pNYDbS2-bGfCvCaOHon9nSq0,14921 +sympy/functions/special/error_functions.py,sha256=GvKrm5fVMpat0Iierh9UsxsTeQknJyiHgMrWS-e7fhw,79556 +sympy/functions/special/gamma_functions.py,sha256=gkPQdMNR7k5uY1vSUu6zzyoJKMoQZfC90YjLRcyft5c,42943 +sympy/functions/special/hyper.py,sha256=oqeAhFALAN_H6eS0Fs_axdfL9jIYQ-aRFGGgQky2HYk,38829 +sympy/functions/special/mathieu_functions.py,sha256=kFmt0RUA70vhGphI7lhwUsOQ7KIy8Y95DU6gGkVQU5Y,6620 +sympy/functions/special/polynomials.py,sha256=IdxRf8D1iFA-oYfLyTcUmsSCeuWja4h2ExNb8UShwxg,46753 +sympy/functions/special/singularity_functions.py,sha256=ounzMgFQiOW-e1Rb5y_dCSopJGegXMdkj-EpjrP8tzU,8346 +sympy/functions/special/spherical_harmonics.py,sha256=MNAJ4ABTy_EIZe2XSK_MZ1SPnk1OeDnJAtrICAst4yw,11018 +sympy/functions/special/tensor_functions.py,sha256=UNoDpLu1EBR-2tX_dLoRpKhN0Yhm-z_hh8EtW50dNSE,12303 +sympy/functions/special/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/functions/special/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/functions/special/tests/__pycache__/test_bessel.cpython-39.pyc,, +sympy/functions/special/tests/__pycache__/test_beta_functions.cpython-39.pyc,, +sympy/functions/special/tests/__pycache__/test_bsplines.cpython-39.pyc,, +sympy/functions/special/tests/__pycache__/test_delta_functions.cpython-39.pyc,, +sympy/functions/special/tests/__pycache__/test_elliptic_integrals.cpython-39.pyc,, +sympy/functions/special/tests/__pycache__/test_error_functions.cpython-39.pyc,, +sympy/functions/special/tests/__pycache__/test_gamma_functions.cpython-39.pyc,, +sympy/functions/special/tests/__pycache__/test_hyper.cpython-39.pyc,, +sympy/functions/special/tests/__pycache__/test_mathieu.cpython-39.pyc,, +sympy/functions/special/tests/__pycache__/test_singularity_functions.cpython-39.pyc,, +sympy/functions/special/tests/__pycache__/test_spec_polynomials.cpython-39.pyc,, +sympy/functions/special/tests/__pycache__/test_spherical_harmonics.cpython-39.pyc,, +sympy/functions/special/tests/__pycache__/test_tensor_functions.cpython-39.pyc,, +sympy/functions/special/tests/__pycache__/test_zeta_functions.cpython-39.pyc,, +sympy/functions/special/tests/test_bessel.py,sha256=c6OKhj4MCjv4kpDG7udueaCOLykblsKHMOyP4CdFToA,37373 +sympy/functions/special/tests/test_beta_functions.py,sha256=yxfgu-wmNEeMfaFABiDHYmuZpZup9FTp0ZYerlc6hhc,3786 +sympy/functions/special/tests/test_bsplines.py,sha256=6UYg7IqXTi8fcSOut8TEzNVkxIA4ff-CyG22qJnbIYA,7145 +sympy/functions/special/tests/test_delta_functions.py,sha256=8xhSWG4SLL86z1QKFfLk_3b--bCrxjvCaxHlODBVToE,7138 +sympy/functions/special/tests/test_elliptic_integrals.py,sha256=scu7KemJ7Q2nsRcZtckQNruuthI-vla9M1sDVkLcbKM,6991 +sympy/functions/special/tests/test_error_functions.py,sha256=TkhkIKiuRZS78-F4ZTt5rvr3hfknp654IEhmHhRXT48,33565 +sympy/functions/special/tests/test_gamma_functions.py,sha256=aveGPdG1X9OdlC6xpdZbSsJ3S4M4DaG2osyOhm1fAYo,29908 +sympy/functions/special/tests/test_hyper.py,sha256=SEnQ9TtyO_dSQRc94AHrbQai6Q7-tmRfDIh40gNE3FE,16694 +sympy/functions/special/tests/test_mathieu.py,sha256=pqoFbnC84NDL6EQkigFtx5OQ1RFYppckTjzsm9XT0PY,1282 +sympy/functions/special/tests/test_singularity_functions.py,sha256=GEXzHHeR5F0swrZzhF7pEGR0tk9Vp3Ve2c-sE4WSpsY,6249 +sympy/functions/special/tests/test_spec_polynomials.py,sha256=wuiZaR_LwaM8SlNuGl3B1p4eOHC_-zZVSXMPNfzKRB4,19561 +sympy/functions/special/tests/test_spherical_harmonics.py,sha256=pUFtFpNPBnJTdnqou0jniSchijyh1rdzKv8H24RT9FU,3850 +sympy/functions/special/tests/test_tensor_functions.py,sha256=bblSDkPABZ6N1j1Rb2Bb5TZIzZoK1D8ks3fHizi69ZI,5546 +sympy/functions/special/tests/test_zeta_functions.py,sha256=2r59_aC0QOXQsBNXqxsHPr2PkJExusI6qvSydZBPbfw,10474 +sympy/functions/special/zeta_functions.py,sha256=QPRYVp0-WKs02CUSQiRIJUWD4N_3PmqrOsk1MGzBUfU,24091 +sympy/galgebra.py,sha256=yEosUPSnhLp9a1NWXvpCLoU20J6TQ58XNIvw07POkVk,123 +sympy/geometry/__init__.py,sha256=BU2MiKm8qJyZJ_hz1qC-3nFJTPEcuvx4hYd02jHjqSM,1240 +sympy/geometry/__pycache__/__init__.cpython-39.pyc,, +sympy/geometry/__pycache__/curve.cpython-39.pyc,, +sympy/geometry/__pycache__/ellipse.cpython-39.pyc,, +sympy/geometry/__pycache__/entity.cpython-39.pyc,, +sympy/geometry/__pycache__/exceptions.cpython-39.pyc,, +sympy/geometry/__pycache__/line.cpython-39.pyc,, +sympy/geometry/__pycache__/parabola.cpython-39.pyc,, +sympy/geometry/__pycache__/plane.cpython-39.pyc,, +sympy/geometry/__pycache__/point.cpython-39.pyc,, +sympy/geometry/__pycache__/polygon.cpython-39.pyc,, +sympy/geometry/__pycache__/util.cpython-39.pyc,, +sympy/geometry/curve.py,sha256=F7b6XrlhUZ0QWLDoZJVojWfC5LeyOU-69OTFnYAREg8,10170 +sympy/geometry/ellipse.py,sha256=So0gH9e-1Oh3HGzrXoLqANyP4lTTBfGHHVo5vCdbfdk,50305 +sympy/geometry/entity.py,sha256=fvHhtSb6RvE6v-8yMyCNvm0ekLPoO7EO9J8TEsGyQGU,20668 +sympy/geometry/exceptions.py,sha256=XtUMA44UTdrBWt771jegFC-TXsobhDiI-10TDH_WNFM,131 +sympy/geometry/line.py,sha256=Yw8ns_w8rn6FEak6MmfsdccgNNjMQHhodDik2UX7jJM,80397 +sympy/geometry/parabola.py,sha256=YhtQ9C5Yco_iSQ6ZYRs0lZBWHmDkcvKUr-Y8RghY3M8,10716 +sympy/geometry/plane.py,sha256=z3hjEf3ibO735DV0Mb4DVjvthYlDyiPfekCeePSCDMM,26770 +sympy/geometry/point.py,sha256=vQB40V1fvhZze3D_D54HlMSS8gXeMIWerd2hBozmPFA,36661 +sympy/geometry/polygon.py,sha256=aooJyJVwf6ZPuxStYgTc-2jNjVaM2YHSvpVY3XRjAuo,82027 +sympy/geometry/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/geometry/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/geometry/tests/__pycache__/test_curve.cpython-39.pyc,, +sympy/geometry/tests/__pycache__/test_ellipse.cpython-39.pyc,, +sympy/geometry/tests/__pycache__/test_entity.cpython-39.pyc,, +sympy/geometry/tests/__pycache__/test_geometrysets.cpython-39.pyc,, +sympy/geometry/tests/__pycache__/test_line.cpython-39.pyc,, +sympy/geometry/tests/__pycache__/test_parabola.cpython-39.pyc,, +sympy/geometry/tests/__pycache__/test_plane.cpython-39.pyc,, +sympy/geometry/tests/__pycache__/test_point.cpython-39.pyc,, +sympy/geometry/tests/__pycache__/test_polygon.cpython-39.pyc,, +sympy/geometry/tests/__pycache__/test_util.cpython-39.pyc,, +sympy/geometry/tests/test_curve.py,sha256=xL4uRWAal4mXZxuQhcs9QOhs6MheCbFNyH1asq_a2IQ,4479 +sympy/geometry/tests/test_ellipse.py,sha256=f-EorUkNl_cg63bDvm4Qrbe1PWrBaCCRg8jAmdIYBuU,26509 +sympy/geometry/tests/test_entity.py,sha256=KSsncs7afLpqUAYTs9b0VH1dt6FhPrX5rrj2rdLpBUo,3896 +sympy/geometry/tests/test_geometrysets.py,sha256=vvOWrFrJuNAFgbrVh1wPY94o-H-85FWlnIyyo2Kst9c,1911 +sympy/geometry/tests/test_line.py,sha256=fLWGfHQNC3YqjxtT7zQbP--Cu0mjcPwSBfji4DA7BP0,38054 +sympy/geometry/tests/test_parabola.py,sha256=kd0RU5sGOcfp6jgwgXMtvT2B6kG1-M3-iGOLnUJfZOw,6150 +sympy/geometry/tests/test_plane.py,sha256=QRcfoDsJtCtcvjFb18hBEHupycLgAT2OohF6GpNShyQ,12525 +sympy/geometry/tests/test_point.py,sha256=YKXQdlBTQWsIVf9l3G6iPMq0OqMTo8nZlibhJDsq_Ic,16410 +sympy/geometry/tests/test_polygon.py,sha256=4D0t98SELQF42dYR-CbFqDU5yJds0LAE2RE8xqEdqag,27602 +sympy/geometry/tests/test_util.py,sha256=sbh1QvkQG1OqvE-kt4fNNIkMWnOFi5EpaBmnZS3pzNc,7044 +sympy/geometry/util.py,sha256=zF6Xx0ciab6Rxm4ekK5RawWXPWAk5oug8eDvgfNKoRc,20671 +sympy/holonomic/__init__.py,sha256=BgHIokaSOo3nwJlGO_caJHz37n6yoA8GeM9Xjn4zMpc,784 +sympy/holonomic/__pycache__/__init__.cpython-39.pyc,, +sympy/holonomic/__pycache__/holonomic.cpython-39.pyc,, +sympy/holonomic/__pycache__/holonomicerrors.cpython-39.pyc,, +sympy/holonomic/__pycache__/numerical.cpython-39.pyc,, +sympy/holonomic/__pycache__/recurrence.cpython-39.pyc,, +sympy/holonomic/holonomic.py,sha256=vP5XabVIezsAURbK65p8P_QfPINQwhol0d62B7Br3s0,91887 +sympy/holonomic/holonomicerrors.py,sha256=qDyUoGbrRjPtVax4SeEEf_o6-264mASEZO_rZETXH5o,1193 +sympy/holonomic/numerical.py,sha256=MlP8xTBqjVQqM1-jwviS_yPICKj6tEVsa9f6wyMCSfo,2625 +sympy/holonomic/recurrence.py,sha256=HWSMokhjcdGTLW_oku0Xi8W8RHxWYS28bv8V7eubC58,10377 +sympy/holonomic/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/holonomic/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/holonomic/tests/__pycache__/test_holonomic.cpython-39.pyc,, +sympy/holonomic/tests/__pycache__/test_recurrence.cpython-39.pyc,, +sympy/holonomic/tests/test_holonomic.py,sha256=NuoJtJi3RlpMYhMU_bGH1tuFuFoOZZrrSdMpgdHGhf0,35333 +sympy/holonomic/tests/test_recurrence.py,sha256=qHv0kn1Q4-aCD7XmbDK2xIdkjF0XkeZUKD2yeLajiq0,1342 +sympy/integrals/__init__.py,sha256=pZ-C3tDP_8woKActNoS7IwyW-9AuB5PMHB5B0ohlpw8,1970 +sympy/integrals/__pycache__/__init__.cpython-39.pyc,, +sympy/integrals/__pycache__/deltafunctions.cpython-39.pyc,, +sympy/integrals/__pycache__/heurisch.cpython-39.pyc,, +sympy/integrals/__pycache__/integrals.cpython-39.pyc,, +sympy/integrals/__pycache__/intpoly.cpython-39.pyc,, +sympy/integrals/__pycache__/laplace.cpython-39.pyc,, +sympy/integrals/__pycache__/manualintegrate.cpython-39.pyc,, +sympy/integrals/__pycache__/meijerint.cpython-39.pyc,, +sympy/integrals/__pycache__/meijerint_doc.cpython-39.pyc,, +sympy/integrals/__pycache__/prde.cpython-39.pyc,, +sympy/integrals/__pycache__/quadrature.cpython-39.pyc,, +sympy/integrals/__pycache__/rationaltools.cpython-39.pyc,, +sympy/integrals/__pycache__/rde.cpython-39.pyc,, +sympy/integrals/__pycache__/risch.cpython-39.pyc,, +sympy/integrals/__pycache__/singularityfunctions.cpython-39.pyc,, +sympy/integrals/__pycache__/transforms.cpython-39.pyc,, +sympy/integrals/__pycache__/trigonometry.cpython-39.pyc,, +sympy/integrals/benchmarks/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/integrals/benchmarks/__pycache__/__init__.cpython-39.pyc,, +sympy/integrals/benchmarks/__pycache__/bench_integrate.cpython-39.pyc,, +sympy/integrals/benchmarks/__pycache__/bench_trigintegrate.cpython-39.pyc,, +sympy/integrals/benchmarks/bench_integrate.py,sha256=vk6wAO1bqzFT9oW4qsW7nKGfc_gP0XaB5PMYKx5339Q,396 +sympy/integrals/benchmarks/bench_trigintegrate.py,sha256=8XU3uB3mcavigvzHQZA7H1sHI32zgT-9RkSnLa-Y3Vc,305 +sympy/integrals/deltafunctions.py,sha256=ysIQLdRBcG_YR-bVDoxt-sxEVU8TG77oSgM-J0gI0mE,7435 +sympy/integrals/heurisch.py,sha256=Huq3dBJEXvYGzkBboXnG6xFCGorc5Uftf5ssbiltyIs,26706 +sympy/integrals/integrals.py,sha256=DV_ditazAyVBpUCzsUFNiYvvQRqJfx3UGjnSEB2HKcg,64887 +sympy/integrals/intpoly.py,sha256=SXjd_f295YrYsvoQpzE2EQM5xaQnnj0zvHdYW5KEdn0,43266 +sympy/integrals/laplace.py,sha256=y5NxZ3_nNtc-NETYWw4iha7LTVM9Ok9NWPtmMUXtZwg,87038 +sympy/integrals/manualintegrate.py,sha256=sO_5-9pbeRKhZSuEwXm6okxpvRsbpQZerbLBdRdDtp4,75721 +sympy/integrals/meijerint.py,sha256=yOsrHJA7doo1uDsmsiTqK2Tly4JF30sA5RElT92lfec,80775 +sympy/integrals/meijerint_doc.py,sha256=Sis9bZvwciU2_BZzr3GCvXLkBcXVUhP9ycqGIDu99xU,1204 +sympy/integrals/prde.py,sha256=3VqBQOX7SZv_lETmhoVgb9gRjb1z8hiZAJ6vv5AgNEo,52073 +sympy/integrals/quadrature.py,sha256=6Bg3JmlIjIduIfaGfNVcwNfSrgEiLOszcN8WPzsXNqE,17064 +sympy/integrals/rationaltools.py,sha256=ZAizq9unSt-_wQkCIAhl_FaemzbQcfBcE9A3mDiU788,11694 +sympy/integrals/rde.py,sha256=EKXK6bf4m-2O0NZ-rtHq0Ye2a_o9LmUrkoV8gNxZKVk,27393 +sympy/integrals/risch.py,sha256=4UsTHizZtAMe3VUmRHugdOuPMilw8sKiVOlSUMWj9oE,67352 +sympy/integrals/singularityfunctions.py,sha256=ONI8x-ed-IcqOF4K2l0LVUvEUN2_dHztvL4auRsi67U,2235 +sympy/integrals/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/integrals/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/integrals/tests/__pycache__/test_deltafunctions.cpython-39.pyc,, +sympy/integrals/tests/__pycache__/test_failing_integrals.cpython-39.pyc,, +sympy/integrals/tests/__pycache__/test_heurisch.cpython-39.pyc,, +sympy/integrals/tests/__pycache__/test_integrals.cpython-39.pyc,, +sympy/integrals/tests/__pycache__/test_intpoly.cpython-39.pyc,, +sympy/integrals/tests/__pycache__/test_laplace.cpython-39.pyc,, +sympy/integrals/tests/__pycache__/test_lineintegrals.cpython-39.pyc,, +sympy/integrals/tests/__pycache__/test_manual.cpython-39.pyc,, +sympy/integrals/tests/__pycache__/test_meijerint.cpython-39.pyc,, +sympy/integrals/tests/__pycache__/test_prde.cpython-39.pyc,, +sympy/integrals/tests/__pycache__/test_quadrature.cpython-39.pyc,, +sympy/integrals/tests/__pycache__/test_rationaltools.cpython-39.pyc,, +sympy/integrals/tests/__pycache__/test_rde.cpython-39.pyc,, +sympy/integrals/tests/__pycache__/test_risch.cpython-39.pyc,, +sympy/integrals/tests/__pycache__/test_singularityfunctions.cpython-39.pyc,, +sympy/integrals/tests/__pycache__/test_transforms.cpython-39.pyc,, +sympy/integrals/tests/__pycache__/test_trigonometry.cpython-39.pyc,, +sympy/integrals/tests/test_deltafunctions.py,sha256=ivFjS-WlLQ4aMqjVS7ZzMChP2Mmw_JUPnwI9otiLnvs,3709 +sympy/integrals/tests/test_failing_integrals.py,sha256=HKMA6O26exCIRWQ43KyKugOvJL3gOei3r9JObcKo8p0,7438 +sympy/integrals/tests/test_heurisch.py,sha256=-btpZS_LI9EO-2EmA0eHV6l7xykQjkHIcpDU_GziTlo,14299 +sympy/integrals/tests/test_integrals.py,sha256=8bVIN4ztTwOvpLrlDIk1SURQMtiNXoqu-LLMkZp-Wrk,81192 +sympy/integrals/tests/test_intpoly.py,sha256=NzGhkR2pUMfd8lIU2cFR9bFa0J89RzpHs3zDggAWtXo,37445 +sympy/integrals/tests/test_laplace.py,sha256=pWn5UEm15gJv_hD35lvJh03Uuw4o79rbCQYvbXQS2xk,38186 +sympy/integrals/tests/test_lineintegrals.py,sha256=zcPJ2n7DYt9KsgAe38t0gq3ARApUlb-kBahLThuRcq8,450 +sympy/integrals/tests/test_manual.py,sha256=unQ6Ew81SCHvVMva_aAOndh7hWpLeDwBtXs6W-yC0jE,34551 +sympy/integrals/tests/test_meijerint.py,sha256=G22dppQGMFU3JGexHZs65nce8UAK9wTilZn0D1XvZu0,32594 +sympy/integrals/tests/test_prde.py,sha256=2BZmEDasdx_3l64-9hioArysDj6Nl520GpQN2xnEE_A,16360 +sympy/integrals/tests/test_quadrature.py,sha256=iFMdqck36gkL-yksLflawIOYmw-0PzO2tFj_qdK6Hjg,19919 +sympy/integrals/tests/test_rationaltools.py,sha256=7QiPnpBXl7lo32RmhXo7ED6FYj5I1gjEFOziJYlqPtI,5669 +sympy/integrals/tests/test_rde.py,sha256=4d3vJupa-hRN4yNDISY8IC3rSI_cZW5BbtxoZm14y-Y,9571 +sympy/integrals/tests/test_risch.py,sha256=HaWg0JnErdrNzNmVfyz2Zz4XAgZPVVpZPt6Map3sQ58,38630 +sympy/integrals/tests/test_singularityfunctions.py,sha256=CSrHie59_NjNZ9B2GaHzKPNsMzxm5Kh6GuxlYk8zTuI,1266 +sympy/integrals/tests/test_transforms.py,sha256=mLl5RfuENf11oenV1OXgg6TfWdQTEY23txiGl7grCyo,27152 +sympy/integrals/tests/test_trigonometry.py,sha256=moMYr_Prc7gaYPjBK0McLjRpTEes2veUlN0vGv9UyEA,3869 +sympy/integrals/transforms.py,sha256=_rwRkrOnV4zby27MZJ9dpJrtQFG4gTbL7naLpFLXKDo,51750 +sympy/integrals/trigonometry.py,sha256=iOoBDGFDZx8PNbgL3XeZEd80I8ro0WAizNuC4P-u8x0,11083 +sympy/interactive/__init__.py,sha256=yokwEO2HF3eN2Xu65JSpUUsN4iYmPvvU4m_64f3Q33o,251 +sympy/interactive/__pycache__/__init__.cpython-39.pyc,, +sympy/interactive/__pycache__/printing.cpython-39.pyc,, +sympy/interactive/__pycache__/session.cpython-39.pyc,, +sympy/interactive/__pycache__/traversal.cpython-39.pyc,, +sympy/interactive/printing.py,sha256=EZBO3YCga2dXBI1EfFHVDmUKBkF9UmrLMgwI76B0nyE,21562 +sympy/interactive/session.py,sha256=sG546e0mAtT0OrFkYNVM7QGvkWrDhAQZ5E1hfx03iBQ,15329 +sympy/interactive/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/interactive/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/interactive/tests/__pycache__/test_interactive.cpython-39.pyc,, +sympy/interactive/tests/__pycache__/test_ipython.cpython-39.pyc,, +sympy/interactive/tests/test_interactive.py,sha256=Pbopy9lODrd_P46_xxlWxLwqPfG6_4J3CWWC4IqfDL4,485 +sympy/interactive/tests/test_ipython.py,sha256=QZUvr77TxJHjdgQVNpihAuoglg1ghVHOTDFOVXJ9pcE,11797 +sympy/interactive/traversal.py,sha256=XbccdO6msNAvrG6FFJl2n4XmIiRISnvda4QflfEPg7U,3189 +sympy/liealgebras/__init__.py,sha256=K8tw7JqG33_y6mYl1LTr8ZNtKH5L21BqkjCHfLhP4aA,79 +sympy/liealgebras/__pycache__/__init__.cpython-39.pyc,, +sympy/liealgebras/__pycache__/cartan_matrix.cpython-39.pyc,, +sympy/liealgebras/__pycache__/cartan_type.cpython-39.pyc,, +sympy/liealgebras/__pycache__/dynkin_diagram.cpython-39.pyc,, +sympy/liealgebras/__pycache__/root_system.cpython-39.pyc,, +sympy/liealgebras/__pycache__/type_a.cpython-39.pyc,, +sympy/liealgebras/__pycache__/type_b.cpython-39.pyc,, +sympy/liealgebras/__pycache__/type_c.cpython-39.pyc,, +sympy/liealgebras/__pycache__/type_d.cpython-39.pyc,, +sympy/liealgebras/__pycache__/type_e.cpython-39.pyc,, +sympy/liealgebras/__pycache__/type_f.cpython-39.pyc,, +sympy/liealgebras/__pycache__/type_g.cpython-39.pyc,, +sympy/liealgebras/__pycache__/weyl_group.cpython-39.pyc,, +sympy/liealgebras/cartan_matrix.py,sha256=yr2LoZi_Gxmu-EMKgFuPOPNMYPOsxucLAS6oRpSYi2U,524 +sympy/liealgebras/cartan_type.py,sha256=xLklg8Y5s40je6sXwmLmG9iyYi9YEk9KoxTSFz1GtdI,1790 +sympy/liealgebras/dynkin_diagram.py,sha256=ZzGuBGNOJ3lPDdJDs4n8hvGbz6wLhC5mwb8zFkDmyPw,535 +sympy/liealgebras/root_system.py,sha256=n1TKAZyWDPSlUDmjAga9Sa2Cm9qeRXddYPbl2XKyFq4,6673 +sympy/liealgebras/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/liealgebras/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/liealgebras/tests/__pycache__/test_cartan_matrix.cpython-39.pyc,, +sympy/liealgebras/tests/__pycache__/test_cartan_type.cpython-39.pyc,, +sympy/liealgebras/tests/__pycache__/test_dynkin_diagram.cpython-39.pyc,, +sympy/liealgebras/tests/__pycache__/test_root_system.cpython-39.pyc,, +sympy/liealgebras/tests/__pycache__/test_type_A.cpython-39.pyc,, +sympy/liealgebras/tests/__pycache__/test_type_B.cpython-39.pyc,, +sympy/liealgebras/tests/__pycache__/test_type_C.cpython-39.pyc,, +sympy/liealgebras/tests/__pycache__/test_type_D.cpython-39.pyc,, +sympy/liealgebras/tests/__pycache__/test_type_E.cpython-39.pyc,, +sympy/liealgebras/tests/__pycache__/test_type_F.cpython-39.pyc,, +sympy/liealgebras/tests/__pycache__/test_type_G.cpython-39.pyc,, +sympy/liealgebras/tests/__pycache__/test_weyl_group.cpython-39.pyc,, +sympy/liealgebras/tests/test_cartan_matrix.py,sha256=KCsakn0fHKHRbIUcrUkHBIKkudl3_ISUdHrfJy-UOd4,303 +sympy/liealgebras/tests/test_cartan_type.py,sha256=t5PvYYDXbNIFL3CV59Je7SBIAeLLf-W3mOINPUoHK6E,339 +sympy/liealgebras/tests/test_dynkin_diagram.py,sha256=DSixbnt_yd0zrhKzXW_XqkXWXYe1Dk2MmXN-Rjb1dGg,260 +sympy/liealgebras/tests/test_root_system.py,sha256=YmGBdUeJ4PkLSfAfRgTF7GW62RCEd5nH27FSX9UaG5Q,927 +sympy/liealgebras/tests/test_type_A.py,sha256=x7QmpjxsGmXol-IYVtN1lmIOmM3HLYwpX1tSG5h6FMM,657 +sympy/liealgebras/tests/test_type_B.py,sha256=Gw0GP24wP2rPn38Wwla9W7BwWH4JtCGpaprZb5W6JVY,642 +sympy/liealgebras/tests/test_type_C.py,sha256=ysSy-vzE9lNwzAunrmvnFkLBoJwF7W2On7QpqS6RI1s,927 +sympy/liealgebras/tests/test_type_D.py,sha256=qrO4oCjrjkp1uDvrNtbgANVyaOExqOLNtIpIxD1uH0U,764 +sympy/liealgebras/tests/test_type_E.py,sha256=IQ46Bo75Kivz0O45mNKNvDO8jen7X2NwxVBTwvr5lQo,987 +sympy/liealgebras/tests/test_type_F.py,sha256=yUQJ7LzTemv4Cd1XW_dr3x7KEI07BahsWAyJfXLS1eA,1378 +sympy/liealgebras/tests/test_type_G.py,sha256=wVa6qcAHbdrc9dA63samexHL35cWWJS606pom-6mH2Q,548 +sympy/liealgebras/tests/test_weyl_group.py,sha256=HrzojRECbhNUsdLFQAXYnJEt8LfktOSJZuqVE45aRnc,1501 +sympy/liealgebras/type_a.py,sha256=Etn_aeyE5FdYrUXrGnjNiTL5k4r5-O_CysaO-_UKASA,4294 +sympy/liealgebras/type_b.py,sha256=ERH1aM38eHZUuTJYRsWLKFDm74G4y1MQtweSfkDFQYc,4547 +sympy/liealgebras/type_c.py,sha256=Al1pA5cw5IKjvBzx9LdMORd5-JZXgA0kaMv8aHhnYJ0,4426 +sympy/liealgebras/type_d.py,sha256=Cagn4GXrYuUP32W6jauIPnEv8IcD3H88sbn0fKS7pmU,4681 +sympy/liealgebras/type_e.py,sha256=KR84bp4Ga7d9itKRCsD3AYrarVewGiNNe7AHWNUfTAM,8280 +sympy/liealgebras/type_f.py,sha256=gIyk4zA04uIpwizcYdvKsZ_xkcpPLId6LIHR6tp38dk,4423 +sympy/liealgebras/type_g.py,sha256=Ife98dGPtarGd-ii8hJbXdB0SMsct4okDkSX2wLN8XI,2965 +sympy/liealgebras/weyl_group.py,sha256=5YFA8qC4GWDM0WLNR_6VgpuNFZDfyDA7fBFjBcZaLgA,14557 +sympy/logic/__init__.py,sha256=RfoXrq9MESnXdL7PkwpYEfWeaxH6wBPHiE4zCgLKvk0,456 +sympy/logic/__pycache__/__init__.cpython-39.pyc,, +sympy/logic/__pycache__/boolalg.cpython-39.pyc,, +sympy/logic/__pycache__/inference.cpython-39.pyc,, +sympy/logic/algorithms/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/logic/algorithms/__pycache__/__init__.cpython-39.pyc,, +sympy/logic/algorithms/__pycache__/dpll.cpython-39.pyc,, +sympy/logic/algorithms/__pycache__/dpll2.cpython-39.pyc,, +sympy/logic/algorithms/__pycache__/lra_theory.cpython-39.pyc,, +sympy/logic/algorithms/__pycache__/minisat22_wrapper.cpython-39.pyc,, +sympy/logic/algorithms/__pycache__/pycosat_wrapper.cpython-39.pyc,, +sympy/logic/algorithms/__pycache__/z3_wrapper.cpython-39.pyc,, +sympy/logic/algorithms/dpll.py,sha256=zqiZDm1oD5sNxFqm_0Hen6NjfILIDp5uRgEOad1vYXI,9188 +sympy/logic/algorithms/dpll2.py,sha256=PEVvFUrEf-IwTwfnTYtAFLPLPjjqA1nmm39nASgsM1I,21497 +sympy/logic/algorithms/lra_theory.py,sha256=6JZAZMJhhHSbUXZuc0ddYwp_x7efVuthIYV7Dc9eUyc,31769 +sympy/logic/algorithms/minisat22_wrapper.py,sha256=uINcvkIHGWYJb8u-Q0OgnSgaHfVUd9tYYFbBAVNiASo,1317 +sympy/logic/algorithms/pycosat_wrapper.py,sha256=0vNFTbu9-YhSfjwYTsZsP_Z4HM8WpL11-xujLBS1kYg,1207 +sympy/logic/algorithms/z3_wrapper.py,sha256=mFmf7DWDV0Zu7006EdFK4qEDTP7sfmcXrfGZkMK97vo,3747 +sympy/logic/boolalg.py,sha256=GqxLtqUangGoSjsncKDzlLEE7iYVe5qsMtPhztniI9c,114972 +sympy/logic/inference.py,sha256=J2D8t9iHCSdotzS9iq6g3EvLPsI2B10kiNbmHsIz_oY,8983 +sympy/logic/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/logic/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/logic/tests/__pycache__/test_boolalg.cpython-39.pyc,, +sympy/logic/tests/__pycache__/test_dimacs.cpython-39.pyc,, +sympy/logic/tests/__pycache__/test_inference.cpython-39.pyc,, +sympy/logic/tests/__pycache__/test_lra_theory.cpython-39.pyc,, +sympy/logic/tests/test_boolalg.py,sha256=0ZjG56WpIhIwFLedUzg5IC-c0VF_g9qoU9aAbFydPYU,49748 +sympy/logic/tests/test_dimacs.py,sha256=EK_mA_k9zBLcQLTOKTZVrGhnGuQNza5mwXDQD_f-X1c,3886 +sympy/logic/tests/test_inference.py,sha256=9BXfPbJs6sBEgx1nwkmg2HcLl82lcfFXngKJL_ByzUI,16116 +sympy/logic/tests/test_lra_theory.py,sha256=hNCH66hP_32x5ioyqM6ltd0DbEl72kdctCBu-H6egG0,16834 +sympy/logic/utilities/__init__.py,sha256=WTn2vBgHcmhONRWI79PdMYNk8UxYDzsxRlZWuc-wtNI,55 +sympy/logic/utilities/__pycache__/__init__.cpython-39.pyc,, +sympy/logic/utilities/__pycache__/dimacs.cpython-39.pyc,, +sympy/logic/utilities/dimacs.py,sha256=gfrGWQpn4F7gLJH1zHAnLB8-CMGa1XtXl9NCGhwAW9k,1671 +sympy/matrices/__init__.py,sha256=i2a37WlcCj8-AKG_Yy8BBdOHFgAuWois_2IcD_Ih00s,2634 +sympy/matrices/__pycache__/__init__.cpython-39.pyc,, +sympy/matrices/__pycache__/common.cpython-39.pyc,, +sympy/matrices/__pycache__/decompositions.cpython-39.pyc,, +sympy/matrices/__pycache__/dense.cpython-39.pyc,, +sympy/matrices/__pycache__/determinant.cpython-39.pyc,, +sympy/matrices/__pycache__/eigen.cpython-39.pyc,, +sympy/matrices/__pycache__/exceptions.cpython-39.pyc,, +sympy/matrices/__pycache__/graph.cpython-39.pyc,, +sympy/matrices/__pycache__/immutable.cpython-39.pyc,, +sympy/matrices/__pycache__/inverse.cpython-39.pyc,, +sympy/matrices/__pycache__/kind.cpython-39.pyc,, +sympy/matrices/__pycache__/matrices.cpython-39.pyc,, +sympy/matrices/__pycache__/matrixbase.cpython-39.pyc,, +sympy/matrices/__pycache__/normalforms.cpython-39.pyc,, +sympy/matrices/__pycache__/reductions.cpython-39.pyc,, +sympy/matrices/__pycache__/repmatrix.cpython-39.pyc,, +sympy/matrices/__pycache__/solvers.cpython-39.pyc,, +sympy/matrices/__pycache__/sparse.cpython-39.pyc,, +sympy/matrices/__pycache__/sparsetools.cpython-39.pyc,, +sympy/matrices/__pycache__/subspaces.cpython-39.pyc,, +sympy/matrices/__pycache__/utilities.cpython-39.pyc,, +sympy/matrices/benchmarks/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/matrices/benchmarks/__pycache__/__init__.cpython-39.pyc,, +sympy/matrices/benchmarks/__pycache__/bench_matrix.cpython-39.pyc,, +sympy/matrices/benchmarks/bench_matrix.py,sha256=vGMlg-2il2cFeAWrf0NJ6pzPX3Yd3ZQMxFgQ4q5ILQE,306 +sympy/matrices/common.py,sha256=v_jMjH9t6mrHl7Z5QluhO-G07QUaJSdiioF2YC1xlhY,95395 +sympy/matrices/decompositions.py,sha256=5OjwFe5Z_eM2XoZPqrNFY9-Esq3cO_8yqFxT9QBPHgQ,47865 +sympy/matrices/dense.py,sha256=HTPNMKbi2M0W0nHImNtHib9JM5VC8si99ldL_FylKwQ,30461 +sympy/matrices/determinant.py,sha256=pDTDYCLBSWJov_j_B-ZZHdqH3-WQ_uwhZawrX3EDE9k,34550 +sympy/matrices/eigen.py,sha256=v09mqb2hZ0HqGbGzYpELOatIHnGTP5XADE-1RGcua-g,39811 +sympy/matrices/exceptions.py,sha256=8diN_ojMGC93XpqvZeS5ow4crlFxSkm6WqtnIK5j81E,503 +sympy/matrices/expressions/__init__.py,sha256=IMqXCSsPh0Vp_MC9HZTudA5DGM4WBq_yB-Bst0azyM8,1692 +sympy/matrices/expressions/__pycache__/__init__.cpython-39.pyc,, +sympy/matrices/expressions/__pycache__/_shape.cpython-39.pyc,, +sympy/matrices/expressions/__pycache__/adjoint.cpython-39.pyc,, +sympy/matrices/expressions/__pycache__/applyfunc.cpython-39.pyc,, +sympy/matrices/expressions/__pycache__/blockmatrix.cpython-39.pyc,, +sympy/matrices/expressions/__pycache__/companion.cpython-39.pyc,, +sympy/matrices/expressions/__pycache__/determinant.cpython-39.pyc,, +sympy/matrices/expressions/__pycache__/diagonal.cpython-39.pyc,, +sympy/matrices/expressions/__pycache__/dotproduct.cpython-39.pyc,, +sympy/matrices/expressions/__pycache__/factorizations.cpython-39.pyc,, +sympy/matrices/expressions/__pycache__/fourier.cpython-39.pyc,, +sympy/matrices/expressions/__pycache__/funcmatrix.cpython-39.pyc,, +sympy/matrices/expressions/__pycache__/hadamard.cpython-39.pyc,, +sympy/matrices/expressions/__pycache__/inverse.cpython-39.pyc,, +sympy/matrices/expressions/__pycache__/kronecker.cpython-39.pyc,, +sympy/matrices/expressions/__pycache__/matadd.cpython-39.pyc,, +sympy/matrices/expressions/__pycache__/matexpr.cpython-39.pyc,, +sympy/matrices/expressions/__pycache__/matmul.cpython-39.pyc,, +sympy/matrices/expressions/__pycache__/matpow.cpython-39.pyc,, +sympy/matrices/expressions/__pycache__/permutation.cpython-39.pyc,, +sympy/matrices/expressions/__pycache__/sets.cpython-39.pyc,, +sympy/matrices/expressions/__pycache__/slice.cpython-39.pyc,, +sympy/matrices/expressions/__pycache__/special.cpython-39.pyc,, +sympy/matrices/expressions/__pycache__/trace.cpython-39.pyc,, +sympy/matrices/expressions/__pycache__/transpose.cpython-39.pyc,, +sympy/matrices/expressions/_shape.py,sha256=TyQSwGx41aaMyAYs5Q7Er6atKVAdWK7DJ6YIVsiEAZg,3062 +sympy/matrices/expressions/adjoint.py,sha256=HV4OIWgmS2f_1_4PHRhQpgNmKJVCoX0eyHmrM6gfN5g,1515 +sympy/matrices/expressions/applyfunc.py,sha256=8scpWjZp7yzuvUNr0mxN03KgyTzLRa5NkGl9s81YPgY,6751 +sympy/matrices/expressions/blockmatrix.py,sha256=TPR1ZDTSzVGtT-Qo23XNzjnDVl1Js8NCi_2qV3fQA9U,31643 +sympy/matrices/expressions/companion.py,sha256=lXUJRbjQR6e1mdHQdJwNIJXMW80XmKbOVqNvUXjB57U,1705 +sympy/matrices/expressions/determinant.py,sha256=RyQXgUgqJkv_rvUPrn1_rOY45wzp5zYH6ZOf9S8NK8s,3281 +sympy/matrices/expressions/diagonal.py,sha256=XHWoT-Jv5QwJVsGNbfxHnNG2sygPy1CeR_t6zr8oUoM,6328 +sympy/matrices/expressions/dotproduct.py,sha256=sKdUhwVKTB3LEvd8xMwCDexNoQ1Dz43DCYsmm3UwFWw,1911 +sympy/matrices/expressions/factorizations.py,sha256=zFNjMBsJqhsIcDD8Me4W8-Q-TV89WptfG3Dd9yK_tPE,1456 +sympy/matrices/expressions/fourier.py,sha256=dvaftgB9jgkR_8ETyhzyVLtf1ZJu_wQC-ZbpTYMXZGE,2094 +sympy/matrices/expressions/funcmatrix.py,sha256=q6R75wLn0UdV4xJdVJUrNaofV1k1egXLLQdBeZcPtiY,3520 +sympy/matrices/expressions/hadamard.py,sha256=feXSZy0vbqxTzg0JeLmkSegiF4T2v5dOdcv0UQczK38,13920 +sympy/matrices/expressions/inverse.py,sha256=4UwgHWSIHgEoKOniObkClMYN9DrO2xNyvOSVToXSpj8,2963 +sympy/matrices/expressions/kronecker.py,sha256=kzCHqXDtcZGVQPln521lfN5redNwj6IjXJwjbv_Dkhg,13404 +sympy/matrices/expressions/matadd.py,sha256=0MSanal1HKVEuCBEpehKwfUX4fuM9UMy6Fg2H5noA0s,4773 +sympy/matrices/expressions/matexpr.py,sha256=nfV9MhDNBR9HkOP7zoXOtv9wWR4OFxNczprik4S3Uh8,27549 +sympy/matrices/expressions/matmul.py,sha256=n9LcrgYoEhEQuq2htALHuboIWN0p8P_3lZdKNo92Rf0,15509 +sympy/matrices/expressions/matpow.py,sha256=puuYB49wr1WzFxhT9DcTWtF2bGDFWcbJ7oh14ATQARs,5140 +sympy/matrices/expressions/permutation.py,sha256=Xe7yOx-EgeD6JrqWc4L-ApdN-3ZiV8XS_LQPmc1lhGw,8050 +sympy/matrices/expressions/sets.py,sha256=3zase_rDn2QdaXETX78BgkfKiWcRC7FmwVjjIU-WmdY,2033 +sympy/matrices/expressions/slice.py,sha256=aNdY1Ey4VJR-UCvoORX2kh2DmA6QjOp-waENvWg8WVE,3355 +sympy/matrices/expressions/special.py,sha256=hywkygOQjcJEGQn2fG4dF8oUqXL3N42U1TxbdJr4-6E,7499 +sympy/matrices/expressions/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/matrices/expressions/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/matrices/expressions/tests/__pycache__/test_adjoint.cpython-39.pyc,, +sympy/matrices/expressions/tests/__pycache__/test_applyfunc.cpython-39.pyc,, +sympy/matrices/expressions/tests/__pycache__/test_blockmatrix.cpython-39.pyc,, +sympy/matrices/expressions/tests/__pycache__/test_companion.cpython-39.pyc,, +sympy/matrices/expressions/tests/__pycache__/test_derivatives.cpython-39.pyc,, +sympy/matrices/expressions/tests/__pycache__/test_determinant.cpython-39.pyc,, +sympy/matrices/expressions/tests/__pycache__/test_diagonal.cpython-39.pyc,, +sympy/matrices/expressions/tests/__pycache__/test_dotproduct.cpython-39.pyc,, +sympy/matrices/expressions/tests/__pycache__/test_factorizations.cpython-39.pyc,, +sympy/matrices/expressions/tests/__pycache__/test_fourier.cpython-39.pyc,, +sympy/matrices/expressions/tests/__pycache__/test_funcmatrix.cpython-39.pyc,, +sympy/matrices/expressions/tests/__pycache__/test_hadamard.cpython-39.pyc,, +sympy/matrices/expressions/tests/__pycache__/test_indexing.cpython-39.pyc,, +sympy/matrices/expressions/tests/__pycache__/test_inverse.cpython-39.pyc,, +sympy/matrices/expressions/tests/__pycache__/test_kronecker.cpython-39.pyc,, +sympy/matrices/expressions/tests/__pycache__/test_matadd.cpython-39.pyc,, +sympy/matrices/expressions/tests/__pycache__/test_matexpr.cpython-39.pyc,, +sympy/matrices/expressions/tests/__pycache__/test_matmul.cpython-39.pyc,, +sympy/matrices/expressions/tests/__pycache__/test_matpow.cpython-39.pyc,, +sympy/matrices/expressions/tests/__pycache__/test_permutation.cpython-39.pyc,, +sympy/matrices/expressions/tests/__pycache__/test_sets.cpython-39.pyc,, +sympy/matrices/expressions/tests/__pycache__/test_slice.cpython-39.pyc,, +sympy/matrices/expressions/tests/__pycache__/test_special.cpython-39.pyc,, +sympy/matrices/expressions/tests/__pycache__/test_trace.cpython-39.pyc,, +sympy/matrices/expressions/tests/__pycache__/test_transpose.cpython-39.pyc,, +sympy/matrices/expressions/tests/test_adjoint.py,sha256=cxOc334yNSI9MazhG9HT8s1OCXjkDWr3Zj2JnyHS3Z4,1065 +sympy/matrices/expressions/tests/test_applyfunc.py,sha256=mxTJaoB4Ze50lk-2TgVopmrrbuQbEqUsZwc3K1H8w-Q,3522 +sympy/matrices/expressions/tests/test_blockmatrix.py,sha256=ANNR7e2eiGIvabMNezxRVzipmA8oUwmDTrBTz5ALMzU,16541 +sympy/matrices/expressions/tests/test_companion.py,sha256=Lam6r-cSOokjhSlJws55Kq-gL5_pHfeV_Xuvmn5PkRU,1657 +sympy/matrices/expressions/tests/test_derivatives.py,sha256=9mBeaAZDX7-JbYs6tMClNuGDygETVN_dCXSlHmyAhwg,15991 +sympy/matrices/expressions/tests/test_determinant.py,sha256=JSgptLz9KNC4_X27qnuFq-rscgHk6144s5TEUQpLxr0,2067 +sympy/matrices/expressions/tests/test_diagonal.py,sha256=3L6Vs_Yr36a8dgIqAeIcNEf0xcVyeyGhANNu0dlIpwI,4516 +sympy/matrices/expressions/tests/test_dotproduct.py,sha256=Zkv2N6oRPm0-sN4PFwsVFrM5Y_qv4x2gWqQQQD86hBY,1171 +sympy/matrices/expressions/tests/test_factorizations.py,sha256=6UPA_UhCL5JPbaQCOatMnxhGnQ-aIHmb3lXqbwrSoIE,786 +sympy/matrices/expressions/tests/test_fourier.py,sha256=0eD69faoHXBcuQ7g2Q31fqs-gyR_Xfe-gv-7DXhJh_c,1638 +sympy/matrices/expressions/tests/test_funcmatrix.py,sha256=uN9r0ECMIBqsIzOezg_n9uDYNs6ebYS8Yf5yexUjmAM,2230 +sympy/matrices/expressions/tests/test_hadamard.py,sha256=rR0l1howrI8SaJOnLb0fsCXS5cIx1rzahwFTfGldp3Y,4614 +sympy/matrices/expressions/tests/test_indexing.py,sha256=wwYQa7LNlzhBA5fU50gPyE8cqaJf0s3O70PUx4eNCEA,12038 +sympy/matrices/expressions/tests/test_inverse.py,sha256=n4gwv-GH0LPXZDVgzEB0lA_fk8MmNFK_BVNJb0FEUfY,2320 +sympy/matrices/expressions/tests/test_kronecker.py,sha256=e5H6av3ioOn8jkjyDBrT3NEmCkyHbN6ZEHOlyB9OYLk,5366 +sympy/matrices/expressions/tests/test_matadd.py,sha256=U1fL5YLP_cYEOsdi2uaGGrzm8qOsKcXn69BC1UV6RMM,1866 +sympy/matrices/expressions/tests/test_matexpr.py,sha256=pQkxhi8okFDbF6Uea-dbeRrjNfmlc6nzJdIXOJxAqUI,18441 +sympy/matrices/expressions/tests/test_matmul.py,sha256=2ofS4YHHIcI9DamZNU-8RhBQTnNogqGknKxqkyKXO8U,6187 +sympy/matrices/expressions/tests/test_matpow.py,sha256=dRbwvZ3vxwnqw09lnRqOu475f46xDPNo7V3oc1d_P2U,7308 +sympy/matrices/expressions/tests/test_permutation.py,sha256=93Cqjj2k3aoR3ayMJLdJUa5h1u87bRRxT3I8B4FQsvU,5607 +sympy/matrices/expressions/tests/test_sets.py,sha256=DfFGe6W1ppUs6bgo3vB3DSJvFemrT68s0F3QbyoIJiE,1408 +sympy/matrices/expressions/tests/test_slice.py,sha256=C7OGAQQTz0YZxZCa7g0m8_0Bqq8jaPRa22JHVSqK7tY,2027 +sympy/matrices/expressions/tests/test_special.py,sha256=Mhg71vnjjb4fm0jZgjDoWW8rAJMBeh8aDCM75gjEpKQ,6496 +sympy/matrices/expressions/tests/test_trace.py,sha256=fRlrw9CfdO3z3SI4TQb1fCUb_zVAndbtyOErEeCTCQ0,3383 +sympy/matrices/expressions/tests/test_transpose.py,sha256=P3wPPRywKnrAppX6gssgD66v0RIcolxqDkCaKGGPVcM,1987 +sympy/matrices/expressions/trace.py,sha256=skr53LvstLV5Yg9hkaRb0yWrTdxC9u95G-YzIc80aTs,5362 +sympy/matrices/expressions/transpose.py,sha256=QGQ1bgqvYmRNs6QiVolhtFlbluPYpwW3UNvkRZUqUHU,2645 +sympy/matrices/graph.py,sha256=4UBv9SI5Z8Xjc5jPJVOoPVGQDtfbyOmS8f0ENrPOuU8,9080 +sympy/matrices/immutable.py,sha256=okpJZ41FnaHp1PpwnPNCnB_o3afAU7DgRsr2NqBKtvg,5530 +sympy/matrices/inverse.py,sha256=aVjDn_SjZUfi-jdM_1uJ6u6lVItuvbnb3CppayHQ-Gs,13166 +sympy/matrices/kind.py,sha256=EJxDdD4gFgvVfC3lRex-bczhcGjwBhglf1hPDk2WzXE,2843 +sympy/matrices/matrices.py,sha256=iqgi7x7cjnLT6OuRp6TYjoObJtUtbcPpeGTOC-Enh9Q,23536 +sympy/matrices/matrixbase.py,sha256=snpv-ZBFi-CRtfMwY9lukSTPpELVXmJF16Gn8xPV3yI,165154 +sympy/matrices/normalforms.py,sha256=mX19wvJqQAirRiPY-WdSfEdljFOhst18LcDX4MvPM8A,4636 +sympy/matrices/reductions.py,sha256=MZ8LtryawSm1jXsWzRPsYBmir867lRUWX4LFNGVIHwI,12500 +sympy/matrices/repmatrix.py,sha256=d2vne9qNn1cu7dQbCdkftMPadA87evGMaoE7LkxaYXE,29914 +sympy/matrices/solvers.py,sha256=kWkUdy0XoWbn7RY4S1N0qNgxKAcUrHiTlicb25rR9z4,25163 +sympy/matrices/sparse.py,sha256=ER4tfnpvAMeockzWi10mKdKZJSYJNPISyoUL_m29UJ8,14673 +sympy/matrices/sparsetools.py,sha256=tzI541P8QW_v1eVJAXgOlo_KK1Xp6u1geawX_tdlBxY,9182 +sympy/matrices/subspaces.py,sha256=uLo4qnP0xvFcFo5hhf6g7pHSHiRbcQ1ATDKwGBxW7CE,3761 +sympy/matrices/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/matrices/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/matrices/tests/__pycache__/test_commonmatrix.cpython-39.pyc,, +sympy/matrices/tests/__pycache__/test_decompositions.cpython-39.pyc,, +sympy/matrices/tests/__pycache__/test_determinant.cpython-39.pyc,, +sympy/matrices/tests/__pycache__/test_domains.cpython-39.pyc,, +sympy/matrices/tests/__pycache__/test_eigen.cpython-39.pyc,, +sympy/matrices/tests/__pycache__/test_graph.cpython-39.pyc,, +sympy/matrices/tests/__pycache__/test_immutable.cpython-39.pyc,, +sympy/matrices/tests/__pycache__/test_interactions.cpython-39.pyc,, +sympy/matrices/tests/__pycache__/test_matrices.cpython-39.pyc,, +sympy/matrices/tests/__pycache__/test_matrixbase.cpython-39.pyc,, +sympy/matrices/tests/__pycache__/test_normalforms.cpython-39.pyc,, +sympy/matrices/tests/__pycache__/test_reductions.cpython-39.pyc,, +sympy/matrices/tests/__pycache__/test_repmatrix.cpython-39.pyc,, +sympy/matrices/tests/__pycache__/test_solvers.cpython-39.pyc,, +sympy/matrices/tests/__pycache__/test_sparse.cpython-39.pyc,, +sympy/matrices/tests/__pycache__/test_sparsetools.cpython-39.pyc,, +sympy/matrices/tests/__pycache__/test_subspaces.cpython-39.pyc,, +sympy/matrices/tests/test_commonmatrix.py,sha256=CEYYa5frGDGPAVz78uYVTBC3IHoQLG18VRnq289XtBQ,40772 +sympy/matrices/tests/test_decompositions.py,sha256=LTm3KjUHh0QkGmR8UA1CwGWzpFu8o-qcl4sv2VKBi8g,14419 +sympy/matrices/tests/test_determinant.py,sha256=FxbLYbiO1wj12YYfsbQPoEOcgWpRGVu1CzVKLx4e8oQ,9560 +sympy/matrices/tests/test_domains.py,sha256=gUnMLr_LeEsN5unFjIwHCZje0URC9uVBN_Q76JbX9f4,3276 +sympy/matrices/tests/test_eigen.py,sha256=Xm-Xe7qGACLpqwOx5W2mDjve2KyJf-q8PyEi40jMgAk,22722 +sympy/matrices/tests/test_graph.py,sha256=ckfGDCg2M6gluv9XFnfURga8gxd2HTL7aX281s6wy6c,3213 +sympy/matrices/tests/test_immutable.py,sha256=JSu6YlGtPP-5iialCeatCKbZ4ScLDUzhQ-TMGhsalp8,4616 +sympy/matrices/tests/test_interactions.py,sha256=6T6wkHyTW5v2fwg0rz2HULoDElfA_NttApU2t-pZFKI,2070 +sympy/matrices/tests/test_matrices.py,sha256=YrwrczVsTuLoPAAkh4dlZYbfUcVEUe88wyM53fmUOtw,163196 +sympy/matrices/tests/test_matrixbase.py,sha256=2EUoz0Dl3hPFNvWU--ls6RPGfllZJbVvZSsDA9D6jPo,168868 +sympy/matrices/tests/test_normalforms.py,sha256=leFZ2rJGrWKv8Zz8aMaHQiNrCEQWOG1RqQ87omvlqIk,3730 +sympy/matrices/tests/test_reductions.py,sha256=tKv_KufpVc6qwH-MDz8XLpgVS4z6snnlNdd1ECAQSXM,13385 +sympy/matrices/tests/test_repmatrix.py,sha256=enu29IYEPzvKvGFeN6Bi9KYKqcXdy0EucifN9fV420Q,2084 +sympy/matrices/tests/test_solvers.py,sha256=55Zmvp2KC6OpdXFHBkXZRBRls_8wHfYNWO2UOQh9AoE,22311 +sympy/matrices/tests/test_sparse.py,sha256=GvXN6kBVldjqoR8WN8I_PjblKhRmyRWvVuLUgZEgugY,23281 +sympy/matrices/tests/test_sparsetools.py,sha256=pjQR6UaEMR92NolB_IGZ9Umk6FPZjvI0vk1Fd4H_C5I,4877 +sympy/matrices/tests/test_subspaces.py,sha256=poY6k6l2LSL7OCixQNGzrauLZIYbrjDul7J-yEE02S8,3465 +sympy/matrices/utilities.py,sha256=mMnNsDTxGKqiG0JATsM4W9b5jglhacy-vmRw2aZojgY,2117 +sympy/multipledispatch/__init__.py,sha256=aV2NC2cO_KmD6QFiwy4oC1D8fm3pFuPbaiTMeWmNWak,259 +sympy/multipledispatch/__pycache__/__init__.cpython-39.pyc,, +sympy/multipledispatch/__pycache__/conflict.cpython-39.pyc,, +sympy/multipledispatch/__pycache__/core.cpython-39.pyc,, +sympy/multipledispatch/__pycache__/dispatcher.cpython-39.pyc,, +sympy/multipledispatch/__pycache__/utils.cpython-39.pyc,, +sympy/multipledispatch/conflict.py,sha256=rR6tKn58MfhMMKZ4ZrhVduylXd9f5PjT2TpzM9LMB6o,2117 +sympy/multipledispatch/core.py,sha256=I4WOnmu1VtlaCnn2oD9R2-xckkYLRZPNFEWtCOTAYfM,2261 +sympy/multipledispatch/dispatcher.py,sha256=A2I4upt4qNollXGpwzrqg7M0oKHJhZx1BUMIBnjRIow,12226 +sympy/multipledispatch/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/multipledispatch/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/multipledispatch/tests/__pycache__/test_conflict.cpython-39.pyc,, +sympy/multipledispatch/tests/__pycache__/test_core.cpython-39.pyc,, +sympy/multipledispatch/tests/__pycache__/test_dispatcher.cpython-39.pyc,, +sympy/multipledispatch/tests/test_conflict.py,sha256=msNVSiikuPOqsEm_MMGmjsNbA2CAR0F1FZaHskzzo04,1786 +sympy/multipledispatch/tests/test_core.py,sha256=UfH_7cyvZ6PHjdH8vmLG49CG7E30W8uxm3FthuMc1Jk,4048 +sympy/multipledispatch/tests/test_dispatcher.py,sha256=saJPpGXLpLOuRfw-ekzZGzY-Rys0NsS5ke0n33i9j0U,6228 +sympy/multipledispatch/utils.py,sha256=39wB9i8jNhlLFZyCTFnioLx5N_CNWv4r5VZwKrxswIE,3097 +sympy/ntheory/__init__.py,sha256=WpYPsusKb9OQ1xmXxZFla7kXacIWv4lNIpFIEVVzTLA,2810 +sympy/ntheory/__pycache__/__init__.cpython-39.pyc,, +sympy/ntheory/__pycache__/bbp_pi.cpython-39.pyc,, +sympy/ntheory/__pycache__/continued_fraction.cpython-39.pyc,, +sympy/ntheory/__pycache__/digits.cpython-39.pyc,, +sympy/ntheory/__pycache__/ecm.cpython-39.pyc,, +sympy/ntheory/__pycache__/egyptian_fraction.cpython-39.pyc,, +sympy/ntheory/__pycache__/elliptic_curve.cpython-39.pyc,, +sympy/ntheory/__pycache__/factor_.cpython-39.pyc,, +sympy/ntheory/__pycache__/generate.cpython-39.pyc,, +sympy/ntheory/__pycache__/modular.cpython-39.pyc,, +sympy/ntheory/__pycache__/multinomial.cpython-39.pyc,, +sympy/ntheory/__pycache__/partitions_.cpython-39.pyc,, +sympy/ntheory/__pycache__/primetest.cpython-39.pyc,, +sympy/ntheory/__pycache__/qs.cpython-39.pyc,, +sympy/ntheory/__pycache__/residue_ntheory.cpython-39.pyc,, +sympy/ntheory/bbp_pi.py,sha256=ILur1c9Ja-1F_blgnInUx-WopQ_WSvK-2OvNLEe2Zx8,5998 +sympy/ntheory/continued_fraction.py,sha256=bQW7PvdgDtWnbpCmkOwyz3mNYPOXh9_ehq3_ZpJO8Rw,10717 +sympy/ntheory/digits.py,sha256=ea3xSLy8RMaMHsekg8nq2gnVTug7SOvaDjw0lmB6NMU,3831 +sympy/ntheory/ecm.py,sha256=zg_GhL0XSyG5ZVSdGv3YceCguPZrFaS-hC-vXqOVOGQ,11793 +sympy/ntheory/egyptian_fraction.py,sha256=hW886hPWJtARqgZIrH1WjZFC0uvf9CHxMIn0X9MWZro,6923 +sympy/ntheory/elliptic_curve.py,sha256=ZT677EHi26BkZdRTLs1Tf4VSyMxDzJrwk1xECoMT7Qg,11544 +sympy/ntheory/factor_.py,sha256=MQQJSnjSpz-B9ex9fqMCnrrNTfqFOlhJV9RAWy3gJTw,83551 +sympy/ntheory/generate.py,sha256=B6LVkK627xWzQ1UE5XoeMxNM7Q1O8EwcP7fwKs5vkZk,33376 +sympy/ntheory/modular.py,sha256=wYNfNr5S8DqipQNBLVMR7cPUNg7twM25xcLhXBD14I4,8474 +sympy/ntheory/multinomial.py,sha256=rbm3STjgfRbNVbcPeH69qtWktthSCk0sC373NuDM6fU,5073 +sympy/ntheory/partitions_.py,sha256=S81zrWeTVHmYAphcP6yUGLz5CrlUwW_JFtkQRoeLlV4,8995 +sympy/ntheory/primetest.py,sha256=9EVL1yeBkOHDfa2phHiDwUnmidSJJrTl6Wzo0y1Sck0,25423 +sympy/ntheory/qs.py,sha256=8ZKxvqc-oGPqZNbEH-XcPsjXLy5-5aygVB5rR1BkSO8,14938 +sympy/ntheory/residue_ntheory.py,sha256=dDBaogSqroJHlyXRqnLmuzNVyXfy4rWB_hlhOCNQHpk,54359 +sympy/ntheory/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/ntheory/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/ntheory/tests/__pycache__/test_bbp_pi.cpython-39.pyc,, +sympy/ntheory/tests/__pycache__/test_continued_fraction.cpython-39.pyc,, +sympy/ntheory/tests/__pycache__/test_digits.cpython-39.pyc,, +sympy/ntheory/tests/__pycache__/test_ecm.cpython-39.pyc,, +sympy/ntheory/tests/__pycache__/test_egyptian_fraction.cpython-39.pyc,, +sympy/ntheory/tests/__pycache__/test_elliptic_curve.cpython-39.pyc,, +sympy/ntheory/tests/__pycache__/test_factor_.cpython-39.pyc,, +sympy/ntheory/tests/__pycache__/test_generate.cpython-39.pyc,, +sympy/ntheory/tests/__pycache__/test_hypothesis.cpython-39.pyc,, +sympy/ntheory/tests/__pycache__/test_modular.cpython-39.pyc,, +sympy/ntheory/tests/__pycache__/test_multinomial.cpython-39.pyc,, +sympy/ntheory/tests/__pycache__/test_partitions.cpython-39.pyc,, +sympy/ntheory/tests/__pycache__/test_primetest.cpython-39.pyc,, +sympy/ntheory/tests/__pycache__/test_qs.cpython-39.pyc,, +sympy/ntheory/tests/__pycache__/test_residue.cpython-39.pyc,, +sympy/ntheory/tests/test_bbp_pi.py,sha256=TgNpZOtKAfU-IVzO62Ko7Oy3LLNfdzEo9gWMTTDZL6c,9486 +sympy/ntheory/tests/test_continued_fraction.py,sha256=JJsyEXatgjNxYZTKkKubvymyVs0WDpLfbEiAY5SYk8g,3583 +sympy/ntheory/tests/test_digits.py,sha256=mrTfwboMCQkiOEpMgYg8Nrk12WE5pEtpmbQthRVt4Xc,1968 +sympy/ntheory/tests/test_ecm.py,sha256=yco77gknWe6co4VKTCoRNKHzd3jdqGKQWQFwuziYNWI,2290 +sympy/ntheory/tests/test_egyptian_fraction.py,sha256=tpHcwteuuQAahcPqvgBm4Mwq-efzcHOn8mldijynjlE,2378 +sympy/ntheory/tests/test_elliptic_curve.py,sha256=wc0EOsGo-qGpdevRq1o64htwTOT_YSUzUfyhJC-JVbg,624 +sympy/ntheory/tests/test_factor_.py,sha256=Z4B33C0blhcG9KuQzmJcDw7TtrTYJ6jAsWQd8Oi3YZw,26382 +sympy/ntheory/tests/test_generate.py,sha256=lhnJPjlz1TYrDHJ4Jq0F64P4KV8C7ngKmm3Jxtz7Wsk,9868 +sympy/ntheory/tests/test_hypothesis.py,sha256=Ztg-QoiBxpUp6euPy1RcPbF6yaLK_ij-Jcl637GGhNY,728 +sympy/ntheory/tests/test_modular.py,sha256=g73sUXtYNxzbDcq5UnMWT8NodAU8unwRj_E-PpvJqDs,1425 +sympy/ntheory/tests/test_multinomial.py,sha256=8uuj6XlatNyIILOpjJap13CMZmDwrCyGKn9LiIUiLV0,2344 +sympy/ntheory/tests/test_partitions.py,sha256=qkd-84AO0rpJv-MQW0lnXtIPtGTAfplFbu2ezKSfEQI,1088 +sympy/ntheory/tests/test_primetest.py,sha256=Lp4elAnnBDYSIPfVXENOUjZYuwAynsNQrE_CEiYDErU,9495 +sympy/ntheory/tests/test_qs.py,sha256=zbwQ1k5ywsUVNZzPewWul_yUnoqWxxvsSWOBAU5Albc,3956 +sympy/ntheory/tests/test_residue.py,sha256=7VGXsslZWVd3R7tAtIcLjYyh0TjZpb5BtG2GQehlAUY,16809 +sympy/parsing/__init__.py,sha256=KHuyDeHY1ifpVxT4aTOhomazCBYVIrKWd28jqp6YNJ8,125 +sympy/parsing/__pycache__/__init__.cpython-39.pyc,, +sympy/parsing/__pycache__/ast_parser.cpython-39.pyc,, +sympy/parsing/__pycache__/mathematica.cpython-39.pyc,, +sympy/parsing/__pycache__/maxima.cpython-39.pyc,, +sympy/parsing/__pycache__/sym_expr.cpython-39.pyc,, +sympy/parsing/__pycache__/sympy_parser.cpython-39.pyc,, +sympy/parsing/ast_parser.py,sha256=iJvr6bhm1RjM5rhWzZA4c4LGTH5lAFazN5zu8y8q-aY,2734 +sympy/parsing/autolev/Autolev.g4,sha256=980mo25mLWrQFmhRIg-aqIalUuwktYYaBGTXZ5_XZwA,4195 +sympy/parsing/autolev/__init__.py,sha256=sp5hzv5siVW3xUmhkp0S0iaA0Cz-PVB0HO1zC04pxYs,3611 +sympy/parsing/autolev/__pycache__/__init__.cpython-39.pyc,, +sympy/parsing/autolev/__pycache__/_build_autolev_antlr.cpython-39.pyc,, +sympy/parsing/autolev/__pycache__/_listener_autolev_antlr.cpython-39.pyc,, +sympy/parsing/autolev/__pycache__/_parse_autolev_antlr.cpython-39.pyc,, +sympy/parsing/autolev/_antlr/__init__.py,sha256=MQ4ZacpTuP-NmruFXKdWLQatoeVJQ8SaBQ2DnYvtyE8,203 +sympy/parsing/autolev/_antlr/__pycache__/__init__.cpython-39.pyc,, +sympy/parsing/autolev/_antlr/__pycache__/autolevlexer.cpython-39.pyc,, +sympy/parsing/autolev/_antlr/__pycache__/autolevlistener.cpython-39.pyc,, +sympy/parsing/autolev/_antlr/__pycache__/autolevparser.cpython-39.pyc,, +sympy/parsing/autolev/_antlr/autolevlexer.py,sha256=K7HF_-5dUyAIv1_7GkhTmxqSCanEhCpzJG8fayAEB3Q,13609 +sympy/parsing/autolev/_antlr/autolevlistener.py,sha256=EDb3XkH9Y7CLzxGM-tY-nGqxMGfBHVkqKdVCPxABgRE,12821 +sympy/parsing/autolev/_antlr/autolevparser.py,sha256=BZYJ7IkurRmm44S50pYp_9JHCjT8fr1w5HeksAEPjtg,106291 +sympy/parsing/autolev/_build_autolev_antlr.py,sha256=0oIrC3sWUDe18zydeNyRmQrGGwPtRYQzIul8YcmMpk4,2578 +sympy/parsing/autolev/_listener_autolev_antlr.py,sha256=jv1jYvmrGmom6F3fhYTHG1sqLK6FFCO1-jrvDW-nPEM,104760 +sympy/parsing/autolev/_parse_autolev_antlr.py,sha256=b9hIaluJUd1V2XIAp1erak6U-c-CwKyDLH1UkYQuvKE,1736 +sympy/parsing/autolev/test-examples/README.txt,sha256=0C4m_nLROeV5J8nMfm3RYEfYgQJqmlHZaCpVD24boQY,528 +sympy/parsing/autolev/test-examples/__pycache__/ruletest1.cpython-39.pyc,, +sympy/parsing/autolev/test-examples/__pycache__/ruletest10.cpython-39.pyc,, +sympy/parsing/autolev/test-examples/__pycache__/ruletest11.cpython-39.pyc,, +sympy/parsing/autolev/test-examples/__pycache__/ruletest12.cpython-39.pyc,, +sympy/parsing/autolev/test-examples/__pycache__/ruletest2.cpython-39.pyc,, +sympy/parsing/autolev/test-examples/__pycache__/ruletest3.cpython-39.pyc,, +sympy/parsing/autolev/test-examples/__pycache__/ruletest4.cpython-39.pyc,, +sympy/parsing/autolev/test-examples/__pycache__/ruletest5.cpython-39.pyc,, +sympy/parsing/autolev/test-examples/__pycache__/ruletest6.cpython-39.pyc,, +sympy/parsing/autolev/test-examples/__pycache__/ruletest7.cpython-39.pyc,, +sympy/parsing/autolev/test-examples/__pycache__/ruletest8.cpython-39.pyc,, +sympy/parsing/autolev/test-examples/__pycache__/ruletest9.cpython-39.pyc,, +sympy/parsing/autolev/test-examples/pydy-example-repo/__pycache__/chaos_pendulum.cpython-39.pyc,, +sympy/parsing/autolev/test-examples/pydy-example-repo/__pycache__/double_pendulum.cpython-39.pyc,, +sympy/parsing/autolev/test-examples/pydy-example-repo/__pycache__/mass_spring_damper.cpython-39.pyc,, +sympy/parsing/autolev/test-examples/pydy-example-repo/__pycache__/non_min_pendulum.cpython-39.pyc,, +sympy/parsing/autolev/test-examples/pydy-example-repo/chaos_pendulum.al,sha256=HpTcX2wXzLqmgpp8fcSqNweKjxljk43iYK0wQmBbCDI,690 +sympy/parsing/autolev/test-examples/pydy-example-repo/chaos_pendulum.py,sha256=FSu4TP2BDTQjzYhMkcpRhXbb3kAD27XCyO_EoL55Ack,2274 +sympy/parsing/autolev/test-examples/pydy-example-repo/double_pendulum.al,sha256=wjeeRdCS3Es6ldX9Ug5Du1uaijUTyoXpfTqmhL0uYfk,427 +sympy/parsing/autolev/test-examples/pydy-example-repo/double_pendulum.py,sha256=uU9azTUGrY15BSDtw5T_V-7gmjyhHbXslzkmwBvFjGk,1583 +sympy/parsing/autolev/test-examples/pydy-example-repo/mass_spring_damper.al,sha256=Gf7OhgRlwqUEXq7rkfbf89yWA23u4uIUJ-buXTyOuXM,505 +sympy/parsing/autolev/test-examples/pydy-example-repo/mass_spring_damper.py,sha256=9ReCAqcUH5HYBgHmop9h5Zx54mfScWZN5L5F6rCHk4w,1366 +sympy/parsing/autolev/test-examples/pydy-example-repo/non_min_pendulum.al,sha256=p5v40h1nVFrWNqnB0K7GiNQT0b-MqwayYjZxXOY4M8M,362 +sympy/parsing/autolev/test-examples/pydy-example-repo/non_min_pendulum.py,sha256=DdxcWrm3HMQuyyY3Pk6sKHb4RXhQEM_EKY3HYZCP8ec,1503 +sympy/parsing/autolev/test-examples/ruletest1.al,sha256=mDJ02Q1Qm-ShVmGoyjzSfgDJHUOuDrsUg3YMnkpKdUw,176 +sympy/parsing/autolev/test-examples/ruletest1.py,sha256=eIKEFzEwkCFhPF0GTmf6SLuxXT384GqdCJnhiL2U0BQ,555 +sympy/parsing/autolev/test-examples/ruletest10.al,sha256=jKpV8BgX91iQsQDLFOJyaS396AyE5YQlUMxih5o9RK0,781 +sympy/parsing/autolev/test-examples/ruletest10.py,sha256=I1tsQcSAW6wqIguF-7lwlj9D4YZ8kCZqPqTKPUHR9oI,2726 +sympy/parsing/autolev/test-examples/ruletest11.al,sha256=j_q7giq2KIuXVRLWwNlwIlpbhNO6SqBMnLGLcxIkzwk,188 +sympy/parsing/autolev/test-examples/ruletest11.py,sha256=dYTRtXvMDXHiKzXHD2Sh0fcEukob3wr_GbSeqaZrrO8,475 +sympy/parsing/autolev/test-examples/ruletest12.al,sha256=drr2NLrK1ewn4FjMppXycpAUNbZEQ0IAMsdVx8nxk6I,185 +sympy/parsing/autolev/test-examples/ruletest12.py,sha256=ZG36s3PnkT0aKBM9Nx6H0sdJrtoLwaebU9386YSUql8,472 +sympy/parsing/autolev/test-examples/ruletest2.al,sha256=d-QjPpW0lzugaGBg8F6pDl_5sZHOR_EDJ8EvWLcz4FY,237 +sympy/parsing/autolev/test-examples/ruletest2.py,sha256=jrJfb0Jk2FP4GS5pDa0UB5ph0ijEVd1X8meKeZrTVng,820 +sympy/parsing/autolev/test-examples/ruletest3.al,sha256=1TAaOe8GI8-yBWJddfIxwnvScHNmOjSzSaQn0RS_v5k,308 +sympy/parsing/autolev/test-examples/ruletest3.py,sha256=O3K3IQo-HCjAIOSkfz3bDlst7dVUiRwhOZ0q_3jb5LU,1574 +sympy/parsing/autolev/test-examples/ruletest4.al,sha256=qPGlPbdDRrzTDUBeWydAIa7mbjs2o3uX938QAsWJ7Qk,302 +sympy/parsing/autolev/test-examples/ruletest4.py,sha256=WHod5yzKF4TNbEf4Yfxmx9WnimA7NOXqtTjZXR8FsP0,682 +sympy/parsing/autolev/test-examples/ruletest5.al,sha256=VuiKjiFmLK3uEdho0m3pk-n0qm4SNLoLPMRJqjMJ4GY,516 +sympy/parsing/autolev/test-examples/ruletest5.py,sha256=WvUtno1D3BrmFNPYYIBKR_gOA-PaHoxLlSTNDX67dcQ,1991 +sympy/parsing/autolev/test-examples/ruletest6.al,sha256=-HwgTmh_6X3wHjo3PQi7378t8YdizRJClc5Eb5DmjhE,703 +sympy/parsing/autolev/test-examples/ruletest6.py,sha256=vEO0jMOD-KIevAcVexmpvac0MGjN7O_dNipOBJJNzF0,1473 +sympy/parsing/autolev/test-examples/ruletest7.al,sha256=wR9S9rTzO9fyKL6Ofgwzw8XCFCV_p2hBpYotC8TvADI,773 +sympy/parsing/autolev/test-examples/ruletest7.py,sha256=_XvMrMe5r9RLopTrIqMGLhaYvHL1qjteWz9CKcotCL8,1696 +sympy/parsing/autolev/test-examples/ruletest8.al,sha256=P7Nu3Pq2R1mKcuFRc9dRO5jJ1_e5fwWdtqYG8NHVVds,682 +sympy/parsing/autolev/test-examples/ruletest8.py,sha256=8tgbwJ-ir0wiOCsgIFCAu4uD8SieYRrLoLzEfae5YQY,2690 +sympy/parsing/autolev/test-examples/ruletest9.al,sha256=txtZ5RH2p1FvAe6etwetSCH8rLktnpk5z0W72sCOdAA,755 +sympy/parsing/autolev/test-examples/ruletest9.py,sha256=GtqV-Wq2GGJzfblMscAz-KXCzs0P_4XqvA3FIdlPe04,1965 +sympy/parsing/c/__init__.py,sha256=J9CvkNRY-qy6CA06GZYuwTuxdnqas6oUP2g0qLztGro,65 +sympy/parsing/c/__pycache__/__init__.cpython-39.pyc,, +sympy/parsing/c/__pycache__/c_parser.cpython-39.pyc,, +sympy/parsing/c/c_parser.py,sha256=fM1GGRIbjzrVQ7ErfkWz6Xzfjvym030BOoWlQLx37Oc,38124 +sympy/parsing/fortran/__init__.py,sha256=KraiVw2qxIgYeMRTFjs1vkMi-hqqDkxUBv8Rc2gwkCI,73 +sympy/parsing/fortran/__pycache__/__init__.cpython-39.pyc,, +sympy/parsing/fortran/__pycache__/fortran_parser.cpython-39.pyc,, +sympy/parsing/fortran/fortran_parser.py,sha256=RpNQR3eNx5vgfzdt0nEZDCB56kF__SnYMaqWN3zla00,11483 +sympy/parsing/latex/LICENSE.txt,sha256=AHvDClj6QKmW53IEcSDeTq8x9REOT5w7X5P8374urKE,1075 +sympy/parsing/latex/LaTeX.g4,sha256=fG0ZUQPwYQOIbcyaPDAkGvcfGs3ZwwMB8ZnKW5yHUDY,5821 +sympy/parsing/latex/__init__.py,sha256=CfRRwZZo2qMF2HksyjGer8qSwhx0T4wnXsa2HKxCvUE,8968 +sympy/parsing/latex/__pycache__/__init__.cpython-39.pyc,, +sympy/parsing/latex/__pycache__/_build_latex_antlr.cpython-39.pyc,, +sympy/parsing/latex/__pycache__/_parse_latex_antlr.cpython-39.pyc,, +sympy/parsing/latex/__pycache__/errors.cpython-39.pyc,, +sympy/parsing/latex/_antlr/__init__.py,sha256=TAb79senorEsoYLCLwUa8wg8AUCHzmmZ7tLdi0XGNaE,384 +sympy/parsing/latex/_antlr/__pycache__/__init__.cpython-39.pyc,, +sympy/parsing/latex/_antlr/__pycache__/latexlexer.cpython-39.pyc,, +sympy/parsing/latex/_antlr/__pycache__/latexparser.cpython-39.pyc,, +sympy/parsing/latex/_antlr/latexlexer.py,sha256=Y1hmY1VGL5FTSSlToTRQydPnyaLLNy1mDSWx76HaYwM,30502 +sympy/parsing/latex/_antlr/latexparser.py,sha256=ZvonpvTS3vLSOVpas88M3CfNnUhPUDsCCPPk4wBYUGE,123655 +sympy/parsing/latex/_build_latex_antlr.py,sha256=3Mwip9f_yXUNhiwwUTCG5Nk8l2uzAw8oZHU0HSl5gkE,2765 +sympy/parsing/latex/_parse_latex_antlr.py,sha256=sVaO04oSeHe_TaMeM-6toheCR88G_RmJYpUIx-Sef1g,20712 +sympy/parsing/latex/errors.py,sha256=adSpvQyWjTLsbN_2KHJ4HuXpY7_U9noeWiG0lskYLgE,45 +sympy/parsing/latex/lark/__init__.py,sha256=hhhvfKRGP3ON36wRJwVfxMWw_GA6rl0JKsIzjuaUX38,120 +sympy/parsing/latex/lark/__pycache__/__init__.cpython-39.pyc,, +sympy/parsing/latex/lark/__pycache__/latex_parser.cpython-39.pyc,, +sympy/parsing/latex/lark/__pycache__/transformer.cpython-39.pyc,, +sympy/parsing/latex/lark/grammar/greek_symbols.lark,sha256=-G8JGrBredhWAzCaurr1UmqgRMRrAJfs_pANub8kXyA,937 +sympy/parsing/latex/lark/grammar/latex.lark,sha256=yz3DXIHllvRDGHeBmpz49PzDCU92pKexg2RTKX4KhTE,13221 +sympy/parsing/latex/lark/latex_parser.py,sha256=lGFEo_RxSKXYmjVEb_E3DiZ-aXl5tAZyfHklzEuJ9VA,4430 +sympy/parsing/latex/lark/transformer.py,sha256=RuXe8enUnLrzUFC9ZfKQyQ1_IlgA-qPINNng6UnY1x4,25754 +sympy/parsing/mathematica.py,sha256=P2vWEgx_HXv4I_C-6vPE3RXtbpXNT2iIJQBrA6QfqkU,39614 +sympy/parsing/maxima.py,sha256=DhTnXRSAceijyA1OAm86c6TyW9-aeUVoZEELGu0oZtY,1835 +sympy/parsing/sym_expr.py,sha256=-hxarp961eyLtuwUhbg3D3qzy06HrEPZEYpGVcJzAv0,8895 +sympy/parsing/sympy_parser.py,sha256=PRRhNS_0LKBVcHFmg7LGygzxRyQt6-vyCSMisXkRhVE,43832 +sympy/parsing/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/parsing/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/parsing/tests/__pycache__/test_ast_parser.cpython-39.pyc,, +sympy/parsing/tests/__pycache__/test_autolev.cpython-39.pyc,, +sympy/parsing/tests/__pycache__/test_c_parser.cpython-39.pyc,, +sympy/parsing/tests/__pycache__/test_custom_latex.cpython-39.pyc,, +sympy/parsing/tests/__pycache__/test_fortran_parser.cpython-39.pyc,, +sympy/parsing/tests/__pycache__/test_implicit_multiplication_application.cpython-39.pyc,, +sympy/parsing/tests/__pycache__/test_latex.cpython-39.pyc,, +sympy/parsing/tests/__pycache__/test_latex_deps.cpython-39.pyc,, +sympy/parsing/tests/__pycache__/test_latex_lark.cpython-39.pyc,, +sympy/parsing/tests/__pycache__/test_mathematica.cpython-39.pyc,, +sympy/parsing/tests/__pycache__/test_maxima.cpython-39.pyc,, +sympy/parsing/tests/__pycache__/test_sym_expr.cpython-39.pyc,, +sympy/parsing/tests/__pycache__/test_sympy_parser.cpython-39.pyc,, +sympy/parsing/tests/test_ast_parser.py,sha256=lcT8w7mn6UEZ8T-xfA4TqG4Mt7JxY00oHhOW7JtHQfY,803 +sympy/parsing/tests/test_autolev.py,sha256=tQuUFa8YqVdsHPOcUhAwlMKB8Uk08HejDhDCda8lXs0,6647 +sympy/parsing/tests/test_c_parser.py,sha256=VYl3K4if_23iIS-Be8MBSG0OKZo-6xgxHiN22laeAyo,155354 +sympy/parsing/tests/test_custom_latex.py,sha256=vX5uVHw9-UgEcRl0XVNOMgrjBbb1sXDlPLL4D8AeOiQ,2060 +sympy/parsing/tests/test_fortran_parser.py,sha256=SGbawrJ4a780TJAFVMONc7Y3Y8VYgVqsIHxVGaicbxE,11828 +sympy/parsing/tests/test_implicit_multiplication_application.py,sha256=xUlc9TKH4HjimhnvWMwCtTabHAJcHCupFdsuiOsZhUs,7389 +sympy/parsing/tests/test_latex.py,sha256=WvKNJ5mtxfzl-rBiE9hc2aEz81b5BmI3SKO7sRZiNbI,11765 +sympy/parsing/tests/test_latex_deps.py,sha256=oe5vm2eIKn05ZiCcXUaO8X6HCcRmN1qCuTsz6tB7Qrk,426 +sympy/parsing/tests/test_latex_lark.py,sha256=kXmZP-gMLwa5XAcNsd-rF_zrPok0RZuJsa6Ny6zfoLw,36059 +sympy/parsing/tests/test_mathematica.py,sha256=vAxwquc8ArTQE9UNbsO21FqSa6J17sCF-A4vhTPXLY0,13395 +sympy/parsing/tests/test_maxima.py,sha256=iIwnFm0lYD0-JcraUIymogqEMN3ji0c-0JeNFFGTEDs,1987 +sympy/parsing/tests/test_sym_expr.py,sha256=-wNR7GwvJHVmPSZxSuAuoX1_FJk83O0tcDi09qYY6Jk,5668 +sympy/parsing/tests/test_sympy_parser.py,sha256=H_dDL89ASex3sI8F92ceXIJy5Qd137rxTa20Rj-9l2A,12946 +sympy/physics/__init__.py,sha256=F_yvUMCuBq3HR-3Ai6W4oktBsXRg8KdutFLwT9FFJlY,220 +sympy/physics/__pycache__/__init__.cpython-39.pyc,, +sympy/physics/__pycache__/hydrogen.cpython-39.pyc,, +sympy/physics/__pycache__/matrices.cpython-39.pyc,, +sympy/physics/__pycache__/paulialgebra.cpython-39.pyc,, +sympy/physics/__pycache__/pring.cpython-39.pyc,, +sympy/physics/__pycache__/qho_1d.cpython-39.pyc,, +sympy/physics/__pycache__/secondquant.cpython-39.pyc,, +sympy/physics/__pycache__/sho.cpython-39.pyc,, +sympy/physics/__pycache__/wigner.cpython-39.pyc,, +sympy/physics/biomechanics/__init__.py,sha256=dG1IoRAnmfXvSyPciqJVrPn5LLnuvbVnBt78hBG0maQ,1520 +sympy/physics/biomechanics/__pycache__/__init__.cpython-39.pyc,, +sympy/physics/biomechanics/__pycache__/_mixin.cpython-39.pyc,, +sympy/physics/biomechanics/__pycache__/activation.cpython-39.pyc,, +sympy/physics/biomechanics/__pycache__/curve.cpython-39.pyc,, +sympy/physics/biomechanics/__pycache__/musculotendon.cpython-39.pyc,, +sympy/physics/biomechanics/_mixin.py,sha256=0D3iBqlCRmR4HXKMxyC2LvYpKGHreOuZY5dXMJioQ4A,1493 +sympy/physics/biomechanics/activation.py,sha256=rUJegXrdO1LPHnzJpVRAYnA6DsvF46ndctkBt1JHDwI,25522 +sympy/physics/biomechanics/curve.py,sha256=PezKOUytiC15oLV11GC4M6u7ndXHpwtbilqdbIiqvgw,63154 +sympy/physics/biomechanics/musculotendon.py,sha256=LrYwFnwkShDWWZQtCgtSUqvc1H5IlTj_BNVKxfFNhE0,58274 +sympy/physics/biomechanics/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/physics/biomechanics/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/physics/biomechanics/tests/__pycache__/test_activation.cpython-39.pyc,, +sympy/physics/biomechanics/tests/__pycache__/test_curve.cpython-39.pyc,, +sympy/physics/biomechanics/tests/__pycache__/test_mixin.cpython-39.pyc,, +sympy/physics/biomechanics/tests/__pycache__/test_musculotendon.cpython-39.pyc,, +sympy/physics/biomechanics/tests/test_activation.py,sha256=hdnMsFBLjloJylu8-cLZ44oamORG3kSsI0q2-eLQ_-I,13395 +sympy/physics/biomechanics/tests/test_curve.py,sha256=3jbWEHr8ieZYCfVlwJ1KTA1WJ9dcTCyQ1p3r8G2jp1I,75800 +sympy/physics/biomechanics/tests/test_mixin.py,sha256=ds-EoUCvfiSjVGnC_mBwjn5mI7z5W5wi2UTZZ4-pIIQ,1322 +sympy/physics/biomechanics/tests/test_musculotendon.py,sha256=Ls59mtJQ83W0fdpDGFGNeTgtxL8yAZ6ODW96N9mvFtM,32906 +sympy/physics/continuum_mechanics/__init__.py,sha256=GzYypKq3so15xfvY4qFerHYCNGFzZthfQ1aOcgu8yCM,191 +sympy/physics/continuum_mechanics/__pycache__/__init__.cpython-39.pyc,, +sympy/physics/continuum_mechanics/__pycache__/arch.cpython-39.pyc,, +sympy/physics/continuum_mechanics/__pycache__/beam.cpython-39.pyc,, +sympy/physics/continuum_mechanics/__pycache__/cable.cpython-39.pyc,, +sympy/physics/continuum_mechanics/__pycache__/truss.cpython-39.pyc,, +sympy/physics/continuum_mechanics/arch.py,sha256=TKaoHk7-ToNP0n6ZGFF_cW6x965C-D3iENiqFo74ZJU,39243 +sympy/physics/continuum_mechanics/beam.py,sha256=Vq1xJHrjatQEUigMsxlAk8qoZSyp_Z_0GiJ5y5AIMIg,159946 +sympy/physics/continuum_mechanics/cable.py,sha256=1IOpa_eLYwdyixNkDhLXJQSg4fQhwOH7Ybm1jWR4aQA,31420 +sympy/physics/continuum_mechanics/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/physics/continuum_mechanics/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/physics/continuum_mechanics/tests/__pycache__/test_arch.cpython-39.pyc,, +sympy/physics/continuum_mechanics/tests/__pycache__/test_beam.cpython-39.pyc,, +sympy/physics/continuum_mechanics/tests/__pycache__/test_cable.cpython-39.pyc,, +sympy/physics/continuum_mechanics/tests/__pycache__/test_truss.cpython-39.pyc,, +sympy/physics/continuum_mechanics/tests/test_arch.py,sha256=t_V-Pt0fxoPqYDZl5LB6ehQ4hBBVlveRWLsEXdyo_3k,2571 +sympy/physics/continuum_mechanics/tests/test_beam.py,sha256=Wmi8BCju21H1RaomKDqFATvCUqsrca_43o5IyA5s4ok,42085 +sympy/physics/continuum_mechanics/tests/test_cable.py,sha256=aANZVJxfihRkjAauCvFPGkXKD_zxMcJFb7vXU443dsQ,3832 +sympy/physics/continuum_mechanics/tests/test_truss.py,sha256=wgtF1GbQX5hzx20UrBg2ZyEvplVpsio3DiU8CS_bAk8,3269 +sympy/physics/continuum_mechanics/truss.py,sha256=hkLLnXA9aRaE2YvA0ods7T40zJQXUc_jLI108NjhtTI,44962 +sympy/physics/control/__init__.py,sha256=U0BeOmFRXOw-cpZ_F-RzNuEh78Tqzx31nuvZMZW2Fo8,1334 +sympy/physics/control/__pycache__/__init__.cpython-39.pyc,, +sympy/physics/control/__pycache__/control_plots.cpython-39.pyc,, +sympy/physics/control/__pycache__/lti.cpython-39.pyc,, +sympy/physics/control/control_plots.py,sha256=lYa42ATTrSXMV44Q3POfEyYQCFSTV1W5Rb4UejYFdT8,37515 +sympy/physics/control/lti.py,sha256=o6n-qFXriBeETr4anCgOuYViUpQXDtBejckIdYdlxx0,177322 +sympy/physics/control/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/physics/control/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/physics/control/tests/__pycache__/test_control_plots.cpython-39.pyc,, +sympy/physics/control/tests/__pycache__/test_lti.cpython-39.pyc,, +sympy/physics/control/tests/test_control_plots.py,sha256=dsXkabVUMzU6b1FmptRz2sGgXfe6y_lu3Mm0KN8is90,16899 +sympy/physics/control/tests/test_lti.py,sha256=YAvwA11zSKRcHmgLPuCJhBJXbFrPfjUNytaQsnEsEr8,104043 +sympy/physics/hep/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/physics/hep/__pycache__/__init__.cpython-39.pyc,, +sympy/physics/hep/__pycache__/gamma_matrices.cpython-39.pyc,, +sympy/physics/hep/gamma_matrices.py,sha256=WlSHLUtMU7NrgLyKEvTntMSYxMZq1r_6o2kqUEAdPaA,24253 +sympy/physics/hep/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/physics/hep/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/physics/hep/tests/__pycache__/test_gamma_matrices.cpython-39.pyc,, +sympy/physics/hep/tests/test_gamma_matrices.py,sha256=iKqICj0bP7EK0sSuYFsPdPkDTbHGa6J_LMPZAzv1j4o,14722 +sympy/physics/hydrogen.py,sha256=R2wnNi1xB-WTQ8Z9aPUhX9Z8mQ8TdhCM1JAZIkyXgjw,7594 +sympy/physics/matrices.py,sha256=jHfbWkzL2myFt-39kodQo5wPubBxNZKXlljuSxZL4bE,3836 +sympy/physics/mechanics/__init__.py,sha256=7rgPJp1YItAMOc29gWqkRp68AKCdZLu4tFQG-CWgjOQ,2874 +sympy/physics/mechanics/__pycache__/__init__.cpython-39.pyc,, +sympy/physics/mechanics/__pycache__/actuator.cpython-39.pyc,, +sympy/physics/mechanics/__pycache__/body.cpython-39.pyc,, +sympy/physics/mechanics/__pycache__/body_base.cpython-39.pyc,, +sympy/physics/mechanics/__pycache__/functions.cpython-39.pyc,, +sympy/physics/mechanics/__pycache__/inertia.cpython-39.pyc,, +sympy/physics/mechanics/__pycache__/joint.cpython-39.pyc,, +sympy/physics/mechanics/__pycache__/jointsmethod.cpython-39.pyc,, +sympy/physics/mechanics/__pycache__/kane.cpython-39.pyc,, +sympy/physics/mechanics/__pycache__/lagrange.cpython-39.pyc,, +sympy/physics/mechanics/__pycache__/linearize.cpython-39.pyc,, +sympy/physics/mechanics/__pycache__/loads.cpython-39.pyc,, +sympy/physics/mechanics/__pycache__/method.cpython-39.pyc,, +sympy/physics/mechanics/__pycache__/models.cpython-39.pyc,, +sympy/physics/mechanics/__pycache__/particle.cpython-39.pyc,, +sympy/physics/mechanics/__pycache__/pathway.cpython-39.pyc,, +sympy/physics/mechanics/__pycache__/rigidbody.cpython-39.pyc,, +sympy/physics/mechanics/__pycache__/system.cpython-39.pyc,, +sympy/physics/mechanics/__pycache__/wrapping_geometry.cpython-39.pyc,, +sympy/physics/mechanics/actuator.py,sha256=d7Ffb9vo1Tmrcfiy71_g0r1btscfvvabE-vznku3Auk,43633 +sympy/physics/mechanics/body.py,sha256=Z9ZOReEab9FHEDsyv2RfIBQM4_ripXhFXmQtu1OmN_Y,24617 +sympy/physics/mechanics/body_base.py,sha256=bwP04lWmD0iY_T0Vsn6NWMbnWyzUMFFqkAoAyKsow_c,2491 +sympy/physics/mechanics/functions.py,sha256=KsX3z33rZdMCifwyVlyKIHVJh0Fve--zs5wd1rkGMl0,25190 +sympy/physics/mechanics/inertia.py,sha256=FKcEbpYPx_26LS9sSp8lcsJ83spqwSYIyxM5q_z1mYA,6172 +sympy/physics/mechanics/joint.py,sha256=o9SWJ9pXDO-sVh0gJ6VTI0LgY019gFB5w5FP1PN--UI,84621 +sympy/physics/mechanics/jointsmethod.py,sha256=nqkXawtuxeyP0D8DEAYURCJlDYS4Eka_vXx6UZP7y74,10415 +sympy/physics/mechanics/kane.py,sha256=XyvEDNF09i5DnHf-YlCecEHDdKK9j486N2P5YD8_XaM,37118 +sympy/physics/mechanics/lagrange.py,sha256=UTmClOP-PkY8S-LcPlcOAUmshndlq6mpmZailbvy4E4,20202 +sympy/physics/mechanics/linearize.py,sha256=YtAap6TzsZJd794cdDCEDJVlMOXd_wFAMaCZZacHwbE,17242 +sympy/physics/mechanics/loads.py,sha256=jnajZOg631Aqtd0-BplehohUL991C3Cji2OZZ3MVHdk,5406 +sympy/physics/mechanics/method.py,sha256=2vFRhA79ra4HR6AzVBHMr3oNncrcqgLLMRqdyif0DrI,660 +sympy/physics/mechanics/models.py,sha256=9q1g3I2xYpuTMi-v9geswEqxJWTP3RjcOquRfzMhHzM,6463 +sympy/physics/mechanics/particle.py,sha256=YKiEBkPLVRI9foXlEe8eu8Ys5W1GyQO1oOBf0T8fHFg,5985 +sympy/physics/mechanics/pathway.py,sha256=2T71x06X8CZ20NUyXkMnP83gKsx2wJJEbPJS-1gMXFA,26555 +sympy/physics/mechanics/rigidbody.py,sha256=s5b7ynsdcMsq9PN2nM1yozNcZd08RA-MgzMjcs7jnBI,10287 +sympy/physics/mechanics/system.py,sha256=VsPPQfJs0gBLKJYNQ2ZLmD98JWLd1IZy0QGRwI21_TM,59457 +sympy/physics/mechanics/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/physics/mechanics/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/physics/mechanics/tests/__pycache__/test_actuator.cpython-39.pyc,, +sympy/physics/mechanics/tests/__pycache__/test_body.cpython-39.pyc,, +sympy/physics/mechanics/tests/__pycache__/test_functions.cpython-39.pyc,, +sympy/physics/mechanics/tests/__pycache__/test_inertia.cpython-39.pyc,, +sympy/physics/mechanics/tests/__pycache__/test_joint.cpython-39.pyc,, +sympy/physics/mechanics/tests/__pycache__/test_jointsmethod.cpython-39.pyc,, +sympy/physics/mechanics/tests/__pycache__/test_kane.cpython-39.pyc,, +sympy/physics/mechanics/tests/__pycache__/test_kane2.cpython-39.pyc,, +sympy/physics/mechanics/tests/__pycache__/test_kane3.cpython-39.pyc,, +sympy/physics/mechanics/tests/__pycache__/test_kane4.cpython-39.pyc,, +sympy/physics/mechanics/tests/__pycache__/test_kane5.cpython-39.pyc,, +sympy/physics/mechanics/tests/__pycache__/test_lagrange.cpython-39.pyc,, +sympy/physics/mechanics/tests/__pycache__/test_lagrange2.cpython-39.pyc,, +sympy/physics/mechanics/tests/__pycache__/test_linearity_of_velocity_constraints.cpython-39.pyc,, +sympy/physics/mechanics/tests/__pycache__/test_linearize.cpython-39.pyc,, +sympy/physics/mechanics/tests/__pycache__/test_loads.cpython-39.pyc,, +sympy/physics/mechanics/tests/__pycache__/test_method.cpython-39.pyc,, +sympy/physics/mechanics/tests/__pycache__/test_models.cpython-39.pyc,, +sympy/physics/mechanics/tests/__pycache__/test_particle.cpython-39.pyc,, +sympy/physics/mechanics/tests/__pycache__/test_pathway.cpython-39.pyc,, +sympy/physics/mechanics/tests/__pycache__/test_rigidbody.cpython-39.pyc,, +sympy/physics/mechanics/tests/__pycache__/test_system.cpython-39.pyc,, +sympy/physics/mechanics/tests/__pycache__/test_system_class.cpython-39.pyc,, +sympy/physics/mechanics/tests/__pycache__/test_wrapping_geometry.cpython-39.pyc,, +sympy/physics/mechanics/tests/test_actuator.py,sha256=WoTnD_IGhewVOzcL7ybJXf_mrK0_E71gDcnon-3Mayg,41081 +sympy/physics/mechanics/tests/test_body.py,sha256=ygHfWeL6-3LW9gqXHfzdvqQhJWxg2-U2KQIHKqGNqKQ,12067 +sympy/physics/mechanics/tests/test_functions.py,sha256=Cl1aT3qT-0Ik6vJgS6at6PUaeWRQfpEG88QvEIxz1Bk,10447 +sympy/physics/mechanics/tests/test_inertia.py,sha256=On4K78tmq4tyr_QFs3Uo01efdlVj5R6yfn6CVs7ksnE,2726 +sympy/physics/mechanics/tests/test_joint.py,sha256=ojhQdBsWN6nR7egEhpw-LAyDRdFIPfSEsDoFL1bhzqs,57922 +sympy/physics/mechanics/tests/test_jointsmethod.py,sha256=8IUO7ntLPsydDT4YgA1FPyO5TXNpo4CIWC7O_JdWg9Q,10226 +sympy/physics/mechanics/tests/test_kane.py,sha256=K1HOm7cuPzXeO9ZmW7oAHgu-jxOg1IFqyKB3YheAbus,21545 +sympy/physics/mechanics/tests/test_kane2.py,sha256=qNLkbSV6UYclQzMDoRi_TJrJEz0-MUi7yf1e7XW4ba4,19256 +sympy/physics/mechanics/tests/test_kane3.py,sha256=-9MbXLfYUDhA_9D2o4NrNLgDTOnKDLLsgryQW3AHacs,14959 +sympy/physics/mechanics/tests/test_kane4.py,sha256=qJYUfvnq1F5UN_AQqTu_3BT5XqGceqxnCyP3d2gE04A,4709 +sympy/physics/mechanics/tests/test_kane5.py,sha256=gZvAyxJ8fkkLtH60xHlb_Gxrbt9d7VmJTckXSGmF0j4,5693 +sympy/physics/mechanics/tests/test_lagrange.py,sha256=iuHomulBF8MafLeorKGaLHUEF8CvFhXcxEtN0hk1akM,10119 +sympy/physics/mechanics/tests/test_lagrange2.py,sha256=VBonbWg0uLduE4AwNUWCsGGa_2KmaDbg9jROZL3DTWE,1402 +sympy/physics/mechanics/tests/test_linearity_of_velocity_constraints.py,sha256=Y9JYoaob8yyL7PS-VN7g-DC8YEfCxBGdFSSkQAKlOBI,1341 +sympy/physics/mechanics/tests/test_linearize.py,sha256=6yFFGEhJW60fx9Ny1duc6eyvGDg6rtmJVo_V1mgHgGk,13273 +sympy/physics/mechanics/tests/test_loads.py,sha256=Kw94kP0tfwKsV-jCDHGTQsyc-1dKQl3ABJfqJtR8AJg,2698 +sympy/physics/mechanics/tests/test_method.py,sha256=L7CnsvbQC-U7ijbSZdu7DEr03p88OLj4IPvFJ_3kCDo,154 +sympy/physics/mechanics/tests/test_models.py,sha256=GcsfCm5G4PPYQXsHCiAKI1dEW42RaZOh-x6aEouTYo4,5078 +sympy/physics/mechanics/tests/test_particle.py,sha256=JL6QAA6T3POQkSutUnungrVkR3xt6ZVX-hp75-EufQw,2682 +sympy/physics/mechanics/tests/test_pathway.py,sha256=oGCxlUOviyNc1GBMvhk5sYVZfu8C4o7lJMbqatBie3A,24944 +sympy/physics/mechanics/tests/test_rigidbody.py,sha256=ezMW5BWt9cWdNeY1B9aYcL4NsPcVkaKZuUS1C7S1qPk,6188 +sympy/physics/mechanics/tests/test_system.py,sha256=Dihu77qM5_QkDQg-zavHbVhh_nvaGEVztXgPNl2_enk,8700 +sympy/physics/mechanics/tests/test_system_class.py,sha256=Xe4VLrxWaYL3oRwP2SBaanWf5DY46d2yPlwf9H1XJ4M,38219 +sympy/physics/mechanics/tests/test_wrapping_geometry.py,sha256=aXwuEaprstnSW7BwLr3OKUyxSagRPO58L-tUMaq3I9s,10502 +sympy/physics/mechanics/wrapping_geometry.py,sha256=XkI331oQcBme4jDmH0RqQDWh_pq49o38nbRsXwT9O5E,21599 +sympy/physics/optics/__init__.py,sha256=0UmqIt2-u8WwNkAqsnOVt9VlkB9K0CRIJYiQaltJ73w,1647 +sympy/physics/optics/__pycache__/__init__.cpython-39.pyc,, +sympy/physics/optics/__pycache__/gaussopt.cpython-39.pyc,, +sympy/physics/optics/__pycache__/medium.cpython-39.pyc,, +sympy/physics/optics/__pycache__/polarization.cpython-39.pyc,, +sympy/physics/optics/__pycache__/utils.cpython-39.pyc,, +sympy/physics/optics/__pycache__/waves.cpython-39.pyc,, +sympy/physics/optics/gaussopt.py,sha256=mVQ-JX7xmAp9XbNOYIlwsPAxkUukTw_QjbjIxuKWZW8,20898 +sympy/physics/optics/medium.py,sha256=cys0tWGi1VCPWMTZuKadcN_bToz_bqKsDHSEVzuV3CE,7124 +sympy/physics/optics/polarization.py,sha256=mIrZiOVXetGtKkLxl8Llaf2Z9coWenf6JKrClh4W8yU,21434 +sympy/physics/optics/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/physics/optics/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/physics/optics/tests/__pycache__/test_gaussopt.cpython-39.pyc,, +sympy/physics/optics/tests/__pycache__/test_medium.cpython-39.pyc,, +sympy/physics/optics/tests/__pycache__/test_polarization.cpython-39.pyc,, +sympy/physics/optics/tests/__pycache__/test_utils.cpython-39.pyc,, +sympy/physics/optics/tests/__pycache__/test_waves.cpython-39.pyc,, +sympy/physics/optics/tests/test_gaussopt.py,sha256=QMXJw_6mFCC3918b-pc_4b_zgO8Hsk7_SBvMupbEi5I,4222 +sympy/physics/optics/tests/test_medium.py,sha256=RxG7N3lzmCO_8hIoKyPnDKffmk8QFzA9yamu1_mr_dE,2194 +sympy/physics/optics/tests/test_polarization.py,sha256=81MzyA29HZckg_Ss-88-5o0g9augDqCr_LwcJIiXuA0,2605 +sympy/physics/optics/tests/test_utils.py,sha256=SjicjAptcZGwuX-ib_Lq7PlGONotvo2XJ4p3JA9iNVI,8553 +sympy/physics/optics/tests/test_waves.py,sha256=PeFfrl7MBkWBHdc796sDDYDuhGepat3DQk7PmyTXVnw,3397 +sympy/physics/optics/utils.py,sha256=BqfuvtrjO3PEcDQ1DecNyt2Th9Yps6xued1tEY4ysvk,22172 +sympy/physics/optics/waves.py,sha256=Iw-9gGksvWhPmQ_VepmI90ekKyzHdPlq6U41wdM4ikI,10042 +sympy/physics/paulialgebra.py,sha256=1r_qDBbVyl836qIXlVDdoF89Z9wedGvWIkHAbwQaK-4,6002 +sympy/physics/pring.py,sha256=SCMGGIcEhVoD7dwhY7_NWL1iKwo7OfgKdmm2Ok_9Xl0,2240 +sympy/physics/qho_1d.py,sha256=ZXemUsa_b0rLtPVTUkgAkZQ1Ecu2eIZxaiNSSXW0PDk,2005 +sympy/physics/quantum/__init__.py,sha256=vzWc2d2aXbMdZGJRx4tyS6FDCtjbtl7SqCJE7nSXN9A,1952 +sympy/physics/quantum/__pycache__/__init__.cpython-39.pyc,, +sympy/physics/quantum/__pycache__/anticommutator.cpython-39.pyc,, +sympy/physics/quantum/__pycache__/boson.cpython-39.pyc,, +sympy/physics/quantum/__pycache__/cartesian.cpython-39.pyc,, +sympy/physics/quantum/__pycache__/cg.cpython-39.pyc,, +sympy/physics/quantum/__pycache__/circuitplot.cpython-39.pyc,, +sympy/physics/quantum/__pycache__/circuitutils.cpython-39.pyc,, +sympy/physics/quantum/__pycache__/commutator.cpython-39.pyc,, +sympy/physics/quantum/__pycache__/constants.cpython-39.pyc,, +sympy/physics/quantum/__pycache__/dagger.cpython-39.pyc,, +sympy/physics/quantum/__pycache__/density.cpython-39.pyc,, +sympy/physics/quantum/__pycache__/fermion.cpython-39.pyc,, +sympy/physics/quantum/__pycache__/gate.cpython-39.pyc,, +sympy/physics/quantum/__pycache__/grover.cpython-39.pyc,, +sympy/physics/quantum/__pycache__/hilbert.cpython-39.pyc,, +sympy/physics/quantum/__pycache__/identitysearch.cpython-39.pyc,, +sympy/physics/quantum/__pycache__/innerproduct.cpython-39.pyc,, +sympy/physics/quantum/__pycache__/kind.cpython-39.pyc,, +sympy/physics/quantum/__pycache__/matrixcache.cpython-39.pyc,, +sympy/physics/quantum/__pycache__/matrixutils.cpython-39.pyc,, +sympy/physics/quantum/__pycache__/operator.cpython-39.pyc,, +sympy/physics/quantum/__pycache__/operatorordering.cpython-39.pyc,, +sympy/physics/quantum/__pycache__/operatorset.cpython-39.pyc,, +sympy/physics/quantum/__pycache__/pauli.cpython-39.pyc,, +sympy/physics/quantum/__pycache__/piab.cpython-39.pyc,, +sympy/physics/quantum/__pycache__/qapply.cpython-39.pyc,, +sympy/physics/quantum/__pycache__/qasm.cpython-39.pyc,, +sympy/physics/quantum/__pycache__/qexpr.cpython-39.pyc,, +sympy/physics/quantum/__pycache__/qft.cpython-39.pyc,, +sympy/physics/quantum/__pycache__/qubit.cpython-39.pyc,, +sympy/physics/quantum/__pycache__/represent.cpython-39.pyc,, +sympy/physics/quantum/__pycache__/sho1d.cpython-39.pyc,, +sympy/physics/quantum/__pycache__/shor.cpython-39.pyc,, +sympy/physics/quantum/__pycache__/spin.cpython-39.pyc,, +sympy/physics/quantum/__pycache__/state.cpython-39.pyc,, +sympy/physics/quantum/__pycache__/tensorproduct.cpython-39.pyc,, +sympy/physics/quantum/__pycache__/trace.cpython-39.pyc,, +sympy/physics/quantum/__pycache__/transforms.cpython-39.pyc,, +sympy/physics/quantum/anticommutator.py,sha256=kXr7DWZZy0Sqcnxx-QEeA58mUCRXTL_VHDzQBjNV0DI,5072 +sympy/physics/quantum/boson.py,sha256=Nn4TN917TQv4tdFNZddlT1eEGg7O1uybnU9K5NI6AyM,5834 +sympy/physics/quantum/cartesian.py,sha256=9R9VDYLV1Xe-GkA9TQbj8PVlBLaD0fF6KXfHJ1ze5as,9092 +sympy/physics/quantum/cg.py,sha256=WK7HkAIRFejQQLjRsCC7rH0L--0fmXAoeL1JdTHb3GA,23319 +sympy/physics/quantum/circuitplot.py,sha256=SacQMhPyDhizKmGRNEs1vtXph8lR6bMn5bVJI4rJiXg,11799 +sympy/physics/quantum/circuitutils.py,sha256=mrQNUDbwM3LV1NZ1EqVpXyOY2mOXCBVZW7cQTiCxUaM,13882 +sympy/physics/quantum/commutator.py,sha256=MKk8OqwIi8LCzxkD8URmKCFM9OthAAxS8xFNgP_i09U,8139 +sympy/physics/quantum/constants.py,sha256=20VRATCkSprSnGFR5ejvMEYlWwEcv1B-dE3RPqPTQ9k,1420 +sympy/physics/quantum/dagger.py,sha256=P58FL8fLFfMhtu79jGsG69oEMjLjutRDrXsLwBat-9U,2484 +sympy/physics/quantum/density.py,sha256=ie_NJO-YNJ3PydAMp3jlmhC5H8bN3ldhn2SiCKY5N0s,9546 +sympy/physics/quantum/fermion.py,sha256=1ipn3FItUJ_ruLnflpp9MN_6t5w8CgHAJRJCOsukGGI,4983 +sympy/physics/quantum/gate.py,sha256=Iv7-qhSCe_An9qaJcYRDgwr8ClNraP47E75UlS5fCoQ,42588 +sympy/physics/quantum/grover.py,sha256=17KC5MJR3cCavmzqRqi9dB5OFTOpsYjfrTZuv03HiuE,10452 +sympy/physics/quantum/hilbert.py,sha256=qrja92vF7BUeSyHOLKVX8-XKcPGT7QaQMWrqWXjRNus,19632 +sympy/physics/quantum/identitysearch.py,sha256=Zh_ji5J0YeAy2AezsQcHV9W2icWoaa3ZwTbfjCCQmJo,27607 +sympy/physics/quantum/innerproduct.py,sha256=NKlJufh068C3MeKBORJG8k7zdH_vGJeo7jlH67PqU7E,4303 +sympy/physics/quantum/kind.py,sha256=odJyriPZKyievddvlY-kks8iBCSi1yZ7JEsaLANHoe4,2831 +sympy/physics/quantum/matrixcache.py,sha256=S6fPkkYmfX8ELBOc9EST-8XnQ1gtpSOBfd2KwLGKdYo,3587 +sympy/physics/quantum/matrixutils.py,sha256=53VGeUNYNpU1eA1sNcYFaX1_45I9SVabCfpTyHUDFog,8214 +sympy/physics/quantum/operator.py,sha256=Q-uiUjAZgMOXHKoBFe2oW_qJsJSjC8fM246yzxMmZ0A,19657 +sympy/physics/quantum/operatorordering.py,sha256=byAyZCNKTCeFWIFThmNx0NgdI4u32O4ydodYSa6Wrr8,10296 +sympy/physics/quantum/operatorset.py,sha256=h8nkScpQcUzCO3zemqKpgQfJDWiBbfj33IJzcl4J2_4,9563 +sympy/physics/quantum/pauli.py,sha256=lzxWFHXqxKWRiYK99QCo9zuVG9eVXiB8vFya7TvrVxQ,17250 +sympy/physics/quantum/piab.py,sha256=Zjb2cRGniVDV6e35gjP4uEpI4w0C7YGQIEXReaq_z-E,1912 +sympy/physics/quantum/qapply.py,sha256=5vqWe__WRvJuD6sWC2NTDe1qEcToBGTGaye9lzirgf4,9494 +sympy/physics/quantum/qasm.py,sha256=UWpcUIBgkK55SmEBZlpmz-1KGHZvW7dNeSVG8tHr44A,6288 +sympy/physics/quantum/qexpr.py,sha256=N87sU5AqFbGEXgHJof-B7ua46rk4SeysK16abU1GfYI,13940 +sympy/physics/quantum/qft.py,sha256=ua4qBQAm-gi923lRRxOgAebTsTCoR93pz8ZHPXBdcus,6425 +sympy/physics/quantum/qubit.py,sha256=mE0xebz8zqeKEc4xww9-Iv7tZzxHjm9Klsr2WgUQcVM,25997 +sympy/physics/quantum/represent.py,sha256=dBY9RLD21V0WuIBbigagjaLSUCiloDzitv8Wl6fU9dE,18729 +sympy/physics/quantum/sho1d.py,sha256=ZroR_FjxmjOmDcd0Fm04vWKTGCpvLaEu4NiuplKm708,20867 +sympy/physics/quantum/shor.py,sha256=DVwPxLAPSr8t3F3aXJIPe4o5XSuQiE6a6eA6OYmdZFw,5504 +sympy/physics/quantum/spin.py,sha256=l6168fTEZEC4FeaMS8iLqo3b4YMvUlPTZvq9N8iSV08,72986 +sympy/physics/quantum/state.py,sha256=618YaDGAtpUeuWCVYs6lDuO3mD87_f8KIUGUXaN9L4s,29971 +sympy/physics/quantum/tensorproduct.py,sha256=PvpDSuLku1IdIV3T2dK3fVv6AdMNVcc5eR8n34AB3OQ,12599 +sympy/physics/quantum/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/physics/quantum/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/physics/quantum/tests/__pycache__/test_anticommutator.cpython-39.pyc,, +sympy/physics/quantum/tests/__pycache__/test_boson.cpython-39.pyc,, +sympy/physics/quantum/tests/__pycache__/test_cartesian.cpython-39.pyc,, +sympy/physics/quantum/tests/__pycache__/test_cg.cpython-39.pyc,, +sympy/physics/quantum/tests/__pycache__/test_circuitplot.cpython-39.pyc,, +sympy/physics/quantum/tests/__pycache__/test_circuitutils.cpython-39.pyc,, +sympy/physics/quantum/tests/__pycache__/test_commutator.cpython-39.pyc,, +sympy/physics/quantum/tests/__pycache__/test_constants.cpython-39.pyc,, +sympy/physics/quantum/tests/__pycache__/test_dagger.cpython-39.pyc,, +sympy/physics/quantum/tests/__pycache__/test_density.cpython-39.pyc,, +sympy/physics/quantum/tests/__pycache__/test_fermion.cpython-39.pyc,, +sympy/physics/quantum/tests/__pycache__/test_gate.cpython-39.pyc,, +sympy/physics/quantum/tests/__pycache__/test_grover.cpython-39.pyc,, +sympy/physics/quantum/tests/__pycache__/test_hilbert.cpython-39.pyc,, +sympy/physics/quantum/tests/__pycache__/test_identitysearch.cpython-39.pyc,, +sympy/physics/quantum/tests/__pycache__/test_innerproduct.cpython-39.pyc,, +sympy/physics/quantum/tests/__pycache__/test_kind.cpython-39.pyc,, +sympy/physics/quantum/tests/__pycache__/test_matrixutils.cpython-39.pyc,, +sympy/physics/quantum/tests/__pycache__/test_operator.cpython-39.pyc,, +sympy/physics/quantum/tests/__pycache__/test_operatorordering.cpython-39.pyc,, +sympy/physics/quantum/tests/__pycache__/test_operatorset.cpython-39.pyc,, +sympy/physics/quantum/tests/__pycache__/test_pauli.cpython-39.pyc,, +sympy/physics/quantum/tests/__pycache__/test_piab.cpython-39.pyc,, +sympy/physics/quantum/tests/__pycache__/test_printing.cpython-39.pyc,, +sympy/physics/quantum/tests/__pycache__/test_qapply.cpython-39.pyc,, +sympy/physics/quantum/tests/__pycache__/test_qasm.cpython-39.pyc,, +sympy/physics/quantum/tests/__pycache__/test_qexpr.cpython-39.pyc,, +sympy/physics/quantum/tests/__pycache__/test_qft.cpython-39.pyc,, +sympy/physics/quantum/tests/__pycache__/test_qubit.cpython-39.pyc,, +sympy/physics/quantum/tests/__pycache__/test_represent.cpython-39.pyc,, +sympy/physics/quantum/tests/__pycache__/test_sho1d.cpython-39.pyc,, +sympy/physics/quantum/tests/__pycache__/test_shor.cpython-39.pyc,, +sympy/physics/quantum/tests/__pycache__/test_spin.cpython-39.pyc,, +sympy/physics/quantum/tests/__pycache__/test_state.cpython-39.pyc,, +sympy/physics/quantum/tests/__pycache__/test_tensorproduct.cpython-39.pyc,, +sympy/physics/quantum/tests/__pycache__/test_trace.cpython-39.pyc,, +sympy/physics/quantum/tests/__pycache__/test_transforms.cpython-39.pyc,, +sympy/physics/quantum/tests/test_anticommutator.py,sha256=ckWHKwQFiAMWcDaYSa_26vi_GIsvs32_0O62I5lGsr8,1304 +sympy/physics/quantum/tests/test_boson.py,sha256=BZjdrZ-F1QhyhDqfK4Zc1VEFBJi1PeiPjMpfBcHekfo,1676 +sympy/physics/quantum/tests/test_cartesian.py,sha256=4szaB4w555oDUkhaRtBF8mOfIxiODNCqOlDHEVcUFOo,4417 +sympy/physics/quantum/tests/test_cg.py,sha256=BF-2ybLhoAYOb0wfWFlgnobMzH20zTYJvW5Z7v46SYI,9159 +sympy/physics/quantum/tests/test_circuitplot.py,sha256=c3v9wUzLHUH-eBVGj6_broVhHkioNwpaaApTDAJEflU,2096 +sympy/physics/quantum/tests/test_circuitutils.py,sha256=GrJAWRQVH_l8EIHrj1ve2jtxske72IriQ3lo94fqrVQ,13187 +sympy/physics/quantum/tests/test_commutator.py,sha256=keBstGDpNITFRr06uVFrka_Lje56g6oFoJQEpZXmnYw,2727 +sympy/physics/quantum/tests/test_constants.py,sha256=KBmYPIF49Sq34lbzbFCZRYWSyIdhnR3AK3q-VbU6grU,338 +sympy/physics/quantum/tests/test_dagger.py,sha256=pi2lhL5EaE8FuIuwM9gOce2NpDCP4S1JptqI-mx8FGM,2632 +sympy/physics/quantum/tests/test_density.py,sha256=EyxiEgyc0nDSweJwI0JUwta7gZ81TVHCl7YDEosTrvI,9718 +sympy/physics/quantum/tests/test_fermion.py,sha256=RK-J3nV1UO_9R5UyrBIp_qfWX-5iZ152aoyEllKWSIc,1636 +sympy/physics/quantum/tests/test_gate.py,sha256=7oBX1HoWnrYtHjABRoqv_wQDB9B829E99fdcJzaqawM,12496 +sympy/physics/quantum/tests/test_grover.py,sha256=uze62AG6H4x2MYJJA-EY3NtkqwvrDIQ2kONuvIRQiZ4,3640 +sympy/physics/quantum/tests/test_hilbert.py,sha256=IGP6rc2-b3we9dRDbpRniFAhQwp_TYtMfFzxusAprx0,2643 +sympy/physics/quantum/tests/test_identitysearch.py,sha256=3YGrXCsFLhLtN5MRyT5ZF8ELrSdkvDKTv6xKM4i2ims,17745 +sympy/physics/quantum/tests/test_innerproduct.py,sha256=37tT8p6MhHjAYeoay1Zyv7gCs-DeZQi4VdwUH2IffDE,1483 +sympy/physics/quantum/tests/test_kind.py,sha256=MCc0jobaf-fZmi99ir2VpZM1Lah5w9ZMxqygVQZfleU,2515 +sympy/physics/quantum/tests/test_matrixutils.py,sha256=wXc2jYLYaR07MgbVBpUU0cEsy2YQhQygbHV6M55u_ss,4115 +sympy/physics/quantum/tests/test_operator.py,sha256=kmNGQzjtdsko6PlUNcTT0kjOYX73LOMzHGuncS2SjSw,8417 +sympy/physics/quantum/tests/test_operatorordering.py,sha256=SFvJfrBxreMgMB3PEpXGcTvO_113Pi1O-Jco-A9_aVI,2003 +sympy/physics/quantum/tests/test_operatorset.py,sha256=DNfBeYBa_58kSG7PM5Ilo6xnzek8lSiAGX01uMFRYqI,2628 +sympy/physics/quantum/tests/test_pauli.py,sha256=Bhsx_gj5cpYv4BhVJRQohxlKk_rcp4jHtSRlTP-m_xs,4940 +sympy/physics/quantum/tests/test_piab.py,sha256=8ndnzyIsjF4AOu_9k6Yqap_1XUDTbiGnv7onJdrZBWA,1086 +sympy/physics/quantum/tests/test_printing.py,sha256=Qg34sHUAgKeh7DhHNp3KQX0_bhfdkMPCKJUvL08ufz8,30332 +sympy/physics/quantum/tests/test_qapply.py,sha256=ol5IJuuPand4r-P2cBv_Wycv2wnkMsW3dlJxDAp_PSs,6086 +sympy/physics/quantum/tests/test_qasm.py,sha256=ZvMjiheWBceSmIM9LHOL5fiFUl6HsUo8puqdzywrhkc,2976 +sympy/physics/quantum/tests/test_qexpr.py,sha256=0a80r2d31yxuJn11CcemrMn855o2ge5CH8j73ZCwuKI,1830 +sympy/physics/quantum/tests/test_qft.py,sha256=v-sGTaW9S-gcGTDAUPvjwd1kINF6rlI_u5Sf-Gso0r8,1931 +sympy/physics/quantum/tests/test_qubit.py,sha256=rj5RNjLWjZ8514fQwEsVp0wuUlI1SKAIomL0wZ-9PqA,9271 +sympy/physics/quantum/tests/test_represent.py,sha256=lEwzpL0fGxDGkojZ4_WoBAtCcA7aq2-S-i0Z0QrnTXg,5177 +sympy/physics/quantum/tests/test_sho1d.py,sha256=1QCoRXwhcK-P1oUrREgAroPHJPzPgWEBSrHomjS3meA,6893 +sympy/physics/quantum/tests/test_shor.py,sha256=3a3GCg6V5_mlJ2bltoXinGMGvlSxpq7GluapD_3SZaQ,666 +sympy/physics/quantum/tests/test_spin.py,sha256=p6F0cfQkPYwXlzE29mTL-dpsEYZmFU0EfAwm2MUA9xE,345698 +sympy/physics/quantum/tests/test_state.py,sha256=pjNwjEQN1zMXzC6NbL60ZLzpqBfZD5RscjuyiFuhTaE,6748 +sympy/physics/quantum/tests/test_tensorproduct.py,sha256=CL8xJyUl7plFU7TYsHHkLDYt9PtS4SM_H-SG4lKM5aE,5314 +sympy/physics/quantum/tests/test_trace.py,sha256=dbpTXcJArWRR_Hh5JTuy2GJIfgjVo6zS20o5mdVEGH4,3057 +sympy/physics/quantum/tests/test_transforms.py,sha256=w1fLTJO07CWPOQp5jFVU83MatsaoNBkxOI_YbM6I1w8,2346 +sympy/physics/quantum/trace.py,sha256=2ZqN9IEsz3LKHTLV8ZDwTK0sM5PfwL0p2sYet0N7Gis,6397 +sympy/physics/quantum/transforms.py,sha256=XLnWmyoWW8oXJ8I7dUIF7cMxFt9xr736HfNstD_tnK0,10929 +sympy/physics/secondquant.py,sha256=1XDsbnUpFC7zmJ4VC9DXEsK_smSqNMaDfsi6slIkws0,90974 +sympy/physics/sho.py,sha256=K8P9FAdZr6UfQKYZO9TlhDUqUd3YsMekXCsKy2HhaY0,2480 +sympy/physics/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/physics/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/physics/tests/__pycache__/test_clebsch_gordan.cpython-39.pyc,, +sympy/physics/tests/__pycache__/test_hydrogen.cpython-39.pyc,, +sympy/physics/tests/__pycache__/test_paulialgebra.cpython-39.pyc,, +sympy/physics/tests/__pycache__/test_physics_matrices.cpython-39.pyc,, +sympy/physics/tests/__pycache__/test_pring.cpython-39.pyc,, +sympy/physics/tests/__pycache__/test_qho_1d.cpython-39.pyc,, +sympy/physics/tests/__pycache__/test_secondquant.cpython-39.pyc,, +sympy/physics/tests/__pycache__/test_sho.cpython-39.pyc,, +sympy/physics/tests/test_clebsch_gordan.py,sha256=gHMZnYkigoVWeDV7bwD4YARplb_L1AOMppzcLxSmWnI,9921 +sympy/physics/tests/test_hydrogen.py,sha256=kohRIR6JojE_GWYnlzLsMMgdhoKd8whazs0mq7cCTQc,4987 +sympy/physics/tests/test_paulialgebra.py,sha256=tyshEMsLNPR4iYzoAbPGZRZ-e_8t7GDP_xyjRyhepeQ,1477 +sympy/physics/tests/test_physics_matrices.py,sha256=Dha8iQRhzxLcl7TKSA6QP0pnEcBoqtj_Ob6tx01SMwI,2948 +sympy/physics/tests/test_pring.py,sha256=XScQQO9RhRrlqSII_ZyyOUpE-zs-7wphSFCZq2OuFnE,1261 +sympy/physics/tests/test_qho_1d.py,sha256=LD9WU-Y5lW7bVM7MyCkSGW9MU2FZhVjMB5Zk848_q1M,1775 +sympy/physics/tests/test_secondquant.py,sha256=vI9EuCMd1PMYNbZH2bHW9xHPYS5DS92QOmAk1kC12cQ,49401 +sympy/physics/tests/test_sho.py,sha256=aIs1f3eo6hb4ErRU8xrr_h_yhTmRx-fQgv9n27SfsLM,693 +sympy/physics/units/__init__.py,sha256=DVvWy9qNRm742NFGcBpybFY20ZK3BU7DWNbLMTXYiFo,12386 +sympy/physics/units/__pycache__/__init__.cpython-39.pyc,, +sympy/physics/units/__pycache__/dimensions.cpython-39.pyc,, +sympy/physics/units/__pycache__/prefixes.cpython-39.pyc,, +sympy/physics/units/__pycache__/quantities.cpython-39.pyc,, +sympy/physics/units/__pycache__/unitsystem.cpython-39.pyc,, +sympy/physics/units/__pycache__/util.cpython-39.pyc,, +sympy/physics/units/definitions/__init__.py,sha256=F3RyZc1AjM2Ch5b27Tt-VYdZ1HAIWvhgtQQQTfMiN6w,7470 +sympy/physics/units/definitions/__pycache__/__init__.cpython-39.pyc,, +sympy/physics/units/definitions/__pycache__/dimension_definitions.cpython-39.pyc,, +sympy/physics/units/definitions/__pycache__/unit_definitions.cpython-39.pyc,, +sympy/physics/units/definitions/dimension_definitions.py,sha256=QmPLCy7GvLx1z_YEZ5KFl8QUXg4rcuu3PuNI5wLpPdY,1633 +sympy/physics/units/definitions/unit_definitions.py,sha256=05wpHmAtyQvuJBeuzWm3cDQ6UYviNtsi4kVc0hv8VHw,14680 +sympy/physics/units/dimensions.py,sha256=GMuoSF1vqKDs_e3JJ7Y0lpFrneXMZLZk_qIC5d6Q5UU,20899 +sympy/physics/units/prefixes.py,sha256=_q2f8gA-kckBG7TutTFQazTf15PCZqNnaTR1gKXRfsk,6260 +sympy/physics/units/quantities.py,sha256=r5E231CULmsSEM7Rh7zfcTPuR85_X0CwRCVU_nDsek0,4671 +sympy/physics/units/systems/__init__.py,sha256=jJuvdc15c83yl11IuvhyjijwOZ9m1JGgZOgKwKv2e2o,244 +sympy/physics/units/systems/__pycache__/__init__.cpython-39.pyc,, +sympy/physics/units/systems/__pycache__/cgs.cpython-39.pyc,, +sympy/physics/units/systems/__pycache__/length_weight_time.cpython-39.pyc,, +sympy/physics/units/systems/__pycache__/mks.cpython-39.pyc,, +sympy/physics/units/systems/__pycache__/mksa.cpython-39.pyc,, +sympy/physics/units/systems/__pycache__/natural.cpython-39.pyc,, +sympy/physics/units/systems/__pycache__/si.cpython-39.pyc,, +sympy/physics/units/systems/cgs.py,sha256=gXbX8uuZo7lcYIENA-CpAnyS9WVQy-vRisxlQm-198A,3702 +sympy/physics/units/systems/length_weight_time.py,sha256=DXIDSWdhjfxGLA0ldOziWhwQjzTAs7-VQTNCHzDvCgY,7004 +sympy/physics/units/systems/mks.py,sha256=Z3eX9yWK9BdvEosCROK2qRKtKFYOjtQ50Jk6vFT7AQY,1546 +sympy/physics/units/systems/mksa.py,sha256=U8cSI-maIuLJRvpKLBuZA8V19LDRYVc2I40Rao-wvjk,2002 +sympy/physics/units/systems/natural.py,sha256=43Odvmtxdpbz8UcW_xoRE9ArJVVdF7dgdAN2ByDAXx4,909 +sympy/physics/units/systems/si.py,sha256=YBPUuovW3-JBDZYuStXXRaC8cfzE3En3K5MjNy5pLJk,14478 +sympy/physics/units/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/physics/units/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/physics/units/tests/__pycache__/test_dimensions.cpython-39.pyc,, +sympy/physics/units/tests/__pycache__/test_dimensionsystem.cpython-39.pyc,, +sympy/physics/units/tests/__pycache__/test_prefixes.cpython-39.pyc,, +sympy/physics/units/tests/__pycache__/test_quantities.cpython-39.pyc,, +sympy/physics/units/tests/__pycache__/test_unit_system_cgs_gauss.cpython-39.pyc,, +sympy/physics/units/tests/__pycache__/test_unitsystem.cpython-39.pyc,, +sympy/physics/units/tests/__pycache__/test_util.cpython-39.pyc,, +sympy/physics/units/tests/test_dimensions.py,sha256=lzkgGfEXMHxB8Izv7nRTN2uOEPh65LXPYaG8Kr5H05o,6122 +sympy/physics/units/tests/test_dimensionsystem.py,sha256=s2_2RAJwOaPOTvyIiAO9SYap374ytZqWbatWkLCnbSU,2717 +sympy/physics/units/tests/test_prefixes.py,sha256=Y3vlIReWGu8bwwZTrNGZSVoWYjhzgUZC33CDeyIvw48,2238 +sympy/physics/units/tests/test_quantities.py,sha256=XPuM6ul7XrUHvQE7F8rvpoCaxT9N6TM2gX99qUM4gTA,19758 +sympy/physics/units/tests/test_unit_system_cgs_gauss.py,sha256=JepTWt8yGdtv5dQ2AKUKb9fxpuYqLWOp0oOmzov9vfY,3173 +sympy/physics/units/tests/test_unitsystem.py,sha256=1Xh78_8hbv-yP4ICWI_dUrOnk3cimlvP_VhO-EXOa7Q,3254 +sympy/physics/units/tests/test_util.py,sha256=76PaLp_Cd9BAiry6VcWUD4Hrr68D6lTOkScWcLyhmL0,9355 +sympy/physics/units/unitsystem.py,sha256=mLuXftt3003OLk3knUGdYb1XfGbQa6vp-pHcenL11-o,7593 +sympy/physics/units/util.py,sha256=GU43MjXBRl5fm0CrU1fhFsZi9Q7wPG0NSJVJb1Oc_Y8,9885 +sympy/physics/vector/__init__.py,sha256=jZmrNB6ZfY7NOP8nx8GWcfI2Ixb2mv7lXuGHn63kyOw,985 +sympy/physics/vector/__pycache__/__init__.cpython-39.pyc,, +sympy/physics/vector/__pycache__/dyadic.cpython-39.pyc,, +sympy/physics/vector/__pycache__/fieldfunctions.cpython-39.pyc,, +sympy/physics/vector/__pycache__/frame.cpython-39.pyc,, +sympy/physics/vector/__pycache__/functions.cpython-39.pyc,, +sympy/physics/vector/__pycache__/point.cpython-39.pyc,, +sympy/physics/vector/__pycache__/printing.cpython-39.pyc,, +sympy/physics/vector/__pycache__/vector.cpython-39.pyc,, +sympy/physics/vector/dyadic.py,sha256=CuAkSBgnatfnZQ2gBazZQC1DBVXbksiFWWqxu2w3C2U,18042 +sympy/physics/vector/fieldfunctions.py,sha256=H3S7oeLHaInZns9ko04BPOcSSOK2M5hBK71jB813ExY,8610 +sympy/physics/vector/frame.py,sha256=J1sY-qaAamcZFKP26x54Q9TWITXBgjmE3v5PiEaFri4,57260 +sympy/physics/vector/functions.py,sha256=b4dHZOxbvS77SMcSlCV2GTaOzWj0TQiUtDaO_dhj3Uw,24810 +sympy/physics/vector/point.py,sha256=W1H6Jd-ZAIHY8yKO3YWj-vdEp7qI2yoglzk5HuT-y2E,20569 +sympy/physics/vector/printing.py,sha256=a1N-wziCnt4gHtY9luqe-CDW9aAtpZ0FcvWwQ0hMEEo,11790 +sympy/physics/vector/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/physics/vector/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/physics/vector/tests/__pycache__/test_dyadic.cpython-39.pyc,, +sympy/physics/vector/tests/__pycache__/test_fieldfunctions.cpython-39.pyc,, +sympy/physics/vector/tests/__pycache__/test_frame.cpython-39.pyc,, +sympy/physics/vector/tests/__pycache__/test_functions.cpython-39.pyc,, +sympy/physics/vector/tests/__pycache__/test_output.cpython-39.pyc,, +sympy/physics/vector/tests/__pycache__/test_point.cpython-39.pyc,, +sympy/physics/vector/tests/__pycache__/test_printing.cpython-39.pyc,, +sympy/physics/vector/tests/__pycache__/test_vector.cpython-39.pyc,, +sympy/physics/vector/tests/test_dyadic.py,sha256=vCSLxo0pynrHrCCo7S4925lgRhoSm_Mr5YhIh_1bYcw,4265 +sympy/physics/vector/tests/test_fieldfunctions.py,sha256=FUjh18QzB6dXSau9iHutb36o28faSa7T9sB0icpja-M,5825 +sympy/physics/vector/tests/test_frame.py,sha256=S5xn5Q9dMj_ACIS9ff7ACTBS7sFFovi9DIov-crUsko,29683 +sympy/physics/vector/tests/test_functions.py,sha256=qHGT0RR-vkAp1oF30TFNgaeuOGvXdZnK2aKZuRgrZHg,21127 +sympy/physics/vector/tests/test_output.py,sha256=hgqlE-_zN_EPE_gIZ_v2uXB1y2auo39hReWpvFUm2QQ,2612 +sympy/physics/vector/tests/test_point.py,sha256=5CTzT3a-igd33auAra4uusm0PYUc_whPtV7KAoZ4g5w,12373 +sympy/physics/vector/tests/test_printing.py,sha256=qVBjz4f3TtDrduUYLNDrvlrzBVMBDqLo27JWsFHdX18,10967 +sympy/physics/vector/tests/test_vector.py,sha256=bqU1ltS6UGbSo74KXMtvP1mOpqKQ6XSV19Wjw2QoNFc,10259 +sympy/physics/vector/vector.py,sha256=R7P63QPZee0lfHHVDU7Uq91ZJs5ewjVK5HVj9HHNeQo,24874 +sympy/physics/wigner.py,sha256=L5QyzD4yH4SyTk7-WCCd07XK0_Pd9nz7sHwo2zOWV7I,39588 +sympy/plotting/__init__.py,sha256=hAdOjai8-laj79rLJ2HZbiW1okXlz0p1ck-CoeNU6m8,526 +sympy/plotting/__pycache__/__init__.cpython-39.pyc,, +sympy/plotting/__pycache__/experimental_lambdify.cpython-39.pyc,, +sympy/plotting/__pycache__/plot.cpython-39.pyc,, +sympy/plotting/__pycache__/plot_implicit.cpython-39.pyc,, +sympy/plotting/__pycache__/plotgrid.cpython-39.pyc,, +sympy/plotting/__pycache__/series.cpython-39.pyc,, +sympy/plotting/__pycache__/textplot.cpython-39.pyc,, +sympy/plotting/__pycache__/utils.cpython-39.pyc,, +sympy/plotting/backends/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/plotting/backends/__pycache__/__init__.cpython-39.pyc,, +sympy/plotting/backends/__pycache__/base_backend.cpython-39.pyc,, +sympy/plotting/backends/base_backend.py,sha256=aof5cs8mWG6fl87DadR5sZBjY8ssMelOwKFWi27-FO0,14911 +sympy/plotting/backends/matplotlibbackend/__init__.py,sha256=Mzsz43gkanid12G8qnOcwJNbbvqf6eGppn-IZyJCMto,162 +sympy/plotting/backends/matplotlibbackend/__pycache__/__init__.cpython-39.pyc,, +sympy/plotting/backends/matplotlibbackend/__pycache__/matplotlib.cpython-39.pyc,, +sympy/plotting/backends/matplotlibbackend/matplotlib.py,sha256=9efu7ST_D3M-_fepZCPTcy8TdYEIX0PCkyocDGVfbPE,12548 +sympy/plotting/backends/textbackend/__init__.py,sha256=nnV_C7JJ_kwGcQe2C-tpnK635W8vTuIf5Grvvmur0rQ,92 +sympy/plotting/backends/textbackend/__pycache__/__init__.cpython-39.pyc,, +sympy/plotting/backends/textbackend/__pycache__/text.cpython-39.pyc,, +sympy/plotting/backends/textbackend/text.py,sha256=1ukArjwwQWUED9SB-1dmkB6YL5EcJ2rUosUf_NcBpXs,803 +sympy/plotting/experimental_lambdify.py,sha256=xcBhlvZ2h20aI1MpUN6qAEpO075Dv132AWbQJ7l3Wzg,22828 +sympy/plotting/intervalmath/__init__.py,sha256=fQV7sLZ9NHpZO5XGl2ZfqX56x-mdq-sYhtWEKLngHlU,479 +sympy/plotting/intervalmath/__pycache__/__init__.cpython-39.pyc,, +sympy/plotting/intervalmath/__pycache__/interval_arithmetic.cpython-39.pyc,, +sympy/plotting/intervalmath/__pycache__/interval_membership.cpython-39.pyc,, +sympy/plotting/intervalmath/__pycache__/lib_interval.cpython-39.pyc,, +sympy/plotting/intervalmath/interval_arithmetic.py,sha256=jv5YolNs6pOawIhuSsTBVwgkgmdOFwPrGN_1KtjfcIs,15570 +sympy/plotting/intervalmath/interval_membership.py,sha256=1VpO1T7UjvPxcMySC5GhZl8-VM_DxIirSWC3ZGmxIAY,2385 +sympy/plotting/intervalmath/lib_interval.py,sha256=WY1qRtyub4MDJaZizw6cXQI5NMEIXBO9UEWPEI80aW8,14809 +sympy/plotting/intervalmath/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/plotting/intervalmath/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/plotting/intervalmath/tests/__pycache__/test_interval_functions.cpython-39.pyc,, +sympy/plotting/intervalmath/tests/__pycache__/test_interval_membership.cpython-39.pyc,, +sympy/plotting/intervalmath/tests/__pycache__/test_intervalmath.cpython-39.pyc,, +sympy/plotting/intervalmath/tests/test_interval_functions.py,sha256=gdIo5z54tIbG8hDaGd3I8rBDP67oetMZWWdM-uvt1ec,9862 +sympy/plotting/intervalmath/tests/test_interval_membership.py,sha256=D1KjcrLxAwOmDEUqA-8TCqkFWGtmeerR9KwmzS7tyjk,4216 +sympy/plotting/intervalmath/tests/test_intervalmath.py,sha256=ndBMczrs6xYMN5RGnyCL9yq7pNUxrXHTSU1mdUsp5tU,9034 +sympy/plotting/plot.py,sha256=umvG47eIqOdEUo5qki3jDkS6LYKnTWFTZ5qzxZZBgAw,40882 +sympy/plotting/plot_implicit.py,sha256=1xIIr4eV1gIU4SkQ2k54PBZGUgYMCgsPttY-9uZ2eZM,7330 +sympy/plotting/plotgrid.py,sha256=QZmbxY-fcgPuseYLnKVDFsoH6m57Cosvy4W7l4jIqWw,6115 +sympy/plotting/pygletplot/__init__.py,sha256=DM7GURQbdSfcddHz23MxOShatBFc26tP_sd3G8pGCQE,3732 +sympy/plotting/pygletplot/__pycache__/__init__.cpython-39.pyc,, +sympy/plotting/pygletplot/__pycache__/color_scheme.cpython-39.pyc,, +sympy/plotting/pygletplot/__pycache__/managed_window.cpython-39.pyc,, +sympy/plotting/pygletplot/__pycache__/plot.cpython-39.pyc,, +sympy/plotting/pygletplot/__pycache__/plot_axes.cpython-39.pyc,, +sympy/plotting/pygletplot/__pycache__/plot_camera.cpython-39.pyc,, +sympy/plotting/pygletplot/__pycache__/plot_controller.cpython-39.pyc,, +sympy/plotting/pygletplot/__pycache__/plot_curve.cpython-39.pyc,, +sympy/plotting/pygletplot/__pycache__/plot_interval.cpython-39.pyc,, +sympy/plotting/pygletplot/__pycache__/plot_mode.cpython-39.pyc,, +sympy/plotting/pygletplot/__pycache__/plot_mode_base.cpython-39.pyc,, +sympy/plotting/pygletplot/__pycache__/plot_modes.cpython-39.pyc,, +sympy/plotting/pygletplot/__pycache__/plot_object.cpython-39.pyc,, +sympy/plotting/pygletplot/__pycache__/plot_rotation.cpython-39.pyc,, +sympy/plotting/pygletplot/__pycache__/plot_surface.cpython-39.pyc,, +sympy/plotting/pygletplot/__pycache__/plot_window.cpython-39.pyc,, +sympy/plotting/pygletplot/__pycache__/util.cpython-39.pyc,, +sympy/plotting/pygletplot/color_scheme.py,sha256=NgPUamkldygfrIPj0LvC_1AzhscVtg18FSudElvFYB8,12522 +sympy/plotting/pygletplot/managed_window.py,sha256=N7AKtM7ELfIJLie6zvI-J6-OQRBnMZu6AL1USz7hFEk,3072 +sympy/plotting/pygletplot/plot.py,sha256=s-5AJB0KelHs9WGoFIVIdYrOoMXfdpnM5-G2cF8xzDQ,13352 +sympy/plotting/pygletplot/plot_axes.py,sha256=Q9YN8W0Hd1PeflHLvOvSZ-hxeLU4Kq3nUFLYDC0x0E8,8655 +sympy/plotting/pygletplot/plot_camera.py,sha256=myYtKbv1ov_ltgR34hf8BR76t3AwTSu4QFUY5YY-e1E,3928 +sympy/plotting/pygletplot/plot_controller.py,sha256=MroJJSPCbBDT8gGs_GdqpV_KHsllMNJpxx0MU3vKJV8,6941 +sympy/plotting/pygletplot/plot_curve.py,sha256=YwKA2lYC7IwCOQJaOVnww8AAG4P36cArgbC1iLV9OFI,2838 +sympy/plotting/pygletplot/plot_interval.py,sha256=doqr2wxnrED4MJDlkxQ07GFvaagX36HUb77ly_vIuKQ,5431 +sympy/plotting/pygletplot/plot_mode.py,sha256=Djq-ewVms_JoSriDpolDhhtttBJQdJO8BD4E0nyOWcQ,14156 +sympy/plotting/pygletplot/plot_mode_base.py,sha256=3z3WjeN7TTslHJevhr3X_7HRHPgUleYSngu6285lR6k,11502 +sympy/plotting/pygletplot/plot_modes.py,sha256=gKzJShz6OXa6EHKar8SuHWrELVznxg_s2d5IBQkkeYE,5352 +sympy/plotting/pygletplot/plot_object.py,sha256=qGtzcKup4It1CqZ2jxA7FnorCua4S9I-B_7I3SHBjcQ,330 +sympy/plotting/pygletplot/plot_rotation.py,sha256=KGMCTyaSh5_UQD9FKGCIpoI3-CbbOhXxTcjAc1o5-PU,1445 +sympy/plotting/pygletplot/plot_surface.py,sha256=C0q9tzDmxzC1IpWiNKY4llzcopx6dhotGOLpK1N9m3s,3803 +sympy/plotting/pygletplot/plot_window.py,sha256=5boC2Fkmk46-gWGqWzdTkPmTMNHHOpA0CnB9q946Hwc,4643 +sympy/plotting/pygletplot/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/plotting/pygletplot/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/plotting/pygletplot/tests/__pycache__/test_plotting.cpython-39.pyc,, +sympy/plotting/pygletplot/tests/test_plotting.py,sha256=NisjR-yuBRJfQvjcb20skTR3yid2U3MhKHW6sy8RE10,2720 +sympy/plotting/pygletplot/util.py,sha256=mzQQgDDbp04B03KyJrossLp8Yq72RJzjp-3ArfjbMH8,4621 +sympy/plotting/series.py,sha256=hPkMicAi29RjoJkFpMCt1CRn0xdBVYU1wLniigHgCaM,96568 +sympy/plotting/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/plotting/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/plotting/tests/__pycache__/test_experimental_lambdify.cpython-39.pyc,, +sympy/plotting/tests/__pycache__/test_plot.cpython-39.pyc,, +sympy/plotting/tests/__pycache__/test_plot_implicit.cpython-39.pyc,, +sympy/plotting/tests/__pycache__/test_series.cpython-39.pyc,, +sympy/plotting/tests/__pycache__/test_textplot.cpython-39.pyc,, +sympy/plotting/tests/__pycache__/test_utils.cpython-39.pyc,, +sympy/plotting/tests/test_experimental_lambdify.py,sha256=EYshdXA5tAGWolaDX-nHAolp7xIJN4Oqb1Uc1C1IhJI,3127 +sympy/plotting/tests/test_plot.py,sha256=mXiNdG-dHkK1iJ_moNlehZQU3i8mCtUKhVJdTxSxjdA,48217 +sympy/plotting/tests/test_plot_implicit.py,sha256=wurrO7ntsHb4tWYPVs5VogRrtcylevu0EceCSEwiWQg,5799 +sympy/plotting/tests/test_region_and.png,sha256=EV0Lm4HtQPk_6eIWtPY4TPcQk-O7tkpdZIuLmFjGRaA,6864 +sympy/plotting/tests/test_region_not.png,sha256=3O_9_nPW149FMULEcT5RqI2-k2H3nHELbfJADt2cO8k,7939 +sympy/plotting/tests/test_region_or.png,sha256=5Bug09vyog-Cu3mky7pbtFjew5bMvbpe0ZXWsgDKfy4,8809 +sympy/plotting/tests/test_region_xor.png,sha256=kucVWBA9A98OpcR4did5aLXUyoq4z0O4C3PM6dliBSw,10002 +sympy/plotting/tests/test_series.py,sha256=Nwa6__YMoWCiY0mYSlvJPPnlvtfOYLACHR1RbkopMc0,65638 +sympy/plotting/tests/test_textplot.py,sha256=VurTGeMjUfBLpLdoMqzJK9gbcShNb7f1OrAcRNyrtag,12761 +sympy/plotting/tests/test_utils.py,sha256=FkcYZAFT8TnHRIbkknx9NvA3-LgR0ua7WFyzQEPsVIE,3602 +sympy/plotting/textplot.py,sha256=Xc-bWvzVOq8QUn_BxDlJv9bryCpgzsgV73XoQQwVj0Q,5075 +sympy/plotting/utils.py,sha256=Cimno9MQGjcbe4rQ6TN86r4rM2t62CgARqPtgF0mhR8,12259 +sympy/polys/__init__.py,sha256=qXmDNr7noCUAtIs0rJJRHHFbHjy3uHIvXZDajcSePc4,5577 +sympy/polys/__pycache__/__init__.cpython-39.pyc,, +sympy/polys/__pycache__/appellseqs.cpython-39.pyc,, +sympy/polys/__pycache__/compatibility.cpython-39.pyc,, +sympy/polys/__pycache__/constructor.cpython-39.pyc,, +sympy/polys/__pycache__/densearith.cpython-39.pyc,, +sympy/polys/__pycache__/densebasic.cpython-39.pyc,, +sympy/polys/__pycache__/densetools.cpython-39.pyc,, +sympy/polys/__pycache__/dispersion.cpython-39.pyc,, +sympy/polys/__pycache__/distributedmodules.cpython-39.pyc,, +sympy/polys/__pycache__/domainmatrix.cpython-39.pyc,, +sympy/polys/__pycache__/euclidtools.cpython-39.pyc,, +sympy/polys/__pycache__/factortools.cpython-39.pyc,, +sympy/polys/__pycache__/fglmtools.cpython-39.pyc,, +sympy/polys/__pycache__/fields.cpython-39.pyc,, +sympy/polys/__pycache__/galoistools.cpython-39.pyc,, +sympy/polys/__pycache__/groebnertools.cpython-39.pyc,, +sympy/polys/__pycache__/heuristicgcd.cpython-39.pyc,, +sympy/polys/__pycache__/modulargcd.cpython-39.pyc,, +sympy/polys/__pycache__/monomials.cpython-39.pyc,, +sympy/polys/__pycache__/multivariate_resultants.cpython-39.pyc,, +sympy/polys/__pycache__/orderings.cpython-39.pyc,, +sympy/polys/__pycache__/orthopolys.cpython-39.pyc,, +sympy/polys/__pycache__/partfrac.cpython-39.pyc,, +sympy/polys/__pycache__/polyclasses.cpython-39.pyc,, +sympy/polys/__pycache__/polyconfig.cpython-39.pyc,, +sympy/polys/__pycache__/polyerrors.cpython-39.pyc,, +sympy/polys/__pycache__/polyfuncs.cpython-39.pyc,, +sympy/polys/__pycache__/polymatrix.cpython-39.pyc,, +sympy/polys/__pycache__/polyoptions.cpython-39.pyc,, +sympy/polys/__pycache__/polyquinticconst.cpython-39.pyc,, +sympy/polys/__pycache__/polyroots.cpython-39.pyc,, +sympy/polys/__pycache__/polytools.cpython-39.pyc,, +sympy/polys/__pycache__/polyutils.cpython-39.pyc,, +sympy/polys/__pycache__/puiseux.cpython-39.pyc,, +sympy/polys/__pycache__/rationaltools.cpython-39.pyc,, +sympy/polys/__pycache__/ring_series.cpython-39.pyc,, +sympy/polys/__pycache__/rings.cpython-39.pyc,, +sympy/polys/__pycache__/rootisolation.cpython-39.pyc,, +sympy/polys/__pycache__/rootoftools.cpython-39.pyc,, +sympy/polys/__pycache__/solvers.cpython-39.pyc,, +sympy/polys/__pycache__/specialpolys.cpython-39.pyc,, +sympy/polys/__pycache__/sqfreetools.cpython-39.pyc,, +sympy/polys/__pycache__/subresultants_qq_zz.cpython-39.pyc,, +sympy/polys/agca/__init__.py,sha256=fahpWoG_0LgoqOXBnDBJS16Jj1fE1_VKG7edM3qZ2HE,130 +sympy/polys/agca/__pycache__/__init__.cpython-39.pyc,, +sympy/polys/agca/__pycache__/extensions.cpython-39.pyc,, +sympy/polys/agca/__pycache__/homomorphisms.cpython-39.pyc,, +sympy/polys/agca/__pycache__/ideals.cpython-39.pyc,, +sympy/polys/agca/__pycache__/modules.cpython-39.pyc,, +sympy/polys/agca/extensions.py,sha256=YmtFs9C0s-4DNMXFtdX1hYVNlby18mAJzhJ5Aqickrw,9388 +sympy/polys/agca/homomorphisms.py,sha256=gaMNV96pKUuYHZ8Bd7QOs27J1IbbJgkEjyWcTLe8GFI,21937 +sympy/polys/agca/ideals.py,sha256=S6rBl3H-hdeI44ZbELwjjt1rKlrhK11AKb8Aas-OtCE,11073 +sympy/polys/agca/modules.py,sha256=PeA138FHsCfUn8vkdrjPa92xmzq7Qqk0PrdYwVoNqf0,47240 +sympy/polys/agca/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/polys/agca/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/polys/agca/tests/__pycache__/test_extensions.cpython-39.pyc,, +sympy/polys/agca/tests/__pycache__/test_homomorphisms.cpython-39.pyc,, +sympy/polys/agca/tests/__pycache__/test_ideals.cpython-39.pyc,, +sympy/polys/agca/tests/__pycache__/test_modules.cpython-39.pyc,, +sympy/polys/agca/tests/test_extensions.py,sha256=i3IHQNXQByFMCvjjyd_hwwJSCiUj0z1rRwS9WFK2AFc,6455 +sympy/polys/agca/tests/test_homomorphisms.py,sha256=m0hFmcTzvZ8sZbbnWeENwzKyufpE9zWwZR-WCI4kdpU,4224 +sympy/polys/agca/tests/test_ideals.py,sha256=w76qXO-_HN6LQbV7l3h7gJZsM-DZ2io2X-kPWiHYRNw,3788 +sympy/polys/agca/tests/test_modules.py,sha256=HdfmcxdEVucEbtfmzVq8i_1wGojT5b5DE5VIfbTMx3k,13552 +sympy/polys/appellseqs.py,sha256=hWeDKsKnJuAuPN_5IU6m1okurAq9xMt3LQgMehcvBKQ,8305 +sympy/polys/benchmarks/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/polys/benchmarks/__pycache__/__init__.cpython-39.pyc,, +sympy/polys/benchmarks/__pycache__/bench_galoispolys.cpython-39.pyc,, +sympy/polys/benchmarks/__pycache__/bench_groebnertools.cpython-39.pyc,, +sympy/polys/benchmarks/__pycache__/bench_solvers.cpython-39.pyc,, +sympy/polys/benchmarks/bench_galoispolys.py,sha256=8RtN9ZQga2oxscVPPkMGB29Dz8UbskMS2szYtqZ69u0,1502 +sympy/polys/benchmarks/bench_groebnertools.py,sha256=YqGDCzewRszCye_GnneDXMRNB38ORSpVu_Jn0ELIySo,803 +sympy/polys/benchmarks/bench_solvers.py,sha256=gLrZguh6pE0E4_vM2GeOS5bHnrcSUQXqD0Qz9tItfmo,446778 +sympy/polys/compatibility.py,sha256=ti_gJSz8zB1sf-oIyfXAw4xkDGkQaZs_34RSC-6HldY,58373 +sympy/polys/constructor.py,sha256=rv8hQgE8-P9QXUtOQmlEiyA4XD-fb1Qzd8bvH3UHQls,11373 +sympy/polys/densearith.py,sha256=6wB6DJWKNB0RA7Tdg7jMgRIERjZdtAp32OQhIQ6wmV4,34114 +sympy/polys/densebasic.py,sha256=Kag6cqlxHhHUMYUXpr9GHLWh4E1Mnw2Esoh6zx9r14E,35954 +sympy/polys/densetools.py,sha256=A0Z1E16eVx577zw84m8KZgT5o6dbDXXonW8TqfISJZY,29746 +sympy/polys/dispersion.py,sha256=s6GIYnGA6U9jhGP7YXQQS8G3byG4-kPbr55BR6p-iz4,5740 +sympy/polys/distributedmodules.py,sha256=t8pLIgDQs_dMecGXwybVYoLavofEy2DXhFS8N5gj5SU,21827 +sympy/polys/domainmatrix.py,sha256=FmNqklNFQR1WrQYtP2r7jypw2IQadNKGP14EaUaxUqI,310 +sympy/polys/domains/__init__.py,sha256=T6qPNkU1EJ6D5BnvyJSXJv4zeJ5MUT5RLsovMkkXS9E,1872 +sympy/polys/domains/__pycache__/__init__.cpython-39.pyc,, +sympy/polys/domains/__pycache__/algebraicfield.cpython-39.pyc,, +sympy/polys/domains/__pycache__/characteristiczero.cpython-39.pyc,, +sympy/polys/domains/__pycache__/complexfield.cpython-39.pyc,, +sympy/polys/domains/__pycache__/compositedomain.cpython-39.pyc,, +sympy/polys/domains/__pycache__/domain.cpython-39.pyc,, +sympy/polys/domains/__pycache__/domainelement.cpython-39.pyc,, +sympy/polys/domains/__pycache__/expressiondomain.cpython-39.pyc,, +sympy/polys/domains/__pycache__/expressionrawdomain.cpython-39.pyc,, +sympy/polys/domains/__pycache__/field.cpython-39.pyc,, +sympy/polys/domains/__pycache__/finitefield.cpython-39.pyc,, +sympy/polys/domains/__pycache__/fractionfield.cpython-39.pyc,, +sympy/polys/domains/__pycache__/gaussiandomains.cpython-39.pyc,, +sympy/polys/domains/__pycache__/gmpyfinitefield.cpython-39.pyc,, +sympy/polys/domains/__pycache__/gmpyintegerring.cpython-39.pyc,, +sympy/polys/domains/__pycache__/gmpyrationalfield.cpython-39.pyc,, +sympy/polys/domains/__pycache__/groundtypes.cpython-39.pyc,, +sympy/polys/domains/__pycache__/integerring.cpython-39.pyc,, +sympy/polys/domains/__pycache__/modularinteger.cpython-39.pyc,, +sympy/polys/domains/__pycache__/mpelements.cpython-39.pyc,, +sympy/polys/domains/__pycache__/old_fractionfield.cpython-39.pyc,, +sympy/polys/domains/__pycache__/old_polynomialring.cpython-39.pyc,, +sympy/polys/domains/__pycache__/polynomialring.cpython-39.pyc,, +sympy/polys/domains/__pycache__/pythonfinitefield.cpython-39.pyc,, +sympy/polys/domains/__pycache__/pythonintegerring.cpython-39.pyc,, +sympy/polys/domains/__pycache__/pythonrational.cpython-39.pyc,, +sympy/polys/domains/__pycache__/pythonrationalfield.cpython-39.pyc,, +sympy/polys/domains/__pycache__/quotientring.cpython-39.pyc,, +sympy/polys/domains/__pycache__/rationalfield.cpython-39.pyc,, +sympy/polys/domains/__pycache__/realfield.cpython-39.pyc,, +sympy/polys/domains/__pycache__/ring.cpython-39.pyc,, +sympy/polys/domains/__pycache__/simpledomain.cpython-39.pyc,, +sympy/polys/domains/algebraicfield.py,sha256=JXK_o6BdQbYT_0gOLALNOLsmqgfV46DRFJb7cANJgAM,23051 +sympy/polys/domains/characteristiczero.py,sha256=vHYRUXPrfJzDF8wrd1KSFqG8WzwfITP_eweA-SHPVYA,382 +sympy/polys/domains/complexfield.py,sha256=VbhX6yigmU2OLhof8qEpPp1XZhzzkfXfcuZevx1LZsk,6159 +sympy/polys/domains/compositedomain.py,sha256=DAo2ISA9XdOnYzFu8azuPIQAT9fyVwSM8Pe425vmvww,1642 +sympy/polys/domains/domain.py,sha256=aj9HOp98G2m25wRzSDj4lWzoRLGyXDsuevzScdHjV6o,40850 +sympy/polys/domains/domainelement.py,sha256=IrG-Mzv_VlCAmE-hmJVH_d77TrsfyaGGfJVmU8FFvlY,860 +sympy/polys/domains/expressiondomain.py,sha256=AB7Mnd6kOLaS_yG4lMXZTpVPUuAHGcrd6RAjjFSVxNc,7593 +sympy/polys/domains/expressionrawdomain.py,sha256=cXarD2jXi97FGNiqNiDqQlX0g764EW2M1PEbrveImnY,1448 +sympy/polys/domains/field.py,sha256=fR0l6wZIv02UMJ5rioEP910jGUViVx38EdwtJuJX5vg,2959 +sympy/polys/domains/finitefield.py,sha256=1DHkfrCz08KXLEz7in4IK71B8udHPZZAWv50-jV5_vQ,10654 +sympy/polys/domains/fractionfield.py,sha256=BoqUCbVggwaoqGCVy_r42MIVHuAV-P3gONFVZIlWojU,5848 +sympy/polys/domains/gaussiandomains.py,sha256=YrcHg4QSPgIRW-Ft8J84q1zwmU-DVJuAscWXnVCQ9CI,19598 +sympy/polys/domains/gmpyfinitefield.py,sha256=WgSLnALNOVVKH931WpVT28ZWDilsrTDG8DyMee2xR94,437 +sympy/polys/domains/gmpyintegerring.py,sha256=qJ7w8K40cfzhztZtOuWlIL2DXa642xJcKIIxoAOlcSs,3037 +sympy/polys/domains/gmpyrationalfield.py,sha256=dZjrfcWaUA-BHUtutzLOWPlOSNLYzBqSFeukER6L_bA,3178 +sympy/polys/domains/groundtypes.py,sha256=hAla27w5ekoJR_8c-1Yo8vrEgIIggPc615tfe1udc9A,2102 +sympy/polys/domains/integerring.py,sha256=4oy49xTi8hV6qh8CWUAoBcgn2aJqgqwyh7bZBsjGfwI,7442 +sympy/polys/domains/modularinteger.py,sha256=k6gskb0eypXdrKJRxR3l_75mVjmICGnaOy7FdRMwG8E,6042 +sympy/polys/domains/mpelements.py,sha256=PREDWq_WA5er2zzftt5Lg0opnWzyFBsIkLKcyGXUVz0,5042 +sympy/polys/domains/old_fractionfield.py,sha256=TUVxyL2fS4QF3kgyW5EGfkl91ir3S1klu08UfZr3GuI,6226 +sympy/polys/domains/old_polynomialring.py,sha256=KQcH58oHnHzOpDdWojZiLlHDqrAiUd4OAaBIZigqpyc,15982 +sympy/polys/domains/polynomialring.py,sha256=OLPcwkyrbkt5pNZ342o3Ziv9TCrlpU3h18s9R-cxqXI,6223 +sympy/polys/domains/pythonfinitefield.py,sha256=lWp266ErnuUPuo7k8ju3z2S5IresFInpJAl4Ihsq0pI,453 +sympy/polys/domains/pythonintegerring.py,sha256=EH5s6nwaxmeaScLER_FfqPdhyuJnbjtBslHmgyOyEPs,2962 +sympy/polys/domains/pythonrational.py,sha256=M3VUGODh3MLElePjYtjt9b02ReMThw-XXpuQTkohgNs,548 +sympy/polys/domains/pythonrationalfield.py,sha256=x8BPkGKj0WPuwJzN2py5l9aAjHaY4djv65c4tzUTr3Y,2295 +sympy/polys/domains/quotientring.py,sha256=2NOUkkbFj4514qkDUszl6hl1EQuPikn3QEoQ1sDobGI,5911 +sympy/polys/domains/rationalfield.py,sha256=D-pA-iHDHbOi6fivqSrJnfmH2JTrheKJQ9ZFXXN5ftM,5982 +sympy/polys/domains/realfield.py,sha256=OdLfE_9OVxvQlnQiT4faJxuB6dZwnXdIt-o-5wGYroA,6456 +sympy/polys/domains/ring.py,sha256=p66U2X58acSHLHxOTU6aJZ0Umdcu1qiGIUDtV8iJCD0,3236 +sympy/polys/domains/simpledomain.py,sha256=_K-Zz8Opf505r3eHSrbPAlnGiGSjY_O4Cwa4OTeOSoY,369 +sympy/polys/domains/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/polys/domains/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/polys/domains/tests/__pycache__/test_domains.cpython-39.pyc,, +sympy/polys/domains/tests/__pycache__/test_polynomialring.cpython-39.pyc,, +sympy/polys/domains/tests/__pycache__/test_quotientring.cpython-39.pyc,, +sympy/polys/domains/tests/test_domains.py,sha256=XJwTGPM-8W4Rn76-cec99rpUIDDjP-7HTFs7AKGNrEw,51028 +sympy/polys/domains/tests/test_polynomialring.py,sha256=8nNFKuQeicwiBk6pnqPLNk0cnI61iK8Dpl-TZifvORc,2895 +sympy/polys/domains/tests/test_quotientring.py,sha256=BYoq1CqI76RDSm0xQdp1v7Dv1n5sdcmes-b_y_AfW-0,1459 +sympy/polys/euclidtools.py,sha256=hAsNupzfR-07fTwFvTF2yDabyKMm_PLgFC8KX9c2CAo,41808 +sympy/polys/factortools.py,sha256=CupycipegP1IGCYPY9eY2T8f_6h4gmr6qId1A4UDJvk,43034 +sympy/polys/fglmtools.py,sha256=Z3VS_IKsZu1o3g22KPTIH1feNL8cTpieaNWvLuILmK4,4298 +sympy/polys/fields.py,sha256=RdSDkOls6n87yvHG1KXFnGng0JS3SAcUAlz4tTb1RZE,21250 +sympy/polys/galoistools.py,sha256=Z3ed_A0Nohj9RzuM3kom7-ay4MnzIHppwgV2QpONuTo,57238 +sympy/polys/groebnertools.py,sha256=NhK-XcFR9e4chDDJJ-diXb7XYuw9zcixFA_riomThPM,23342 +sympy/polys/heuristicgcd.py,sha256=rD3intgKCtAAMH3sqlgqbJL1XSq9QjfeG_MYzwCOek0,3732 +sympy/polys/matrices/__init__.py,sha256=ZaPJMi8l22d3F3rudS4NqzSt0xwxbs3uwnQwlhhR91o,397 +sympy/polys/matrices/__pycache__/__init__.cpython-39.pyc,, +sympy/polys/matrices/__pycache__/_dfm.cpython-39.pyc,, +sympy/polys/matrices/__pycache__/_typing.cpython-39.pyc,, +sympy/polys/matrices/__pycache__/ddm.cpython-39.pyc,, +sympy/polys/matrices/__pycache__/dense.cpython-39.pyc,, +sympy/polys/matrices/__pycache__/dfm.cpython-39.pyc,, +sympy/polys/matrices/__pycache__/domainmatrix.cpython-39.pyc,, +sympy/polys/matrices/__pycache__/domainscalar.cpython-39.pyc,, +sympy/polys/matrices/__pycache__/eigen.cpython-39.pyc,, +sympy/polys/matrices/__pycache__/exceptions.cpython-39.pyc,, +sympy/polys/matrices/__pycache__/linsolve.cpython-39.pyc,, +sympy/polys/matrices/__pycache__/lll.cpython-39.pyc,, +sympy/polys/matrices/__pycache__/normalforms.cpython-39.pyc,, +sympy/polys/matrices/__pycache__/rref.cpython-39.pyc,, +sympy/polys/matrices/__pycache__/sdm.cpython-39.pyc,, +sympy/polys/matrices/_dfm.py,sha256=_3p5Hb9e86uBG_r__Ee6BAZxj4KiRhcf2uN3XnjHxwQ,32971 +sympy/polys/matrices/_typing.py,sha256=Egp2nMOaq-oJCW0bgNDxy1Bx6evTl8eMTpf4mzIw1s4,407 +sympy/polys/matrices/ddm.py,sha256=eMKUy8X1kQTEG1l7Fx_L4LG9Jqs9l0KEwlYQwlcl6Do,32362 +sympy/polys/matrices/dense.py,sha256=RYdogVHeWZoeO94Oal1mf1sJvj-JmyS_baRi8-d9Rms,23217 +sympy/polys/matrices/dfm.py,sha256=Fj_uF4FskrwcBDNuRSV--AozCP2cYkUz-BMitcMlkRE,1241 +sympy/polys/matrices/domainmatrix.py,sha256=MvjWRzQ4lKzX2rCl8mRtpiIXSe641_x6OMrTr8iscPA,115928 +sympy/polys/matrices/domainscalar.py,sha256=sPXC-it46yun2r_tJkun4MIuVQQyQMQTVsTedId8Rj4,3778 +sympy/polys/matrices/eigen.py,sha256=glArxs9_rTrE_ssz2fd2gKCTsguSyEHwoeQP82tmIcM,3015 +sympy/polys/matrices/exceptions.py,sha256=ay3Lv21X3QqszysBN71xdr9KGQuC5kDBl90a2Sjx6pM,1351 +sympy/polys/matrices/linsolve.py,sha256=p86jlJP9h3CxZX92xk2sh1WAvKBkB5_YD55k1oYpmGI,7534 +sympy/polys/matrices/lll.py,sha256=p5q_rb5QvAuHsGRZAzP2Jf_h2vdC_Be4Q62lkaYxW-c,3550 +sympy/polys/matrices/normalforms.py,sha256=6MJfVbiwlZuCk-0fYiyOMuqkPdVpyzTNdDP2s6jkRw0,17122 +sympy/polys/matrices/rref.py,sha256=5YK_phB382K9JFVIydlH2IYpziirEDUnKzyl5BfS2_o,15532 +sympy/polys/matrices/sdm.py,sha256=bb2fIO42KjyJuEuMBkiDFn_qLJVjD4gsRgWZ6GOmTHw,64293 +sympy/polys/matrices/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/polys/matrices/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/polys/matrices/tests/__pycache__/test_ddm.cpython-39.pyc,, +sympy/polys/matrices/tests/__pycache__/test_dense.cpython-39.pyc,, +sympy/polys/matrices/tests/__pycache__/test_domainmatrix.cpython-39.pyc,, +sympy/polys/matrices/tests/__pycache__/test_domainscalar.cpython-39.pyc,, +sympy/polys/matrices/tests/__pycache__/test_eigen.cpython-39.pyc,, +sympy/polys/matrices/tests/__pycache__/test_fflu.cpython-39.pyc,, +sympy/polys/matrices/tests/__pycache__/test_inverse.cpython-39.pyc,, +sympy/polys/matrices/tests/__pycache__/test_linsolve.cpython-39.pyc,, +sympy/polys/matrices/tests/__pycache__/test_lll.cpython-39.pyc,, +sympy/polys/matrices/tests/__pycache__/test_normalforms.cpython-39.pyc,, +sympy/polys/matrices/tests/__pycache__/test_nullspace.cpython-39.pyc,, +sympy/polys/matrices/tests/__pycache__/test_rref.cpython-39.pyc,, +sympy/polys/matrices/tests/__pycache__/test_sdm.cpython-39.pyc,, +sympy/polys/matrices/tests/__pycache__/test_xxm.cpython-39.pyc,, +sympy/polys/matrices/tests/test_ddm.py,sha256=BcnUntz3XR2Xyff02l4Mk2EsxwYlIBzM4lFSRo9oZT8,16699 +sympy/polys/matrices/tests/test_dense.py,sha256=ETC2h5yLKsPJ5GAgx_RUcWyR63iZx81pbtOZiWty1r0,9516 +sympy/polys/matrices/tests/test_domainmatrix.py,sha256=4QttFR5PYJg6Xn37VXdxmv1w9Xkm82fF-M-KRQIe0RU,50006 +sympy/polys/matrices/tests/test_domainscalar.py,sha256=HEFEKX5tP6SZ83_91nvLmFqgHxbVdpqOP4ZwzfFbHnc,3740 +sympy/polys/matrices/tests/test_eigen.py,sha256=T1lYZeW-0NwDxDOG6ZJLr-OICfxY2wa0fVHV2V6EXSk,3200 +sympy/polys/matrices/tests/test_fflu.py,sha256=7En_75QemdtHvR7lUyZrcHiV9Dv2rswBuuKVOPWi1pk,8542 +sympy/polys/matrices/tests/test_inverse.py,sha256=5DT5qvfS3MlrgNgeHnG8GLjgLnmhsxj94yR3cbtmEO8,5247 +sympy/polys/matrices/tests/test_linsolve.py,sha256=ur1BFzlIQyBrO_-aL71lXsLC67nfec9G5o-mzW_TFY4,3373 +sympy/polys/matrices/tests/test_lll.py,sha256=RGYTDGbLfFvAMTESv-S1kqSWzwtOIjmguqXO3yGCjH4,6624 +sympy/polys/matrices/tests/test_normalforms.py,sha256=x1-vkjXq3LGblUSJFz7z73IVmaxqZbVzWoGj9XS-yXw,5106 +sympy/polys/matrices/tests/test_nullspace.py,sha256=eAEDPlnVkKfFM9jn1gztOUQTos1Sm9qwH10C-t9uLUE,5418 +sympy/polys/matrices/tests/test_rref.py,sha256=mWTIfKAUP3vkGKhffCrPHuC_0DdF-iH41cchlSN8Pqc,25982 +sympy/polys/matrices/tests/test_sdm.py,sha256=fSE3bQlDU-atzTFTgp4AgAS3QL-7rvb_61stj3QGBnU,13415 +sympy/polys/matrices/tests/test_xxm.py,sha256=X9sP_VfPi4bUR7V1ZLq6ZHXPOBUA7tE5-gQULang_Ag,29778 +sympy/polys/modulargcd.py,sha256=gOwojISiV8IncaEHZJZMHOC1odN7xRKN5L1Ls4szROI,58729 +sympy/polys/monomials.py,sha256=6cDCTVNNPDeee5bt0RQnGuYnsO_fefioHA5mEDKHSuA,18458 +sympy/polys/multivariate_resultants.py,sha256=6bwdF-lcUqtKoVDrnOKhm_PyPfo8w0Yyy_GtESatQFw,15248 +sympy/polys/numberfields/__init__.py,sha256=ZfhC9MyfGfGUz_DT_rXasB-M_P2zUiZXOJUNh_Gtm8c,538 +sympy/polys/numberfields/__pycache__/__init__.cpython-39.pyc,, +sympy/polys/numberfields/__pycache__/basis.cpython-39.pyc,, +sympy/polys/numberfields/__pycache__/exceptions.cpython-39.pyc,, +sympy/polys/numberfields/__pycache__/galois_resolvents.cpython-39.pyc,, +sympy/polys/numberfields/__pycache__/galoisgroups.cpython-39.pyc,, +sympy/polys/numberfields/__pycache__/minpoly.cpython-39.pyc,, +sympy/polys/numberfields/__pycache__/modules.cpython-39.pyc,, +sympy/polys/numberfields/__pycache__/primes.cpython-39.pyc,, +sympy/polys/numberfields/__pycache__/resolvent_lookup.cpython-39.pyc,, +sympy/polys/numberfields/__pycache__/subfield.cpython-39.pyc,, +sympy/polys/numberfields/__pycache__/utilities.cpython-39.pyc,, +sympy/polys/numberfields/basis.py,sha256=IPA6cSwz-53ClQwo-wkmRzfx9pRX4iBhiggdLMVSgJ0,8261 +sympy/polys/numberfields/exceptions.py,sha256=kHb-aB4eBzG3SsIpYtL5wLExDSb8lIOpNq0tk3guFIA,1594 +sympy/polys/numberfields/galois_resolvents.py,sha256=OL3u-G6sCwvZuBuuYQO0QpL-wWxtjxFaBjVQhtiQ_Z8,25006 +sympy/polys/numberfields/galoisgroups.py,sha256=c4s6z_mEUkoWKezFrfjKiU2tZnEnxR4m1LYaagicVno,20671 +sympy/polys/numberfields/minpoly.py,sha256=0F7QVSOk8WqViD50jKtT-94C6kQyuXPBCyNZPTf0ScA,27791 +sympy/polys/numberfields/modules.py,sha256=4MJykT6gtT_LC033LWsHT_CM7nEydLISAdQ0yA_bhkQ,69243 +sympy/polys/numberfields/primes.py,sha256=UXOkuMdnamVvHPQZxwilCbhdwNNNS7zQ4bwSFc5xGgk,23984 +sympy/polys/numberfields/resolvent_lookup.py,sha256=qfLNKOz_WjtXwpVlfzy8EkD4gw12epx9npE9HsjyIdg,40411 +sympy/polys/numberfields/subfield.py,sha256=Huh9iq2CDQN8p0fs2bSe-yEoB_V1OVr8rH58rJT4crk,16668 +sympy/polys/numberfields/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/polys/numberfields/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/polys/numberfields/tests/__pycache__/test_basis.cpython-39.pyc,, +sympy/polys/numberfields/tests/__pycache__/test_galoisgroups.cpython-39.pyc,, +sympy/polys/numberfields/tests/__pycache__/test_minpoly.cpython-39.pyc,, +sympy/polys/numberfields/tests/__pycache__/test_modules.cpython-39.pyc,, +sympy/polys/numberfields/tests/__pycache__/test_numbers.cpython-39.pyc,, +sympy/polys/numberfields/tests/__pycache__/test_primes.cpython-39.pyc,, +sympy/polys/numberfields/tests/__pycache__/test_subfield.cpython-39.pyc,, +sympy/polys/numberfields/tests/__pycache__/test_utilities.cpython-39.pyc,, +sympy/polys/numberfields/tests/test_basis.py,sha256=96BJ7e4oPDKXyvlRrUkiQxmHyjRGpOkAC7R3ln-jgNE,4580 +sympy/polys/numberfields/tests/test_galoisgroups.py,sha256=3LFuMbV92VBFlqqEqjh37oQvmG8cgZ0pFxDCXUoYRL4,5036 +sympy/polys/numberfields/tests/test_minpoly.py,sha256=HjlmPWX90Cb-XK5G1H0-uirOiHdgTu7R8qpcG5ZI5iw,23115 +sympy/polys/numberfields/tests/test_modules.py,sha256=GU4166j_hMlB22uWxxIjV_ON8RsyvpaN7Ly3eK8_m8Y,22926 +sympy/polys/numberfields/tests/test_numbers.py,sha256=M0vZIBnjPBHV4vFUnPBILaqiR_cgSuU50kFB-v7l1gA,5988 +sympy/polys/numberfields/tests/test_primes.py,sha256=JhcAkaQMgjkOSziQ2jZApJ8b8oviil5cUy0hfFqNmZg,9779 +sympy/polys/numberfields/tests/test_subfield.py,sha256=EpBexToc17YwLB4-n6nE_HaleU8AcFGi5OpDReW7q0k,12776 +sympy/polys/numberfields/tests/test_utilities.py,sha256=rQGEJWctcfzjUtMwRuyCHerSqEsoP5C3z1bnddJA17o,3661 +sympy/polys/numberfields/utilities.py,sha256=IHKmfafE9tMGammfdtkyR4_IHb0mON-OBwitE0Ss1R0,13087 +sympy/polys/orderings.py,sha256=IFieyj4LkFa7NDiGTZD3VwUY7mSN3GEjThKk0z5WJ1s,8500 +sympy/polys/orthopolys.py,sha256=PIJg_Kml8MxmJs5QPCjGmRLjjlE_4QKhsWhU7-an7x0,10181 +sympy/polys/partfrac.py,sha256=ptYpXsp8dYyyeiPnLU-rdQ-uv3-GAvDnN79E81C68y8,14695 +sympy/polys/polyclasses.py,sha256=Zc9sjiwVxsM6wM8h9KCFP2Pjlfl4XA3MjC6NlguQzqM,96529 +sympy/polys/polyconfig.py,sha256=mgfFpp9SU159tA_PM2o04WZyzMoWfOtWZugRcHnP42c,1598 +sympy/polys/polyerrors.py,sha256=xByI-fqIHVYsYRm63NmHXlSSRCwSI9vZUoO-1Mf5Wlk,4744 +sympy/polys/polyfuncs.py,sha256=OEZpdYeHQADBJYqMw8JAyN4sw-jsJ6lzVH6m-CCoK8g,8547 +sympy/polys/polymatrix.py,sha256=iNa-EyIrzv7ZeXMaP_PjojgNd29AnIotZE3NeF2te44,9771 +sympy/polys/polyoptions.py,sha256=yG1k4ZPN2ufwpZ4lI4Sn8DSIHvd7EwDUHUaz64HxlUM,21806 +sympy/polys/polyquinticconst.py,sha256=mYLFWSBq3H3Y0I8cx76Z_xauLx1YeViC4xF6yWsSTPQ,96035 +sympy/polys/polyroots.py,sha256=bJToZlR6WZNl9fNxkIJxiPt3XoDQ_K6Pb3UOCBPfsKM,37025 +sympy/polys/polytools.py,sha256=BUwK8wJApazWeKQ3t27yg-HTMnwiv9ScPiavk2n02Rk,212876 +sympy/polys/polyutils.py,sha256=M8V-D4NkfMlsJe4LeAZzljrM4BKp10Qvuuutd8kmp3E,17248 +sympy/polys/puiseux.py,sha256=wc_5w5RyXHfAwC9Dde-6ecPm9y1z5DAmfthtSZ53CqA,26518 +sympy/polys/rationaltools.py,sha256=8vbkg3nuBxbd4ztR7lOj2jTF9taKUKTPah_fD38Ex6c,2838 +sympy/polys/ring_series.py,sha256=7DijmljyAaQrupWB0JmaNwGm3wNTstnEsBln-o1uX1M,59877 +sympy/polys/rings.py,sha256=KVdA3C04G2H3YN81rM-8ORsUHqmktIrz0S0tfwLwN4g,86080 +sympy/polys/rootisolation.py,sha256=QKCYgD4kMNBtSLYkUim8MXHa7louMtlO0riwKTiXRRI,64341 +sympy/polys/rootoftools.py,sha256=bMcVUcIwqaM643UU1WTCNmPmDFRepZEq9Xfdr9mm63Q,43092 +sympy/polys/solvers.py,sha256=FareKqDVVC6P2I4yjudi_3CS3kx0SbmAz3k7zHkETgE,13520 +sympy/polys/specialpolys.py,sha256=B2vijl75zgUKUTY1HCqjB9BTDFf3FM8ugwkKGTB83XA,11038 +sympy/polys/sqfreetools.py,sha256=oRFlOLhcXz5OKjc9SQtOuEhMvacvkHyTSEqhk6Snh1Y,19876 +sympy/polys/subresultants_qq_zz.py,sha256=VFKTAe3iGMCimLkc9GeuwG5o4SB0x2T5kRwxe4A5BDo,88241 +sympy/polys/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/polys/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_appellseqs.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_constructor.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_densearith.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_densebasic.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_densetools.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_dispersion.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_distributedmodules.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_euclidtools.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_factortools.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_fields.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_galoistools.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_groebnertools.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_heuristicgcd.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_hypothesis.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_injections.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_modulargcd.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_monomials.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_multivariate_resultants.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_orderings.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_orthopolys.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_partfrac.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_polyclasses.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_polyfuncs.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_polymatrix.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_polyoptions.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_polyroots.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_polytools.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_polyutils.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_puiseux.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_pythonrational.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_rationaltools.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_ring_series.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_rings.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_rootisolation.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_rootoftools.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_solvers.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_specialpolys.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_sqfreetools.cpython-39.pyc,, +sympy/polys/tests/__pycache__/test_subresultants_qq_zz.cpython-39.pyc,, +sympy/polys/tests/test_appellseqs.py,sha256=YTERuRr30QtfxYR0erXvJG8D-INe9RaMFAF0ZM-H4Ks,3820 +sympy/polys/tests/test_constructor.py,sha256=t6JpLbIdy2hcHF4_s_Xiav0kVc4OOsH5EFg6IXOiwmI,7220 +sympy/polys/tests/test_densearith.py,sha256=ptFGN7ovPH3dqT5khr6Wm8CR6a1WOR7Fi8dj8SuBvJs,40827 +sympy/polys/tests/test_densebasic.py,sha256=FcjAfZngYRQSWJ__BrBPITAUYjMNkS_pOtGyuJv_Bs0,21590 +sympy/polys/tests/test_densetools.py,sha256=o7HZfOw4Xn71RpCDhk98srNipA3zXfikAP_lV0gShlM,26241 +sympy/polys/tests/test_dispersion.py,sha256=S3diTgz7WKH88ulGNT3ijiQ8Y53Rpul6jATBJMGyQdw,3182 +sympy/polys/tests/test_distributedmodules.py,sha256=dXmjhozX5Yzb7DsrtbdFTqAxi9Z1UZNJvGxj-vHM7cM,7639 +sympy/polys/tests/test_euclidtools.py,sha256=vEyj48eIjm6-KRQtThNfI4ic_VDNB6l7jMouxJAF9HE,19482 +sympy/polys/tests/test_factortools.py,sha256=K5_R-01Muc_45XOtolk_tQswE2l6tGrlOJ9zA_aNVYE,24903 +sympy/polys/tests/test_fields.py,sha256=DUvcax7WGtydkLDXq0aQ9Ksc6PkdV67NPZCvqTYN2w8,9869 +sympy/polys/tests/test_galoistools.py,sha256=0oN3eSWvV99juZSXco6Ek9n6s6BFOmOE4UIEhyZnQQs,28532 +sympy/polys/tests/test_groebnertools.py,sha256=ZWHBcCCOVNwDxuJWg1WPo0krTHx1m1wTPi2cOYPsAT4,18584 +sympy/polys/tests/test_heuristicgcd.py,sha256=87Yc0on955VExFyOgJuxBZhsIMFz1Vq4vclIVkQ--cE,4297 +sympy/polys/tests/test_hypothesis.py,sha256=LdokFa3JrQ59_umY15x21QSScPjJoCl8RxHRfn1NUOc,1109 +sympy/polys/tests/test_injections.py,sha256=EONGggBUNWaVSwi817CzLBYJgkTehFq8-m-Qdqes984,1286 +sympy/polys/tests/test_modulargcd.py,sha256=2SjCEjCPkf1-u-Co5vuVg5sJmfBLSY3gkFLicLl4g6E,9043 +sympy/polys/tests/test_monomials.py,sha256=YcMBU1qe6g7EJojtV89dfJZmdXsXVTSly2J5cF6SD6o,11083 +sympy/polys/tests/test_multivariate_resultants.py,sha256=DJu8CcZ3xwx8njpjDeSOyhyxeqZYmhfb7dkSCU-ll7Y,9501 +sympy/polys/tests/test_orderings.py,sha256=bdsIsqJTFJCVyZNRMAGVDXVk79ldw9rmAGejS_lwKP0,4254 +sympy/polys/tests/test_orthopolys.py,sha256=Ob5Osxd8hwebEjK3dDMarZ2VDXzl3pdsEdz_-kMa80M,6560 +sympy/polys/tests/test_partfrac.py,sha256=VfWOx30fZu_avMbWqxtbLDcLi-9x6iuQmLfyatnFc3Y,7838 +sympy/polys/tests/test_polyclasses.py,sha256=EuoT3b9a9KRTQFe95MjGrQIqAhRgLG5qUo42FtkM-Js,14765 +sympy/polys/tests/test_polyfuncs.py,sha256=VbgCgCRE06dtSY9I9GSdPH9T52ETYYoxk4J3N1WBtd4,4520 +sympy/polys/tests/test_polymatrix.py,sha256=pl2VrN_d2XGOVHvvAnaNQzkdFTdQgjt9ePgo41soBRs,7353 +sympy/polys/tests/test_polyoptions.py,sha256=z9DUdt8K3lYkm4IyLH1Cv-TKe76HP-EyaRkZVsfWb6U,12416 +sympy/polys/tests/test_polyroots.py,sha256=m0rnjXMqX59XdFn9X4rRtZJjI39vBN-MAt9cpQszq-I,26803 +sympy/polys/tests/test_polytools.py,sha256=JfaZlcguRIh6ZVNpkIUc2sSRJX_Hq6aluW7HxjCDWO4,139212 +sympy/polys/tests/test_polyutils.py,sha256=Qs3QQl0WYmTnkYE2ovTxdLeu6DYnWO_OoUmLwNDZzSw,11547 +sympy/polys/tests/test_puiseux.py,sha256=4PnBdJuYDWku8l6Jg2Dsn7umg14ldYBdvHqtVaaELbE,7304 +sympy/polys/tests/test_pythonrational.py,sha256=vYMlOTuYvf-15P0nKTFm-uRrhUc-nCFEkqYFAPLxg08,4143 +sympy/polys/tests/test_rationaltools.py,sha256=wkvjzNP1IH-SdubNk5JJ7OWcY-zNF6z3t32kfp9Ncs0,2397 +sympy/polys/tests/test_ring_series.py,sha256=j88HGfiJ9ZL114UezdJtC5aWtox8RG0I2Qyq9gLnkCE,31697 +sympy/polys/tests/test_rings.py,sha256=q8A_CeerrzGqVf1Ol-O-Jc-R7q-2X119YSNMnds3Ccs,48735 +sympy/polys/tests/test_rootisolation.py,sha256=x-n-T-Con-8phelNa05BPszkC_UCW1C0yAOwz658I60,32724 +sympy/polys/tests/test_rootoftools.py,sha256=1cv7M6EQ3mK7NhcwKqbIvour1l_QYtW_zr6vyZ3rzs0,23414 +sympy/polys/tests/test_solvers.py,sha256=kKyEvmqt91HHuAV42DS26B8fPDFQ2Ghtb_Nf3-1eWyI,13644 +sympy/polys/tests/test_specialpolys.py,sha256=v1i0RMNlixxe2EbwqoE-UlV7a2KY-md4T7thhUhWSx0,5031 +sympy/polys/tests/test_sqfreetools.py,sha256=m8ipJE28YC9eOrsVVAqxduCsw4iOW-s_nw76y0E8wUs,4886 +sympy/polys/tests/test_subresultants_qq_zz.py,sha256=bSFOjbtyIClbJHLDvvzgmOgoZyC_5Vdpe7vtW9VGU_s,13219 +sympy/printing/__init__.py,sha256=Hh-z1idh92pk17XjdXkjYwb8cJ5coEr4ZJQGcpE3d9I,2215 +sympy/printing/__pycache__/__init__.cpython-39.pyc,, +sympy/printing/__pycache__/aesaracode.cpython-39.pyc,, +sympy/printing/__pycache__/c.cpython-39.pyc,, +sympy/printing/__pycache__/codeprinter.cpython-39.pyc,, +sympy/printing/__pycache__/conventions.cpython-39.pyc,, +sympy/printing/__pycache__/cxx.cpython-39.pyc,, +sympy/printing/__pycache__/defaults.cpython-39.pyc,, +sympy/printing/__pycache__/dot.cpython-39.pyc,, +sympy/printing/__pycache__/fortran.cpython-39.pyc,, +sympy/printing/__pycache__/glsl.cpython-39.pyc,, +sympy/printing/__pycache__/gtk.cpython-39.pyc,, +sympy/printing/__pycache__/jscode.cpython-39.pyc,, +sympy/printing/__pycache__/julia.cpython-39.pyc,, +sympy/printing/__pycache__/lambdarepr.cpython-39.pyc,, +sympy/printing/__pycache__/latex.cpython-39.pyc,, +sympy/printing/__pycache__/llvmjitcode.cpython-39.pyc,, +sympy/printing/__pycache__/maple.cpython-39.pyc,, +sympy/printing/__pycache__/mathematica.cpython-39.pyc,, +sympy/printing/__pycache__/mathml.cpython-39.pyc,, +sympy/printing/__pycache__/numpy.cpython-39.pyc,, +sympy/printing/__pycache__/octave.cpython-39.pyc,, +sympy/printing/__pycache__/precedence.cpython-39.pyc,, +sympy/printing/__pycache__/preview.cpython-39.pyc,, +sympy/printing/__pycache__/printer.cpython-39.pyc,, +sympy/printing/__pycache__/pycode.cpython-39.pyc,, +sympy/printing/__pycache__/python.cpython-39.pyc,, +sympy/printing/__pycache__/pytorch.cpython-39.pyc,, +sympy/printing/__pycache__/rcode.cpython-39.pyc,, +sympy/printing/__pycache__/repr.cpython-39.pyc,, +sympy/printing/__pycache__/rust.cpython-39.pyc,, +sympy/printing/__pycache__/smtlib.cpython-39.pyc,, +sympy/printing/__pycache__/str.cpython-39.pyc,, +sympy/printing/__pycache__/tableform.cpython-39.pyc,, +sympy/printing/__pycache__/tensorflow.cpython-39.pyc,, +sympy/printing/__pycache__/theanocode.cpython-39.pyc,, +sympy/printing/__pycache__/tree.cpython-39.pyc,, +sympy/printing/aesaracode.py,sha256=_iHvF9avR9MGiOnxLIwog2M5rUwhLRt-bdsbcJHeO_s,18921 +sympy/printing/c.py,sha256=qX8EkxVT5Z2Hncz8DJh8iar_jBuVcRDEiIt1d6jcerE,27011 +sympy/printing/codeprinter.py,sha256=QUAWM6H45YBkd7fX6jXXO3j-tfVaHhzxbXn7OnSm0Es,42371 +sympy/printing/conventions.py,sha256=hvN1VMZE5krMj8FL4AAuxAO0xCwhAugBFWrFQ0355AY,2586 +sympy/printing/cxx.py,sha256=vCQyzT-1eNLLDJy4NBwCp5X5OCoPOZ9icsx8YifaPsc,6123 +sympy/printing/defaults.py,sha256=YitLfIRfFH8ltNd18Y6YtBgq5H2te0wFKlHuIO4cvo8,135 +sympy/printing/dot.py,sha256=W0J798ZxBdlJercffBGnNDTp7J2tMdIYQkE_KIiyi3s,8274 +sympy/printing/fortran.py,sha256=zHVQMphFeeIGl-1x29Hw5Lekxi5NWou-V3Kn4O0oHQQ,28635 +sympy/printing/glsl.py,sha256=wsjPcpyzk92m4I2_re0cn3JwEjPmRlmlVWX9owhzM5w,20310 +sympy/printing/gtk.py,sha256=ptnwYxJr5ox3LG4TCDbRIgxsCikaVvEzWBaqIpITUXc,466 +sympy/printing/jscode.py,sha256=BORODo9Aio62FKh1HYlluG-IxtJC26_-IWrxAiEtmr8,11981 +sympy/printing/julia.py,sha256=V_Qbld0MOebIkT4VnmyP6VdIbgrPaqllgjkv7eSq8bk,23541 +sympy/printing/lambdarepr.py,sha256=BCx4eSdG8MQ8ZSUV1lWEd3CzbZ4IiMid-TTxPoV6FHU,8305 +sympy/printing/latex.py,sha256=jCqXe1XOWLNUcRNzsHOY3CMvbkPBKjf8tD4ii6PVKxU,123669 +sympy/printing/llvmjitcode.py,sha256=NjjHbgheH4j6SboFIDfrrIogkcrJRkW_kjlzM8DpDmQ,17127 +sympy/printing/maple.py,sha256=aG7t4RUEVojSFBVA14GDBzU702KtTcolTCelXozNiN0,10571 +sympy/printing/mathematica.py,sha256=to8LmNkhiP0ACh-Vun-hfaMnLFSEc0mF2OmZnQmihE8,12720 +sympy/printing/mathml.py,sha256=UT7g45oLTqvnrC1IUZeqywHTV6-LUClfuOF_U4AMlaI,76634 +sympy/printing/numpy.py,sha256=Kpq-ySkQSRKYCxtgQZEi_4lj3EmK7FgvbDJTzvB77Tc,21049 +sympy/printing/octave.py,sha256=1ofeBaIcoJfgzImezfPed9Z4dZ6Tk5gVbEQbr02TcLI,25469 +sympy/printing/precedence.py,sha256=hKIin20F4LJp2l5yhSQ-WAegTRHYQLhccABC1_ciI-U,5522 +sympy/printing/pretty/__init__.py,sha256=pJTe-DO4ctTlnjg1UvqyoeBY50B5znFjcGvivXRhM2U,344 +sympy/printing/pretty/__pycache__/__init__.cpython-39.pyc,, +sympy/printing/pretty/__pycache__/pretty.cpython-39.pyc,, +sympy/printing/pretty/__pycache__/pretty_symbology.cpython-39.pyc,, +sympy/printing/pretty/__pycache__/stringpict.cpython-39.pyc,, +sympy/printing/pretty/pretty.py,sha256=KmEXRb6XEtI5tiQYssueibKKrJVF1F1mwrF1uHGxmpM,105024 +sympy/printing/pretty/pretty_symbology.py,sha256=v4FyDz_03mbSGX7tGni8Zw2L2qKTRGvmCIL85ttXg0s,24112 +sympy/printing/pretty/stringpict.py,sha256=UbG55kC4ve8-Dby_v0cm8AXUdLcu4YDLXVnA8F_iV8Y,19158 +sympy/printing/pretty/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/printing/pretty/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/printing/pretty/tests/__pycache__/test_pretty.cpython-39.pyc,, +sympy/printing/pretty/tests/test_pretty.py,sha256=K2Xt--ptdxWLuUmtYJnuyh81ZrGIDRqNM_iOUADAdjA,187716 +sympy/printing/preview.py,sha256=3KacgHEch3KKJZeKc7kvD-L4pfftI8Gm62fx9-Vipak,14089 +sympy/printing/printer.py,sha256=z1phdIHrrcU2cyD-1GaBXREmhkcVl8bxeIWZrGAlczg,15868 +sympy/printing/pycode.py,sha256=LsDnREZEJVw0zsy_AjRSkYER-acomwQiDicrWBmnoW8,27048 +sympy/printing/python.py,sha256=oxGJhNQUXRKjHSEgqroGAyB-A4ZgJQqA_o7C8uXA1U8,3347 +sympy/printing/pytorch.py,sha256=HXEAoLbgbuE-SYj0hvDHeJ0zgzOVDBkK5qWErLCJhNo,10620 +sympy/printing/rcode.py,sha256=uBazXZwSRd95UUQjNoD6qLlpun7Xwlm8cfLG4qYhy3Y,14101 +sympy/printing/repr.py,sha256=GLO6WFOunZ8k0Om0g1VbGW2Y5uOP0IVEo_pIGO8HoNw,11545 +sympy/printing/rust.py,sha256=5Dh_zFxXZRVzbfHGowRQd9moUjzJZ7G2PXjxqsmBBw0,20831 +sympy/printing/smtlib.py,sha256=u2rVmprPITJdLq79FTYm7Fl-osteadr1JyXeQXlZW68,21752 +sympy/printing/str.py,sha256=dTK_CdT-Mm_zRoOL-Mwb4ocBA__h2wBFAAg6VvAiebQ,33239 +sympy/printing/tableform.py,sha256=pz4hXDLZ1IlqU8eOo0XBX1xaALYMqe4qHPUnEVq0mAE,11797 +sympy/printing/tensorflow.py,sha256=QmWoLUKpefBYJZ-J0Ei0NHHrU7Livp8Td4trWTDFzy8,8161 +sympy/printing/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/printing/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/printing/tests/__pycache__/test_aesaracode.cpython-39.pyc,, +sympy/printing/tests/__pycache__/test_c.cpython-39.pyc,, +sympy/printing/tests/__pycache__/test_codeprinter.cpython-39.pyc,, +sympy/printing/tests/__pycache__/test_conventions.cpython-39.pyc,, +sympy/printing/tests/__pycache__/test_cupy.cpython-39.pyc,, +sympy/printing/tests/__pycache__/test_cxx.cpython-39.pyc,, +sympy/printing/tests/__pycache__/test_dot.cpython-39.pyc,, +sympy/printing/tests/__pycache__/test_fortran.cpython-39.pyc,, +sympy/printing/tests/__pycache__/test_glsl.cpython-39.pyc,, +sympy/printing/tests/__pycache__/test_gtk.cpython-39.pyc,, +sympy/printing/tests/__pycache__/test_jax.cpython-39.pyc,, +sympy/printing/tests/__pycache__/test_jscode.cpython-39.pyc,, +sympy/printing/tests/__pycache__/test_julia.cpython-39.pyc,, +sympy/printing/tests/__pycache__/test_lambdarepr.cpython-39.pyc,, +sympy/printing/tests/__pycache__/test_latex.cpython-39.pyc,, +sympy/printing/tests/__pycache__/test_llvmjit.cpython-39.pyc,, +sympy/printing/tests/__pycache__/test_maple.cpython-39.pyc,, +sympy/printing/tests/__pycache__/test_mathematica.cpython-39.pyc,, +sympy/printing/tests/__pycache__/test_mathml.cpython-39.pyc,, +sympy/printing/tests/__pycache__/test_numpy.cpython-39.pyc,, +sympy/printing/tests/__pycache__/test_octave.cpython-39.pyc,, +sympy/printing/tests/__pycache__/test_precedence.cpython-39.pyc,, +sympy/printing/tests/__pycache__/test_preview.cpython-39.pyc,, +sympy/printing/tests/__pycache__/test_pycode.cpython-39.pyc,, +sympy/printing/tests/__pycache__/test_python.cpython-39.pyc,, +sympy/printing/tests/__pycache__/test_rcode.cpython-39.pyc,, +sympy/printing/tests/__pycache__/test_repr.cpython-39.pyc,, +sympy/printing/tests/__pycache__/test_rust.cpython-39.pyc,, +sympy/printing/tests/__pycache__/test_smtlib.cpython-39.pyc,, +sympy/printing/tests/__pycache__/test_str.cpython-39.pyc,, +sympy/printing/tests/__pycache__/test_tableform.cpython-39.pyc,, +sympy/printing/tests/__pycache__/test_tensorflow.cpython-39.pyc,, +sympy/printing/tests/__pycache__/test_theanocode.cpython-39.pyc,, +sympy/printing/tests/__pycache__/test_torch.cpython-39.pyc,, +sympy/printing/tests/__pycache__/test_tree.cpython-39.pyc,, +sympy/printing/tests/test_aesaracode.py,sha256=dX-SKJHfx9m91NSlt7pQ5Gul0x7fig3XZsHCt2kvtRc,21340 +sympy/printing/tests/test_c.py,sha256=etIEgJwlGWxtH6VRYg3Zkjyw_ktYRYj_-mvEMQehmjI,31227 +sympy/printing/tests/test_codeprinter.py,sha256=rVimgQTdJchy7PsjdZupQmXfqi6FBme2BPwnz7Z5WeQ,2142 +sympy/printing/tests/test_conventions.py,sha256=yqPpU3F0WcbxImPBBAHd3YEZpkFGfcq_TLK4WN_gtP4,5257 +sympy/printing/tests/test_cupy.py,sha256=E2T4a82XITbWAssmRZD4zCQdMCrpaZFCAjGXthUp8cY,1882 +sympy/printing/tests/test_cxx.py,sha256=0XQPQdtDCgEKVfpuelEB-79f7d2fe9QPgEKCCYVAI_w,3183 +sympy/printing/tests/test_dot.py,sha256=TSAtgGIgK_JbY-RMbQgUvnAI87SJqeJOqzcLjAobhKM,4648 +sympy/printing/tests/test_fortran.py,sha256=d5uy1zp0IgaBPupzXpNuk169YZYygnodPWhNRcvUtAQ,35456 +sympy/printing/tests/test_glsl.py,sha256=cfog9fp_EOFm_piJwqUcSvAIJ78bRwkFjecwr3ocCak,28421 +sympy/printing/tests/test_gtk.py,sha256=94gp1xRlPrFiALQGuqHnmh9xKrMxR52RQVkN0MXbUdA,500 +sympy/printing/tests/test_jax.py,sha256=5sOhmFNadFAKT2FumDuk5UqgOc0e0cYkRnJTczClnG8,11062 +sympy/printing/tests/test_jscode.py,sha256=ObahZne9lQbBiXyJZLohjQGdHsG2CnWCFOB8KbFOAqQ,11369 +sympy/printing/tests/test_julia.py,sha256=7KSmbPi_0J2ZSL2fI_xLWqoR9Gy73LfNtzFANHS8EiM,14014 +sympy/printing/tests/test_lambdarepr.py,sha256=DMoIfRNmUkF0KUXoNnJ_skNPU42tJ2xlEvqCT2GxwaA,6947 +sympy/printing/tests/test_latex.py,sha256=xHG_gSlEsD-QkOIiHUyWE4zb0brJxwgwGs1jJBvvOA0,138061 +sympy/printing/tests/test_llvmjit.py,sha256=EGPeRisM60_TIVgnk7PTLSm5F-Aod_88zLjHPZwfyZ8,5344 +sympy/printing/tests/test_maple.py,sha256=dhcz2bqapB_qMyBLMgsgnvVRIcHQ-HmxvAxC_Y0Z3q8,13078 +sympy/printing/tests/test_mathematica.py,sha256=oNNTHZzHviNi_ckua2issBb05IS7GaYyluBJiCjhGoM,10954 +sympy/printing/tests/test_mathml.py,sha256=yVIJYHlFNLbt1qtbKrqjxB2P88dsD6AtS9EVsDi_vEw,99434 +sympy/printing/tests/test_numpy.py,sha256=JZwyiFUS7Vfiwn3j1ogF31PJLJRzH57dtEHq_ujV9wQ,11715 +sympy/printing/tests/test_octave.py,sha256=WOJpsuGhMYLw2ikjwdTV6kVac6-I3xkjZ7qvUVheFbU,18545 +sympy/printing/tests/test_precedence.py,sha256=Za6JCTVvRGZhNY2dqiA8XpUpR5YGck7Op2cXiV78dZA,4067 +sympy/printing/tests/test_preview.py,sha256=dSVxiGqdNR6gbF40V4J2tGhQ-T4RDvSyGypHvYcPDYM,988 +sympy/printing/tests/test_pycode.py,sha256=txwQ1jKkQNZSOXoKw9w2n-_IeIdd-QUjh7WYvOgi1rA,18363 +sympy/printing/tests/test_python.py,sha256=HN7JkzQcKSnB6718i7kaEJZ5pYMqu56z1mSmHQGzY4k,8128 +sympy/printing/tests/test_rcode.py,sha256=TXyl449eCO4J-P3DQb3w9FWvMqOdXmfuiXgyBg_inQQ,13781 +sympy/printing/tests/test_repr.py,sha256=5XmDdIDlQlCWckWq8H95Fw82h-oDxrRpMWZePb6hHa4,12980 +sympy/printing/tests/test_rust.py,sha256=ClZ30PUmZj1ZoBWL2zSt4SsWQS0WB8FrFSQE-w3bxjU,11988 +sympy/printing/tests/test_smtlib.py,sha256=S-42OTOUGkZ-k17qCQ36DYgT-tyOeYqDCvtXzrL2Kzo,22553 +sympy/printing/tests/test_str.py,sha256=QhRuH-hwpXRcPTmKClxqBv8HQdlpNjR4dMXf_XJ3jCs,43523 +sympy/printing/tests/test_tableform.py,sha256=Ff5l1QL2HxN32WS_TdFhUAVqzop8YoWY3Uz1TThvVIM,5692 +sympy/printing/tests/test_tensorflow.py,sha256=-56L5IdYszZy9CGNxS8Ix8UaFStZudFuKQlyyz6Rzzs,17100 +sympy/printing/tests/test_theanocode.py,sha256=E36Fj72HxMK0e1pKTkoTpv9wI4UvwHdVufo-JA6dYq0,21394 +sympy/printing/tests/test_torch.py,sha256=yOtB4yl1DjbTbb8I_t2bdW8fAafVC5G8ubeIJpVLPus,16842 +sympy/printing/tests/test_tree.py,sha256=_8PGAhWMQ_A0f2DQLdDeMrpxY19889P5Ih9H41RZn8s,6080 +sympy/printing/theanocode.py,sha256=d1Aznwfqz6lFR680RGrIf5ht2AH0Ei7pz-7yGZt-CLY,19094 +sympy/printing/tree.py,sha256=tQNjLV3_c3QUUce068AbzAuCiUIJb47l_7F_TYxxFUQ,3868 +sympy/release.py,sha256=oitamisYLp3lrFI3uGv-5ZCJIeYOT9gZLcszUmVNhw8,23 +sympy/sandbox/__init__.py,sha256=IaEVOYHaZ97OHEuto1UGthFuO35c0uvAZFZU23YyEaU,189 +sympy/sandbox/__pycache__/__init__.cpython-39.pyc,, +sympy/sandbox/__pycache__/indexed_integrals.cpython-39.pyc,, +sympy/sandbox/indexed_integrals.py,sha256=svh4xDIa8nGpDeH4TeRb49gG8miMvXpCzEarbor58EE,2141 +sympy/sandbox/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/sandbox/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/sandbox/tests/__pycache__/test_indexed_integrals.cpython-39.pyc,, +sympy/sandbox/tests/test_indexed_integrals.py,sha256=UK2E2wg9EMwda4Vwpzyj3rmXs6ni33HqcbyaqAww6ww,1179 +sympy/series/__init__.py,sha256=DYG9oisjzYeS55dIUpQpbAFcoDz7Q81fZJw36PRGu14,766 +sympy/series/__pycache__/__init__.cpython-39.pyc,, +sympy/series/__pycache__/acceleration.cpython-39.pyc,, +sympy/series/__pycache__/approximants.cpython-39.pyc,, +sympy/series/__pycache__/aseries.cpython-39.pyc,, +sympy/series/__pycache__/formal.cpython-39.pyc,, +sympy/series/__pycache__/fourier.cpython-39.pyc,, +sympy/series/__pycache__/gruntz.cpython-39.pyc,, +sympy/series/__pycache__/kauers.cpython-39.pyc,, +sympy/series/__pycache__/limits.cpython-39.pyc,, +sympy/series/__pycache__/limitseq.cpython-39.pyc,, +sympy/series/__pycache__/order.cpython-39.pyc,, +sympy/series/__pycache__/residues.cpython-39.pyc,, +sympy/series/__pycache__/sequences.cpython-39.pyc,, +sympy/series/__pycache__/series.cpython-39.pyc,, +sympy/series/__pycache__/series_class.cpython-39.pyc,, +sympy/series/acceleration.py,sha256=Phe8jpZ4uiS4vYAzh8PiUoHKVUjyu5K9MoAnZX3pV_w,3365 +sympy/series/approximants.py,sha256=tE-hHuoW62QJHDA3WhRlXaTkokCAODs1vXgjirhOYiQ,3181 +sympy/series/aseries.py,sha256=cHVGRQaza4ayqI6ji6OHNkdQEMV7Bko4f4vug2buEQY,255 +sympy/series/benchmarks/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/series/benchmarks/__pycache__/__init__.cpython-39.pyc,, +sympy/series/benchmarks/__pycache__/bench_limit.cpython-39.pyc,, +sympy/series/benchmarks/__pycache__/bench_order.cpython-39.pyc,, +sympy/series/benchmarks/bench_limit.py,sha256=2PtdeeJtD6qyEvt9HFNvyTnMM8phFZRjscgnb4fHndU,173 +sympy/series/benchmarks/bench_order.py,sha256=iC8sQJ0lLlTgiXltAyLzSCQ-3490cf-c6NFiIU44JSk,207 +sympy/series/formal.py,sha256=9HJp9fcNX-MxcIfTtMptdKEyAw6cQQ3LyyzMyEt-9aw,51668 +sympy/series/fourier.py,sha256=H3zDg_S5D0Qeewp-o87_t_30TIg3rad83bzAoz9yp_E,22941 +sympy/series/gruntz.py,sha256=fhHr8w8SeMd-6GKjHmlwajd4FUmzeYo97AA3C_q0lKg,23497 +sympy/series/kauers.py,sha256=PzD0MATMNjLjPi9GW5GQGL6Uqc2UT-uPwnzhi7TkJH8,1720 +sympy/series/limits.py,sha256=hoSFWORWR-DodSf6Po93uK_09JUePhNvI60YsH6V4CI,13307 +sympy/series/limitseq.py,sha256=WM1Lh3RXhSZM1gQaJrhWnUtYEgJunLujIEw1gmtVhYw,7752 +sympy/series/order.py,sha256=1JjdsH2zvP8X1rf75qpfIB-iI70fkcH9CmJsXsRqSxY,19558 +sympy/series/residues.py,sha256=k46s_fFfIHdJZqfst-B_-X1R-SAWs_rR9MQH7a9JLtg,2213 +sympy/series/sequences.py,sha256=WLubAZr5AevpllgI4b6WmIfXRX88QNVCpMGBofGkcqo,35543 +sympy/series/series.py,sha256=PveJ6LibsrDX-mymTbV7teNggnsX0cfaoGMLAP0KAIo,1863 +sympy/series/series_class.py,sha256=033NJ5Re8AS4eq-chmfct3-Lz2vBqdFqXtnrbxswTx0,2918 +sympy/series/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/series/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/series/tests/__pycache__/test_approximants.cpython-39.pyc,, +sympy/series/tests/__pycache__/test_aseries.cpython-39.pyc,, +sympy/series/tests/__pycache__/test_demidovich.cpython-39.pyc,, +sympy/series/tests/__pycache__/test_formal.cpython-39.pyc,, +sympy/series/tests/__pycache__/test_fourier.cpython-39.pyc,, +sympy/series/tests/__pycache__/test_gruntz.cpython-39.pyc,, +sympy/series/tests/__pycache__/test_kauers.cpython-39.pyc,, +sympy/series/tests/__pycache__/test_limits.cpython-39.pyc,, +sympy/series/tests/__pycache__/test_limitseq.cpython-39.pyc,, +sympy/series/tests/__pycache__/test_lseries.cpython-39.pyc,, +sympy/series/tests/__pycache__/test_nseries.cpython-39.pyc,, +sympy/series/tests/__pycache__/test_order.cpython-39.pyc,, +sympy/series/tests/__pycache__/test_residues.cpython-39.pyc,, +sympy/series/tests/__pycache__/test_sequences.cpython-39.pyc,, +sympy/series/tests/__pycache__/test_series.cpython-39.pyc,, +sympy/series/tests/test_approximants.py,sha256=KViHMW1dPXn7xaPYhtTQ9L_WtLLkoIic6yfFnwZ8Q70,1012 +sympy/series/tests/test_aseries.py,sha256=5yWntog5ogRZrmjF3ZhY4e-d2mHYC1Ocs7IXRRFQrtI,2377 +sympy/series/tests/test_demidovich.py,sha256=JGYacqJMEqHS6oT2AYs9d7iutIEb32PkJs9EJqOHxcQ,4947 +sympy/series/tests/test_formal.py,sha256=TpMws7tgQHaokLrzEKZdlOcn6f04WIZhUqreaOOJ0kU,22496 +sympy/series/tests/test_fourier.py,sha256=-RnlhZUZJuctdE3pEtVMeGzCfyWzR1TFeojDhpJqyWo,5908 +sympy/series/tests/test_gruntz.py,sha256=3yPq8ksMGM7x4VQIfsFkk3JIgHlj-g1kJJBwIL76uPg,16445 +sympy/series/tests/test_kauers.py,sha256=Z85FhfXOOVki0HNGeK5BEBZOpkuB6SnKK3FqfK1-aLQ,1102 +sympy/series/tests/test_limits.py,sha256=YuwcBy8Atf-UGP8Li2sFU0guV_ZsjHqx-IdPwI6M3EA,48638 +sympy/series/tests/test_limitseq.py,sha256=QjEF99sYEDqfY7ULz1qjQTo6e0lIRUCflEOBgiDYRVA,5691 +sympy/series/tests/test_lseries.py,sha256=GlQvlBlD9wh02PPBP6zU83wmhurvGUFTuCRp44B4uI4,1875 +sympy/series/tests/test_nseries.py,sha256=uzhzYswSOe9Gh_nWKeO69tvGPMLd-9tqk4HBYX8JIm4,18284 +sympy/series/tests/test_order.py,sha256=AUlRkYKCUIs4DyT6g94ZGmjzT2fqHutPXOUGhpz4mb0,17901 +sympy/series/tests/test_residues.py,sha256=pT9xzPqtmfKGSbLLAxgDVZLTSy3TOxyfq3thTJs2VLw,3178 +sympy/series/tests/test_sequences.py,sha256=Oyq32yQZnGNQDS2uJ3by3bZ-y4G9c9BFfdQTcVuW2RM,11161 +sympy/series/tests/test_series.py,sha256=yPqSIfTytdiIYlYR9yvK7uS4IwF1fctpDiiktEnwvec,17006 +sympy/sets/__init__.py,sha256=3vjCm4v2esbpsVPY0ROwTXMETxns_66bG4FCIFZ96oM,1026 +sympy/sets/__pycache__/__init__.cpython-39.pyc,, +sympy/sets/__pycache__/conditionset.cpython-39.pyc,, +sympy/sets/__pycache__/contains.cpython-39.pyc,, +sympy/sets/__pycache__/fancysets.cpython-39.pyc,, +sympy/sets/__pycache__/ordinals.cpython-39.pyc,, +sympy/sets/__pycache__/powerset.cpython-39.pyc,, +sympy/sets/__pycache__/setexpr.cpython-39.pyc,, +sympy/sets/__pycache__/sets.cpython-39.pyc,, +sympy/sets/conditionset.py,sha256=mhodBVrMqJ6W5H8CuaFhO3FO9VdJifOEPjjFmV9lT2I,7792 +sympy/sets/contains.py,sha256=AlsfOc_6V0TH_6G7XTPIXhDywgtg3ECdjCTfSOmPC-E,1829 +sympy/sets/fancysets.py,sha256=tZ4wVasc937Yr0VkQuLVHe75Falcr72Nrkhk89178QQ,48187 +sympy/sets/handlers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/sets/handlers/__pycache__/__init__.cpython-39.pyc,, +sympy/sets/handlers/__pycache__/add.cpython-39.pyc,, +sympy/sets/handlers/__pycache__/comparison.cpython-39.pyc,, +sympy/sets/handlers/__pycache__/functions.cpython-39.pyc,, +sympy/sets/handlers/__pycache__/intersection.cpython-39.pyc,, +sympy/sets/handlers/__pycache__/issubset.cpython-39.pyc,, +sympy/sets/handlers/__pycache__/mul.cpython-39.pyc,, +sympy/sets/handlers/__pycache__/power.cpython-39.pyc,, +sympy/sets/handlers/__pycache__/union.cpython-39.pyc,, +sympy/sets/handlers/add.py,sha256=_ucFvxuDv9wsmKxGkCDUERtYk3I_tQxjZjY3ZkroWs0,1863 +sympy/sets/handlers/comparison.py,sha256=WfT_vLrOkvPqRg2mf7gziVs_6cLg0kOTEFv-Nb1zIvo,1601 +sympy/sets/handlers/functions.py,sha256=jYSFqFNH6mXbKFPgvIAIGY8BhbLPo1dAvcNg4MxmCaI,8381 +sympy/sets/handlers/intersection.py,sha256=oYPmrx3FAkGLVWT3EjSimeMfsGqPsqVnpJm5UxQzuCM,17514 +sympy/sets/handlers/issubset.py,sha256=azka_5eOaUro3r3v72PmET0oY8-aaoJkzVEK7kuqXCA,4739 +sympy/sets/handlers/mul.py,sha256=XFbkOw4PDQumaOEUlHeQLvjhIom0f3iniSYv_Kau-xw,1842 +sympy/sets/handlers/power.py,sha256=84N3dIus7r09XV7PF_RiEpFRw1y5tOGD34WKzSM9F-4,3186 +sympy/sets/handlers/union.py,sha256=lrAdydqExnALUjM0dnoM-7JAZqtbgLb46Y2GGmFtQdw,4225 +sympy/sets/ordinals.py,sha256=GSyaBq7BHJC3pvgoCDoUKZQ0IE2VXyHtx6_g5OS64W4,7641 +sympy/sets/powerset.py,sha256=vIGnSYKngEPEt6V-6beDOXAOY9ugDLJ8fXOx5H9JJck,2913 +sympy/sets/setexpr.py,sha256=jMOQigDscLTrFPXvHqo1ODVRG9BqC4yn38Ej4m6WPa0,3019 +sympy/sets/sets.py,sha256=-iuLnPWLunCa0oOErHELO9EkLZXs6L2AFWQgkCFbJcA,81202 +sympy/sets/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/sets/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/sets/tests/__pycache__/test_conditionset.cpython-39.pyc,, +sympy/sets/tests/__pycache__/test_contains.cpython-39.pyc,, +sympy/sets/tests/__pycache__/test_fancysets.cpython-39.pyc,, +sympy/sets/tests/__pycache__/test_ordinals.cpython-39.pyc,, +sympy/sets/tests/__pycache__/test_powerset.cpython-39.pyc,, +sympy/sets/tests/__pycache__/test_setexpr.cpython-39.pyc,, +sympy/sets/tests/__pycache__/test_sets.cpython-39.pyc,, +sympy/sets/tests/test_conditionset.py,sha256=UeutWuBCT5PZicNkkl9E94pREnU5CvLGWUfmRS29PcU,11370 +sympy/sets/tests/test_contains.py,sha256=ec6WRzriwV9nurz3jS9IXEqtfL1pZedtJFp--fSsBvY,1561 +sympy/sets/tests/test_fancysets.py,sha256=oiuu5PKNrDeDO-NtgsE4CpkwQaS7JNgCYuSyc1ykKKE,51938 +sympy/sets/tests/test_ordinals.py,sha256=L4DYc6ByQMDwJGFzJC3YhfSrVk5auW7pf4QYpJ5xY7w,2637 +sympy/sets/tests/test_powerset.py,sha256=nFvDGlhAf0wG-pZnPkgJjfwDHrTwdro3MYIinwyxn94,4805 +sympy/sets/tests/test_setexpr.py,sha256=E--SjYVzrmau0EbD8g4NTqp6aLD8qHzIuI7sAfuWxpY,14797 +sympy/sets/tests/test_sets.py,sha256=oMtqcr9gr68itfWH4j18POID8tD0hAZUV8HCz7-KKj0,68544 +sympy/simplify/__init__.py,sha256=MH1vkwHq0J5tNm7ss8V6v-mjrDGUXwfOsariIwfi38c,1274 +sympy/simplify/__pycache__/__init__.cpython-39.pyc,, +sympy/simplify/__pycache__/_cse_diff.cpython-39.pyc,, +sympy/simplify/__pycache__/combsimp.cpython-39.pyc,, +sympy/simplify/__pycache__/cse_main.cpython-39.pyc,, +sympy/simplify/__pycache__/cse_opts.cpython-39.pyc,, +sympy/simplify/__pycache__/epathtools.cpython-39.pyc,, +sympy/simplify/__pycache__/fu.cpython-39.pyc,, +sympy/simplify/__pycache__/gammasimp.cpython-39.pyc,, +sympy/simplify/__pycache__/hyperexpand.cpython-39.pyc,, +sympy/simplify/__pycache__/hyperexpand_doc.cpython-39.pyc,, +sympy/simplify/__pycache__/powsimp.cpython-39.pyc,, +sympy/simplify/__pycache__/radsimp.cpython-39.pyc,, +sympy/simplify/__pycache__/ratsimp.cpython-39.pyc,, +sympy/simplify/__pycache__/simplify.cpython-39.pyc,, +sympy/simplify/__pycache__/sqrtdenest.cpython-39.pyc,, +sympy/simplify/__pycache__/traversaltools.cpython-39.pyc,, +sympy/simplify/__pycache__/trigsimp.cpython-39.pyc,, +sympy/simplify/_cse_diff.py,sha256=Nq1gXn27o1v83t__SQUuacnOohhywa39ucPaR2jD518,10479 +sympy/simplify/combsimp.py,sha256=UzgG6eC_eIzieMS3j0oQFJnS5XkZ1gxNqOSzBqJcIFQ,3604 +sympy/simplify/cse_main.py,sha256=tcUmXXFGIZyoX4PvBEmZ-YYm-CBLXoAyGLHverxP23o,31310 +sympy/simplify/cse_opts.py,sha256=ZTCaOdOrgtifWxQmFzyngrLq9uwzByBdiSS5mE-DDoE,1618 +sympy/simplify/epathtools.py,sha256=fV3oeD2J3LJRQ4YoTV6gq3yFT_NSZi-7xxKnrX-2KJA,10122 +sympy/simplify/fu.py,sha256=exevmk1p4YgEAO-MmTgh7cG7cyVMMF8p7vZ1ymwwKiI,62208 +sympy/simplify/gammasimp.py,sha256=yzfkD_q1Zxu1eMCZuprmA3VzZXhecX3K0xdZlUdYlcY,18485 +sympy/simplify/hyperexpand.py,sha256=sTWd1-VeLr0NTr-cm6nftjmO_gVPqa1P9ud_rvSZS54,84420 +sympy/simplify/hyperexpand_doc.py,sha256=E8AD0mj8ULtelDSUkmJKJY7kYm5fVfCL4QH_DX65qEw,521 +sympy/simplify/powsimp.py,sha256=aCkX2zxgOm1S9oWc260KrzhSwcJw8p3JEMZ-vkmUapM,26754 +sympy/simplify/radsimp.py,sha256=VP83Ufl_9DUtmeHCRXSmOHqdg9zDsPoQJ_7eUyCiwpI,41649 +sympy/simplify/ratsimp.py,sha256=vTu0t0k1FDFTofl4mwcK9NuTMNZ20IsL8fIyFJRFNkg,7682 +sympy/simplify/simplify.py,sha256=3rK7cCfkh1svliMVsCuTv9_aQDz3nmZvA_HRAvmNkSo,73330 +sympy/simplify/sqrtdenest.py,sha256=MsTyPUcEhiMf51fpk6e5grlqudU-LMAHOuatPCxo5uA,21600 +sympy/simplify/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/simplify/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/simplify/tests/__pycache__/test_combsimp.cpython-39.pyc,, +sympy/simplify/tests/__pycache__/test_cse.cpython-39.pyc,, +sympy/simplify/tests/__pycache__/test_cse_diff.cpython-39.pyc,, +sympy/simplify/tests/__pycache__/test_epathtools.cpython-39.pyc,, +sympy/simplify/tests/__pycache__/test_fu.cpython-39.pyc,, +sympy/simplify/tests/__pycache__/test_function.cpython-39.pyc,, +sympy/simplify/tests/__pycache__/test_gammasimp.cpython-39.pyc,, +sympy/simplify/tests/__pycache__/test_hyperexpand.cpython-39.pyc,, +sympy/simplify/tests/__pycache__/test_powsimp.cpython-39.pyc,, +sympy/simplify/tests/__pycache__/test_radsimp.cpython-39.pyc,, +sympy/simplify/tests/__pycache__/test_ratsimp.cpython-39.pyc,, +sympy/simplify/tests/__pycache__/test_rewrite.cpython-39.pyc,, +sympy/simplify/tests/__pycache__/test_simplify.cpython-39.pyc,, +sympy/simplify/tests/__pycache__/test_sqrtdenest.cpython-39.pyc,, +sympy/simplify/tests/__pycache__/test_trigsimp.cpython-39.pyc,, +sympy/simplify/tests/test_combsimp.py,sha256=O95WSxCvo2fDQs-UlarAcSf0_8M3PuTR76lhREDoNA8,2958 +sympy/simplify/tests/test_cse.py,sha256=JOjH5G6GPfeaKpKdeeKL1vvHgbt8TGxTkqmH2e1ovY8,25516 +sympy/simplify/tests/test_cse_diff.py,sha256=Y_zYZpldwjUoWiTTpFMlw4PXkqNp9r9Zwrw46gkmNoY,7275 +sympy/simplify/tests/test_epathtools.py,sha256=ugsQlfuK6POiixdeit63QovsVAlG5JyCaPlPp0j35LE,3525 +sympy/simplify/tests/test_fu.py,sha256=X2NNsynwefg2aH5GhyxQItL80fKPZ9md6nydMEedWbQ,19410 +sympy/simplify/tests/test_function.py,sha256=gzdcSFObuDzVFJDdAgmERtZJvG38WNSmclPAdG8OaPQ,2199 +sympy/simplify/tests/test_gammasimp.py,sha256=32cPRmtG-_Mz9g02lmmn-PWDD3J_Ku6sxLxIUU7WqxE,5320 +sympy/simplify/tests/test_hyperexpand.py,sha256=y3lxd97UZ1BkrrZe0r2l6MM8zNABTE_L4KtxFwmdzeQ,41043 +sympy/simplify/tests/test_powsimp.py,sha256=FWI_K8nozwYzhrhY1Q0CZkf6n88hFR0jjuYRvQnMius,14375 +sympy/simplify/tests/test_radsimp.py,sha256=Gf6Tf7-lb5G8gqF6hJA6kPrZNFES6fCz7_-ls7TFUfY,19128 +sympy/simplify/tests/test_ratsimp.py,sha256=bv-K60A7m2on-U9szzaeYO7Hlp1EK5P0HvBBO59oyas,2198 +sympy/simplify/tests/test_rewrite.py,sha256=LZj4V6a95GJj1o3NlKRoHMk7sWGPASFlw24nsm4z43k,1127 +sympy/simplify/tests/test_simplify.py,sha256=QkiaVVj0nqHXvu1lUUYZfYZ6_Ods6DfiBwFiqAJKSoo,41928 +sympy/simplify/tests/test_sqrtdenest.py,sha256=4zRtDQVGpKRRBYSAnEF5pSM0AR_fAMumONu2Ocb3tqg,7470 +sympy/simplify/tests/test_trigsimp.py,sha256=vG5PDTDNOuFypT7H9DSMjIollPqkKdNhWv5FBj6vFnE,19949 +sympy/simplify/traversaltools.py,sha256=pn_t9Yrk_SL1X0vl-zVR6yZaxkY25D4MwTBv4ywnD1Y,409 +sympy/simplify/trigsimp.py,sha256=t36QKVf_0yqVCNo1hUBlBp5locL0sMH1AlsAwQCzpLQ,46857 +sympy/solvers/__init__.py,sha256=nb1LceF2MwvULXsCN8VUq7X2tVgZOPsmWYHGFszIdW4,2333 +sympy/solvers/__pycache__/__init__.cpython-39.pyc,, +sympy/solvers/__pycache__/bivariate.cpython-39.pyc,, +sympy/solvers/__pycache__/decompogen.cpython-39.pyc,, +sympy/solvers/__pycache__/deutils.cpython-39.pyc,, +sympy/solvers/__pycache__/inequalities.cpython-39.pyc,, +sympy/solvers/__pycache__/pde.cpython-39.pyc,, +sympy/solvers/__pycache__/polysys.cpython-39.pyc,, +sympy/solvers/__pycache__/recurr.cpython-39.pyc,, +sympy/solvers/__pycache__/simplex.cpython-39.pyc,, +sympy/solvers/__pycache__/solvers.cpython-39.pyc,, +sympy/solvers/__pycache__/solveset.cpython-39.pyc,, +sympy/solvers/benchmarks/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/solvers/benchmarks/__pycache__/__init__.cpython-39.pyc,, +sympy/solvers/benchmarks/__pycache__/bench_solvers.cpython-39.pyc,, +sympy/solvers/benchmarks/bench_solvers.py,sha256=ZVK2TIW0XjWRDBex054ymmVlSBQw-RIBhEL1wS2ZAmU,288 +sympy/solvers/bivariate.py,sha256=Ku4Ttid3yU3CIc4IPfJ7bedLtXVaGKvFOonEIlZrXXw,17870 +sympy/solvers/decompogen.py,sha256=dWQla7hp7A4RqI2a0qRNQLWNPEuur68lD3dVTyktdBU,3757 +sympy/solvers/deutils.py,sha256=jukqDw6JGZfl36HRFjZIX_A0_iarmTawI0Up7z5tQDc,10315 +sympy/solvers/diophantine/__init__.py,sha256=I1p3uj3kFQv20cbsZ34K5rNCx1_pDS7JwHUCFstpBgs,128 +sympy/solvers/diophantine/__pycache__/__init__.cpython-39.pyc,, +sympy/solvers/diophantine/__pycache__/diophantine.cpython-39.pyc,, +sympy/solvers/diophantine/diophantine.py,sha256=n74XrMIW6ZwrAfaVVks7nwB01GClI0SYdvTe5K90Jnc,122482 +sympy/solvers/diophantine/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/solvers/diophantine/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/solvers/diophantine/tests/__pycache__/test_diophantine.cpython-39.pyc,, +sympy/solvers/diophantine/tests/test_diophantine.py,sha256=N7ZCyNVfSV0JzAKbnc_8YbUFq5ZkiDLnO6MxxMxCG90,43335 +sympy/solvers/inequalities.py,sha256=kz6HzAytg4V2eqSKwGzbJuYyFfgj6bFSUgk1i15zVp0,33189 +sympy/solvers/ode/__init__.py,sha256=I7RKwCcaoerflUm5i3ZDJgBIOnkhBjb83BCHcVcFqfM,468 +sympy/solvers/ode/__pycache__/__init__.cpython-39.pyc,, +sympy/solvers/ode/__pycache__/hypergeometric.cpython-39.pyc,, +sympy/solvers/ode/__pycache__/lie_group.cpython-39.pyc,, +sympy/solvers/ode/__pycache__/nonhomogeneous.cpython-39.pyc,, +sympy/solvers/ode/__pycache__/ode.cpython-39.pyc,, +sympy/solvers/ode/__pycache__/riccati.cpython-39.pyc,, +sympy/solvers/ode/__pycache__/single.cpython-39.pyc,, +sympy/solvers/ode/__pycache__/subscheck.cpython-39.pyc,, +sympy/solvers/ode/__pycache__/systems.cpython-39.pyc,, +sympy/solvers/ode/hypergeometric.py,sha256=t9KN5AdB7yF-rRoatEBzUYBcrwvIrKm83KBDa5BJYqY,10048 +sympy/solvers/ode/lie_group.py,sha256=C2HoOVtDbIT2ryx4TzEvQU46Yqc8LurTBBtjv5m2-BQ,39175 +sympy/solvers/ode/nonhomogeneous.py,sha256=Wi6UxjYNAPcfJsEzKHNvIJ8zhJ7-ZOkqccCR93MpUuM,17875 +sympy/solvers/ode/ode.py,sha256=9UQ0bqcrFOsGuSxGfRJpd2CllIWfiUYZPC2-A7Crefo,145379 +sympy/solvers/ode/riccati.py,sha256=3ZrkCy3ufXCyCx8gJfAuxyCOgbOLEHtvTh2VgJhA7Mk,30736 +sympy/solvers/ode/single.py,sha256=yBpPVFK3LsCf07dDb9_x6tzhp5d4ln_xIpMZpP2Ryrc,109465 +sympy/solvers/ode/subscheck.py,sha256=CIPca_qTxL9z5oaD2e2NrgME0eVQgF9PabZndcVqHZM,16130 +sympy/solvers/ode/systems.py,sha256=Kok2AMO4YEKX38fi2LrxLm5zpTcNXz398b7OBvrUezU,71467 +sympy/solvers/ode/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/solvers/ode/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/solvers/ode/tests/__pycache__/test_lie_group.cpython-39.pyc,, +sympy/solvers/ode/tests/__pycache__/test_ode.cpython-39.pyc,, +sympy/solvers/ode/tests/__pycache__/test_riccati.cpython-39.pyc,, +sympy/solvers/ode/tests/__pycache__/test_single.cpython-39.pyc,, +sympy/solvers/ode/tests/__pycache__/test_subscheck.cpython-39.pyc,, +sympy/solvers/ode/tests/__pycache__/test_systems.cpython-39.pyc,, +sympy/solvers/ode/tests/test_lie_group.py,sha256=vg1yy_-a5x1Xm2IcVkEi5cD2uA5wE5gjqpfBwkV1vZc,5319 +sympy/solvers/ode/tests/test_ode.py,sha256=bEcTi9zA5TooKOE9duFtd6TAVRvLPrO92arX7fF85-Q,48508 +sympy/solvers/ode/tests/test_riccati.py,sha256=a2_pXzCFb9WDtT_kdxYkfLd-PG5xvs4rQn_vpFk-O9s,29348 +sympy/solvers/ode/tests/test_single.py,sha256=sEvhXws920Oe_IuvmBIYM5iPViR6I3njZNgJlc8sMC4,100200 +sympy/solvers/ode/tests/test_subscheck.py,sha256=Gzwc9h9n6zlNOhJ8Qh6fQDeB8ghaRmgv3ktBAfPJx-U,12468 +sympy/solvers/ode/tests/test_systems.py,sha256=PeIEHOx8-eZQNuqOfhjTvGEeFSVriLlHShPmy84mde4,129087 +sympy/solvers/pde.py,sha256=V2jxgmmepng2ofFFDVOsBHWqf-IUQCgIz4lXzhBUeyg,34913 +sympy/solvers/polysys.py,sha256=4CCK3FNhl-K6B4bYF_goSrAq1JOMubl6KsmaB_5R2f4,27168 +sympy/solvers/recurr.py,sha256=DyssZuOyemoC6J1cWq635O7zkg1WLHrR7KGoM-gNy0g,25389 +sympy/solvers/simplex.py,sha256=peRv44aDM30U_0wk8xZv3t6M_cpm_a9hBALiyr7zndY,35276 +sympy/solvers/solvers.py,sha256=xsSlQaj0XaunexmhfQZZI_uiRXHETmaUab9hgZrmmbU,138454 +sympy/solvers/solveset.py,sha256=6VxsiKcn4fwdJGYXA8k1si5zvYs6vNrSTtLcmSKpjMk,151683 +sympy/solvers/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/solvers/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/solvers/tests/__pycache__/test_constantsimp.cpython-39.pyc,, +sympy/solvers/tests/__pycache__/test_decompogen.cpython-39.pyc,, +sympy/solvers/tests/__pycache__/test_inequalities.cpython-39.pyc,, +sympy/solvers/tests/__pycache__/test_numeric.cpython-39.pyc,, +sympy/solvers/tests/__pycache__/test_pde.cpython-39.pyc,, +sympy/solvers/tests/__pycache__/test_polysys.cpython-39.pyc,, +sympy/solvers/tests/__pycache__/test_recurr.cpython-39.pyc,, +sympy/solvers/tests/__pycache__/test_simplex.cpython-39.pyc,, +sympy/solvers/tests/__pycache__/test_solvers.cpython-39.pyc,, +sympy/solvers/tests/__pycache__/test_solveset.cpython-39.pyc,, +sympy/solvers/tests/test_constantsimp.py,sha256=9Feugsg9jD2BwQiG4EFpb9fORyst6JdBmZqq2GaOgH8,8707 +sympy/solvers/tests/test_decompogen.py,sha256=7GUsDQQZtYbZIK0p0UxsOuNEJxEt4IHeOSsem_k-k0U,2943 +sympy/solvers/tests/test_inequalities.py,sha256=whg3vGXEYxeIHNQS4yBeB9VQpoYWnfw5NBS4xLiqDJ8,21025 +sympy/solvers/tests/test_numeric.py,sha256=pKLBJuf4lGCewf5DgBjvH5T9phf6ilkyW1gB2R-5SzA,4734 +sympy/solvers/tests/test_pde.py,sha256=UGP3uWjF8pKQgfPifmdfvS5URVmzSg6m2NkS7LGzmio,9257 +sympy/solvers/tests/test_polysys.py,sha256=bK2ckIhQvbJCc6_qsUOAuAEyTc5CR8J2yiqN4MFOgpM,15266 +sympy/solvers/tests/test_recurr.py,sha256=-OeghSg16GFN70y_RUXC6CF6VU_b7NXaKDbejtRSocg,11418 +sympy/solvers/tests/test_simplex.py,sha256=pG7j7KxXrlK8K_fF04XH_mvMzwEOopSBpJZiagKjyNs,9037 +sympy/solvers/tests/test_solvers.py,sha256=eXM6AT2upxlRC7S2l1YxUSS6jJNM42PxzYCTz4C4vmc,107765 +sympy/solvers/tests/test_solveset.py,sha256=lOQ4hG7Dg6yuDgFriKIlvWdzEumehgL08Srsh4JwDHY,148480 +sympy/stats/__init__.py,sha256=9pY3OMdvAIeJ_Q1ZqynOZ8hHWFFrrxOx0qBHKkGfHCg,8487 +sympy/stats/__pycache__/__init__.cpython-39.pyc,, +sympy/stats/__pycache__/compound_rv.cpython-39.pyc,, +sympy/stats/__pycache__/crv.cpython-39.pyc,, +sympy/stats/__pycache__/crv_types.cpython-39.pyc,, +sympy/stats/__pycache__/drv.cpython-39.pyc,, +sympy/stats/__pycache__/drv_types.cpython-39.pyc,, +sympy/stats/__pycache__/error_prop.cpython-39.pyc,, +sympy/stats/__pycache__/frv.cpython-39.pyc,, +sympy/stats/__pycache__/frv_types.cpython-39.pyc,, +sympy/stats/__pycache__/joint_rv.cpython-39.pyc,, +sympy/stats/__pycache__/joint_rv_types.cpython-39.pyc,, +sympy/stats/__pycache__/matrix_distributions.cpython-39.pyc,, +sympy/stats/__pycache__/random_matrix.cpython-39.pyc,, +sympy/stats/__pycache__/random_matrix_models.cpython-39.pyc,, +sympy/stats/__pycache__/rv.cpython-39.pyc,, +sympy/stats/__pycache__/rv_interface.cpython-39.pyc,, +sympy/stats/__pycache__/stochastic_process.cpython-39.pyc,, +sympy/stats/__pycache__/stochastic_process_types.cpython-39.pyc,, +sympy/stats/__pycache__/symbolic_multivariate_probability.cpython-39.pyc,, +sympy/stats/__pycache__/symbolic_probability.cpython-39.pyc,, +sympy/stats/compound_rv.py,sha256=SO1KXJ0aHGbD5y9QA8o6qOHbio3ua8wyO2Rsh0Hnw48,7965 +sympy/stats/crv.py,sha256=jd8iemE41aW-byXgFDYKaMv2VOOUnyUtiMx_QAXI-n4,21028 +sympy/stats/crv_types.py,sha256=qXYxtvJPg4ORJI7AxIkRhwnJ6lcz3sVBWf4toQy9vvA,122371 +sympy/stats/drv.py,sha256=Zu8IK9yt69XJFKt-ShoJfZbvZ-CGMwvrKNiXvi6q_n4,11994 +sympy/stats/drv_types.py,sha256=2jL6QYa9lAlbvTxvx9BXdfwU6PRt0uzV3CVXRIeZ5k0,19865 +sympy/stats/error_prop.py,sha256=a-H6GZEidsiP_4-iNw7nSD99AMyN6DNHsSl0IUZGIAs,3315 +sympy/stats/frv.py,sha256=vZROeD6DSVGX0kfOL_yOds3pZgjSiXFX-bMZtSUkVMA,16874 +sympy/stats/frv_types.py,sha256=Xkc3MU5dFLajPajIStW_lV4qiuiHg86rBo45zX_oGLk,23348 +sympy/stats/joint_rv.py,sha256=DcixlO2Ml4gnwMmZk2VTegiHVq88DkLdQlOTQ57SQtc,15963 +sympy/stats/joint_rv_types.py,sha256=PUatR4WcPHmAHadt8iRh5xYh5NJigzYh-EoAMR5blDw,30575 +sympy/stats/matrix_distributions.py,sha256=3OricwEMM_NU8b2lJxoiSTml7kvqrNQ6IUIn9Xy_DsY,21953 +sympy/stats/random_matrix.py,sha256=NmzLC5JMDWI2TvH8tY6go8lYyHmqcZ-B7sSIO7z7oAk,1028 +sympy/stats/random_matrix_models.py,sha256=7i5XAUYxt-ekmP5KDMaytUlmCvxglEspoWbswSf82tE,15328 +sympy/stats/rv.py,sha256=rmjdVpsvJrqtMgnreuT_7lU5S_V2nkvlrpUvA7zCWLM,54579 +sympy/stats/rv_interface.py,sha256=0ZSplwmkEtio_UTlh7bqSK_1aILgvk2JTfZ-CkOcfAc,13935 +sympy/stats/sampling/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/stats/sampling/__pycache__/__init__.cpython-39.pyc,, +sympy/stats/sampling/__pycache__/sample_numpy.cpython-39.pyc,, +sympy/stats/sampling/__pycache__/sample_pymc.cpython-39.pyc,, +sympy/stats/sampling/__pycache__/sample_scipy.cpython-39.pyc,, +sympy/stats/sampling/sample_numpy.py,sha256=B4ZC7ZBrSD6ICQT468rOy-xrOgQDuecsHa0zJesAeYE,4229 +sympy/stats/sampling/sample_pymc.py,sha256=9g-n04aXSFc6F7FJ5zTYtHHL6W8-26g1nrgtamJc3Hw,2995 +sympy/stats/sampling/sample_scipy.py,sha256=ysqpDy8bp1RMH0g5FFgMmp2SQuXGFkcSH7JDZEpiZ8w,6329 +sympy/stats/sampling/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/stats/sampling/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/stats/sampling/tests/__pycache__/test_sample_continuous_rv.cpython-39.pyc,, +sympy/stats/sampling/tests/__pycache__/test_sample_discrete_rv.cpython-39.pyc,, +sympy/stats/sampling/tests/__pycache__/test_sample_finite_rv.cpython-39.pyc,, +sympy/stats/sampling/tests/test_sample_continuous_rv.py,sha256=Gh8hFN1hFFsthEv9wP2ZdgghQfaEnE8n7HlmyXXhN1E,5708 +sympy/stats/sampling/tests/test_sample_discrete_rv.py,sha256=Z3hqjHEFaIoB0hrhpy7YYL_7Dtqzmy5YMWAYrlqPu38,3360 +sympy/stats/sampling/tests/test_sample_finite_rv.py,sha256=dWwrFePw8eX2rBheAXi1AVxr_gqBD63VZKfW81hNoQc,3061 +sympy/stats/stochastic_process.py,sha256=pDz0rbKXTiaNmMmmz70dP3F_KWL_XhoCKFHYBNt1QeU,2312 +sympy/stats/stochastic_process_types.py,sha256=YEUA-saTh8WF6Pn7GcgrcyclRCb_SC13YnpgOIpu7-Q,88553 +sympy/stats/symbolic_multivariate_probability.py,sha256=R6Co7XCcxLoOtTqC6ZSnGuylZNUBrC5AD0DrJr2jE1A,10450 +sympy/stats/symbolic_probability.py,sha256=HkWUiH9EM_DQ1aM1FxNmBm_8xtqeviIsyUpqFl2RYnI,23266 +sympy/stats/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/stats/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/stats/tests/__pycache__/test_compound_rv.cpython-39.pyc,, +sympy/stats/tests/__pycache__/test_continuous_rv.cpython-39.pyc,, +sympy/stats/tests/__pycache__/test_discrete_rv.cpython-39.pyc,, +sympy/stats/tests/__pycache__/test_error_prop.cpython-39.pyc,, +sympy/stats/tests/__pycache__/test_finite_rv.cpython-39.pyc,, +sympy/stats/tests/__pycache__/test_joint_rv.cpython-39.pyc,, +sympy/stats/tests/__pycache__/test_matrix_distributions.cpython-39.pyc,, +sympy/stats/tests/__pycache__/test_mix.cpython-39.pyc,, +sympy/stats/tests/__pycache__/test_random_matrix.cpython-39.pyc,, +sympy/stats/tests/__pycache__/test_rv.cpython-39.pyc,, +sympy/stats/tests/__pycache__/test_stochastic_process.cpython-39.pyc,, +sympy/stats/tests/__pycache__/test_symbolic_multivariate.cpython-39.pyc,, +sympy/stats/tests/__pycache__/test_symbolic_probability.cpython-39.pyc,, +sympy/stats/tests/test_compound_rv.py,sha256=2927chbHTThA34Ki-ji319QT7ajQ1ueC640Mga-18ZA,6263 +sympy/stats/tests/test_continuous_rv.py,sha256=BcBkiyX7e1QiwS6xnD-qLzXnijvqbsebTzXEn3IfGyE,56140 +sympy/stats/tests/test_discrete_rv.py,sha256=bLqMtKpH5fDjuP7uuRAtAjakbahuFz9x3dY-PPKFsGk,11230 +sympy/stats/tests/test_error_prop.py,sha256=xKAkw3F5XJ72xiDREI7PkyReWNVW_89CD_mjOY_diDY,1933 +sympy/stats/tests/test_finite_rv.py,sha256=JHYgY4snFF5t9qcnQfKaN5zaGsO7_SuNR7Tq234W4No,20413 +sympy/stats/tests/test_joint_rv.py,sha256=W28rCRYczv5Jax7k-bj7OveT-y-AP4q-kRR0-LNaWX0,18653 +sympy/stats/tests/test_matrix_distributions.py,sha256=9daJUiSGaLq34TeZfB-xPqC8xz6vECGrm0DdBZaQPyY,8857 +sympy/stats/tests/test_mix.py,sha256=Cplnw06Ki96Y_4fx6Bu7lUXjxoIfX7tNJasm9SOz5wQ,3991 +sympy/stats/tests/test_random_matrix.py,sha256=CiD1hV25MGHwTfHGaoaehGD3iJ4lqNYi-ZiwReO6CVk,5842 +sympy/stats/tests/test_rv.py,sha256=Bp7UwffIMO7oc8UnFV11yYGcXUjSa0NhsuOgQaNRMt8,12959 +sympy/stats/tests/test_stochastic_process.py,sha256=i-VCOZrjpJtvyTBm9xgniTCkk_iUYIuFnkiyxzkn6Ig,39323 +sympy/stats/tests/test_symbolic_multivariate.py,sha256=G3AgbRbt0DQ-p0DYXYDjbx4e4f5FIgd31F34e0NO2n8,5580 +sympy/stats/tests/test_symbolic_probability.py,sha256=k5trScMiwSgl9dzJt30BV-t0KuYcyD-s9HtT2-hVhQ0,9398 +sympy/strategies/__init__.py,sha256=XaTAPqDoi6527juvR8LLN1mv6ZcslDrGloTTBMjJzxA,1402 +sympy/strategies/__pycache__/__init__.cpython-39.pyc,, +sympy/strategies/__pycache__/core.cpython-39.pyc,, +sympy/strategies/__pycache__/rl.cpython-39.pyc,, +sympy/strategies/__pycache__/tools.cpython-39.pyc,, +sympy/strategies/__pycache__/traverse.cpython-39.pyc,, +sympy/strategies/__pycache__/tree.cpython-39.pyc,, +sympy/strategies/__pycache__/util.cpython-39.pyc,, +sympy/strategies/branch/__init__.py,sha256=xxbMwR2LzLcQWsH9ss8ddE99VHFJTY-cYiR6xhO3tj0,356 +sympy/strategies/branch/__pycache__/__init__.cpython-39.pyc,, +sympy/strategies/branch/__pycache__/core.cpython-39.pyc,, +sympy/strategies/branch/__pycache__/tools.cpython-39.pyc,, +sympy/strategies/branch/__pycache__/traverse.cpython-39.pyc,, +sympy/strategies/branch/core.py,sha256=QiXSa7uhvmUBTLyUwBQHrYkWlOceKh5p4kVD90VnCKM,2759 +sympy/strategies/branch/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/strategies/branch/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/strategies/branch/tests/__pycache__/test_core.cpython-39.pyc,, +sympy/strategies/branch/tests/__pycache__/test_tools.cpython-39.pyc,, +sympy/strategies/branch/tests/__pycache__/test_traverse.cpython-39.pyc,, +sympy/strategies/branch/tests/test_core.py,sha256=23KQWJxC_2T1arwMAkt9pY1ZtG59avlxTZcVTn81UPI,2246 +sympy/strategies/branch/tests/test_tools.py,sha256=4BDkqVqrTlsivQ0PldQr6PjVZsAikc39tSxGAQA3ir8,942 +sympy/strategies/branch/tests/test_traverse.py,sha256=6rikMnZdamSzww1sSiM-aQwqa4lQrpM-DpOU9XCbiOQ,1322 +sympy/strategies/branch/tools.py,sha256=tvv3IjmQGNYbo-slCbbDf_rylZd537wvLcpdBtT-bbY,357 +sympy/strategies/branch/traverse.py,sha256=7iBViQdNpKu-AHoFED7_C9KBSyYcQBfLGopEJQbNtvk,799 +sympy/strategies/core.py,sha256=nsH6LZgyc_aslv4Na5XvJMEizC6uSzscRlVW91k1pu4,3956 +sympy/strategies/rl.py,sha256=-M5lfrtcREkvbW4YfoiB19wfYyMLKz6SmL887ymKeoI,4404 +sympy/strategies/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/strategies/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/strategies/tests/__pycache__/test_core.cpython-39.pyc,, +sympy/strategies/tests/__pycache__/test_rl.cpython-39.pyc,, +sympy/strategies/tests/__pycache__/test_tools.cpython-39.pyc,, +sympy/strategies/tests/__pycache__/test_traverse.cpython-39.pyc,, +sympy/strategies/tests/__pycache__/test_tree.cpython-39.pyc,, +sympy/strategies/tests/test_core.py,sha256=42XHlv1hN1S1QPEf2r9pddZ2EQL6o4FEPQvfo-UmXcw,2152 +sympy/strategies/tests/test_rl.py,sha256=wm0L6pdvddBgRcwhpiSk-nCgyzVGickfnOCkmHWS0j4,1949 +sympy/strategies/tests/test_tools.py,sha256=UdMojFIn3f1b2x2iRGv1Wfnwdso-Kl57GTyjCU_DjzQ,875 +sympy/strategies/tests/test_traverse.py,sha256=jWuZhYEt-F18_rxEMhn6OgGQ1GNs-dM_GFZ2F5nHs2I,2082 +sympy/strategies/tests/test_tree.py,sha256=9NL948rt6i9tYU6CQz9VNxE6l1begQs-MxP2euzE3Sc,2400 +sympy/strategies/tools.py,sha256=ERASzEP2SP-EcJ8p-4XyREYB15q3t81x1cyamJ-M880,1368 +sympy/strategies/traverse.py,sha256=DhPnBJ5Rw_xzhGiBtSciTyV-H2zhlxgjYVjrNH-gLyk,1183 +sympy/strategies/tree.py,sha256=ggnP9l3NIpJsssBMVKr4-yM_m8uCkrkm191ZC6MfZjc,3770 +sympy/strategies/util.py,sha256=2fbR813IY4IYco5mBoGJLu5z88OhXmwuIxgOO9IvZO4,361 +sympy/tensor/__init__.py,sha256=VMNXCRSayigQT6a3cvf5M_M-wdV-KSil_JbAmHcuUQc,870 +sympy/tensor/__pycache__/__init__.cpython-39.pyc,, +sympy/tensor/__pycache__/functions.cpython-39.pyc,, +sympy/tensor/__pycache__/index_methods.cpython-39.pyc,, +sympy/tensor/__pycache__/indexed.cpython-39.pyc,, +sympy/tensor/__pycache__/tensor.cpython-39.pyc,, +sympy/tensor/__pycache__/toperators.cpython-39.pyc,, +sympy/tensor/array/__init__.py,sha256=lTT1EwV5tb3WAvmmS_mIjhCSWSLiB0NNPW4n9_3fu0k,8244 +sympy/tensor/array/__pycache__/__init__.cpython-39.pyc,, +sympy/tensor/array/__pycache__/array_comprehension.cpython-39.pyc,, +sympy/tensor/array/__pycache__/array_derivatives.cpython-39.pyc,, +sympy/tensor/array/__pycache__/arrayop.cpython-39.pyc,, +sympy/tensor/array/__pycache__/dense_ndim_array.cpython-39.pyc,, +sympy/tensor/array/__pycache__/mutable_ndim_array.cpython-39.pyc,, +sympy/tensor/array/__pycache__/ndim_array.cpython-39.pyc,, +sympy/tensor/array/__pycache__/sparse_ndim_array.cpython-39.pyc,, +sympy/tensor/array/array_comprehension.py,sha256=01PTIbkAGaq0CDcaI_2KsaMnYm1nxQ8sFAiHHcc__gw,12262 +sympy/tensor/array/array_derivatives.py,sha256=c-gYeA_qpXOY3aexyz7psSqmTVIGVBrcGDvSkW5dZV0,4796 +sympy/tensor/array/arrayop.py,sha256=KvFGYWcYvChWkThtVAotlaSrcfjHogAxvpxWvp6dSgo,18397 +sympy/tensor/array/dense_ndim_array.py,sha256=Ie8qVMJyp2Tsq7aVhmZpPX8X-KTlF9uaxkQfTzCZ9z8,6433 +sympy/tensor/array/expressions/__init__.py,sha256=OUMJjZY7HtWJL0ygqkdWC8LdCqibJZhHCfYeXu-eB4E,7045 +sympy/tensor/array/expressions/__pycache__/__init__.cpython-39.pyc,, +sympy/tensor/array/expressions/__pycache__/array_expressions.cpython-39.pyc,, +sympy/tensor/array/expressions/__pycache__/arrayexpr_derivatives.cpython-39.pyc,, +sympy/tensor/array/expressions/__pycache__/conv_array_to_indexed.cpython-39.pyc,, +sympy/tensor/array/expressions/__pycache__/conv_array_to_matrix.cpython-39.pyc,, +sympy/tensor/array/expressions/__pycache__/conv_indexed_to_array.cpython-39.pyc,, +sympy/tensor/array/expressions/__pycache__/conv_matrix_to_array.cpython-39.pyc,, +sympy/tensor/array/expressions/__pycache__/from_array_to_indexed.cpython-39.pyc,, +sympy/tensor/array/expressions/__pycache__/from_array_to_matrix.cpython-39.pyc,, +sympy/tensor/array/expressions/__pycache__/from_indexed_to_array.cpython-39.pyc,, +sympy/tensor/array/expressions/__pycache__/from_matrix_to_array.cpython-39.pyc,, +sympy/tensor/array/expressions/__pycache__/utils.cpython-39.pyc,, +sympy/tensor/array/expressions/array_expressions.py,sha256=28TI8cwwQ4Q-QCI9X_j6X-lG0kR6amjAe9-h9JKb5EU,76961 +sympy/tensor/array/expressions/arrayexpr_derivatives.py,sha256=W9-bY2LL83lLSNHXItzqjOgvf-HIDbUXPoVw8uOymcg,6249 +sympy/tensor/array/expressions/conv_array_to_indexed.py,sha256=BIwlQr7RKC8bZN3mR8ICC5TYOC9uasYcV0Zc1VNKmiE,445 +sympy/tensor/array/expressions/conv_array_to_matrix.py,sha256=85YZBTZI4o9dJtKDJXXug_lJVLG8dT_22AT7l7DKoyE,416 +sympy/tensor/array/expressions/conv_indexed_to_array.py,sha256=EyW52TplBxIx25mUDvI_5Tzc8LD6Mnp6XNW9wIw9pH4,254 +sympy/tensor/array/expressions/conv_matrix_to_array.py,sha256=XYyqt0NsQSrgNpEkr8xTGeUhR7ZYeNljVFfVEF1K7vA,250 +sympy/tensor/array/expressions/from_array_to_indexed.py,sha256=3YIcsAzWVWQRJYQS90uPvSl2dM7ZqLV_qt7E9-uYU28,3936 +sympy/tensor/array/expressions/from_array_to_matrix.py,sha256=4HUbCfR2BHd54Ijws9DCcSs6aAJ-k8RwT195NTQImis,41435 +sympy/tensor/array/expressions/from_indexed_to_array.py,sha256=RUcKemmrwuK5RFRr19YSPVMCOkZfLAWlbbB56u8Wi0g,11187 +sympy/tensor/array/expressions/from_matrix_to_array.py,sha256=yIY1RupF9-FVV3jZLsqWxZ1ckoE1-HkQyM8cQIm4_Gs,3929 +sympy/tensor/array/expressions/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/tensor/array/expressions/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/tensor/array/expressions/tests/__pycache__/test_array_expressions.cpython-39.pyc,, +sympy/tensor/array/expressions/tests/__pycache__/test_arrayexpr_derivatives.cpython-39.pyc,, +sympy/tensor/array/expressions/tests/__pycache__/test_as_explicit.cpython-39.pyc,, +sympy/tensor/array/expressions/tests/__pycache__/test_convert_array_to_indexed.cpython-39.pyc,, +sympy/tensor/array/expressions/tests/__pycache__/test_convert_array_to_matrix.cpython-39.pyc,, +sympy/tensor/array/expressions/tests/__pycache__/test_convert_indexed_to_array.cpython-39.pyc,, +sympy/tensor/array/expressions/tests/__pycache__/test_convert_matrix_to_array.cpython-39.pyc,, +sympy/tensor/array/expressions/tests/__pycache__/test_deprecated_conv_modules.cpython-39.pyc,, +sympy/tensor/array/expressions/tests/test_array_expressions.py,sha256=QUAdxQ9TvBpDEAZoJpLSWwbqjmuflPe3xBRP30lFZr0,31262 +sympy/tensor/array/expressions/tests/test_arrayexpr_derivatives.py,sha256=lpC4ly6MJLDRBcVt3GcP3H6ke9bI-o3VULw0xyF5QbY,2470 +sympy/tensor/array/expressions/tests/test_as_explicit.py,sha256=nOjFKXCqYNu2O7Szc1TD1x1bsUchPRAG3nGlNGEd1Yg,2568 +sympy/tensor/array/expressions/tests/test_convert_array_to_indexed.py,sha256=6yNxGXH6BX5607FTjMkwR2t9wNVlEhV8JMSh4UIWux8,2500 +sympy/tensor/array/expressions/tests/test_convert_array_to_matrix.py,sha256=2vkSep9CPKYrQQS0u8Ayn_sc7yek1zwzjjCWK5cfYe8,29311 +sympy/tensor/array/expressions/tests/test_convert_indexed_to_array.py,sha256=RVEG_qUsXiBH9gHtWp2-9pMC4J2aLc4iUdzBFM0QyTw,8615 +sympy/tensor/array/expressions/tests/test_convert_matrix_to_array.py,sha256=G2g5E0l-FABwYyQowbKKvLcEI8NViJXaYLW3eUEcvjw,4595 +sympy/tensor/array/expressions/tests/test_deprecated_conv_modules.py,sha256=DG8IoUtxCy2acWjUHUUKu4bRsTxXbeFLFjKMLA2GdLY,1216 +sympy/tensor/array/expressions/utils.py,sha256=Rn58boHHUEoBZFtinDpruLWFBkNBwgkVQ4c9m7Nym1o,3939 +sympy/tensor/array/mutable_ndim_array.py,sha256=M0PTt8IOIcVXqQPWe2N50sm4Eq2bodRXV4Vkd08crXk,277 +sympy/tensor/array/ndim_array.py,sha256=bZvA21fSTMc2sQLJN4AovQspddzmnINBWFPr-TeGmm0,19138 +sympy/tensor/array/sparse_ndim_array.py,sha256=4nD_Hg-JdC_1mYQTohmKFfL5M1Ugdq0fpnDUILkTtq8,6387 +sympy/tensor/array/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/tensor/array/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/tensor/array/tests/__pycache__/test_array_comprehension.cpython-39.pyc,, +sympy/tensor/array/tests/__pycache__/test_array_derivatives.cpython-39.pyc,, +sympy/tensor/array/tests/__pycache__/test_arrayop.cpython-39.pyc,, +sympy/tensor/array/tests/__pycache__/test_immutable_ndim_array.cpython-39.pyc,, +sympy/tensor/array/tests/__pycache__/test_mutable_ndim_array.cpython-39.pyc,, +sympy/tensor/array/tests/__pycache__/test_ndim_array.cpython-39.pyc,, +sympy/tensor/array/tests/__pycache__/test_ndim_array_conversions.cpython-39.pyc,, +sympy/tensor/array/tests/test_array_comprehension.py,sha256=e8MsWbvwmr-HxTfaM7i8HoIVofa8InLzF9PTCIVvzjU,4529 +sympy/tensor/array/tests/test_array_derivatives.py,sha256=hS10Bkb3F2kWoFoxk2ucr21p0r5DfwHDYPI8HcmAl0o,1601 +sympy/tensor/array/tests/test_arrayop.py,sha256=WahGcUnArsAo9eaMqGT7_AjKons0WgFzLOWTtNvnSEI,25844 +sympy/tensor/array/tests/test_immutable_ndim_array.py,sha256=9ji_14szn-qoL6DQ5muzIFNaXefT7n55PFigXoFwk50,15823 +sympy/tensor/array/tests/test_mutable_ndim_array.py,sha256=rFFa0o0AJYgPNnpqijl91Vb9EW2kgHGQc6cu9f1fIvY,13070 +sympy/tensor/array/tests/test_ndim_array.py,sha256=KH-9LAME3ldVIu5n7Vd_Xr36dN4frCdiF9qZdBWETu0,2232 +sympy/tensor/array/tests/test_ndim_array_conversions.py,sha256=CUGDCbCcslACy3Ngq-zoig9JnO4yHTw3IPcKy0FnRpw,648 +sympy/tensor/functions.py,sha256=FNZ2M0HGVaASJ_1AkPu1vF5NHNLmBXKLsj0Q5wKMX90,4168 +sympy/tensor/index_methods.py,sha256=dcX9kNKLHi_XXkFHBPS-fcM-PaeYKkX80jmzxC0siiQ,15434 +sympy/tensor/indexed.py,sha256=3qiQehMvXSDbr_RDXg6l_aEu6dHZXEfXrMoXCJj7Cqo,24515 +sympy/tensor/tensor.py,sha256=EQ9qJjtvij4gG101jB1eOuTtaSmvm1Da3IHVCBn1Flo,181094 +sympy/tensor/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/tensor/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/tensor/tests/__pycache__/test_functions.cpython-39.pyc,, +sympy/tensor/tests/__pycache__/test_index_methods.cpython-39.pyc,, +sympy/tensor/tests/__pycache__/test_indexed.cpython-39.pyc,, +sympy/tensor/tests/__pycache__/test_printing.cpython-39.pyc,, +sympy/tensor/tests/__pycache__/test_tensor.cpython-39.pyc,, +sympy/tensor/tests/__pycache__/test_tensor_element.cpython-39.pyc,, +sympy/tensor/tests/__pycache__/test_tensor_operators.cpython-39.pyc,, +sympy/tensor/tests/test_functions.py,sha256=rBBHjJIUA2oR83UgEJ_GIASDWfTZXDzOllmcO90XYDU,1552 +sympy/tensor/tests/test_index_methods.py,sha256=Pu951z4yYYMOXBKcNteH63hTAxmNX8702nSQH_pciFE,7112 +sympy/tensor/tests/test_indexed.py,sha256=EB2-t0gYkaTOvrCkhyW2cEL_MVi4xBUZxbf35NA-puI,16376 +sympy/tensor/tests/test_printing.py,sha256=sUx_rChNTWFKPNwVl296QXO-d4-yemDJnkEHFislsmc,424 +sympy/tensor/tests/test_tensor.py,sha256=XQiLkOtyLLFKX_lWEqpCr2lXxQ6LK_hShj0rz7RrWpk,81236 +sympy/tensor/tests/test_tensor_element.py,sha256=1dF96FtqUGaJzethw23vJIj3H5KdxsU1Xyd4DU54EB4,908 +sympy/tensor/tests/test_tensor_operators.py,sha256=n4md5Dv7QoVSMR4cHbkEKSFDxgTTpEyxMiCAqoqyJyk,17966 +sympy/tensor/toperators.py,sha256=fniTUpdYz0OvtNnFgrHINedX86FxVcxfKj9l_l1p9Rw,8840 +sympy/testing/__init__.py,sha256=IN9aHvoksdZywonAYE0cExFnuPQs9z1E4P742aYZNnE,169 +sympy/testing/__pycache__/__init__.cpython-39.pyc,, +sympy/testing/__pycache__/matrices.cpython-39.pyc,, +sympy/testing/__pycache__/pytest.cpython-39.pyc,, +sympy/testing/__pycache__/quality_unicode.cpython-39.pyc,, +sympy/testing/__pycache__/randtest.cpython-39.pyc,, +sympy/testing/__pycache__/runtests.cpython-39.pyc,, +sympy/testing/__pycache__/runtests_pytest.cpython-39.pyc,, +sympy/testing/__pycache__/tmpfiles.cpython-39.pyc,, +sympy/testing/matrices.py,sha256=VWBPdjIUYNHE7fdbYcmQwQTYcIWpOP9tFn9A0rGCBmE,216 +sympy/testing/pytest.py,sha256=xiKL67rD4sKYVZ1Xqci4arPJQzrm3xWfg-m6I7eWNBM,13363 +sympy/testing/quality_unicode.py,sha256=quKFUpEUPftVGNpsj71WpRoeKZUzc8sgLjODRmHPf-w,3482 +sympy/testing/randtest.py,sha256=IKDFAm8b72Z1OkT7vpgnZjaW5LsSU_wf6g35sCkq9I0,562 +sympy/testing/runtests.py,sha256=dD9_0xtvG0wMgQtdvxDi7ePc2CL2ybNRqHKgHjJomzE,89921 +sympy/testing/runtests_pytest.py,sha256=KSqhzaRis7sx2AwctWsP6VjCbDcg6C5UCApp_LUNRuc,17801 +sympy/testing/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/testing/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/testing/tests/__pycache__/diagnose_imports.cpython-39.pyc,, +sympy/testing/tests/__pycache__/test_code_quality.cpython-39.pyc,, +sympy/testing/tests/__pycache__/test_deprecated.cpython-39.pyc,, +sympy/testing/tests/__pycache__/test_module_imports.cpython-39.pyc,, +sympy/testing/tests/__pycache__/test_pytest.cpython-39.pyc,, +sympy/testing/tests/__pycache__/test_runtests_pytest.cpython-39.pyc,, +sympy/testing/tests/diagnose_imports.py,sha256=VQAc0WLHgQgksu6kcQLmEVW4O98v7YgjLMzCLNZ6JfE,9508 +sympy/testing/tests/test_code_quality.py,sha256=Yzb1CAIAGTO29Ay2fsPn0-OqstAoy9FMiAMIed6iVZE,19214 +sympy/testing/tests/test_deprecated.py,sha256=wQZHs4wDNuK4flaKKLsJW6XRMtrVjMv_5rUP3WspgPA,183 +sympy/testing/tests/test_module_imports.py,sha256=5w6F6JW6K7lgpbB4X9Tj0Vw8AcNVlfaSuvbwKXJKD6c,1459 +sympy/testing/tests/test_pytest.py,sha256=iKO10Tvua1Xem6a22IWH4SDrpFfr-bM-rXx039Ua7YA,6778 +sympy/testing/tests/test_runtests_pytest.py,sha256=C4bo-SOx8BsxvT4M_0OPUiwSEXShaMqpYRoedMmbR1g,6083 +sympy/testing/tmpfiles.py,sha256=bF8ktKC9lDhS65gahB9hOewsZ378UkhLgq3QHiqWYXU,1042 +sympy/this.py,sha256=XfOkN5EIM2RuDxSm_q6k_R_WtkIoSy6PXWKp3aAXvoc,550 +sympy/unify/__init__.py,sha256=Upa9h7SSr9W1PXo0WkNESsGsMZ85rcWkeruBtkAi3Fg,293 +sympy/unify/__pycache__/__init__.cpython-39.pyc,, +sympy/unify/__pycache__/core.cpython-39.pyc,, +sympy/unify/__pycache__/rewrite.cpython-39.pyc,, +sympy/unify/__pycache__/usympy.cpython-39.pyc,, +sympy/unify/core.py,sha256=-BCNPPMdfZuhhIWqyn9pYJoO8yFPGDX78Hn2551ABuE,7037 +sympy/unify/rewrite.py,sha256=Emr8Uoum3gxKpMDqFHJIjx3xChArUIN6XIy6NPfCS8I,1798 +sympy/unify/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/unify/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/unify/tests/__pycache__/test_rewrite.cpython-39.pyc,, +sympy/unify/tests/__pycache__/test_sympy.cpython-39.pyc,, +sympy/unify/tests/__pycache__/test_unify.cpython-39.pyc,, +sympy/unify/tests/test_rewrite.py,sha256=BgA8zmdz9Nw-Xbu4-w3UABeWypqLvmy9VzL744EmYtE,2002 +sympy/unify/tests/test_sympy.py,sha256=UCItZJNAx9dG5F7O27pyXUF1-e6aOwkZ-cVdB6SZFZc,5922 +sympy/unify/tests/test_unify.py,sha256=4TlgchV6NWuBekJx9RGlMjx3-UwonzgIYXDytb7sBRU,3029 +sympy/unify/usympy.py,sha256=6Kxx96FXSdqXimLseVK_FkYwy2vqWhNnxMVPMRShvy4,3964 +sympy/utilities/__init__.py,sha256=nbQhzII8dw5zd4hQJ2SUyriK5dOrqf-bbjy10XKQXPw,840 +sympy/utilities/__pycache__/__init__.cpython-39.pyc,, +sympy/utilities/__pycache__/autowrap.cpython-39.pyc,, +sympy/utilities/__pycache__/codegen.cpython-39.pyc,, +sympy/utilities/__pycache__/decorator.cpython-39.pyc,, +sympy/utilities/__pycache__/enumerative.cpython-39.pyc,, +sympy/utilities/__pycache__/exceptions.cpython-39.pyc,, +sympy/utilities/__pycache__/iterables.cpython-39.pyc,, +sympy/utilities/__pycache__/lambdify.cpython-39.pyc,, +sympy/utilities/__pycache__/magic.cpython-39.pyc,, +sympy/utilities/__pycache__/matchpy_connector.cpython-39.pyc,, +sympy/utilities/__pycache__/memoization.cpython-39.pyc,, +sympy/utilities/__pycache__/misc.cpython-39.pyc,, +sympy/utilities/__pycache__/pkgdata.cpython-39.pyc,, +sympy/utilities/__pycache__/pytest.cpython-39.pyc,, +sympy/utilities/__pycache__/randtest.cpython-39.pyc,, +sympy/utilities/__pycache__/runtests.cpython-39.pyc,, +sympy/utilities/__pycache__/source.cpython-39.pyc,, +sympy/utilities/__pycache__/timeutils.cpython-39.pyc,, +sympy/utilities/__pycache__/tmpfiles.cpython-39.pyc,, +sympy/utilities/_compilation/__init__.py,sha256=uYUDPbwrMTbGEMVuago32EN_ix8fsi5M0SvcLOtwMOk,751 +sympy/utilities/_compilation/__pycache__/__init__.cpython-39.pyc,, +sympy/utilities/_compilation/__pycache__/availability.cpython-39.pyc,, +sympy/utilities/_compilation/__pycache__/compilation.cpython-39.pyc,, +sympy/utilities/_compilation/__pycache__/runners.cpython-39.pyc,, +sympy/utilities/_compilation/__pycache__/util.cpython-39.pyc,, +sympy/utilities/_compilation/availability.py,sha256=ybxp3mboH5772JHTWKBN1D-cs6QxATQiaL4zJVV4RE0,2884 +sympy/utilities/_compilation/compilation.py,sha256=qkfDtSgSeRJ-97E5hN1asUqoDXS4nnNgYoHCA8WAmHA,22156 +sympy/utilities/_compilation/runners.py,sha256=KAinYF0m55dRdDFO5pvHBlvTxDY0ldoImqTOwX7VmL4,10237 +sympy/utilities/_compilation/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/utilities/_compilation/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/utilities/_compilation/tests/__pycache__/test_compilation.cpython-39.pyc,, +sympy/utilities/_compilation/tests/test_compilation.py,sha256=DmdxCPlWaqgtNZPWzMVPwqRuqqJcgWBWxXfOEMZ21D0,3189 +sympy/utilities/_compilation/util.py,sha256=otSIKhM5YSYzbUNYycagG2xSj6owlQVrwIVl0rv7Owc,8611 +sympy/utilities/autowrap.py,sha256=l8noWRKmA-B814S92RK3O64YWAkARnMfqY1AQqi2S8g,42561 +sympy/utilities/codegen.py,sha256=1ieFSlxGgd9utMSPzN5NraTfoSSzlUM9V_BnGYMuMc8,81670 +sympy/utilities/decorator.py,sha256=NaZBNCDdMhG4_3esEOLZCbVQcjltvi3lrMIRCqewfR4,11184 +sympy/utilities/enumerative.py,sha256=ah4Td5KhqGdPR7EQHcUDOsYLDH64Sh2z6xm5gDpJSFg,43585 +sympy/utilities/exceptions.py,sha256=RoKY7jDIq6OsZbNSCyneWbVQb1Cw2MtOuioJlCKmBec,10570 +sympy/utilities/iterables.py,sha256=noIMRQmdSPfbg1LorRPHreozS265u4IHUmDkOvMB2po,91103 +sympy/utilities/lambdify.py,sha256=lepieExI19QvZYPrGl_WZaFwta8YZBo6AgxDnwq01ow,57816 +sympy/utilities/magic.py,sha256=ofrwi1-xwMWb4VCQOEIwe4J1QAwxOscigDq26uSn3iY,400 +sympy/utilities/matchpy_connector.py,sha256=7dSDOPbN5Y_XW6bIGVNK3dJ-gVdTB_liJ8O5rIqd28c,11948 +sympy/utilities/mathml/__init__.py,sha256=74VhxNJlvCZTm2Jh3t69N8QZRklTQjqMvqFMLVwCKQk,3388 +sympy/utilities/mathml/__pycache__/__init__.cpython-39.pyc,, +sympy/utilities/mathml/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/utilities/mathml/data/__pycache__/__init__.cpython-39.pyc,, +sympy/utilities/mathml/data/mmlctop.xsl,sha256=fi3CTNyg-mSscOGYBXLJv8veE_ItR_YTFMJ4jmjp6aE,114444 +sympy/utilities/mathml/data/mmltex.xsl,sha256=haX7emZOfD6_nbn5BjK93F-C85mSS8KogAbIBsW1aBA,137304 +sympy/utilities/mathml/data/simple_mmlctop.xsl,sha256=OM-Vge1satH-MAYwWhraeXcorn1KGtuqBK-PDddaOrk,114433 +sympy/utilities/memoization.py,sha256=jD6RjVMZkGpNZYaJ9481vTiqvmwyu1IKDpsF5PYIvf4,1838 +sympy/utilities/misc.py,sha256=C6C3Em9YUjqg4gUhVdu97K0Wt_gRxuhNHS0TW6MlrxQ,16000 +sympy/utilities/pkgdata.py,sha256=BiBonyObCsS6EnLUII_1kDtchDN7Tur0T4GMzoUo06M,935 +sympy/utilities/pytest.py,sha256=ifo5fkOQ94HKXjy_U__mLFSohNRGP2ft8Z4P5gAi4yw,440 +sympy/utilities/randtest.py,sha256=Qlr4MMe0ade6jUH6pIEE0-v4g4qGswJhKpdaQFS_GDw,435 +sympy/utilities/runtests.py,sha256=Js3GGuzxQDZw5Tr5Kge0DCYx9gi3eya4DZxfwib6e3Y,451 +sympy/utilities/source.py,sha256=ShIXRNtplSEfZNi5VDYD3yi6305eRz4TmchEOEvcicw,1127 +sympy/utilities/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/utilities/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/utilities/tests/__pycache__/test_autowrap.cpython-39.pyc,, +sympy/utilities/tests/__pycache__/test_codegen.cpython-39.pyc,, +sympy/utilities/tests/__pycache__/test_codegen_julia.cpython-39.pyc,, +sympy/utilities/tests/__pycache__/test_codegen_octave.cpython-39.pyc,, +sympy/utilities/tests/__pycache__/test_codegen_rust.cpython-39.pyc,, +sympy/utilities/tests/__pycache__/test_decorator.cpython-39.pyc,, +sympy/utilities/tests/__pycache__/test_deprecated.cpython-39.pyc,, +sympy/utilities/tests/__pycache__/test_enumerative.cpython-39.pyc,, +sympy/utilities/tests/__pycache__/test_exceptions.cpython-39.pyc,, +sympy/utilities/tests/__pycache__/test_iterables.cpython-39.pyc,, +sympy/utilities/tests/__pycache__/test_lambdify.cpython-39.pyc,, +sympy/utilities/tests/__pycache__/test_matchpy_connector.cpython-39.pyc,, +sympy/utilities/tests/__pycache__/test_mathml.cpython-39.pyc,, +sympy/utilities/tests/__pycache__/test_misc.cpython-39.pyc,, +sympy/utilities/tests/__pycache__/test_pickling.cpython-39.pyc,, +sympy/utilities/tests/__pycache__/test_source.cpython-39.pyc,, +sympy/utilities/tests/__pycache__/test_timeutils.cpython-39.pyc,, +sympy/utilities/tests/__pycache__/test_wester.cpython-39.pyc,, +sympy/utilities/tests/__pycache__/test_xxe.cpython-39.pyc,, +sympy/utilities/tests/test_autowrap.py,sha256=PA6zaRVrIWyQHq6bGKezqNB6YWTfcnnB9VH2vzaEfaQ,14865 +sympy/utilities/tests/test_codegen.py,sha256=pjfhZTa5LI9gdrakDfLtBBrIxgWjWORR_NBDW9J8Cq8,56130 +sympy/utilities/tests/test_codegen_julia.py,sha256=QYqSluVhmxq_BAagEKBb6eJIseFFAweMjpuLgNPMUWI,18540 +sympy/utilities/tests/test_codegen_octave.py,sha256=GUQIj6YjsMsj8ZLmVngWa9ovoyqfwvikWXdeS7m_c7I,17830 +sympy/utilities/tests/test_codegen_rust.py,sha256=seq9o51dsa09MmcqwKRkGiurbWLs1uAIjaz7xGOKNE4,12326 +sympy/utilities/tests/test_decorator.py,sha256=VYUvzUrVI7I7MK0YZxLLEmEu4pV5dqaB1CLEJ8Ocav4,3705 +sympy/utilities/tests/test_deprecated.py,sha256=LRrZ2UxuXnK6Jwxl8vT0EdLT-q-7jLkTC69U9JjuYYU,489 +sympy/utilities/tests/test_enumerative.py,sha256=nw-1r0MgxbzYkScmxv7HUAIAQycg8pUVOH_slDRLg7E,6097 +sympy/utilities/tests/test_exceptions.py,sha256=OKRa2yuHMtnVcnisu-xcaedi2RKsH9QrgU9exgoOK30,716 +sympy/utilities/tests/test_iterables.py,sha256=xYWyoDtVkYfnIOdA-5yUYZsxo4iTsExaNpVmjMuwpTc,35288 +sympy/utilities/tests/test_lambdify.py,sha256=emTSDwHdGprENO3ugK86k1L6m8Y3IamPj1oUNIKaYrQ,73531 +sympy/utilities/tests/test_matchpy_connector.py,sha256=mBrAev2Hxe4jg_1ny3ZaGIfh2xvWbr6BDRVjB9owJFM,4850 +sympy/utilities/tests/test_mathml.py,sha256=-6z1MRYEH4eYQi2_wt8zmdjwtt5Cn483zqsvD-o_r70,836 +sympy/utilities/tests/test_misc.py,sha256=xZh6ux9o2iZpWvfX_ivp4lD3WgJVeP5k6ipppNGHDsw,4643 +sympy/utilities/tests/test_pickling.py,sha256=ZXGzB_gQ9juws_N_n2vf6QPVojdfP4Mbl7uqdtkPb0E,23600 +sympy/utilities/tests/test_source.py,sha256=ObjrJxZFVhLgXjVmFHUy7bti9UPPgOh5Cptw8lHW9mM,289 +sympy/utilities/tests/test_timeutils.py,sha256=sCRC6BCSho1e9n4clke3QXHx4a3qYLru-bddS_sEmFA,337 +sympy/utilities/tests/test_wester.py,sha256=S79UNsGhEW5z6hDriQB8OY_Vhaz1j0Q7yv31HynyDfk,94866 +sympy/utilities/tests/test_xxe.py,sha256=xk1j0Dd96wsGYKRNDzXTW0hTQejGCfiZcEhYcYiqojg,66 +sympy/utilities/timeutils.py,sha256=PNHyvuEM66UEXLF8ye6PMilcEDK75B9wab3ai8a9cyo,1941 +sympy/utilities/tmpfiles.py,sha256=tQP3g7LTrnXTGqUgGNoIZtjug6O6lq6EYLrKPJcUGEU,450 +sympy/vector/__init__.py,sha256=bgKutVqUPWuvHWBpiEm6XWE6iGfAPJw3pt19NRn2n10,1969 +sympy/vector/__pycache__/__init__.cpython-39.pyc,, +sympy/vector/__pycache__/basisdependent.cpython-39.pyc,, +sympy/vector/__pycache__/coordsysrect.cpython-39.pyc,, +sympy/vector/__pycache__/deloperator.cpython-39.pyc,, +sympy/vector/__pycache__/dyadic.cpython-39.pyc,, +sympy/vector/__pycache__/functions.cpython-39.pyc,, +sympy/vector/__pycache__/implicitregion.cpython-39.pyc,, +sympy/vector/__pycache__/integrals.cpython-39.pyc,, +sympy/vector/__pycache__/kind.cpython-39.pyc,, +sympy/vector/__pycache__/operators.cpython-39.pyc,, +sympy/vector/__pycache__/orienters.cpython-39.pyc,, +sympy/vector/__pycache__/parametricregion.cpython-39.pyc,, +sympy/vector/__pycache__/point.cpython-39.pyc,, +sympy/vector/__pycache__/scalar.cpython-39.pyc,, +sympy/vector/__pycache__/vector.cpython-39.pyc,, +sympy/vector/basisdependent.py,sha256=P7LogUck6NWjEzAOH3dYr4A8M5QCOPaCW75lmiQpBqU,11857 +sympy/vector/coordsysrect.py,sha256=0BKdr8AG8BUJnShIIyyCvkU5OVdjtZsKcBuGgdbvdHw,36833 +sympy/vector/deloperator.py,sha256=4BJNjmI342HkVRmeQkqauqvibKsf2HOuzknQTfQMkpg,3191 +sympy/vector/dyadic.py,sha256=IOyrgONyGDHPtG0RINcMgetAVMSOmYI5a99s-OwXBTA,8571 +sympy/vector/functions.py,sha256=rmEBaV0XpZLkCPcJFojHS9CNHirV55v5sDwNqTPwqGM,15454 +sympy/vector/implicitregion.py,sha256=4-M_-shL4EJ9nfjAZ_xg00IXZVQghcCZ3bv8QGbkopI,16158 +sympy/vector/integrals.py,sha256=x8DrvKXPznE05JgnZ7I3IWLWrvFl9SEghGaFmHrBaE4,6837 +sympy/vector/kind.py,sha256=U6qxciZY7-ACzBSpGM0iSTI5tqflpPBvm4_J9ujyxyE,1812 +sympy/vector/operators.py,sha256=dNC7mNXZb3KxX5GyjFNDTLTkzSnYjOohZtBXuQ1keZE,9579 +sympy/vector/orienters.py,sha256=EtWNWfOvAuy_wipam9SA7_muKSrsP-43UPRCCz56sb0,11798 +sympy/vector/parametricregion.py,sha256=gAarunY8q3D8WVvryZEbJQ4g-nlXo4FnO-CbyEpVUXc,5932 +sympy/vector/point.py,sha256=-JFOibsQclOKC2m2qZRiGGB4oK9ieVZRd2ll5g4k3bE,4488 +sympy/vector/scalar.py,sha256=WAb25BC0TEOZ-_bEtL1gdySuM1adJzK47kJYTrC4Y4w,2024 +sympy/vector/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sympy/vector/tests/__pycache__/__init__.cpython-39.pyc,, +sympy/vector/tests/__pycache__/test_coordsysrect.cpython-39.pyc,, +sympy/vector/tests/__pycache__/test_dyadic.cpython-39.pyc,, +sympy/vector/tests/__pycache__/test_field_functions.cpython-39.pyc,, +sympy/vector/tests/__pycache__/test_functions.cpython-39.pyc,, +sympy/vector/tests/__pycache__/test_implicitregion.cpython-39.pyc,, +sympy/vector/tests/__pycache__/test_integrals.cpython-39.pyc,, +sympy/vector/tests/__pycache__/test_operators.cpython-39.pyc,, +sympy/vector/tests/__pycache__/test_parametricregion.cpython-39.pyc,, +sympy/vector/tests/__pycache__/test_printing.cpython-39.pyc,, +sympy/vector/tests/__pycache__/test_vector.cpython-39.pyc,, +sympy/vector/tests/test_coordsysrect.py,sha256=q9n9OIG_CpD4KQN20dzwRZIXoMv7VSgp8fHmVnkZfr0,19595 +sympy/vector/tests/test_dyadic.py,sha256=f1R-BL_63VBbc0XgEX_LYzV_3OupYd4hp5RzRk6dAbI,4949 +sympy/vector/tests/test_field_functions.py,sha256=v9l8Ex8K2MsPGxqAPhpEgu6WAo6wS6qvdWLKQMxgE4A,14094 +sympy/vector/tests/test_functions.py,sha256=Bs2sekdDJyw_wrUpG7vZQGH0y0S4C4AbxGSpeU_8C2s,8050 +sympy/vector/tests/test_implicitregion.py,sha256=wVilD5H-MhHiW58QT6P5U7uT79JdKHm9D7JgZoi6BE4,4028 +sympy/vector/tests/test_integrals.py,sha256=iDq3Wpr8I8EXAGkXqMpA-P9i1zEnV7lU8WR6xBOhEK0,5087 +sympy/vector/tests/test_operators.py,sha256=KexUWvc_Nwp2HWrEbhxiO7MeaFxYlckrp__Tkwg-wmU,1613 +sympy/vector/tests/test_parametricregion.py,sha256=OfKapF9A_g9X6JxgYc0UfxIhwXzRERzaj-EijQCJONw,4009 +sympy/vector/tests/test_printing.py,sha256=3BeW55iQ4qXdfDTFqptE2ufJPJIBOzdfIYVx84n_EwA,7708 +sympy/vector/tests/test_vector.py,sha256=rJpMlUWiezrY6drHUNdJc8_FDYMwGCG4lBy8zS4XSZk,10238 +sympy/vector/vector.py,sha256=aOSfyfRvkatzWHWuulY6n0_iqElI5mntAUjNPptpBiM,20322 diff --git a/phivenv/Lib/site-packages/sympy-1.14.0.dist-info/WHEEL b/phivenv/Lib/site-packages/sympy-1.14.0.dist-info/WHEEL new file mode 100644 index 0000000000000000000000000000000000000000..8acb95590701b87bf84eec079cf4e3989f63b098 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy-1.14.0.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: setuptools (79.0.1) +Root-Is-Purelib: true +Tag: py3-none-any + diff --git a/phivenv/Lib/site-packages/sympy-1.14.0.dist-info/entry_points.txt b/phivenv/Lib/site-packages/sympy-1.14.0.dist-info/entry_points.txt new file mode 100644 index 0000000000000000000000000000000000000000..42a12f960335556fcee728e5754c346447d4e89c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy-1.14.0.dist-info/entry_points.txt @@ -0,0 +1,2 @@ +[console_scripts] +isympy = isympy:main diff --git a/phivenv/Lib/site-packages/sympy-1.14.0.dist-info/licenses/AUTHORS b/phivenv/Lib/site-packages/sympy-1.14.0.dist-info/licenses/AUTHORS new file mode 100644 index 0000000000000000000000000000000000000000..1062f3bf743c0a37bcde05b1afc1ba5e98a95822 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy-1.14.0.dist-info/licenses/AUTHORS @@ -0,0 +1,1379 @@ +All people who contributed to SymPy by sending at least a patch or +more (in the order of the date of their first contribution), except +those who explicitly didn't want to be mentioned. People with a * next +to their names are not found in the metadata of the git history. This +file is generated automatically by running `./bin/authors_update.py`. + +There are a total of 1371 authors. + +Ondřej Čertík +Fabian Pedregosa +Jurjen N.E. Bos +Mateusz Paprocki +*Marc-Etienne M.Leveille +Brian Jorgensen +Jason Gedge +Robert Schwarz +Pearu Peterson +Fredrik Johansson +Chris Wu +*Ulrich Hecht +Goutham Lakshminarayan +David Lawrence +Jaroslaw Tworek +David Marek +Bernhard R. Link +Andrej Tokarčík +Or Dvory +Saroj Adhikari +Pauli Virtanen +Robert Kern +James Aspnes +Nimish Telang +Abderrahim Kitouni +Pan Peng +Friedrich Hagedorn +Elrond der Elbenfuerst +Rizgar Mella +Felix Kaiser +Roberto Nobrega +David Roberts +Sebastian Krämer +Vinzent Steinberg +Riccardo Gori +Case Van Horsen +Stepan Roucka +Ali Raza Syed +Stefano Maggiolo +Robert Cimrman +Bastian Weber +Sebastian Krause +Sebastian Kreft +*Dan +Alan Bromborsky +Boris Timokhin +Robert +Andy R. Terrel +Hubert Tsang +Konrad Meyer +Henrik Johansson +Priit Laes +Freddie Witherden +Brian E. Granger +Andrew Straw +Kaifeng Zhu +Ted Horst +Andrew Docherty +Akshay Srinivasan +Aaron Meurer +Barry Wardell +Tomasz Buchert +Vinay Kumar +Johann Cohen-Tanugi +Jochen Voss +Luke Peterson +Chris Smith +Thomas Sidoti +Florian Mickler +Nicolas Pourcelot +Ben Goodrich +Toon Verstraelen +Ronan Lamy +James Abbatiello +Ryan Krauss +Bill Flynn +Kevin Goodsell +Jorn Baayen +Eh Tan +Renato Coutinho +Oscar Benjamin +Øyvind Jensen +Julio Idichekop Filho +Łukasz Pankowski +*Chu-Ching Huang +Fernando Perez +Raffaele De Feo +Christian Muise +Matt Curry +Kazuo Thow +Christian Schubert +Jezreel Ng +James Pearson +Matthew Brett +Addison Cugini +Nicholas J.S. Kinar +Harold Erbin +Thomas Dixon +Cristóvão Sousa +Andre de Fortier Smit +Mark Dewing +Alexey U. Gudchenko +Gary Kerr +Sherjil Ozair +Oleksandr Gituliar +Sean Vig +Prafullkumar P. Tale +Vladimir Perić +Tom Bachmann +Yuri Karadzhov +Vladimir Lagunov +Matthew Rocklin +Saptarshi Mandal +Gilbert Gede +Anatolii Koval +Tomo Lazovich +Pavel Fedotov +Jack McCaffery +Jeremias Yehdegho +Kibeom Kim +Gregory Ksionda +Tomáš Bambas +Raymond Wong +Luca Weihs +Shai 'Deshe' Wyborski +Thomas Wiecki +Óscar Nájera +Mario Pernici +Benjamin McDonald +Sam Magura +Stefan Krastanov +Bradley Froehle +Min Ragan-Kelley +Emma Hogan +Nikhil Sarda +Julien Rioux +Roberto Colistete, Jr. +Raoul Bourquin +Gert-Ludwig Ingold +Srinivas Vasudevan +Jason Moore +Miha Marolt +Tim Lahey +Luis Garcia +Matt Rajca +David Li +Alexandr Gudulin +Bilal Akhtar +Grzegorz Świrski +Matt Habel +David Ju +Nichita Utiu +Nikolay Lazarov +Steve Anton +Imran Ahmed Manzoor +Ljubiša Moćić <3rdslasher@gmail.com> +Piotr Korgul +Jim Zhang +Sam Sleight +tborisova +Chancellor Arkantos +Stepan Simsa +Tobias Lenz +Siddhanathan Shanmugam +Tiffany Zhu +Tristan Hume +Alexey Subach +Joan Creus +Geoffry Song +Puneeth Chaganti +Marcin Kostrzewa <> +Natalia Nawara +vishal +Shruti Mangipudi +Davy Mao +Swapnil Agarwal +Dhia Kennouche +jerryma1121 +Joachim Durchholz +Martin Povišer +Siddhant Jain +Kevin Hunter +Michael Mayorov +Nathan Alison +Christian Bühler +Carsten Knoll +Bharath M R +Matthias Toews +Sergiu Ivanov +Jorge E. Cardona +Sanket Agarwal +Manoj Babu K. +Sai Nikhil +Aleksandar Makelov +Sachin Irukula +Raphael Michel +Ashwini Oruganti +Andreas Klöckner +Prateek Papriwal +Arpit Goyal +Angadh Nanjangud +Comer Duncan +Jens H. Nielsen +Joseph Dougherty +Elliot Marshall +Guru Devanla +George Waksman +Alexandr Popov +Tarun Gaba +Takafumi Arakaki +Saurabh Jha +Rom le Clair +Angus Griffith <16sn6uv@gmail.com> +Timothy Reluga +Brian Stephanik +Alexander Eberspächer +Sachin Joglekar +Tyler Pirtle +Vasily Povalyaev +Colleen Lee +Matthew Hoff +Niklas Thörne +Huijun Mai +Marek Šuppa +Ramana Venkata +Prasoon Shukla +Stefen Yin +Thomas Hisch +Madeleine Ball +Mary Clark +Rishabh Dixit +Manoj Kumar +Akshit Agarwal +CJ Carey +Patrick Lacasse +Ananya H +Tarang Patel +Christopher Dembia +Benjamin Fishbein +Sean Ge +Amit Jamadagni +Ankit Agrawal +Björn Dahlgren +Christophe Saint-Jean +Demian Wassermann +Khagesh Patel +Stephen Loo +hm +Patrick Poitras +Katja Sophie Hotz +Varun Joshi +Chetna Gupta +Thilina Rathnayake +Max Hutchinson +Shravas K Rao +Matthew Tadd +Alexander Hirzel +Randy Heydon +Oliver Lee +Seshagiri Prabhu +Pradyumna +Erik Welch +Eric Nelson +Roland Puntaier +Chris Conley +Tim Swast +Dmitry Batkovich +Francesco Bonazzi +Yuriy Demidov +Rick Muller +Manish Gill +Markus Müller +Amit Saha +Jeremy +QuaBoo +Stefan van der Walt +David Joyner +Lars Buitinck +Alkiviadis G. Akritas +Vinit Ravishankar +Michael Boyle +Heiner Kirchhoffer +Pablo Puente +James Fiedler +Harsh Gupta +Tuomas Airaksinen +Paul Strickland +James Goppert +rathmann +Avichal Dayal +Paul Scott +Shipra Banga +Pramod Ch +Akshay +Buck Shlegeris +Jonathan Miller +Edward Schembor +Rajath Shashidhara +Zamrath Nizam +Aditya Shah +Rajat Aggarwal +Sambuddha Basu +Zeel Shah +Abhinav Chanda +Jim Crist +Sudhanshu Mishra +Anurag Sharma +Soumya Dipta Biswas +Sushant Hiray +Ben Lucato +Kunal Arora +Henry Gebhardt +Dammina Sahabandu +Manish Shukla +Ralph Bean +richierichrawr +John Connor +Juan Luis Cano Rodríguez +Sahil Shekhawat +Kundan Kumar +Stas Kelvich +sevaader +Dhruvesh Vijay Parikh +Venkatesh Halli +Lennart Fricke +Vlad Seghete +Shashank Agarwal +carstimon +Pierre Haessig +Maciej Baranski +Benjamin Gudehus +Faisal Anees +Mark Shoulson +Robert Johansson +Kalevi Suominen +Kaushik Varanasi +Fawaz Alazemi +Ambar Mehrotra +David P. Sanders +Peter Brady +John V. Siratt +Sarwar Chahal +Nathan Woods +Colin B. Macdonald +Marcus Näslund +Clemens Novak +Mridul Seth +Craig A. Stoudt +Raj +Mihai A. Ionescu +immerrr +Chai Wah Wu +Leonid Blouvshtein +Peleg Michaeli +ck Lux +zsc347 +Hamish Dickson +Michael Gallaspy +Roman Inflianskas +Duane Nykamp +Ted Dokos +Sunny Aggarwal +Victor Brebenar +Akshat Jain +Shivam Vats +Longqi Wang +Juan Felipe Osorio +Ray Cathcart +Lukas Zorich +Eric Miller +Cody Herbst +Nishith Shah +Amit Kumar +Yury G. Kudryashov +Guillaume Gay +Mihir Wadwekar +Tuan Manh Lai +Asish Panda +Darshan Chaudhary +Alec Kalinin +Ralf Stephan +Aaditya Nair +Jayesh Lahori +Harshil Goel +Luv Agarwal +Jason Ly +Lokesh Sharma +Sartaj Singh +Chris Swierczewski +Konstantin Togoi +Param Singh +Sumith Kulal +Juha Remes +Philippe Bouafia +Peter Schmidt +Jiaxing Liang +Lucas Jones +Gregory Ashton +Jennifer White +Renato Orsino +Alistair Lynn +Govind Sahai +Adam Bloomston +Kyle McDaniel +Nguyen Truong Duy +Alex Lindsay +Mathew Chong +Jason Siefken +Gaurav Dhingra +Gao, Xiang +Kevin Ventullo +mao8 +Isuru Fernando +Shivam Tyagi +Richard Otis +Rich LaSota +dustyrockpyle +Anton Akhmerov +Michael Zingale +Chak-Pong Chung +David T +Phil Ruffwind +Sebastian Koslowski +Kumar Krishna Agrawal +Dustin Gadal +João Moura +Yu Kobayashi +Shashank Kumar +Timothy Cyrus +Devyani Kota +Keval Shah +Dzhelil Rufat +Pastafarianist +Sourav Singh +Jacob Garber +Vinay Singh +GolimarOurHero +Prashant Tyagi +Matthew Davis +Tschijnmo TSCHAU +Alexander Bentkamp +Jack Kemp +Kshitij Saraogi +Thomas Baruchel +Nicolás Guarín-Zapata +Jens Jørgen Mortensen +Sampad Kumar Saha +Eva Charlotte Mayer +Laura Domine +Justin Blythe +Meghana Madhyastha +Tanu Hari Dixit +Shekhar Prasad Rajak +Aqnouch Mohammed +Arafat Dad Khan +Boris Atamanovskiy +Sam Tygier +Jai Luthra +Guo Xingjian +Sandeep Veethu +Archit Verma +Shubham Tibra +Ashutosh Saboo +Michael S. Hansen +Anish Shah +Guillaume Jacquenot +Bhautik Mavani +Michał Radwański +Jerry Li +Pablo Zubieta +Shivam Agarwal +Chaitanya Sai Alaparthi +Arihant Parsoya +Ruslan Pisarev +Akash Trehan +Nishant Nikhil +Vladimir Poluhsin +Akshay Nagar +James Brandon Milam +Abhinav Agarwal +Rishabh Daal +Sanya Khurana +Aman Deep +Aravind Reddy +Abhishek Verma +Matthew Parnell +Thomas Hickman +Akshay Siramdas +YiDing Jiang +Jatin Yadav +Matthew Thomas +Rehas Sachdeva +Michael Mueller +Srajan Garg +Prabhjot Singh +Haruki Moriguchi +Tom Gijselinck +Nitin Chaudhary +Alex Argunov +Nathan Musoke +Abhishek Garg +Dana Jacobsen +Vasiliy Dommes +Phillip Berndt +Haimo Zhang +Anthony Scopatz +bluebrook +Leonid Kovalev +Josh Burkart +Dimitra Konomi +Christina Zografou +Fiach Antaw +Langston Barrett +Krit Karan +G. D. McBain +Prempal Singh +Gabriel Orisaka +Matthias Bussonnier +rahuldan +Colin Marquardt +Andrew Taber +Yash Reddy +Peter Stangl +elvis-sik +Nikos Karagiannakis +Jainul Vaghasia +Dennis Meckel +Harshil Meena +Micky +Nick Curtis +Michele Zaffalon +Martha Giannoudovardi +Devang Kulshreshtha +Steph Papanik +Mohammad Sadeq Dousti +Arif Ahmed +Abdullah Javed Nesar +Lakshya Agrawal +shruti +Rohit Rango +Hong Xu +Ivan Petuhov +Alsheh +Marcel Stimberg +Alexey Pakhocmhik +Tommy Olofsson +Zulfikar +Blair Azzopardi +Danny Hermes +Sergey Pestov +Mohit Chandra +Karthik Chintapalli +Marcin Briański +andreo +Flamy Owl +Yicong Guo +Varun Garg +Rishabh Madan +Aditya Kapoor +Karan Sharma +Vedant Rathore +Johan Blåbäck +Pranjal Tale +Jason Tokayer +Raghav Jajodia +Rajat Thakur +Dhruv Bhanushali +Anjul Kumar Tyagi +Barun Parruck +Bao Chau +Tanay Agrawal +Ranjith Kumar +Shikhar Makhija +Yathartha Joshi +Valeriia Gladkova +Sagar Bharadwaj +Daniel Mahler +Ka Yi +Rishat Iskhakov +Szymon Mieszczak +Sachin Agarwal +Priyank Patel +Satya Prakash Dwibedi +tools4origins +Nico Schlömer +Fermi Paradox +Ekansh Purohit +Vedarth Sharma +Peeyush Kushwaha +Jayjayyy +Christopher J. Wright +Jakub Wilk +Mauro Garavello +Chris Tefer +Shikhar Jaiswal +Chiu-Hsiang Hsu +Carlos Cordoba +Fabian Ball +Yerniyaz +Christiano Anderson +Robin Neatherway +Thomas Hunt +Theodore Han +Duc-Minh Phan +Lejla Metohajrova +Samyak Jain +Aditya Rohan +Vincent Delecroix +Michael Sparapany +Harsh Jain +Nathan Goldbaum +latot +Kenneth Lyons +Stan Schymanski +David Daly +Ayush Shridhar +Javed Nissar +Jiri Kuncar +vedantc98 +Rupesh Harode +Rob Zinkov +James Harrop +James Taylor +Ishan Joshi +Marco Mancini +Boris Ettinger +Micah Fitch +Daniel Wennberg +ylemkimon +Akash Vaish +Peter Enenkel +Waldir Pimenta +Jithin D. George +Lev Chelyadinov +Lucas Wiman +Rhea Parekh +James Cotton +Robert Pollak +anca-mc +Sourav Ghosh +Jonathan Allan +Nikhil Pappu +Ethan Ward +Cezary Marczak +dps7ud +Nilabja Bhattacharya +Itay4 <31018228+Itay4@users.noreply.github.com> +Poom Chiarawongse +Yang Yang +Cavendish McKay +Bradley Gannon +B McG +Rob Drynkin +Seth Ebner +Akash Kundu +Mark Jeromin +Roberto Díaz Pérez +Gleb Siroki +Segev Finer +Alex Lubbock +Ayodeji Ige +Matthew Wardrop +Hugo van Kemenade +Austin Palmer +der-blaue-elefant +Filip Gokstorp +Yuki Matsuda +Aaron Miller +Salil Vishnu Kapur +Atharva Khare +Shubham Maheshwari +Pavel Tkachenko +Ashish Kumar Gaurav +Rajeev Singh +Keno Goertz +Lucas Gallindo +Himanshu +David Menéndez Hurtado +Amit Manchanda +Rohit Jain +Jonathan A. Gross +Unknown +Sayan Goswami +Subhash Saurabh +Rastislav Rabatin +Vishal +Jeremey Gluck +Akshat Maheshwari +symbolique +Saloni Jain +Arighna Chakrabarty +Abhigyan Khaund +Jashanpreet Singh +Saurabh Agarwal +luzpaz +P. Sai Prasanth +Nirmal Sarswat +Cristian Di Pietrantonio +Ravi charan +Nityananda Gohain +Cédric Travelletti +Nicholas Bollweg +Himanshu Ladia +Adwait Baokar +Mihail Tarigradschi +Saketh +rushyam +sfoo +Rahil Hastu +Zach Raines +Sidhant Nagpal +Gagandeep Singh +Rishav Chakraborty +Malkhan Singh +Joaquim Monserrat +Mayank Singh +Rémy Léone +Maxence Mayrand <35958639+maxencemayrand@users.noreply.github.com> +Nikoleta Glynatsi +helo9 +Ken Wakita +Carl Sandrock +Fredrik Eriksson +Ian Swire +Bulat +Ehren Metcalfe +Dmitry Savransky +Kiyohito Yamazaki +Caley Finn +Zhi-Qiang Zhou +Alexander Pozdneev +Wes Turner <50891+westurner@users.noreply.github.com> +JMSS-Unknown <31131631+JMSS-Unknown@users.noreply.github.com> +Arshdeep Singh +cym1 <16437732+cym1@users.noreply.github.com> +Stewart Wadsworth +Jared Lumpe +Avi Shrivastava +ramvenkat98 +Bilal Ahmed +Dimas Abreu Archanjo Dutra +Yatna Verma +S.Y. Lee +Miro Hrončok +Sudarshan Kamath +Ayushman Koul +Robert Dougherty-Bliss +Andrey Grozin +Bavish Kulur +Arun Singh +sirnicolaf <43586954+sirnicolaf@users.noreply.github.com> +Zachariah Etienne +Prayush Dawda <35144226+iamprayush@users.noreply.github.com> +2torus +Faisal Riyaz +Martin Roelfs +SirJohnFranklin +Anthony Sottile +ViacheslavP +Safiya03 +Alexander Dunlap +Rohit Sharma <31184621+rohitx007@users.noreply.github.com> +Jonathan Warner +Mohit Balwani +Marduk Bolaños +amsuhane +Matthias Geier +klaasvanaarsen <44929042+klaasvanaarsen@users.noreply.github.com> +Shubham Kumar Jha +rationa-kunal +Animesh Sinha +Gaurang Tandon <1gaurangtandon@gmail.com> +Matthew Craven +Daniel Ingram +Jogi Miglani +Takumasa Nakamura +Ritu Raj Singh +Rajiv Ranjan Singh +Vera Lozhkina +adhoc-king <46354827+adhoc-king@users.noreply.github.com> +Mikel Rouco +Oscar Gustafsson +damianos +Supreet Agrawal +shiksha11 +Martin Ueding +sharma-kunal +Divyanshu Thakur +Susumu Ishizuka +Samnan Rahee +Fredrik Andersson +Bhavya Srivastava +Alpesh Jamgade +Shubham Abhang +Vishesh Mangla +Nicko van Someren +dandiez <47832466+dandiez@users.noreply.github.com> +Frédéric Chapoton +jhanwar +Noumbissi valere Gille Geovan +Salmista-94 +Shivani Kohli +Parker Berry +Pragyan Mehrotra +Nabanita Dash +Gaetano Guerriero +Ankit Raj Pandey +Ritesh Kumar +kangzhiq <709563092@qq.com> +Jun Lin +Petr Kungurtsev +Anway De +znxftw +Denis Ivanenko +Orestis Vaggelis +Nikhil Maan +Abhinav Anand +Qingsha Shi +Juan Barbosa +Prionti Nasir +Bharat Raghunathan +arooshiverma +Christoph Gohle +Charalampos Tsiagkalis +Daniel Sears +Megan Ly +Sean P. Cornelius +Erik R. Gomez +Riccardo Magliocchetti +Henry Metlov +pekochun +Bendik Samseth +Vighnesh Shenoy +Versus Void +Denys Rybalka +Mark Dickinson +Rimi +rimibis <33387803+rimibis@users.noreply.github.com> +Steven Lee +Gilles Schintgen +Abhi58 +Tomasz Pytel +Aadit Kamat +Samesh +Velibor Zeli +Gabriel Bernardino +Joseph Redfern +Evelyn King +Miguel Marco +David Hagen +Hannah Kari +Soniya Nayak +Harsh Agarwal +Enric Florit +Yogesh Mishra +Denis Rykov +Ivan Tkachenko +Kenneth Emeka Odoh +Stephan Seitz +Yeshwanth N +Oscar Gerardo Lazo Arjona +Srinivasa Arun Yeragudipati +Kirtan Mali +TitanSnow +Pengning Chao <8857165+PengningChao@users.noreply.github.com> +Louis Abraham +Morten Olsen Lysgaard +Akash Nagaraj (akasnaga) +Akash Nagaraj +Lauren Glattly +Hou-Rui +George Korepanov +dranknight09 +aditisingh2362 +Gina +gregmedlock +Georgios Giapitzakis Tzintanos +Eric Wieser +Bradley Dowling <34559056+btdow@users.noreply.github.com> +Maria Marginean <33810762+mmargin@users.noreply.github.com> +Akash Agrawall +jgulian +Sourav Goyal +Zlatan Vasović +Alex Meiburg +Smit Lunagariya +Naman Gera +Julien Palard +Dhruv Mendiratta +erdOne <36414270+erdOne@users.noreply.github.com> +risubaba +abhinav28071999 <41710346+abhinav28071999@users.noreply.github.com> +Jisoo Song +Jaime R <38530589+Jaime02@users.noreply.github.com> +Vikrant Malik +Hardik Saini <43683678+Guardianofgotham@users.noreply.github.com> +Abhishek +Johannes Hartung +Milan Jolly +faizan2700 +mohit <39158356+mohitacecode@users.noreply.github.com> +Mohit Gupta +Psycho-Pirate +Chanakya-Ekbote +Rashmi Shehana +Jonty16117 +Anubhav Gupta +Michal Grňo +vezeli <37907135+vezeli@users.noreply.github.com> +Tim Gates +Sandeep Murthy +Neil +V1krant <46847915+V1krant@users.noreply.github.com> +alejandro +Riyan Dhiman +sbt4104 +Seth Troisi +Bhaskar Gupta +Smit Gajjar +rbl +Ilya Pchelintsev +Omar Wagih +prshnt19 +Johan Guzman +Vasileios Kalos +BasileiosKal <61801875+BasileiosKal@users.noreply.github.com> +Shubham Thorat <37049710+sbt4104@users.noreply.github.com> +Arpan Chattopadhyay +Ashutosh Hathidara +Moses Paul R +Saanidhya vats +tnzl +Vatsal Srivastava +Jean-Luc Herren +Dhruv Kothari +seadavis <45022599+seadavis@users.noreply.github.com> +kamimura +slacker404 +Jaime Resano +Ebrahim Byagowi +wuyudi +Akira Kyle +Calvin Jay Ross +Martin Thoma +Thomas A Caswell +Lagaras Stelios +Jerry James +Jan Kruse +Nathan Taylor +Vaishnav Damani +Mohit Shah +Mathias Louboutin +Marijan Smetko +Dave Witte Morris +soumi7 +Zhongshi +Wes Galbraith +KaustubhDamania +w495 +Akhil Rajput +Markus Mohrhard +Benjamin Wolba +彭于斌 <1931127624@qq.com> +Rudr Tiwari +Aaryan Dewan +Benedikt Placke +Sneha Goddu +goddus <39923708+goddus@users.noreply.github.com> +Shivang Dubey +Michael Greminger +Peter Cock +Willem Melching +Elias Basler +Brandon David +Abhay_Dhiman +Tasha Kim +Ayush Malik +Devesh Sawant +Wolfgang Stöcher +Sudeep Sidhu +foice +Ben Payne +Muskan Kumar <31043527+muskanvk@users.noreply.github.com> +noam simcha finkelstein +Garrett Folbe +Islam Mansour +Sayandip Halder +Shubham Agrawal +numbermaniac <5206120+numbermaniac@users.noreply.github.com> +Sakirul Alam +Mohammed Bilal +Chris du Plessis +Coder-RG +Ansh Mishra +Alex Malins +Lorenzo Contento +Naveen Sai +Shital Mule +Amanda Dsouza +Nijso Beishuizen +Harry Zheng +Felix Yan +Constantin Mateescu +Eva Tiwari +Aditya Kumar Sinha +Soumi Bardhan <51290447+Soumi7@users.noreply.github.com> +Kaustubh Chaudhari +Kristian Brünn +Neel Gorasiya +Akshat Sood <68052998+akshatsood2249@users.noreply.github.com> +Jose M. Gomez +Stefan Petrea +Praveen Sahu +Mark Bell +AlexCQY +Fabian Froehlich +Nikhil Gopalam +Kartik Sethi +Muhammed Abdul Quadir Owais +Harshit Yadav +Sidharth Mundhra +Suryam Arnav Kalra +Prince Gupta +Kunal Singh +Mayank Raj +Achal Jain <2achaljain@gmail.com> +Mario Maio +Aaron Stiff <69512633+AaronStiff@users.noreply.github.com> +Wyatt Peak +Bhaskar Joshi +Aditya Jindal +Vaibhav Bhat +Priyansh Rathi +Saket Kumar Singh +Yukai Chou +Qijia Liu +Paul Mandel +Nisarg Chaudhari <54911392+Nisarg-Chaudhari@users.noreply.github.com> +Dominik Stańczak +Rodrigo Luger +Marco Antônio Habitzreuter +Ayush Bisht +Akshansh Bhatt +Brandon T. Willard +Thomas Aarholt +Hiren Chalodiya +Roland Dixon +dimasvq +Sagar231 +Michael Chu +Abby Ng +Angad Sandhu <55819847+angadsinghsandhu@users.noreply.github.com> +Alexander Cockburn +Yaser AlOsh +Davide Sandonà +Jonathan Gutow +Nihir Agarwal +Lee Johnston +Zach Carmichael <20629897+craymichael@users.noreply.github.com> +Vijairam Ganesh Moorthy +Hanspeter Schmid +Ben Oostendorp +Nikita +Aman +Shashank KS +Aman Sharma +Anup Parikh +Lucy Mountain +Miguel Torres Costa +Rikard Nordgren +Arun sanganal <74652697+ArunSanganal@users.noreply.github.com> +Kamlesh Joshi <72374645+kamleshjoshi8102@users.noreply.github.com> +Joseph Rance <56409230+Joseph-Rance@users.noreply.github.com> +Huangduirong +Nils Schulte <47043622+Schnilz@users.noreply.github.com> +Matt Bogosian +Elisha Hollander +Aditya Ravuri +Mamidi Ratna Praneeth +Jeffrey Ryan +Jonathan Daniel <36337649+jond01@users.noreply.github.com> +Robin Richard +Gautam Menghani +Remco de Boer <29308176+redeboer@users.noreply.github.com> +Sebastian East +Evani Balasubramanyam +Rahil Parikh +Jason Ross +Joannah Nanjekye +Ayush Kumar +Kshitij +Daniel Hyams +alijosephine +Matthias Köppe +mohajain +Anibal M. Medina-Mardones +Travis Ens +Evgenia Karunus +Risiraj Dey +lastcodestanding +Andrey Lekar +Abbas Mohammed <42001049+iam-abbas@users.noreply.github.com> +Anutosh Bhat +Steve Kieffer +Paul Spiering +Pieter Gijsbers +Wang Ran (汪然) +naelsondouglas +Aman Thakur +S. Hanko +Dennis Sweeney +Gurpartap Singh +Hampus Malmberg +scimax +Nikhil Date +Kuldeep Borkar Jr +AkuBrain <76952313+Franck2111@users.noreply.github.com> +Leo Battle +Advait Pote +Anurag Bhat +Jeremy Monat +Diane Tchuindjo +Tom Fryers <61272761+TomFryers@users.noreply.github.com> +Zouhair +zzj <29055749+zjzh@users.noreply.github.com> +shubhayu09 +Siddhant Jain +Tirthankar Mazumder <63574588+wermos@users.noreply.github.com> +Sumit Kumar +Shivam Sagar +Gaurav Jain +Andrii Oriekhov +Luis Talavera +Arie Bovenberg +Carson McManus +Jack Schmidt <1107865+jackschmidt@users.noreply.github.com> +Riley Britten +Georges Khaznadar +Donald Wilson +Timo Stienstra +dispasha +Saksham Alok +Varenyam Bhardwaj +oittaa <8972248+oittaa@users.noreply.github.com> +Omkaar <79257339+Pysics@users.noreply.github.com> +Islem BOUZENIA +extraymond +Alexander Behrens +user202729 <25191436+user202729@users.noreply.github.com> +Pieter Eendebak +Zaz Brown +ritikBhandari +viocha <66580331+viocha@users.noreply.github.com> +Arthur Ryman +Xiang Wu +tttc3 +Seth Poulsen +cocolato +Anton Golovanov +Gareth Ma +Clément M.T. Robert +Glenn Horton-Smith +Karan +Stefan Behnle <84378403+behnle@users.noreply.github.com> +Shreyash Mishra <72146041+Shreyash-cyber@users.noreply.github.com> +Arthur Milchior +NotWearingPants <26556598+NotWearingPants@users.noreply.github.com> +Ishan Pandhare +Carlos García Montoro +Parcly Taxel +Saicharan +Kunal Sheth +Biswadeep Purkayastha <98874428+metabiswadeep@users.noreply.github.com> +Jyn Spring 琴春 +Phil LeMaitre +Chris Kerr +José Senart +Uwe L. Korn +ForeverHaibara <69423537+ForeverHaibara@users.noreply.github.com> +Yves Tumushimire +wookie184 +Costor +Klaus Rettinghaus +Sam Brockie +Abhishek Patidar <1e9abhi1e10@gmail.com> +Eric Demer +Pontus von Brömssen +Victor Immanuel +Evandro Bernardes +Michele Ceccacci +Ayush Aryan +Kishore Gopalakrishnan +Jan-Philipp Hoffmann +Daiki Takahashi +Sayan Mitra +Aman Kumar Shukla +Zoufiné Lauer-Baré +Charles Harris +Tejaswini Sanapathi +Devansh +Aaron Gokaslan +Daan Koning (he/him) +Steven Burns +Jay Patankar +Vivek Soni +Le Cong Minh Hieu +Sam Ritchie +Maciej Skórski +Tilo Reneau-Cardoso +Laurence Warne +Lukas Molleman +Konstantinos Riganas +Grace Su +Pedro Rosa +Abhinav Cillanki +Baiyuan Qiu <1061688677@qq.com> +Liwei Cai +Daniel Weindl +Isidora Araya +Seb Tiburzio +Victory Omole +Abhishek Chaudhary +Alexander Zhura +Shuai Zhou +Martin Manns +John Möller +zzc <1378113190@qq.com> +Pablo Galindo Salgado +Johannes Kasimir +Theodore Dias +Kaustubh <90597818+kaustubh-765@users.noreply.github.com> +Idan Pazi +Bobby Palmer +Saikat Das +Suman mondal +Taylan Sahin +Fabio Luporini +Oriel Malihi +Geetika Vadali +Matthias Rettl +Mikhail Remnev +philwillnyc <56197213+philwillnyc@users.noreply.github.com> +Raphael Lehner +Harry Mountain +Bhavik Sachdev +袁野 (Yuan Ye) +fazledyn-or +mohammedouahman +K. Kraus +Zac Hatfield-Dodds +platypus +codecruisader +James Whitehead +atharvParlikar +Ivan Petukhov +Augusto Borges +Han Wei Ang +Congxu Yang +Saicharan <62512681+saicharan2804@users.noreply.github.com> +Arnab Nandi +Harrison Oates <48871176+HarrisonOates@users.noreply.github.com> +Corey Cerovsek +Harsh Kasat +omahs <73983677+omahs@users.noreply.github.com> +Pascal Gitz +Ravindu-Hirimuthugoda +Sophia Pustova +George Pittock +Warren Jacinto +Sachin Singh +Zedmat <104870914+harshkasat@users.noreply.github.com> +Soumendra Ganguly +Samith Karunathilake <55777141+samithkavishke@users.noreply.github.com> +Viraj Vekaria +Shishir Kushwaha +Ankit Kumar Singh +Abhishek Kumar +Mohak Malviya +Matthias Liesenfeld <116307294+maliesen@users.noreply.github.com> +dodo +Mohamed Rezk +Tommaso Vaccari <05-gesto-follemente@icloud.com> +Alexis Schotte +Lauren Yim <31467609+cherryblossom000@users.noreply.github.com> +Prey Patel +Riccardo Di Girolamo +Abhishek kumar +Sam Lubelsky +Henrique Soares +Vladimir Sereda +Hwayeon Kang +Raj Sapale +Gerald Teschl +Richard Samuel <98638849+samuelard7@users.noreply.github.com> +HeeJae Chang +Nick Harder +Ethan DeGuire +Lorenz Winkler +Richard Rodenbusch +Zhenxu Zhu +Mark van Gelder +Mark van Gelder +Ishan Pandhare <91841626+Ishanned@users.noreply.github.com> +James A. Preiss +Emile Fourcini +Alberto Jiménez Ruiz +João Bravo +Dean Price +Edward Z. Yang +James Titus +Zhuoyuan Li +Hugo Kerstens +Jan Jancar +Andrew Mosson +Marek Madejski +Gonzalo Tornaría +Peter Stahlecker +Jean-François B <2589111+jfbu@users.noreply.github.com> +Zexuan Zhou (Bruce) +George Frolov +Corbet Elkins +Håkon Kvernmoen +Muhammad Maaz +Shishir Kushwaha <138311586+shishir-11@users.noreply.github.com> +Matt Wang +bharatAmeria <21001019007@jcboseust.ac.in> +Amir Ebrahimi +Steven Esquea +Rishabh Kamboj <111004091+VectorNd@users.noreply.github.com> +Aasim Ali +Ivan A. Melnikov +Borek Saheli +Guido Roncarolo +Quek Zi Yao +Roelof Rietbroek +MostafaGalal1 +Au Huishan +Kris Katterjohn +Shiyao Guo +Rushabh Mehta +Temiloluwa Yusuf ytemiloluwa@gmail.com ytemiloluwa +Davi Laerte +Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com> +Harshit Gupta +Praveen Perumal +Kevin McWhirter +Prayag V +Lucas Kletzander +Pratyksh Gupta +Leonardo Mangani +Karan Anand +Gagan Mishra +Krishnav Bajoria +Matt Ord +Jatin Bhardwaj +Prashant Tandon +Paramjit Singh +João Rodrigues +Alejandro García Prada <114813960+AlexGarciaPrada@users.noreply.github.com> +Matthew Treinish +Clayton Rabideau +Victoria Koval +Voaides Negustor Robert <134785947+voaidesr@users.noreply.github.com> +Ovsk Mendov +David Brooks +Nicholas Laustrup <124007393+nicklaustrup@users.noreply.github.com> +Harikrishna Srinivasan +Mathis Cros +Arnav Mummineni <45217840+RCoder01@users.noreply.github.com> +Thangaraju Sibiraj <85477603+t-sibiraj@users.noreply.github.com> +KJaybhaye diff --git a/phivenv/Lib/site-packages/sympy-1.14.0.dist-info/licenses/LICENSE b/phivenv/Lib/site-packages/sympy-1.14.0.dist-info/licenses/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..0744f229d697ca3ed1b1b257bfdb70e3eecf0b9e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy-1.14.0.dist-info/licenses/LICENSE @@ -0,0 +1,153 @@ +Copyright (c) 2006-2023 SymPy Development Team + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + a. Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + b. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + c. Neither the name of SymPy nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. + + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH +DAMAGE. + +-------------------------------------------------------------------------------- + +Patches that were taken from the Diofant project (https://github.com/diofant/diofant) +are licensed as: + +Copyright (c) 2006-2018 SymPy Development Team, + 2013-2023 Sergey B Kirpichev + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + a. Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + b. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + c. Neither the name of Diofant or SymPy nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. + + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH +DAMAGE. + +-------------------------------------------------------------------------------- + +Submodules taken from the multipledispatch project (https://github.com/mrocklin/multipledispatch) +are licensed as: + +Copyright (c) 2014 Matthew Rocklin + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + a. Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + b. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + c. Neither the name of multipledispatch nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. + + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH +DAMAGE. + +-------------------------------------------------------------------------------- + +The files under the directory sympy/parsing/autolev/tests/pydy-example-repo +are directly copied from PyDy project and are licensed as: + +Copyright (c) 2009-2023, PyDy Authors +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. +* Neither the name of this project nor the names of its contributors may be + used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL PYDY AUTHORS BE LIABLE FOR ANY DIRECT, +INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE +OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +-------------------------------------------------------------------------------- + +The files under the directory sympy/parsing/latex +are directly copied from latex2sympy project and are licensed as: + +Copyright 2016, latex2sympy + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/phivenv/Lib/site-packages/sympy-1.14.0.dist-info/top_level.txt b/phivenv/Lib/site-packages/sympy-1.14.0.dist-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..0aa85c253508b68bb075f556be3c3f76dc4467ad --- /dev/null +++ b/phivenv/Lib/site-packages/sympy-1.14.0.dist-info/top_level.txt @@ -0,0 +1,2 @@ +isympy +sympy diff --git a/phivenv/Lib/site-packages/sympy/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d18ac9ff023e1da595eb6d5c11a0014fc101c76 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/__pycache__/abc.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/__pycache__/abc.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f19f038ae4cdd2962ee4e57fbafbc625846527a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/__pycache__/abc.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/__pycache__/conftest.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/__pycache__/conftest.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a1e6f60e369eec15c6b554f5c50c0fdddef2b2e Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/__pycache__/conftest.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/__pycache__/galgebra.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/__pycache__/galgebra.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa4467c9cec1281a9a97c7d2735a32d78479712f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/__pycache__/galgebra.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/__pycache__/release.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/__pycache__/release.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08c7a7212e575cd61e875cc92a964c00c7e3a835 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/__pycache__/release.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/__pycache__/this.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/__pycache__/this.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d67209d8f6a3ba612bef990105042effcd32ccea Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/__pycache__/this.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/testing/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/testing/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8552d7d595e5b68c72daed26b416987cad188ae Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/testing/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/testing/__pycache__/matrices.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/testing/__pycache__/matrices.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c8ddf9613769ad7a3c9e8327eccd62024b88080 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/testing/__pycache__/matrices.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/testing/__pycache__/pytest.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/testing/__pycache__/pytest.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bda2b2568a7d1a27a3022141957512ba4996c5e4 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/testing/__pycache__/pytest.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/testing/__pycache__/quality_unicode.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/testing/__pycache__/quality_unicode.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d5feb14eddcc506673511479ce31ed740a21456 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/testing/__pycache__/quality_unicode.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/testing/__pycache__/randtest.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/testing/__pycache__/randtest.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7371d389285c247ad5c4a41c4144859ac70efadf Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/testing/__pycache__/randtest.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/testing/__pycache__/runtests.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/testing/__pycache__/runtests.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd6b556c96a26d335b9fa18315b24fcc58d66d05 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/testing/__pycache__/runtests.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/testing/__pycache__/runtests_pytest.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/testing/__pycache__/runtests_pytest.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4141f583b026ed77ffad6567cfbd93adcc3a75b0 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/testing/__pycache__/runtests_pytest.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/testing/__pycache__/tmpfiles.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/testing/__pycache__/tmpfiles.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7de07d63d873ba999287ca778b8ad1fe01659d3 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/testing/__pycache__/tmpfiles.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/testing/tests/__init__.py b/phivenv/Lib/site-packages/sympy/testing/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/testing/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/testing/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e058e857a5eebfa47b44b8b918bc2f598316fbed Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/testing/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/testing/tests/__pycache__/diagnose_imports.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/testing/tests/__pycache__/diagnose_imports.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c9b16a75a6b7f5bc49c2d61d65385bdb2a821b0 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/testing/tests/__pycache__/diagnose_imports.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/testing/tests/__pycache__/test_code_quality.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/testing/tests/__pycache__/test_code_quality.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5287cd241a42525d335327a9f97ec08536d8acf Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/testing/tests/__pycache__/test_code_quality.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/testing/tests/__pycache__/test_deprecated.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/testing/tests/__pycache__/test_deprecated.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fba316ef584561c87ffacd57f2f34ad61257431c Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/testing/tests/__pycache__/test_deprecated.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/testing/tests/__pycache__/test_module_imports.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/testing/tests/__pycache__/test_module_imports.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ba3741ef4617aad67bdb0146e853dfa3103e837 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/testing/tests/__pycache__/test_module_imports.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/testing/tests/__pycache__/test_pytest.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/testing/tests/__pycache__/test_pytest.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f243f169deb2bba3acd9b5166d14504c418afa5a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/testing/tests/__pycache__/test_pytest.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/testing/tests/__pycache__/test_runtests_pytest.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/testing/tests/__pycache__/test_runtests_pytest.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bec10bb7269bb440153e930d46789baca8883fb Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/testing/tests/__pycache__/test_runtests_pytest.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/unify/__init__.py b/phivenv/Lib/site-packages/sympy/unify/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5c166f9163785f4aa5744324eb817bef79b33525 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/unify/__init__.py @@ -0,0 +1,15 @@ +""" Unification in SymPy + +See sympy.unify.core docstring for algorithmic details + +See http://matthewrocklin.com/blog/work/2012/11/01/Unification/ for discussion +""" + +from .usympy import unify, rebuild +from .rewrite import rewriterule + +__all__ = [ + 'unify', 'rebuild', + + 'rewriterule', +] diff --git a/phivenv/Lib/site-packages/sympy/unify/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/unify/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0a1a22d61f236d5613eca6d0152c62b03f11657 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/unify/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/unify/__pycache__/core.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/unify/__pycache__/core.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..354d6def3bd1dd41f65a42b6a1495941133dd9c3 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/unify/__pycache__/core.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/unify/__pycache__/rewrite.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/unify/__pycache__/rewrite.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20fb82402e9f95524fcde5c0a662d9201aa9a2c2 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/unify/__pycache__/rewrite.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/unify/__pycache__/usympy.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/unify/__pycache__/usympy.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0704348c92a5d5ce20c7e50995991c17c155d0e5 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/unify/__pycache__/usympy.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/unify/core.py b/phivenv/Lib/site-packages/sympy/unify/core.py new file mode 100644 index 0000000000000000000000000000000000000000..5359d0bbd376e9fa9efacff1d90c0bf51414cebf --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/unify/core.py @@ -0,0 +1,234 @@ +""" Generic Unification algorithm for expression trees with lists of children + +This implementation is a direct translation of + +Artificial Intelligence: A Modern Approach by Stuart Russel and Peter Norvig +Second edition, section 9.2, page 276 + +It is modified in the following ways: + +1. We allow associative and commutative Compound expressions. This results in + combinatorial blowup. +2. We explore the tree lazily. +3. We provide generic interfaces to symbolic algebra libraries in Python. + +A more traditional version can be found here +http://aima.cs.berkeley.edu/python/logic.html +""" + +from sympy.utilities.iterables import kbins + +class Compound: + """ A little class to represent an interior node in the tree + + This is analogous to SymPy.Basic for non-Atoms + """ + def __init__(self, op, args): + self.op = op + self.args = args + + def __eq__(self, other): + return (type(self) is type(other) and self.op == other.op and + self.args == other.args) + + def __hash__(self): + return hash((type(self), self.op, self.args)) + + def __str__(self): + return "%s[%s]" % (str(self.op), ', '.join(map(str, self.args))) + +class Variable: + """ A Wild token """ + def __init__(self, arg): + self.arg = arg + + def __eq__(self, other): + return type(self) is type(other) and self.arg == other.arg + + def __hash__(self): + return hash((type(self), self.arg)) + + def __str__(self): + return "Variable(%s)" % str(self.arg) + +class CondVariable: + """ A wild token that matches conditionally. + + arg - a wild token. + valid - an additional constraining function on a match. + """ + def __init__(self, arg, valid): + self.arg = arg + self.valid = valid + + def __eq__(self, other): + return (type(self) is type(other) and + self.arg == other.arg and + self.valid == other.valid) + + def __hash__(self): + return hash((type(self), self.arg, self.valid)) + + def __str__(self): + return "CondVariable(%s)" % str(self.arg) + +def unify(x, y, s=None, **fns): + """ Unify two expressions. + + Parameters + ========== + + x, y - expression trees containing leaves, Compounds and Variables. + s - a mapping of variables to subtrees. + + Returns + ======= + + lazy sequence of mappings {Variable: subtree} + + Examples + ======== + + >>> from sympy.unify.core import unify, Compound, Variable + >>> expr = Compound("Add", ("x", "y")) + >>> pattern = Compound("Add", ("x", Variable("a"))) + >>> next(unify(expr, pattern, {})) + {Variable(a): 'y'} + """ + s = s or {} + + if x == y: + yield s + elif isinstance(x, (Variable, CondVariable)): + yield from unify_var(x, y, s, **fns) + elif isinstance(y, (Variable, CondVariable)): + yield from unify_var(y, x, s, **fns) + elif isinstance(x, Compound) and isinstance(y, Compound): + is_commutative = fns.get('is_commutative', lambda x: False) + is_associative = fns.get('is_associative', lambda x: False) + for sop in unify(x.op, y.op, s, **fns): + if is_associative(x) and is_associative(y): + a, b = (x, y) if len(x.args) < len(y.args) else (y, x) + if is_commutative(x) and is_commutative(y): + combs = allcombinations(a.args, b.args, 'commutative') + else: + combs = allcombinations(a.args, b.args, 'associative') + for aaargs, bbargs in combs: + aa = [unpack(Compound(a.op, arg)) for arg in aaargs] + bb = [unpack(Compound(b.op, arg)) for arg in bbargs] + yield from unify(aa, bb, sop, **fns) + elif len(x.args) == len(y.args): + yield from unify(x.args, y.args, sop, **fns) + + elif is_args(x) and is_args(y) and len(x) == len(y): + if len(x) == 0: + yield s + else: + for shead in unify(x[0], y[0], s, **fns): + yield from unify(x[1:], y[1:], shead, **fns) + +def unify_var(var, x, s, **fns): + if var in s: + yield from unify(s[var], x, s, **fns) + elif occur_check(var, x): + pass + elif isinstance(var, CondVariable) and var.valid(x): + yield assoc(s, var, x) + elif isinstance(var, Variable): + yield assoc(s, var, x) + +def occur_check(var, x): + """ var occurs in subtree owned by x? """ + if var == x: + return True + elif isinstance(x, Compound): + return occur_check(var, x.args) + elif is_args(x): + if any(occur_check(var, xi) for xi in x): return True + return False + +def assoc(d, key, val): + """ Return copy of d with key associated to val """ + d = d.copy() + d[key] = val + return d + +def is_args(x): + """ Is x a traditional iterable? """ + return type(x) in (tuple, list, set) + +def unpack(x): + if isinstance(x, Compound) and len(x.args) == 1: + return x.args[0] + else: + return x + +def allcombinations(A, B, ordered): + """ + Restructure A and B to have the same number of elements. + + Parameters + ========== + + ordered must be either 'commutative' or 'associative'. + + A and B can be rearranged so that the larger of the two lists is + reorganized into smaller sublists. + + Examples + ======== + + >>> from sympy.unify.core import allcombinations + >>> for x in allcombinations((1, 2, 3), (5, 6), 'associative'): print(x) + (((1,), (2, 3)), ((5,), (6,))) + (((1, 2), (3,)), ((5,), (6,))) + + >>> for x in allcombinations((1, 2, 3), (5, 6), 'commutative'): print(x) + (((1,), (2, 3)), ((5,), (6,))) + (((1, 2), (3,)), ((5,), (6,))) + (((1,), (3, 2)), ((5,), (6,))) + (((1, 3), (2,)), ((5,), (6,))) + (((2,), (1, 3)), ((5,), (6,))) + (((2, 1), (3,)), ((5,), (6,))) + (((2,), (3, 1)), ((5,), (6,))) + (((2, 3), (1,)), ((5,), (6,))) + (((3,), (1, 2)), ((5,), (6,))) + (((3, 1), (2,)), ((5,), (6,))) + (((3,), (2, 1)), ((5,), (6,))) + (((3, 2), (1,)), ((5,), (6,))) + """ + + if ordered == "commutative": + ordered = 11 + if ordered == "associative": + ordered = None + sm, bg = (A, B) if len(A) < len(B) else (B, A) + for part in kbins(list(range(len(bg))), len(sm), ordered=ordered): + if bg == B: + yield tuple((a,) for a in A), partition(B, part) + else: + yield partition(A, part), tuple((b,) for b in B) + +def partition(it, part): + """ Partition a tuple/list into pieces defined by indices. + + Examples + ======== + + >>> from sympy.unify.core import partition + >>> partition((10, 20, 30, 40), [[0, 1, 2], [3]]) + ((10, 20, 30), (40,)) + """ + return type(it)([index(it, ind) for ind in part]) + +def index(it, ind): + """ Fancy indexing into an indexable iterable (tuple, list). + + Examples + ======== + + >>> from sympy.unify.core import index + >>> index([10, 20, 30], (1, 2, 0)) + [20, 30, 10] + """ + return type(it)([it[i] for i in ind]) diff --git a/phivenv/Lib/site-packages/sympy/unify/rewrite.py b/phivenv/Lib/site-packages/sympy/unify/rewrite.py new file mode 100644 index 0000000000000000000000000000000000000000..95a6fa5ffd6a3fde94d17ee845c03bb2b44cf009 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/unify/rewrite.py @@ -0,0 +1,55 @@ +""" Functions to support rewriting of SymPy expressions """ + +from sympy.core.expr import Expr +from sympy.assumptions import ask +from sympy.strategies.tools import subs +from sympy.unify.usympy import rebuild, unify + +def rewriterule(source, target, variables=(), condition=None, assume=None): + """ Rewrite rule. + + Transform expressions that match source into expressions that match target + treating all ``variables`` as wilds. + + Examples + ======== + + >>> from sympy.abc import w, x, y, z + >>> from sympy.unify.rewrite import rewriterule + >>> from sympy import default_sort_key + >>> rl = rewriterule(x + y, x**y, [x, y]) + >>> sorted(rl(z + 3), key=default_sort_key) + [3**z, z**3] + + Use ``condition`` to specify additional requirements. Inputs are taken in + the same order as is found in variables. + + >>> rl = rewriterule(x + y, x**y, [x, y], lambda x, y: x.is_integer) + >>> list(rl(z + 3)) + [3**z] + + Use ``assume`` to specify additional requirements using new assumptions. + + >>> from sympy.assumptions import Q + >>> rl = rewriterule(x + y, x**y, [x, y], assume=Q.integer(x)) + >>> list(rl(z + 3)) + [3**z] + + Assumptions for the local context are provided at rule runtime + + >>> list(rl(w + z, Q.integer(z))) + [z**w] + """ + + def rewrite_rl(expr, assumptions=True): + for match in unify(source, expr, {}, variables=variables): + if (condition and + not condition(*[match.get(var, var) for var in variables])): + continue + if (assume and not ask(assume.xreplace(match), assumptions)): + continue + expr2 = subs(match)(target) + if isinstance(expr2, Expr): + expr2 = rebuild(expr2) + yield expr2 + return rewrite_rl diff --git a/phivenv/Lib/site-packages/sympy/unify/tests/__init__.py b/phivenv/Lib/site-packages/sympy/unify/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/unify/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/unify/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc44bc5ffff418625db5354656c401ad8e945122 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/unify/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/unify/tests/__pycache__/test_rewrite.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/unify/tests/__pycache__/test_rewrite.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5cb6a45efebc503ac6b5be13e9ce9396632254c Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/unify/tests/__pycache__/test_rewrite.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/unify/tests/__pycache__/test_sympy.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/unify/tests/__pycache__/test_sympy.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eecf0bc0e30956c0a544fc20483dea0e504364ab Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/unify/tests/__pycache__/test_sympy.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/unify/tests/__pycache__/test_unify.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/unify/tests/__pycache__/test_unify.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..feb97b1dd0b34693749a3e2ebbfb11071f3dc089 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/unify/tests/__pycache__/test_unify.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/unify/tests/test_rewrite.py b/phivenv/Lib/site-packages/sympy/unify/tests/test_rewrite.py new file mode 100644 index 0000000000000000000000000000000000000000..7b73e2856d5f6380c576220fa2780324df98091a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/unify/tests/test_rewrite.py @@ -0,0 +1,74 @@ +from sympy.unify.rewrite import rewriterule +from sympy.core.basic import Basic +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.elementary.trigonometric import sin +from sympy.abc import x, y +from sympy.strategies.rl import rebuild +from sympy.assumptions import Q + +p, q = Symbol('p'), Symbol('q') + +def test_simple(): + rl = rewriterule(Basic(p, S(1)), Basic(p, S(2)), variables=(p,)) + assert list(rl(Basic(S(3), S(1)))) == [Basic(S(3), S(2))] + + p1 = p**2 + p2 = p**3 + rl = rewriterule(p1, p2, variables=(p,)) + + expr = x**2 + assert list(rl(expr)) == [x**3] + +def test_simple_variables(): + rl = rewriterule(Basic(x, S(1)), Basic(x, S(2)), variables=(x,)) + assert list(rl(Basic(S(3), S(1)))) == [Basic(S(3), S(2))] + + rl = rewriterule(x**2, x**3, variables=(x,)) + assert list(rl(y**2)) == [y**3] + +def test_moderate(): + p1 = p**2 + q**3 + p2 = (p*q)**4 + rl = rewriterule(p1, p2, (p, q)) + + expr = x**2 + y**3 + assert list(rl(expr)) == [(x*y)**4] + +def test_sincos(): + p1 = sin(p)**2 + sin(p)**2 + p2 = 1 + rl = rewriterule(p1, p2, (p, q)) + + assert list(rl(sin(x)**2 + sin(x)**2)) == [1] + assert list(rl(sin(y)**2 + sin(y)**2)) == [1] + +def test_Exprs_ok(): + rl = rewriterule(p+q, q+p, (p, q)) + next(rl(x+y)).is_commutative + str(next(rl(x+y))) + +def test_condition_simple(): + rl = rewriterule(x, x+1, [x], lambda x: x < 10) + assert not list(rl(S(15))) + assert rebuild(next(rl(S(5)))) == 6 + + +def test_condition_multiple(): + rl = rewriterule(x + y, x**y, [x,y], lambda x, y: x.is_integer) + + a = Symbol('a') + b = Symbol('b', integer=True) + expr = a + b + assert list(rl(expr)) == [b**a] + + c = Symbol('c', integer=True) + d = Symbol('d', integer=True) + assert set(rl(c + d)) == {c**d, d**c} + +def test_assumptions(): + rl = rewriterule(x + y, x**y, [x, y], assume=Q.integer(x)) + + a, b = map(Symbol, 'ab') + expr = a + b + assert list(rl(expr, Q.integer(b))) == [b**a] diff --git a/phivenv/Lib/site-packages/sympy/unify/tests/test_sympy.py b/phivenv/Lib/site-packages/sympy/unify/tests/test_sympy.py new file mode 100644 index 0000000000000000000000000000000000000000..eca3933a91abfabdbad96f626e4da761a41b3fd2 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/unify/tests/test_sympy.py @@ -0,0 +1,162 @@ +from sympy.core.add import Add +from sympy.core.basic import Basic +from sympy.core.containers import Tuple +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.logic.boolalg import And +from sympy.core.symbol import Str +from sympy.unify.core import Compound, Variable +from sympy.unify.usympy import (deconstruct, construct, unify, is_associative, + is_commutative) +from sympy.abc import x, y, z, n + +def test_deconstruct(): + expr = Basic(S(1), S(2), S(3)) + expected = Compound(Basic, (1, 2, 3)) + assert deconstruct(expr) == expected + + assert deconstruct(1) == 1 + assert deconstruct(x) == x + assert deconstruct(x, variables=(x,)) == Variable(x) + assert deconstruct(Add(1, x, evaluate=False)) == Compound(Add, (1, x)) + assert deconstruct(Add(1, x, evaluate=False), variables=(x,)) == \ + Compound(Add, (1, Variable(x))) + +def test_construct(): + expr = Compound(Basic, (S(1), S(2), S(3))) + expected = Basic(S(1), S(2), S(3)) + assert construct(expr) == expected + +def test_nested(): + expr = Basic(S(1), Basic(S(2)), S(3)) + cmpd = Compound(Basic, (S(1), Compound(Basic, Tuple(2)), S(3))) + assert deconstruct(expr) == cmpd + assert construct(cmpd) == expr + +def test_unify(): + expr = Basic(S(1), S(2), S(3)) + a, b, c = map(Symbol, 'abc') + pattern = Basic(a, b, c) + assert list(unify(expr, pattern, {}, (a, b, c))) == [{a: 1, b: 2, c: 3}] + assert list(unify(expr, pattern, variables=(a, b, c))) == \ + [{a: 1, b: 2, c: 3}] + +def test_unify_variables(): + assert list(unify(Basic(S(1), S(2)), Basic(S(1), x), {}, variables=(x,))) == [{x: 2}] + +def test_s_input(): + expr = Basic(S(1), S(2)) + a, b = map(Symbol, 'ab') + pattern = Basic(a, b) + assert list(unify(expr, pattern, {}, (a, b))) == [{a: 1, b: 2}] + assert list(unify(expr, pattern, {a: 5}, (a, b))) == [] + +def iterdicteq(a, b): + a = tuple(a) + b = tuple(b) + return len(a) == len(b) and all(x in b for x in a) + +def test_unify_commutative(): + expr = Add(1, 2, 3, evaluate=False) + a, b, c = map(Symbol, 'abc') + pattern = Add(a, b, c, evaluate=False) + + result = tuple(unify(expr, pattern, {}, (a, b, c))) + expected = ({a: 1, b: 2, c: 3}, + {a: 1, b: 3, c: 2}, + {a: 2, b: 1, c: 3}, + {a: 2, b: 3, c: 1}, + {a: 3, b: 1, c: 2}, + {a: 3, b: 2, c: 1}) + + assert iterdicteq(result, expected) + +def test_unify_iter(): + expr = Add(1, 2, 3, evaluate=False) + a, b, c = map(Symbol, 'abc') + pattern = Add(a, c, evaluate=False) + assert is_associative(deconstruct(pattern)) + assert is_commutative(deconstruct(pattern)) + + result = list(unify(expr, pattern, {}, (a, c))) + expected = [{a: 1, c: Add(2, 3, evaluate=False)}, + {a: 1, c: Add(3, 2, evaluate=False)}, + {a: 2, c: Add(1, 3, evaluate=False)}, + {a: 2, c: Add(3, 1, evaluate=False)}, + {a: 3, c: Add(1, 2, evaluate=False)}, + {a: 3, c: Add(2, 1, evaluate=False)}, + {a: Add(1, 2, evaluate=False), c: 3}, + {a: Add(2, 1, evaluate=False), c: 3}, + {a: Add(1, 3, evaluate=False), c: 2}, + {a: Add(3, 1, evaluate=False), c: 2}, + {a: Add(2, 3, evaluate=False), c: 1}, + {a: Add(3, 2, evaluate=False), c: 1}] + + assert iterdicteq(result, expected) + +def test_hard_match(): + from sympy.functions.elementary.trigonometric import (cos, sin) + expr = sin(x) + cos(x)**2 + p, q = map(Symbol, 'pq') + pattern = sin(p) + cos(p)**2 + assert list(unify(expr, pattern, {}, (p, q))) == [{p: x}] + +def test_matrix(): + from sympy.matrices.expressions.matexpr import MatrixSymbol + X = MatrixSymbol('X', n, n) + Y = MatrixSymbol('Y', 2, 2) + Z = MatrixSymbol('Z', 2, 3) + assert list(unify(X, Y, {}, variables=[n, Str('X')])) == [{Str('X'): Str('Y'), n: 2}] + assert list(unify(X, Z, {}, variables=[n, Str('X')])) == [] + +def test_non_frankenAdds(): + # the is_commutative property used to fail because of Basic.__new__ + # This caused is_commutative and str calls to fail + expr = x+y*2 + rebuilt = construct(deconstruct(expr)) + # Ensure that we can run these commands without causing an error + str(rebuilt) + rebuilt.is_commutative + +def test_FiniteSet_commutivity(): + from sympy.sets.sets import FiniteSet + a, b, c, x, y = symbols('a,b,c,x,y') + s = FiniteSet(a, b, c) + t = FiniteSet(x, y) + variables = (x, y) + assert {x: FiniteSet(a, c), y: b} in tuple(unify(s, t, variables=variables)) + +def test_FiniteSet_complex(): + from sympy.sets.sets import FiniteSet + a, b, c, x, y, z = symbols('a,b,c,x,y,z') + expr = FiniteSet(Basic(S(1), x), y, Basic(x, z)) + pattern = FiniteSet(a, Basic(x, b)) + variables = a, b + expected = ({b: 1, a: FiniteSet(y, Basic(x, z))}, + {b: z, a: FiniteSet(y, Basic(S(1), x))}) + assert iterdicteq(unify(expr, pattern, variables=variables), expected) + + +def test_and(): + variables = x, y + expected = ({x: z > 0, y: n < 3},) + assert iterdicteq(unify((z>0) & (n<3), And(x, y), variables=variables), + expected) + +def test_Union(): + from sympy.sets.sets import Interval + assert list(unify(Interval(0, 1) + Interval(10, 11), + Interval(0, 1) + Interval(12, 13), + variables=(Interval(12, 13),))) + +def test_is_commutative(): + assert is_commutative(deconstruct(x+y)) + assert is_commutative(deconstruct(x*y)) + assert not is_commutative(deconstruct(x**y)) + +def test_commutative_in_commutative(): + from sympy.abc import a,b,c,d + from sympy.functions.elementary.trigonometric import (cos, sin) + eq = sin(3)*sin(4)*sin(5) + 4*cos(3)*cos(4) + pat = a*cos(b)*cos(c) + d*sin(b)*sin(c) + assert next(unify(eq, pat, variables=(a,b,c,d))) diff --git a/phivenv/Lib/site-packages/sympy/unify/tests/test_unify.py b/phivenv/Lib/site-packages/sympy/unify/tests/test_unify.py new file mode 100644 index 0000000000000000000000000000000000000000..31153242576e1ff55dd3097efbc985aced5d574a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/unify/tests/test_unify.py @@ -0,0 +1,88 @@ +from sympy.unify.core import Compound, Variable, CondVariable, allcombinations +from sympy.unify import core + +a,b,c = 'a', 'b', 'c' +w,x,y,z = map(Variable, 'wxyz') + +C = Compound + +def is_associative(x): + return isinstance(x, Compound) and (x.op in ('Add', 'Mul', 'CAdd', 'CMul')) +def is_commutative(x): + return isinstance(x, Compound) and (x.op in ('CAdd', 'CMul')) + + +def unify(a, b, s={}): + return core.unify(a, b, s=s, is_associative=is_associative, + is_commutative=is_commutative) + +def test_basic(): + assert list(unify(a, x, {})) == [{x: a}] + assert list(unify(a, x, {x: 10})) == [] + assert list(unify(1, x, {})) == [{x: 1}] + assert list(unify(a, a, {})) == [{}] + assert list(unify((w, x), (y, z), {})) == [{w: y, x: z}] + assert list(unify(x, (a, b), {})) == [{x: (a, b)}] + + assert list(unify((a, b), (x, x), {})) == [] + assert list(unify((y, z), (x, x), {}))!= [] + assert list(unify((a, (b, c)), (a, (x, y)), {})) == [{x: b, y: c}] + +def test_ops(): + assert list(unify(C('Add', (a,b,c)), C('Add', (a,x,y)), {})) == \ + [{x:b, y:c}] + assert list(unify(C('Add', (C('Mul', (1,2)), b,c)), C('Add', (x,y,c)), {})) == \ + [{x: C('Mul', (1,2)), y:b}] + +def test_associative(): + c1 = C('Add', (1,2,3)) + c2 = C('Add', (x,y)) + assert tuple(unify(c1, c2, {})) == ({x: 1, y: C('Add', (2, 3))}, + {x: C('Add', (1, 2)), y: 3}) + +def test_commutative(): + c1 = C('CAdd', (1,2,3)) + c2 = C('CAdd', (x,y)) + result = list(unify(c1, c2, {})) + assert {x: 1, y: C('CAdd', (2, 3))} in result + assert ({x: 2, y: C('CAdd', (1, 3))} in result or + {x: 2, y: C('CAdd', (3, 1))} in result) + +def _test_combinations_assoc(): + assert set(allcombinations((1,2,3), (a,b), True)) == \ + {(((1, 2), (3,)), (a, b)), (((1,), (2, 3)), (a, b))} + +def _test_combinations_comm(): + assert set(allcombinations((1,2,3), (a,b), None)) == \ + {(((1,), (2, 3)), ('a', 'b')), (((2,), (3, 1)), ('a', 'b')), + (((3,), (1, 2)), ('a', 'b')), (((1, 2), (3,)), ('a', 'b')), + (((2, 3), (1,)), ('a', 'b')), (((3, 1), (2,)), ('a', 'b'))} + +def test_allcombinations(): + assert set(allcombinations((1,2), (1,2), 'commutative')) ==\ + {(((1,),(2,)), ((1,),(2,))), (((1,),(2,)), ((2,),(1,)))} + + +def test_commutativity(): + c1 = Compound('CAdd', (a, b)) + c2 = Compound('CAdd', (x, y)) + assert is_commutative(c1) and is_commutative(c2) + assert len(list(unify(c1, c2, {}))) == 2 + + +def test_CondVariable(): + expr = C('CAdd', (1, 2)) + x = Variable('x') + y = CondVariable('y', lambda a: a % 2 == 0) + z = CondVariable('z', lambda a: a > 3) + pattern = C('CAdd', (x, y)) + assert list(unify(expr, pattern, {})) == \ + [{x: 1, y: 2}] + + z = CondVariable('z', lambda a: a > 3) + pattern = C('CAdd', (z, y)) + + assert list(unify(expr, pattern, {})) == [] + +def test_defaultdict(): + assert next(unify(Variable('x'), 'foo')) == {Variable('x'): 'foo'} diff --git a/phivenv/Lib/site-packages/sympy/unify/usympy.py b/phivenv/Lib/site-packages/sympy/unify/usympy.py new file mode 100644 index 0000000000000000000000000000000000000000..3942b35ec549e5dbd08a3cf1cad2b2ecea733c7a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/unify/usympy.py @@ -0,0 +1,124 @@ +""" SymPy interface to Unification engine + +See sympy.unify for module level docstring +See sympy.unify.core for algorithmic docstring """ + +from sympy.core import Basic, Add, Mul, Pow +from sympy.core.operations import AssocOp, LatticeOp +from sympy.matrices import MatAdd, MatMul, MatrixExpr +from sympy.sets.sets import Union, Intersection, FiniteSet +from sympy.unify.core import Compound, Variable, CondVariable +from sympy.unify import core + +basic_new_legal = [MatrixExpr] +eval_false_legal = [AssocOp, Pow, FiniteSet] +illegal = [LatticeOp] + +def sympy_associative(op): + assoc_ops = (AssocOp, MatAdd, MatMul, Union, Intersection, FiniteSet) + return any(issubclass(op, aop) for aop in assoc_ops) + +def sympy_commutative(op): + comm_ops = (Add, MatAdd, Union, Intersection, FiniteSet) + return any(issubclass(op, cop) for cop in comm_ops) + +def is_associative(x): + return isinstance(x, Compound) and sympy_associative(x.op) + +def is_commutative(x): + if not isinstance(x, Compound): + return False + if sympy_commutative(x.op): + return True + if issubclass(x.op, Mul): + return all(construct(arg).is_commutative for arg in x.args) + +def mk_matchtype(typ): + def matchtype(x): + return (isinstance(x, typ) or + isinstance(x, Compound) and issubclass(x.op, typ)) + return matchtype + +def deconstruct(s, variables=()): + """ Turn a SymPy object into a Compound """ + if s in variables: + return Variable(s) + if isinstance(s, (Variable, CondVariable)): + return s + if not isinstance(s, Basic) or s.is_Atom: + return s + return Compound(s.__class__, + tuple(deconstruct(arg, variables) for arg in s.args)) + +def construct(t): + """ Turn a Compound into a SymPy object """ + if isinstance(t, (Variable, CondVariable)): + return t.arg + if not isinstance(t, Compound): + return t + if any(issubclass(t.op, cls) for cls in eval_false_legal): + return t.op(*map(construct, t.args), evaluate=False) + elif any(issubclass(t.op, cls) for cls in basic_new_legal): + return Basic.__new__(t.op, *map(construct, t.args)) + else: + return t.op(*map(construct, t.args)) + +def rebuild(s): + """ Rebuild a SymPy expression. + + This removes harm caused by Expr-Rules interactions. + """ + return construct(deconstruct(s)) + +def unify(x, y, s=None, variables=(), **kwargs): + """ Structural unification of two expressions/patterns. + + Examples + ======== + + >>> from sympy.unify.usympy import unify + >>> from sympy import Basic, S + >>> from sympy.abc import x, y, z, p, q + + >>> next(unify(Basic(S(1), S(2)), Basic(S(1), x), variables=[x])) + {x: 2} + + >>> expr = 2*x + y + z + >>> pattern = 2*p + q + >>> next(unify(expr, pattern, {}, variables=(p, q))) + {p: x, q: y + z} + + Unification supports commutative and associative matching + + >>> expr = x + y + z + >>> pattern = p + q + >>> len(list(unify(expr, pattern, {}, variables=(p, q)))) + 12 + + Symbols not indicated to be variables are treated as literal, + else they are wild-like and match anything in a sub-expression. + + >>> expr = x*y*z + 3 + >>> pattern = x*y + 3 + >>> next(unify(expr, pattern, {}, variables=[x, y])) + {x: y, y: x*z} + + The x and y of the pattern above were in a Mul and matched factors + in the Mul of expr. Here, a single symbol matches an entire term: + + >>> expr = x*y + 3 + >>> pattern = p + 3 + >>> next(unify(expr, pattern, {}, variables=[p])) + {p: x*y} + + """ + decons = lambda x: deconstruct(x, variables) + s = s or {} + s = {decons(k): decons(v) for k, v in s.items()} + + ds = core.unify(decons(x), decons(y), s, + is_associative=is_associative, + is_commutative=is_commutative, + **kwargs) + for d in ds: + yield {construct(k): construct(v) for k, v in d.items()} diff --git a/phivenv/Lib/site-packages/sympy/utilities/__init__.py b/phivenv/Lib/site-packages/sympy/utilities/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8f35da4a84396618a33a12c3c6b5cf58e9d4742c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/__init__.py @@ -0,0 +1,30 @@ +"""This module contains some general purpose utilities that are used across +SymPy. +""" +from .iterables import (flatten, group, take, subsets, + variations, numbered_symbols, cartes, capture, dict_merge, + prefixes, postfixes, sift, topological_sort, unflatten, + has_dups, has_variety, reshape, rotations) + +from .misc import filldedent + +from .lambdify import lambdify + +from .decorator import threaded, xthreaded, public, memoize_property + +from .timeutils import timed + +__all__ = [ + 'flatten', 'group', 'take', 'subsets', 'variations', 'numbered_symbols', + 'cartes', 'capture', 'dict_merge', 'prefixes', 'postfixes', 'sift', + 'topological_sort', 'unflatten', 'has_dups', 'has_variety', 'reshape', + 'rotations', + + 'filldedent', + + 'lambdify', + + 'threaded', 'xthreaded', 'public', 'memoize_property', + + 'timed', +] diff --git a/phivenv/Lib/site-packages/sympy/utilities/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4a9a7b60bbe5b5e362e207d2d84abaa165b2dd0 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/__pycache__/autowrap.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/__pycache__/autowrap.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66f5abd6f69de2a3b5129f3eefbd1476444b97d8 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/__pycache__/autowrap.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/__pycache__/codegen.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/__pycache__/codegen.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d297ad4ade704c36e48763917bc81a247ec9c4c9 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/__pycache__/codegen.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/__pycache__/decorator.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/__pycache__/decorator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ceeaf45de5d929e9d403b4bdca0c361a7b26ed5 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/__pycache__/decorator.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/__pycache__/enumerative.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/__pycache__/enumerative.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b154e8f6797eb3aa3bca80843a8b2e8373e73589 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/__pycache__/enumerative.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/__pycache__/exceptions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/__pycache__/exceptions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aebeda25b2c86ad7ade1de1a7cd6132cd2c17c9e Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/__pycache__/exceptions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/__pycache__/iterables.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/__pycache__/iterables.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1887e8f20d692dd119f73b787ba1b2c3ad88e2b Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/__pycache__/iterables.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/__pycache__/lambdify.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/__pycache__/lambdify.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4062b6d0ed19851b8d3815fe2444b0c29fe9a9a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/__pycache__/lambdify.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/__pycache__/magic.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/__pycache__/magic.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3117cfeef6cbe62db2d303fc9278c8c2ec5a469f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/__pycache__/magic.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/__pycache__/matchpy_connector.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/__pycache__/matchpy_connector.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab2ef505a5f66eb32a737a112a04e9a551ef9a72 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/__pycache__/matchpy_connector.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/__pycache__/memoization.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/__pycache__/memoization.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4ad81ff1a01aa00c54fa319ff44d9ce9aaa6aae Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/__pycache__/memoization.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/__pycache__/misc.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/__pycache__/misc.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..310dd1d30da6210a5a5b8e52d97dccef946141f2 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/__pycache__/misc.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/__pycache__/pkgdata.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/__pycache__/pkgdata.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aabf90802693cbe7f3dcbb6352cf778b7651dbf5 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/__pycache__/pkgdata.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/__pycache__/pytest.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/__pycache__/pytest.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01ff9d78330cd32350c297d33c3137e00929d0d7 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/__pycache__/pytest.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/__pycache__/randtest.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/__pycache__/randtest.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4bfbb64f7cf9d46763271747f9784b5044e82ef Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/__pycache__/randtest.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/__pycache__/runtests.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/__pycache__/runtests.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90a3f39e9437bee307d563f253e23db5cd7d5dc1 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/__pycache__/runtests.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/__pycache__/source.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/__pycache__/source.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8ed82b71707f588ae7131b4f377d00f32481194 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/__pycache__/source.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/__pycache__/timeutils.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/__pycache__/timeutils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a59c1d2dfd76694d4f6d6e727f2ca5402e3d3360 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/__pycache__/timeutils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/__pycache__/tmpfiles.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/__pycache__/tmpfiles.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0c1dc7c76d5c0471c59fbf1608c900782367c59 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/__pycache__/tmpfiles.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/_compilation/__init__.py b/phivenv/Lib/site-packages/sympy/utilities/_compilation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d2c05ad48a93493bb5434256c88dfd614ac47b2d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/_compilation/__init__.py @@ -0,0 +1,22 @@ +""" This sub-module is private, i.e. external code should not depend on it. + +These functions are used by tests run as part of continuous integration. +Once the implementation is mature (it should support the major +platforms: Windows, OS X & Linux) it may become official API which + may be relied upon by downstream libraries. Until then API may break +without prior notice. + +TODO: +- (optionally) clean up after tempfile.mkdtemp() +- cross-platform testing +- caching of compiler choice and intermediate files + +""" + +from .compilation import compile_link_import_strings, compile_run_strings +from .availability import has_fortran, has_c, has_cxx + +__all__ = [ + 'compile_link_import_strings', 'compile_run_strings', + 'has_fortran', 'has_c', 'has_cxx', +] diff --git a/phivenv/Lib/site-packages/sympy/utilities/_compilation/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/_compilation/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..582087d350bba5b6e2808d964fc9bfda7977b10e Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/_compilation/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/_compilation/__pycache__/availability.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/_compilation/__pycache__/availability.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a80ebd892f8b34b9436f7079fca75b499fa1d21 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/_compilation/__pycache__/availability.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/_compilation/__pycache__/compilation.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/_compilation/__pycache__/compilation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af91110e2cd80bd62e64bd7b76780d2ad384436d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/_compilation/__pycache__/compilation.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/_compilation/__pycache__/runners.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/_compilation/__pycache__/runners.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c2ff26a44919e94c4975b2e92d4faf948effe11 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/_compilation/__pycache__/runners.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/_compilation/__pycache__/util.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/_compilation/__pycache__/util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa6fee91e9dadcfe915923d64e673b597e3b3e2d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/_compilation/__pycache__/util.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/_compilation/availability.py b/phivenv/Lib/site-packages/sympy/utilities/_compilation/availability.py new file mode 100644 index 0000000000000000000000000000000000000000..dc97b3e7b8c7e7307c6c21352ed4035d977aabb3 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/_compilation/availability.py @@ -0,0 +1,77 @@ +import os +from .compilation import compile_run_strings +from .util import CompilerNotFoundError + +def has_fortran(): + if not hasattr(has_fortran, 'result'): + try: + (stdout, stderr), info = compile_run_strings( + [('main.f90', ( + 'program foo\n' + 'print *, "hello world"\n' + 'end program' + ))], clean=True + ) + except CompilerNotFoundError: + has_fortran.result = False + if os.environ.get('SYMPY_STRICT_COMPILER_CHECKS', '0') == '1': + raise + else: + if info['exit_status'] != os.EX_OK or 'hello world' not in stdout: + if os.environ.get('SYMPY_STRICT_COMPILER_CHECKS', '0') == '1': + raise ValueError("Failed to compile test program:\n%s\n%s\n" % (stdout, stderr)) + has_fortran.result = False + else: + has_fortran.result = True + return has_fortran.result + + +def has_c(): + if not hasattr(has_c, 'result'): + try: + (stdout, stderr), info = compile_run_strings( + [('main.c', ( + '#include \n' + 'int main(){\n' + 'printf("hello world\\n");\n' + 'return 0;\n' + '}' + ))], clean=True + ) + except CompilerNotFoundError: + has_c.result = False + if os.environ.get('SYMPY_STRICT_COMPILER_CHECKS', '0') == '1': + raise + else: + if info['exit_status'] != os.EX_OK or 'hello world' not in stdout: + if os.environ.get('SYMPY_STRICT_COMPILER_CHECKS', '0') == '1': + raise ValueError("Failed to compile test program:\n%s\n%s\n" % (stdout, stderr)) + has_c.result = False + else: + has_c.result = True + return has_c.result + + +def has_cxx(): + if not hasattr(has_cxx, 'result'): + try: + (stdout, stderr), info = compile_run_strings( + [('main.cxx', ( + '#include \n' + 'int main(){\n' + 'std::cout << "hello world" << std::endl;\n' + '}' + ))], clean=True + ) + except CompilerNotFoundError: + has_cxx.result = False + if os.environ.get('SYMPY_STRICT_COMPILER_CHECKS', '0') == '1': + raise + else: + if info['exit_status'] != os.EX_OK or 'hello world' not in stdout: + if os.environ.get('SYMPY_STRICT_COMPILER_CHECKS', '0') == '1': + raise ValueError("Failed to compile test program:\n%s\n%s\n" % (stdout, stderr)) + has_cxx.result = False + else: + has_cxx.result = True + return has_cxx.result diff --git a/phivenv/Lib/site-packages/sympy/utilities/_compilation/compilation.py b/phivenv/Lib/site-packages/sympy/utilities/_compilation/compilation.py new file mode 100644 index 0000000000000000000000000000000000000000..2d50a20467c08086d6cb5fb5b0d478e7a953d720 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/_compilation/compilation.py @@ -0,0 +1,657 @@ +import glob +import os +import shutil +import subprocess +import sys +import tempfile +import warnings +from pathlib import Path +from sysconfig import get_config_var, get_config_vars, get_path + +from .runners import ( + CCompilerRunner, + CppCompilerRunner, + FortranCompilerRunner +) +from .util import ( + get_abspath, make_dirs, copy, Glob, ArbitraryDepthGlob, + glob_at_depth, import_module_from_file, pyx_is_cplus, + sha256_of_string, sha256_of_file, CompileError +) + +if os.name == 'posix': + objext = '.o' +elif os.name == 'nt': + objext = '.obj' +else: + warnings.warn("Unknown os.name: {}".format(os.name)) + objext = '.o' + + +def compile_sources(files, Runner=None, destdir=None, cwd=None, keep_dir_struct=False, + per_file_kwargs=None, **kwargs): + """ Compile source code files to object files. + + Parameters + ========== + + files : iterable of str + Paths to source files, if ``cwd`` is given, the paths are taken as relative. + Runner: CompilerRunner subclass (optional) + Could be e.g. ``FortranCompilerRunner``. Will be inferred from filename + extensions if missing. + destdir: str + Output directory, if cwd is given, the path is taken as relative. + cwd: str + Working directory. Specify to have compiler run in other directory. + also used as root of relative paths. + keep_dir_struct: bool + Reproduce directory structure in `destdir`. default: ``False`` + per_file_kwargs: dict + Dict mapping instances in ``files`` to keyword arguments. + \\*\\*kwargs: dict + Default keyword arguments to pass to ``Runner``. + + Returns + ======= + List of strings (paths of object files). + """ + _per_file_kwargs = {} + + if per_file_kwargs is not None: + for k, v in per_file_kwargs.items(): + if isinstance(k, Glob): + for path in glob.glob(k.pathname): + _per_file_kwargs[path] = v + elif isinstance(k, ArbitraryDepthGlob): + for path in glob_at_depth(k.filename, cwd): + _per_file_kwargs[path] = v + else: + _per_file_kwargs[k] = v + + # Set up destination directory + destdir = destdir or '.' + if not os.path.isdir(destdir): + if os.path.exists(destdir): + raise OSError("{} is not a directory".format(destdir)) + else: + make_dirs(destdir) + if cwd is None: + cwd = '.' + for f in files: + copy(f, destdir, only_update=True, dest_is_dir=True) + + # Compile files and return list of paths to the objects + dstpaths = [] + for f in files: + if keep_dir_struct: + name, ext = os.path.splitext(f) + else: + name, ext = os.path.splitext(os.path.basename(f)) + file_kwargs = kwargs.copy() + file_kwargs.update(_per_file_kwargs.get(f, {})) + dstpaths.append(src2obj(f, Runner, cwd=cwd, **file_kwargs)) + return dstpaths + + +def get_mixed_fort_c_linker(vendor=None, cplus=False, cwd=None): + vendor = vendor or os.environ.get('SYMPY_COMPILER_VENDOR', 'gnu') + + if vendor.lower() == 'intel': + if cplus: + return (FortranCompilerRunner, + {'flags': ['-nofor_main', '-cxxlib']}, vendor) + else: + return (FortranCompilerRunner, + {'flags': ['-nofor_main']}, vendor) + elif vendor.lower() == 'gnu' or 'llvm': + if cplus: + return (CppCompilerRunner, + {'lib_options': ['fortran']}, vendor) + else: + return (FortranCompilerRunner, + {}, vendor) + else: + raise ValueError("No vendor found.") + + +def link(obj_files, out_file=None, shared=False, Runner=None, + cwd=None, cplus=False, fort=False, extra_objs=None, **kwargs): + """ Link object files. + + Parameters + ========== + + obj_files: iterable of str + Paths to object files. + out_file: str (optional) + Path to executable/shared library, if ``None`` it will be + deduced from the last item in obj_files. + shared: bool + Generate a shared library? + Runner: CompilerRunner subclass (optional) + If not given the ``cplus`` and ``fort`` flags will be inspected + (fallback is the C compiler). + cwd: str + Path to the root of relative paths and working directory for compiler. + cplus: bool + C++ objects? default: ``False``. + fort: bool + Fortran objects? default: ``False``. + extra_objs: list + List of paths to extra object files / static libraries. + \\*\\*kwargs: dict + Keyword arguments passed to ``Runner``. + + Returns + ======= + + The absolute path to the generated shared object / executable. + + """ + if out_file is None: + out_file, ext = os.path.splitext(os.path.basename(obj_files[-1])) + if shared: + out_file += get_config_var('EXT_SUFFIX') + + if not Runner: + if fort: + Runner, extra_kwargs, vendor = \ + get_mixed_fort_c_linker( + vendor=kwargs.get('vendor', None), + cplus=cplus, + cwd=cwd, + ) + for k, v in extra_kwargs.items(): + if k in kwargs: + kwargs[k].expand(v) + else: + kwargs[k] = v + else: + if cplus: + Runner = CppCompilerRunner + else: + Runner = CCompilerRunner + + flags = kwargs.pop('flags', []) + if shared: + if '-shared' not in flags: + flags.append('-shared') + run_linker = kwargs.pop('run_linker', True) + if not run_linker: + raise ValueError("run_linker was set to False (nonsensical).") + + out_file = get_abspath(out_file, cwd=cwd) + runner = Runner(obj_files+(extra_objs or []), out_file, flags, cwd=cwd, **kwargs) + runner.run() + return out_file + + +def link_py_so(obj_files, so_file=None, cwd=None, libraries=None, + cplus=False, fort=False, extra_objs=None, **kwargs): + """ Link Python extension module (shared object) for importing + + Parameters + ========== + + obj_files: iterable of str + Paths to object files to be linked. + so_file: str + Name (path) of shared object file to create. If not specified it will + have the basname of the last object file in `obj_files` but with the + extension '.so' (Unix). + cwd: path string + Root of relative paths and working directory of linker. + libraries: iterable of strings + Libraries to link against, e.g. ['m']. + cplus: bool + Any C++ objects? default: ``False``. + fort: bool + Any Fortran objects? default: ``False``. + extra_objs: list + List of paths of extra object files / static libraries to link against. + kwargs**: dict + Keyword arguments passed to ``link(...)``. + + Returns + ======= + + Absolute path to the generate shared object. + """ + libraries = libraries or [] + + include_dirs = kwargs.pop('include_dirs', []) + library_dirs = kwargs.pop('library_dirs', []) + + # Add Python include and library directories + # PY_LDFLAGS does not available on all python implementations + # e.g. when with pypy, so it's LDFLAGS we need to use + if sys.platform == "win32": + warnings.warn("Windows not yet supported.") + elif sys.platform == 'darwin': + cfgDict = get_config_vars() + kwargs['linkline'] = kwargs.get('linkline', []) + [cfgDict['LDFLAGS']] + library_dirs += [cfgDict['LIBDIR']] + + # In macOS, linker needs to compile frameworks + # e.g. "-framework CoreFoundation" + is_framework = False + for opt in cfgDict['LIBS'].split(): + if is_framework: + kwargs['linkline'] = kwargs.get('linkline', []) + ['-framework', opt] + is_framework = False + elif opt.startswith('-l'): + libraries.append(opt[2:]) + elif opt.startswith('-framework'): + is_framework = True + # The python library is not included in LIBS + libfile = cfgDict['LIBRARY'] + libname = ".".join(libfile.split('.')[:-1])[3:] + libraries.append(libname) + + elif sys.platform[:3] == 'aix': + # Don't use the default code below + pass + else: + if get_config_var('Py_ENABLE_SHARED'): + cfgDict = get_config_vars() + kwargs['linkline'] = kwargs.get('linkline', []) + [cfgDict['LDFLAGS']] + library_dirs += [cfgDict['LIBDIR']] + for opt in cfgDict['BLDLIBRARY'].split(): + if opt.startswith('-l'): + libraries += [opt[2:]] + else: + pass + + flags = kwargs.pop('flags', []) + needed_flags = ('-pthread',) + for flag in needed_flags: + if flag not in flags: + flags.append(flag) + + return link(obj_files, shared=True, flags=flags, cwd=cwd, cplus=cplus, fort=fort, + include_dirs=include_dirs, libraries=libraries, + library_dirs=library_dirs, extra_objs=extra_objs, **kwargs) + + +def simple_cythonize(src, destdir=None, cwd=None, **cy_kwargs): + """ Generates a C file from a Cython source file. + + Parameters + ========== + + src: str + Path to Cython source. + destdir: str (optional) + Path to output directory (default: '.'). + cwd: path string (optional) + Root of relative paths (default: '.'). + **cy_kwargs: + Second argument passed to cy_compile. Generates a .cpp file if ``cplus=True`` in ``cy_kwargs``, + else a .c file. + """ + from Cython.Compiler.Main import ( + default_options, CompilationOptions + ) + from Cython.Compiler.Main import compile as cy_compile + + assert src.lower().endswith('.pyx') or src.lower().endswith('.py') + cwd = cwd or '.' + destdir = destdir or '.' + + ext = '.cpp' if cy_kwargs.get('cplus', False) else '.c' + c_name = os.path.splitext(os.path.basename(src))[0] + ext + + dstfile = os.path.join(destdir, c_name) + + if cwd: + ori_dir = os.getcwd() + else: + ori_dir = '.' + os.chdir(cwd) + try: + cy_options = CompilationOptions(default_options) + cy_options.__dict__.update(cy_kwargs) + # Set language_level if not set by cy_kwargs + # as not setting it is deprecated + if 'language_level' not in cy_kwargs: + cy_options.__dict__['language_level'] = 3 + cy_result = cy_compile([src], cy_options) + if cy_result.num_errors > 0: + raise ValueError("Cython compilation failed.") + + # Move generated C file to destination + # In macOS, the generated C file is in the same directory as the source + # but the /var is a symlink to /private/var, so we need to use realpath + if os.path.realpath(os.path.dirname(src)) != os.path.realpath(destdir): + if os.path.exists(dstfile): + os.unlink(dstfile) + shutil.move(os.path.join(os.path.dirname(src), c_name), destdir) + finally: + os.chdir(ori_dir) + return dstfile + + +extension_mapping = { + '.c': (CCompilerRunner, None), + '.cpp': (CppCompilerRunner, None), + '.cxx': (CppCompilerRunner, None), + '.f': (FortranCompilerRunner, None), + '.for': (FortranCompilerRunner, None), + '.ftn': (FortranCompilerRunner, None), + '.f90': (FortranCompilerRunner, None), # ifort only knows about .f90 + '.f95': (FortranCompilerRunner, 'f95'), + '.f03': (FortranCompilerRunner, 'f2003'), + '.f08': (FortranCompilerRunner, 'f2008'), +} + + +def src2obj(srcpath, Runner=None, objpath=None, cwd=None, inc_py=False, **kwargs): + """ Compiles a source code file to an object file. + + Files ending with '.pyx' assumed to be cython files and + are dispatched to pyx2obj. + + Parameters + ========== + + srcpath: str + Path to source file. + Runner: CompilerRunner subclass (optional) + If ``None``: deduced from extension of srcpath. + objpath : str (optional) + Path to generated object. If ``None``: deduced from ``srcpath``. + cwd: str (optional) + Working directory and root of relative paths. If ``None``: current dir. + inc_py: bool + Add Python include path to kwarg "include_dirs". Default: False + \\*\\*kwargs: dict + keyword arguments passed to Runner or pyx2obj + + """ + name, ext = os.path.splitext(os.path.basename(srcpath)) + if objpath is None: + if os.path.isabs(srcpath): + objpath = '.' + else: + objpath = os.path.dirname(srcpath) + objpath = objpath or '.' # avoid objpath == '' + + if os.path.isdir(objpath): + objpath = os.path.join(objpath, name + objext) + + include_dirs = kwargs.pop('include_dirs', []) + if inc_py: + py_inc_dir = get_path('include') + if py_inc_dir not in include_dirs: + include_dirs.append(py_inc_dir) + + if ext.lower() == '.pyx': + return pyx2obj(srcpath, objpath=objpath, include_dirs=include_dirs, cwd=cwd, + **kwargs) + + if Runner is None: + Runner, std = extension_mapping[ext.lower()] + if 'std' not in kwargs: + kwargs['std'] = std + + flags = kwargs.pop('flags', []) + needed_flags = ('-fPIC',) + for flag in needed_flags: + if flag not in flags: + flags.append(flag) + + # src2obj implies not running the linker... + run_linker = kwargs.pop('run_linker', False) + if run_linker: + raise CompileError("src2obj called with run_linker=True") + + runner = Runner([srcpath], objpath, include_dirs=include_dirs, + run_linker=run_linker, cwd=cwd, flags=flags, **kwargs) + runner.run() + return objpath + + +def pyx2obj(pyxpath, objpath=None, destdir=None, cwd=None, + include_dirs=None, cy_kwargs=None, cplus=None, **kwargs): + """ + Convenience function + + If cwd is specified, pyxpath and dst are taken to be relative + If only_update is set to `True` the modification time is checked + and compilation is only run if the source is newer than the + destination + + Parameters + ========== + + pyxpath: str + Path to Cython source file. + objpath: str (optional) + Path to object file to generate. + destdir: str (optional) + Directory to put generated C file. When ``None``: directory of ``objpath``. + cwd: str (optional) + Working directory and root of relative paths. + include_dirs: iterable of path strings (optional) + Passed onto src2obj and via cy_kwargs['include_path'] + to simple_cythonize. + cy_kwargs: dict (optional) + Keyword arguments passed onto `simple_cythonize` + cplus: bool (optional) + Indicate whether C++ is used. default: auto-detect using ``.util.pyx_is_cplus``. + compile_kwargs: dict + keyword arguments passed onto src2obj + + Returns + ======= + + Absolute path of generated object file. + + """ + assert pyxpath.endswith('.pyx') + cwd = cwd or '.' + objpath = objpath or '.' + destdir = destdir or os.path.dirname(objpath) + + abs_objpath = get_abspath(objpath, cwd=cwd) + + if os.path.isdir(abs_objpath): + pyx_fname = os.path.basename(pyxpath) + name, ext = os.path.splitext(pyx_fname) + objpath = os.path.join(objpath, name + objext) + + cy_kwargs = cy_kwargs or {} + cy_kwargs['output_dir'] = cwd + if cplus is None: + cplus = pyx_is_cplus(pyxpath) + cy_kwargs['cplus'] = cplus + + interm_c_file = simple_cythonize(pyxpath, destdir=destdir, cwd=cwd, **cy_kwargs) + + include_dirs = include_dirs or [] + flags = kwargs.pop('flags', []) + needed_flags = ('-fwrapv', '-pthread', '-fPIC') + for flag in needed_flags: + if flag not in flags: + flags.append(flag) + + options = kwargs.pop('options', []) + + if kwargs.pop('strict_aliasing', False): + raise CompileError("Cython requires strict aliasing to be disabled.") + + # Let's be explicit about standard + if cplus: + std = kwargs.pop('std', 'c++98') + else: + std = kwargs.pop('std', 'c99') + + return src2obj(interm_c_file, objpath=objpath, cwd=cwd, + include_dirs=include_dirs, flags=flags, std=std, + options=options, inc_py=True, strict_aliasing=False, + **kwargs) + + +def _any_X(srcs, cls): + for src in srcs: + name, ext = os.path.splitext(src) + key = ext.lower() + if key in extension_mapping: + if extension_mapping[key][0] == cls: + return True + return False + + +def any_fortran_src(srcs): + return _any_X(srcs, FortranCompilerRunner) + + +def any_cplus_src(srcs): + return _any_X(srcs, CppCompilerRunner) + + +def compile_link_import_py_ext(sources, extname=None, build_dir='.', compile_kwargs=None, + link_kwargs=None, extra_objs=None): + """ Compiles sources to a shared object (Python extension) and imports it + + Sources in ``sources`` which is imported. If shared object is newer than the sources, they + are not recompiled but instead it is imported. + + Parameters + ========== + + sources : list of strings + List of paths to sources. + extname : string + Name of extension (default: ``None``). + If ``None``: taken from the last file in ``sources`` without extension. + build_dir: str + Path to directory in which objects files etc. are generated. + compile_kwargs: dict + keyword arguments passed to ``compile_sources`` + link_kwargs: dict + keyword arguments passed to ``link_py_so`` + extra_objs: list + List of paths to (prebuilt) object files / static libraries to link against. + + Returns + ======= + + The imported module from of the Python extension. + """ + if extname is None: + extname = os.path.splitext(os.path.basename(sources[-1]))[0] + + compile_kwargs = compile_kwargs or {} + link_kwargs = link_kwargs or {} + + try: + mod = import_module_from_file(os.path.join(build_dir, extname), sources) + except ImportError: + objs = compile_sources(list(map(get_abspath, sources)), destdir=build_dir, + cwd=build_dir, **compile_kwargs) + so = link_py_so(objs, cwd=build_dir, fort=any_fortran_src(sources), + cplus=any_cplus_src(sources), extra_objs=extra_objs, **link_kwargs) + mod = import_module_from_file(so) + return mod + + +def _write_sources_to_build_dir(sources, build_dir): + build_dir = build_dir or tempfile.mkdtemp() + if not os.path.isdir(build_dir): + raise OSError("Non-existent directory: ", build_dir) + + source_files = [] + for name, src in sources: + dest = os.path.join(build_dir, name) + differs = True + sha256_in_mem = sha256_of_string(src.encode('utf-8')).hexdigest() + if os.path.exists(dest): + if os.path.exists(dest + '.sha256'): + sha256_on_disk = Path(dest + '.sha256').read_text() + else: + sha256_on_disk = sha256_of_file(dest).hexdigest() + + differs = sha256_on_disk != sha256_in_mem + if differs: + with open(dest, 'wt') as fh: + fh.write(src) + with open(dest + '.sha256', 'wt') as fh: + fh.write(sha256_in_mem) + source_files.append(dest) + return source_files, build_dir + + +def compile_link_import_strings(sources, build_dir=None, **kwargs): + """ Compiles, links and imports extension module from source. + + Parameters + ========== + + sources : iterable of name/source pair tuples + build_dir : string (default: None) + Path. ``None`` implies use a temporary directory. + **kwargs: + Keyword arguments passed onto `compile_link_import_py_ext`. + + Returns + ======= + + mod : module + The compiled and imported extension module. + info : dict + Containing ``build_dir`` as 'build_dir'. + + """ + source_files, build_dir = _write_sources_to_build_dir(sources, build_dir) + mod = compile_link_import_py_ext(source_files, build_dir=build_dir, **kwargs) + info = {"build_dir": build_dir} + return mod, info + + +def compile_run_strings(sources, build_dir=None, clean=False, compile_kwargs=None, link_kwargs=None): + """ Compiles, links and runs a program built from sources. + + Parameters + ========== + + sources : iterable of name/source pair tuples + build_dir : string (default: None) + Path. ``None`` implies use a temporary directory. + clean : bool + Whether to remove build_dir after use. This will only have an + effect if ``build_dir`` is ``None`` (which creates a temporary directory). + Passing ``clean == True`` and ``build_dir != None`` raises a ``ValueError``. + This will also set ``build_dir`` in returned info dictionary to ``None``. + compile_kwargs: dict + Keyword arguments passed onto ``compile_sources`` + link_kwargs: dict + Keyword arguments passed onto ``link`` + + Returns + ======= + + (stdout, stderr): pair of strings + info: dict + Containing exit status as 'exit_status' and ``build_dir`` as 'build_dir' + + """ + if clean and build_dir is not None: + raise ValueError("Automatic removal of build_dir is only available for temporary directory.") + try: + source_files, build_dir = _write_sources_to_build_dir(sources, build_dir) + objs = compile_sources(list(map(get_abspath, source_files)), destdir=build_dir, + cwd=build_dir, **(compile_kwargs or {})) + prog = link(objs, cwd=build_dir, + fort=any_fortran_src(source_files), + cplus=any_cplus_src(source_files), **(link_kwargs or {})) + p = subprocess.Popen([prog], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + exit_status = p.wait() + stdout, stderr = [txt.decode('utf-8') for txt in p.communicate()] + finally: + if clean and os.path.isdir(build_dir): + shutil.rmtree(build_dir) + build_dir = None + info = {"exit_status": exit_status, "build_dir": build_dir} + return (stdout, stderr), info diff --git a/phivenv/Lib/site-packages/sympy/utilities/_compilation/runners.py b/phivenv/Lib/site-packages/sympy/utilities/_compilation/runners.py new file mode 100644 index 0000000000000000000000000000000000000000..1f37d6cf8ac47807da7f3f00dfc5cd847c03fa8d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/_compilation/runners.py @@ -0,0 +1,301 @@ +from __future__ import annotations +from typing import Callable, Optional + +from collections import OrderedDict +import os +import re +import subprocess +import warnings + +from .util import ( + find_binary_of_command, unique_list, CompileError +) + + +class CompilerRunner: + """ CompilerRunner base class. + + Parameters + ========== + + sources : list of str + Paths to sources. + out : str + flags : iterable of str + Compiler flags. + run_linker : bool + compiler_name_exe : (str, str) tuple + Tuple of compiler name & command to call. + cwd : str + Path of root of relative paths. + include_dirs : list of str + Include directories. + libraries : list of str + Libraries to link against. + library_dirs : list of str + Paths to search for shared libraries. + std : str + Standard string, e.g. ``'c++11'``, ``'c99'``, ``'f2003'``. + define: iterable of strings + macros to define + undef : iterable of strings + macros to undefine + preferred_vendor : string + name of preferred vendor e.g. 'gnu' or 'intel' + + Methods + ======= + + run(): + Invoke compilation as a subprocess. + + """ + + environ_key_compiler: str # e.g. 'CC', 'CXX', ... + environ_key_flags: str # e.g. 'CFLAGS', 'CXXFLAGS', ... + environ_key_ldflags: str = "LDFLAGS" # typically 'LDFLAGS' + + # Subclass to vendor/binary dict + compiler_dict: dict[str, str] + + # Standards should be a tuple of supported standards + # (first one will be the default) + standards: tuple[None | str, ...] + + # Subclass to dict of binary/formater-callback + std_formater: dict[str, Callable[[Optional[str]], str]] + + # subclass to be e.g. {'gcc': 'gnu', ...} + compiler_name_vendor_mapping: dict[str, str] + + def __init__(self, sources, out, flags=None, run_linker=True, compiler=None, cwd='.', + include_dirs=None, libraries=None, library_dirs=None, std=None, define=None, + undef=None, strict_aliasing=None, preferred_vendor=None, linkline=None, **kwargs): + if isinstance(sources, str): + raise ValueError("Expected argument sources to be a list of strings.") + self.sources = list(sources) + self.out = out + self.flags = flags or [] + if os.environ.get(self.environ_key_flags): + self.flags += os.environ[self.environ_key_flags].split() + self.cwd = cwd + if compiler: + self.compiler_name, self.compiler_binary = compiler + elif os.environ.get(self.environ_key_compiler): + self.compiler_binary = os.environ[self.environ_key_compiler] + for k, v in self.compiler_dict.items(): + if k in self.compiler_binary: + self.compiler_vendor = k + self.compiler_name = v + break + else: + self.compiler_vendor, self.compiler_name = list(self.compiler_dict.items())[0] + warnings.warn("failed to determine what kind of compiler %s is, assuming %s" % + (self.compiler_binary, self.compiler_name)) + else: + # Find a compiler + if preferred_vendor is None: + preferred_vendor = os.environ.get('SYMPY_COMPILER_VENDOR', None) + self.compiler_name, self.compiler_binary, self.compiler_vendor = self.find_compiler(preferred_vendor) + if self.compiler_binary is None: + raise ValueError("No compiler found (searched: {})".format(', '.join(self.compiler_dict.values()))) + self.define = define or [] + self.undef = undef or [] + self.include_dirs = include_dirs or [] + self.libraries = libraries or [] + self.library_dirs = library_dirs or [] + self.std = std or self.standards[0] + self.run_linker = run_linker + if self.run_linker: + # both gnu and intel compilers use '-c' for disabling linker + self.flags = list(filter(lambda x: x != '-c', self.flags)) + else: + if '-c' not in self.flags: + self.flags.append('-c') + + if self.std: + self.flags.append(self.std_formater[ + self.compiler_name](self.std)) + + self.linkline = (linkline or []) + [lf for lf in map( + str.strip, os.environ.get(self.environ_key_ldflags, "").split() + ) if lf != ""] + + if strict_aliasing is not None: + nsa_re = re.compile("no-strict-aliasing$") + sa_re = re.compile("strict-aliasing$") + if strict_aliasing is True: + if any(map(nsa_re.match, flags)): + raise CompileError("Strict aliasing cannot be both enforced and disabled") + elif any(map(sa_re.match, flags)): + pass # already enforced + else: + flags.append('-fstrict-aliasing') + elif strict_aliasing is False: + if any(map(nsa_re.match, flags)): + pass # already disabled + else: + if any(map(sa_re.match, flags)): + raise CompileError("Strict aliasing cannot be both enforced and disabled") + else: + flags.append('-fno-strict-aliasing') + else: + msg = "Expected argument strict_aliasing to be True/False, got {}" + raise ValueError(msg.format(strict_aliasing)) + + @classmethod + def find_compiler(cls, preferred_vendor=None): + """ Identify a suitable C/fortran/other compiler. """ + candidates = list(cls.compiler_dict.keys()) + if preferred_vendor: + if preferred_vendor in candidates: + candidates = [preferred_vendor]+candidates + else: + raise ValueError("Unknown vendor {}".format(preferred_vendor)) + name, path = find_binary_of_command([cls.compiler_dict[x] for x in candidates]) + return name, path, cls.compiler_name_vendor_mapping[name] + + def cmd(self): + """ List of arguments (str) to be passed to e.g. ``subprocess.Popen``. """ + cmd = ( + [self.compiler_binary] + + self.flags + + ['-U'+x for x in self.undef] + + ['-D'+x for x in self.define] + + ['-I'+x for x in self.include_dirs] + + self.sources + ) + if self.run_linker: + cmd += (['-L'+x for x in self.library_dirs] + + ['-l'+x for x in self.libraries] + + self.linkline) + counted = [] + for envvar in re.findall(r'\$\{(\w+)\}', ' '.join(cmd)): + if os.getenv(envvar) is None: + if envvar not in counted: + counted.append(envvar) + msg = "Environment variable '{}' undefined.".format(envvar) + raise CompileError(msg) + return cmd + + def run(self): + self.flags = unique_list(self.flags) + + # Append output flag and name to tail of flags + self.flags.extend(['-o', self.out]) + env = os.environ.copy() + env['PWD'] = self.cwd + + # NOTE: intel compilers seems to need shell=True + p = subprocess.Popen(' '.join(self.cmd()), + shell=True, + cwd=self.cwd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env) + comm = p.communicate() + try: + self.cmd_outerr = comm[0].decode('utf-8') + except UnicodeDecodeError: + self.cmd_outerr = comm[0].decode('iso-8859-1') # win32 + self.cmd_returncode = p.returncode + + # Error handling + if self.cmd_returncode != 0: + msg = "Error executing '{}' in {} (exited status {}):\n {}\n".format( + ' '.join(self.cmd()), self.cwd, str(self.cmd_returncode), self.cmd_outerr + ) + raise CompileError(msg) + + return self.cmd_outerr, self.cmd_returncode + + +class CCompilerRunner(CompilerRunner): + + environ_key_compiler = 'CC' + environ_key_flags = 'CFLAGS' + + compiler_dict = OrderedDict([ + ('gnu', 'gcc'), + ('intel', 'icc'), + ('llvm', 'clang'), + ]) + + standards = ('c89', 'c90', 'c99', 'c11') # First is default + + std_formater = { + 'gcc': '-std={}'.format, + 'icc': '-std={}'.format, + 'clang': '-std={}'.format, + } + + compiler_name_vendor_mapping = { + 'gcc': 'gnu', + 'icc': 'intel', + 'clang': 'llvm' + } + + +def _mk_flag_filter(cmplr_name): # helper for class initialization + not_welcome = {'g++': ("Wimplicit-interface",)} # "Wstrict-prototypes",)} + if cmplr_name in not_welcome: + def fltr(x): + for nw in not_welcome[cmplr_name]: + if nw in x: + return False + return True + else: + def fltr(x): + return True + return fltr + + +class CppCompilerRunner(CompilerRunner): + + environ_key_compiler = 'CXX' + environ_key_flags = 'CXXFLAGS' + + compiler_dict = OrderedDict([ + ('gnu', 'g++'), + ('intel', 'icpc'), + ('llvm', 'clang++'), + ]) + + # First is the default, c++0x == c++11 + standards = ('c++98', 'c++0x') + + std_formater = { + 'g++': '-std={}'.format, + 'icpc': '-std={}'.format, + 'clang++': '-std={}'.format, + } + + compiler_name_vendor_mapping = { + 'g++': 'gnu', + 'icpc': 'intel', + 'clang++': 'llvm' + } + + +class FortranCompilerRunner(CompilerRunner): + + environ_key_compiler = 'FC' + environ_key_flags = 'FFLAGS' + + standards = (None, 'f77', 'f95', 'f2003', 'f2008') + + std_formater = { + 'gfortran': lambda x: '-std=gnu' if x is None else '-std=legacy' if x == 'f77' else '-std={}'.format(x), + 'ifort': lambda x: '-stand f08' if x is None else '-stand f{}'.format(x[-2:]), # f2008 => f08 + } + + compiler_dict = OrderedDict([ + ('gnu', 'gfortran'), + ('intel', 'ifort'), + ]) + + compiler_name_vendor_mapping = { + 'gfortran': 'gnu', + 'ifort': 'intel', + } diff --git a/phivenv/Lib/site-packages/sympy/utilities/_compilation/tests/__init__.py b/phivenv/Lib/site-packages/sympy/utilities/_compilation/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/utilities/_compilation/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/_compilation/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e7ce7806e1ec059ff76777f970ec08674ec2d5d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/_compilation/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/_compilation/tests/__pycache__/test_compilation.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/_compilation/tests/__pycache__/test_compilation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0aad763779236d53421d214d04d352ed796f733 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/_compilation/tests/__pycache__/test_compilation.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/_compilation/tests/test_compilation.py b/phivenv/Lib/site-packages/sympy/utilities/_compilation/tests/test_compilation.py new file mode 100644 index 0000000000000000000000000000000000000000..ff7cf86644a9645665e62b49cfc4e7ea73b2c1ab --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/_compilation/tests/test_compilation.py @@ -0,0 +1,104 @@ +import shutil +import os +import subprocess +import tempfile +from sympy.external import import_module +from sympy.testing.pytest import skip, skip_under_pyodide + +from sympy.utilities._compilation.compilation import compile_link_import_py_ext, compile_link_import_strings, compile_sources, get_abspath + +numpy = import_module('numpy') +cython = import_module('cython') + +_sources1 = [ + ('sigmoid.c', r""" +#include + +void sigmoid(int n, const double * const restrict in, + double * const restrict out, double lim){ + for (int i=0; i 0: + if not os.path.exists(parent): + make_dirs(parent) + + if not os.path.exists(path): + os.mkdir(path, 0o777) + else: + assert os.path.isdir(path) + +def missing_or_other_newer(path, other_path, cwd=None): + """ + Investigate if path is non-existent or older than provided reference + path. + + Parameters + ========== + path: string + path to path which might be missing or too old + other_path: string + reference path + cwd: string + working directory (root of relative paths) + + Returns + ======= + True if path is older or missing. + """ + cwd = cwd or '.' + path = get_abspath(path, cwd=cwd) + other_path = get_abspath(other_path, cwd=cwd) + if not os.path.exists(path): + return True + if os.path.getmtime(other_path) - 1e-6 >= os.path.getmtime(path): + # 1e-6 is needed because http://stackoverflow.com/questions/17086426/ + return True + return False + +def copy(src, dst, only_update=False, copystat=True, cwd=None, + dest_is_dir=False, create_dest_dirs=False): + """ Variation of ``shutil.copy`` with extra options. + + Parameters + ========== + + src : str + Path to source file. + dst : str + Path to destination. + only_update : bool + Only copy if source is newer than destination + (returns None if it was newer), default: ``False``. + copystat : bool + See ``shutil.copystat``. default: ``True``. + cwd : str + Path to working directory (root of relative paths). + dest_is_dir : bool + Ensures that dst is treated as a directory. default: ``False`` + create_dest_dirs : bool + Creates directories if needed. + + Returns + ======= + + Path to the copied file. + + """ + if cwd: # Handle working directory + if not os.path.isabs(src): + src = os.path.join(cwd, src) + if not os.path.isabs(dst): + dst = os.path.join(cwd, dst) + + if not os.path.exists(src): # Make sure source file exists + raise FileNotFoundError("Source: `{}` does not exist".format(src)) + + # We accept both (re)naming destination file _or_ + # passing a (possible non-existent) destination directory + if dest_is_dir: + if not dst[-1] == '/': + dst = dst+'/' + else: + if os.path.exists(dst) and os.path.isdir(dst): + dest_is_dir = True + + if dest_is_dir: + dest_dir = dst + dest_fname = os.path.basename(src) + dst = os.path.join(dest_dir, dest_fname) + else: + dest_dir = os.path.dirname(dst) + + if not os.path.exists(dest_dir): + if create_dest_dirs: + make_dirs(dest_dir) + else: + raise FileNotFoundError("You must create directory first.") + + if only_update: + if not missing_or_other_newer(dst, src): + return + + if os.path.islink(dst): + dst = os.path.abspath(os.path.realpath(dst), cwd=cwd) + + shutil.copy(src, dst) + if copystat: + shutil.copystat(src, dst) + + return dst + +Glob = namedtuple('Glob', 'pathname') +ArbitraryDepthGlob = namedtuple('ArbitraryDepthGlob', 'filename') + +def glob_at_depth(filename_glob, cwd=None): + if cwd is not None: + cwd = '.' + globbed = [] + for root, dirs, filenames in os.walk(cwd): + for fn in filenames: + # This is not tested: + if fnmatch.fnmatch(fn, filename_glob): + globbed.append(os.path.join(root, fn)) + return globbed + +def sha256_of_file(path, nblocks=128): + """ Computes the SHA256 hash of a file. + + Parameters + ========== + + path : string + Path to file to compute hash of. + nblocks : int + Number of blocks to read per iteration. + + Returns + ======= + + hashlib sha256 hash object. Use ``.digest()`` or ``.hexdigest()`` + on returned object to get binary or hex encoded string. + """ + sh = sha256() + with open(path, 'rb') as f: + for chunk in iter(lambda: f.read(nblocks*sh.block_size), b''): + sh.update(chunk) + return sh + + +def sha256_of_string(string): + """ Computes the SHA256 hash of a string. """ + sh = sha256() + sh.update(string) + return sh + + +def pyx_is_cplus(path): + """ + Inspect a Cython source file (.pyx) and look for comment line like: + + # distutils: language = c++ + + Returns True if such a file is present in the file, else False. + """ + with open(path) as fh: + for line in fh: + if line.startswith('#') and '=' in line: + splitted = line.split('=') + if len(splitted) != 2: + continue + lhs, rhs = splitted + if lhs.strip().split()[-1].lower() == 'language' and \ + rhs.strip().split()[0].lower() == 'c++': + return True + return False + +def import_module_from_file(filename, only_if_newer_than=None): + """ Imports Python extension (from shared object file) + + Provide a list of paths in `only_if_newer_than` to check + timestamps of dependencies. import_ raises an ImportError + if any is newer. + + Word of warning: The OS may cache shared objects which makes + reimporting same path of an shared object file very problematic. + + It will not detect the new time stamp, nor new checksum, but will + instead silently use old module. Use unique names for this reason. + + Parameters + ========== + + filename : str + Path to shared object. + only_if_newer_than : iterable of strings + Paths to dependencies of the shared object. + + Raises + ====== + + ``ImportError`` if any of the files specified in ``only_if_newer_than`` are newer + than the file given by filename. + """ + path, name = os.path.split(filename) + name, ext = os.path.splitext(name) + name = name.split('.')[0] + if sys.version_info[0] == 2: + from imp import find_module, load_module + fobj, filename, data = find_module(name, [path]) + if only_if_newer_than: + for dep in only_if_newer_than: + if os.path.getmtime(filename) < os.path.getmtime(dep): + raise ImportError("{} is newer than {}".format(dep, filename)) + mod = load_module(name, fobj, filename, data) + else: + import importlib.util + spec = importlib.util.spec_from_file_location(name, filename) + if spec is None: + raise ImportError("Failed to import: '%s'" % filename) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +def find_binary_of_command(candidates): + """ Finds binary first matching name among candidates. + + Calls ``which`` from shutils for provided candidates and returns + first hit. + + Parameters + ========== + + candidates : iterable of str + Names of candidate commands + + Raises + ====== + + CompilerNotFoundError if no candidates match. + """ + from shutil import which + for c in candidates: + binary_path = which(c) + if c and binary_path: + return c, binary_path + + raise CompilerNotFoundError('No binary located for candidates: {}'.format(candidates)) + + +def unique_list(l): + """ Uniquify a list (skip duplicate items). """ + result = [] + for x in l: + if x not in result: + result.append(x) + return result diff --git a/phivenv/Lib/site-packages/sympy/utilities/autowrap.py b/phivenv/Lib/site-packages/sympy/utilities/autowrap.py new file mode 100644 index 0000000000000000000000000000000000000000..b6c33b2f0f72f89b5680f910929618a821e924c3 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/autowrap.py @@ -0,0 +1,1178 @@ +"""Module for compiling codegen output, and wrap the binary for use in +python. + +.. note:: To use the autowrap module it must first be imported + + >>> from sympy.utilities.autowrap import autowrap + +This module provides a common interface for different external backends, such +as f2py, fwrap, Cython, SWIG(?) etc. (Currently only f2py and Cython are +implemented) The goal is to provide access to compiled binaries of acceptable +performance with a one-button user interface, e.g., + + >>> from sympy.abc import x,y + >>> expr = (x - y)**25 + >>> flat = expr.expand() + >>> binary_callable = autowrap(flat) + >>> binary_callable(2, 3) + -1.0 + +Although a SymPy user might primarily be interested in working with +mathematical expressions and not in the details of wrapping tools +needed to evaluate such expressions efficiently in numerical form, +the user cannot do so without some understanding of the +limits in the target language. For example, the expanded expression +contains large coefficients which result in loss of precision when +computing the expression: + + >>> binary_callable(3, 2) + 0.0 + >>> binary_callable(4, 5), binary_callable(5, 4) + (-22925376.0, 25165824.0) + +Wrapping the unexpanded expression gives the expected behavior: + + >>> e = autowrap(expr) + >>> e(4, 5), e(5, 4) + (-1.0, 1.0) + +The callable returned from autowrap() is a binary Python function, not a +SymPy object. If it is desired to use the compiled function in symbolic +expressions, it is better to use binary_function() which returns a SymPy +Function object. The binary callable is attached as the _imp_ attribute and +invoked when a numerical evaluation is requested with evalf(), or with +lambdify(). + + >>> from sympy.utilities.autowrap import binary_function + >>> f = binary_function('f', expr) + >>> 2*f(x, y) + y + y + 2*f(x, y) + >>> (2*f(x, y) + y).evalf(2, subs={x: 1, y:2}) + 0.e-110 + +When is this useful? + + 1) For computations on large arrays, Python iterations may be too slow, + and depending on the mathematical expression, it may be difficult to + exploit the advanced index operations provided by NumPy. + + 2) For *really* long expressions that will be called repeatedly, the + compiled binary should be significantly faster than SymPy's .evalf() + + 3) If you are generating code with the codegen utility in order to use + it in another project, the automatic Python wrappers let you test the + binaries immediately from within SymPy. + + 4) To create customized ufuncs for use with numpy arrays. + See *ufuncify*. + +When is this module NOT the best approach? + + 1) If you are really concerned about speed or memory optimizations, + you will probably get better results by working directly with the + wrapper tools and the low level code. However, the files generated + by this utility may provide a useful starting point and reference + code. Temporary files will be left intact if you supply the keyword + tempdir="path/to/files/". + + 2) If the array computation can be handled easily by numpy, and you + do not need the binaries for another project. + +""" + +import sys +import os +import shutil +import tempfile +from pathlib import Path +from subprocess import STDOUT, CalledProcessError, check_output +from string import Template +from warnings import warn + +from sympy.core.cache import cacheit +from sympy.core.function import Lambda +from sympy.core.relational import Eq +from sympy.core.symbol import Dummy, Symbol +from sympy.tensor.indexed import Idx, IndexedBase +from sympy.utilities.codegen import (make_routine, get_code_generator, + OutputArgument, InOutArgument, + InputArgument, CodeGenArgumentListError, + Result, ResultBase, C99CodeGen) +from sympy.utilities.iterables import iterable +from sympy.utilities.lambdify import implemented_function +from sympy.utilities.decorator import doctest_depends_on + +_doctest_depends_on = {'exe': ('f2py', 'gfortran', 'gcc'), + 'modules': ('numpy',)} + + +class CodeWrapError(Exception): + pass + + +class CodeWrapper: + """Base Class for code wrappers""" + _filename = "wrapped_code" + _module_basename = "wrapper_module" + _module_counter = 0 + + @property + def filename(self): + return "%s_%s" % (self._filename, CodeWrapper._module_counter) + + @property + def module_name(self): + return "%s_%s" % (self._module_basename, CodeWrapper._module_counter) + + def __init__(self, generator, filepath=None, flags=[], verbose=False): + """ + generator -- the code generator to use + """ + self.generator = generator + self.filepath = filepath + self.flags = flags + self.quiet = not verbose + + @property + def include_header(self): + return bool(self.filepath) + + @property + def include_empty(self): + return bool(self.filepath) + + def _generate_code(self, main_routine, routines): + routines.append(main_routine) + self.generator.write( + routines, self.filename, True, self.include_header, + self.include_empty) + + def wrap_code(self, routine, helpers=None): + helpers = helpers or [] + if self.filepath: + workdir = os.path.abspath(self.filepath) + else: + workdir = tempfile.mkdtemp("_sympy_compile") + if not os.access(workdir, os.F_OK): + os.mkdir(workdir) + oldwork = os.getcwd() + os.chdir(workdir) + try: + sys.path.append(workdir) + self._generate_code(routine, helpers) + self._prepare_files(routine) + self._process_files(routine) + mod = __import__(self.module_name) + finally: + sys.path.remove(workdir) + CodeWrapper._module_counter += 1 + os.chdir(oldwork) + if not self.filepath: + try: + shutil.rmtree(workdir) + except OSError: + # Could be some issues on Windows + pass + + return self._get_wrapped_function(mod, routine.name) + + def _process_files(self, routine): + command = self.command + command.extend(self.flags) + try: + retoutput = check_output(command, stderr=STDOUT) + except CalledProcessError as e: + raise CodeWrapError( + "Error while executing command: %s. Command output is:\n%s" % ( + " ".join(command), e.output.decode('utf-8'))) + if not self.quiet: + print(retoutput) + + +class DummyWrapper(CodeWrapper): + """Class used for testing independent of backends """ + + template = """# dummy module for testing of SymPy +def %(name)s(): + return "%(expr)s" +%(name)s.args = "%(args)s" +%(name)s.returns = "%(retvals)s" +""" + + def _prepare_files(self, routine): + return + + def _generate_code(self, routine, helpers): + with open('%s.py' % self.module_name, 'w') as f: + printed = ", ".join( + [str(res.expr) for res in routine.result_variables]) + # convert OutputArguments to return value like f2py + args = filter(lambda x: not isinstance( + x, OutputArgument), routine.arguments) + retvals = [] + for val in routine.result_variables: + if isinstance(val, Result): + retvals.append('nameless') + else: + retvals.append(val.result_var) + + print(DummyWrapper.template % { + 'name': routine.name, + 'expr': printed, + 'args': ", ".join([str(a.name) for a in args]), + 'retvals': ", ".join([str(val) for val in retvals]) + }, end="", file=f) + + def _process_files(self, routine): + return + + @classmethod + def _get_wrapped_function(cls, mod, name): + return getattr(mod, name) + + +class CythonCodeWrapper(CodeWrapper): + """Wrapper that uses Cython""" + + setup_template = """\ +from setuptools import setup +from setuptools import Extension +from Cython.Build import cythonize +cy_opts = {cythonize_options} +{np_import} +ext_mods = [Extension( + {ext_args}, + include_dirs={include_dirs}, + library_dirs={library_dirs}, + libraries={libraries}, + extra_compile_args={extra_compile_args}, + extra_link_args={extra_link_args} +)] +setup(ext_modules=cythonize(ext_mods, **cy_opts)) +""" + + _cythonize_options = {'compiler_directives':{'language_level' : "3"}} + + pyx_imports = ( + "import numpy as np\n" + "cimport numpy as np\n\n") + + pyx_header = ( + "cdef extern from '{header_file}.h':\n" + " {prototype}\n\n") + + pyx_func = ( + "def {name}_c({arg_string}):\n" + "\n" + "{declarations}" + "{body}") + + std_compile_flag = '-std=c99' + + def __init__(self, *args, **kwargs): + """Instantiates a Cython code wrapper. + + The following optional parameters get passed to ``setuptools.Extension`` + for building the Python extension module. Read its documentation to + learn more. + + Parameters + ========== + include_dirs : [list of strings] + A list of directories to search for C/C++ header files (in Unix + form for portability). + library_dirs : [list of strings] + A list of directories to search for C/C++ libraries at link time. + libraries : [list of strings] + A list of library names (not filenames or paths) to link against. + extra_compile_args : [list of strings] + Any extra platform- and compiler-specific information to use when + compiling the source files in 'sources'. For platforms and + compilers where "command line" makes sense, this is typically a + list of command-line arguments, but for other platforms it could be + anything. Note that the attribute ``std_compile_flag`` will be + appended to this list. + extra_link_args : [list of strings] + Any extra platform- and compiler-specific information to use when + linking object files together to create the extension (or to create + a new static Python interpreter). Similar interpretation as for + 'extra_compile_args'. + cythonize_options : [dictionary] + Keyword arguments passed on to cythonize. + + """ + + self._include_dirs = kwargs.pop('include_dirs', []) + self._library_dirs = kwargs.pop('library_dirs', []) + self._libraries = kwargs.pop('libraries', []) + self._extra_compile_args = kwargs.pop('extra_compile_args', []) + self._extra_compile_args.append(self.std_compile_flag) + self._extra_link_args = kwargs.pop('extra_link_args', []) + self._cythonize_options = kwargs.pop('cythonize_options', self._cythonize_options) + + self._need_numpy = False + + super().__init__(*args, **kwargs) + + @property + def command(self): + command = [sys.executable, "setup.py", "build_ext", "--inplace"] + return command + + def _prepare_files(self, routine, build_dir=os.curdir): + # NOTE : build_dir is used for testing purposes. + pyxfilename = self.module_name + '.pyx' + codefilename = "%s.%s" % (self.filename, self.generator.code_extension) + + # pyx + with open(os.path.join(build_dir, pyxfilename), 'w') as f: + self.dump_pyx([routine], f, self.filename) + + # setup.py + ext_args = [repr(self.module_name), repr([pyxfilename, codefilename])] + if self._need_numpy: + np_import = 'import numpy as np\n' + self._include_dirs.append('np.get_include()') + else: + np_import = '' + + includes = str(self._include_dirs).replace("'np.get_include()'", + 'np.get_include()') + code = self.setup_template.format( + ext_args=", ".join(ext_args), + np_import=np_import, + include_dirs=includes, + library_dirs=self._library_dirs, + libraries=self._libraries, + extra_compile_args=self._extra_compile_args, + extra_link_args=self._extra_link_args, + cythonize_options=self._cythonize_options) + Path(os.path.join(build_dir, 'setup.py')).write_text(code) + + @classmethod + def _get_wrapped_function(cls, mod, name): + return getattr(mod, name + '_c') + + def dump_pyx(self, routines, f, prefix): + """Write a Cython file with Python wrappers + + This file contains all the definitions of the routines in c code and + refers to the header file. + + Arguments + --------- + routines + List of Routine instances + f + File-like object to write the file to + prefix + The filename prefix, used to refer to the proper header file. + Only the basename of the prefix is used. + """ + headers = [] + functions = [] + for routine in routines: + prototype = self.generator.get_prototype(routine) + + # C Function Header Import + headers.append(self.pyx_header.format(header_file=prefix, + prototype=prototype)) + + # Partition the C function arguments into categories + py_rets, py_args, py_loc, py_inf = self._partition_args(routine.arguments) + + # Function prototype + name = routine.name + arg_string = ", ".join(self._prototype_arg(arg) for arg in py_args) + + # Local Declarations + local_decs = [] + for arg, val in py_inf.items(): + proto = self._prototype_arg(arg) + mat, ind = [self._string_var(v) for v in val] + local_decs.append(" cdef {} = {}.shape[{}]".format(proto, mat, ind)) + local_decs.extend([" cdef {}".format(self._declare_arg(a)) for a in py_loc]) + declarations = "\n".join(local_decs) + if declarations: + declarations = declarations + "\n" + + # Function Body + args_c = ", ".join([self._call_arg(a) for a in routine.arguments]) + rets = ", ".join([self._string_var(r.name) for r in py_rets]) + if routine.results: + body = ' return %s(%s)' % (routine.name, args_c) + if rets: + body = body + ', ' + rets + else: + body = ' %s(%s)\n' % (routine.name, args_c) + body = body + ' return ' + rets + + functions.append(self.pyx_func.format(name=name, arg_string=arg_string, + declarations=declarations, body=body)) + + # Write text to file + if self._need_numpy: + # Only import numpy if required + f.write(self.pyx_imports) + f.write('\n'.join(headers)) + f.write('\n'.join(functions)) + + def _partition_args(self, args): + """Group function arguments into categories.""" + py_args = [] + py_returns = [] + py_locals = [] + py_inferred = {} + for arg in args: + if isinstance(arg, OutputArgument): + py_returns.append(arg) + py_locals.append(arg) + elif isinstance(arg, InOutArgument): + py_returns.append(arg) + py_args.append(arg) + else: + py_args.append(arg) + # Find arguments that are array dimensions. These can be inferred + # locally in the Cython code. + if isinstance(arg, (InputArgument, InOutArgument)) and arg.dimensions: + dims = [d[1] + 1 for d in arg.dimensions] + sym_dims = [(i, d) for (i, d) in enumerate(dims) if + isinstance(d, Symbol)] + for (i, d) in sym_dims: + py_inferred[d] = (arg.name, i) + for arg in args: + if arg.name in py_inferred: + py_inferred[arg] = py_inferred.pop(arg.name) + # Filter inferred arguments from py_args + py_args = [a for a in py_args if a not in py_inferred] + return py_returns, py_args, py_locals, py_inferred + + def _prototype_arg(self, arg): + mat_dec = "np.ndarray[{mtype}, ndim={ndim}] {name}" + np_types = {'double': 'np.double_t', + 'int': 'np.int_t'} + t = arg.get_datatype('c') + if arg.dimensions: + self._need_numpy = True + ndim = len(arg.dimensions) + mtype = np_types[t] + return mat_dec.format(mtype=mtype, ndim=ndim, name=self._string_var(arg.name)) + else: + return "%s %s" % (t, self._string_var(arg.name)) + + def _declare_arg(self, arg): + proto = self._prototype_arg(arg) + if arg.dimensions: + shape = '(' + ','.join(self._string_var(i[1] + 1) for i in arg.dimensions) + ')' + return proto + " = np.empty({shape})".format(shape=shape) + else: + return proto + " = 0" + + def _call_arg(self, arg): + if arg.dimensions: + t = arg.get_datatype('c') + return "<{}*> {}.data".format(t, self._string_var(arg.name)) + elif isinstance(arg, ResultBase): + return "&{}".format(self._string_var(arg.name)) + else: + return self._string_var(arg.name) + + def _string_var(self, var): + printer = self.generator.printer.doprint + return printer(var) + + +class F2PyCodeWrapper(CodeWrapper): + """Wrapper that uses f2py""" + + def __init__(self, *args, **kwargs): + + ext_keys = ['include_dirs', 'library_dirs', 'libraries', + 'extra_compile_args', 'extra_link_args'] + msg = ('The compilation option kwarg {} is not supported with the f2py ' + 'backend.') + + for k in ext_keys: + if k in kwargs.keys(): + warn(msg.format(k)) + kwargs.pop(k, None) + + super().__init__(*args, **kwargs) + + @property + def command(self): + filename = self.filename + '.' + self.generator.code_extension + args = ['-c', '-m', self.module_name, filename] + command = [sys.executable, "-c", "import numpy.f2py as f2py2e;f2py2e.main()"]+args + return command + + def _prepare_files(self, routine): + pass + + @classmethod + def _get_wrapped_function(cls, mod, name): + return getattr(mod, name) + + +# Here we define a lookup of backends -> tuples of languages. For now, each +# tuple is of length 1, but if a backend supports more than one language, +# the most preferable language is listed first. +_lang_lookup = {'CYTHON': ('C99', 'C89', 'C'), + 'F2PY': ('F95',), + 'NUMPY': ('C99', 'C89', 'C'), + 'DUMMY': ('F95',)} # Dummy here just for testing + + +def _infer_language(backend): + """For a given backend, return the top choice of language""" + langs = _lang_lookup.get(backend.upper(), False) + if not langs: + raise ValueError("Unrecognized backend: " + backend) + return langs[0] + + +def _validate_backend_language(backend, language): + """Throws error if backend and language are incompatible""" + langs = _lang_lookup.get(backend.upper(), False) + if not langs: + raise ValueError("Unrecognized backend: " + backend) + if language.upper() not in langs: + raise ValueError(("Backend {} and language {} are " + "incompatible").format(backend, language)) + + +@cacheit +@doctest_depends_on(exe=('f2py', 'gfortran'), modules=('numpy',)) +def autowrap(expr, language=None, backend='f2py', tempdir=None, args=None, + flags=None, verbose=False, helpers=None, code_gen=None, **kwargs): + """Generates Python callable binaries based on the math expression. + + Parameters + ========== + + expr + The SymPy expression that should be wrapped as a binary routine. + language : string, optional + If supplied, (options: 'C' or 'F95'), specifies the language of the + generated code. If ``None`` [default], the language is inferred based + upon the specified backend. + backend : string, optional + Backend used to wrap the generated code. Either 'f2py' [default], + or 'cython'. + tempdir : string, optional + Path to directory for temporary files. If this argument is supplied, + the generated code and the wrapper input files are left intact in the + specified path. + args : iterable, optional + An ordered iterable of symbols. Specifies the argument sequence for the + function. + flags : iterable, optional + Additional option flags that will be passed to the backend. + verbose : bool, optional + If True, autowrap will not mute the command line backends. This can be + helpful for debugging. + helpers : 3-tuple or iterable of 3-tuples, optional + Used to define auxiliary functions needed for the main expression. + Each tuple should be of the form (name, expr, args) where: + + - name : str, the function name + - expr : sympy expression, the function + - args : iterable, the function arguments (can be any iterable of symbols) + + code_gen : CodeGen instance + An instance of a CodeGen subclass. Overrides ``language``. + include_dirs : [string] + A list of directories to search for C/C++ header files (in Unix form + for portability). + library_dirs : [string] + A list of directories to search for C/C++ libraries at link time. + libraries : [string] + A list of library names (not filenames or paths) to link against. + extra_compile_args : [string] + Any extra platform- and compiler-specific information to use when + compiling the source files in 'sources'. For platforms and compilers + where "command line" makes sense, this is typically a list of + command-line arguments, but for other platforms it could be anything. + extra_link_args : [string] + Any extra platform- and compiler-specific information to use when + linking object files together to create the extension (or to create a + new static Python interpreter). Similar interpretation as for + 'extra_compile_args'. + + Examples + ======== + + Basic usage: + + >>> from sympy.abc import x, y, z + >>> from sympy.utilities.autowrap import autowrap + >>> expr = ((x - y + z)**(13)).expand() + >>> binary_func = autowrap(expr) + >>> binary_func(1, 4, 2) + -1.0 + + Using helper functions: + + >>> from sympy.abc import x, t + >>> from sympy import Function + >>> helper_func = Function('helper_func') # Define symbolic function + >>> expr = 3*x + helper_func(t) # Main expression using helper function + >>> # Define helper_func(x) = 4*x using f2py backend + >>> binary_func = autowrap(expr, args=[x, t], + ... helpers=('helper_func', 4*x, [x])) + >>> binary_func(2, 5) # 3*2 + helper_func(5) = 6 + 20 + 26.0 + >>> # Same example using cython backend + >>> binary_func = autowrap(expr, args=[x, t], backend='cython', + ... helpers=[('helper_func', 4*x, [x])]) + >>> binary_func(2, 5) # 3*2 + helper_func(5) = 6 + 20 + 26.0 + + Type handling example: + + >>> import numpy as np + >>> expr = x + y + >>> f_cython = autowrap(expr, backend='cython') + >>> f_cython(1, 2) # doctest: +ELLIPSIS + Traceback (most recent call last): + ... + TypeError: Argument '_x' has incorrect type (expected numpy.ndarray, got int) + >>> f_cython(np.array([1.0]), np.array([2.0])) + array([ 3.]) + + """ + if language: + if not isinstance(language, type): + _validate_backend_language(backend, language) + else: + language = _infer_language(backend) + + # two cases 1) helpers is an iterable of 3-tuples and 2) helpers is a + # 3-tuple + if iterable(helpers) and len(helpers) != 0 and iterable(helpers[0]): + helpers = helpers if helpers else () + else: + helpers = [helpers] if helpers else () + args = list(args) if iterable(args, exclude=set) else args + + if code_gen is None: + code_gen = get_code_generator(language, "autowrap") + + CodeWrapperClass = { + 'F2PY': F2PyCodeWrapper, + 'CYTHON': CythonCodeWrapper, + 'DUMMY': DummyWrapper + }[backend.upper()] + code_wrapper = CodeWrapperClass(code_gen, tempdir, flags if flags else (), + verbose, **kwargs) + + helps = [] + for name_h, expr_h, args_h in helpers: + helps.append(code_gen.routine(name_h, expr_h, args_h)) + + for name_h, expr_h, args_h in helpers: + if expr.has(expr_h): + name_h = binary_function(name_h, expr_h, backend='dummy') + expr = expr.subs(expr_h, name_h(*args_h)) + try: + routine = code_gen.routine('autofunc', expr, args) + except CodeGenArgumentListError as e: + # if all missing arguments are for pure output, we simply attach them + # at the end and try again, because the wrappers will silently convert + # them to return values anyway. + new_args = [] + for missing in e.missing_args: + if not isinstance(missing, OutputArgument): + raise + new_args.append(missing.name) + routine = code_gen.routine('autofunc', expr, args + new_args) + + return code_wrapper.wrap_code(routine, helpers=helps) + + +@doctest_depends_on(exe=('f2py', 'gfortran'), modules=('numpy',)) +def binary_function(symfunc, expr, **kwargs): + """Returns a SymPy function with expr as binary implementation + + This is a convenience function that automates the steps needed to + autowrap the SymPy expression and attaching it to a Function object + with implemented_function(). + + Parameters + ========== + + symfunc : SymPy Function + The function to bind the callable to. + expr : SymPy Expression + The expression used to generate the function. + kwargs : dict + Any kwargs accepted by autowrap. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy.utilities.autowrap import binary_function + >>> expr = ((x - y)**(25)).expand() + >>> f = binary_function('f', expr) + >>> type(f) + + >>> 2*f(x, y) + 2*f(x, y) + >>> f(x, y).evalf(2, subs={x: 1, y: 2}) + -1.0 + + """ + binary = autowrap(expr, **kwargs) + return implemented_function(symfunc, binary) + +################################################################# +# UFUNCIFY # +################################################################# + +_ufunc_top = Template("""\ +#include "Python.h" +#include "math.h" +#include "numpy/ndarraytypes.h" +#include "numpy/ufuncobject.h" +#include "numpy/halffloat.h" +#include ${include_file} + +static PyMethodDef ${module}Methods[] = { + {NULL, NULL, 0, NULL} +};""") + +_ufunc_outcalls = Template("*((double *)out${outnum}) = ${funcname}(${call_args});") + +_ufunc_body = Template("""\ +#ifdef NPY_1_19_API_VERSION +static void ${funcname}_ufunc(char **args, const npy_intp *dimensions, const npy_intp* steps, void* data) +#else +static void ${funcname}_ufunc(char **args, npy_intp *dimensions, npy_intp* steps, void* data) +#endif +{ + npy_intp i; + npy_intp n = dimensions[0]; + ${declare_args} + ${declare_steps} + for (i = 0; i < n; i++) { + ${outcalls} + ${step_increments} + } +} +PyUFuncGenericFunction ${funcname}_funcs[1] = {&${funcname}_ufunc}; +static char ${funcname}_types[${n_types}] = ${types} +static void *${funcname}_data[1] = {NULL};""") + +_ufunc_bottom = Template("""\ +#if PY_VERSION_HEX >= 0x03000000 +static struct PyModuleDef moduledef = { + PyModuleDef_HEAD_INIT, + "${module}", + NULL, + -1, + ${module}Methods, + NULL, + NULL, + NULL, + NULL +}; + +PyMODINIT_FUNC PyInit_${module}(void) +{ + PyObject *m, *d; + ${function_creation} + m = PyModule_Create(&moduledef); + if (!m) { + return NULL; + } + import_array(); + import_umath(); + d = PyModule_GetDict(m); + ${ufunc_init} + return m; +} +#else +PyMODINIT_FUNC init${module}(void) +{ + PyObject *m, *d; + ${function_creation} + m = Py_InitModule("${module}", ${module}Methods); + if (m == NULL) { + return; + } + import_array(); + import_umath(); + d = PyModule_GetDict(m); + ${ufunc_init} +} +#endif\ +""") + +_ufunc_init_form = Template("""\ +ufunc${ind} = PyUFunc_FromFuncAndData(${funcname}_funcs, ${funcname}_data, ${funcname}_types, 1, ${n_in}, ${n_out}, + PyUFunc_None, "${module}", ${docstring}, 0); + PyDict_SetItemString(d, "${funcname}", ufunc${ind}); + Py_DECREF(ufunc${ind});""") + +_ufunc_setup = Template("""\ +from setuptools.extension import Extension +from setuptools import setup + +from numpy import get_include + +if __name__ == "__main__": + setup(ext_modules=[ + Extension('${module}', + sources=['${module}.c', '${filename}.c'], + include_dirs=[get_include()])]) +""") + + +class UfuncifyCodeWrapper(CodeWrapper): + """Wrapper for Ufuncify""" + + def __init__(self, *args, **kwargs): + + ext_keys = ['include_dirs', 'library_dirs', 'libraries', + 'extra_compile_args', 'extra_link_args'] + msg = ('The compilation option kwarg {} is not supported with the numpy' + ' backend.') + + for k in ext_keys: + if k in kwargs.keys(): + warn(msg.format(k)) + kwargs.pop(k, None) + + super().__init__(*args, **kwargs) + + @property + def command(self): + command = [sys.executable, "setup.py", "build_ext", "--inplace"] + return command + + def wrap_code(self, routines, helpers=None): + # This routine overrides CodeWrapper because we can't assume funcname == routines[0].name + # Therefore we have to break the CodeWrapper private API. + # There isn't an obvious way to extend multi-expr support to + # the other autowrap backends, so we limit this change to ufuncify. + helpers = helpers if helpers is not None else [] + # We just need a consistent name + funcname = 'wrapped_' + str(id(routines) + id(helpers)) + + workdir = self.filepath or tempfile.mkdtemp("_sympy_compile") + if not os.access(workdir, os.F_OK): + os.mkdir(workdir) + oldwork = os.getcwd() + os.chdir(workdir) + try: + sys.path.append(workdir) + self._generate_code(routines, helpers) + self._prepare_files(routines, funcname) + self._process_files(routines) + mod = __import__(self.module_name) + finally: + sys.path.remove(workdir) + CodeWrapper._module_counter += 1 + os.chdir(oldwork) + if not self.filepath: + try: + shutil.rmtree(workdir) + except OSError: + # Could be some issues on Windows + pass + + return self._get_wrapped_function(mod, funcname) + + def _generate_code(self, main_routines, helper_routines): + all_routines = main_routines + helper_routines + self.generator.write( + all_routines, self.filename, True, self.include_header, + self.include_empty) + + def _prepare_files(self, routines, funcname): + + # C + codefilename = self.module_name + '.c' + with open(codefilename, 'w') as f: + self.dump_c(routines, f, self.filename, funcname=funcname) + + # setup.py + with open('setup.py', 'w') as f: + self.dump_setup(f) + + @classmethod + def _get_wrapped_function(cls, mod, name): + return getattr(mod, name) + + def dump_setup(self, f): + setup = _ufunc_setup.substitute(module=self.module_name, + filename=self.filename) + f.write(setup) + + def dump_c(self, routines, f, prefix, funcname=None): + """Write a C file with Python wrappers + + This file contains all the definitions of the routines in c code. + + Arguments + --------- + routines + List of Routine instances + f + File-like object to write the file to + prefix + The filename prefix, used to name the imported module. + funcname + Name of the main function to be returned. + """ + if funcname is None: + if len(routines) == 1: + funcname = routines[0].name + else: + msg = 'funcname must be specified for multiple output routines' + raise ValueError(msg) + functions = [] + function_creation = [] + ufunc_init = [] + module = self.module_name + include_file = "\"{}.h\"".format(prefix) + top = _ufunc_top.substitute(include_file=include_file, module=module) + + name = funcname + + # Partition the C function arguments into categories + # Here we assume all routines accept the same arguments + r_index = 0 + py_in, _ = self._partition_args(routines[0].arguments) + n_in = len(py_in) + n_out = len(routines) + + # Declare Args + form = "char *{0}{1} = args[{2}];" + arg_decs = [form.format('in', i, i) for i in range(n_in)] + arg_decs.extend([form.format('out', i, i+n_in) for i in range(n_out)]) + declare_args = '\n '.join(arg_decs) + + # Declare Steps + form = "npy_intp {0}{1}_step = steps[{2}];" + step_decs = [form.format('in', i, i) for i in range(n_in)] + step_decs.extend([form.format('out', i, i+n_in) for i in range(n_out)]) + declare_steps = '\n '.join(step_decs) + + # Call Args + form = "*(double *)in{0}" + call_args = ', '.join([form.format(a) for a in range(n_in)]) + + # Step Increments + form = "{0}{1} += {0}{1}_step;" + step_incs = [form.format('in', i) for i in range(n_in)] + step_incs.extend([form.format('out', i, i) for i in range(n_out)]) + step_increments = '\n '.join(step_incs) + + # Types + n_types = n_in + n_out + types = "{" + ', '.join(["NPY_DOUBLE"]*n_types) + "};" + + # Docstring + docstring = '"Created in SymPy with Ufuncify"' + + # Function Creation + function_creation.append("PyObject *ufunc{};".format(r_index)) + + # Ufunc initialization + init_form = _ufunc_init_form.substitute(module=module, + funcname=name, + docstring=docstring, + n_in=n_in, n_out=n_out, + ind=r_index) + ufunc_init.append(init_form) + + outcalls = [_ufunc_outcalls.substitute( + outnum=i, call_args=call_args, funcname=routines[i].name) for i in + range(n_out)] + + body = _ufunc_body.substitute(module=module, funcname=name, + declare_args=declare_args, + declare_steps=declare_steps, + call_args=call_args, + step_increments=step_increments, + n_types=n_types, types=types, + outcalls='\n '.join(outcalls)) + functions.append(body) + + body = '\n\n'.join(functions) + ufunc_init = '\n '.join(ufunc_init) + function_creation = '\n '.join(function_creation) + bottom = _ufunc_bottom.substitute(module=module, + ufunc_init=ufunc_init, + function_creation=function_creation) + text = [top, body, bottom] + f.write('\n\n'.join(text)) + + def _partition_args(self, args): + """Group function arguments into categories.""" + py_in = [] + py_out = [] + for arg in args: + if isinstance(arg, OutputArgument): + py_out.append(arg) + elif isinstance(arg, InOutArgument): + raise ValueError("Ufuncify doesn't support InOutArguments") + else: + py_in.append(arg) + return py_in, py_out + + +@cacheit +@doctest_depends_on(exe=('f2py', 'gfortran', 'gcc'), modules=('numpy',)) +def ufuncify(args, expr, language=None, backend='numpy', tempdir=None, + flags=None, verbose=False, helpers=None, **kwargs): + """Generates a binary function that supports broadcasting on numpy arrays. + + Parameters + ========== + + args : iterable + Either a Symbol or an iterable of symbols. Specifies the argument + sequence for the function. + expr + A SymPy expression that defines the element wise operation. + language : string, optional + If supplied, (options: 'C' or 'F95'), specifies the language of the + generated code. If ``None`` [default], the language is inferred based + upon the specified backend. + backend : string, optional + Backend used to wrap the generated code. Either 'numpy' [default], + 'cython', or 'f2py'. + tempdir : string, optional + Path to directory for temporary files. If this argument is supplied, + the generated code and the wrapper input files are left intact in + the specified path. + flags : iterable, optional + Additional option flags that will be passed to the backend. + verbose : bool, optional + If True, autowrap will not mute the command line backends. This can + be helpful for debugging. + helpers : 3-tuple or iterable of 3-tuples, optional + Used to define auxiliary functions needed for the main expression. + Each tuple should be of the form (name, expr, args) where: + + - name : str, the function name + - expr : sympy expression, the function + - args : iterable, the function arguments (can be any iterable of symbols) + + kwargs : dict + These kwargs will be passed to autowrap if the `f2py` or `cython` + backend is used and ignored if the `numpy` backend is used. + + Notes + ===== + + The default backend ('numpy') will create actual instances of + ``numpy.ufunc``. These support ndimensional broadcasting, and implicit type + conversion. Use of the other backends will result in a "ufunc-like" + function, which requires equal length 1-dimensional arrays for all + arguments, and will not perform any type conversions. + + References + ========== + + .. [1] https://numpy.org/doc/stable/reference/ufuncs.html + + Examples + ======== + + Basic usage: + + >>> from sympy.utilities.autowrap import ufuncify + >>> from sympy.abc import x, y + >>> import numpy as np + >>> f = ufuncify((x, y), y + x**2) + >>> type(f) + + >>> f([1, 2, 3], 2) + array([ 3., 6., 11.]) + >>> f(np.arange(5), 3) + array([ 3., 4., 7., 12., 19.]) + + Using helper functions: + + >>> from sympy import Function + >>> helper_func = Function('helper_func') # Define symbolic function + >>> expr = x**2 + y*helper_func(x) # Main expression using helper function + >>> # Define helper_func(x) = x**3 + >>> f = ufuncify((x, y), expr, helpers=[('helper_func', x**3, [x])]) + >>> f([1, 2], [3, 4]) + array([ 4., 36.]) + + Type handling with different backends: + + For the 'f2py' and 'cython' backends, inputs are required to be equal length + 1-dimensional arrays. The 'f2py' backend will perform type conversion, but + the Cython backend will error if the inputs are not of the expected type. + + >>> f_fortran = ufuncify((x, y), y + x**2, backend='f2py') + >>> f_fortran(1, 2) + array([ 3.]) + >>> f_fortran(np.array([1, 2, 3]), np.array([1.0, 2.0, 3.0])) + array([ 2., 6., 12.]) + >>> f_cython = ufuncify((x, y), y + x**2, backend='Cython') + >>> f_cython(1, 2) # doctest: +ELLIPSIS + Traceback (most recent call last): + ... + TypeError: Argument '_x' has incorrect type (expected numpy.ndarray, got int) + >>> f_cython(np.array([1.0]), np.array([2.0])) + array([ 3.]) + + """ + + if isinstance(args, Symbol): + args = (args,) + else: + args = tuple(args) + + if language: + _validate_backend_language(backend, language) + else: + language = _infer_language(backend) + + helpers = helpers if helpers else () + flags = flags if flags else () + + if backend.upper() == 'NUMPY': + # maxargs is set by numpy compile-time constant NPY_MAXARGS + # If a future version of numpy modifies or removes this restriction + # this variable should be changed or removed + maxargs = 32 + helps = [] + for name, expr, args in helpers: + helps.append(make_routine(name, expr, args)) + code_wrapper = UfuncifyCodeWrapper(C99CodeGen("ufuncify"), tempdir, + flags, verbose) + if not isinstance(expr, (list, tuple)): + expr = [expr] + if len(expr) == 0: + raise ValueError('Expression iterable has zero length') + if len(expr) + len(args) > maxargs: + msg = ('Cannot create ufunc with more than {0} total arguments: ' + 'got {1} in, {2} out') + raise ValueError(msg.format(maxargs, len(args), len(expr))) + routines = [make_routine('autofunc{}'.format(idx), exprx, args) for + idx, exprx in enumerate(expr)] + return code_wrapper.wrap_code(routines, helpers=helps) + else: + # Dummies are used for all added expressions to prevent name clashes + # within the original expression. + y = IndexedBase(Dummy('y')) + m = Dummy('m', integer=True) + i = Idx(Dummy('i', integer=True), m) + f_dummy = Dummy('f') + f = implemented_function('%s_%d' % (f_dummy.name, f_dummy.dummy_index), Lambda(args, expr)) + # For each of the args create an indexed version. + indexed_args = [IndexedBase(Dummy(str(a))) for a in args] + # Order the arguments (out, args, dim) + args = [y] + indexed_args + [m] + args_with_indices = [a[i] for a in indexed_args] + return autowrap(Eq(y[i], f(*args_with_indices)), language, backend, + tempdir, args, flags, verbose, helpers, **kwargs) diff --git a/phivenv/Lib/site-packages/sympy/utilities/codegen.py b/phivenv/Lib/site-packages/sympy/utilities/codegen.py new file mode 100644 index 0000000000000000000000000000000000000000..9ac8772fc000d707ae33c67eaa44b4c281157ab0 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/codegen.py @@ -0,0 +1,2237 @@ +""" +module for generating C, C++, Fortran77, Fortran90, Julia, Rust +and Octave/Matlab routines that evaluate SymPy expressions. +This module is work in progress. +Only the milestones with a '+' character in the list below have been completed. + +--- How is sympy.utilities.codegen different from sympy.printing.ccode? --- + +We considered the idea to extend the printing routines for SymPy functions in +such a way that it prints complete compilable code, but this leads to a few +unsurmountable issues that can only be tackled with dedicated code generator: + +- For C, one needs both a code and a header file, while the printing routines + generate just one string. This code generator can be extended to support + .pyf files for f2py. + +- SymPy functions are not concerned with programming-technical issues, such + as input, output and input-output arguments. Other examples are contiguous + or non-contiguous arrays, including headers of other libraries such as gsl + or others. + +- It is highly interesting to evaluate several SymPy functions in one C + routine, eventually sharing common intermediate results with the help + of the cse routine. This is more than just printing. + +- From the programming perspective, expressions with constants should be + evaluated in the code generator as much as possible. This is different + for printing. + +--- Basic assumptions --- + +* A generic Routine data structure describes the routine that must be + translated into C/Fortran/... code. This data structure covers all + features present in one or more of the supported languages. + +* Descendants from the CodeGen class transform multiple Routine instances + into compilable code. Each derived class translates into a specific + language. + +* In many cases, one wants a simple workflow. The friendly functions in the + last part are a simple api on top of the Routine/CodeGen stuff. They are + easier to use, but are less powerful. + +--- Milestones --- + ++ First working version with scalar input arguments, generating C code, + tests ++ Friendly functions that are easier to use than the rigorous + Routine/CodeGen workflow. ++ Integer and Real numbers as input and output ++ Output arguments ++ InputOutput arguments ++ Sort input/output arguments properly ++ Contiguous array arguments (numpy matrices) ++ Also generate .pyf code for f2py (in autowrap module) ++ Isolate constants and evaluate them beforehand in double precision ++ Fortran 90 ++ Octave/Matlab + +- Common Subexpression Elimination +- User defined comments in the generated code +- Optional extra include lines for libraries/objects that can eval special + functions +- Test other C compilers and libraries: gcc, tcc, libtcc, gcc+gsl, ... +- Contiguous array arguments (SymPy matrices) +- Non-contiguous array arguments (SymPy matrices) +- ccode must raise an error when it encounters something that cannot be + translated into c. ccode(integrate(sin(x)/x, x)) does not make sense. +- Complex numbers as input and output +- A default complex datatype +- Include extra information in the header: date, user, hostname, sha1 + hash, ... +- Fortran 77 +- C++ +- Python +- Julia +- Rust +- ... + +""" + +import os +import textwrap +from io import StringIO + +from sympy import __version__ as sympy_version +from sympy.core import Symbol, S, Tuple, Equality, Function, Basic +from sympy.printing.c import c_code_printers +from sympy.printing.codeprinter import AssignmentError +from sympy.printing.fortran import FCodePrinter +from sympy.printing.julia import JuliaCodePrinter +from sympy.printing.octave import OctaveCodePrinter +from sympy.printing.rust import RustCodePrinter +from sympy.tensor import Idx, Indexed, IndexedBase +from sympy.matrices import (MatrixSymbol, ImmutableMatrix, MatrixBase, + MatrixExpr, MatrixSlice) +from sympy.utilities.iterables import is_sequence + + +__all__ = [ + # description of routines + "Routine", "DataType", "default_datatypes", "get_default_datatype", + "Argument", "InputArgument", "OutputArgument", "Result", + # routines -> code + "CodeGen", "CCodeGen", "FCodeGen", "JuliaCodeGen", "OctaveCodeGen", + "RustCodeGen", + # friendly functions + "codegen", "make_routine", +] + + +# +# Description of routines +# + + +class Routine: + """Generic description of evaluation routine for set of expressions. + + A CodeGen class can translate instances of this class into code in a + particular language. The routine specification covers all the features + present in these languages. The CodeGen part must raise an exception + when certain features are not present in the target language. For + example, multiple return values are possible in Python, but not in C or + Fortran. Another example: Fortran and Python support complex numbers, + while C does not. + + """ + + def __init__(self, name, arguments, results, local_vars, global_vars): + """Initialize a Routine instance. + + Parameters + ========== + + name : string + Name of the routine. + + arguments : list of Arguments + These are things that appear in arguments of a routine, often + appearing on the right-hand side of a function call. These are + commonly InputArguments but in some languages, they can also be + OutputArguments or InOutArguments (e.g., pass-by-reference in C + code). + + results : list of Results + These are the return values of the routine, often appearing on + the left-hand side of a function call. The difference between + Results and OutputArguments and when you should use each is + language-specific. + + local_vars : list of Results + These are variables that will be defined at the beginning of the + function. + + global_vars : list of Symbols + Variables which will not be passed into the function. + + """ + + # extract all input symbols and all symbols appearing in an expression + input_symbols = set() + symbols = set() + for arg in arguments: + if isinstance(arg, OutputArgument): + symbols.update(arg.expr.free_symbols - arg.expr.atoms(Indexed)) + elif isinstance(arg, InputArgument): + input_symbols.add(arg.name) + elif isinstance(arg, InOutArgument): + input_symbols.add(arg.name) + symbols.update(arg.expr.free_symbols - arg.expr.atoms(Indexed)) + else: + raise ValueError("Unknown Routine argument: %s" % arg) + + for r in results: + if not isinstance(r, Result): + raise ValueError("Unknown Routine result: %s" % r) + symbols.update(r.expr.free_symbols - r.expr.atoms(Indexed)) + + local_symbols = set() + for r in local_vars: + if isinstance(r, Result): + symbols.update(r.expr.free_symbols - r.expr.atoms(Indexed)) + local_symbols.add(r.name) + else: + local_symbols.add(r) + + symbols = {s.label if isinstance(s, Idx) else s for s in symbols} + + # Check that all symbols in the expressions are covered by + # InputArguments/InOutArguments---subset because user could + # specify additional (unused) InputArguments or local_vars. + notcovered = symbols.difference( + input_symbols.union(local_symbols).union(global_vars)) + if notcovered != set(): + raise ValueError("Symbols needed for output are not in input " + + ", ".join([str(x) for x in notcovered])) + + self.name = name + self.arguments = arguments + self.results = results + self.local_vars = local_vars + self.global_vars = global_vars + + def __str__(self): + return self.__class__.__name__ + "({name!r}, {arguments}, {results}, {local_vars}, {global_vars})".format(**self.__dict__) + + __repr__ = __str__ + + @property + def variables(self): + """Returns a set of all variables possibly used in the routine. + + For routines with unnamed return values, the dummies that may or + may not be used will be included in the set. + + """ + v = set(self.local_vars) + v.update(arg.name for arg in self.arguments) + v.update(res.result_var for res in self.results) + return v + + @property + def result_variables(self): + """Returns a list of OutputArgument, InOutArgument and Result. + + If return values are present, they are at the end of the list. + """ + args = [arg for arg in self.arguments if isinstance( + arg, (OutputArgument, InOutArgument))] + args.extend(self.results) + return args + + +class DataType: + """Holds strings for a certain datatype in different languages.""" + def __init__(self, cname, fname, pyname, jlname, octname, rsname): + self.cname = cname + self.fname = fname + self.pyname = pyname + self.jlname = jlname + self.octname = octname + self.rsname = rsname + + +default_datatypes = { + "int": DataType("int", "INTEGER*4", "int", "", "", "i32"), + "float": DataType("double", "REAL*8", "float", "", "", "f64"), + "complex": DataType("double", "COMPLEX*16", "complex", "", "", "float") #FIXME: + # complex is only supported in fortran, python, julia, and octave. + # So to not break c or rust code generation, we stick with double or + # float, respectively (but actually should raise an exception for + # explicitly complex variables (x.is_complex==True)) +} + + +COMPLEX_ALLOWED = False +def get_default_datatype(expr, complex_allowed=None): + """Derives an appropriate datatype based on the expression.""" + if complex_allowed is None: + complex_allowed = COMPLEX_ALLOWED + if complex_allowed: + final_dtype = "complex" + else: + final_dtype = "float" + if expr.is_integer: + return default_datatypes["int"] + elif expr.is_real: + return default_datatypes["float"] + elif isinstance(expr, MatrixBase): + #check all entries + dt = "int" + for element in expr: + if dt == "int" and not element.is_integer: + dt = "float" + if dt == "float" and not element.is_real: + return default_datatypes[final_dtype] + return default_datatypes[dt] + else: + return default_datatypes[final_dtype] + + +class Variable: + """Represents a typed variable.""" + + def __init__(self, name, datatype=None, dimensions=None, precision=None): + """Return a new variable. + + Parameters + ========== + + name : Symbol or MatrixSymbol + + datatype : optional + When not given, the data type will be guessed based on the + assumptions on the symbol argument. + + dimensions : sequence containing tuples, optional + If present, the argument is interpreted as an array, where this + sequence of tuples specifies (lower, upper) bounds for each + index of the array. + + precision : int, optional + Controls the precision of floating point constants. + + """ + if not isinstance(name, (Symbol, MatrixSymbol)): + raise TypeError("The first argument must be a SymPy symbol.") + if datatype is None: + datatype = get_default_datatype(name) + elif not isinstance(datatype, DataType): + raise TypeError("The (optional) `datatype' argument must be an " + "instance of the DataType class.") + if dimensions and not isinstance(dimensions, (tuple, list)): + raise TypeError( + "The dimensions argument must be a sequence of tuples") + + self._name = name + self._datatype = { + 'C': datatype.cname, + 'FORTRAN': datatype.fname, + 'JULIA': datatype.jlname, + 'OCTAVE': datatype.octname, + 'PYTHON': datatype.pyname, + 'RUST': datatype.rsname, + } + self.dimensions = dimensions + self.precision = precision + + def __str__(self): + return "%s(%r)" % (self.__class__.__name__, self.name) + + __repr__ = __str__ + + @property + def name(self): + return self._name + + def get_datatype(self, language): + """Returns the datatype string for the requested language. + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.utilities.codegen import Variable + >>> x = Variable(Symbol('x')) + >>> x.get_datatype('c') + 'double' + >>> x.get_datatype('fortran') + 'REAL*8' + + """ + try: + return self._datatype[language.upper()] + except KeyError: + raise CodeGenError("Has datatypes for languages: %s" % + ", ".join(self._datatype)) + + +class Argument(Variable): + """An abstract Argument data structure: a name and a data type. + + This structure is refined in the descendants below. + + """ + pass + + +class InputArgument(Argument): + pass + + +class ResultBase: + """Base class for all "outgoing" information from a routine. + + Objects of this class stores a SymPy expression, and a SymPy object + representing a result variable that will be used in the generated code + only if necessary. + + """ + def __init__(self, expr, result_var): + self.expr = expr + self.result_var = result_var + + def __str__(self): + return "%s(%r, %r)" % (self.__class__.__name__, self.expr, + self.result_var) + + __repr__ = __str__ + + +class OutputArgument(Argument, ResultBase): + """OutputArgument are always initialized in the routine.""" + + def __init__(self, name, result_var, expr, datatype=None, dimensions=None, precision=None): + """Return a new variable. + + Parameters + ========== + + name : Symbol, MatrixSymbol + The name of this variable. When used for code generation, this + might appear, for example, in the prototype of function in the + argument list. + + result_var : Symbol, Indexed + Something that can be used to assign a value to this variable. + Typically the same as `name` but for Indexed this should be e.g., + "y[i]" whereas `name` should be the Symbol "y". + + expr : object + The expression that should be output, typically a SymPy + expression. + + datatype : optional + When not given, the data type will be guessed based on the + assumptions on the symbol argument. + + dimensions : sequence containing tuples, optional + If present, the argument is interpreted as an array, where this + sequence of tuples specifies (lower, upper) bounds for each + index of the array. + + precision : int, optional + Controls the precision of floating point constants. + + """ + + Argument.__init__(self, name, datatype, dimensions, precision) + ResultBase.__init__(self, expr, result_var) + + def __str__(self): + return "%s(%r, %r, %r)" % (self.__class__.__name__, self.name, self.result_var, self.expr) + + __repr__ = __str__ + + +class InOutArgument(Argument, ResultBase): + """InOutArgument are never initialized in the routine.""" + + def __init__(self, name, result_var, expr, datatype=None, dimensions=None, precision=None): + if not datatype: + datatype = get_default_datatype(expr) + Argument.__init__(self, name, datatype, dimensions, precision) + ResultBase.__init__(self, expr, result_var) + __init__.__doc__ = OutputArgument.__init__.__doc__ + + + def __str__(self): + return "%s(%r, %r, %r)" % (self.__class__.__name__, self.name, self.expr, + self.result_var) + + __repr__ = __str__ + + +class Result(Variable, ResultBase): + """An expression for a return value. + + The name result is used to avoid conflicts with the reserved word + "return" in the Python language. It is also shorter than ReturnValue. + + These may or may not need a name in the destination (e.g., "return(x*y)" + might return a value without ever naming it). + + """ + + def __init__(self, expr, name=None, result_var=None, datatype=None, + dimensions=None, precision=None): + """Initialize a return value. + + Parameters + ========== + + expr : SymPy expression + + name : Symbol, MatrixSymbol, optional + The name of this return variable. When used for code generation, + this might appear, for example, in the prototype of function in a + list of return values. A dummy name is generated if omitted. + + result_var : Symbol, Indexed, optional + Something that can be used to assign a value to this variable. + Typically the same as `name` but for Indexed this should be e.g., + "y[i]" whereas `name` should be the Symbol "y". Defaults to + `name` if omitted. + + datatype : optional + When not given, the data type will be guessed based on the + assumptions on the expr argument. + + dimensions : sequence containing tuples, optional + If present, this variable is interpreted as an array, + where this sequence of tuples specifies (lower, upper) + bounds for each index of the array. + + precision : int, optional + Controls the precision of floating point constants. + + """ + # Basic because it is the base class for all types of expressions + if not isinstance(expr, (Basic, MatrixBase)): + raise TypeError("The first argument must be a SymPy expression.") + + if name is None: + name = 'result_%d' % abs(hash(expr)) + + if datatype is None: + #try to infer data type from the expression + datatype = get_default_datatype(expr) + + if isinstance(name, str): + if isinstance(expr, (MatrixBase, MatrixExpr)): + name = MatrixSymbol(name, *expr.shape) + else: + name = Symbol(name) + + if result_var is None: + result_var = name + + Variable.__init__(self, name, datatype=datatype, + dimensions=dimensions, precision=precision) + ResultBase.__init__(self, expr, result_var) + + def __str__(self): + return "%s(%r, %r, %r)" % (self.__class__.__name__, self.expr, self.name, + self.result_var) + + __repr__ = __str__ + + +# +# Transformation of routine objects into code +# + +class CodeGen: + """Abstract class for the code generators.""" + + printer = None # will be set to an instance of a CodePrinter subclass + + def _indent_code(self, codelines): + return self.printer.indent_code(codelines) + + def _printer_method_with_settings(self, method, settings=None, *args, **kwargs): + settings = settings or {} + ori = {k: self.printer._settings[k] for k in settings} + for k, v in settings.items(): + self.printer._settings[k] = v + result = getattr(self.printer, method)(*args, **kwargs) + for k, v in ori.items(): + self.printer._settings[k] = v + return result + + def _get_symbol(self, s): + """Returns the symbol as fcode prints it.""" + if self.printer._settings['human']: + expr_str = self.printer.doprint(s) + else: + constants, not_supported, expr_str = self.printer.doprint(s) + if constants or not_supported: + raise ValueError("Failed to print %s" % str(s)) + return expr_str.strip() + + def __init__(self, project="project", cse=False): + """Initialize a code generator. + + Derived classes will offer more options that affect the generated + code. + + """ + self.project = project + self.cse = cse + + def routine(self, name, expr, argument_sequence=None, global_vars=None): + """Creates an Routine object that is appropriate for this language. + + This implementation is appropriate for at least C/Fortran. Subclasses + can override this if necessary. + + Here, we assume at most one return value (the l-value) which must be + scalar. Additional outputs are OutputArguments (e.g., pointers on + right-hand-side or pass-by-reference). Matrices are always returned + via OutputArguments. If ``argument_sequence`` is None, arguments will + be ordered alphabetically, but with all InputArguments first, and then + OutputArgument and InOutArguments. + + """ + + if self.cse: + from sympy.simplify.cse_main import cse + + if is_sequence(expr) and not isinstance(expr, (MatrixBase, MatrixExpr)): + if not expr: + raise ValueError("No expression given") + for e in expr: + if not e.is_Equality: + raise CodeGenError("Lists of expressions must all be Equalities. {} is not.".format(e)) + + # create a list of right hand sides and simplify them + rhs = [e.rhs for e in expr] + common, simplified = cse(rhs) + + # pack the simplified expressions back up with their left hand sides + expr = [Equality(e.lhs, rhs) for e, rhs in zip(expr, simplified)] + else: + if isinstance(expr, Equality): + common, simplified = cse(expr.rhs) #, ignore=in_out_args) + expr = Equality(expr.lhs, simplified[0]) + else: + common, simplified = cse(expr) + expr = simplified + + local_vars = [Result(b,a) for a,b in common] + local_symbols = {a for a,_ in common} + local_expressions = Tuple(*[b for _,b in common]) + else: + local_expressions = Tuple() + + if is_sequence(expr) and not isinstance(expr, (MatrixBase, MatrixExpr)): + if not expr: + raise ValueError("No expression given") + expressions = Tuple(*expr) + else: + expressions = Tuple(expr) + + if self.cse: + if {i.label for i in expressions.atoms(Idx)} != set(): + raise CodeGenError("CSE and Indexed expressions do not play well together yet") + else: + # local variables for indexed expressions + local_vars = {i.label for i in expressions.atoms(Idx)} + local_symbols = local_vars + + # global variables + global_vars = set() if global_vars is None else set(global_vars) + + # symbols that should be arguments + symbols = (expressions.free_symbols | local_expressions.free_symbols) - local_symbols - global_vars + new_symbols = set() + new_symbols.update(symbols) + + for symbol in symbols: + if isinstance(symbol, Idx): + new_symbols.remove(symbol) + new_symbols.update(symbol.args[1].free_symbols) + if isinstance(symbol, Indexed): + new_symbols.remove(symbol) + symbols = new_symbols + + # Decide whether to use output argument or return value + return_val = [] + output_args = [] + for expr in expressions: + if isinstance(expr, Equality): + out_arg = expr.lhs + expr = expr.rhs + if isinstance(out_arg, Indexed): + dims = tuple([ (S.Zero, dim - 1) for dim in out_arg.shape]) + symbol = out_arg.base.label + elif isinstance(out_arg, Symbol): + dims = [] + symbol = out_arg + elif isinstance(out_arg, MatrixSymbol): + dims = tuple([ (S.Zero, dim - 1) for dim in out_arg.shape]) + symbol = out_arg + else: + raise CodeGenError("Only Indexed, Symbol, or MatrixSymbol " + "can define output arguments.") + + if expr.has(symbol): + output_args.append( + InOutArgument(symbol, out_arg, expr, dimensions=dims)) + else: + output_args.append( + OutputArgument(symbol, out_arg, expr, dimensions=dims)) + + # remove duplicate arguments when they are not local variables + if symbol not in local_vars: + # avoid duplicate arguments + symbols.remove(symbol) + elif isinstance(expr, (ImmutableMatrix, MatrixSlice)): + # Create a "dummy" MatrixSymbol to use as the Output arg + out_arg = MatrixSymbol('out_%s' % abs(hash(expr)), *expr.shape) + dims = tuple([(S.Zero, dim - 1) for dim in out_arg.shape]) + output_args.append( + OutputArgument(out_arg, out_arg, expr, dimensions=dims)) + else: + return_val.append(Result(expr)) + + arg_list = [] + + # setup input argument list + + # helper to get dimensions for data for array-like args + def dimensions(s): + return [(S.Zero, dim - 1) for dim in s.shape] + + array_symbols = {} + for array in expressions.atoms(Indexed) | local_expressions.atoms(Indexed): + array_symbols[array.base.label] = array + for array in expressions.atoms(MatrixSymbol) | local_expressions.atoms(MatrixSymbol): + array_symbols[array] = array + + for symbol in sorted(symbols, key=str): + if symbol in array_symbols: + array = array_symbols[symbol] + metadata = {'dimensions': dimensions(array)} + else: + metadata = {} + + arg_list.append(InputArgument(symbol, **metadata)) + + output_args.sort(key=lambda x: str(x.name)) + arg_list.extend(output_args) + + if argument_sequence is not None: + # if the user has supplied IndexedBase instances, we'll accept that + new_sequence = [] + for arg in argument_sequence: + if isinstance(arg, IndexedBase): + new_sequence.append(arg.label) + else: + new_sequence.append(arg) + argument_sequence = new_sequence + + missing = [x for x in arg_list if x.name not in argument_sequence] + if missing: + msg = "Argument list didn't specify: {0} " + msg = msg.format(", ".join([str(m.name) for m in missing])) + raise CodeGenArgumentListError(msg, missing) + + # create redundant arguments to produce the requested sequence + name_arg_dict = {x.name: x for x in arg_list} + new_args = [] + for symbol in argument_sequence: + try: + new_args.append(name_arg_dict[symbol]) + except KeyError: + if isinstance(symbol, (IndexedBase, MatrixSymbol)): + metadata = {'dimensions': dimensions(symbol)} + else: + metadata = {} + new_args.append(InputArgument(symbol, **metadata)) + arg_list = new_args + + return Routine(name, arg_list, return_val, local_vars, global_vars) + + def write(self, routines, prefix, to_files=False, header=True, empty=True): + """Writes all the source code files for the given routines. + + The generated source is returned as a list of (filename, contents) + tuples, or is written to files (see below). Each filename consists + of the given prefix, appended with an appropriate extension. + + Parameters + ========== + + routines : list + A list of Routine instances to be written + + prefix : string + The prefix for the output files + + to_files : bool, optional + When True, the output is written to files. Otherwise, a list + of (filename, contents) tuples is returned. [default: False] + + header : bool, optional + When True, a header comment is included on top of each source + file. [default: True] + + empty : bool, optional + When True, empty lines are included to structure the source + files. [default: True] + + """ + if to_files: + for dump_fn in self.dump_fns: + filename = "%s.%s" % (prefix, dump_fn.extension) + with open(filename, "w") as f: + dump_fn(self, routines, f, prefix, header, empty) + else: + result = [] + for dump_fn in self.dump_fns: + filename = "%s.%s" % (prefix, dump_fn.extension) + contents = StringIO() + dump_fn(self, routines, contents, prefix, header, empty) + result.append((filename, contents.getvalue())) + return result + + def dump_code(self, routines, f, prefix, header=True, empty=True): + """Write the code by calling language specific methods. + + The generated file contains all the definitions of the routines in + low-level code and refers to the header file if appropriate. + + Parameters + ========== + + routines : list + A list of Routine instances. + + f : file-like + Where to write the file. + + prefix : string + The filename prefix, used to refer to the proper header file. + Only the basename of the prefix is used. + + header : bool, optional + When True, a header comment is included on top of each source + file. [default : True] + + empty : bool, optional + When True, empty lines are included to structure the source + files. [default : True] + + """ + + code_lines = self._preprocessor_statements(prefix) + + for routine in routines: + if empty: + code_lines.append("\n") + code_lines.extend(self._get_routine_opening(routine)) + code_lines.extend(self._declare_arguments(routine)) + code_lines.extend(self._declare_globals(routine)) + code_lines.extend(self._declare_locals(routine)) + if empty: + code_lines.append("\n") + code_lines.extend(self._call_printer(routine)) + if empty: + code_lines.append("\n") + code_lines.extend(self._get_routine_ending(routine)) + + code_lines = self._indent_code(''.join(code_lines)) + + if header: + code_lines = ''.join(self._get_header() + [code_lines]) + + if code_lines: + f.write(code_lines) + + +class CodeGenError(Exception): + pass + + +class CodeGenArgumentListError(Exception): + @property + def missing_args(self): + return self.args[1] + + +header_comment = """Code generated with SymPy %(version)s + +See http://www.sympy.org/ for more information. + +This file is part of '%(project)s' +""" + + +class CCodeGen(CodeGen): + """Generator for C code. + + The .write() method inherited from CodeGen will output a code file and + an interface file, .c and .h respectively. + + """ + + code_extension = "c" + interface_extension = "h" + standard = 'c99' + + def __init__(self, project="project", printer=None, + preprocessor_statements=None, cse=False): + super().__init__(project=project, cse=cse) + self.printer = printer or c_code_printers[self.standard.lower()]() + + self.preprocessor_statements = preprocessor_statements + if preprocessor_statements is None: + self.preprocessor_statements = ['#include '] + + def _get_header(self): + """Writes a common header for the generated files.""" + code_lines = [] + code_lines.append("/" + "*"*78 + '\n') + tmp = header_comment % {"version": sympy_version, + "project": self.project} + for line in tmp.splitlines(): + code_lines.append(" *%s*\n" % line.center(76)) + code_lines.append(" " + "*"*78 + "/\n") + return code_lines + + def get_prototype(self, routine): + """Returns a string for the function prototype of the routine. + + If the routine has multiple result objects, an CodeGenError is + raised. + + See: https://en.wikipedia.org/wiki/Function_prototype + + """ + if len(routine.results) > 1: + raise CodeGenError("C only supports a single or no return value.") + elif len(routine.results) == 1: + ctype = routine.results[0].get_datatype('C') + else: + ctype = "void" + + type_args = [] + for arg in routine.arguments: + name = self.printer.doprint(arg.name) + if arg.dimensions or isinstance(arg, ResultBase): + type_args.append((arg.get_datatype('C'), "*%s" % name)) + else: + type_args.append((arg.get_datatype('C'), name)) + arguments = ", ".join([ "%s %s" % t for t in type_args]) + return "%s %s(%s)" % (ctype, routine.name, arguments) + + def _preprocessor_statements(self, prefix): + code_lines = [] + code_lines.append('#include "{}.h"'.format(os.path.basename(prefix))) + code_lines.extend(self.preprocessor_statements) + code_lines = ['{}\n'.format(l) for l in code_lines] + return code_lines + + def _get_routine_opening(self, routine): + prototype = self.get_prototype(routine) + return ["%s {\n" % prototype] + + def _declare_arguments(self, routine): + # arguments are declared in prototype + return [] + + def _declare_globals(self, routine): + # global variables are not explicitly declared within C functions + return [] + + def _declare_locals(self, routine): + + # Compose a list of symbols to be dereferenced in the function + # body. These are the arguments that were passed by a reference + # pointer, excluding arrays. + dereference = [] + for arg in routine.arguments: + if isinstance(arg, ResultBase) and not arg.dimensions: + dereference.append(arg.name) + + code_lines = [] + for result in routine.local_vars: + + # local variables that are simple symbols such as those used as indices into + # for loops are defined declared elsewhere. + if not isinstance(result, Result): + continue + + if result.name != result.result_var: + raise CodeGen("Result variable and name should match: {}".format(result)) + assign_to = result.name + t = result.get_datatype('c') + if isinstance(result.expr, (MatrixBase, MatrixExpr)): + dims = result.expr.shape + code_lines.append("{} {}[{}];\n".format(t, str(assign_to), dims[0]*dims[1])) + prefix = "" + else: + prefix = "const {} ".format(t) + + constants, not_c, c_expr = self._printer_method_with_settings( + 'doprint', {"human": False, "dereference": dereference, "strict": False}, + result.expr, assign_to=assign_to) + + for name, value in sorted(constants, key=str): + code_lines.append("double const %s = %s;\n" % (name, value)) + + code_lines.append("{}{}\n".format(prefix, c_expr)) + + return code_lines + + def _call_printer(self, routine): + code_lines = [] + + # Compose a list of symbols to be dereferenced in the function + # body. These are the arguments that were passed by a reference + # pointer, excluding arrays. + dereference = [] + for arg in routine.arguments: + if isinstance(arg, ResultBase) and not arg.dimensions: + dereference.append(arg.name) + + return_val = None + for result in routine.result_variables: + if isinstance(result, Result): + assign_to = routine.name + "_result" + t = result.get_datatype('c') + code_lines.append("{} {};\n".format(t, str(assign_to))) + return_val = assign_to + else: + assign_to = result.result_var + + try: + constants, not_c, c_expr = self._printer_method_with_settings( + 'doprint', {"human": False, "dereference": dereference, "strict": False}, + result.expr, assign_to=assign_to) + except AssignmentError: + assign_to = result.result_var + code_lines.append( + "%s %s;\n" % (result.get_datatype('c'), str(assign_to))) + constants, not_c, c_expr = self._printer_method_with_settings( + 'doprint', {"human": False, "dereference": dereference, "strict": False}, + result.expr, assign_to=assign_to) + + for name, value in sorted(constants, key=str): + code_lines.append("double const %s = %s;\n" % (name, value)) + code_lines.append("%s\n" % c_expr) + + if return_val: + code_lines.append(" return %s;\n" % return_val) + return code_lines + + def _get_routine_ending(self, routine): + return ["}\n"] + + def dump_c(self, routines, f, prefix, header=True, empty=True): + self.dump_code(routines, f, prefix, header, empty) + dump_c.extension = code_extension # type: ignore + dump_c.__doc__ = CodeGen.dump_code.__doc__ + + def dump_h(self, routines, f, prefix, header=True, empty=True): + """Writes the C header file. + + This file contains all the function declarations. + + Parameters + ========== + + routines : list + A list of Routine instances. + + f : file-like + Where to write the file. + + prefix : string + The filename prefix, used to construct the include guards. + Only the basename of the prefix is used. + + header : bool, optional + When True, a header comment is included on top of each source + file. [default : True] + + empty : bool, optional + When True, empty lines are included to structure the source + files. [default : True] + + """ + if header: + print(''.join(self._get_header()), file=f) + guard_name = "%s__%s__H" % (self.project.replace( + " ", "_").upper(), prefix.replace("/", "_").upper()) + # include guards + if empty: + print(file=f) + print("#ifndef %s" % guard_name, file=f) + print("#define %s" % guard_name, file=f) + if empty: + print(file=f) + # declaration of the function prototypes + for routine in routines: + prototype = self.get_prototype(routine) + print("%s;" % prototype, file=f) + # end if include guards + if empty: + print(file=f) + print("#endif", file=f) + if empty: + print(file=f) + dump_h.extension = interface_extension # type: ignore + + # This list of dump functions is used by CodeGen.write to know which dump + # functions it has to call. + dump_fns = [dump_c, dump_h] + +class C89CodeGen(CCodeGen): + standard = 'C89' + +class C99CodeGen(CCodeGen): + standard = 'C99' + +class FCodeGen(CodeGen): + """Generator for Fortran 95 code + + The .write() method inherited from CodeGen will output a code file and + an interface file, .f90 and .h respectively. + + """ + + code_extension = "f90" + interface_extension = "h" + + def __init__(self, project='project', printer=None): + super().__init__(project) + self.printer = printer or FCodePrinter() + + def _get_header(self): + """Writes a common header for the generated files.""" + code_lines = [] + code_lines.append("!" + "*"*78 + '\n') + tmp = header_comment % {"version": sympy_version, + "project": self.project} + for line in tmp.splitlines(): + code_lines.append("!*%s*\n" % line.center(76)) + code_lines.append("!" + "*"*78 + '\n') + return code_lines + + def _preprocessor_statements(self, prefix): + return [] + + def _get_routine_opening(self, routine): + """Returns the opening statements of the fortran routine.""" + code_list = [] + if len(routine.results) > 1: + raise CodeGenError( + "Fortran only supports a single or no return value.") + elif len(routine.results) == 1: + result = routine.results[0] + code_list.append(result.get_datatype('fortran')) + code_list.append("function") + else: + code_list.append("subroutine") + + args = ", ".join("%s" % self._get_symbol(arg.name) + for arg in routine.arguments) + + call_sig = "{}({})\n".format(routine.name, args) + # Fortran 95 requires all lines be less than 132 characters, so wrap + # this line before appending. + call_sig = ' &\n'.join(textwrap.wrap(call_sig, + width=60, + break_long_words=False)) + '\n' + code_list.append(call_sig) + code_list = [' '.join(code_list)] + code_list.append('implicit none\n') + return code_list + + def _declare_arguments(self, routine): + # argument type declarations + code_list = [] + array_list = [] + scalar_list = [] + for arg in routine.arguments: + + if isinstance(arg, InputArgument): + typeinfo = "%s, intent(in)" % arg.get_datatype('fortran') + elif isinstance(arg, InOutArgument): + typeinfo = "%s, intent(inout)" % arg.get_datatype('fortran') + elif isinstance(arg, OutputArgument): + typeinfo = "%s, intent(out)" % arg.get_datatype('fortran') + else: + raise CodeGenError("Unknown Argument type: %s" % type(arg)) + + fprint = self._get_symbol + + if arg.dimensions: + # fortran arrays start at 1 + dimstr = ", ".join(["%s:%s" % ( + fprint(dim[0] + 1), fprint(dim[1] + 1)) + for dim in arg.dimensions]) + typeinfo += ", dimension(%s)" % dimstr + array_list.append("%s :: %s\n" % (typeinfo, fprint(arg.name))) + else: + scalar_list.append("%s :: %s\n" % (typeinfo, fprint(arg.name))) + + # scalars first, because they can be used in array declarations + code_list.extend(scalar_list) + code_list.extend(array_list) + + return code_list + + def _declare_globals(self, routine): + # Global variables not explicitly declared within Fortran 90 functions. + # Note: a future F77 mode may need to generate "common" blocks. + return [] + + def _declare_locals(self, routine): + code_list = [] + for var in sorted(routine.local_vars, key=str): + typeinfo = get_default_datatype(var) + code_list.append("%s :: %s\n" % ( + typeinfo.fname, self._get_symbol(var))) + return code_list + + def _get_routine_ending(self, routine): + """Returns the closing statements of the fortran routine.""" + if len(routine.results) == 1: + return ["end function\n"] + else: + return ["end subroutine\n"] + + def get_interface(self, routine): + """Returns a string for the function interface. + + The routine should have a single result object, which can be None. + If the routine has multiple result objects, a CodeGenError is + raised. + + See: https://en.wikipedia.org/wiki/Function_prototype + + """ + prototype = [ "interface\n" ] + prototype.extend(self._get_routine_opening(routine)) + prototype.extend(self._declare_arguments(routine)) + prototype.extend(self._get_routine_ending(routine)) + prototype.append("end interface\n") + + return "".join(prototype) + + def _call_printer(self, routine): + declarations = [] + code_lines = [] + for result in routine.result_variables: + if isinstance(result, Result): + assign_to = routine.name + elif isinstance(result, (OutputArgument, InOutArgument)): + assign_to = result.result_var + + constants, not_fortran, f_expr = self._printer_method_with_settings( + 'doprint', {"human": False, "source_format": 'free', "standard": 95, "strict": False}, + result.expr, assign_to=assign_to) + + for obj, v in sorted(constants, key=str): + t = get_default_datatype(obj) + declarations.append( + "%s, parameter :: %s = %s\n" % (t.fname, obj, v)) + for obj in sorted(not_fortran, key=str): + t = get_default_datatype(obj) + if isinstance(obj, Function): + name = obj.func + else: + name = obj + declarations.append("%s :: %s\n" % (t.fname, name)) + + code_lines.append("%s\n" % f_expr) + return declarations + code_lines + + def _indent_code(self, codelines): + return self._printer_method_with_settings( + 'indent_code', {"human": False, "source_format": 'free', "strict": False}, codelines) + + def dump_f95(self, routines, f, prefix, header=True, empty=True): + # check that symbols are unique with ignorecase + for r in routines: + lowercase = {str(x).lower() for x in r.variables} + orig_case = {str(x) for x in r.variables} + if len(lowercase) < len(orig_case): + raise CodeGenError("Fortran ignores case. Got symbols: %s" % + (", ".join([str(var) for var in r.variables]))) + self.dump_code(routines, f, prefix, header, empty) + dump_f95.extension = code_extension # type: ignore + dump_f95.__doc__ = CodeGen.dump_code.__doc__ + + def dump_h(self, routines, f, prefix, header=True, empty=True): + """Writes the interface to a header file. + + This file contains all the function declarations. + + Parameters + ========== + + routines : list + A list of Routine instances. + + f : file-like + Where to write the file. + + prefix : string + The filename prefix. + + header : bool, optional + When True, a header comment is included on top of each source + file. [default : True] + + empty : bool, optional + When True, empty lines are included to structure the source + files. [default : True] + + """ + if header: + print(''.join(self._get_header()), file=f) + if empty: + print(file=f) + # declaration of the function prototypes + for routine in routines: + prototype = self.get_interface(routine) + f.write(prototype) + if empty: + print(file=f) + dump_h.extension = interface_extension # type: ignore + + # This list of dump functions is used by CodeGen.write to know which dump + # functions it has to call. + dump_fns = [dump_f95, dump_h] + + +class JuliaCodeGen(CodeGen): + """Generator for Julia code. + + The .write() method inherited from CodeGen will output a code file + .jl. + + """ + + code_extension = "jl" + + def __init__(self, project='project', printer=None): + super().__init__(project) + self.printer = printer or JuliaCodePrinter() + + def routine(self, name, expr, argument_sequence, global_vars): + """Specialized Routine creation for Julia.""" + + if is_sequence(expr) and not isinstance(expr, (MatrixBase, MatrixExpr)): + if not expr: + raise ValueError("No expression given") + expressions = Tuple(*expr) + else: + expressions = Tuple(expr) + + # local variables + local_vars = {i.label for i in expressions.atoms(Idx)} + + # global variables + global_vars = set() if global_vars is None else set(global_vars) + + # symbols that should be arguments + old_symbols = expressions.free_symbols - local_vars - global_vars + symbols = set() + for s in old_symbols: + if isinstance(s, Idx): + symbols.update(s.args[1].free_symbols) + elif not isinstance(s, Indexed): + symbols.add(s) + + # Julia supports multiple return values + return_vals = [] + output_args = [] + for (i, expr) in enumerate(expressions): + if isinstance(expr, Equality): + out_arg = expr.lhs + expr = expr.rhs + symbol = out_arg + if isinstance(out_arg, Indexed): + dims = tuple([ (S.One, dim) for dim in out_arg.shape]) + symbol = out_arg.base.label + output_args.append(InOutArgument(symbol, out_arg, expr, dimensions=dims)) + if not isinstance(out_arg, (Indexed, Symbol, MatrixSymbol)): + raise CodeGenError("Only Indexed, Symbol, or MatrixSymbol " + "can define output arguments.") + + return_vals.append(Result(expr, name=symbol, result_var=out_arg)) + if not expr.has(symbol): + # this is a pure output: remove from the symbols list, so + # it doesn't become an input. + symbols.remove(symbol) + + else: + # we have no name for this output + return_vals.append(Result(expr, name='out%d' % (i+1))) + + # setup input argument list + output_args.sort(key=lambda x: str(x.name)) + arg_list = list(output_args) + array_symbols = {} + for array in expressions.atoms(Indexed): + array_symbols[array.base.label] = array + for array in expressions.atoms(MatrixSymbol): + array_symbols[array] = array + + for symbol in sorted(symbols, key=str): + arg_list.append(InputArgument(symbol)) + + if argument_sequence is not None: + # if the user has supplied IndexedBase instances, we'll accept that + new_sequence = [] + for arg in argument_sequence: + if isinstance(arg, IndexedBase): + new_sequence.append(arg.label) + else: + new_sequence.append(arg) + argument_sequence = new_sequence + + missing = [x for x in arg_list if x.name not in argument_sequence] + if missing: + msg = "Argument list didn't specify: {0} " + msg = msg.format(", ".join([str(m.name) for m in missing])) + raise CodeGenArgumentListError(msg, missing) + + # create redundant arguments to produce the requested sequence + name_arg_dict = {x.name: x for x in arg_list} + new_args = [] + for symbol in argument_sequence: + try: + new_args.append(name_arg_dict[symbol]) + except KeyError: + new_args.append(InputArgument(symbol)) + arg_list = new_args + + return Routine(name, arg_list, return_vals, local_vars, global_vars) + + def _get_header(self): + """Writes a common header for the generated files.""" + code_lines = [] + tmp = header_comment % {"version": sympy_version, + "project": self.project} + for line in tmp.splitlines(): + if line == '': + code_lines.append("#\n") + else: + code_lines.append("# %s\n" % line) + return code_lines + + def _preprocessor_statements(self, prefix): + return [] + + def _get_routine_opening(self, routine): + """Returns the opening statements of the routine.""" + code_list = [] + code_list.append("function ") + + # Inputs + args = [] + for arg in routine.arguments: + if isinstance(arg, OutputArgument): + raise CodeGenError("Julia: invalid argument of type %s" % + str(type(arg))) + if isinstance(arg, (InputArgument, InOutArgument)): + args.append("%s" % self._get_symbol(arg.name)) + args = ", ".join(args) + code_list.append("%s(%s)\n" % (routine.name, args)) + code_list = [ "".join(code_list) ] + + return code_list + + def _declare_arguments(self, routine): + return [] + + def _declare_globals(self, routine): + return [] + + def _declare_locals(self, routine): + return [] + + def _get_routine_ending(self, routine): + outs = [] + for result in routine.results: + if isinstance(result, Result): + # Note: name not result_var; want `y` not `y[i]` for Indexed + s = self._get_symbol(result.name) + else: + raise CodeGenError("unexpected object in Routine results") + outs.append(s) + return ["return " + ", ".join(outs) + "\nend\n"] + + def _call_printer(self, routine): + declarations = [] + code_lines = [] + for result in routine.results: + if isinstance(result, Result): + assign_to = result.result_var + else: + raise CodeGenError("unexpected object in Routine results") + + constants, not_supported, jl_expr = self._printer_method_with_settings( + 'doprint', {"human": False, "strict": False}, result.expr, assign_to=assign_to) + + for obj, v in sorted(constants, key=str): + declarations.append( + "%s = %s\n" % (obj, v)) + for obj in sorted(not_supported, key=str): + if isinstance(obj, Function): + name = obj.func + else: + name = obj + declarations.append( + "# unsupported: %s\n" % (name)) + code_lines.append("%s\n" % (jl_expr)) + return declarations + code_lines + + def _indent_code(self, codelines): + # Note that indenting seems to happen twice, first + # statement-by-statement by JuliaPrinter then again here. + p = JuliaCodePrinter({'human': False, "strict": False}) + return p.indent_code(codelines) + + def dump_jl(self, routines, f, prefix, header=True, empty=True): + self.dump_code(routines, f, prefix, header, empty) + + dump_jl.extension = code_extension # type: ignore + dump_jl.__doc__ = CodeGen.dump_code.__doc__ + + # This list of dump functions is used by CodeGen.write to know which dump + # functions it has to call. + dump_fns = [dump_jl] + + +class OctaveCodeGen(CodeGen): + """Generator for Octave code. + + The .write() method inherited from CodeGen will output a code file + .m. + + Octave .m files usually contain one function. That function name should + match the filename (``prefix``). If you pass multiple ``name_expr`` pairs, + the latter ones are presumed to be private functions accessed by the + primary function. + + You should only pass inputs to ``argument_sequence``: outputs are ordered + according to their order in ``name_expr``. + + """ + + code_extension = "m" + + def __init__(self, project='project', printer=None): + super().__init__(project) + self.printer = printer or OctaveCodePrinter() + + def routine(self, name, expr, argument_sequence, global_vars): + """Specialized Routine creation for Octave.""" + + # FIXME: this is probably general enough for other high-level + # languages, perhaps its the C/Fortran one that is specialized! + + if is_sequence(expr) and not isinstance(expr, (MatrixBase, MatrixExpr)): + if not expr: + raise ValueError("No expression given") + expressions = Tuple(*expr) + else: + expressions = Tuple(expr) + + # local variables + local_vars = {i.label for i in expressions.atoms(Idx)} + + # global variables + global_vars = set() if global_vars is None else set(global_vars) + + # symbols that should be arguments + old_symbols = expressions.free_symbols - local_vars - global_vars + symbols = set() + for s in old_symbols: + if isinstance(s, Idx): + symbols.update(s.args[1].free_symbols) + elif not isinstance(s, Indexed): + symbols.add(s) + + # Octave supports multiple return values + return_vals = [] + for (i, expr) in enumerate(expressions): + if isinstance(expr, Equality): + out_arg = expr.lhs + expr = expr.rhs + symbol = out_arg + if isinstance(out_arg, Indexed): + symbol = out_arg.base.label + if not isinstance(out_arg, (Indexed, Symbol, MatrixSymbol)): + raise CodeGenError("Only Indexed, Symbol, or MatrixSymbol " + "can define output arguments.") + + return_vals.append(Result(expr, name=symbol, result_var=out_arg)) + if not expr.has(symbol): + # this is a pure output: remove from the symbols list, so + # it doesn't become an input. + symbols.remove(symbol) + + else: + # we have no name for this output + return_vals.append(Result(expr, name='out%d' % (i+1))) + + # setup input argument list + arg_list = [] + array_symbols = {} + for array in expressions.atoms(Indexed): + array_symbols[array.base.label] = array + for array in expressions.atoms(MatrixSymbol): + array_symbols[array] = array + + for symbol in sorted(symbols, key=str): + arg_list.append(InputArgument(symbol)) + + if argument_sequence is not None: + # if the user has supplied IndexedBase instances, we'll accept that + new_sequence = [] + for arg in argument_sequence: + if isinstance(arg, IndexedBase): + new_sequence.append(arg.label) + else: + new_sequence.append(arg) + argument_sequence = new_sequence + + missing = [x for x in arg_list if x.name not in argument_sequence] + if missing: + msg = "Argument list didn't specify: {0} " + msg = msg.format(", ".join([str(m.name) for m in missing])) + raise CodeGenArgumentListError(msg, missing) + + # create redundant arguments to produce the requested sequence + name_arg_dict = {x.name: x for x in arg_list} + new_args = [] + for symbol in argument_sequence: + try: + new_args.append(name_arg_dict[symbol]) + except KeyError: + new_args.append(InputArgument(symbol)) + arg_list = new_args + + return Routine(name, arg_list, return_vals, local_vars, global_vars) + + def _get_header(self): + """Writes a common header for the generated files.""" + code_lines = [] + tmp = header_comment % {"version": sympy_version, + "project": self.project} + for line in tmp.splitlines(): + if line == '': + code_lines.append("%\n") + else: + code_lines.append("%% %s\n" % line) + return code_lines + + def _preprocessor_statements(self, prefix): + return [] + + def _get_routine_opening(self, routine): + """Returns the opening statements of the routine.""" + code_list = [] + code_list.append("function ") + + # Outputs + outs = [] + for result in routine.results: + if isinstance(result, Result): + # Note: name not result_var; want `y` not `y(i)` for Indexed + s = self._get_symbol(result.name) + else: + raise CodeGenError("unexpected object in Routine results") + outs.append(s) + if len(outs) > 1: + code_list.append("[" + (", ".join(outs)) + "]") + else: + code_list.append("".join(outs)) + code_list.append(" = ") + + # Inputs + args = [] + for arg in routine.arguments: + if isinstance(arg, (OutputArgument, InOutArgument)): + raise CodeGenError("Octave: invalid argument of type %s" % + str(type(arg))) + if isinstance(arg, InputArgument): + args.append("%s" % self._get_symbol(arg.name)) + args = ", ".join(args) + code_list.append("%s(%s)\n" % (routine.name, args)) + code_list = [ "".join(code_list) ] + + return code_list + + def _declare_arguments(self, routine): + return [] + + def _declare_globals(self, routine): + if not routine.global_vars: + return [] + s = " ".join(sorted([self._get_symbol(g) for g in routine.global_vars])) + return ["global " + s + "\n"] + + def _declare_locals(self, routine): + return [] + + def _get_routine_ending(self, routine): + return ["end\n"] + + def _call_printer(self, routine): + declarations = [] + code_lines = [] + for result in routine.results: + if isinstance(result, Result): + assign_to = result.result_var + else: + raise CodeGenError("unexpected object in Routine results") + + constants, not_supported, oct_expr = self._printer_method_with_settings( + 'doprint', {"human": False, "strict": False}, result.expr, assign_to=assign_to) + + for obj, v in sorted(constants, key=str): + declarations.append( + " %s = %s; %% constant\n" % (obj, v)) + for obj in sorted(not_supported, key=str): + if isinstance(obj, Function): + name = obj.func + else: + name = obj + declarations.append( + " %% unsupported: %s\n" % (name)) + code_lines.append("%s\n" % (oct_expr)) + return declarations + code_lines + + def _indent_code(self, codelines): + return self._printer_method_with_settings( + 'indent_code', {"human": False, "strict": False}, codelines) + + def dump_m(self, routines, f, prefix, header=True, empty=True, inline=True): + # Note used to call self.dump_code() but we need more control for header + + code_lines = self._preprocessor_statements(prefix) + + for i, routine in enumerate(routines): + if i > 0: + if empty: + code_lines.append("\n") + code_lines.extend(self._get_routine_opening(routine)) + if i == 0: + if routine.name != prefix: + raise ValueError('Octave function name should match prefix') + if header: + code_lines.append("%" + prefix.upper() + + " Autogenerated by SymPy\n") + code_lines.append(''.join(self._get_header())) + code_lines.extend(self._declare_arguments(routine)) + code_lines.extend(self._declare_globals(routine)) + code_lines.extend(self._declare_locals(routine)) + if empty: + code_lines.append("\n") + code_lines.extend(self._call_printer(routine)) + if empty: + code_lines.append("\n") + code_lines.extend(self._get_routine_ending(routine)) + + code_lines = self._indent_code(''.join(code_lines)) + + if code_lines: + f.write(code_lines) + + dump_m.extension = code_extension # type: ignore + dump_m.__doc__ = CodeGen.dump_code.__doc__ + + # This list of dump functions is used by CodeGen.write to know which dump + # functions it has to call. + dump_fns = [dump_m] + +class RustCodeGen(CodeGen): + """Generator for Rust code. + + The .write() method inherited from CodeGen will output a code file + .rs + + """ + + code_extension = "rs" + + def __init__(self, project="project", printer=None): + super().__init__(project=project) + self.printer = printer or RustCodePrinter() + + def routine(self, name, expr, argument_sequence, global_vars): + """Specialized Routine creation for Rust.""" + + if is_sequence(expr) and not isinstance(expr, (MatrixBase, MatrixExpr)): + if not expr: + raise ValueError("No expression given") + expressions = Tuple(*expr) + else: + expressions = Tuple(expr) + + # local variables + local_vars = {i.label for i in expressions.atoms(Idx)} + + # global variables + global_vars = set() if global_vars is None else set(global_vars) + + # symbols that should be arguments + symbols = expressions.free_symbols - local_vars - global_vars - expressions.atoms(Indexed) + + # Rust supports multiple return values + return_vals = [] + output_args = [] + for (i, expr) in enumerate(expressions): + if isinstance(expr, Equality): + out_arg = expr.lhs + expr = expr.rhs + symbol = out_arg + if isinstance(out_arg, Indexed): + dims = tuple([ (S.One, dim) for dim in out_arg.shape]) + symbol = out_arg.base.label + output_args.append(InOutArgument(symbol, out_arg, expr, dimensions=dims)) + if not isinstance(out_arg, (Indexed, Symbol, MatrixSymbol)): + raise CodeGenError("Only Indexed, Symbol, or MatrixSymbol " + "can define output arguments.") + + return_vals.append(Result(expr, name=symbol, result_var=out_arg)) + if not expr.has(symbol): + # this is a pure output: remove from the symbols list, so + # it doesn't become an input. + symbols.remove(symbol) + + else: + # we have no name for this output + return_vals.append(Result(expr, name='out%d' % (i+1))) + + # setup input argument list + output_args.sort(key=lambda x: str(x.name)) + arg_list = list(output_args) + array_symbols = {} + for array in expressions.atoms(Indexed): + array_symbols[array.base.label] = array + for array in expressions.atoms(MatrixSymbol): + array_symbols[array] = array + + for symbol in sorted(symbols, key=str): + arg_list.append(InputArgument(symbol)) + + if argument_sequence is not None: + # if the user has supplied IndexedBase instances, we'll accept that + new_sequence = [] + for arg in argument_sequence: + if isinstance(arg, IndexedBase): + new_sequence.append(arg.label) + else: + new_sequence.append(arg) + argument_sequence = new_sequence + + missing = [x for x in arg_list if x.name not in argument_sequence] + if missing: + msg = "Argument list didn't specify: {0} " + msg = msg.format(", ".join([str(m.name) for m in missing])) + raise CodeGenArgumentListError(msg, missing) + + # create redundant arguments to produce the requested sequence + name_arg_dict = {x.name: x for x in arg_list} + new_args = [] + for symbol in argument_sequence: + try: + new_args.append(name_arg_dict[symbol]) + except KeyError: + new_args.append(InputArgument(symbol)) + arg_list = new_args + + return Routine(name, arg_list, return_vals, local_vars, global_vars) + + + def _get_header(self): + """Writes a common header for the generated files.""" + code_lines = [] + code_lines.append("/*\n") + tmp = header_comment % {"version": sympy_version, + "project": self.project} + for line in tmp.splitlines(): + code_lines.append((" *%s" % line.center(76)).rstrip() + "\n") + code_lines.append(" */\n") + return code_lines + + def get_prototype(self, routine): + """Returns a string for the function prototype of the routine. + + If the routine has multiple result objects, an CodeGenError is + raised. + + See: https://en.wikipedia.org/wiki/Function_prototype + + """ + results = [i.get_datatype('Rust') for i in routine.results] + + if len(results) == 1: + rstype = " -> " + results[0] + elif len(routine.results) > 1: + rstype = " -> (" + ", ".join(results) + ")" + else: + rstype = "" + + type_args = [] + for arg in routine.arguments: + name = self.printer.doprint(arg.name) + if arg.dimensions or isinstance(arg, ResultBase): + type_args.append(("*%s" % name, arg.get_datatype('Rust'))) + else: + type_args.append((name, arg.get_datatype('Rust'))) + arguments = ", ".join([ "%s: %s" % t for t in type_args]) + return "fn %s(%s)%s" % (routine.name, arguments, rstype) + + def _preprocessor_statements(self, prefix): + code_lines = [] + # code_lines.append("use std::f64::consts::*;\n") + return code_lines + + def _get_routine_opening(self, routine): + prototype = self.get_prototype(routine) + return ["%s {\n" % prototype] + + def _declare_arguments(self, routine): + # arguments are declared in prototype + return [] + + def _declare_globals(self, routine): + # global variables are not explicitly declared within C functions + return [] + + def _declare_locals(self, routine): + # loop variables are declared in loop statement + return [] + + def _call_printer(self, routine): + + code_lines = [] + declarations = [] + returns = [] + + # Compose a list of symbols to be dereferenced in the function + # body. These are the arguments that were passed by a reference + # pointer, excluding arrays. + dereference = [] + for arg in routine.arguments: + if isinstance(arg, ResultBase) and not arg.dimensions: + dereference.append(arg.name) + + for result in routine.results: + if isinstance(result, Result): + assign_to = result.result_var + returns.append(str(result.result_var)) + else: + raise CodeGenError("unexpected object in Routine results") + + constants, not_supported, rs_expr = self._printer_method_with_settings( + 'doprint', {"human": False, "strict": False}, result.expr, assign_to=assign_to) + + for name, value in sorted(constants, key=str): + declarations.append("const %s: f64 = %s;\n" % (name, value)) + + for obj in sorted(not_supported, key=str): + if isinstance(obj, Function): + name = obj.func + else: + name = obj + declarations.append("// unsupported: %s\n" % (name)) + + code_lines.append("let %s\n" % rs_expr) + + if len(returns) > 1: + returns = ['(' + ', '.join(returns) + ')'] + + returns.append('\n') + + return declarations + code_lines + returns + + def _get_routine_ending(self, routine): + return ["}\n"] + + def dump_rs(self, routines, f, prefix, header=True, empty=True): + self.dump_code(routines, f, prefix, header, empty) + + dump_rs.extension = code_extension # type: ignore + dump_rs.__doc__ = CodeGen.dump_code.__doc__ + + # This list of dump functions is used by CodeGen.write to know which dump + # functions it has to call. + dump_fns = [dump_rs] + + + + +def get_code_generator(language, project=None, standard=None, printer = None): + if language == 'C': + if standard is None: + pass + elif standard.lower() == 'c89': + language = 'C89' + elif standard.lower() == 'c99': + language = 'C99' + CodeGenClass = {"C": CCodeGen, "C89": C89CodeGen, "C99": C99CodeGen, + "F95": FCodeGen, "JULIA": JuliaCodeGen, + "OCTAVE": OctaveCodeGen, + "RUST": RustCodeGen}.get(language.upper()) + if CodeGenClass is None: + raise ValueError("Language '%s' is not supported." % language) + return CodeGenClass(project, printer) + + +# +# Friendly functions +# + + +def codegen(name_expr, language=None, prefix=None, project="project", + to_files=False, header=True, empty=True, argument_sequence=None, + global_vars=None, standard=None, code_gen=None, printer=None): + """Generate source code for expressions in a given language. + + Parameters + ========== + + name_expr : tuple, or list of tuples + A single (name, expression) tuple or a list of (name, expression) + tuples. Each tuple corresponds to a routine. If the expression is + an equality (an instance of class Equality) the left hand side is + considered an output argument. If expression is an iterable, then + the routine will have multiple outputs. + + language : string, + A string that indicates the source code language. This is case + insensitive. Currently, 'C', 'F95' and 'Octave' are supported. + 'Octave' generates code compatible with both Octave and Matlab. + + prefix : string, optional + A prefix for the names of the files that contain the source code. + Language-dependent suffixes will be appended. If omitted, the name + of the first name_expr tuple is used. + + project : string, optional + A project name, used for making unique preprocessor instructions. + [default: "project"] + + to_files : bool, optional + When True, the code will be written to one or more files with the + given prefix, otherwise strings with the names and contents of + these files are returned. [default: False] + + header : bool, optional + When True, a header is written on top of each source file. + [default: True] + + empty : bool, optional + When True, empty lines are used to structure the code. + [default: True] + + argument_sequence : iterable, optional + Sequence of arguments for the routine in a preferred order. A + CodeGenError is raised if required arguments are missing. + Redundant arguments are used without warning. If omitted, + arguments will be ordered alphabetically, but with all input + arguments first, and then output or in-out arguments. + + global_vars : iterable, optional + Sequence of global variables used by the routine. Variables + listed here will not show up as function arguments. + + standard : string, optional + + code_gen : CodeGen instance, optional + An instance of a CodeGen subclass. Overrides ``language``. + + printer : Printer instance, optional + An instance of a Printer subclass. + + Examples + ======== + + >>> from sympy.utilities.codegen import codegen + >>> from sympy.abc import x, y, z + >>> [(c_name, c_code), (h_name, c_header)] = codegen( + ... ("f", x+y*z), "C89", "test", header=False, empty=False) + >>> print(c_name) + test.c + >>> print(c_code) + #include "test.h" + #include + double f(double x, double y, double z) { + double f_result; + f_result = x + y*z; + return f_result; + } + + >>> print(h_name) + test.h + >>> print(c_header) + #ifndef PROJECT__TEST__H + #define PROJECT__TEST__H + double f(double x, double y, double z); + #endif + + + Another example using Equality objects to give named outputs. Here the + filename (prefix) is taken from the first (name, expr) pair. + + >>> from sympy.abc import f, g + >>> from sympy import Eq + >>> [(c_name, c_code), (h_name, c_header)] = codegen( + ... [("myfcn", x + y), ("fcn2", [Eq(f, 2*x), Eq(g, y)])], + ... "C99", header=False, empty=False) + >>> print(c_name) + myfcn.c + >>> print(c_code) + #include "myfcn.h" + #include + double myfcn(double x, double y) { + double myfcn_result; + myfcn_result = x + y; + return myfcn_result; + } + void fcn2(double x, double y, double *f, double *g) { + (*f) = 2*x; + (*g) = y; + } + + + If the generated function(s) will be part of a larger project where various + global variables have been defined, the 'global_vars' option can be used + to remove the specified variables from the function signature + + >>> from sympy.utilities.codegen import codegen + >>> from sympy.abc import x, y, z + >>> [(f_name, f_code), header] = codegen( + ... ("f", x+y*z), "F95", header=False, empty=False, + ... argument_sequence=(x, y), global_vars=(z,)) + >>> print(f_code) + REAL*8 function f(x, y) + implicit none + REAL*8, intent(in) :: x + REAL*8, intent(in) :: y + f = x + y*z + end function + + + """ + + # Initialize the code generator. + if language is None: + if code_gen is None: + raise ValueError("Need either language or code_gen") + else: + if code_gen is not None: + raise ValueError("You cannot specify both language and code_gen.") + code_gen = get_code_generator(language, project, standard, printer) + + if isinstance(name_expr[0], str): + # single tuple is given, turn it into a singleton list with a tuple. + name_expr = [name_expr] + + if prefix is None: + prefix = name_expr[0][0] + + # Construct Routines appropriate for this code_gen from (name, expr) pairs. + routines = [] + for name, expr in name_expr: + routines.append(code_gen.routine(name, expr, argument_sequence, + global_vars)) + + # Write the code. + return code_gen.write(routines, prefix, to_files, header, empty) + + +def make_routine(name, expr, argument_sequence=None, + global_vars=None, language="F95"): + """A factory that makes an appropriate Routine from an expression. + + Parameters + ========== + + name : string + The name of this routine in the generated code. + + expr : expression or list/tuple of expressions + A SymPy expression that the Routine instance will represent. If + given a list or tuple of expressions, the routine will be + considered to have multiple return values and/or output arguments. + + argument_sequence : list or tuple, optional + List arguments for the routine in a preferred order. If omitted, + the results are language dependent, for example, alphabetical order + or in the same order as the given expressions. + + global_vars : iterable, optional + Sequence of global variables used by the routine. Variables + listed here will not show up as function arguments. + + language : string, optional + Specify a target language. The Routine itself should be + language-agnostic but the precise way one is created, error + checking, etc depend on the language. [default: "F95"]. + + Notes + ===== + + A decision about whether to use output arguments or return values is made + depending on both the language and the particular mathematical expressions. + For an expression of type Equality, the left hand side is typically made + into an OutputArgument (or perhaps an InOutArgument if appropriate). + Otherwise, typically, the calculated expression is made a return values of + the routine. + + Examples + ======== + + >>> from sympy.utilities.codegen import make_routine + >>> from sympy.abc import x, y, f, g + >>> from sympy import Eq + >>> r = make_routine('test', [Eq(f, 2*x), Eq(g, x + y)]) + >>> [arg.result_var for arg in r.results] + [] + >>> [arg.name for arg in r.arguments] + [x, y, f, g] + >>> [arg.name for arg in r.result_variables] + [f, g] + >>> r.local_vars + set() + + Another more complicated example with a mixture of specified and + automatically-assigned names. Also has Matrix output. + + >>> from sympy import Matrix + >>> r = make_routine('fcn', [x*y, Eq(f, 1), Eq(g, x + g), Matrix([[x, 2]])]) + >>> [arg.result_var for arg in r.results] # doctest: +SKIP + [result_5397460570204848505] + >>> [arg.expr for arg in r.results] + [x*y] + >>> [arg.name for arg in r.arguments] # doctest: +SKIP + [x, y, f, g, out_8598435338387848786] + + We can examine the various arguments more closely: + + >>> from sympy.utilities.codegen import (InputArgument, OutputArgument, + ... InOutArgument) + >>> [a.name for a in r.arguments if isinstance(a, InputArgument)] + [x, y] + + >>> [a.name for a in r.arguments if isinstance(a, OutputArgument)] # doctest: +SKIP + [f, out_8598435338387848786] + >>> [a.expr for a in r.arguments if isinstance(a, OutputArgument)] + [1, Matrix([[x, 2]])] + + >>> [a.name for a in r.arguments if isinstance(a, InOutArgument)] + [g] + >>> [a.expr for a in r.arguments if isinstance(a, InOutArgument)] + [g + x] + + """ + + # initialize a new code generator + code_gen = get_code_generator(language) + + return code_gen.routine(name, expr, argument_sequence, global_vars) diff --git a/phivenv/Lib/site-packages/sympy/utilities/decorator.py b/phivenv/Lib/site-packages/sympy/utilities/decorator.py new file mode 100644 index 0000000000000000000000000000000000000000..82636c35e9f235a1d23aa7df5182418db3ef6b4f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/decorator.py @@ -0,0 +1,339 @@ +"""Useful utility decorators. """ + +from typing import TypeVar +import sys +import types +import inspect +from functools import wraps, update_wrapper + +from sympy.utilities.exceptions import sympy_deprecation_warning + + +T = TypeVar('T') +"""A generic type""" + + +def threaded_factory(func, use_add): + """A factory for ``threaded`` decorators. """ + from sympy.core import sympify + from sympy.matrices import MatrixBase + from sympy.utilities.iterables import iterable + + @wraps(func) + def threaded_func(expr, *args, **kwargs): + if isinstance(expr, MatrixBase): + return expr.applyfunc(lambda f: func(f, *args, **kwargs)) + elif iterable(expr): + try: + return expr.__class__([func(f, *args, **kwargs) for f in expr]) + except TypeError: + return expr + else: + expr = sympify(expr) + + if use_add and expr.is_Add: + return expr.__class__(*[ func(f, *args, **kwargs) for f in expr.args ]) + elif expr.is_Relational: + return expr.__class__(func(expr.lhs, *args, **kwargs), + func(expr.rhs, *args, **kwargs)) + else: + return func(expr, *args, **kwargs) + + return threaded_func + + +def threaded(func): + """Apply ``func`` to sub--elements of an object, including :class:`~.Add`. + + This decorator is intended to make it uniformly possible to apply a + function to all elements of composite objects, e.g. matrices, lists, tuples + and other iterable containers, or just expressions. + + This version of :func:`threaded` decorator allows threading over + elements of :class:`~.Add` class. If this behavior is not desirable + use :func:`xthreaded` decorator. + + Functions using this decorator must have the following signature:: + + @threaded + def function(expr, *args, **kwargs): + + """ + return threaded_factory(func, True) + + +def xthreaded(func): + """Apply ``func`` to sub--elements of an object, excluding :class:`~.Add`. + + This decorator is intended to make it uniformly possible to apply a + function to all elements of composite objects, e.g. matrices, lists, tuples + and other iterable containers, or just expressions. + + This version of :func:`threaded` decorator disallows threading over + elements of :class:`~.Add` class. If this behavior is not desirable + use :func:`threaded` decorator. + + Functions using this decorator must have the following signature:: + + @xthreaded + def function(expr, *args, **kwargs): + + """ + return threaded_factory(func, False) + + +def conserve_mpmath_dps(func): + """After the function finishes, resets the value of ``mpmath.mp.dps`` to + the value it had before the function was run.""" + import mpmath + + def func_wrapper(*args, **kwargs): + dps = mpmath.mp.dps + try: + return func(*args, **kwargs) + finally: + mpmath.mp.dps = dps + + func_wrapper = update_wrapper(func_wrapper, func) + return func_wrapper + + +class no_attrs_in_subclass: + """Don't 'inherit' certain attributes from a base class + + >>> from sympy.utilities.decorator import no_attrs_in_subclass + + >>> class A(object): + ... x = 'test' + + >>> A.x = no_attrs_in_subclass(A, A.x) + + >>> class B(A): + ... pass + + >>> hasattr(A, 'x') + True + >>> hasattr(B, 'x') + False + + """ + def __init__(self, cls, f): + self.cls = cls + self.f = f + + def __get__(self, instance, owner=None): + if owner == self.cls: + if hasattr(self.f, '__get__'): + return self.f.__get__(instance, owner) + return self.f + raise AttributeError + + +def doctest_depends_on(exe=None, modules=None, disable_viewers=None, + python_version=None, ground_types=None): + """ + Adds metadata about the dependencies which need to be met for doctesting + the docstrings of the decorated objects. + + ``exe`` should be a list of executables + + ``modules`` should be a list of modules + + ``disable_viewers`` should be a list of viewers for :func:`~sympy.printing.preview.preview` to disable + + ``python_version`` should be the minimum Python version required, as a tuple + (like ``(3, 0)``) + """ + dependencies = {} + if exe is not None: + dependencies['executables'] = exe + if modules is not None: + dependencies['modules'] = modules + if disable_viewers is not None: + dependencies['disable_viewers'] = disable_viewers + if python_version is not None: + dependencies['python_version'] = python_version + if ground_types is not None: + dependencies['ground_types'] = ground_types + + def skiptests(): + from sympy.testing.runtests import DependencyError, SymPyDocTests, PyTestReporter # lazy import + r = PyTestReporter() + t = SymPyDocTests(r, None) + try: + t._check_dependencies(**dependencies) + except DependencyError: + return True # Skip doctests + else: + return False # Run doctests + + def depends_on_deco(fn): + fn._doctest_depends_on = dependencies + fn.__doctest_skip__ = skiptests + + if inspect.isclass(fn): + fn._doctest_depdends_on = no_attrs_in_subclass( + fn, fn._doctest_depends_on) + fn.__doctest_skip__ = no_attrs_in_subclass( + fn, fn.__doctest_skip__) + return fn + + return depends_on_deco + + +def public(obj: T) -> T: + """ + Append ``obj``'s name to global ``__all__`` variable (call site). + + By using this decorator on functions or classes you achieve the same goal + as by filling ``__all__`` variables manually, you just do not have to repeat + yourself (object's name). You also know if object is public at definition + site, not at some random location (where ``__all__`` was set). + + Note that in multiple decorator setup (in almost all cases) ``@public`` + decorator must be applied before any other decorators, because it relies + on the pointer to object's global namespace. If you apply other decorators + first, ``@public`` may end up modifying the wrong namespace. + + Examples + ======== + + >>> from sympy.utilities.decorator import public + + >>> __all__ # noqa: F821 + Traceback (most recent call last): + ... + NameError: name '__all__' is not defined + + >>> @public + ... def some_function(): + ... pass + + >>> __all__ # noqa: F821 + ['some_function'] + + """ + if isinstance(obj, types.FunctionType): + ns = obj.__globals__ + name = obj.__name__ + elif isinstance(obj, (type(type), type)): + ns = sys.modules[obj.__module__].__dict__ + name = obj.__name__ + else: + raise TypeError("expected a function or a class, got %s" % obj) + + if "__all__" not in ns: + ns["__all__"] = [name] + else: + ns["__all__"].append(name) + + return obj + + +def memoize_property(propfunc): + """Property decorator that caches the value of potentially expensive + ``propfunc`` after the first evaluation. The cached value is stored in + the corresponding property name with an attached underscore.""" + attrname = '_' + propfunc.__name__ + sentinel = object() + + @wraps(propfunc) + def accessor(self): + val = getattr(self, attrname, sentinel) + if val is sentinel: + val = propfunc(self) + setattr(self, attrname, val) + return val + + return property(accessor) + + +def deprecated(message, *, deprecated_since_version, + active_deprecations_target, stacklevel=3): + ''' + Mark a function as deprecated. + + This decorator should be used if an entire function or class is + deprecated. If only a certain functionality is deprecated, you should use + :func:`~.warns_deprecated_sympy` directly. This decorator is just a + convenience. There is no functional difference between using this + decorator and calling ``warns_deprecated_sympy()`` at the top of the + function. + + The decorator takes the same arguments as + :func:`~.warns_deprecated_sympy`. See its + documentation for details on what the keywords to this decorator do. + + See the :ref:`deprecation-policy` document for details on when and how + things should be deprecated in SymPy. + + Examples + ======== + + >>> from sympy.utilities.decorator import deprecated + >>> from sympy import simplify + >>> @deprecated("""\ + ... The simplify_this(expr) function is deprecated. Use simplify(expr) + ... instead.""", deprecated_since_version="1.1", + ... active_deprecations_target='simplify-this-deprecation') + ... def simplify_this(expr): + ... """ + ... Simplify ``expr``. + ... + ... .. deprecated:: 1.1 + ... + ... The ``simplify_this`` function is deprecated. Use :func:`simplify` + ... instead. See its documentation for more information. See + ... :ref:`simplify-this-deprecation` for details. + ... + ... """ + ... return simplify(expr) + >>> from sympy.abc import x + >>> simplify_this(x*(x + 1) - x**2) # doctest: +SKIP + :1: SymPyDeprecationWarning: + + The simplify_this(expr) function is deprecated. Use simplify(expr) + instead. + + See https://docs.sympy.org/latest/explanation/active-deprecations.html#simplify-this-deprecation + for details. + + This has been deprecated since SymPy version 1.1. It + will be removed in a future version of SymPy. + + simplify_this(x) + x + + See Also + ======== + sympy.utilities.exceptions.SymPyDeprecationWarning + sympy.utilities.exceptions.sympy_deprecation_warning + sympy.utilities.exceptions.ignore_warnings + sympy.testing.pytest.warns_deprecated_sympy + + ''' + decorator_kwargs = {"deprecated_since_version": deprecated_since_version, + "active_deprecations_target": active_deprecations_target} + def deprecated_decorator(wrapped): + if hasattr(wrapped, '__mro__'): # wrapped is actually a class + class wrapper(wrapped): + __doc__ = wrapped.__doc__ + __module__ = wrapped.__module__ + _sympy_deprecated_func = wrapped + if '__new__' in wrapped.__dict__: + def __new__(cls, *args, **kwargs): + sympy_deprecation_warning(message, **decorator_kwargs, stacklevel=stacklevel) + return super().__new__(cls, *args, **kwargs) + else: + def __init__(self, *args, **kwargs): + sympy_deprecation_warning(message, **decorator_kwargs, stacklevel=stacklevel) + super().__init__(*args, **kwargs) + wrapper.__name__ = wrapped.__name__ + else: + @wraps(wrapped) + def wrapper(*args, **kwargs): + sympy_deprecation_warning(message, **decorator_kwargs, stacklevel=stacklevel) + return wrapped(*args, **kwargs) + wrapper._sympy_deprecated_func = wrapped + return wrapper + return deprecated_decorator diff --git a/phivenv/Lib/site-packages/sympy/utilities/enumerative.py b/phivenv/Lib/site-packages/sympy/utilities/enumerative.py new file mode 100644 index 0000000000000000000000000000000000000000..dcdabf064ad6291282b5905e7aed0ba9eb412d1a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/enumerative.py @@ -0,0 +1,1155 @@ +""" +Algorithms and classes to support enumerative combinatorics. + +Currently just multiset partitions, but more could be added. + +Terminology (following Knuth, algorithm 7.1.2.5M TAOCP) +*multiset* aaabbcccc has a *partition* aaabc | bccc + +The submultisets, aaabc and bccc of the partition are called +*parts*, or sometimes *vectors*. (Knuth notes that multiset +partitions can be thought of as partitions of vectors of integers, +where the ith element of the vector gives the multiplicity of +element i.) + +The values a, b and c are *components* of the multiset. These +correspond to elements of a set, but in a multiset can be present +with a multiplicity greater than 1. + +The algorithm deserves some explanation. + +Think of the part aaabc from the multiset above. If we impose an +ordering on the components of the multiset, we can represent a part +with a vector, in which the value of the first element of the vector +corresponds to the multiplicity of the first component in that +part. Thus, aaabc can be represented by the vector [3, 1, 1]. We +can also define an ordering on parts, based on the lexicographic +ordering of the vector (leftmost vector element, i.e., the element +with the smallest component number, is the most significant), so +that [3, 1, 1] > [3, 1, 0] and [3, 1, 1] > [2, 1, 4]. The ordering +on parts can be extended to an ordering on partitions: First, sort +the parts in each partition, left-to-right in decreasing order. Then +partition A is greater than partition B if A's leftmost/greatest +part is greater than B's leftmost part. If the leftmost parts are +equal, compare the second parts, and so on. + +In this ordering, the greatest partition of a given multiset has only +one part. The least partition is the one in which the components +are spread out, one per part. + +The enumeration algorithms in this file yield the partitions of the +argument multiset in decreasing order. The main data structure is a +stack of parts, corresponding to the current partition. An +important invariant is that the parts on the stack are themselves in +decreasing order. This data structure is decremented to find the +next smaller partition. Most often, decrementing the partition will +only involve adjustments to the smallest parts at the top of the +stack, much as adjacent integers *usually* differ only in their last +few digits. + +Knuth's algorithm uses two main operations on parts: + +Decrement - change the part so that it is smaller in the + (vector) lexicographic order, but reduced by the smallest amount possible. + For example, if the multiset has vector [5, + 3, 1], and the bottom/greatest part is [4, 2, 1], this part would + decrement to [4, 2, 0], while [4, 0, 0] would decrement to [3, 3, + 1]. A singleton part is never decremented -- [1, 0, 0] is not + decremented to [0, 3, 1]. Instead, the decrement operator needs + to fail for this case. In Knuth's pseudocode, the decrement + operator is step m5. + +Spread unallocated multiplicity - Once a part has been decremented, + it cannot be the rightmost part in the partition. There is some + multiplicity that has not been allocated, and new parts must be + created above it in the stack to use up this multiplicity. To + maintain the invariant that the parts on the stack are in + decreasing order, these new parts must be less than or equal to + the decremented part. + For example, if the multiset is [5, 3, 1], and its most + significant part has just been decremented to [5, 3, 0], the + spread operation will add a new part so that the stack becomes + [[5, 3, 0], [0, 0, 1]]. If the most significant part (for the + same multiset) has been decremented to [2, 0, 0] the stack becomes + [[2, 0, 0], [2, 0, 0], [1, 3, 1]]. In the pseudocode, the spread + operation for one part is step m2. The complete spread operation + is a loop of steps m2 and m3. + +In order to facilitate the spread operation, Knuth stores, for each +component of each part, not just the multiplicity of that component +in the part, but also the total multiplicity available for this +component in this part or any lesser part above it on the stack. + +One added twist is that Knuth does not represent the part vectors as +arrays. Instead, he uses a sparse representation, in which a +component of a part is represented as a component number (c), plus +the multiplicity of the component in that part (v) as well as the +total multiplicity available for that component (u). This saves +time that would be spent skipping over zeros. + +""" + +class PartComponent: + """Internal class used in support of the multiset partitions + enumerators and the associated visitor functions. + + Represents one component of one part of the current partition. + + A stack of these, plus an auxiliary frame array, f, represents a + partition of the multiset. + + Knuth's pseudocode makes c, u, and v separate arrays. + """ + + __slots__ = ('c', 'u', 'v') + + def __init__(self): + self.c = 0 # Component number + self.u = 0 # The as yet unpartitioned amount in component c + # *before* it is allocated by this triple + self.v = 0 # Amount of c component in the current part + # (v<=u). An invariant of the representation is + # that the next higher triple for this component + # (if there is one) will have a value of u-v in + # its u attribute. + + def __repr__(self): + "for debug/algorithm animation purposes" + return 'c:%d u:%d v:%d' % (self.c, self.u, self.v) + + def __eq__(self, other): + """Define value oriented equality, which is useful for testers""" + return (isinstance(other, self.__class__) and + self.c == other.c and + self.u == other.u and + self.v == other.v) + + def __ne__(self, other): + """Defined for consistency with __eq__""" + return not self == other + + +# This function tries to be a faithful implementation of algorithm +# 7.1.2.5M in Volume 4A, Combinatoral Algorithms, Part 1, of The Art +# of Computer Programming, by Donald Knuth. This includes using +# (mostly) the same variable names, etc. This makes for rather +# low-level Python. + +# Changes from Knuth's pseudocode include +# - use PartComponent struct/object instead of 3 arrays +# - make the function a generator +# - map (with some difficulty) the GOTOs to Python control structures. +# - Knuth uses 1-based numbering for components, this code is 0-based +# - renamed variable l to lpart. +# - flag variable x takes on values True/False instead of 1/0 +# +def multiset_partitions_taocp(multiplicities): + """Enumerates partitions of a multiset. + + Parameters + ========== + + multiplicities + list of integer multiplicities of the components of the multiset. + + Yields + ====== + + state + Internal data structure which encodes a particular partition. + This output is then usually processed by a visitor function + which combines the information from this data structure with + the components themselves to produce an actual partition. + + Unless they wish to create their own visitor function, users will + have little need to look inside this data structure. But, for + reference, it is a 3-element list with components: + + f + is a frame array, which is used to divide pstack into parts. + + lpart + points to the base of the topmost part. + + pstack + is an array of PartComponent objects. + + The ``state`` output offers a peek into the internal data + structures of the enumeration function. The client should + treat this as read-only; any modification of the data + structure will cause unpredictable (and almost certainly + incorrect) results. Also, the components of ``state`` are + modified in place at each iteration. Hence, the visitor must + be called at each loop iteration. Accumulating the ``state`` + instances and processing them later will not work. + + Examples + ======== + + >>> from sympy.utilities.enumerative import list_visitor + >>> from sympy.utilities.enumerative import multiset_partitions_taocp + >>> # variables components and multiplicities represent the multiset 'abb' + >>> components = 'ab' + >>> multiplicities = [1, 2] + >>> states = multiset_partitions_taocp(multiplicities) + >>> list(list_visitor(state, components) for state in states) + [[['a', 'b', 'b']], + [['a', 'b'], ['b']], + [['a'], ['b', 'b']], + [['a'], ['b'], ['b']]] + + See Also + ======== + + sympy.utilities.iterables.multiset_partitions: Takes a multiset + as input and directly yields multiset partitions. It + dispatches to a number of functions, including this one, for + implementation. Most users will find it more convenient to + use than multiset_partitions_taocp. + + """ + + # Important variables. + # m is the number of components, i.e., number of distinct elements + m = len(multiplicities) + # n is the cardinality, total number of elements whether or not distinct + n = sum(multiplicities) + + # The main data structure, f segments pstack into parts. See + # list_visitor() for example code indicating how this internal + # state corresponds to a partition. + + # Note: allocation of space for stack is conservative. Knuth's + # exercise 7.2.1.5.68 gives some indication of how to tighten this + # bound, but this is not implemented. + pstack = [PartComponent() for i in range(n * m + 1)] + f = [0] * (n + 1) + + # Step M1 in Knuth (Initialize) + # Initial state - entire multiset in one part. + for j in range(m): + ps = pstack[j] + ps.c = j + ps.u = multiplicities[j] + ps.v = multiplicities[j] + + # Other variables + f[0] = 0 + a = 0 + lpart = 0 + f[1] = m + b = m # in general, current stack frame is from a to b - 1 + + while True: + while True: + # Step M2 (Subtract v from u) + k = b + x = False + for j in range(a, b): + pstack[k].u = pstack[j].u - pstack[j].v + if pstack[k].u == 0: + x = True + elif not x: + pstack[k].c = pstack[j].c + pstack[k].v = min(pstack[j].v, pstack[k].u) + x = pstack[k].u < pstack[j].v + k = k + 1 + else: # x is True + pstack[k].c = pstack[j].c + pstack[k].v = pstack[k].u + k = k + 1 + # Note: x is True iff v has changed + + # Step M3 (Push if nonzero.) + if k > b: + a = b + b = k + lpart = lpart + 1 + f[lpart + 1] = b + # Return to M2 + else: + break # Continue to M4 + + # M4 Visit a partition + state = [f, lpart, pstack] + yield state + + # M5 (Decrease v) + while True: + j = b-1 + while (pstack[j].v == 0): + j = j - 1 + if j == a and pstack[j].v == 1: + # M6 (Backtrack) + if lpart == 0: + return + lpart = lpart - 1 + b = a + a = f[lpart] + # Return to M5 + else: + pstack[j].v = pstack[j].v - 1 + for k in range(j + 1, b): + pstack[k].v = pstack[k].u + break # GOTO M2 + +# --------------- Visitor functions for multiset partitions --------------- +# A visitor takes the partition state generated by +# multiset_partitions_taocp or other enumerator, and produces useful +# output (such as the actual partition). + + +def factoring_visitor(state, primes): + """Use with multiset_partitions_taocp to enumerate the ways a + number can be expressed as a product of factors. For this usage, + the exponents of the prime factors of a number are arguments to + the partition enumerator, while the corresponding prime factors + are input here. + + Examples + ======== + + To enumerate the factorings of a number we can think of the elements of the + partition as being the prime factors and the multiplicities as being their + exponents. + + >>> from sympy.utilities.enumerative import factoring_visitor + >>> from sympy.utilities.enumerative import multiset_partitions_taocp + >>> from sympy import factorint + >>> primes, multiplicities = zip(*factorint(24).items()) + >>> primes + (2, 3) + >>> multiplicities + (3, 1) + >>> states = multiset_partitions_taocp(multiplicities) + >>> list(factoring_visitor(state, primes) for state in states) + [[24], [8, 3], [12, 2], [4, 6], [4, 2, 3], [6, 2, 2], [2, 2, 2, 3]] + """ + f, lpart, pstack = state + factoring = [] + for i in range(lpart + 1): + factor = 1 + for ps in pstack[f[i]: f[i + 1]]: + if ps.v > 0: + factor *= primes[ps.c] ** ps.v + factoring.append(factor) + return factoring + + +def list_visitor(state, components): + """Return a list of lists to represent the partition. + + Examples + ======== + + >>> from sympy.utilities.enumerative import list_visitor + >>> from sympy.utilities.enumerative import multiset_partitions_taocp + >>> states = multiset_partitions_taocp([1, 2, 1]) + >>> s = next(states) + >>> list_visitor(s, 'abc') # for multiset 'a b b c' + [['a', 'b', 'b', 'c']] + >>> s = next(states) + >>> list_visitor(s, [1, 2, 3]) # for multiset '1 2 2 3 + [[1, 2, 2], [3]] + """ + f, lpart, pstack = state + + partition = [] + for i in range(lpart+1): + part = [] + for ps in pstack[f[i]:f[i+1]]: + if ps.v > 0: + part.extend([components[ps.c]] * ps.v) + partition.append(part) + + return partition + + +class MultisetPartitionTraverser(): + """ + Has methods to ``enumerate`` and ``count`` the partitions of a multiset. + + This implements a refactored and extended version of Knuth's algorithm + 7.1.2.5M [AOCP]_." + + The enumeration methods of this class are generators and return + data structures which can be interpreted by the same visitor + functions used for the output of ``multiset_partitions_taocp``. + + Examples + ======== + + >>> from sympy.utilities.enumerative import MultisetPartitionTraverser + >>> m = MultisetPartitionTraverser() + >>> m.count_partitions([4,4,4,2]) + 127750 + >>> m.count_partitions([3,3,3]) + 686 + + See Also + ======== + + multiset_partitions_taocp + sympy.utilities.iterables.multiset_partitions + + References + ========== + + .. [AOCP] Algorithm 7.1.2.5M in Volume 4A, Combinatoral Algorithms, + Part 1, of The Art of Computer Programming, by Donald Knuth. + + .. [Factorisatio] On a Problem of Oppenheim concerning + "Factorisatio Numerorum" E. R. Canfield, Paul Erdos, Carl + Pomerance, JOURNAL OF NUMBER THEORY, Vol. 17, No. 1. August + 1983. See section 7 for a description of an algorithm + similar to Knuth's. + + .. [Yorgey] Generating Multiset Partitions, Brent Yorgey, The + Monad.Reader, Issue 8, September 2007. + + """ + + def __init__(self): + self.debug = False + # TRACING variables. These are useful for gathering + # statistics on the algorithm itself, but have no particular + # benefit to a user of the code. + self.k1 = 0 + self.k2 = 0 + self.p1 = 0 + self.pstack = None + self.f = None + self.lpart = 0 + self.discarded = 0 + # dp_stack is list of lists of (part_key, start_count) pairs + self.dp_stack = [] + + # dp_map is map part_key-> count, where count represents the + # number of multiset which are descendants of a part with this + # key, **or any of its decrements** + + # Thus, when we find a part in the map, we add its count + # value to the running total, cut off the enumeration, and + # backtrack + + if not hasattr(self, 'dp_map'): + self.dp_map = {} + + def db_trace(self, msg): + """Useful for understanding/debugging the algorithms. Not + generally activated in end-user code.""" + if self.debug: + # XXX: animation_visitor is undefined... Clearly this does not + # work and was not tested. Previous code in comments below. + raise RuntimeError + #letters = 'abcdefghijklmnopqrstuvwxyz' + #state = [self.f, self.lpart, self.pstack] + #print("DBG:", msg, + # ["".join(part) for part in list_visitor(state, letters)], + # animation_visitor(state)) + + # + # Helper methods for enumeration + # + def _initialize_enumeration(self, multiplicities): + """Allocates and initializes the partition stack. + + This is called from the enumeration/counting routines, so + there is no need to call it separately.""" + + num_components = len(multiplicities) + # cardinality is the total number of elements, whether or not distinct + cardinality = sum(multiplicities) + + # pstack is the partition stack, which is segmented by + # f into parts. + self.pstack = [PartComponent() for i in + range(num_components * cardinality + 1)] + self.f = [0] * (cardinality + 1) + + # Initial state - entire multiset in one part. + for j in range(num_components): + ps = self.pstack[j] + ps.c = j + ps.u = multiplicities[j] + ps.v = multiplicities[j] + + self.f[0] = 0 + self.f[1] = num_components + self.lpart = 0 + + # The decrement_part() method corresponds to step M5 in Knuth's + # algorithm. This is the base version for enum_all(). Modified + # versions of this method are needed if we want to restrict + # sizes of the partitions produced. + def decrement_part(self, part): + """Decrements part (a subrange of pstack), if possible, returning + True iff the part was successfully decremented. + + If you think of the v values in the part as a multi-digit + integer (least significant digit on the right) this is + basically decrementing that integer, but with the extra + constraint that the leftmost digit cannot be decremented to 0. + + Parameters + ========== + + part + The part, represented as a list of PartComponent objects, + which is to be decremented. + + """ + plen = len(part) + for j in range(plen - 1, -1, -1): + if j == 0 and part[j].v > 1 or j > 0 and part[j].v > 0: + # found val to decrement + part[j].v -= 1 + # Reset trailing parts back to maximum + for k in range(j + 1, plen): + part[k].v = part[k].u + return True + return False + + # Version to allow number of parts to be bounded from above. + # Corresponds to (a modified) step M5. + def decrement_part_small(self, part, ub): + """Decrements part (a subrange of pstack), if possible, returning + True iff the part was successfully decremented. + + Parameters + ========== + + part + part to be decremented (topmost part on the stack) + + ub + the maximum number of parts allowed in a partition + returned by the calling traversal. + + Notes + ===== + + The goal of this modification of the ordinary decrement method + is to fail (meaning that the subtree rooted at this part is to + be skipped) when it can be proved that this part can only have + child partitions which are larger than allowed by ``ub``. If a + decision is made to fail, it must be accurate, otherwise the + enumeration will miss some partitions. But, it is OK not to + capture all the possible failures -- if a part is passed that + should not be, the resulting too-large partitions are filtered + by the enumeration one level up. However, as is usual in + constrained enumerations, failing early is advantageous. + + The tests used by this method catch the most common cases, + although this implementation is by no means the last word on + this problem. The tests include: + + 1) ``lpart`` must be less than ``ub`` by at least 2. This is because + once a part has been decremented, the partition + will gain at least one child in the spread step. + + 2) If the leading component of the part is about to be + decremented, check for how many parts will be added in + order to use up the unallocated multiplicity in that + leading component, and fail if this number is greater than + allowed by ``ub``. (See code for the exact expression.) This + test is given in the answer to Knuth's problem 7.2.1.5.69. + + 3) If there is *exactly* enough room to expand the leading + component by the above test, check the next component (if + it exists) once decrementing has finished. If this has + ``v == 0``, this next component will push the expansion over the + limit by 1, so fail. + """ + if self.lpart >= ub - 1: + self.p1 += 1 # increment to keep track of usefulness of tests + return False + plen = len(part) + for j in range(plen - 1, -1, -1): + # Knuth's mod, (answer to problem 7.2.1.5.69) + if j == 0 and (part[0].v - 1)*(ub - self.lpart) < part[0].u: + self.k1 += 1 + return False + + if j == 0 and part[j].v > 1 or j > 0 and part[j].v > 0: + # found val to decrement + part[j].v -= 1 + # Reset trailing parts back to maximum + for k in range(j + 1, plen): + part[k].v = part[k].u + + # Have now decremented part, but are we doomed to + # failure when it is expanded? Check one oddball case + # that turns out to be surprisingly common - exactly + # enough room to expand the leading component, but no + # room for the second component, which has v=0. + if (plen > 1 and part[1].v == 0 and + (part[0].u - part[0].v) == + ((ub - self.lpart - 1) * part[0].v)): + self.k2 += 1 + self.db_trace("Decrement fails test 3") + return False + return True + return False + + def decrement_part_large(self, part, amt, lb): + """Decrements part, while respecting size constraint. + + A part can have no children which are of sufficient size (as + indicated by ``lb``) unless that part has sufficient + unallocated multiplicity. When enforcing the size constraint, + this method will decrement the part (if necessary) by an + amount needed to ensure sufficient unallocated multiplicity. + + Returns True iff the part was successfully decremented. + + Parameters + ========== + + part + part to be decremented (topmost part on the stack) + + amt + Can only take values 0 or 1. A value of 1 means that the + part must be decremented, and then the size constraint is + enforced. A value of 0 means just to enforce the ``lb`` + size constraint. + + lb + The partitions produced by the calling enumeration must + have more parts than this value. + + """ + + if amt == 1: + # In this case we always need to decrement, *before* + # enforcing the "sufficient unallocated multiplicity" + # constraint. Easiest for this is just to call the + # regular decrement method. + if not self.decrement_part(part): + return False + + # Next, perform any needed additional decrementing to respect + # "sufficient unallocated multiplicity" (or fail if this is + # not possible). + min_unalloc = lb - self.lpart + if min_unalloc <= 0: + return True + total_mult = sum(pc.u for pc in part) + total_alloc = sum(pc.v for pc in part) + if total_mult <= min_unalloc: + return False + + deficit = min_unalloc - (total_mult - total_alloc) + if deficit <= 0: + return True + + for i in range(len(part) - 1, -1, -1): + if i == 0: + if part[0].v > deficit: + part[0].v -= deficit + return True + else: + return False # This shouldn't happen, due to above check + else: + if part[i].v >= deficit: + part[i].v -= deficit + return True + else: + deficit -= part[i].v + part[i].v = 0 + + def decrement_part_range(self, part, lb, ub): + """Decrements part (a subrange of pstack), if possible, returning + True iff the part was successfully decremented. + + Parameters + ========== + + part + part to be decremented (topmost part on the stack) + + ub + the maximum number of parts allowed in a partition + returned by the calling traversal. + + lb + The partitions produced by the calling enumeration must + have more parts than this value. + + Notes + ===== + + Combines the constraints of _small and _large decrement + methods. If returns success, part has been decremented at + least once, but perhaps by quite a bit more if needed to meet + the lb constraint. + """ + + # Constraint in the range case is just enforcing both the + # constraints from _small and _large cases. Note the 0 as the + # second argument to the _large call -- this is the signal to + # decrement only as needed to for constraint enforcement. The + # short circuiting and left-to-right order of the 'and' + # operator is important for this to work correctly. + return self.decrement_part_small(part, ub) and \ + self.decrement_part_large(part, 0, lb) + + def spread_part_multiplicity(self): + """Returns True if a new part has been created, and + adjusts pstack, f and lpart as needed. + + Notes + ===== + + Spreads unallocated multiplicity from the current top part + into a new part created above the current on the stack. This + new part is constrained to be less than or equal to the old in + terms of the part ordering. + + This call does nothing (and returns False) if the current top + part has no unallocated multiplicity. + + """ + j = self.f[self.lpart] # base of current top part + k = self.f[self.lpart + 1] # ub of current; potential base of next + base = k # save for later comparison + + changed = False # Set to true when the new part (so far) is + # strictly less than (as opposed to less than + # or equal) to the old. + for j in range(self.f[self.lpart], self.f[self.lpart + 1]): + self.pstack[k].u = self.pstack[j].u - self.pstack[j].v + if self.pstack[k].u == 0: + changed = True + else: + self.pstack[k].c = self.pstack[j].c + if changed: # Put all available multiplicity in this part + self.pstack[k].v = self.pstack[k].u + else: # Still maintaining ordering constraint + if self.pstack[k].u < self.pstack[j].v: + self.pstack[k].v = self.pstack[k].u + changed = True + else: + self.pstack[k].v = self.pstack[j].v + k = k + 1 + if k > base: + # Adjust for the new part on stack + self.lpart = self.lpart + 1 + self.f[self.lpart + 1] = k + return True + return False + + def top_part(self): + """Return current top part on the stack, as a slice of pstack. + + """ + return self.pstack[self.f[self.lpart]:self.f[self.lpart + 1]] + + # Same interface and functionality as multiset_partitions_taocp(), + # but some might find this refactored version easier to follow. + def enum_all(self, multiplicities): + """Enumerate the partitions of a multiset. + + Examples + ======== + + >>> from sympy.utilities.enumerative import list_visitor + >>> from sympy.utilities.enumerative import MultisetPartitionTraverser + >>> m = MultisetPartitionTraverser() + >>> states = m.enum_all([2,2]) + >>> list(list_visitor(state, 'ab') for state in states) + [[['a', 'a', 'b', 'b']], + [['a', 'a', 'b'], ['b']], + [['a', 'a'], ['b', 'b']], + [['a', 'a'], ['b'], ['b']], + [['a', 'b', 'b'], ['a']], + [['a', 'b'], ['a', 'b']], + [['a', 'b'], ['a'], ['b']], + [['a'], ['a'], ['b', 'b']], + [['a'], ['a'], ['b'], ['b']]] + + See Also + ======== + + multiset_partitions_taocp: + which provides the same result as this method, but is + about twice as fast. Hence, enum_all is primarily useful + for testing. Also see the function for a discussion of + states and visitors. + + """ + self._initialize_enumeration(multiplicities) + while True: + while self.spread_part_multiplicity(): + pass + + # M4 Visit a partition + state = [self.f, self.lpart, self.pstack] + yield state + + # M5 (Decrease v) + while not self.decrement_part(self.top_part()): + # M6 (Backtrack) + if self.lpart == 0: + return + self.lpart -= 1 + + def enum_small(self, multiplicities, ub): + """Enumerate multiset partitions with no more than ``ub`` parts. + + Equivalent to enum_range(multiplicities, 0, ub) + + Parameters + ========== + + multiplicities + list of multiplicities of the components of the multiset. + + ub + Maximum number of parts + + Examples + ======== + + >>> from sympy.utilities.enumerative import list_visitor + >>> from sympy.utilities.enumerative import MultisetPartitionTraverser + >>> m = MultisetPartitionTraverser() + >>> states = m.enum_small([2,2], 2) + >>> list(list_visitor(state, 'ab') for state in states) + [[['a', 'a', 'b', 'b']], + [['a', 'a', 'b'], ['b']], + [['a', 'a'], ['b', 'b']], + [['a', 'b', 'b'], ['a']], + [['a', 'b'], ['a', 'b']]] + + The implementation is based, in part, on the answer given to + exercise 69, in Knuth [AOCP]_. + + See Also + ======== + + enum_all, enum_large, enum_range + + """ + + # Keep track of iterations which do not yield a partition. + # Clearly, we would like to keep this number small. + self.discarded = 0 + if ub <= 0: + return + self._initialize_enumeration(multiplicities) + while True: + while self.spread_part_multiplicity(): + self.db_trace('spread 1') + if self.lpart >= ub: + self.discarded += 1 + self.db_trace(' Discarding') + self.lpart = ub - 2 + break + else: + # M4 Visit a partition + state = [self.f, self.lpart, self.pstack] + yield state + + # M5 (Decrease v) + while not self.decrement_part_small(self.top_part(), ub): + self.db_trace("Failed decrement, going to backtrack") + # M6 (Backtrack) + if self.lpart == 0: + return + self.lpart -= 1 + self.db_trace("Backtracked to") + self.db_trace("decrement ok, about to expand") + + def enum_large(self, multiplicities, lb): + """Enumerate the partitions of a multiset with lb < num(parts) + + Equivalent to enum_range(multiplicities, lb, sum(multiplicities)) + + Parameters + ========== + + multiplicities + list of multiplicities of the components of the multiset. + + lb + Number of parts in the partition must be greater than + this lower bound. + + + Examples + ======== + + >>> from sympy.utilities.enumerative import list_visitor + >>> from sympy.utilities.enumerative import MultisetPartitionTraverser + >>> m = MultisetPartitionTraverser() + >>> states = m.enum_large([2,2], 2) + >>> list(list_visitor(state, 'ab') for state in states) + [[['a', 'a'], ['b'], ['b']], + [['a', 'b'], ['a'], ['b']], + [['a'], ['a'], ['b', 'b']], + [['a'], ['a'], ['b'], ['b']]] + + See Also + ======== + + enum_all, enum_small, enum_range + + """ + self.discarded = 0 + if lb >= sum(multiplicities): + return + self._initialize_enumeration(multiplicities) + self.decrement_part_large(self.top_part(), 0, lb) + while True: + good_partition = True + while self.spread_part_multiplicity(): + if not self.decrement_part_large(self.top_part(), 0, lb): + # Failure here should be rare/impossible + self.discarded += 1 + good_partition = False + break + + # M4 Visit a partition + if good_partition: + state = [self.f, self.lpart, self.pstack] + yield state + + # M5 (Decrease v) + while not self.decrement_part_large(self.top_part(), 1, lb): + # M6 (Backtrack) + if self.lpart == 0: + return + self.lpart -= 1 + + def enum_range(self, multiplicities, lb, ub): + + """Enumerate the partitions of a multiset with + ``lb < num(parts) <= ub``. + + In particular, if partitions with exactly ``k`` parts are + desired, call with ``(multiplicities, k - 1, k)``. This + method generalizes enum_all, enum_small, and enum_large. + + Examples + ======== + + >>> from sympy.utilities.enumerative import list_visitor + >>> from sympy.utilities.enumerative import MultisetPartitionTraverser + >>> m = MultisetPartitionTraverser() + >>> states = m.enum_range([2,2], 1, 2) + >>> list(list_visitor(state, 'ab') for state in states) + [[['a', 'a', 'b'], ['b']], + [['a', 'a'], ['b', 'b']], + [['a', 'b', 'b'], ['a']], + [['a', 'b'], ['a', 'b']]] + + """ + # combine the constraints of the _large and _small + # enumerations. + self.discarded = 0 + if ub <= 0 or lb >= sum(multiplicities): + return + self._initialize_enumeration(multiplicities) + self.decrement_part_large(self.top_part(), 0, lb) + while True: + good_partition = True + while self.spread_part_multiplicity(): + self.db_trace("spread 1") + if not self.decrement_part_large(self.top_part(), 0, lb): + # Failure here - possible in range case? + self.db_trace(" Discarding (large cons)") + self.discarded += 1 + good_partition = False + break + elif self.lpart >= ub: + self.discarded += 1 + good_partition = False + self.db_trace(" Discarding small cons") + self.lpart = ub - 2 + break + + # M4 Visit a partition + if good_partition: + state = [self.f, self.lpart, self.pstack] + yield state + + # M5 (Decrease v) + while not self.decrement_part_range(self.top_part(), lb, ub): + self.db_trace("Failed decrement, going to backtrack") + # M6 (Backtrack) + if self.lpart == 0: + return + self.lpart -= 1 + self.db_trace("Backtracked to") + self.db_trace("decrement ok, about to expand") + + def count_partitions_slow(self, multiplicities): + """Returns the number of partitions of a multiset whose elements + have the multiplicities given in ``multiplicities``. + + Primarily for comparison purposes. It follows the same path as + enumerate, and counts, rather than generates, the partitions. + + See Also + ======== + + count_partitions + Has the same calling interface, but is much faster. + + """ + # number of partitions so far in the enumeration + self.pcount = 0 + self._initialize_enumeration(multiplicities) + while True: + while self.spread_part_multiplicity(): + pass + + # M4 Visit (count) a partition + self.pcount += 1 + + # M5 (Decrease v) + while not self.decrement_part(self.top_part()): + # M6 (Backtrack) + if self.lpart == 0: + return self.pcount + self.lpart -= 1 + + def count_partitions(self, multiplicities): + """Returns the number of partitions of a multiset whose components + have the multiplicities given in ``multiplicities``. + + For larger counts, this method is much faster than calling one + of the enumerators and counting the result. Uses dynamic + programming to cut down on the number of nodes actually + explored. The dictionary used in order to accelerate the + counting process is stored in the ``MultisetPartitionTraverser`` + object and persists across calls. If the user does not + expect to call ``count_partitions`` for any additional + multisets, the object should be cleared to save memory. On + the other hand, the cache built up from one count run can + significantly speed up subsequent calls to ``count_partitions``, + so it may be advantageous not to clear the object. + + Examples + ======== + + >>> from sympy.utilities.enumerative import MultisetPartitionTraverser + >>> m = MultisetPartitionTraverser() + >>> m.count_partitions([9,8,2]) + 288716 + >>> m.count_partitions([2,2]) + 9 + >>> del m + + Notes + ===== + + If one looks at the workings of Knuth's algorithm M [AOCP]_, it + can be viewed as a traversal of a binary tree of parts. A + part has (up to) two children, the left child resulting from + the spread operation, and the right child from the decrement + operation. The ordinary enumeration of multiset partitions is + an in-order traversal of this tree, and with the partitions + corresponding to paths from the root to the leaves. The + mapping from paths to partitions is a little complicated, + since the partition would contain only those parts which are + leaves or the parents of a spread link, not those which are + parents of a decrement link. + + For counting purposes, it is sufficient to count leaves, and + this can be done with a recursive in-order traversal. The + number of leaves of a subtree rooted at a particular part is a + function only of that part itself, so memoizing has the + potential to speed up the counting dramatically. + + This method follows a computational approach which is similar + to the hypothetical memoized recursive function, but with two + differences: + + 1) This method is iterative, borrowing its structure from the + other enumerations and maintaining an explicit stack of + parts which are in the process of being counted. (There + may be multisets which can be counted reasonably quickly by + this implementation, but which would overflow the default + Python recursion limit with a recursive implementation.) + + 2) Instead of using the part data structure directly, a more + compact key is constructed. This saves space, but more + importantly coalesces some parts which would remain + separate with physical keys. + + Unlike the enumeration functions, there is currently no _range + version of count_partitions. If someone wants to stretch + their brain, it should be possible to construct one by + memoizing with a histogram of counts rather than a single + count, and combining the histograms. + """ + # number of partitions so far in the enumeration + self.pcount = 0 + + # dp_stack is list of lists of (part_key, start_count) pairs + self.dp_stack = [] + + self._initialize_enumeration(multiplicities) + pkey = part_key(self.top_part()) + self.dp_stack.append([(pkey, 0), ]) + while True: + while self.spread_part_multiplicity(): + pkey = part_key(self.top_part()) + if pkey in self.dp_map: + # Already have a cached value for the count of the + # subtree rooted at this part. Add it to the + # running counter, and break out of the spread + # loop. The -1 below is to compensate for the + # leaf that this code path would otherwise find, + # and which gets incremented for below. + + self.pcount += (self.dp_map[pkey] - 1) + self.lpart -= 1 + break + else: + self.dp_stack.append([(pkey, self.pcount), ]) + + # M4 count a leaf partition + self.pcount += 1 + + # M5 (Decrease v) + while not self.decrement_part(self.top_part()): + # M6 (Backtrack) + for key, oldcount in self.dp_stack.pop(): + self.dp_map[key] = self.pcount - oldcount + if self.lpart == 0: + return self.pcount + self.lpart -= 1 + + # At this point have successfully decremented the part on + # the stack and it does not appear in the cache. It needs + # to be added to the list at the top of dp_stack + pkey = part_key(self.top_part()) + self.dp_stack[-1].append((pkey, self.pcount),) + + +def part_key(part): + """Helper for MultisetPartitionTraverser.count_partitions that + creates a key for ``part``, that only includes information which can + affect the count for that part. (Any irrelevant information just + reduces the effectiveness of dynamic programming.) + + Notes + ===== + + This member function is a candidate for future exploration. There + are likely symmetries that can be exploited to coalesce some + ``part_key`` values, and thereby save space and improve + performance. + + """ + # The component number is irrelevant for counting partitions, so + # leave it out of the memo key. + rval = [] + for ps in part: + rval.append(ps.u) + rval.append(ps.v) + return tuple(rval) diff --git a/phivenv/Lib/site-packages/sympy/utilities/exceptions.py b/phivenv/Lib/site-packages/sympy/utilities/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..f764a31371ed109d69102777bd9ec0dfb3d75380 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/exceptions.py @@ -0,0 +1,271 @@ +""" +General SymPy exceptions and warnings. +""" + +import warnings +import contextlib + +from textwrap import dedent + + +class SymPyDeprecationWarning(DeprecationWarning): + r""" + A warning for deprecated features of SymPy. + + See the :ref:`deprecation-policy` document for details on when and how + things should be deprecated in SymPy. + + Note that simply constructing this class will not cause a warning to be + issued. To do that, you must call the :func`sympy_deprecation_warning` + function. For this reason, it is not recommended to ever construct this + class directly. + + Explanation + =========== + + The ``SymPyDeprecationWarning`` class is a subclass of + ``DeprecationWarning`` that is used for all deprecations in SymPy. A + special subclass is used so that we can automatically augment the warning + message with additional metadata about the version the deprecation was + introduced in and a link to the documentation. This also allows users to + explicitly filter deprecation warnings from SymPy using ``warnings`` + filters (see :ref:`silencing-sympy-deprecation-warnings`). + + Additionally, ``SymPyDeprecationWarning`` is enabled to be shown by + default, unlike normal ``DeprecationWarning``\s, which are only shown by + default in interactive sessions. This ensures that deprecation warnings in + SymPy will actually be seen by users. + + See the documentation of :func:`sympy_deprecation_warning` for a + description of the parameters to this function. + + To mark a function as deprecated, you can use the :func:`@deprecated + ` decorator. + + See Also + ======== + sympy.utilities.exceptions.sympy_deprecation_warning + sympy.utilities.exceptions.ignore_warnings + sympy.utilities.decorator.deprecated + sympy.testing.pytest.warns_deprecated_sympy + + """ + def __init__(self, message, *, deprecated_since_version, active_deprecations_target): + + super().__init__(message, deprecated_since_version, + active_deprecations_target) + self.message = message + if not isinstance(deprecated_since_version, str): + raise TypeError(f"'deprecated_since_version' should be a string, got {deprecated_since_version!r}") + self.deprecated_since_version = deprecated_since_version + self.active_deprecations_target = active_deprecations_target + if any(i in active_deprecations_target for i in '()='): + raise ValueError("active_deprecations_target be the part inside of the '(...)='") + + self.full_message = f""" + +{dedent(message).strip()} + +See https://docs.sympy.org/latest/explanation/active-deprecations.html#{active_deprecations_target} +for details. + +This has been deprecated since SymPy version {deprecated_since_version}. It +will be removed in a future version of SymPy. +""" + + def __str__(self): + return self.full_message + + def __repr__(self): + return f"{self.__class__.__name__}({self.message!r}, deprecated_since_version={self.deprecated_since_version!r}, active_deprecations_target={self.active_deprecations_target!r})" + + def __eq__(self, other): + return isinstance(other, SymPyDeprecationWarning) and self.args == other.args + + # Make pickling work. The by default, it tries to recreate the expression + # from its args, but this doesn't work because of our keyword-only + # arguments. + @classmethod + def _new(cls, message, deprecated_since_version, + active_deprecations_target): + return cls(message, deprecated_since_version=deprecated_since_version, active_deprecations_target=active_deprecations_target) + + def __reduce__(self): + return (self._new, (self.message, self.deprecated_since_version, self.active_deprecations_target)) + +# Python by default hides DeprecationWarnings, which we do not want. +warnings.simplefilter("once", SymPyDeprecationWarning) + +def sympy_deprecation_warning(message, *, deprecated_since_version, + active_deprecations_target, stacklevel=3): + r''' + Warn that a feature is deprecated in SymPy. + + See the :ref:`deprecation-policy` document for details on when and how + things should be deprecated in SymPy. + + To mark an entire function or class as deprecated, you can use the + :func:`@deprecated ` decorator. + + Parameters + ========== + + message : str + The deprecation message. This may span multiple lines and contain + code examples. Messages should be wrapped to 80 characters. The + message is automatically dedented and leading and trailing whitespace + stripped. Messages may include dynamic content based on the user + input, but avoid using ``str(expression)`` if an expression can be + arbitrary, as it might be huge and make the warning message + unreadable. + + deprecated_since_version : str + The version of SymPy the feature has been deprecated since. For new + deprecations, this should be the version in `sympy/release.py + `_ + without the ``.dev``. If the next SymPy version ends up being + different from this, the release manager will need to update any + ``SymPyDeprecationWarning``\s using the incorrect version. This + argument is required and must be passed as a keyword argument. + (example: ``deprecated_since_version="1.10"``). + + active_deprecations_target : str + The Sphinx target corresponding to the section for the deprecation in + the :ref:`active-deprecations` document (see + ``doc/src/explanation/active-deprecations.md``). This is used to + automatically generate a URL to the page in the warning message. This + argument is required and must be passed as a keyword argument. + (example: ``active_deprecations_target="deprecated-feature-abc"``) + + stacklevel : int, default: 3 + The ``stacklevel`` parameter that is passed to ``warnings.warn``. If + you create a wrapper that calls this function, this should be + increased so that the warning message shows the user line of code that + produced the warning. Note that in some cases there will be multiple + possible different user code paths that could result in the warning. + In that case, just choose the smallest common stacklevel. + + Examples + ======== + + >>> from sympy.utilities.exceptions import sympy_deprecation_warning + >>> def is_this_zero(x, y=0): + ... """ + ... Determine if x = 0. + ... + ... Parameters + ... ========== + ... + ... x : Expr + ... The expression to check. + ... + ... y : Expr, optional + ... If provided, check if x = y. + ... + ... .. deprecated:: 1.1 + ... + ... The ``y`` argument to ``is_this_zero`` is deprecated. Use + ... ``is_this_zero(x - y)`` instead. + ... + ... """ + ... from sympy import simplify + ... + ... if y != 0: + ... sympy_deprecation_warning(""" + ... The y argument to is_zero() is deprecated. Use is_zero(x - y) instead.""", + ... deprecated_since_version="1.1", + ... active_deprecations_target='is-this-zero-y-deprecation') + ... return simplify(x - y) == 0 + >>> is_this_zero(0) + True + >>> is_this_zero(1, 1) # doctest: +SKIP + :1: SymPyDeprecationWarning: + + The y argument to is_zero() is deprecated. Use is_zero(x - y) instead. + + See https://docs.sympy.org/latest/explanation/active-deprecations.html#is-this-zero-y-deprecation + for details. + + This has been deprecated since SymPy version 1.1. It + will be removed in a future version of SymPy. + + is_this_zero(1, 1) + True + + See Also + ======== + + sympy.utilities.exceptions.SymPyDeprecationWarning + sympy.utilities.exceptions.ignore_warnings + sympy.utilities.decorator.deprecated + sympy.testing.pytest.warns_deprecated_sympy + + ''' + w = SymPyDeprecationWarning(message, + deprecated_since_version=deprecated_since_version, + active_deprecations_target=active_deprecations_target) + warnings.warn(w, stacklevel=stacklevel) + + +@contextlib.contextmanager +def ignore_warnings(warningcls): + ''' + Context manager to suppress warnings during tests. + + .. note:: + + Do not use this with SymPyDeprecationWarning in the tests. + warns_deprecated_sympy() should be used instead. + + This function is useful for suppressing warnings during tests. The warns + function should be used to assert that a warning is raised. The + ignore_warnings function is useful in situation when the warning is not + guaranteed to be raised (e.g. on importing a module) or if the warning + comes from third-party code. + + This function is also useful to prevent the same or similar warnings from + being issue twice due to recursive calls. + + When the warning is coming (reliably) from SymPy the warns function should + be preferred to ignore_warnings. + + >>> from sympy.utilities.exceptions import ignore_warnings + >>> import warnings + + Here's a warning: + + >>> with warnings.catch_warnings(): # reset warnings in doctest + ... warnings.simplefilter('error') + ... warnings.warn('deprecated', UserWarning) + Traceback (most recent call last): + ... + UserWarning: deprecated + + Let's suppress it with ignore_warnings: + + >>> with warnings.catch_warnings(): # reset warnings in doctest + ... warnings.simplefilter('error') + ... with ignore_warnings(UserWarning): + ... warnings.warn('deprecated', UserWarning) + + (No warning emitted) + + See Also + ======== + sympy.utilities.exceptions.SymPyDeprecationWarning + sympy.utilities.exceptions.sympy_deprecation_warning + sympy.utilities.decorator.deprecated + sympy.testing.pytest.warns_deprecated_sympy + + ''' + # Absorbs all warnings in warnrec + with warnings.catch_warnings(record=True) as warnrec: + # Make sure our warning doesn't get filtered + warnings.simplefilter("always", warningcls) + # Now run the test + yield + + # Reissue any warnings that we aren't testing for + for w in warnrec: + if not issubclass(w.category, warningcls): + warnings.warn_explicit(w.message, w.category, w.filename, w.lineno) diff --git a/phivenv/Lib/site-packages/sympy/utilities/iterables.py b/phivenv/Lib/site-packages/sympy/utilities/iterables.py new file mode 100644 index 0000000000000000000000000000000000000000..cc5c1152f03b10f62f9dd04c842b6fc74b0b9242 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/iterables.py @@ -0,0 +1,3179 @@ +from collections import Counter, defaultdict, OrderedDict +from itertools import ( + chain, combinations, combinations_with_replacement, cycle, islice, + permutations, product, groupby +) +# For backwards compatibility +from itertools import product as cartes # noqa: F401 +from operator import gt + + + +# this is the logical location of these functions +from sympy.utilities.enumerative import ( + multiset_partitions_taocp, list_visitor, MultisetPartitionTraverser) + +from sympy.utilities.misc import as_int +from sympy.utilities.decorator import deprecated + + +def is_palindromic(s, i=0, j=None): + """ + Return True if the sequence is the same from left to right as it + is from right to left in the whole sequence (default) or in the + Python slice ``s[i: j]``; else False. + + Examples + ======== + + >>> from sympy.utilities.iterables import is_palindromic + >>> is_palindromic([1, 0, 1]) + True + >>> is_palindromic('abcbb') + False + >>> is_palindromic('abcbb', 1) + False + + Normal Python slicing is performed in place so there is no need to + create a slice of the sequence for testing: + + >>> is_palindromic('abcbb', 1, -1) + True + >>> is_palindromic('abcbb', -4, -1) + True + + See Also + ======== + + sympy.ntheory.digits.is_palindromic: tests integers + + """ + i, j, _ = slice(i, j).indices(len(s)) + m = (j - i)//2 + # if length is odd, middle element will be ignored + return all(s[i + k] == s[j - 1 - k] for k in range(m)) + + +def flatten(iterable, levels=None, cls=None): # noqa: F811 + """ + Recursively denest iterable containers. + + >>> from sympy import flatten + + >>> flatten([1, 2, 3]) + [1, 2, 3] + >>> flatten([1, 2, [3]]) + [1, 2, 3] + >>> flatten([1, [2, 3], [4, 5]]) + [1, 2, 3, 4, 5] + >>> flatten([1.0, 2, (1, None)]) + [1.0, 2, 1, None] + + If you want to denest only a specified number of levels of + nested containers, then set ``levels`` flag to the desired + number of levels:: + + >>> ls = [[(-2, -1), (1, 2)], [(0, 0)]] + + >>> flatten(ls, levels=1) + [(-2, -1), (1, 2), (0, 0)] + + If cls argument is specified, it will only flatten instances of that + class, for example: + + >>> from sympy import Basic, S + >>> class MyOp(Basic): + ... pass + ... + >>> flatten([MyOp(S(1), MyOp(S(2), S(3)))], cls=MyOp) + [1, 2, 3] + + adapted from https://kogs-www.informatik.uni-hamburg.de/~meine/python_tricks + """ + from sympy.tensor.array import NDimArray + if levels is not None: + if not levels: + return iterable + elif levels > 0: + levels -= 1 + else: + raise ValueError( + "expected non-negative number of levels, got %s" % levels) + + if cls is None: + def reducible(x): + return is_sequence(x, set) + else: + def reducible(x): + return isinstance(x, cls) + + result = [] + + for el in iterable: + if reducible(el): + if hasattr(el, 'args') and not isinstance(el, NDimArray): + el = el.args + result.extend(flatten(el, levels=levels, cls=cls)) + else: + result.append(el) + + return result + + +def unflatten(iter, n=2): + """Group ``iter`` into tuples of length ``n``. Raise an error if + the length of ``iter`` is not a multiple of ``n``. + """ + if n < 1 or len(iter) % n: + raise ValueError('iter length is not a multiple of %i' % n) + return list(zip(*(iter[i::n] for i in range(n)))) + + +def reshape(seq, how): + """Reshape the sequence according to the template in ``how``. + + Examples + ======== + + >>> from sympy.utilities import reshape + >>> seq = list(range(1, 9)) + + >>> reshape(seq, [4]) # lists of 4 + [[1, 2, 3, 4], [5, 6, 7, 8]] + + >>> reshape(seq, (4,)) # tuples of 4 + [(1, 2, 3, 4), (5, 6, 7, 8)] + + >>> reshape(seq, (2, 2)) # tuples of 4 + [(1, 2, 3, 4), (5, 6, 7, 8)] + + >>> reshape(seq, (2, [2])) # (i, i, [i, i]) + [(1, 2, [3, 4]), (5, 6, [7, 8])] + + >>> reshape(seq, ((2,), [2])) # etc.... + [((1, 2), [3, 4]), ((5, 6), [7, 8])] + + >>> reshape(seq, (1, [2], 1)) + [(1, [2, 3], 4), (5, [6, 7], 8)] + + >>> reshape(tuple(seq), ([[1], 1, (2,)],)) + (([[1], 2, (3, 4)],), ([[5], 6, (7, 8)],)) + + >>> reshape(tuple(seq), ([1], 1, (2,))) + (([1], 2, (3, 4)), ([5], 6, (7, 8))) + + >>> reshape(list(range(12)), [2, [3], {2}, (1, (3,), 1)]) + [[0, 1, [2, 3, 4], {5, 6}, (7, (8, 9, 10), 11)]] + + """ + m = sum(flatten(how)) + n, rem = divmod(len(seq), m) + if m < 0 or rem: + raise ValueError('template must sum to positive number ' + 'that divides the length of the sequence') + i = 0 + container = type(how) + rv = [None]*n + for k in range(len(rv)): + _rv = [] + for hi in how: + if isinstance(hi, int): + _rv.extend(seq[i: i + hi]) + i += hi + else: + n = sum(flatten(hi)) + hi_type = type(hi) + _rv.append(hi_type(reshape(seq[i: i + n], hi)[0])) + i += n + rv[k] = container(_rv) + return type(seq)(rv) + + +def group(seq, multiple=True): + """ + Splits a sequence into a list of lists of equal, adjacent elements. + + Examples + ======== + + >>> from sympy import group + + >>> group([1, 1, 1, 2, 2, 3]) + [[1, 1, 1], [2, 2], [3]] + >>> group([1, 1, 1, 2, 2, 3], multiple=False) + [(1, 3), (2, 2), (3, 1)] + >>> group([1, 1, 3, 2, 2, 1], multiple=False) + [(1, 2), (3, 1), (2, 2), (1, 1)] + + See Also + ======== + + multiset + + """ + if multiple: + return [(list(g)) for _, g in groupby(seq)] + return [(k, len(list(g))) for k, g in groupby(seq)] + + +def _iproduct2(iterable1, iterable2): + '''Cartesian product of two possibly infinite iterables''' + + it1 = iter(iterable1) + it2 = iter(iterable2) + + elems1 = [] + elems2 = [] + + sentinel = object() + def append(it, elems): + e = next(it, sentinel) + if e is not sentinel: + elems.append(e) + + n = 0 + append(it1, elems1) + append(it2, elems2) + + while n <= len(elems1) + len(elems2): + for m in range(n-len(elems1)+1, len(elems2)): + yield (elems1[n-m], elems2[m]) + n += 1 + append(it1, elems1) + append(it2, elems2) + + +def iproduct(*iterables): + ''' + Cartesian product of iterables. + + Generator of the Cartesian product of iterables. This is analogous to + itertools.product except that it works with infinite iterables and will + yield any item from the infinite product eventually. + + Examples + ======== + + >>> from sympy.utilities.iterables import iproduct + >>> sorted(iproduct([1,2], [3,4])) + [(1, 3), (1, 4), (2, 3), (2, 4)] + + With an infinite iterator: + + >>> from sympy import S + >>> (3,) in iproduct(S.Integers) + True + >>> (3, 4) in iproduct(S.Integers, S.Integers) + True + + .. seealso:: + + `itertools.product + `_ + ''' + if len(iterables) == 0: + yield () + return + elif len(iterables) == 1: + for e in iterables[0]: + yield (e,) + elif len(iterables) == 2: + yield from _iproduct2(*iterables) + else: + first, others = iterables[0], iterables[1:] + for ef, eo in _iproduct2(first, iproduct(*others)): + yield (ef,) + eo + + +def multiset(seq): + """Return the hashable sequence in multiset form with values being the + multiplicity of the item in the sequence. + + Examples + ======== + + >>> from sympy.utilities.iterables import multiset + >>> multiset('mississippi') + {'i': 4, 'm': 1, 'p': 2, 's': 4} + + See Also + ======== + + group + + """ + return dict(Counter(seq).items()) + + + + +def ibin(n, bits=None, str=False): + """Return a list of length ``bits`` corresponding to the binary value + of ``n`` with small bits to the right (last). If bits is omitted, the + length will be the number required to represent ``n``. If the bits are + desired in reversed order, use the ``[::-1]`` slice of the returned list. + + If a sequence of all bits-length lists starting from ``[0, 0,..., 0]`` + through ``[1, 1, ..., 1]`` are desired, pass a non-integer for bits, e.g. + ``'all'``. + + If the bit *string* is desired pass ``str=True``. + + Examples + ======== + + >>> from sympy.utilities.iterables import ibin + >>> ibin(2) + [1, 0] + >>> ibin(2, 4) + [0, 0, 1, 0] + + If all lists corresponding to 0 to 2**n - 1, pass a non-integer + for bits: + + >>> bits = 2 + >>> for i in ibin(2, 'all'): + ... print(i) + (0, 0) + (0, 1) + (1, 0) + (1, 1) + + If a bit string is desired of a given length, use str=True: + + >>> n = 123 + >>> bits = 10 + >>> ibin(n, bits, str=True) + '0001111011' + >>> ibin(n, bits, str=True)[::-1] # small bits left + '1101111000' + >>> list(ibin(3, 'all', str=True)) + ['000', '001', '010', '011', '100', '101', '110', '111'] + + """ + if n < 0: + raise ValueError("negative numbers are not allowed") + n = as_int(n) + + if bits is None: + bits = 0 + else: + try: + bits = as_int(bits) + except ValueError: + bits = -1 + else: + if n.bit_length() > bits: + raise ValueError( + "`bits` must be >= {}".format(n.bit_length())) + + if not str: + if bits >= 0: + return [1 if i == "1" else 0 for i in bin(n)[2:].rjust(bits, "0")] + else: + return variations(range(2), n, repetition=True) + else: + if bits >= 0: + return bin(n)[2:].rjust(bits, "0") + else: + return (bin(i)[2:].rjust(n, "0") for i in range(2**n)) + + +def variations(seq, n, repetition=False): + r"""Returns an iterator over the n-sized variations of ``seq`` (size N). + ``repetition`` controls whether items in ``seq`` can appear more than once; + + Examples + ======== + + ``variations(seq, n)`` will return `\frac{N!}{(N - n)!}` permutations without + repetition of ``seq``'s elements: + + >>> from sympy import variations + >>> list(variations([1, 2], 2)) + [(1, 2), (2, 1)] + + ``variations(seq, n, True)`` will return the `N^n` permutations obtained + by allowing repetition of elements: + + >>> list(variations([1, 2], 2, repetition=True)) + [(1, 1), (1, 2), (2, 1), (2, 2)] + + If you ask for more items than are in the set you get the empty set unless + you allow repetitions: + + >>> list(variations([0, 1], 3, repetition=False)) + [] + >>> list(variations([0, 1], 3, repetition=True))[:4] + [(0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 1, 1)] + + .. seealso:: + + `itertools.permutations + `_, + `itertools.product + `_ + """ + if not repetition: + seq = tuple(seq) + if len(seq) < n: + return iter(()) # 0 length iterator + return permutations(seq, n) + else: + if n == 0: + return iter(((),)) # yields 1 empty tuple + else: + return product(seq, repeat=n) + + +def subsets(seq, k=None, repetition=False): + r"""Generates all `k`-subsets (combinations) from an `n`-element set, ``seq``. + + A `k`-subset of an `n`-element set is any subset of length exactly `k`. The + number of `k`-subsets of an `n`-element set is given by ``binomial(n, k)``, + whereas there are `2^n` subsets all together. If `k` is ``None`` then all + `2^n` subsets will be returned from shortest to longest. + + Examples + ======== + + >>> from sympy import subsets + + ``subsets(seq, k)`` will return the + `\frac{n!}{k!(n - k)!}` `k`-subsets (combinations) + without repetition, i.e. once an item has been removed, it can no + longer be "taken": + + >>> list(subsets([1, 2], 2)) + [(1, 2)] + >>> list(subsets([1, 2])) + [(), (1,), (2,), (1, 2)] + >>> list(subsets([1, 2, 3], 2)) + [(1, 2), (1, 3), (2, 3)] + + + ``subsets(seq, k, repetition=True)`` will return the + `\frac{(n - 1 + k)!}{k!(n - 1)!}` + combinations *with* repetition: + + >>> list(subsets([1, 2], 2, repetition=True)) + [(1, 1), (1, 2), (2, 2)] + + If you ask for more items than are in the set you get the empty set unless + you allow repetitions: + + >>> list(subsets([0, 1], 3, repetition=False)) + [] + >>> list(subsets([0, 1], 3, repetition=True)) + [(0, 0, 0), (0, 0, 1), (0, 1, 1), (1, 1, 1)] + + """ + if k is None: + if not repetition: + return chain.from_iterable((combinations(seq, k) + for k in range(len(seq) + 1))) + else: + return chain.from_iterable((combinations_with_replacement(seq, k) + for k in range(len(seq) + 1))) + else: + if not repetition: + return combinations(seq, k) + else: + return combinations_with_replacement(seq, k) + + +def filter_symbols(iterator, exclude): + """ + Only yield elements from `iterator` that do not occur in `exclude`. + + Parameters + ========== + + iterator : iterable + iterator to take elements from + + exclude : iterable + elements to exclude + + Returns + ======= + + iterator : iterator + filtered iterator + """ + exclude = set(exclude) + for s in iterator: + if s not in exclude: + yield s + +def numbered_symbols(prefix='x', cls=None, start=0, exclude=(), *args, **assumptions): + """ + Generate an infinite stream of Symbols consisting of a prefix and + increasing subscripts provided that they do not occur in ``exclude``. + + Parameters + ========== + + prefix : str, optional + The prefix to use. By default, this function will generate symbols of + the form "x0", "x1", etc. + + cls : class, optional + The class to use. By default, it uses ``Symbol``, but you can also use ``Wild`` + or ``Dummy``. + + start : int, optional + The start number. By default, it is 0. + + exclude : list, tuple, set of cls, optional + Symbols to be excluded. + + *args, **kwargs + Additional positional and keyword arguments are passed to the *cls* class. + + Returns + ======= + + sym : Symbol + The subscripted symbols. + """ + exclude = set(exclude or []) + if cls is None: + # We can't just make the default cls=Symbol because it isn't + # imported yet. + from sympy.core import Symbol + cls = Symbol + + while True: + name = '%s%s' % (prefix, start) + s = cls(name, *args, **assumptions) + if s not in exclude: + yield s + start += 1 + + +def capture(func): + """Return the printed output of func(). + + ``func`` should be a function without arguments that produces output with + print statements. + + >>> from sympy.utilities.iterables import capture + >>> from sympy import pprint + >>> from sympy.abc import x + >>> def foo(): + ... print('hello world!') + ... + >>> 'hello' in capture(foo) # foo, not foo() + True + >>> capture(lambda: pprint(2/x)) + '2\\n-\\nx\\n' + + """ + from io import StringIO + import sys + + stdout = sys.stdout + sys.stdout = file = StringIO() + try: + func() + finally: + sys.stdout = stdout + return file.getvalue() + + +def sift(seq, keyfunc, binary=False): + """ + Sift the sequence, ``seq`` according to ``keyfunc``. + + Returns + ======= + + When ``binary`` is ``False`` (default), the output is a dictionary + where elements of ``seq`` are stored in a list keyed to the value + of keyfunc for that element. If ``binary`` is True then a tuple + with lists ``T`` and ``F`` are returned where ``T`` is a list + containing elements of seq for which ``keyfunc`` was ``True`` and + ``F`` containing those elements for which ``keyfunc`` was ``False``; + a ValueError is raised if the ``keyfunc`` is not binary. + + Examples + ======== + + >>> from sympy.utilities import sift + >>> from sympy.abc import x, y + >>> from sympy import sqrt, exp, pi, Tuple + + >>> sift(range(5), lambda x: x % 2) + {0: [0, 2, 4], 1: [1, 3]} + + sift() returns a defaultdict() object, so any key that has no matches will + give []. + + >>> sift([x], lambda x: x.is_commutative) + {True: [x]} + >>> _[False] + [] + + Sometimes you will not know how many keys you will get: + + >>> sift([sqrt(x), exp(x), (y**x)**2], + ... lambda x: x.as_base_exp()[0]) + {E: [exp(x)], x: [sqrt(x)], y: [y**(2*x)]} + + Sometimes you expect the results to be binary; the + results can be unpacked by setting ``binary`` to True: + + >>> sift(range(4), lambda x: x % 2, binary=True) + ([1, 3], [0, 2]) + >>> sift(Tuple(1, pi), lambda x: x.is_rational, binary=True) + ([1], [pi]) + + A ValueError is raised if the predicate was not actually binary + (which is a good test for the logic where sifting is used and + binary results were expected): + + >>> unknown = exp(1) - pi # the rationality of this is unknown + >>> args = Tuple(1, pi, unknown) + >>> sift(args, lambda x: x.is_rational, binary=True) + Traceback (most recent call last): + ... + ValueError: keyfunc gave non-binary output + + The non-binary sifting shows that there were 3 keys generated: + + >>> set(sift(args, lambda x: x.is_rational).keys()) + {None, False, True} + + If you need to sort the sifted items it might be better to use + ``ordered`` which can economically apply multiple sort keys + to a sequence while sorting. + + See Also + ======== + + ordered + + """ + if not binary: + m = defaultdict(list) + for i in seq: + m[keyfunc(i)].append(i) + return m + sift = F, T = [], [] + for i in seq: + try: + sift[keyfunc(i)].append(i) + except (IndexError, TypeError): + raise ValueError('keyfunc gave non-binary output') + return T, F + + +def take(iter, n): + """Return ``n`` items from ``iter`` iterator. """ + return [ value for _, value in zip(range(n), iter) ] + + +def dict_merge(*dicts): + """Merge dictionaries into a single dictionary. """ + merged = {} + + for dict in dicts: + merged.update(dict) + + return merged + + +def common_prefix(*seqs): + """Return the subsequence that is a common start of sequences in ``seqs``. + + >>> from sympy.utilities.iterables import common_prefix + >>> common_prefix(list(range(3))) + [0, 1, 2] + >>> common_prefix(list(range(3)), list(range(4))) + [0, 1, 2] + >>> common_prefix([1, 2, 3], [1, 2, 5]) + [1, 2] + >>> common_prefix([1, 2, 3], [1, 3, 5]) + [1] + """ + if not all(seqs): + return [] + elif len(seqs) == 1: + return seqs[0] + i = 0 + for i in range(min(len(s) for s in seqs)): + if not all(seqs[j][i] == seqs[0][i] for j in range(len(seqs))): + break + else: + i += 1 + return seqs[0][:i] + + +def common_suffix(*seqs): + """Return the subsequence that is a common ending of sequences in ``seqs``. + + >>> from sympy.utilities.iterables import common_suffix + >>> common_suffix(list(range(3))) + [0, 1, 2] + >>> common_suffix(list(range(3)), list(range(4))) + [] + >>> common_suffix([1, 2, 3], [9, 2, 3]) + [2, 3] + >>> common_suffix([1, 2, 3], [9, 7, 3]) + [3] + """ + + if not all(seqs): + return [] + elif len(seqs) == 1: + return seqs[0] + i = 0 + for i in range(-1, -min(len(s) for s in seqs) - 1, -1): + if not all(seqs[j][i] == seqs[0][i] for j in range(len(seqs))): + break + else: + i -= 1 + if i == -1: + return [] + else: + return seqs[0][i + 1:] + + +def prefixes(seq): + """ + Generate all prefixes of a sequence. + + Examples + ======== + + >>> from sympy.utilities.iterables import prefixes + + >>> list(prefixes([1,2,3,4])) + [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] + + """ + n = len(seq) + + for i in range(n): + yield seq[:i + 1] + + +def postfixes(seq): + """ + Generate all postfixes of a sequence. + + Examples + ======== + + >>> from sympy.utilities.iterables import postfixes + + >>> list(postfixes([1,2,3,4])) + [[4], [3, 4], [2, 3, 4], [1, 2, 3, 4]] + + """ + n = len(seq) + + for i in range(n): + yield seq[n - i - 1:] + + +def topological_sort(graph, key=None): + r""" + Topological sort of graph's vertices. + + Parameters + ========== + + graph : tuple[list, list[tuple[T, T]] + A tuple consisting of a list of vertices and a list of edges of + a graph to be sorted topologically. + + key : callable[T] (optional) + Ordering key for vertices on the same level. By default the natural + (e.g. lexicographic) ordering is used (in this case the base type + must implement ordering relations). + + Examples + ======== + + Consider a graph:: + + +---+ +---+ +---+ + | 7 |\ | 5 | | 3 | + +---+ \ +---+ +---+ + | _\___/ ____ _/ | + | / \___/ \ / | + V V V V | + +----+ +---+ | + | 11 | | 8 | | + +----+ +---+ | + | | \____ ___/ _ | + | \ \ / / \ | + V \ V V / V V + +---+ \ +---+ | +----+ + | 2 | | | 9 | | | 10 | + +---+ | +---+ | +----+ + \________/ + + where vertices are integers. This graph can be encoded using + elementary Python's data structures as follows:: + + >>> V = [2, 3, 5, 7, 8, 9, 10, 11] + >>> E = [(7, 11), (7, 8), (5, 11), (3, 8), (3, 10), + ... (11, 2), (11, 9), (11, 10), (8, 9)] + + To compute a topological sort for graph ``(V, E)`` issue:: + + >>> from sympy.utilities.iterables import topological_sort + + >>> topological_sort((V, E)) + [3, 5, 7, 8, 11, 2, 9, 10] + + If specific tie breaking approach is needed, use ``key`` parameter:: + + >>> topological_sort((V, E), key=lambda v: -v) + [7, 5, 11, 3, 10, 8, 9, 2] + + Only acyclic graphs can be sorted. If the input graph has a cycle, + then ``ValueError`` will be raised:: + + >>> topological_sort((V, E + [(10, 7)])) + Traceback (most recent call last): + ... + ValueError: cycle detected + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Topological_sorting + + """ + V, E = graph + + L = [] + S = set(V) + E = list(E) + + S.difference_update(u for v, u in E) + + if key is None: + def key(value): + return value + + S = sorted(S, key=key, reverse=True) + + while S: + node = S.pop() + L.append(node) + + for u, v in list(E): + if u == node: + E.remove((u, v)) + + for _u, _v in E: + if v == _v: + break + else: + kv = key(v) + + for i, s in enumerate(S): + ks = key(s) + + if kv > ks: + S.insert(i, v) + break + else: + S.append(v) + + if E: + raise ValueError("cycle detected") + else: + return L + + +def strongly_connected_components(G): + r""" + Strongly connected components of a directed graph in reverse topological + order. + + + Parameters + ========== + + G : tuple[list, list[tuple[T, T]] + A tuple consisting of a list of vertices and a list of edges of + a graph whose strongly connected components are to be found. + + + Examples + ======== + + Consider a directed graph (in dot notation):: + + digraph { + A -> B + A -> C + B -> C + C -> B + B -> D + } + + .. graphviz:: + + digraph { + A -> B + A -> C + B -> C + C -> B + B -> D + } + + where vertices are the letters A, B, C and D. This graph can be encoded + using Python's elementary data structures as follows:: + + >>> V = ['A', 'B', 'C', 'D'] + >>> E = [('A', 'B'), ('A', 'C'), ('B', 'C'), ('C', 'B'), ('B', 'D')] + + The strongly connected components of this graph can be computed as + + >>> from sympy.utilities.iterables import strongly_connected_components + + >>> strongly_connected_components((V, E)) + [['D'], ['B', 'C'], ['A']] + + This also gives the components in reverse topological order. + + Since the subgraph containing B and C has a cycle they must be together in + a strongly connected component. A and D are connected to the rest of the + graph but not in a cyclic manner so they appear as their own strongly + connected components. + + + Notes + ===== + + The vertices of the graph must be hashable for the data structures used. + If the vertices are unhashable replace them with integer indices. + + This function uses Tarjan's algorithm to compute the strongly connected + components in `O(|V|+|E|)` (linear) time. + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Strongly_connected_component + .. [2] https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm + + + See Also + ======== + + sympy.utilities.iterables.connected_components + + """ + # Map from a vertex to its neighbours + V, E = G + Gmap = {vi: [] for vi in V} + for v1, v2 in E: + Gmap[v1].append(v2) + return _strongly_connected_components(V, Gmap) + + +def _strongly_connected_components(V, Gmap): + """More efficient internal routine for strongly_connected_components""" + # + # Here V is an iterable of vertices and Gmap is a dict mapping each vertex + # to a list of neighbours e.g.: + # + # V = [0, 1, 2, 3] + # Gmap = {0: [2, 3], 1: [0]} + # + # For a large graph these data structures can often be created more + # efficiently then those expected by strongly_connected_components() which + # in this case would be + # + # V = [0, 1, 2, 3] + # Gmap = [(0, 2), (0, 3), (1, 0)] + # + # XXX: Maybe this should be the recommended function to use instead... + # + + # Non-recursive Tarjan's algorithm: + lowlink = {} + indices = {} + stack = OrderedDict() + callstack = [] + components = [] + nomore = object() + + def start(v): + index = len(stack) + indices[v] = lowlink[v] = index + stack[v] = None + callstack.append((v, iter(Gmap[v]))) + + def finish(v1): + # Finished a component? + if lowlink[v1] == indices[v1]: + component = [stack.popitem()[0]] + while component[-1] is not v1: + component.append(stack.popitem()[0]) + components.append(component[::-1]) + v2, _ = callstack.pop() + if callstack: + v1, _ = callstack[-1] + lowlink[v1] = min(lowlink[v1], lowlink[v2]) + + for v in V: + if v in indices: + continue + start(v) + while callstack: + v1, it1 = callstack[-1] + v2 = next(it1, nomore) + # Finished children of v1? + if v2 is nomore: + finish(v1) + # Recurse on v2 + elif v2 not in indices: + start(v2) + elif v2 in stack: + lowlink[v1] = min(lowlink[v1], indices[v2]) + + # Reverse topological sort order: + return components + + +def connected_components(G): + r""" + Connected components of an undirected graph or weakly connected components + of a directed graph. + + + Parameters + ========== + + G : tuple[list, list[tuple[T, T]] + A tuple consisting of a list of vertices and a list of edges of + a graph whose connected components are to be found. + + + Examples + ======== + + + Given an undirected graph:: + + graph { + A -- B + C -- D + } + + .. graphviz:: + + graph { + A -- B + C -- D + } + + We can find the connected components using this function if we include + each edge in both directions:: + + >>> from sympy.utilities.iterables import connected_components + + >>> V = ['A', 'B', 'C', 'D'] + >>> E = [('A', 'B'), ('B', 'A'), ('C', 'D'), ('D', 'C')] + >>> connected_components((V, E)) + [['A', 'B'], ['C', 'D']] + + The weakly connected components of a directed graph can found the same + way. + + + Notes + ===== + + The vertices of the graph must be hashable for the data structures used. + If the vertices are unhashable replace them with integer indices. + + This function uses Tarjan's algorithm to compute the connected components + in `O(|V|+|E|)` (linear) time. + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Component_%28graph_theory%29 + .. [2] https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm + + + See Also + ======== + + sympy.utilities.iterables.strongly_connected_components + + """ + # Duplicate edges both ways so that the graph is effectively undirected + # and return the strongly connected components: + V, E = G + E_undirected = [] + for v1, v2 in E: + E_undirected.extend([(v1, v2), (v2, v1)]) + return strongly_connected_components((V, E_undirected)) + + +def rotate_left(x, y): + """ + Left rotates a list x by the number of steps specified + in y. + + Examples + ======== + + >>> from sympy.utilities.iterables import rotate_left + >>> a = [0, 1, 2] + >>> rotate_left(a, 1) + [1, 2, 0] + """ + if len(x) == 0: + return [] + y = y % len(x) + return x[y:] + x[:y] + + +def rotate_right(x, y): + """ + Right rotates a list x by the number of steps specified + in y. + + Examples + ======== + + >>> from sympy.utilities.iterables import rotate_right + >>> a = [0, 1, 2] + >>> rotate_right(a, 1) + [2, 0, 1] + """ + if len(x) == 0: + return [] + y = len(x) - y % len(x) + return x[y:] + x[:y] + + +def least_rotation(x, key=None): + ''' + Returns the number of steps of left rotation required to + obtain lexicographically minimal string/list/tuple, etc. + + Examples + ======== + + >>> from sympy.utilities.iterables import least_rotation, rotate_left + >>> a = [3, 1, 5, 1, 2] + >>> least_rotation(a) + 3 + >>> rotate_left(a, _) + [1, 2, 3, 1, 5] + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Lexicographically_minimal_string_rotation + + ''' + from sympy.functions.elementary.miscellaneous import Id + if key is None: key = Id + S = x + x # Concatenate string to it self to avoid modular arithmetic + f = [-1] * len(S) # Failure function + k = 0 # Least rotation of string found so far + for j in range(1,len(S)): + sj = S[j] + i = f[j-k-1] + while i != -1 and sj != S[k+i+1]: + if key(sj) < key(S[k+i+1]): + k = j-i-1 + i = f[i] + if sj != S[k+i+1]: + if key(sj) < key(S[k]): + k = j + f[j-k] = -1 + else: + f[j-k] = i+1 + return k + + +def multiset_combinations(m, n, g=None): + """ + Return the unique combinations of size ``n`` from multiset ``m``. + + Examples + ======== + + >>> from sympy.utilities.iterables import multiset_combinations + >>> from itertools import combinations + >>> [''.join(i) for i in multiset_combinations('baby', 3)] + ['abb', 'aby', 'bby'] + + >>> def count(f, s): return len(list(f(s, 3))) + + The number of combinations depends on the number of letters; the + number of unique combinations depends on how the letters are + repeated. + + >>> s1 = 'abracadabra' + >>> s2 = 'banana tree' + >>> count(combinations, s1), count(multiset_combinations, s1) + (165, 23) + >>> count(combinations, s2), count(multiset_combinations, s2) + (165, 54) + + """ + from sympy.core.sorting import ordered + if g is None: + if isinstance(m, dict): + if any(as_int(v) < 0 for v in m.values()): + raise ValueError('counts cannot be negative') + N = sum(m.values()) + if n > N: + return + g = [[k, m[k]] for k in ordered(m)] + else: + m = list(m) + N = len(m) + if n > N: + return + try: + m = multiset(m) + g = [(k, m[k]) for k in ordered(m)] + except TypeError: + m = list(ordered(m)) + g = [list(i) for i in group(m, multiple=False)] + del m + else: + # not checking counts since g is intended for internal use + N = sum(v for k, v in g) + if n > N or not n: + yield [] + else: + for i, (k, v) in enumerate(g): + if v >= n: + yield [k]*n + v = n - 1 + for v in range(min(n, v), 0, -1): + for j in multiset_combinations(None, n - v, g[i + 1:]): + rv = [k]*v + j + if len(rv) == n: + yield rv + +def multiset_permutations(m, size=None, g=None): + """ + Return the unique permutations of multiset ``m``. + + Examples + ======== + + >>> from sympy.utilities.iterables import multiset_permutations + >>> from sympy import factorial + >>> [''.join(i) for i in multiset_permutations('aab')] + ['aab', 'aba', 'baa'] + >>> factorial(len('banana')) + 720 + >>> len(list(multiset_permutations('banana'))) + 60 + """ + from sympy.core.sorting import ordered + if g is None: + if isinstance(m, dict): + if any(as_int(v) < 0 for v in m.values()): + raise ValueError('counts cannot be negative') + g = [[k, m[k]] for k in ordered(m)] + else: + m = list(ordered(m)) + g = [list(i) for i in group(m, multiple=False)] + del m + do = [gi for gi in g if gi[1] > 0] + SUM = sum(gi[1] for gi in do) + if not do or size is not None and (size > SUM or size < 1): + if not do and size is None or size == 0: + yield [] + return + elif size == 1: + for k, v in do: + yield [k] + elif len(do) == 1: + k, v = do[0] + v = v if size is None else (size if size <= v else 0) + yield [k for i in range(v)] + elif all(v == 1 for k, v in do): + for p in permutations([k for k, v in do], size): + yield list(p) + else: + size = size if size is not None else SUM + for i, (k, v) in enumerate(do): + do[i][1] -= 1 + for j in multiset_permutations(None, size - 1, do): + if j: + yield [k] + j + do[i][1] += 1 + + +def _partition(seq, vector, m=None): + """ + Return the partition of seq as specified by the partition vector. + + Examples + ======== + + >>> from sympy.utilities.iterables import _partition + >>> _partition('abcde', [1, 0, 1, 2, 0]) + [['b', 'e'], ['a', 'c'], ['d']] + + Specifying the number of bins in the partition is optional: + + >>> _partition('abcde', [1, 0, 1, 2, 0], 3) + [['b', 'e'], ['a', 'c'], ['d']] + + The output of _set_partitions can be passed as follows: + + >>> output = (3, [1, 0, 1, 2, 0]) + >>> _partition('abcde', *output) + [['b', 'e'], ['a', 'c'], ['d']] + + See Also + ======== + + combinatorics.partitions.Partition.from_rgs + + """ + if m is None: + m = max(vector) + 1 + elif isinstance(vector, int): # entered as m, vector + vector, m = m, vector + p = [[] for i in range(m)] + for i, v in enumerate(vector): + p[v].append(seq[i]) + return p + + +def _set_partitions(n): + """Cycle through all partitions of n elements, yielding the + current number of partitions, ``m``, and a mutable list, ``q`` + such that ``element[i]`` is in part ``q[i]`` of the partition. + + NOTE: ``q`` is modified in place and generally should not be changed + between function calls. + + Examples + ======== + + >>> from sympy.utilities.iterables import _set_partitions, _partition + >>> for m, q in _set_partitions(3): + ... print('%s %s %s' % (m, q, _partition('abc', q, m))) + 1 [0, 0, 0] [['a', 'b', 'c']] + 2 [0, 0, 1] [['a', 'b'], ['c']] + 2 [0, 1, 0] [['a', 'c'], ['b']] + 2 [0, 1, 1] [['a'], ['b', 'c']] + 3 [0, 1, 2] [['a'], ['b'], ['c']] + + Notes + ===== + + This algorithm is similar to, and solves the same problem as, + Algorithm 7.2.1.5H, from volume 4A of Knuth's The Art of Computer + Programming. Knuth uses the term "restricted growth string" where + this code refers to a "partition vector". In each case, the meaning is + the same: the value in the ith element of the vector specifies to + which part the ith set element is to be assigned. + + At the lowest level, this code implements an n-digit big-endian + counter (stored in the array q) which is incremented (with carries) to + get the next partition in the sequence. A special twist is that a + digit is constrained to be at most one greater than the maximum of all + the digits to the left of it. The array p maintains this maximum, so + that the code can efficiently decide when a digit can be incremented + in place or whether it needs to be reset to 0 and trigger a carry to + the next digit. The enumeration starts with all the digits 0 (which + corresponds to all the set elements being assigned to the same 0th + part), and ends with 0123...n, which corresponds to each set element + being assigned to a different, singleton, part. + + This routine was rewritten to use 0-based lists while trying to + preserve the beauty and efficiency of the original algorithm. + + References + ========== + + .. [1] Nijenhuis, Albert and Wilf, Herbert. (1978) Combinatorial Algorithms, + 2nd Ed, p 91, algorithm "nexequ". Available online from + https://www.math.upenn.edu/~wilf/website/CombAlgDownld.html (viewed + November 17, 2012). + + """ + p = [0]*n + q = [0]*n + nc = 1 + yield nc, q + while nc != n: + m = n + while 1: + m -= 1 + i = q[m] + if p[i] != 1: + break + q[m] = 0 + i += 1 + q[m] = i + m += 1 + nc += m - n + p[0] += n - m + if i == nc: + p[nc] = 0 + nc += 1 + p[i - 1] -= 1 + p[i] += 1 + yield nc, q + + +def multiset_partitions(multiset, m=None): + """ + Return unique partitions of the given multiset (in list form). + If ``m`` is None, all multisets will be returned, otherwise only + partitions with ``m`` parts will be returned. + + If ``multiset`` is an integer, a range [0, 1, ..., multiset - 1] + will be supplied. + + Examples + ======== + + >>> from sympy.utilities.iterables import multiset_partitions + >>> list(multiset_partitions([1, 2, 3, 4], 2)) + [[[1, 2, 3], [4]], [[1, 2, 4], [3]], [[1, 2], [3, 4]], + [[1, 3, 4], [2]], [[1, 3], [2, 4]], [[1, 4], [2, 3]], + [[1], [2, 3, 4]]] + >>> list(multiset_partitions([1, 2, 3, 4], 1)) + [[[1, 2, 3, 4]]] + + Only unique partitions are returned and these will be returned in a + canonical order regardless of the order of the input: + + >>> a = [1, 2, 2, 1] + >>> ans = list(multiset_partitions(a, 2)) + >>> a.sort() + >>> list(multiset_partitions(a, 2)) == ans + True + >>> a = range(3, 1, -1) + >>> (list(multiset_partitions(a)) == + ... list(multiset_partitions(sorted(a)))) + True + + If m is omitted then all partitions will be returned: + + >>> list(multiset_partitions([1, 1, 2])) + [[[1, 1, 2]], [[1, 1], [2]], [[1, 2], [1]], [[1], [1], [2]]] + >>> list(multiset_partitions([1]*3)) + [[[1, 1, 1]], [[1], [1, 1]], [[1], [1], [1]]] + + Counting + ======== + + The number of partitions of a set is given by the bell number: + + >>> from sympy import bell + >>> len(list(multiset_partitions(5))) == bell(5) == 52 + True + + The number of partitions of length k from a set of size n is given by the + Stirling Number of the 2nd kind: + + >>> from sympy.functions.combinatorial.numbers import stirling + >>> stirling(5, 2) == len(list(multiset_partitions(5, 2))) == 15 + True + + These comments on counting apply to *sets*, not multisets. + + Notes + ===== + + When all the elements are the same in the multiset, the order + of the returned partitions is determined by the ``partitions`` + routine. If one is counting partitions then it is better to use + the ``nT`` function. + + See Also + ======== + + partitions + sympy.combinatorics.partitions.Partition + sympy.combinatorics.partitions.IntegerPartition + sympy.functions.combinatorial.numbers.nT + + """ + # This function looks at the supplied input and dispatches to + # several special-case routines as they apply. + if isinstance(multiset, int): + n = multiset + if m and m > n: + return + multiset = list(range(n)) + if m == 1: + yield [multiset[:]] + return + + # If m is not None, it can sometimes be faster to use + # MultisetPartitionTraverser.enum_range() even for inputs + # which are sets. Since the _set_partitions code is quite + # fast, this is only advantageous when the overall set + # partitions outnumber those with the desired number of parts + # by a large factor. (At least 60.) Such a switch is not + # currently implemented. + for nc, q in _set_partitions(n): + if m is None or nc == m: + rv = [[] for i in range(nc)] + for i in range(n): + rv[q[i]].append(multiset[i]) + yield rv + return + + if len(multiset) == 1 and isinstance(multiset, str): + multiset = [multiset] + + if not has_variety(multiset): + # Only one component, repeated n times. The resulting + # partitions correspond to partitions of integer n. + n = len(multiset) + if m and m > n: + return + if m == 1: + yield [multiset[:]] + return + x = multiset[:1] + for size, p in partitions(n, m, size=True): + if m is None or size == m: + rv = [] + for k in sorted(p): + rv.extend([x*k]*p[k]) + yield rv + else: + from sympy.core.sorting import ordered + multiset = list(ordered(multiset)) + n = len(multiset) + if m and m > n: + return + if m == 1: + yield [multiset[:]] + return + + # Split the information of the multiset into two lists - + # one of the elements themselves, and one (of the same length) + # giving the number of repeats for the corresponding element. + elements, multiplicities = zip(*group(multiset, False)) + + if len(elements) < len(multiset): + # General case - multiset with more than one distinct element + # and at least one element repeated more than once. + if m: + mpt = MultisetPartitionTraverser() + for state in mpt.enum_range(multiplicities, m-1, m): + yield list_visitor(state, elements) + else: + for state in multiset_partitions_taocp(multiplicities): + yield list_visitor(state, elements) + else: + # Set partitions case - no repeated elements. Pretty much + # same as int argument case above, with same possible, but + # currently unimplemented optimization for some cases when + # m is not None + for nc, q in _set_partitions(n): + if m is None or nc == m: + rv = [[] for i in range(nc)] + for i in range(n): + rv[q[i]].append(i) + yield [[multiset[j] for j in i] for i in rv] + + +def partitions(n, m=None, k=None, size=False): + """Generate all partitions of positive integer, n. + + Each partition is represented as a dictionary, mapping an integer + to the number of copies of that integer in the partition. For example, + the first partition of 4 returned is {4: 1}, "4: one of them". + + Parameters + ========== + n : int + m : int, optional + limits number of parts in partition (mnemonic: m, maximum parts) + k : int, optional + limits the numbers that are kept in the partition (mnemonic: k, keys) + size : bool, default: False + If ``True``, (M, P) is returned where M is the sum of the + multiplicities and P is the generated partition. + If ``False``, only the generated partition is returned. + + Examples + ======== + + >>> from sympy.utilities.iterables import partitions + + The numbers appearing in the partition (the key of the returned dict) + are limited with k: + + >>> for p in partitions(6, k=2): # doctest: +SKIP + ... print(p) + {2: 3} + {1: 2, 2: 2} + {1: 4, 2: 1} + {1: 6} + + The maximum number of parts in the partition (the sum of the values in + the returned dict) are limited with m (default value, None, gives + partitions from 1 through n): + + >>> for p in partitions(6, m=2): # doctest: +SKIP + ... print(p) + ... + {6: 1} + {1: 1, 5: 1} + {2: 1, 4: 1} + {3: 2} + + References + ========== + + .. [1] modified from Tim Peter's version to allow for k and m values: + https://code.activestate.com/recipes/218332-generator-for-integer-partitions/ + + See Also + ======== + + sympy.combinatorics.partitions.Partition + sympy.combinatorics.partitions.IntegerPartition + + """ + if (n <= 0 or + m is not None and m < 1 or + k is not None and k < 1 or + m and k and m*k < n): + # the empty set is the only way to handle these inputs + # and returning {} to represent it is consistent with + # the counting convention, e.g. nT(0) == 1. + if size: + yield 0, {} + else: + yield {} + return + + if m is None: + m = n + else: + m = min(m, n) + k = min(k or n, n) + + n, m, k = as_int(n), as_int(m), as_int(k) + q, r = divmod(n, k) + ms = {k: q} + keys = [k] # ms.keys(), from largest to smallest + if r: + ms[r] = 1 + keys.append(r) + room = m - q - bool(r) + if size: + yield sum(ms.values()), ms.copy() + else: + yield ms.copy() + + while keys != [1]: + # Reuse any 1's. + if keys[-1] == 1: + del keys[-1] + reuse = ms.pop(1) + room += reuse + else: + reuse = 0 + + while 1: + # Let i be the smallest key larger than 1. Reuse one + # instance of i. + i = keys[-1] + newcount = ms[i] = ms[i] - 1 + reuse += i + if newcount == 0: + del keys[-1], ms[i] + room += 1 + + # Break the remainder into pieces of size i-1. + i -= 1 + q, r = divmod(reuse, i) + need = q + bool(r) + if need > room: + if not keys: + return + continue + + ms[i] = q + keys.append(i) + if r: + ms[r] = 1 + keys.append(r) + break + room -= need + if size: + yield sum(ms.values()), ms.copy() + else: + yield ms.copy() + + +def ordered_partitions(n, m=None, sort=True): + """Generates ordered partitions of integer *n*. + + Parameters + ========== + n : int + m : int, optional + The default value gives partitions of all sizes else only + those with size m. In addition, if *m* is not None then + partitions are generated *in place* (see examples). + sort : bool, default: True + Controls whether partitions are + returned in sorted order when *m* is not None; when False, + the partitions are returned as fast as possible with elements + sorted, but when m|n the partitions will not be in + ascending lexicographical order. + + Examples + ======== + + >>> from sympy.utilities.iterables import ordered_partitions + + All partitions of 5 in ascending lexicographical: + + >>> for p in ordered_partitions(5): + ... print(p) + [1, 1, 1, 1, 1] + [1, 1, 1, 2] + [1, 1, 3] + [1, 2, 2] + [1, 4] + [2, 3] + [5] + + Only partitions of 5 with two parts: + + >>> for p in ordered_partitions(5, 2): + ... print(p) + [1, 4] + [2, 3] + + When ``m`` is given, a given list objects will be used more than + once for speed reasons so you will not see the correct partitions + unless you make a copy of each as it is generated: + + >>> [p for p in ordered_partitions(7, 3)] + [[1, 1, 1], [1, 1, 1], [1, 1, 1], [2, 2, 2]] + >>> [list(p) for p in ordered_partitions(7, 3)] + [[1, 1, 5], [1, 2, 4], [1, 3, 3], [2, 2, 3]] + + When ``n`` is a multiple of ``m``, the elements are still sorted + but the partitions themselves will be *unordered* if sort is False; + the default is to return them in ascending lexicographical order. + + >>> for p in ordered_partitions(6, 2): + ... print(p) + [1, 5] + [2, 4] + [3, 3] + + But if speed is more important than ordering, sort can be set to + False: + + >>> for p in ordered_partitions(6, 2, sort=False): + ... print(p) + [1, 5] + [3, 3] + [2, 4] + + References + ========== + + .. [1] Generating Integer Partitions, [online], + Available: https://jeromekelleher.net/generating-integer-partitions.html + .. [2] Jerome Kelleher and Barry O'Sullivan, "Generating All + Partitions: A Comparison Of Two Encodings", [online], + Available: https://arxiv.org/pdf/0909.2331v2.pdf + """ + if n < 1 or m is not None and m < 1: + # the empty set is the only way to handle these inputs + # and returning {} to represent it is consistent with + # the counting convention, e.g. nT(0) == 1. + yield [] + return + + if m is None: + # The list `a`'s leading elements contain the partition in which + # y is the biggest element and x is either the same as y or the + # 2nd largest element; v and w are adjacent element indices + # to which x and y are being assigned, respectively. + a = [1]*n + y = -1 + v = n + while v > 0: + v -= 1 + x = a[v] + 1 + while y >= 2 * x: + a[v] = x + y -= x + v += 1 + w = v + 1 + while x <= y: + a[v] = x + a[w] = y + yield a[:w + 1] + x += 1 + y -= 1 + a[v] = x + y + y = a[v] - 1 + yield a[:w] + elif m == 1: + yield [n] + elif n == m: + yield [1]*n + else: + # recursively generate partitions of size m + for b in range(1, n//m + 1): + a = [b]*m + x = n - b*m + if not x: + if sort: + yield a + elif not sort and x <= m: + for ax in ordered_partitions(x, sort=False): + mi = len(ax) + a[-mi:] = [i + b for i in ax] + yield a + a[-mi:] = [b]*mi + else: + for mi in range(1, m): + for ax in ordered_partitions(x, mi, sort=True): + a[-mi:] = [i + b for i in ax] + yield a + a[-mi:] = [b]*mi + + +def binary_partitions(n): + """ + Generates the binary partition of *n*. + + A binary partition consists only of numbers that are + powers of two. Each step reduces a `2^{k+1}` to `2^k` and + `2^k`. Thus 16 is converted to 8 and 8. + + Examples + ======== + + >>> from sympy.utilities.iterables import binary_partitions + >>> for i in binary_partitions(5): + ... print(i) + ... + [4, 1] + [2, 2, 1] + [2, 1, 1, 1] + [1, 1, 1, 1, 1] + + References + ========== + + .. [1] TAOCP 4, section 7.2.1.5, problem 64 + + """ + from math import ceil, log2 + power = int(2**(ceil(log2(n)))) + acc = 0 + partition = [] + while power: + if acc + power <= n: + partition.append(power) + acc += power + power >>= 1 + + last_num = len(partition) - 1 - (n & 1) + while last_num >= 0: + yield partition + if partition[last_num] == 2: + partition[last_num] = 1 + partition.append(1) + last_num -= 1 + continue + partition.append(1) + partition[last_num] >>= 1 + x = partition[last_num + 1] = partition[last_num] + last_num += 1 + while x > 1: + if x <= len(partition) - last_num - 1: + del partition[-x + 1:] + last_num += 1 + partition[last_num] = x + else: + x >>= 1 + yield [1]*n + + +def has_dups(seq): + """Return True if there are any duplicate elements in ``seq``. + + Examples + ======== + + >>> from sympy import has_dups, Dict, Set + >>> has_dups((1, 2, 1)) + True + >>> has_dups(range(3)) + False + >>> all(has_dups(c) is False for c in (set(), Set(), dict(), Dict())) + True + """ + from sympy.core.containers import Dict + from sympy.sets.sets import Set + if isinstance(seq, (dict, set, Dict, Set)): + return False + unique = set() + try: + return any(True for s in seq if s in unique or unique.add(s)) + except TypeError: + return len(seq) != len(list(uniq(seq))) + + +def has_variety(seq): + """Return True if there are any different elements in ``seq``. + + Examples + ======== + + >>> from sympy import has_variety + + >>> has_variety((1, 2, 1)) + True + >>> has_variety((1, 1, 1)) + False + """ + for i, s in enumerate(seq): + if i == 0: + sentinel = s + else: + if s != sentinel: + return True + return False + + +def uniq(seq, result=None): + """ + Yield unique elements from ``seq`` as an iterator. The second + parameter ``result`` is used internally; it is not necessary + to pass anything for this. + + Note: changing the sequence during iteration will raise a + RuntimeError if the size of the sequence is known; if you pass + an iterator and advance the iterator you will change the + output of this routine but there will be no warning. + + Examples + ======== + + >>> from sympy.utilities.iterables import uniq + >>> dat = [1, 4, 1, 5, 4, 2, 1, 2] + >>> type(uniq(dat)) in (list, tuple) + False + + >>> list(uniq(dat)) + [1, 4, 5, 2] + >>> list(uniq(x for x in dat)) + [1, 4, 5, 2] + >>> list(uniq([[1], [2, 1], [1]])) + [[1], [2, 1]] + """ + try: + n = len(seq) + except TypeError: + n = None + def check(): + # check that size of seq did not change during iteration; + # if n == None the object won't support size changing, e.g. + # an iterator can't be changed + if n is not None and len(seq) != n: + raise RuntimeError('sequence changed size during iteration') + try: + seen = set() + result = result or [] + for i, s in enumerate(seq): + if not (s in seen or seen.add(s)): + yield s + check() + except TypeError: + if s not in result: + yield s + check() + result.append(s) + if hasattr(seq, '__getitem__'): + yield from uniq(seq[i + 1:], result) + else: + yield from uniq(seq, result) + + +def generate_bell(n): + """Return permutations of [0, 1, ..., n - 1] such that each permutation + differs from the last by the exchange of a single pair of neighbors. + The ``n!`` permutations are returned as an iterator. In order to obtain + the next permutation from a random starting permutation, use the + ``next_trotterjohnson`` method of the Permutation class (which generates + the same sequence in a different manner). + + Examples + ======== + + >>> from itertools import permutations + >>> from sympy.utilities.iterables import generate_bell + >>> from sympy import zeros, Matrix + + This is the sort of permutation used in the ringing of physical bells, + and does not produce permutations in lexicographical order. Rather, the + permutations differ from each other by exactly one inversion, and the + position at which the swapping occurs varies periodically in a simple + fashion. Consider the first few permutations of 4 elements generated + by ``permutations`` and ``generate_bell``: + + >>> list(permutations(range(4)))[:5] + [(0, 1, 2, 3), (0, 1, 3, 2), (0, 2, 1, 3), (0, 2, 3, 1), (0, 3, 1, 2)] + >>> list(generate_bell(4))[:5] + [(0, 1, 2, 3), (0, 1, 3, 2), (0, 3, 1, 2), (3, 0, 1, 2), (3, 0, 2, 1)] + + Notice how the 2nd and 3rd lexicographical permutations have 3 elements + out of place whereas each "bell" permutation always has only two + elements out of place relative to the previous permutation (and so the + signature (+/-1) of a permutation is opposite of the signature of the + previous permutation). + + How the position of inversion varies across the elements can be seen + by tracing out where the largest number appears in the permutations: + + >>> m = zeros(4, 24) + >>> for i, p in enumerate(generate_bell(4)): + ... m[:, i] = Matrix([j - 3 for j in list(p)]) # make largest zero + >>> m.print_nonzero('X') + [XXX XXXXXX XXXXXX XXX] + [XX XX XXXX XX XXXX XX XX] + [X XXXX XX XXXX XX XXXX X] + [ XXXXXX XXXXXX XXXXXX ] + + See Also + ======== + + sympy.combinatorics.permutations.Permutation.next_trotterjohnson + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Method_ringing + + .. [2] https://stackoverflow.com/questions/4856615/recursive-permutation/4857018 + + .. [3] https://web.archive.org/web/20160313023044/http://programminggeeks.com/bell-algorithm-for-permutation/ + + .. [4] https://en.wikipedia.org/wiki/Steinhaus%E2%80%93Johnson%E2%80%93Trotter_algorithm + + .. [5] Generating involutions, derangements, and relatives by ECO + Vincent Vajnovszki, DMTCS vol 1 issue 12, 2010 + + """ + n = as_int(n) + if n < 1: + raise ValueError('n must be a positive integer') + if n == 1: + yield (0,) + elif n == 2: + yield (0, 1) + yield (1, 0) + elif n == 3: + yield from [(0, 1, 2), (0, 2, 1), (2, 0, 1), (2, 1, 0), (1, 2, 0), (1, 0, 2)] + else: + m = n - 1 + op = [0] + [-1]*m + l = list(range(n)) + while True: + yield tuple(l) + # find biggest element with op + big = None, -1 # idx, value + for i in range(n): + if op[i] and l[i] > big[1]: + big = i, l[i] + i, _ = big + if i is None: + break # there are no ops left + # swap it with neighbor in the indicated direction + j = i + op[i] + l[i], l[j] = l[j], l[i] + op[i], op[j] = op[j], op[i] + # if it landed at the end or if the neighbor in the same + # direction is bigger then turn off op + if j == 0 or j == m or l[j + op[j]] > l[j]: + op[j] = 0 + # any element bigger to the left gets +1 op + for i in range(j): + if l[i] > l[j]: + op[i] = 1 + # any element bigger to the right gets -1 op + for i in range(j + 1, n): + if l[i] > l[j]: + op[i] = -1 + + +def generate_involutions(n): + """ + Generates involutions. + + An involution is a permutation that when multiplied + by itself equals the identity permutation. In this + implementation the involutions are generated using + Fixed Points. + + Alternatively, an involution can be considered as + a permutation that does not contain any cycles with + a length that is greater than two. + + Examples + ======== + + >>> from sympy.utilities.iterables import generate_involutions + >>> list(generate_involutions(3)) + [(0, 1, 2), (0, 2, 1), (1, 0, 2), (2, 1, 0)] + >>> len(list(generate_involutions(4))) + 10 + + References + ========== + + .. [1] https://mathworld.wolfram.com/PermutationInvolution.html + + """ + idx = list(range(n)) + for p in permutations(idx): + for i in idx: + if p[p[i]] != i: + break + else: + yield p + + +def multiset_derangements(s): + """Generate derangements of the elements of s *in place*. + + Examples + ======== + + >>> from sympy.utilities.iterables import multiset_derangements, uniq + + Because the derangements of multisets (not sets) are generated + in place, copies of the return value must be made if a collection + of derangements is desired or else all values will be the same: + + >>> list(uniq([i for i in multiset_derangements('1233')])) + [[None, None, None, None]] + >>> [i.copy() for i in multiset_derangements('1233')] + [['3', '3', '1', '2'], ['3', '3', '2', '1']] + >>> [''.join(i) for i in multiset_derangements('1233')] + ['3312', '3321'] + """ + from sympy.core.sorting import ordered + # create multiset dictionary of hashable elements or else + # remap elements to integers + try: + ms = multiset(s) + except TypeError: + # give each element a canonical integer value + key = dict(enumerate(ordered(uniq(s)))) + h = [] + for si in s: + for k in key: + if key[k] == si: + h.append(k) + break + for i in multiset_derangements(h): + yield [key[j] for j in i] + return + + mx = max(ms.values()) # max repetition of any element + n = len(s) # the number of elements + + ## special cases + + # 1) one element has more than half the total cardinality of s: no + # derangements are possible. + if mx*2 > n: + return + + # 2) all elements appear once: singletons + if len(ms) == n: + yield from _set_derangements(s) + return + + # find the first element that is repeated the most to place + # in the following two special cases where the selection + # is unambiguous: either there are two elements with multiplicity + # of mx or else there is only one with multiplicity mx + for M in ms: + if ms[M] == mx: + break + + inonM = [i for i in range(n) if s[i] != M] # location of non-M + iM = [i for i in range(n) if s[i] == M] # locations of M + rv = [None]*n + + # 3) half are the same + if 2*mx == n: + # M goes into non-M locations + for i in inonM: + rv[i] = M + # permutations of non-M go to M locations + for p in multiset_permutations([s[i] for i in inonM]): + for i, pi in zip(iM, p): + rv[i] = pi + yield rv + # clean-up (and encourages proper use of routine) + rv[:] = [None]*n + return + + # 4) single repeat covers all but 1 of the non-repeats: + # if there is one repeat then the multiset of the values + # of ms would be {mx: 1, 1: n - mx}, i.e. there would + # be n - mx + 1 values with the condition that n - 2*mx = 1 + if n - 2*mx == 1 and len(ms.values()) == n - mx + 1: + for i, i1 in enumerate(inonM): + ifill = inonM[:i] + inonM[i+1:] + for j in ifill: + rv[j] = M + for p in permutations([s[j] for j in ifill]): + rv[i1] = s[i1] + for j, pi in zip(iM, p): + rv[j] = pi + k = i1 + for j in iM: + rv[j], rv[k] = rv[k], rv[j] + yield rv + k = j + # clean-up (and encourages proper use of routine) + rv[:] = [None]*n + return + + ## general case is handled with 3 helpers: + # 1) `finish_derangements` will place the last two elements + # which have arbitrary multiplicities, e.g. for multiset + # {c: 3, a: 2, b: 2}, the last two elements are a and b + # 2) `iopen` will tell where a given element can be placed + # 3) `do` will recursively place elements into subsets of + # valid locations + + def finish_derangements(): + """Place the last two elements into the partially completed + derangement, and yield the results. + """ + + a = take[1][0] # penultimate element + a_ct = take[1][1] + b = take[0][0] # last element to be placed + b_ct = take[0][1] + + # split the indexes of the not-already-assigned elements of rv into + # three categories + forced_a = [] # positions which must have an a + forced_b = [] # positions which must have a b + open_free = [] # positions which could take either + for i in range(len(s)): + if rv[i] is None: + if s[i] == a: + forced_b.append(i) + elif s[i] == b: + forced_a.append(i) + else: + open_free.append(i) + + if len(forced_a) > a_ct or len(forced_b) > b_ct: + # No derangement possible + return + + for i in forced_a: + rv[i] = a + for i in forced_b: + rv[i] = b + for a_place in combinations(open_free, a_ct - len(forced_a)): + for a_pos in a_place: + rv[a_pos] = a + for i in open_free: + if rv[i] is None: # anything not in the subset is set to b + rv[i] = b + yield rv + # Clean up/undo the final placements + for i in open_free: + rv[i] = None + + # additional cleanup - clear forced_a, forced_b + for i in forced_a: + rv[i] = None + for i in forced_b: + rv[i] = None + + def iopen(v): + # return indices at which element v can be placed in rv: + # locations which are not already occupied if that location + # does not already contain v in the same location of s + return [i for i in range(n) if rv[i] is None and s[i] != v] + + def do(j): + if j == 1: + # handle the last two elements (regardless of multiplicity) + # with a special method + yield from finish_derangements() + else: + # place the mx elements of M into a subset of places + # into which it can be replaced + M, mx = take[j] + for i in combinations(iopen(M), mx): + # place M + for ii in i: + rv[ii] = M + # recursively place the next element + yield from do(j - 1) + # mark positions where M was placed as once again + # open for placement of other elements + for ii in i: + rv[ii] = None + + # process elements in order of canonically decreasing multiplicity + take = sorted(ms.items(), key=lambda x:(x[1], x[0])) + yield from do(len(take) - 1) + rv[:] = [None]*n + + +def random_derangement(t, choice=None, strict=True): + """Return a list of elements in which none are in the same positions + as they were originally. If an element fills more than half of the positions + then an error will be raised since no derangement is possible. To obtain + a derangement of as many items as possible--with some of the most numerous + remaining in their original positions--pass `strict=False`. To produce a + pseudorandom derangment, pass a pseudorandom selector like `choice` (see + below). + + Examples + ======== + + >>> from sympy.utilities.iterables import random_derangement + >>> t = 'SymPy: a CAS in pure Python' + >>> d = random_derangement(t) + >>> all(i != j for i, j in zip(d, t)) + True + + A predictable result can be obtained by using a pseudorandom + generator for the choice: + + >>> from sympy.core.random import seed, choice as c + >>> seed(1) + >>> d = [''.join(random_derangement(t, c)) for i in range(5)] + >>> assert len(set(d)) != 1 # we got different values + + By reseeding, the same sequence can be obtained: + + >>> seed(1) + >>> d2 = [''.join(random_derangement(t, c)) for i in range(5)] + >>> assert d == d2 + """ + if choice is None: + import secrets + choice = secrets.choice + def shuffle(rv): + '''Knuth shuffle''' + for i in range(len(rv) - 1, 0, -1): + x = choice(rv[:i + 1]) + j = rv.index(x) + rv[i], rv[j] = rv[j], rv[i] + def pick(rv, n): + '''shuffle rv and return the first n values + ''' + shuffle(rv) + return rv[:n] + ms = multiset(t) + tot = len(t) + ms = sorted(ms.items(), key=lambda x: x[1]) + # if there are not enough spaces for the most + # plentiful element to move to then some of them + # will have to stay in place + M, mx = ms[-1] + n = len(t) + xs = 2*mx - tot + if xs > 0: + if strict: + raise ValueError('no derangement possible') + opts = [i for (i, c) in enumerate(t) if c == ms[-1][0]] + pick(opts, xs) + stay = sorted(opts[:xs]) + rv = list(t) + for i in reversed(stay): + rv.pop(i) + rv = random_derangement(rv, choice) + for i in stay: + rv.insert(i, ms[-1][0]) + return ''.join(rv) if type(t) is str else rv + # the normal derangement calculated from here + if n == len(ms): + # approx 1/3 will succeed + rv = list(t) + while True: + shuffle(rv) + if all(i != j for i,j in zip(rv, t)): + break + else: + # general case + rv = [None]*n + while True: + j = 0 + while j > -len(ms): # do most numerous first + j -= 1 + e, c = ms[j] + opts = [i for i in range(n) if rv[i] is None and t[i] != e] + if len(opts) < c: + for i in range(n): + rv[i] = None + break # try again + pick(opts, c) + for i in range(c): + rv[opts[i]] = e + else: + return rv + return rv + + +def _set_derangements(s): + """ + yield derangements of items in ``s`` which are assumed to contain + no repeated elements + """ + if len(s) < 2: + return + if len(s) == 2: + yield [s[1], s[0]] + return + if len(s) == 3: + yield [s[1], s[2], s[0]] + yield [s[2], s[0], s[1]] + return + for p in permutations(s): + if not any(i == j for i, j in zip(p, s)): + yield list(p) + + +def generate_derangements(s): + """ + Return unique derangements of the elements of iterable ``s``. + + Examples + ======== + + >>> from sympy.utilities.iterables import generate_derangements + >>> list(generate_derangements([0, 1, 2])) + [[1, 2, 0], [2, 0, 1]] + >>> list(generate_derangements([0, 1, 2, 2])) + [[2, 2, 0, 1], [2, 2, 1, 0]] + >>> list(generate_derangements([0, 1, 1])) + [] + + See Also + ======== + + sympy.functions.combinatorial.factorials.subfactorial + + """ + if not has_dups(s): + yield from _set_derangements(s) + else: + for p in multiset_derangements(s): + yield list(p) + + +def necklaces(n, k, free=False): + """ + A routine to generate necklaces that may (free=True) or may not + (free=False) be turned over to be viewed. The "necklaces" returned + are comprised of ``n`` integers (beads) with ``k`` different + values (colors). Only unique necklaces are returned. + + Examples + ======== + + >>> from sympy.utilities.iterables import necklaces, bracelets + >>> def show(s, i): + ... return ''.join(s[j] for j in i) + + The "unrestricted necklace" is sometimes also referred to as a + "bracelet" (an object that can be turned over, a sequence that can + be reversed) and the term "necklace" is used to imply a sequence + that cannot be reversed. So ACB == ABC for a bracelet (rotate and + reverse) while the two are different for a necklace since rotation + alone cannot make the two sequences the same. + + (mnemonic: Bracelets can be viewed Backwards, but Not Necklaces.) + + >>> B = [show('ABC', i) for i in bracelets(3, 3)] + >>> N = [show('ABC', i) for i in necklaces(3, 3)] + >>> set(N) - set(B) + {'ACB'} + + >>> list(necklaces(4, 2)) + [(0, 0, 0, 0), (0, 0, 0, 1), (0, 0, 1, 1), + (0, 1, 0, 1), (0, 1, 1, 1), (1, 1, 1, 1)] + + >>> [show('.o', i) for i in bracelets(4, 2)] + ['....', '...o', '..oo', '.o.o', '.ooo', 'oooo'] + + References + ========== + + .. [1] https://mathworld.wolfram.com/Necklace.html + + .. [2] Frank Ruskey, Carla Savage, and Terry Min Yih Wang, + Generating necklaces, Journal of Algorithms 13 (1992), 414-430; + https://doi.org/10.1016/0196-6774(92)90047-G + + """ + # The FKM algorithm + if k == 0 and n > 0: + return + a = [0]*n + yield tuple(a) + if n == 0: + return + while True: + i = n - 1 + while a[i] == k - 1: + i -= 1 + if i == -1: + return + a[i] += 1 + for j in range(n - i - 1): + a[j + i + 1] = a[j] + if n % (i + 1) == 0 and (not free or all(a <= a[j::-1] + a[-1:j:-1] for j in range(n - 1))): + # No need to test j = n - 1. + yield tuple(a) + + +def bracelets(n, k): + """Wrapper to necklaces to return a free (unrestricted) necklace.""" + return necklaces(n, k, free=True) + + +def generate_oriented_forest(n): + """ + This algorithm generates oriented forests. + + An oriented graph is a directed graph having no symmetric pair of directed + edges. A forest is an acyclic graph, i.e., it has no cycles. A forest can + also be described as a disjoint union of trees, which are graphs in which + any two vertices are connected by exactly one simple path. + + Examples + ======== + + >>> from sympy.utilities.iterables import generate_oriented_forest + >>> list(generate_oriented_forest(4)) + [[0, 1, 2, 3], [0, 1, 2, 2], [0, 1, 2, 1], [0, 1, 2, 0], \ + [0, 1, 1, 1], [0, 1, 1, 0], [0, 1, 0, 1], [0, 1, 0, 0], [0, 0, 0, 0]] + + References + ========== + + .. [1] T. Beyer and S.M. Hedetniemi: constant time generation of + rooted trees, SIAM J. Computing Vol. 9, No. 4, November 1980 + + .. [2] https://stackoverflow.com/questions/1633833/oriented-forest-taocp-algorithm-in-python + + """ + P = list(range(-1, n)) + while True: + yield P[1:] + if P[n] > 0: + P[n] = P[P[n]] + else: + for p in range(n - 1, 0, -1): + if P[p] != 0: + target = P[p] - 1 + for q in range(p - 1, 0, -1): + if P[q] == target: + break + offset = p - q + for i in range(p, n + 1): + P[i] = P[i - offset] + break + else: + break + + +def minlex(seq, directed=True, key=None): + r""" + Return the rotation of the sequence in which the lexically smallest + elements appear first, e.g. `cba \rightarrow acb`. + + The sequence returned is a tuple, unless the input sequence is a string + in which case a string is returned. + + If ``directed`` is False then the smaller of the sequence and the + reversed sequence is returned, e.g. `cba \rightarrow abc`. + + If ``key`` is not None then it is used to extract a comparison key from each element in iterable. + + Examples + ======== + + >>> from sympy.combinatorics.polyhedron import minlex + >>> minlex((1, 2, 0)) + (0, 1, 2) + >>> minlex((1, 0, 2)) + (0, 2, 1) + >>> minlex((1, 0, 2), directed=False) + (0, 1, 2) + + >>> minlex('11010011000', directed=True) + '00011010011' + >>> minlex('11010011000', directed=False) + '00011001011' + + >>> minlex(('bb', 'aaa', 'c', 'a')) + ('a', 'bb', 'aaa', 'c') + >>> minlex(('bb', 'aaa', 'c', 'a'), key=len) + ('c', 'a', 'bb', 'aaa') + + """ + from sympy.functions.elementary.miscellaneous import Id + if key is None: key = Id + best = rotate_left(seq, least_rotation(seq, key=key)) + if not directed: + rseq = seq[::-1] + rbest = rotate_left(rseq, least_rotation(rseq, key=key)) + best = min(best, rbest, key=key) + + # Convert to tuple, unless we started with a string. + return tuple(best) if not isinstance(seq, str) else best + + +def runs(seq, op=gt): + """Group the sequence into lists in which successive elements + all compare the same with the comparison operator, ``op``: + op(seq[i + 1], seq[i]) is True from all elements in a run. + + Examples + ======== + + >>> from sympy.utilities.iterables import runs + >>> from operator import ge + >>> runs([0, 1, 2, 2, 1, 4, 3, 2, 2]) + [[0, 1, 2], [2], [1, 4], [3], [2], [2]] + >>> runs([0, 1, 2, 2, 1, 4, 3, 2, 2], op=ge) + [[0, 1, 2, 2], [1, 4], [3], [2, 2]] + """ + cycles = [] + seq = iter(seq) + try: + run = [next(seq)] + except StopIteration: + return [] + while True: + try: + ei = next(seq) + except StopIteration: + break + if op(ei, run[-1]): + run.append(ei) + continue + else: + cycles.append(run) + run = [ei] + if run: + cycles.append(run) + return cycles + + +def sequence_partitions(l, n, /): + r"""Returns the partition of sequence $l$ into $n$ bins + + Explanation + =========== + + Given the sequence $l_1 \cdots l_m \in V^+$ where + $V^+$ is the Kleene plus of $V$ + + The set of $n$ partitions of $l$ is defined as: + + .. math:: + \{(s_1, \cdots, s_n) | s_1 \in V^+, \cdots, s_n \in V^+, + s_1 \cdots s_n = l_1 \cdots l_m\} + + Parameters + ========== + + l : Sequence[T] + A nonempty sequence of any Python objects + + n : int + A positive integer + + Yields + ====== + + out : list[Sequence[T]] + A list of sequences with concatenation equals $l$. + This should conform with the type of $l$. + + Examples + ======== + + >>> from sympy.utilities.iterables import sequence_partitions + >>> for out in sequence_partitions([1, 2, 3, 4], 2): + ... print(out) + [[1], [2, 3, 4]] + [[1, 2], [3, 4]] + [[1, 2, 3], [4]] + + Notes + ===== + + This is modified version of EnricoGiampieri's partition generator + from https://stackoverflow.com/questions/13131491/partition-n-items-into-k-bins-in-python-lazily + + See Also + ======== + + sequence_partitions_empty + """ + # Asserting l is nonempty is done only for sanity check + if n == 1 and l: + yield [l] + return + for i in range(1, len(l)): + for part in sequence_partitions(l[i:], n - 1): + yield [l[:i]] + part + + +def sequence_partitions_empty(l, n, /): + r"""Returns the partition of sequence $l$ into $n$ bins with + empty sequence + + Explanation + =========== + + Given the sequence $l_1 \cdots l_m \in V^*$ where + $V^*$ is the Kleene star of $V$ + + The set of $n$ partitions of $l$ is defined as: + + .. math:: + \{(s_1, \cdots, s_n) | s_1 \in V^*, \cdots, s_n \in V^*, + s_1 \cdots s_n = l_1 \cdots l_m\} + + There are more combinations than :func:`sequence_partitions` because + empty sequence can fill everywhere, so we try to provide different + utility for this. + + Parameters + ========== + + l : Sequence[T] + A sequence of any Python objects (can be possibly empty) + + n : int + A positive integer + + Yields + ====== + + out : list[Sequence[T]] + A list of sequences with concatenation equals $l$. + This should conform with the type of $l$. + + Examples + ======== + + >>> from sympy.utilities.iterables import sequence_partitions_empty + >>> for out in sequence_partitions_empty([1, 2, 3, 4], 2): + ... print(out) + [[], [1, 2, 3, 4]] + [[1], [2, 3, 4]] + [[1, 2], [3, 4]] + [[1, 2, 3], [4]] + [[1, 2, 3, 4], []] + + See Also + ======== + + sequence_partitions + """ + if n < 1: + return + if n == 1: + yield [l] + return + for i in range(0, len(l) + 1): + for part in sequence_partitions_empty(l[i:], n - 1): + yield [l[:i]] + part + + +def kbins(l, k, ordered=None): + """ + Return sequence ``l`` partitioned into ``k`` bins. + + Examples + ======== + + The default is to give the items in the same order, but grouped + into k partitions without any reordering: + + >>> from sympy.utilities.iterables import kbins + >>> for p in kbins(list(range(5)), 2): + ... print(p) + ... + [[0], [1, 2, 3, 4]] + [[0, 1], [2, 3, 4]] + [[0, 1, 2], [3, 4]] + [[0, 1, 2, 3], [4]] + + The ``ordered`` flag is either None (to give the simple partition + of the elements) or is a 2 digit integer indicating whether the order of + the bins and the order of the items in the bins matters. Given:: + + A = [[0], [1, 2]] + B = [[1, 2], [0]] + C = [[2, 1], [0]] + D = [[0], [2, 1]] + + the following values for ``ordered`` have the shown meanings:: + + 00 means A == B == C == D + 01 means A == B + 10 means A == D + 11 means A == A + + >>> for ordered_flag in [None, 0, 1, 10, 11]: + ... print('ordered = %s' % ordered_flag) + ... for p in kbins(list(range(3)), 2, ordered=ordered_flag): + ... print(' %s' % p) + ... + ordered = None + [[0], [1, 2]] + [[0, 1], [2]] + ordered = 0 + [[0, 1], [2]] + [[0, 2], [1]] + [[0], [1, 2]] + ordered = 1 + [[0], [1, 2]] + [[0], [2, 1]] + [[1], [0, 2]] + [[1], [2, 0]] + [[2], [0, 1]] + [[2], [1, 0]] + ordered = 10 + [[0, 1], [2]] + [[2], [0, 1]] + [[0, 2], [1]] + [[1], [0, 2]] + [[0], [1, 2]] + [[1, 2], [0]] + ordered = 11 + [[0], [1, 2]] + [[0, 1], [2]] + [[0], [2, 1]] + [[0, 2], [1]] + [[1], [0, 2]] + [[1, 0], [2]] + [[1], [2, 0]] + [[1, 2], [0]] + [[2], [0, 1]] + [[2, 0], [1]] + [[2], [1, 0]] + [[2, 1], [0]] + + See Also + ======== + + partitions, multiset_partitions + + """ + if ordered is None: + yield from sequence_partitions(l, k) + elif ordered == 11: + for pl in multiset_permutations(l): + pl = list(pl) + yield from sequence_partitions(pl, k) + elif ordered == 00: + yield from multiset_partitions(l, k) + elif ordered == 10: + for p in multiset_partitions(l, k): + for perm in permutations(p): + yield list(perm) + elif ordered == 1: + for kgot, p in partitions(len(l), k, size=True): + if kgot != k: + continue + for li in multiset_permutations(l): + rv = [] + i = j = 0 + li = list(li) + for size, multiplicity in sorted(p.items()): + for m in range(multiplicity): + j = i + size + rv.append(li[i: j]) + i = j + yield rv + else: + raise ValueError( + 'ordered must be one of 00, 01, 10 or 11, not %s' % ordered) + + +def permute_signs(t): + """Return iterator in which the signs of non-zero elements + of t are permuted. + + Examples + ======== + + >>> from sympy.utilities.iterables import permute_signs + >>> list(permute_signs((0, 1, 2))) + [(0, 1, 2), (0, -1, 2), (0, 1, -2), (0, -1, -2)] + """ + for signs in product(*[(1, -1)]*(len(t) - t.count(0))): + signs = list(signs) + yield type(t)([i*signs.pop() if i else i for i in t]) + + +def signed_permutations(t): + """Return iterator in which the signs of non-zero elements + of t and the order of the elements are permuted and all + returned values are unique. + + Examples + ======== + + >>> from sympy.utilities.iterables import signed_permutations + >>> list(signed_permutations((0, 1, 2))) + [(0, 1, 2), (0, -1, 2), (0, 1, -2), (0, -1, -2), (0, 2, 1), + (0, -2, 1), (0, 2, -1), (0, -2, -1), (1, 0, 2), (-1, 0, 2), + (1, 0, -2), (-1, 0, -2), (1, 2, 0), (-1, 2, 0), (1, -2, 0), + (-1, -2, 0), (2, 0, 1), (-2, 0, 1), (2, 0, -1), (-2, 0, -1), + (2, 1, 0), (-2, 1, 0), (2, -1, 0), (-2, -1, 0)] + """ + return (type(t)(i) for j in multiset_permutations(t) + for i in permute_signs(j)) + + +def rotations(s, dir=1): + """Return a generator giving the items in s as list where + each subsequent list has the items rotated to the left (default) + or right (``dir=-1``) relative to the previous list. + + Examples + ======== + + >>> from sympy import rotations + >>> list(rotations([1,2,3])) + [[1, 2, 3], [2, 3, 1], [3, 1, 2]] + >>> list(rotations([1,2,3], -1)) + [[1, 2, 3], [3, 1, 2], [2, 3, 1]] + """ + seq = list(s) + for i in range(len(seq)): + yield seq + seq = rotate_left(seq, dir) + + +def roundrobin(*iterables): + """roundrobin recipe taken from itertools documentation: + https://docs.python.org/3/library/itertools.html#itertools-recipes + + roundrobin('ABC', 'D', 'EF') --> A D E B F C + + Recipe credited to George Sakkis + """ + nexts = cycle(iter(it).__next__ for it in iterables) + + pending = len(iterables) + while pending: + try: + for nxt in nexts: + yield nxt() + except StopIteration: + pending -= 1 + nexts = cycle(islice(nexts, pending)) + + + +class NotIterable: + """ + Use this as mixin when creating a class which is not supposed to + return true when iterable() is called on its instances because + calling list() on the instance, for example, would result in + an infinite loop. + """ + pass + + +def iterable(i, exclude=(str, dict, NotIterable)): + """ + Return a boolean indicating whether ``i`` is SymPy iterable. + True also indicates that the iterator is finite, e.g. you can + call list(...) on the instance. + + When SymPy is working with iterables, it is almost always assuming + that the iterable is not a string or a mapping, so those are excluded + by default. If you want a pure Python definition, make exclude=None. To + exclude multiple items, pass them as a tuple. + + You can also set the _iterable attribute to True or False on your class, + which will override the checks here, including the exclude test. + + As a rule of thumb, some SymPy functions use this to check if they should + recursively map over an object. If an object is technically iterable in + the Python sense but does not desire this behavior (e.g., because its + iteration is not finite, or because iteration might induce an unwanted + computation), it should disable it by setting the _iterable attribute to False. + + See also: is_sequence + + Examples + ======== + + >>> from sympy.utilities.iterables import iterable + >>> from sympy import Tuple + >>> things = [[1], (1,), set([1]), Tuple(1), (j for j in [1, 2]), {1:2}, '1', 1] + >>> for i in things: + ... print('%s %s' % (iterable(i), type(i))) + True <... 'list'> + True <... 'tuple'> + True <... 'set'> + True + True <... 'generator'> + False <... 'dict'> + False <... 'str'> + False <... 'int'> + + >>> iterable({}, exclude=None) + True + >>> iterable({}, exclude=str) + True + >>> iterable("no", exclude=str) + False + + """ + if hasattr(i, '_iterable'): + return i._iterable + try: + iter(i) + except TypeError: + return False + if exclude: + return not isinstance(i, exclude) + return True + + +def is_sequence(i, include=None): + """ + Return a boolean indicating whether ``i`` is a sequence in the SymPy + sense. If anything that fails the test below should be included as + being a sequence for your application, set 'include' to that object's + type; multiple types should be passed as a tuple of types. + + Note: although generators can generate a sequence, they often need special + handling to make sure their elements are captured before the generator is + exhausted, so these are not included by default in the definition of a + sequence. + + See also: iterable + + Examples + ======== + + >>> from sympy.utilities.iterables import is_sequence + >>> from types import GeneratorType + >>> is_sequence([]) + True + >>> is_sequence(set()) + False + >>> is_sequence('abc') + False + >>> is_sequence('abc', include=str) + True + >>> generator = (c for c in 'abc') + >>> is_sequence(generator) + False + >>> is_sequence(generator, include=(str, GeneratorType)) + True + + """ + return (hasattr(i, '__getitem__') and + iterable(i) or + bool(include) and + isinstance(i, include)) + + +@deprecated( + """ + Using postorder_traversal from the sympy.utilities.iterables submodule is + deprecated. + + Instead, use postorder_traversal from the top-level sympy namespace, like + + sympy.postorder_traversal + """, + deprecated_since_version="1.10", + active_deprecations_target="deprecated-traversal-functions-moved") +def postorder_traversal(node, keys=None): + from sympy.core.traversal import postorder_traversal as _postorder_traversal + return _postorder_traversal(node, keys=keys) + + +@deprecated( + """ + Using interactive_traversal from the sympy.utilities.iterables submodule + is deprecated. + + Instead, use interactive_traversal from the top-level sympy namespace, + like + + sympy.interactive_traversal + """, + deprecated_since_version="1.10", + active_deprecations_target="deprecated-traversal-functions-moved") +def interactive_traversal(expr): + from sympy.interactive.traversal import interactive_traversal as _interactive_traversal + return _interactive_traversal(expr) + + +@deprecated( + """ + Importing default_sort_key from sympy.utilities.iterables is deprecated. + Use from sympy import default_sort_key instead. + """, + deprecated_since_version="1.10", +active_deprecations_target="deprecated-sympy-core-compatibility", +) +def default_sort_key(*args, **kwargs): + from sympy import default_sort_key as _default_sort_key + return _default_sort_key(*args, **kwargs) + + +@deprecated( + """ + Importing default_sort_key from sympy.utilities.iterables is deprecated. + Use from sympy import default_sort_key instead. + """, + deprecated_since_version="1.10", +active_deprecations_target="deprecated-sympy-core-compatibility", +) +def ordered(*args, **kwargs): + from sympy import ordered as _ordered + return _ordered(*args, **kwargs) diff --git a/phivenv/Lib/site-packages/sympy/utilities/lambdify.py b/phivenv/Lib/site-packages/sympy/utilities/lambdify.py new file mode 100644 index 0000000000000000000000000000000000000000..6807fdaec10963a60ba88d6e1ec6b2951c19241e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/lambdify.py @@ -0,0 +1,1592 @@ +""" +This module provides convenient functions to transform SymPy expressions to +lambda functions which can be used to calculate numerical values very fast. +""" + +from __future__ import annotations +from typing import Any + +import builtins +import inspect +import keyword +import textwrap +import linecache +import weakref + +# Required despite static analysis claiming it is not used +from sympy.external import import_module # noqa:F401 +from sympy.utilities.exceptions import sympy_deprecation_warning +from sympy.utilities.decorator import doctest_depends_on +from sympy.utilities.iterables import (is_sequence, iterable, + NotIterable, flatten) +from sympy.utilities.misc import filldedent + + +__doctest_requires__ = {('lambdify',): ['numpy', 'tensorflow']} + + +# Default namespaces, letting us define translations that can't be defined +# by simple variable maps, like I => 1j +MATH_DEFAULT: dict[str, Any] = {} +CMATH_DEFAULT: dict[str,Any] = {} +MPMATH_DEFAULT: dict[str, Any] = {} +NUMPY_DEFAULT: dict[str, Any] = {"I": 1j} +SCIPY_DEFAULT: dict[str, Any] = {"I": 1j} +CUPY_DEFAULT: dict[str, Any] = {"I": 1j} +JAX_DEFAULT: dict[str, Any] = {"I": 1j} +TENSORFLOW_DEFAULT: dict[str, Any] = {} +TORCH_DEFAULT: dict[str, Any] = {"I": 1j} +SYMPY_DEFAULT: dict[str, Any] = {} +NUMEXPR_DEFAULT: dict[str, Any] = {} + +# These are the namespaces the lambda functions will use. +# These are separate from the names above because they are modified +# throughout this file, whereas the defaults should remain unmodified. + +MATH = MATH_DEFAULT.copy() +CMATH = CMATH_DEFAULT.copy() +MPMATH = MPMATH_DEFAULT.copy() +NUMPY = NUMPY_DEFAULT.copy() +SCIPY = SCIPY_DEFAULT.copy() +CUPY = CUPY_DEFAULT.copy() +JAX = JAX_DEFAULT.copy() +TENSORFLOW = TENSORFLOW_DEFAULT.copy() +TORCH = TORCH_DEFAULT.copy() +SYMPY = SYMPY_DEFAULT.copy() +NUMEXPR = NUMEXPR_DEFAULT.copy() + + +# Mappings between SymPy and other modules function names. +MATH_TRANSLATIONS = { + "ceiling": "ceil", + "E": "e", + "ln": "log", +} + +CMATH_TRANSLATIONS: dict[str, str] = {} + +# NOTE: This dictionary is reused in Function._eval_evalf to allow subclasses +# of Function to automatically evalf. +MPMATH_TRANSLATIONS = { + "Abs": "fabs", + "elliptic_k": "ellipk", + "elliptic_f": "ellipf", + "elliptic_e": "ellipe", + "elliptic_pi": "ellippi", + "ceiling": "ceil", + "chebyshevt": "chebyt", + "chebyshevu": "chebyu", + "assoc_legendre": "legenp", + "E": "e", + "I": "j", + "ln": "log", + #"lowergamma":"lower_gamma", + "oo": "inf", + #"uppergamma":"upper_gamma", + "LambertW": "lambertw", + "MutableDenseMatrix": "matrix", + "ImmutableDenseMatrix": "matrix", + "conjugate": "conj", + "dirichlet_eta": "altzeta", + "Ei": "ei", + "Shi": "shi", + "Chi": "chi", + "Si": "si", + "Ci": "ci", + "RisingFactorial": "rf", + "FallingFactorial": "ff", + "betainc_regularized": "betainc", +} + +NUMPY_TRANSLATIONS: dict[str, str] = { + "Heaviside": "heaviside", +} +SCIPY_TRANSLATIONS: dict[str, str] = { + "jn" : "spherical_jn", + "yn" : "spherical_yn" +} +CUPY_TRANSLATIONS: dict[str, str] = {} +JAX_TRANSLATIONS: dict[str, str] = {} + +TENSORFLOW_TRANSLATIONS: dict[str, str] = {} +TORCH_TRANSLATIONS: dict[str, str] = {} + +NUMEXPR_TRANSLATIONS: dict[str, str] = {} + +# Available modules: +MODULES = { + "math": (MATH, MATH_DEFAULT, MATH_TRANSLATIONS, ("from math import *",)), + "cmath": (CMATH, CMATH_DEFAULT, CMATH_TRANSLATIONS, ("import cmath; from cmath import *",)), + "mpmath": (MPMATH, MPMATH_DEFAULT, MPMATH_TRANSLATIONS, ("from mpmath import *",)), + "numpy": (NUMPY, NUMPY_DEFAULT, NUMPY_TRANSLATIONS, ("import numpy; from numpy import *; from numpy.linalg import *",)), + "scipy": (SCIPY, SCIPY_DEFAULT, SCIPY_TRANSLATIONS, ("import scipy; import numpy; from scipy.special import *",)), + "cupy": (CUPY, CUPY_DEFAULT, CUPY_TRANSLATIONS, ("import cupy",)), + "jax": (JAX, JAX_DEFAULT, JAX_TRANSLATIONS, ("import jax",)), + "tensorflow": (TENSORFLOW, TENSORFLOW_DEFAULT, TENSORFLOW_TRANSLATIONS, ("import tensorflow",)), + "torch": (TORCH, TORCH_DEFAULT, TORCH_TRANSLATIONS, ("import torch",)), + "sympy": (SYMPY, SYMPY_DEFAULT, {}, ( + "from sympy.functions import *", + "from sympy.matrices import *", + "from sympy import Integral, pi, oo, nan, zoo, E, I",)), + "numexpr" : (NUMEXPR, NUMEXPR_DEFAULT, NUMEXPR_TRANSLATIONS, + ("import_module('numexpr')", )), +} + + +def _import(module, reload=False): + """ + Creates a global translation dictionary for module. + + The argument module has to be one of the following strings: "math","cmath" + "mpmath", "numpy", "sympy", "tensorflow", "jax". + These dictionaries map names of Python functions to their equivalent in + other modules. + """ + try: + namespace, namespace_default, translations, import_commands = MODULES[ + module] + except KeyError: + raise NameError( + "'%s' module cannot be used for lambdification" % module) + + # Clear namespace or exit + if namespace != namespace_default: + # The namespace was already generated, don't do it again if not forced. + if reload: + namespace.clear() + namespace.update(namespace_default) + else: + return + + for import_command in import_commands: + if import_command.startswith('import_module'): + module = eval(import_command) + + if module is not None: + namespace.update(module.__dict__) + continue + else: + try: + exec(import_command, {}, namespace) + continue + except ImportError: + pass + + raise ImportError( + "Cannot import '%s' with '%s' command" % (module, import_command)) + + # Add translated names to namespace + for sympyname, translation in translations.items(): + namespace[sympyname] = namespace[translation] + + # For computing the modulus of a SymPy expression we use the builtin abs + # function, instead of the previously used fabs function for all + # translation modules. This is because the fabs function in the math + # module does not accept complex valued arguments. (see issue 9474). The + # only exception, where we don't use the builtin abs function is the + # mpmath translation module, because mpmath.fabs returns mpf objects in + # contrast to abs(). + if 'Abs' not in namespace: + namespace['Abs'] = abs + +# Used for dynamically generated filenames that are inserted into the +# linecache. +_lambdify_generated_counter = 1 + + +@doctest_depends_on(modules=('numpy', 'scipy', 'tensorflow',), python_version=(3,)) +def lambdify(args, expr, modules=None, printer=None, use_imps=True, + dummify=False, cse=False, docstring_limit=1000): + """Convert a SymPy expression into a function that allows for fast + numeric evaluation. + + .. warning:: + This function uses ``exec``, and thus should not be used on + unsanitized input. + + .. deprecated:: 1.7 + Passing a set for the *args* parameter is deprecated as sets are + unordered. Use an ordered iterable such as a list or tuple. + + Explanation + =========== + + For example, to convert the SymPy expression ``sin(x) + cos(x)`` to an + equivalent NumPy function that numerically evaluates it: + + >>> from sympy import sin, cos, symbols, lambdify + >>> import numpy as np + >>> x = symbols('x') + >>> expr = sin(x) + cos(x) + >>> expr + sin(x) + cos(x) + >>> f = lambdify(x, expr, 'numpy') + >>> a = np.array([1, 2]) + >>> f(a) + [1.38177329 0.49315059] + + The primary purpose of this function is to provide a bridge from SymPy + expressions to numerical libraries such as NumPy, SciPy, NumExpr, mpmath, + and tensorflow. In general, SymPy functions do not work with objects from + other libraries, such as NumPy arrays, and functions from numeric + libraries like NumPy or mpmath do not work on SymPy expressions. + ``lambdify`` bridges the two by converting a SymPy expression to an + equivalent numeric function. + + The basic workflow with ``lambdify`` is to first create a SymPy expression + representing whatever mathematical function you wish to evaluate. This + should be done using only SymPy functions and expressions. Then, use + ``lambdify`` to convert this to an equivalent function for numerical + evaluation. For instance, above we created ``expr`` using the SymPy symbol + ``x`` and SymPy functions ``sin`` and ``cos``, then converted it to an + equivalent NumPy function ``f``, and called it on a NumPy array ``a``. + + Parameters + ========== + + args : List[Symbol] + A variable or a list of variables whose nesting represents the + nesting of the arguments that will be passed to the function. + + Variables can be symbols, undefined functions, or matrix symbols. + + >>> from sympy import Eq + >>> from sympy.abc import x, y, z + + The list of variables should match the structure of how the + arguments will be passed to the function. Simply enclose the + parameters as they will be passed in a list. + + To call a function like ``f(x)`` then ``[x]`` + should be the first argument to ``lambdify``; for this + case a single ``x`` can also be used: + + >>> f = lambdify(x, x + 1) + >>> f(1) + 2 + >>> f = lambdify([x], x + 1) + >>> f(1) + 2 + + To call a function like ``f(x, y)`` then ``[x, y]`` will + be the first argument of the ``lambdify``: + + >>> f = lambdify([x, y], x + y) + >>> f(1, 1) + 2 + + To call a function with a single 3-element tuple like + ``f((x, y, z))`` then ``[(x, y, z)]`` will be the first + argument of the ``lambdify``: + + >>> f = lambdify([(x, y, z)], Eq(z**2, x**2 + y**2)) + >>> f((3, 4, 5)) + True + + If two args will be passed and the first is a scalar but + the second is a tuple with two arguments then the items + in the list should match that structure: + + >>> f = lambdify([x, (y, z)], x + y + z) + >>> f(1, (2, 3)) + 6 + + expr : Expr + An expression, list of expressions, or matrix to be evaluated. + + Lists may be nested. + If the expression is a list, the output will also be a list. + + >>> f = lambdify(x, [x, [x + 1, x + 2]]) + >>> f(1) + [1, [2, 3]] + + If it is a matrix, an array will be returned (for the NumPy module). + + >>> from sympy import Matrix + >>> f = lambdify(x, Matrix([x, x + 1])) + >>> f(1) + [[1] + [2]] + + Note that the argument order here (variables then expression) is used + to emulate the Python ``lambda`` keyword. ``lambdify(x, expr)`` works + (roughly) like ``lambda x: expr`` + (see :ref:`lambdify-how-it-works` below). + + modules : str, optional + Specifies the numeric library to use. + + If not specified, *modules* defaults to: + + - ``["scipy", "numpy"]`` if SciPy is installed + - ``["numpy"]`` if only NumPy is installed + - ``["math","cmath", "mpmath", "sympy"]`` if neither is installed. + + That is, SymPy functions are replaced as far as possible by + either ``scipy`` or ``numpy`` functions if available, and Python's + standard library ``math`` and ``cmath``, or ``mpmath`` functions otherwise. + + *modules* can be one of the following types: + + - The strings ``"math"``, ``"cmath"``, ``"mpmath"``, ``"numpy"``, ``"numexpr"``, + ``"scipy"``, ``"sympy"``, or ``"tensorflow"`` or ``"jax"``. This uses the + corresponding printer and namespace mapping for that module. + - A module (e.g., ``math``). This uses the global namespace of the + module. If the module is one of the above known modules, it will + also use the corresponding printer and namespace mapping + (i.e., ``modules=numpy`` is equivalent to ``modules="numpy"``). + - A dictionary that maps names of SymPy functions to arbitrary + functions + (e.g., ``{'sin': custom_sin}``). + - A list that contains a mix of the arguments above, with higher + priority given to entries appearing first + (e.g., to use the NumPy module but override the ``sin`` function + with a custom version, you can use + ``[{'sin': custom_sin}, 'numpy']``). + + dummify : bool, optional + Whether or not the variables in the provided expression that are not + valid Python identifiers are substituted with dummy symbols. + + This allows for undefined functions like ``Function('f')(t)`` to be + supplied as arguments. By default, the variables are only dummified + if they are not valid Python identifiers. + + Set ``dummify=True`` to replace all arguments with dummy symbols + (if ``args`` is not a string) - for example, to ensure that the + arguments do not redefine any built-in names. + + cse : bool, or callable, optional + Large expressions can be computed more efficiently when + common subexpressions are identified and precomputed before + being used multiple time. Finding the subexpressions will make + creation of the 'lambdify' function slower, however. + + When ``True``, ``sympy.simplify.cse`` is used, otherwise (the default) + the user may pass a function matching the ``cse`` signature. + + docstring_limit : int or None + When lambdifying large expressions, a significant proportion of the time + spent inside ``lambdify`` is spent producing a string representation of + the expression for use in the automatically generated docstring of the + returned function. For expressions containing hundreds or more nodes the + resulting docstring often becomes so long and dense that it is difficult + to read. To reduce the runtime of lambdify, the rendering of the full + expression inside the docstring can be disabled. + + When ``None``, the full expression is rendered in the docstring. When + ``0`` or a negative ``int``, an ellipsis is rendering in the docstring + instead of the expression. When a strictly positive ``int``, if the + number of nodes in the expression exceeds ``docstring_limit`` an + ellipsis is rendered in the docstring, otherwise a string representation + of the expression is rendered as normal. The default is ``1000``. + + Examples + ======== + + >>> from sympy.utilities.lambdify import implemented_function + >>> from sympy import sqrt, sin, Matrix + >>> from sympy import Function + >>> from sympy.abc import w, x, y, z + + >>> f = lambdify(x, x**2) + >>> f(2) + 4 + >>> f = lambdify((x, y, z), [z, y, x]) + >>> f(1,2,3) + [3, 2, 1] + >>> f = lambdify(x, sqrt(x)) + >>> f(4) + 2.0 + >>> f = lambdify((x, y), sin(x*y)**2) + >>> f(0, 5) + 0.0 + >>> row = lambdify((x, y), Matrix((x, x + y)).T, modules='sympy') + >>> row(1, 2) + Matrix([[1, 3]]) + + ``lambdify`` can be used to translate SymPy expressions into mpmath + functions. This may be preferable to using ``evalf`` (which uses mpmath on + the backend) in some cases. + + >>> f = lambdify(x, sin(x), 'mpmath') + >>> f(1) + 0.8414709848078965 + + Tuple arguments are handled and the lambdified function should + be called with the same type of arguments as were used to create + the function: + + >>> f = lambdify((x, (y, z)), x + y) + >>> f(1, (2, 4)) + 3 + + The ``flatten`` function can be used to always work with flattened + arguments: + + >>> from sympy.utilities.iterables import flatten + >>> args = w, (x, (y, z)) + >>> vals = 1, (2, (3, 4)) + >>> f = lambdify(flatten(args), w + x + y + z) + >>> f(*flatten(vals)) + 10 + + Functions present in ``expr`` can also carry their own numerical + implementations, in a callable attached to the ``_imp_`` attribute. This + can be used with undefined functions using the ``implemented_function`` + factory: + + >>> f = implemented_function(Function('f'), lambda x: x+1) + >>> func = lambdify(x, f(x)) + >>> func(4) + 5 + + ``lambdify`` always prefers ``_imp_`` implementations to implementations + in other namespaces, unless the ``use_imps`` input parameter is False. + + Usage with Tensorflow: + + >>> import tensorflow as tf + >>> from sympy import Max, sin, lambdify + >>> from sympy.abc import x + + >>> f = Max(x, sin(x)) + >>> func = lambdify(x, f, 'tensorflow') + + After tensorflow v2, eager execution is enabled by default. + If you want to get the compatible result across tensorflow v1 and v2 + as same as this tutorial, run this line. + + >>> tf.compat.v1.enable_eager_execution() + + If you have eager execution enabled, you can get the result out + immediately as you can use numpy. + + If you pass tensorflow objects, you may get an ``EagerTensor`` + object instead of value. + + >>> result = func(tf.constant(1.0)) + >>> print(result) + tf.Tensor(1.0, shape=(), dtype=float32) + >>> print(result.__class__) + + + You can use ``.numpy()`` to get the numpy value of the tensor. + + >>> result.numpy() + 1.0 + + >>> var = tf.Variable(2.0) + >>> result = func(var) # also works for tf.Variable and tf.Placeholder + >>> result.numpy() + 2.0 + + And it works with any shape array. + + >>> tensor = tf.constant([[1.0, 2.0], [3.0, 4.0]]) + >>> result = func(tensor) + >>> result.numpy() + [[1. 2.] + [3. 4.]] + + Notes + ===== + + - For functions involving large array calculations, numexpr can provide a + significant speedup over numpy. Please note that the available functions + for numexpr are more limited than numpy but can be expanded with + ``implemented_function`` and user defined subclasses of Function. If + specified, numexpr may be the only option in modules. The official list + of numexpr functions can be found at: + https://numexpr.readthedocs.io/en/latest/user_guide.html#supported-functions + + - In the above examples, the generated functions can accept scalar + values or numpy arrays as arguments. However, in some cases + the generated function relies on the input being a numpy array: + + >>> import numpy + >>> from sympy import Piecewise + >>> from sympy.testing.pytest import ignore_warnings + >>> f = lambdify(x, Piecewise((x, x <= 1), (1/x, x > 1)), "numpy") + + >>> with ignore_warnings(RuntimeWarning): + ... f(numpy.array([-1, 0, 1, 2])) + [-1. 0. 1. 0.5] + + >>> f(0) + Traceback (most recent call last): + ... + ZeroDivisionError: division by zero + + In such cases, the input should be wrapped in a numpy array: + + >>> with ignore_warnings(RuntimeWarning): + ... float(f(numpy.array([0]))) + 0.0 + + Or if numpy functionality is not required another module can be used: + + >>> f = lambdify(x, Piecewise((x, x <= 1), (1/x, x > 1)), "math") + >>> f(0) + 0 + + .. _lambdify-how-it-works: + + How it works + ============ + + When using this function, it helps a great deal to have an idea of what it + is doing. At its core, lambdify is nothing more than a namespace + translation, on top of a special printer that makes some corner cases work + properly. + + To understand lambdify, first we must properly understand how Python + namespaces work. Say we had two files. One called ``sin_cos_sympy.py``, + with + + .. code:: python + + # sin_cos_sympy.py + + from sympy.functions.elementary.trigonometric import (cos, sin) + + def sin_cos(x): + return sin(x) + cos(x) + + + and one called ``sin_cos_numpy.py`` with + + .. code:: python + + # sin_cos_numpy.py + + from numpy import sin, cos + + def sin_cos(x): + return sin(x) + cos(x) + + The two files define an identical function ``sin_cos``. However, in the + first file, ``sin`` and ``cos`` are defined as the SymPy ``sin`` and + ``cos``. In the second, they are defined as the NumPy versions. + + If we were to import the first file and use the ``sin_cos`` function, we + would get something like + + >>> from sin_cos_sympy import sin_cos # doctest: +SKIP + >>> sin_cos(1) # doctest: +SKIP + cos(1) + sin(1) + + On the other hand, if we imported ``sin_cos`` from the second file, we + would get + + >>> from sin_cos_numpy import sin_cos # doctest: +SKIP + >>> sin_cos(1) # doctest: +SKIP + 1.38177329068 + + In the first case we got a symbolic output, because it used the symbolic + ``sin`` and ``cos`` functions from SymPy. In the second, we got a numeric + result, because ``sin_cos`` used the numeric ``sin`` and ``cos`` functions + from NumPy. But notice that the versions of ``sin`` and ``cos`` that were + used was not inherent to the ``sin_cos`` function definition. Both + ``sin_cos`` definitions are exactly the same. Rather, it was based on the + names defined at the module where the ``sin_cos`` function was defined. + + The key point here is that when function in Python references a name that + is not defined in the function, that name is looked up in the "global" + namespace of the module where that function is defined. + + Now, in Python, we can emulate this behavior without actually writing a + file to disk using the ``exec`` function. ``exec`` takes a string + containing a block of Python code, and a dictionary that should contain + the global variables of the module. It then executes the code "in" that + dictionary, as if it were the module globals. The following is equivalent + to the ``sin_cos`` defined in ``sin_cos_sympy.py``: + + >>> import sympy + >>> module_dictionary = {'sin': sympy.sin, 'cos': sympy.cos} + >>> exec(''' + ... def sin_cos(x): + ... return sin(x) + cos(x) + ... ''', module_dictionary) + >>> sin_cos = module_dictionary['sin_cos'] + >>> sin_cos(1) + cos(1) + sin(1) + + and similarly with ``sin_cos_numpy``: + + >>> import numpy + >>> module_dictionary = {'sin': numpy.sin, 'cos': numpy.cos} + >>> exec(''' + ... def sin_cos(x): + ... return sin(x) + cos(x) + ... ''', module_dictionary) + >>> sin_cos = module_dictionary['sin_cos'] + >>> sin_cos(1) + 1.38177329068 + + So now we can get an idea of how ``lambdify`` works. The name "lambdify" + comes from the fact that we can think of something like ``lambdify(x, + sin(x) + cos(x), 'numpy')`` as ``lambda x: sin(x) + cos(x)``, where + ``sin`` and ``cos`` come from the ``numpy`` namespace. This is also why + the symbols argument is first in ``lambdify``, as opposed to most SymPy + functions where it comes after the expression: to better mimic the + ``lambda`` keyword. + + ``lambdify`` takes the input expression (like ``sin(x) + cos(x)``) and + + 1. Converts it to a string + 2. Creates a module globals dictionary based on the modules that are + passed in (by default, it uses the NumPy module) + 3. Creates the string ``"def func({vars}): return {expr}"``, where ``{vars}`` is the + list of variables separated by commas, and ``{expr}`` is the string + created in step 1., then ``exec``s that string with the module globals + namespace and returns ``func``. + + In fact, functions returned by ``lambdify`` support inspection. So you can + see exactly how they are defined by using ``inspect.getsource``, or ``??`` if you + are using IPython or the Jupyter notebook. + + >>> f = lambdify(x, sin(x) + cos(x)) + >>> import inspect + >>> print(inspect.getsource(f)) + def _lambdifygenerated(x): + return sin(x) + cos(x) + + This shows us the source code of the function, but not the namespace it + was defined in. We can inspect that by looking at the ``__globals__`` + attribute of ``f``: + + >>> f.__globals__['sin'] + + >>> f.__globals__['cos'] + + >>> f.__globals__['sin'] is numpy.sin + True + + This shows us that ``sin`` and ``cos`` in the namespace of ``f`` will be + ``numpy.sin`` and ``numpy.cos``. + + Note that there are some convenience layers in each of these steps, but at + the core, this is how ``lambdify`` works. Step 1 is done using the + ``LambdaPrinter`` printers defined in the printing module (see + :mod:`sympy.printing.lambdarepr`). This allows different SymPy expressions + to define how they should be converted to a string for different modules. + You can change which printer ``lambdify`` uses by passing a custom printer + in to the ``printer`` argument. + + Step 2 is augmented by certain translations. There are default + translations for each module, but you can provide your own by passing a + list to the ``modules`` argument. For instance, + + >>> def mysin(x): + ... print('taking the sin of', x) + ... return numpy.sin(x) + ... + >>> f = lambdify(x, sin(x), [{'sin': mysin}, 'numpy']) + >>> f(1) + taking the sin of 1 + 0.8414709848078965 + + The globals dictionary is generated from the list by merging the + dictionary ``{'sin': mysin}`` and the module dictionary for NumPy. The + merging is done so that earlier items take precedence, which is why + ``mysin`` is used above instead of ``numpy.sin``. + + If you want to modify the way ``lambdify`` works for a given function, it + is usually easiest to do so by modifying the globals dictionary as such. + In more complicated cases, it may be necessary to create and pass in a + custom printer. + + Finally, step 3 is augmented with certain convenience operations, such as + the addition of a docstring. + + Understanding how ``lambdify`` works can make it easier to avoid certain + gotchas when using it. For instance, a common mistake is to create a + lambdified function for one module (say, NumPy), and pass it objects from + another (say, a SymPy expression). + + For instance, say we create + + >>> from sympy.abc import x + >>> f = lambdify(x, x + 1, 'numpy') + + Now if we pass in a NumPy array, we get that array plus 1 + + >>> import numpy + >>> a = numpy.array([1, 2]) + >>> f(a) + [2 3] + + But what happens if you make the mistake of passing in a SymPy expression + instead of a NumPy array: + + >>> f(x + 1) + x + 2 + + This worked, but it was only by accident. Now take a different lambdified + function: + + >>> from sympy import sin + >>> g = lambdify(x, x + sin(x), 'numpy') + + This works as expected on NumPy arrays: + + >>> g(a) + [1.84147098 2.90929743] + + But if we try to pass in a SymPy expression, it fails + + >>> g(x + 1) + Traceback (most recent call last): + ... + TypeError: loop of ufunc does not support argument 0 of type Add which has + no callable sin method + + Now, let's look at what happened. The reason this fails is that ``g`` + calls ``numpy.sin`` on the input expression, and ``numpy.sin`` does not + know how to operate on a SymPy object. **As a general rule, NumPy + functions do not know how to operate on SymPy expressions, and SymPy + functions do not know how to operate on NumPy arrays. This is why lambdify + exists: to provide a bridge between SymPy and NumPy.** + + However, why is it that ``f`` did work? That's because ``f`` does not call + any functions, it only adds 1. So the resulting function that is created, + ``def _lambdifygenerated(x): return x + 1`` does not depend on the globals + namespace it is defined in. Thus it works, but only by accident. A future + version of ``lambdify`` may remove this behavior. + + Be aware that certain implementation details described here may change in + future versions of SymPy. The API of passing in custom modules and + printers will not change, but the details of how a lambda function is + created may change. However, the basic idea will remain the same, and + understanding it will be helpful to understanding the behavior of + lambdify. + + **In general: you should create lambdified functions for one module (say, + NumPy), and only pass it input types that are compatible with that module + (say, NumPy arrays).** Remember that by default, if the ``module`` + argument is not provided, ``lambdify`` creates functions using the NumPy + and SciPy namespaces. + """ + from sympy.core.symbol import Symbol + from sympy.core.expr import Expr + + # If the user hasn't specified any modules, use what is available. + if modules is None: + try: + _import("scipy") + except ImportError: + try: + _import("numpy") + except ImportError: + # Use either numpy (if available) or python.math where possible. + # XXX: This leads to different behaviour on different systems and + # might be the reason for irreproducible errors. + modules = ["math", "mpmath", "sympy"] + else: + modules = ["numpy"] + else: + modules = ["numpy", "scipy"] + + # Get the needed namespaces. + namespaces = [] + # First find any function implementations + if use_imps: + namespaces.append(_imp_namespace(expr)) + # Check for dict before iterating + if isinstance(modules, (dict, str)) or not hasattr(modules, '__iter__'): + namespaces.append(modules) + else: + # consistency check + if _module_present('numexpr', modules) and len(modules) > 1: + raise TypeError("numexpr must be the only item in 'modules'") + namespaces += list(modules) + # fill namespace with first having highest priority + namespace = {} + for m in namespaces[::-1]: + buf = _get_namespace(m) + namespace.update(buf) + + if hasattr(expr, "atoms"): + #Try if you can extract symbols from the expression. + #Move on if expr.atoms in not implemented. + syms = expr.atoms(Symbol) + for term in syms: + namespace.update({str(term): term}) + + if printer is None: + if _module_present('mpmath', namespaces): + from sympy.printing.pycode import MpmathPrinter as Printer # type: ignore + elif _module_present('scipy', namespaces): + from sympy.printing.numpy import SciPyPrinter as Printer # type: ignore + elif _module_present('numpy', namespaces): + from sympy.printing.numpy import NumPyPrinter as Printer # type: ignore + elif _module_present('cupy', namespaces): + from sympy.printing.numpy import CuPyPrinter as Printer # type: ignore + elif _module_present('jax', namespaces): + from sympy.printing.numpy import JaxPrinter as Printer # type: ignore + elif _module_present('numexpr', namespaces): + from sympy.printing.lambdarepr import NumExprPrinter as Printer # type: ignore + elif _module_present('tensorflow', namespaces): + from sympy.printing.tensorflow import TensorflowPrinter as Printer # type: ignore + elif _module_present('torch', namespaces): + from sympy.printing.pytorch import TorchPrinter as Printer # type: ignore + elif _module_present('sympy', namespaces): + from sympy.printing.pycode import SymPyPrinter as Printer # type: ignore + elif _module_present('cmath', namespaces): + from sympy.printing.pycode import CmathPrinter as Printer # type: ignore + else: + from sympy.printing.pycode import PythonCodePrinter as Printer # type: ignore + user_functions = {} + for m in namespaces[::-1]: + if isinstance(m, dict): + for k in m: + user_functions[k] = k + printer = Printer({'fully_qualified_modules': False, 'inline': True, + 'allow_unknown_functions': True, + 'user_functions': user_functions}) + + if isinstance(args, set): + sympy_deprecation_warning( + """ +Passing the function arguments to lambdify() as a set is deprecated. This +leads to unpredictable results since sets are unordered. Instead, use a list +or tuple for the function arguments. + """, + deprecated_since_version="1.6.3", + active_deprecations_target="deprecated-lambdify-arguments-set", + ) + + # Get the names of the args, for creating a docstring + iterable_args = (args,) if isinstance(args, Expr) else args + names = [] + + # Grab the callers frame, for getting the names by inspection (if needed) + callers_local_vars = inspect.currentframe().f_back.f_locals.items() # type: ignore + for n, var in enumerate(iterable_args): + if hasattr(var, 'name'): + names.append(var.name) + else: + # It's an iterable. Try to get name by inspection of calling frame. + name_list = [var_name for var_name, var_val in callers_local_vars + if var_val is var] + if len(name_list) == 1: + names.append(name_list[0]) + else: + # Cannot infer name with certainty. arg_# will have to do. + names.append('arg_' + str(n)) + + # Create the function definition code and execute it + funcname = '_lambdifygenerated' + if _module_present('tensorflow', namespaces): + funcprinter = _TensorflowEvaluatorPrinter(printer, dummify) + else: + funcprinter = _EvaluatorPrinter(printer, dummify) + + if cse == True: + from sympy.simplify.cse_main import cse as _cse + cses, _expr = _cse(expr, list=False) + elif callable(cse): + cses, _expr = cse(expr) + else: + cses, _expr = (), expr + funcstr = funcprinter.doprint(funcname, iterable_args, _expr, cses=cses) + + # Collect the module imports from the code printers. + imp_mod_lines = [] + for mod, keys in (getattr(printer, 'module_imports', None) or {}).items(): + for k in keys: + if k not in namespace: + ln = "from %s import %s" % (mod, k) + try: + exec(ln, {}, namespace) + except ImportError: + # Tensorflow 2.0 has issues with importing a specific + # function from its submodule. + # https://github.com/tensorflow/tensorflow/issues/33022 + ln = "%s = %s.%s" % (k, mod, k) + exec(ln, {}, namespace) + imp_mod_lines.append(ln) + + # Provide lambda expression with builtins, and compatible implementation of range + namespace.update({'builtins':builtins, 'range':range}) + + funclocals = {} + global _lambdify_generated_counter + filename = '' % _lambdify_generated_counter + _lambdify_generated_counter += 1 + c = compile(funcstr, filename, 'exec') + exec(c, namespace, funclocals) + # mtime has to be None or else linecache.checkcache will remove it + linecache.cache[filename] = (len(funcstr), None, funcstr.splitlines(True), filename) # type: ignore + + # Remove the entry from the linecache when the object is garbage collected + def cleanup_linecache(filename): + def _cleanup(): + if filename in linecache.cache: + del linecache.cache[filename] + return _cleanup + + func = funclocals[funcname] + + weakref.finalize(func, cleanup_linecache(filename)) + + # Apply the docstring + sig = "func({})".format(", ".join(str(i) for i in names)) + sig = textwrap.fill(sig, subsequent_indent=' '*8) + if _too_large_for_docstring(expr, docstring_limit): + expr_str = "EXPRESSION REDACTED DUE TO LENGTH, (see lambdify's `docstring_limit`)" + src_str = "SOURCE CODE REDACTED DUE TO LENGTH, (see lambdify's `docstring_limit`)" + else: + expr_str = str(expr) + if len(expr_str) > 78: + expr_str = textwrap.wrap(expr_str, 75)[0] + '...' + src_str = funcstr + func.__doc__ = ( + "Created with lambdify. Signature:\n\n" + "{sig}\n\n" + "Expression:\n\n" + "{expr}\n\n" + "Source code:\n\n" + "{src}\n\n" + "Imported modules:\n\n" + "{imp_mods}" + ).format(sig=sig, expr=expr_str, src=src_str, imp_mods='\n'.join(imp_mod_lines)) + return func + +def _module_present(modname, modlist): + if modname in modlist: + return True + for m in modlist: + if hasattr(m, '__name__') and m.__name__ == modname: + return True + return False + +def _get_namespace(m): + """ + This is used by _lambdify to parse its arguments. + """ + if isinstance(m, str): + _import(m) + return MODULES[m][0] + elif isinstance(m, dict): + return m + elif hasattr(m, "__dict__"): + return m.__dict__ + else: + raise TypeError("Argument must be either a string, dict or module but it is: %s" % m) + + +def _recursive_to_string(doprint, arg): + """Functions in lambdify accept both SymPy types and non-SymPy types such as python + lists and tuples. This method ensures that we only call the doprint method of the + printer with SymPy types (so that the printer safely can use SymPy-methods).""" + from sympy.matrices.matrixbase import MatrixBase + from sympy.core.basic import Basic + + if isinstance(arg, (Basic, MatrixBase)): + return doprint(arg) + elif iterable(arg): + if isinstance(arg, list): + left, right = "[", "]" + elif isinstance(arg, tuple): + left, right = "(", ",)" + if not arg: + return "()" + else: + raise NotImplementedError("unhandled type: %s, %s" % (type(arg), arg)) + return left +', '.join(_recursive_to_string(doprint, e) for e in arg) + right + elif isinstance(arg, str): + return arg + else: + return doprint(arg) + + +def lambdastr(args, expr, printer=None, dummify=None): + """ + Returns a string that can be evaluated to a lambda function. + + Examples + ======== + + >>> from sympy.abc import x, y, z + >>> from sympy.utilities.lambdify import lambdastr + >>> lambdastr(x, x**2) + 'lambda x: (x**2)' + >>> lambdastr((x,y,z), [z,y,x]) + 'lambda x,y,z: ([z, y, x])' + + Although tuples may not appear as arguments to lambda in Python 3, + lambdastr will create a lambda function that will unpack the original + arguments so that nested arguments can be handled: + + >>> lambdastr((x, (y, z)), x + y) + 'lambda _0,_1: (lambda x,y,z: (x + y))(_0,_1[0],_1[1])' + """ + # Transforming everything to strings. + from sympy.matrices import DeferredVector + from sympy.core.basic import Basic + from sympy.core.function import (Derivative, Function) + from sympy.core.symbol import (Dummy, Symbol) + from sympy.core.sympify import sympify + + if printer is not None: + if inspect.isfunction(printer): + lambdarepr = printer + else: + if inspect.isclass(printer): + lambdarepr = lambda expr: printer().doprint(expr) + else: + lambdarepr = lambda expr: printer.doprint(expr) + else: + #XXX: This has to be done here because of circular imports + from sympy.printing.lambdarepr import lambdarepr + + def sub_args(args, dummies_dict): + if isinstance(args, str): + return args + elif isinstance(args, DeferredVector): + return str(args) + elif iterable(args): + dummies = flatten([sub_args(a, dummies_dict) for a in args]) + return ",".join(str(a) for a in dummies) + else: + # replace these with Dummy symbols + if isinstance(args, (Function, Symbol, Derivative)): + dummies = Dummy() + dummies_dict.update({args : dummies}) + return str(dummies) + else: + return str(args) + + def sub_expr(expr, dummies_dict): + expr = sympify(expr) + # dict/tuple are sympified to Basic + if isinstance(expr, Basic): + expr = expr.xreplace(dummies_dict) + # list is not sympified to Basic + elif isinstance(expr, list): + expr = [sub_expr(a, dummies_dict) for a in expr] + return expr + + # Transform args + def isiter(l): + return iterable(l, exclude=(str, DeferredVector, NotIterable)) + + def flat_indexes(iterable): + n = 0 + + for el in iterable: + if isiter(el): + for ndeep in flat_indexes(el): + yield (n,) + ndeep + else: + yield (n,) + + n += 1 + + if dummify is None: + dummify = any(isinstance(a, Basic) and + a.atoms(Function, Derivative) for a in ( + args if isiter(args) else [args])) + + if isiter(args) and any(isiter(i) for i in args): + dum_args = [str(Dummy(str(i))) for i in range(len(args))] + + indexed_args = ','.join([ + dum_args[ind[0]] + ''.join(["[%s]" % k for k in ind[1:]]) + for ind in flat_indexes(args)]) + + lstr = lambdastr(flatten(args), expr, printer=printer, dummify=dummify) + + return 'lambda %s: (%s)(%s)' % (','.join(dum_args), lstr, indexed_args) + + dummies_dict = {} + if dummify: + args = sub_args(args, dummies_dict) + else: + if isinstance(args, str): + pass + elif iterable(args, exclude=DeferredVector): + args = ",".join(str(a) for a in args) + + # Transform expr + if dummify: + if isinstance(expr, str): + pass + else: + expr = sub_expr(expr, dummies_dict) + expr = _recursive_to_string(lambdarepr, expr) + return "lambda %s: (%s)" % (args, expr) + +class _EvaluatorPrinter: + def __init__(self, printer=None, dummify=False): + self._dummify = dummify + + #XXX: This has to be done here because of circular imports + from sympy.printing.lambdarepr import LambdaPrinter + + if printer is None: + printer = LambdaPrinter() + + if inspect.isfunction(printer): + self._exprrepr = printer + else: + if inspect.isclass(printer): + printer = printer() + + self._exprrepr = printer.doprint + + #if hasattr(printer, '_print_Symbol'): + # symbolrepr = printer._print_Symbol + + #if hasattr(printer, '_print_Dummy'): + # dummyrepr = printer._print_Dummy + + # Used to print the generated function arguments in a standard way + self._argrepr = LambdaPrinter().doprint + + def doprint(self, funcname, args, expr, *, cses=()): + """ + Returns the function definition code as a string. + """ + from sympy.core.symbol import Dummy + + funcbody = [] + + if not iterable(args): + args = [args] + + if cses: + cses = list(cses) + subvars, subexprs = zip(*cses) + exprs = [expr] + list(subexprs) + argstrs, exprs = self._preprocess(args, exprs, cses=cses) + expr, subexprs = exprs[0], exprs[1:] + cses = zip(subvars, subexprs) + else: + argstrs, expr = self._preprocess(args, expr) + + # Generate argument unpacking and final argument list + funcargs = [] + unpackings = [] + + for argstr in argstrs: + if iterable(argstr): + funcargs.append(self._argrepr(Dummy())) + unpackings.extend(self._print_unpacking(argstr, funcargs[-1])) + else: + funcargs.append(argstr) + + funcsig = 'def {}({}):'.format(funcname, ', '.join(funcargs)) + + # Wrap input arguments before unpacking + funcbody.extend(self._print_funcargwrapping(funcargs)) + + funcbody.extend(unpackings) + + for s, e in cses: + if e is None: + funcbody.append('del {}'.format(self._exprrepr(s))) + else: + funcbody.append('{} = {}'.format(self._exprrepr(s), self._exprrepr(e))) + + # Subs may appear in expressions generated by .diff() + subs_assignments = [] + expr = self._handle_Subs(expr, out=subs_assignments) + for lhs, rhs in subs_assignments: + funcbody.append('{} = {}'.format(self._exprrepr(lhs), self._exprrepr(rhs))) + + str_expr = _recursive_to_string(self._exprrepr, expr) + + if '\n' in str_expr: + str_expr = '({})'.format(str_expr) + funcbody.append('return {}'.format(str_expr)) + + funclines = [funcsig] + funclines.extend([' ' + line for line in funcbody]) + + return '\n'.join(funclines) + '\n' + + @classmethod + def _is_safe_ident(cls, ident): + return isinstance(ident, str) and ident.isidentifier() \ + and not keyword.iskeyword(ident) + + def _preprocess(self, args, expr, cses=(), _dummies_dict=None): + """Preprocess args, expr to replace arguments that do not map + to valid Python identifiers. + + Returns string form of args, and updated expr. + """ + from sympy.core.basic import Basic + from sympy.core.sorting import ordered + from sympy.core.function import (Derivative, Function) + from sympy.core.symbol import Dummy, uniquely_named_symbol + from sympy.matrices import DeferredVector + from sympy.core.expr import Expr + + # Args of type Dummy can cause name collisions with args + # of type Symbol. Force dummify of everything in this + # situation. + dummify = self._dummify or any( + isinstance(arg, Dummy) for arg in flatten(args)) + + argstrs = [None]*len(args) + if _dummies_dict is None: + _dummies_dict = {} + + def update_dummies(arg, dummy): + _dummies_dict[arg] = dummy + for repl, sub in cses: + arg = arg.xreplace({sub: repl}) + _dummies_dict[arg] = dummy + + for arg, i in reversed(list(ordered(zip(args, range(len(args)))))): + if iterable(arg): + s, expr = self._preprocess(arg, expr, cses=cses, _dummies_dict=_dummies_dict) + elif isinstance(arg, DeferredVector): + s = str(arg) + elif isinstance(arg, Basic) and arg.is_symbol: + s = str(arg) + if dummify or not self._is_safe_ident(s): + dummy = Dummy() + if isinstance(expr, Expr): + dummy = uniquely_named_symbol( + dummy.name, expr, modify=lambda s: '_' + s) + s = self._argrepr(dummy) + update_dummies(arg, dummy) + expr = self._subexpr(expr, _dummies_dict) + elif dummify or isinstance(arg, (Function, Derivative)): + dummy = Dummy() + s = self._argrepr(dummy) + update_dummies(arg, dummy) + expr = self._subexpr(expr, _dummies_dict) + else: + s = str(arg) + argstrs[i] = s + return argstrs, expr + + def _subexpr(self, expr, dummies_dict): + from sympy.matrices import DeferredVector + from sympy.core.sympify import sympify + + expr = sympify(expr) + xreplace = getattr(expr, 'xreplace', None) + if xreplace is not None: + expr = xreplace(dummies_dict) + else: + if isinstance(expr, DeferredVector): + pass + elif isinstance(expr, dict): + k = [self._subexpr(sympify(a), dummies_dict) for a in expr.keys()] + v = [self._subexpr(sympify(a), dummies_dict) for a in expr.values()] + expr = dict(zip(k, v)) + elif isinstance(expr, tuple): + expr = tuple(self._subexpr(sympify(a), dummies_dict) for a in expr) + elif isinstance(expr, list): + expr = [self._subexpr(sympify(a), dummies_dict) for a in expr] + return expr + + def _print_funcargwrapping(self, args): + """Generate argument wrapping code. + + args is the argument list of the generated function (strings). + + Return value is a list of lines of code that will be inserted at + the beginning of the function definition. + """ + return [] + + def _print_unpacking(self, unpackto, arg): + """Generate argument unpacking code. + + arg is the function argument to be unpacked (a string), and + unpackto is a list or nested lists of the variable names (strings) to + unpack to. + """ + def unpack_lhs(lvalues): + return '[{}]'.format(', '.join( + unpack_lhs(val) if iterable(val) else val for val in lvalues)) + + return ['{} = {}'.format(unpack_lhs(unpackto), arg)] + + def _handle_Subs(self, expr, out): + """Any instance of Subs is extracted and returned as assignment pairs.""" + from sympy.core.basic import Basic + from sympy.core.function import Subs + from sympy.core.symbol import Dummy + from sympy.matrices.matrixbase import MatrixBase + + def _replace(ex, variables, point): + safe = {} + for lhs, rhs in zip(variables, point): + dummy = Dummy() + safe[lhs] = dummy + out.append((dummy, rhs)) + return ex.xreplace(safe) + + if isinstance(expr, (Basic, MatrixBase)): + expr = expr.replace(Subs, _replace) + elif iterable(expr): + expr = type(expr)([self._handle_Subs(e, out) for e in expr]) + return expr + +class _TensorflowEvaluatorPrinter(_EvaluatorPrinter): + def _print_unpacking(self, lvalues, rvalue): + """Generate argument unpacking code. + + This method is used when the input value is not iterable, + but can be indexed (see issue #14655). + """ + + def flat_indexes(elems): + n = 0 + + for el in elems: + if iterable(el): + for ndeep in flat_indexes(el): + yield (n,) + ndeep + else: + yield (n,) + + n += 1 + + indexed = ', '.join('{}[{}]'.format(rvalue, ']['.join(map(str, ind))) + for ind in flat_indexes(lvalues)) + + return ['[{}] = [{}]'.format(', '.join(flatten(lvalues)), indexed)] + +def _imp_namespace(expr, namespace=None): + """ Return namespace dict with function implementations + + We need to search for functions in anything that can be thrown at + us - that is - anything that could be passed as ``expr``. Examples + include SymPy expressions, as well as tuples, lists and dicts that may + contain SymPy expressions. + + Parameters + ---------- + expr : object + Something passed to lambdify, that will generate valid code from + ``str(expr)``. + namespace : None or mapping + Namespace to fill. None results in new empty dict + + Returns + ------- + namespace : dict + dict with keys of implemented function names within ``expr`` and + corresponding values being the numerical implementation of + function + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy.utilities.lambdify import implemented_function, _imp_namespace + >>> from sympy import Function + >>> f = implemented_function(Function('f'), lambda x: x+1) + >>> g = implemented_function(Function('g'), lambda x: x*10) + >>> namespace = _imp_namespace(f(g(x))) + >>> sorted(namespace.keys()) + ['f', 'g'] + """ + # Delayed import to avoid circular imports + from sympy.core.function import FunctionClass + if namespace is None: + namespace = {} + # tuples, lists, dicts are valid expressions + if is_sequence(expr): + for arg in expr: + _imp_namespace(arg, namespace) + return namespace + elif isinstance(expr, dict): + for key, val in expr.items(): + # functions can be in dictionary keys + _imp_namespace(key, namespace) + _imp_namespace(val, namespace) + return namespace + # SymPy expressions may be Functions themselves + func = getattr(expr, 'func', None) + if isinstance(func, FunctionClass): + imp = getattr(func, '_imp_', None) + if imp is not None: + name = expr.func.__name__ + if name in namespace and namespace[name] != imp: + raise ValueError('We found more than one ' + 'implementation with name ' + '"%s"' % name) + namespace[name] = imp + # and / or they may take Functions as arguments + if hasattr(expr, 'args'): + for arg in expr.args: + _imp_namespace(arg, namespace) + return namespace + + +def implemented_function(symfunc, implementation): + """ Add numerical ``implementation`` to function ``symfunc``. + + ``symfunc`` can be an ``UndefinedFunction`` instance, or a name string. + In the latter case we create an ``UndefinedFunction`` instance with that + name. + + Be aware that this is a quick workaround, not a general method to create + special symbolic functions. If you want to create a symbolic function to be + used by all the machinery of SymPy you should subclass the ``Function`` + class. + + Parameters + ---------- + symfunc : ``str`` or ``UndefinedFunction`` instance + If ``str``, then create new ``UndefinedFunction`` with this as + name. If ``symfunc`` is an Undefined function, create a new function + with the same name and the implemented function attached. + implementation : callable + numerical implementation to be called by ``evalf()`` or ``lambdify`` + + Returns + ------- + afunc : sympy.FunctionClass instance + function with attached implementation + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy.utilities.lambdify import implemented_function + >>> from sympy import lambdify + >>> f = implemented_function('f', lambda x: x+1) + >>> lam_f = lambdify(x, f(x)) + >>> lam_f(4) + 5 + """ + # Delayed import to avoid circular imports + from sympy.core.function import UndefinedFunction + # if name, create function to hold implementation + kwargs = {} + if isinstance(symfunc, UndefinedFunction): + kwargs = symfunc._kwargs + symfunc = symfunc.__name__ + if isinstance(symfunc, str): + # Keyword arguments to UndefinedFunction are added as attributes to + # the created class. + symfunc = UndefinedFunction( + symfunc, _imp_=staticmethod(implementation), **kwargs) + elif not isinstance(symfunc, UndefinedFunction): + raise ValueError(filldedent(''' + symfunc should be either a string or + an UndefinedFunction instance.''')) + return symfunc + + +def _too_large_for_docstring(expr, limit): + """Decide whether an ``Expr`` is too large to be fully rendered in a + ``lambdify`` docstring. + + This is a fast alternative to ``count_ops``, which can become prohibitively + slow for large expressions, because in this instance we only care whether + ``limit`` is exceeded rather than counting the exact number of nodes in the + expression. + + Parameters + ========== + expr : ``Expr``, (nested) ``list`` of ``Expr``, or ``Matrix`` + The same objects that can be passed to the ``expr`` argument of + ``lambdify``. + limit : ``int`` or ``None`` + The threshold above which an expression contains too many nodes to be + usefully rendered in the docstring. If ``None`` then there is no limit. + + Returns + ======= + bool + ``True`` if the number of nodes in the expression exceeds the limit, + ``False`` otherwise. + + Examples + ======== + + >>> from sympy.abc import x, y, z + >>> from sympy.utilities.lambdify import _too_large_for_docstring + >>> expr = x + >>> _too_large_for_docstring(expr, None) + False + >>> _too_large_for_docstring(expr, 100) + False + >>> _too_large_for_docstring(expr, 1) + False + >>> _too_large_for_docstring(expr, 0) + True + >>> _too_large_for_docstring(expr, -1) + True + + Does this split it? + + >>> expr = [x, y, z] + >>> _too_large_for_docstring(expr, None) + False + >>> _too_large_for_docstring(expr, 100) + False + >>> _too_large_for_docstring(expr, 1) + True + >>> _too_large_for_docstring(expr, 0) + True + >>> _too_large_for_docstring(expr, -1) + True + + >>> expr = [x, [y], z, [[x+y], [x*y*z, [x+y+z]]]] + >>> _too_large_for_docstring(expr, None) + False + >>> _too_large_for_docstring(expr, 100) + False + >>> _too_large_for_docstring(expr, 1) + True + >>> _too_large_for_docstring(expr, 0) + True + >>> _too_large_for_docstring(expr, -1) + True + + >>> expr = ((x + y + z)**5).expand() + >>> _too_large_for_docstring(expr, None) + False + >>> _too_large_for_docstring(expr, 100) + True + >>> _too_large_for_docstring(expr, 1) + True + >>> _too_large_for_docstring(expr, 0) + True + >>> _too_large_for_docstring(expr, -1) + True + + >>> from sympy import Matrix + >>> expr = Matrix([[(x + y + z), ((x + y + z)**2).expand(), + ... ((x + y + z)**3).expand(), ((x + y + z)**4).expand()]]) + >>> _too_large_for_docstring(expr, None) + False + >>> _too_large_for_docstring(expr, 1000) + False + >>> _too_large_for_docstring(expr, 100) + True + >>> _too_large_for_docstring(expr, 1) + True + >>> _too_large_for_docstring(expr, 0) + True + >>> _too_large_for_docstring(expr, -1) + True + + """ + # Must be imported here to avoid a circular import error + from sympy.core.traversal import postorder_traversal + + if limit is None: + return False + + i = 0 + for _ in postorder_traversal(expr): + i += 1 + if i > limit: + return True + return False diff --git a/phivenv/Lib/site-packages/sympy/utilities/magic.py b/phivenv/Lib/site-packages/sympy/utilities/magic.py new file mode 100644 index 0000000000000000000000000000000000000000..e853a0ad9a85bc252dcb24e8a1ecbfca422ac3fd --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/magic.py @@ -0,0 +1,12 @@ +"""Functions that involve magic. """ + +def pollute(names, objects): + """Pollute the global namespace with symbols -> objects mapping. """ + from inspect import currentframe + frame = currentframe().f_back.f_back + + try: + for name, obj in zip(names, objects): + frame.f_globals[name] = obj + finally: + del frame # break cyclic dependencies as stated in inspect docs diff --git a/phivenv/Lib/site-packages/sympy/utilities/matchpy_connector.py b/phivenv/Lib/site-packages/sympy/utilities/matchpy_connector.py new file mode 100644 index 0000000000000000000000000000000000000000..35aa013294b93bbcfe4a1bf4ec96b629ea5a468f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/matchpy_connector.py @@ -0,0 +1,340 @@ +""" +The objects in this module allow the usage of the MatchPy pattern matching +library on SymPy expressions. +""" +import re +from typing import List, Callable, NamedTuple, Any, Dict + +from sympy.core.sympify import _sympify +from sympy.external import import_module +from sympy.functions import (log, sin, cos, tan, cot, csc, sec, erf, gamma, uppergamma) +from sympy.functions.elementary.hyperbolic import acosh, asinh, atanh, acoth, acsch, asech, cosh, sinh, tanh, coth, sech, csch +from sympy.functions.elementary.trigonometric import atan, acsc, asin, acot, acos, asec +from sympy.functions.special.error_functions import fresnelc, fresnels, erfc, erfi, Ei +from sympy.core.add import Add +from sympy.core.basic import Basic +from sympy.core.expr import Expr +from sympy.core.mul import Mul +from sympy.core.power import Pow +from sympy.core.relational import (Equality, Unequality) +from sympy.core.symbol import Symbol +from sympy.functions.elementary.exponential import exp +from sympy.integrals.integrals import Integral +from sympy.printing.repr import srepr +from sympy.utilities.decorator import doctest_depends_on + + +matchpy = import_module("matchpy") + + +__doctest_requires__ = {('*',): ['matchpy']} + + +if matchpy: + from matchpy import Operation, CommutativeOperation, AssociativeOperation, OneIdentityOperation + from matchpy.expressions.functions import op_iter, create_operation_expression, op_len + + Operation.register(Integral) + Operation.register(Pow) + OneIdentityOperation.register(Pow) + + Operation.register(Add) + OneIdentityOperation.register(Add) + CommutativeOperation.register(Add) + AssociativeOperation.register(Add) + + Operation.register(Mul) + OneIdentityOperation.register(Mul) + CommutativeOperation.register(Mul) + AssociativeOperation.register(Mul) + + Operation.register(Equality) + CommutativeOperation.register(Equality) + Operation.register(Unequality) + CommutativeOperation.register(Unequality) + + Operation.register(exp) + Operation.register(log) + Operation.register(gamma) + Operation.register(uppergamma) + Operation.register(fresnels) + Operation.register(fresnelc) + Operation.register(erf) + Operation.register(Ei) + Operation.register(erfc) + Operation.register(erfi) + Operation.register(sin) + Operation.register(cos) + Operation.register(tan) + Operation.register(cot) + Operation.register(csc) + Operation.register(sec) + Operation.register(sinh) + Operation.register(cosh) + Operation.register(tanh) + Operation.register(coth) + Operation.register(csch) + Operation.register(sech) + Operation.register(asin) + Operation.register(acos) + Operation.register(atan) + Operation.register(acot) + Operation.register(acsc) + Operation.register(asec) + Operation.register(asinh) + Operation.register(acosh) + Operation.register(atanh) + Operation.register(acoth) + Operation.register(acsch) + Operation.register(asech) + + @op_iter.register(Integral) # type: ignore + def _(operation): + return iter((operation._args[0],) + operation._args[1]) + + @op_iter.register(Basic) # type: ignore + def _(operation): + return iter(operation._args) + + @op_len.register(Integral) # type: ignore + def _(operation): + return 1 + len(operation._args[1]) + + @op_len.register(Basic) # type: ignore + def _(operation): + return len(operation._args) + + @create_operation_expression.register(Basic) + def sympy_op_factory(old_operation, new_operands, variable_name=True): + return type(old_operation)(*new_operands) + + +if matchpy: + from matchpy import Wildcard +else: + class Wildcard: # type: ignore + def __init__(self, min_length, fixed_size, variable_name, optional): + self.min_count = min_length + self.fixed_size = fixed_size + self.variable_name = variable_name + self.optional = optional + + +@doctest_depends_on(modules=('matchpy',)) +class _WildAbstract(Wildcard, Symbol): + min_length: int # abstract field required in subclasses + fixed_size: bool # abstract field required in subclasses + + def __init__(self, variable_name=None, optional=None, **assumptions): + min_length = self.min_length + fixed_size = self.fixed_size + if optional is not None: + optional = _sympify(optional) + Wildcard.__init__(self, min_length, fixed_size, str(variable_name), optional) + + def __getstate__(self): + return { + "min_length": self.min_length, + "fixed_size": self.fixed_size, + "min_count": self.min_count, + "variable_name": self.variable_name, + "optional": self.optional, + } + + def __new__(cls, variable_name=None, optional=None, **assumptions): + cls._sanitize(assumptions, cls) + return _WildAbstract.__xnew__(cls, variable_name, optional, **assumptions) + + def __getnewargs__(self): + return self.variable_name, self.optional + + @staticmethod + def __xnew__(cls, variable_name=None, optional=None, **assumptions): + obj = Symbol.__xnew__(cls, variable_name, **assumptions) + return obj + + def _hashable_content(self): + if self.optional: + return super()._hashable_content() + (self.min_count, self.fixed_size, self.variable_name, self.optional) + else: + return super()._hashable_content() + (self.min_count, self.fixed_size, self.variable_name) + + def __copy__(self) -> '_WildAbstract': + return type(self)(variable_name=self.variable_name, optional=self.optional) + + def __repr__(self): + return str(self) + + def __str__(self): + return self.name + + +@doctest_depends_on(modules=('matchpy',)) +class WildDot(_WildAbstract): + min_length = 1 + fixed_size = True + + +@doctest_depends_on(modules=('matchpy',)) +class WildPlus(_WildAbstract): + min_length = 1 + fixed_size = False + + +@doctest_depends_on(modules=('matchpy',)) +class WildStar(_WildAbstract): + min_length = 0 + fixed_size = False + + +def _get_srepr(expr): + s = srepr(expr) + s = re.sub(r"WildDot\('(\w+)'\)", r"\1", s) + s = re.sub(r"WildPlus\('(\w+)'\)", r"*\1", s) + s = re.sub(r"WildStar\('(\w+)'\)", r"*\1", s) + return s + + +class ReplacementInfo(NamedTuple): + replacement: Any + info: Any + + +@doctest_depends_on(modules=('matchpy',)) +class Replacer: + """ + Replacer object to perform multiple pattern matching and subexpression + replacements in SymPy expressions. + + Examples + ======== + + Example to construct a simple first degree equation solver: + + >>> from sympy.utilities.matchpy_connector import WildDot, Replacer + >>> from sympy import Equality, Symbol + >>> x = Symbol("x") + >>> a_ = WildDot("a_", optional=1) + >>> b_ = WildDot("b_", optional=0) + + The lines above have defined two wildcards, ``a_`` and ``b_``, the + coefficients of the equation `a x + b = 0`. The optional values specified + indicate which expression to return in case no match is found, they are + necessary in equations like `a x = 0` and `x + b = 0`. + + Create two constraints to make sure that ``a_`` and ``b_`` will not match + any expression containing ``x``: + + >>> from matchpy import CustomConstraint + >>> free_x_a = CustomConstraint(lambda a_: not a_.has(x)) + >>> free_x_b = CustomConstraint(lambda b_: not b_.has(x)) + + Now create the rule replacer with the constraints: + + >>> replacer = Replacer(common_constraints=[free_x_a, free_x_b]) + + Add the matching rule: + + >>> replacer.add(Equality(a_*x + b_, 0), -b_/a_) + + Let's try it: + + >>> replacer.replace(Equality(3*x + 4, 0)) + -4/3 + + Notice that it will not match equations expressed with other patterns: + + >>> eq = Equality(3*x, 4) + >>> replacer.replace(eq) + Eq(3*x, 4) + + In order to extend the matching patterns, define another one (we also need + to clear the cache, because the previous result has already been memorized + and the pattern matcher will not iterate again if given the same expression) + + >>> replacer.add(Equality(a_*x, b_), b_/a_) + >>> replacer._matcher.clear() + >>> replacer.replace(eq) + 4/3 + """ + + def __init__(self, common_constraints: list = [], lambdify: bool = False, info: bool = False): + self._matcher = matchpy.ManyToOneMatcher() + self._common_constraint = common_constraints + self._lambdify = lambdify + self._info = info + self._wildcards: Dict[str, Wildcard] = {} + + def _get_lambda(self, lambda_str: str) -> Callable[..., Expr]: + exec("from sympy import *") + return eval(lambda_str, locals()) + + def _get_custom_constraint(self, constraint_expr: Expr, condition_template: str) -> Callable[..., Expr]: + wilds = [x.name for x in constraint_expr.atoms(_WildAbstract)] + lambdaargs = ', '.join(wilds) + fullexpr = _get_srepr(constraint_expr) + condition = condition_template.format(fullexpr) + return matchpy.CustomConstraint( + self._get_lambda(f"lambda {lambdaargs}: ({condition})")) + + def _get_custom_constraint_nonfalse(self, constraint_expr: Expr) -> Callable[..., Expr]: + return self._get_custom_constraint(constraint_expr, "({}) != False") + + def _get_custom_constraint_true(self, constraint_expr: Expr) -> Callable[..., Expr]: + return self._get_custom_constraint(constraint_expr, "({}) == True") + + def add(self, expr: Expr, replacement, conditions_true: List[Expr] = [], + conditions_nonfalse: List[Expr] = [], info: Any = None) -> None: + expr = _sympify(expr) + replacement = _sympify(replacement) + constraints = self._common_constraint[:] + constraint_conditions_true = [ + self._get_custom_constraint_true(cond) for cond in conditions_true] + constraint_conditions_nonfalse = [ + self._get_custom_constraint_nonfalse(cond) for cond in conditions_nonfalse] + constraints.extend(constraint_conditions_true) + constraints.extend(constraint_conditions_nonfalse) + pattern = matchpy.Pattern(expr, *constraints) + if self._lambdify: + lambda_str = f"lambda {', '.join((x.name for x in expr.atoms(_WildAbstract)))}: {_get_srepr(replacement)}" + lambda_expr = self._get_lambda(lambda_str) + replacement = lambda_expr + else: + self._wildcards.update({str(i): i for i in expr.atoms(Wildcard)}) + if self._info: + replacement = ReplacementInfo(replacement, info) + self._matcher.add(pattern, replacement) + + def replace(self, expression, max_count: int = -1): + # This method partly rewrites the .replace method of ManyToOneReplacer + # in MatchPy. + # License: https://github.com/HPAC/matchpy/blob/master/LICENSE + infos = [] + replaced = True + replace_count = 0 + while replaced and (max_count < 0 or replace_count < max_count): + replaced = False + for subexpr, pos in matchpy.preorder_iter_with_position(expression): + try: + replacement_data, subst = next(iter(self._matcher.match(subexpr))) + if self._info: + replacement = replacement_data.replacement + infos.append(replacement_data.info) + else: + replacement = replacement_data + + if self._lambdify: + result = replacement(**subst) + else: + result = replacement.xreplace({self._wildcards[k]: v for k, v in subst.items()}) + + expression = matchpy.functions.replace(expression, pos, result) + replaced = True + break + except StopIteration: + pass + replace_count += 1 + if self._info: + return expression, infos + else: + return expression diff --git a/phivenv/Lib/site-packages/sympy/utilities/mathml/__init__.py b/phivenv/Lib/site-packages/sympy/utilities/mathml/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eded44ee3c0f34ad1324765ba06ee9d6eb5e9899 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/mathml/__init__.py @@ -0,0 +1,122 @@ +"""Module with some functions for MathML, like transforming MathML +content in MathML presentation. + +To use this module, you will need lxml. +""" + +from pathlib import Path + +from sympy.utilities.decorator import doctest_depends_on + + +__doctest_requires__ = {('apply_xsl', 'c2p'): ['lxml']} + + +def add_mathml_headers(s): + return """""" + s + "" + + +def _read_binary(pkgname, filename): + import sys + + if sys.version_info >= (3, 10): + # files was added in Python 3.9 but only seems to work here in 3.10+ + from importlib.resources import files + return files(pkgname).joinpath(filename).read_bytes() + else: + # read_binary was deprecated in Python 3.11 + from importlib.resources import read_binary + return read_binary(pkgname, filename) + + +def _read_xsl(xsl): + # Previously these values were allowed: + if xsl == 'mathml/data/simple_mmlctop.xsl': + xsl = 'simple_mmlctop.xsl' + elif xsl == 'mathml/data/mmlctop.xsl': + xsl = 'mmlctop.xsl' + elif xsl == 'mathml/data/mmltex.xsl': + xsl = 'mmltex.xsl' + + if xsl in ['simple_mmlctop.xsl', 'mmlctop.xsl', 'mmltex.xsl']: + xslbytes = _read_binary('sympy.utilities.mathml.data', xsl) + else: + xslbytes = Path(xsl).read_bytes() + + return xslbytes + + +@doctest_depends_on(modules=('lxml',)) +def apply_xsl(mml, xsl): + """Apply a xsl to a MathML string. + + Parameters + ========== + + mml + A string with MathML code. + xsl + A string giving the name of an xsl (xml stylesheet) file which can be + found in sympy/utilities/mathml/data. The following files are supplied + with SymPy: + + - mmlctop.xsl + - mmltex.xsl + - simple_mmlctop.xsl + + Alternatively, a full path to an xsl file can be given. + + Examples + ======== + + >>> from sympy.utilities.mathml import apply_xsl + >>> xsl = 'simple_mmlctop.xsl' + >>> mml = ' a b ' + >>> res = apply_xsl(mml,xsl) + >>> print(res) + + + a + + + b + + """ + from lxml import etree + + parser = etree.XMLParser(resolve_entities=False) + ac = etree.XSLTAccessControl.DENY_ALL + + s = etree.XML(_read_xsl(xsl), parser=parser) + transform = etree.XSLT(s, access_control=ac) + doc = etree.XML(mml, parser=parser) + result = transform(doc) + s = str(result) + return s + + +@doctest_depends_on(modules=('lxml',)) +def c2p(mml, simple=False): + """Transforms a document in MathML content (like the one that sympy produces) + in one document in MathML presentation, more suitable for printing, and more + widely accepted + + Examples + ======== + + >>> from sympy.utilities.mathml import c2p + >>> mml = ' 2 ' + >>> c2p(mml,simple=True) != c2p(mml,simple=False) + True + + """ + + if not mml.startswith(' + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + e + + + + + + + + + + + + + - + + + + + + + + &#x2062; + &#x2148; + + + + + + + + + + + + + - + + + + + + + + &#x2062; + &#x2148; + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Polar + &#x2062; + + + + + + + + + + + + + + + Polar + &#x2062; + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + &#x2061; + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + [ + + + ] + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + -1 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + &#x03BB; + &#x2061; + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + &#x2218; + + + + + + + + + + + + + + + + + + &#x2218; + + + + + + + + + + + + + + + + id + + + id + + + + + + + + + + + + + + domain + + + codomain + + + image + + &#x2061; + + + + + + + + + + + + + + + + + + + + + + + + + + + { + + + + + + + + if + + + + + + + + + + + otherwise + + + + + + + + + + + + + + + + + + + + + + + &#x230A; + + + + + + + + + + + + + + + + + + + &#x230B; + + + + + + + + + + + + &#x2147; + + + + + + + + + + + + + + + + + ! + + + + + + + + + + + + + + + + max + + + min + + + + + + + max + + + min + + + + + + + + + + + + + | + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + + + + + + + + + + + + + + - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + &#x2062; + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + gcd + + + lcm + + + + + + gcd + + + lcm + + + + + &#x2061; + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + &#x2227; + &#x2061; + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + &#x2228; + &#x2061; + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + &#x22BB; + &#x2061; + + + + + + + + + + + + + + + + + + + + + + + &#x00AC; + &#x2061; + + + + + + + + + + + + + + + &#x00AC; + &#x2061; + + + + + + + + + + + + + + + + + + + &#x2200; + + + + + + + + + + + + : + + + + , + + + + + + + + + + + + + + &#x2203; + + + + + + + + + + + + : + + + + , + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + &#x00AF; + + + + + + + + + + + + + + + + + &#x211C; + + + &#x2111; + + + &#x2061; + + + + + + + + + + + + + + + + + &#x230A; + + + &#x2308; + + + + + + &#x230B; + + + &#x2309; + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + &#x2260; + + + &#x2248; + + + &#x2223; + + + + + &#x2198; + + + &#x2197; + + + &#x2192; + + + + + &#x21D2; + + + &#x2208; + + + &#x2209; + + + &#x2284; + + + &#x2288; + + + + + + + + + + &#x2286; + + + &#x2282; + + + + + + + + + + + + &#x2265; + + + &#x2264; + + + &#x2261; + + + + + + + + + + + + + + + + + + + + + + ln + + + + + ln + + + + + + + + + + + + + + + + + + + + + + log + + + + + + log + + + + + + + + log + + + + log + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + &#x2146; + + &#x2146; + + + + + + + + &#x2146; + + + + &#x2146; + + + + + + + + + + + + + + + + + + &#x2032; + + + + + + + + + + + + + + + + + &#x2145; + + + + + + + + &#x2202; + + + + + &#x2202; + + + + + + + + + + + + + + + + + + + &#x2202; + + + + &#x2202; + + + + + + + + + + &#x2202; + + &#x2202; + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + &#x2207; + 2 + + &#x2061; + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + | + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + &#x222A; + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + &#x2229; + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + &#x00D7; + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + &#x2211; + + + &#x220F; + + + + + = + + + + + + + + + + + &#x2211; + + + &#x220F; + + + + + + + + + + + &#x2211; + + + &#x220F; + + + + + + + + + + + + + + + + + + + + + + + + &#x222B; + + + + + + &#x222B; + + + + + + &#x222B; + + + + + + + + + + + + &#x222B; + + + + + + + + + &#x222B; + + + + + + + + + &#x2146; + + + + + + + + + + + + + + + + lim + + + + &#x2192; + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + &#x03C3; + + + + + + + + + + + + + + + + + &#x03C3; + + + + + + + 2 + + + + + + + + + + + + + median + + + + + + + + + + + + + + + + + mode + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + det + + + + + + + + + + + + + + + T + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + &#x00D7; + + + &#x22C5; + + + &#x2297; + + + + + + + + + + + + &#x2124; + + + + &#x211D; + + + + &#x211A; + + + + &#x2115; + + + + &#x2102; + + + + &#x2119; + + + + &#x2147; + + + + &#x2148; + + + + NaN + + + + true + + + + false + + + + &#x2205; + + + + &#x03C0; + + + + &#x213D; + + + + &#x221E; + + + diff --git a/phivenv/Lib/site-packages/sympy/utilities/mathml/data/mmltex.xsl b/phivenv/Lib/site-packages/sympy/utilities/mathml/data/mmltex.xsl new file mode 100644 index 0000000000000000000000000000000000000000..5e6b85e02efd5196fe76b6ce4e10def27b9a8497 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/mathml/data/mmltex.xsl @@ -0,0 +1,2360 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + $ + + $ + + + + + + + + + + + + + + + + i + + + + + / + + + + + + _{} + + + + + e^{i + + } + + + + + E + + + + + + + + \mathrm{} + + + + + + + + + + + + + ( + + + , + + ) + + + + + () + + + + + + + \left( + + \left[ + + + , + + + + \right) + + \right] + + + + + \left\{\right\} + + + + + ^{(-1)} + + + + + + + + \mathrm{lambda}\: + + .\: + + + + + + + + + + \circ + + + + +\mathrm{id} + + + + \mathop{\mathrm{ + + }} + + + + + + + + \begin{cases} + + + \end{cases} + + + + + & \text{if $ + + $} + \\ + + + + + & \text{otherwise} + + + + + \left\lfloor\frac{ + + }{ + + }\right\rfloor + + + + + + + + ! + + + + + + + \left( + \frac{ + + + }{ + + + } + \right) + + + + + \ + + \{ + + + + , + + + + + + , + + + + \} + + + + + - + + + + + + + + + - + + + + + + + + + + ( + + + + + - + + + + + + + + + + + + + + + + + + + + + + + + + + + + ) + + + + + + + + + ^{ + + + + } + + + + + + + \mod + + + + + + + + + + ( + + + + \times + + + + + + + + + + ) + + + + + \sqrt + + [ + + ] + + { + + } + + + +\gcd + + + + + + + + \land + + + + + + + + + + \lor + + + + + + + + + + \mathop{\mathrm{xor}} + + + + + + \neg + + + + + + + + + + \implies + + + + + + + + \ + + + + + , + + + \colon + + + + + + + \left| + + \right| + + + + + \overline{} + + + +\Re + + +\Im + + + + \left\lfloor + + \right\rfloor + + + + + \left\lceil + + \right\rceil + + + + + + + + + = + + + + + + + + + + \neq + + + + + + + + + + > + + + + + + + + + + < + + + + + + + + + + \ge + + + + + + + + + + \le + + + + + + + + + + \equiv + + + + + + + + + + \approx + + + + + + + + | + + + + + + + + \int + + _{ + + } + + + ^{ + + } + + + + \,d + + + + + + + ^\prime + + + + \frac{ + + + d^{ + + } + + }{d + + ^{ + + } + + + d + + }{d + + } + + + } + + + + + D_{ + + + , + + } + + + + + \frac{\partial^{ + + + + + + + + + + + + + + + + + + + + + } + + }{ + + \partial + + + ^{ + + } + + + } + + + + + + + + + , + + + +\mathop{\mathrm{div}} + + +\nabla^2 + + + + \{\} + + + + + \left[\right] + + + + + + + \colon + + + + + + , + + + + + + + + + + + + \cup + + + + + + + + + + \cap + + + + + + + + \in + + + + + + + + + + \notin + + + + + + + + + + + + \subseteq + + + + + + + + + + \subset + + + + + + + + + + \nsubseteq + + + + + + + + + + \not\subset + + + + + + + + + + \setminus + + + + + + | + + | + + + + + + + + + \times + + + + + + + + ^{ + + } + + + + + \sum + + + + + \prod + + + + + _{ + + + = + + + } + + + ^{ + + } + + + + + + + + \lim_{ + + } + + + + + + \to + + + + + + + + + + + + \searrow + \nearrow + \rightarrow + \to + + + + + + + + \ + + + + + + + + + \ + + + + + + \mathrm{ + + \,} + + + + + + + \mathrm{ + + } + + + + + e^{} + + + + + \lg + + + + + + + \log_{ + + } + + + + + + + + \left\langle + + + , + + \right\rangle + + + +\sigma + + + + \sigma( + + )^2 + + + + + \left\langle + + ^{ + + }\right\rangle + + _{ + + } + + + + + + + \left(\begin{array}{c} + + + \\ + + \end{array}\right) + + + + + \begin{pmatrix} + + \end{pmatrix} + + + + + + + & + + \\ + + + + + \det + + + + + + + \begin{vmatrix} + + \end{vmatrix} + + + + + + + + ^T + + + + + + + + _{ + + + , + + } + + + + + + + + + \dot + + + + + + + + + + + +\mathbb{Z} + + +\mathbb{R} + + +\mathbb{Q} + + +\mathbb{N} + + +\mathbb{C} + + +\mathbb{P} + + +e + + +i + + +NaN + + +\mbox{true} + + +\mbox{false} + + +\emptyset + + +\pi + + +\gamma + + +\infty + + + + + + + ( + + + + + + + + + ) + + + + + + + ( + + + + + + + + ) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \multicolumn{ + + }{c}{ + + } + + & + + + + + + + \hfill + + + + \hfill + + + + & + + + + + + + \\ + + + + + \begin{array}{ + + | + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + | + + } + + \hline + + + + \\ \hline + + \end{array} + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \overline{ + + + + + } + + + \overbrace{ + + + + + } + + + \underline{ + + + + + + } + + + \underbrace{ + + + + + + } + + + + + _{ + + }^{ + + } + + + \underset{ + + }{\overset{ + + }{ + + }} + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \overline{ + + } + + + \overbrace{ + + } + + + + + ^{ + + } + + + \stackrel{ + + }{ + + } + + + + + + + + + + + \underline{ + + } + + + \underbrace{ + + } + + + + + _{ + + } + + + \underset{ + + }{ + + } + + + + + + { + + }_{ + + }^{ + + } + + + + { + + }^{ + + } + + + + { + + }_{ + + } + + + + + + {}_{ + + } + + + {}^{ + + } + + + + + + {} + + + _{ + + } + + + ^{ + + } + + + + + + + + + + + + + + {} + + + _{ + + } + + + ^{ + + } + + + + + + + + + + + + + + + + + + \genfrac{}{}{ + + + + ex + + + .05ex + + + + .2ex + + + + + + }{}{ + + + \frac{ + + + + \hfill + + + + \hfill + + }{ + + \hfill + + + + \hfill + + } + + + + + + \sqrt[ + + ]{ + + } + + + + exception 25: + \text{exception 25:} + + + + + + \sqrt{ + + } + + + + + + + \left + + + \ + + + + \left( + + + + + + + + + + + , + + + + + + + + + + + + + + + + + + + + + + + + \right + + + \ + + + + \right) + + + + + \phantom{ + + } + + + + + + \overline{ + + \hspace{.2em}|} + + + \sqrt{ + + } + + + \overline{) + + } + + + + + + + + + + + \colorbox[rgb]{ + + + + }{$ + + + \textcolor[rgb]{ + + + + }{ + + + + } + + + $} + + + + + + + + + + + + + + + + + + + + + \mathrm{ + + } + + + + + + + + + + + + + + + + + + + + + + \text{ + + } + + + + \phantom{\rule + + [- + + ] + + { + + 0ex + + + }{ + + 0ex + + + }} + + + + + + " + + + " + + + + + + \colorbox[rgb]{ + + + + }{$ + + + \textcolor[rgb]{ + + + + }{ + + + + + \mathrm{ + + + \mathbf{ + + + \mathit{ + + + \mathbit{ + + + \mathbb{ + + + { + + + \mathcal{ + + + \mathsc{ + + + \mathfrak{ + + + \mathsf{ + + + \mathbsf{ + + + \mathsfit{ + + + \mathbsfit{ + + + \mathtt{ + + + { + + + + + + } + + + } + + + $} + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + , + + + + + + , + + + + + + + + + + + + + + + + + + + , + + + + + + + + + + + , + + + + + + + + + + + + + + 0,1,1 + 0,0,0 + 0,0,1 + 1,0,1 + .5,.5,.5 + 0,.5,0 + 0,1,0 + .5,0,0 + 0,0,.5 + .5,.5,0 + .5,0,.5 + 1,0,0 + .75,.75,.75 + 0,.5,.5 + 1,1,1 + 1,1,0 + + Exception at color template + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Exception at Hex2Decimal template + + + + + + + + + + + diff --git a/phivenv/Lib/site-packages/sympy/utilities/mathml/data/simple_mmlctop.xsl b/phivenv/Lib/site-packages/sympy/utilities/mathml/data/simple_mmlctop.xsl new file mode 100644 index 0000000000000000000000000000000000000000..0cd73bccc24c963ed9c6c4121cc410997b94261c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/mathml/data/simple_mmlctop.xsl @@ -0,0 +1,3166 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + e + + + + + + + + + + + + + - + + + + + + + + &#x2062; + &#x2148; + + + + + + + + + + + + + - + + + + + + + + &#x2062; + &#x2148; + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Polar + &#x2062; + + + + + + + + + + + + + + + Polar + &#x2062; + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + &#x2061; + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + [ + + + ] + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + -1 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + &#x03BB; + &#x2061; + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + &#x2218; + + + + + + + + + + + + + + + + + + &#x2218; + + + + + + + + + + + + + + + + id + + + id + + + + + + + + + + + + + + domain + + + codomain + + + image + + &#x2061; + + + + + + + + + + + + + + + + + + + + + + + + + + + { + + + + + + + + if + + + + + + + + + + + otherwise + + + + + + + + + + + + + + + + + + + + + + + &#x230A; + + + + + + + + + + + + + + + + + + + &#x230B; + + + + + + + + + + + + e + + + + + + + + + + + + + + + + + ! + + + + + + + + + + + + + + + + max + + + min + + + + + + + max + + + min + + + + + + + + + + + + + | + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + + + + + + + + + + + + + + - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + &#x2062; + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + gcd + + + lcm + + + + + + gcd + + + lcm + + + + + &#x2061; + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + &#x2227; + &#x2061; + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + &#x2228; + &#x2061; + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + &#x22BB; + &#x2061; + + + + + + + + + + + + + + + + + + + + + + + &#x00AC; + &#x2061; + + + + + + + + + + + + + + + &#x00AC; + &#x2061; + + + + + + + + + + + + + + + + + + + &#x2200; + + + + + + + + + + + + : + + + + , + + + + + + + + + + + + + + &#x2203; + + + + + + + + + + + + : + + + + , + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + &#x00AF; + + + + + + + + + + + + + + + + + &#x211C; + + + &#x2111; + + + &#x2061; + + + + + + + + + + + + + + + + + &#x230A; + + + &#x2308; + + + + + + &#x230B; + + + &#x2309; + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + &#x2260; + + + &#x2248; + + + &#x2223; + + + + + &#x2198; + + + &#x2197; + + + &#x2192; + + + + + &#x21D2; + + + &#x2208; + + + &#x2209; + + + &#x2284; + + + &#x2288; + + + + + + + + + + &#x2286; + + + &#x2282; + + + + + + + + + + + + &#x2265; + + + &#x2264; + + + &#x2261; + + + + + + + + + + + + + + + + + + + + + + ln + + + + + ln + + + + + + + + + + + + + + + + + + + + + + log + + + + + + log + + + + + + + + log + + + + log + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + d + + d + + + + + + + + d + + + + d + + + + + + + + + + + + + + + + + + &#x2032; + + + + + + + + + + + + + + + + + &#x2145; + + + + + + + + &#x2202; + + + + + &#x2202; + + + + + + + + + + + + + + + + + + + &#x2202; + + + + &#x2202; + + + + + + + + + + &#x2202; + + &#x2202; + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + &#x2207; + 2 + + &#x2061; + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + | + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + &#x222A; + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + &#x2229; + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + &#x00D7; + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + &#x2211; + + + &#x220F; + + + + + = + + + + + + + + + + + &#x2211; + + + &#x220F; + + + + + + + + + + + &#x2211; + + + &#x220F; + + + + + + + + + + + + + + + + + + + + + + + + &#x222B; + + + + + + &#x222B; + + + + + + &#x222B; + + + + + + + + + + + + &#x222B; + + + + + + + + + &#x222B; + + + + + + + + + d + + + + + + + + + + + + + + + + lim + + + + &#x2192; + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + &#x03C3; + + + + + + + + + + + + + + + + + &#x03C3; + + + + + + + 2 + + + + + + + + + + + + + median + + + + + + + + + + + + + + + + + mode + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + det + + + + + + + + + + + + + + + T + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + &#x00D7; + + + &#x22C5; + + + &#x2297; + + + + + + + + + + + + &#x2124; + + + + &#x211D; + + + + &#x211A; + + + + &#x2115; + + + + &#x2102; + + + + &#x2119; + + + + e + + + + &#x2148; + + + + NaN + + + + true + + + + false + + + + &#x2205; + + + + &#x03C0; + + + + &#x213D; + + + + &#x221E; + + + diff --git a/phivenv/Lib/site-packages/sympy/utilities/memoization.py b/phivenv/Lib/site-packages/sympy/utilities/memoization.py new file mode 100644 index 0000000000000000000000000000000000000000..b638dfe244628096108ea689b664782f6538b7b8 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/memoization.py @@ -0,0 +1,76 @@ +from functools import wraps + + +def recurrence_memo(initial): + """ + Memo decorator for sequences defined by recurrence + + Examples + ======== + + >>> from sympy.utilities.memoization import recurrence_memo + >>> @recurrence_memo([1]) # 0! = 1 + ... def factorial(n, prev): + ... return n * prev[-1] + >>> factorial(4) + 24 + >>> factorial(3) # use cache values + 6 + >>> factorial.cache_length() # cache length can be obtained + 5 + >>> factorial.fetch_item(slice(2, 4)) + [2, 6] + + """ + cache = initial + + def decorator(f): + @wraps(f) + def g(n): + L = len(cache) + if n < L: + return cache[n] + for i in range(L, n + 1): + cache.append(f(i, cache)) + return cache[-1] + g.cache_length = lambda: len(cache) + g.fetch_item = lambda x: cache[x] + return g + return decorator + + +def assoc_recurrence_memo(base_seq): + """ + Memo decorator for associated sequences defined by recurrence starting from base + + base_seq(n) -- callable to get base sequence elements + + XXX works only for Pn0 = base_seq(0) cases + XXX works only for m <= n cases + """ + + cache = [] + + def decorator(f): + @wraps(f) + def g(n, m): + L = len(cache) + if n < L: + return cache[n][m] + + for i in range(L, n + 1): + # get base sequence + F_i0 = base_seq(i) + F_i_cache = [F_i0] + cache.append(F_i_cache) + + # XXX only works for m <= n cases + # generate assoc sequence + for j in range(1, i + 1): + F_ij = f(i, j, cache) + F_i_cache.append(F_ij) + + return cache[n][m] + + return g + return decorator diff --git a/phivenv/Lib/site-packages/sympy/utilities/misc.py b/phivenv/Lib/site-packages/sympy/utilities/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..741b983f03272890b22f2545f67b141767507634 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/misc.py @@ -0,0 +1,564 @@ +"""Miscellaneous stuff that does not really fit anywhere else.""" + +from __future__ import annotations + +import operator +import sys +import os +import re as _re +import struct +from textwrap import fill, dedent + + +class Undecidable(ValueError): + # an error to be raised when a decision cannot be made definitively + # where a definitive answer is needed + pass + + +def filldedent(s, w=70, **kwargs): + """ + Strips leading and trailing empty lines from a copy of ``s``, then dedents, + fills and returns it. + + Empty line stripping serves to deal with docstrings like this one that + start with a newline after the initial triple quote, inserting an empty + line at the beginning of the string. + + Additional keyword arguments will be passed to ``textwrap.fill()``. + + See Also + ======== + strlines, rawlines + + """ + return '\n' + fill(dedent(str(s)).strip('\n'), width=w, **kwargs) + + +def strlines(s, c=64, short=False): + """Return a cut-and-pastable string that, when printed, is + equivalent to the input. The lines will be surrounded by + parentheses and no line will be longer than c (default 64) + characters. If the line contains newlines characters, the + `rawlines` result will be returned. If ``short`` is True + (default is False) then if there is one line it will be + returned without bounding parentheses. + + Examples + ======== + + >>> from sympy.utilities.misc import strlines + >>> q = 'this is a long string that should be broken into shorter lines' + >>> print(strlines(q, 40)) + ( + 'this is a long string that should be b' + 'roken into shorter lines' + ) + >>> q == ( + ... 'this is a long string that should be b' + ... 'roken into shorter lines' + ... ) + True + + See Also + ======== + filldedent, rawlines + """ + if not isinstance(s, str): + raise ValueError('expecting string input') + if '\n' in s: + return rawlines(s) + q = '"' if repr(s).startswith('"') else "'" + q = (q,)*2 + if '\\' in s: # use r-string + m = '(\nr%s%%s%s\n)' % q + j = '%s\nr%s' % q + c -= 3 + else: + m = '(\n%s%%s%s\n)' % q + j = '%s\n%s' % q + c -= 2 + out = [] + while s: + out.append(s[:c]) + s=s[c:] + if short and len(out) == 1: + return (m % out[0]).splitlines()[1] # strip bounding (\n...\n) + return m % j.join(out) + + +def rawlines(s): + """Return a cut-and-pastable string that, when printed, is equivalent + to the input. Use this when there is more than one line in the + string. The string returned is formatted so it can be indented + nicely within tests; in some cases it is wrapped in the dedent + function which has to be imported from textwrap. + + Examples + ======== + + Note: because there are characters in the examples below that need + to be escaped because they are themselves within a triple quoted + docstring, expressions below look more complicated than they would + be if they were printed in an interpreter window. + + >>> from sympy.utilities.misc import rawlines + >>> from sympy import TableForm + >>> s = str(TableForm([[1, 10]], headings=(None, ['a', 'bee']))) + >>> print(rawlines(s)) + ( + 'a bee\\n' + '-----\\n' + '1 10 ' + ) + >>> print(rawlines('''this + ... that''')) + dedent('''\\ + this + that''') + + >>> print(rawlines('''this + ... that + ... ''')) + dedent('''\\ + this + that + ''') + + >>> s = \"\"\"this + ... is a triple ''' + ... \"\"\" + >>> print(rawlines(s)) + dedent(\"\"\"\\ + this + is a triple ''' + \"\"\") + + >>> print(rawlines('''this + ... that + ... ''')) + ( + 'this\\n' + 'that\\n' + ' ' + ) + + See Also + ======== + filldedent, strlines + """ + lines = s.split('\n') + if len(lines) == 1: + return repr(lines[0]) + triple = ["'''" in s, '"""' in s] + if any(li.endswith(' ') for li in lines) or '\\' in s or all(triple): + rv = [] + # add on the newlines + trailing = s.endswith('\n') + last = len(lines) - 1 + for i, li in enumerate(lines): + if i != last or trailing: + rv.append(repr(li + '\n')) + else: + rv.append(repr(li)) + return '(\n %s\n)' % '\n '.join(rv) + else: + rv = '\n '.join(lines) + if triple[0]: + return 'dedent("""\\\n %s""")' % rv + else: + return "dedent('''\\\n %s''')" % rv + +ARCH = str(struct.calcsize('P') * 8) + "-bit" + + +# XXX: PyPy does not support hash randomization +HASH_RANDOMIZATION = getattr(sys.flags, 'hash_randomization', False) + +_debug_tmp: list[str] = [] +_debug_iter = 0 + +def debug_decorator(func): + """If SYMPY_DEBUG is True, it will print a nice execution tree with + arguments and results of all decorated functions, else do nothing. + """ + from sympy import SYMPY_DEBUG + + if not SYMPY_DEBUG: + return func + + def maketree(f, *args, **kw): + global _debug_tmp, _debug_iter + oldtmp = _debug_tmp + _debug_tmp = [] + _debug_iter += 1 + + def tree(subtrees): + def indent(s, variant=1): + x = s.split("\n") + r = "+-%s\n" % x[0] + for a in x[1:]: + if a == "": + continue + if variant == 1: + r += "| %s\n" % a + else: + r += " %s\n" % a + return r + if len(subtrees) == 0: + return "" + f = [] + for a in subtrees[:-1]: + f.append(indent(a)) + f.append(indent(subtrees[-1], 2)) + return ''.join(f) + + # If there is a bug and the algorithm enters an infinite loop, enable the + # following lines. It will print the names and parameters of all major functions + # that are called, *before* they are called + #from functools import reduce + #print("%s%s %s%s" % (_debug_iter, reduce(lambda x, y: x + y, \ + # map(lambda x: '-', range(1, 2 + _debug_iter))), f.__name__, args)) + + r = f(*args, **kw) + + _debug_iter -= 1 + s = "%s%s = %s\n" % (f.__name__, args, r) + if _debug_tmp != []: + s += tree(_debug_tmp) + _debug_tmp = oldtmp + _debug_tmp.append(s) + if _debug_iter == 0: + print(_debug_tmp[0]) + _debug_tmp = [] + return r + + def decorated(*args, **kwargs): + return maketree(func, *args, **kwargs) + + return decorated + + +def debug(*args): + """ + Print ``*args`` if SYMPY_DEBUG is True, else do nothing. + """ + from sympy import SYMPY_DEBUG + if SYMPY_DEBUG: + print(*args, file=sys.stderr) + + +def debugf(string, args): + """ + Print ``string%args`` if SYMPY_DEBUG is True, else do nothing. This is + intended for debug messages using formatted strings. + """ + from sympy import SYMPY_DEBUG + if SYMPY_DEBUG: + print(string%args, file=sys.stderr) + + +def find_executable(executable, path=None): + """Try to find 'executable' in the directories listed in 'path' (a + string listing directories separated by 'os.pathsep'; defaults to + os.environ['PATH']). Returns the complete filename or None if not + found + """ + from .exceptions import sympy_deprecation_warning + sympy_deprecation_warning( + """ + sympy.utilities.misc.find_executable() is deprecated. Use the standard + library shutil.which() function instead. + """, + deprecated_since_version="1.7", + active_deprecations_target="deprecated-find-executable", + ) + if path is None: + path = os.environ['PATH'] + paths = path.split(os.pathsep) + extlist = [''] + if os.name == 'os2': + (base, ext) = os.path.splitext(executable) + # executable files on OS/2 can have an arbitrary extension, but + # .exe is automatically appended if no dot is present in the name + if not ext: + executable = executable + ".exe" + elif sys.platform == 'win32': + pathext = os.environ['PATHEXT'].lower().split(os.pathsep) + (base, ext) = os.path.splitext(executable) + if ext.lower() not in pathext: + extlist = pathext + for ext in extlist: + execname = executable + ext + if os.path.isfile(execname): + return execname + else: + for p in paths: + f = os.path.join(p, execname) + if os.path.isfile(f): + return f + + return None + + +def func_name(x, short=False): + """Return function name of `x` (if defined) else the `type(x)`. + If short is True and there is a shorter alias for the result, + return the alias. + + Examples + ======== + + >>> from sympy.utilities.misc import func_name + >>> from sympy import Matrix + >>> from sympy.abc import x + >>> func_name(Matrix.eye(3)) + 'MutableDenseMatrix' + >>> func_name(x < 1) + 'StrictLessThan' + >>> func_name(x < 1, short=True) + 'Lt' + """ + alias = { + 'GreaterThan': 'Ge', + 'StrictGreaterThan': 'Gt', + 'LessThan': 'Le', + 'StrictLessThan': 'Lt', + 'Equality': 'Eq', + 'Unequality': 'Ne', + } + typ = type(x) + if str(typ).startswith(">> from sympy.utilities.misc import _replace + >>> f = _replace(dict(foo='bar', d='t')) + >>> f('food') + 'bart' + >>> f = _replace({}) + >>> f('food') + 'food' + """ + if not reps: + return lambda x: x + D = lambda match: reps[match.group(0)] + pattern = _re.compile("|".join( + [_re.escape(k) for k, v in reps.items()]), _re.MULTILINE) + return lambda string: pattern.sub(D, string) + + +def replace(string, *reps): + """Return ``string`` with all keys in ``reps`` replaced with + their corresponding values, longer strings first, irrespective + of the order they are given. ``reps`` may be passed as tuples + or a single mapping. + + Examples + ======== + + >>> from sympy.utilities.misc import replace + >>> replace('foo', {'oo': 'ar', 'f': 'b'}) + 'bar' + >>> replace("spamham sha", ("spam", "eggs"), ("sha","md5")) + 'eggsham md5' + + There is no guarantee that a unique answer will be + obtained if keys in a mapping overlap (i.e. are the same + length and have some identical sequence at the + beginning/end): + + >>> reps = [ + ... ('ab', 'x'), + ... ('bc', 'y')] + >>> replace('abc', *reps) in ('xc', 'ay') + True + + References + ========== + + .. [1] https://stackoverflow.com/questions/6116978/how-to-replace-multiple-substrings-of-a-string + """ + if len(reps) == 1: + kv = reps[0] + if isinstance(kv, dict): + reps = kv + else: + return string.replace(*kv) + else: + reps = dict(reps) + return _replace(reps)(string) + + +def translate(s, a, b=None, c=None): + """Return ``s`` where characters have been replaced or deleted. + + SYNTAX + ====== + + translate(s, None, deletechars): + all characters in ``deletechars`` are deleted + translate(s, map [,deletechars]): + all characters in ``deletechars`` (if provided) are deleted + then the replacements defined by map are made; if the keys + of map are strings then the longer ones are handled first. + Multicharacter deletions should have a value of ''. + translate(s, oldchars, newchars, deletechars) + all characters in ``deletechars`` are deleted + then each character in ``oldchars`` is replaced with the + corresponding character in ``newchars`` + + Examples + ======== + + >>> from sympy.utilities.misc import translate + >>> abc = 'abc' + >>> translate(abc, None, 'a') + 'bc' + >>> translate(abc, {'a': 'x'}, 'c') + 'xb' + >>> translate(abc, {'abc': 'x', 'a': 'y'}) + 'x' + + >>> translate('abcd', 'ac', 'AC', 'd') + 'AbC' + + There is no guarantee that a unique answer will be + obtained if keys in a mapping overlap are the same + length and have some identical sequences at the + beginning/end: + + >>> translate(abc, {'ab': 'x', 'bc': 'y'}) in ('xc', 'ay') + True + """ + + mr = {} + if a is None: + if c is not None: + raise ValueError('c should be None when a=None is passed, instead got %s' % c) + if b is None: + return s + c = b + a = b = '' + else: + if isinstance(a, dict): + short = {} + for k in list(a.keys()): + if len(k) == 1 and len(a[k]) == 1: + short[k] = a.pop(k) + mr = a + c = b + if short: + a, b = [''.join(i) for i in list(zip(*short.items()))] + else: + a = b = '' + elif len(a) != len(b): + raise ValueError('oldchars and newchars have different lengths') + + if c: + val = str.maketrans('', '', c) + s = s.translate(val) + s = replace(s, mr) + n = str.maketrans(a, b) + return s.translate(n) + + +def ordinal(num): + """Return ordinal number string of num, e.g. 1 becomes 1st. + """ + # modified from https://codereview.stackexchange.com/questions/41298/producing-ordinal-numbers + n = as_int(num) + k = abs(n) % 100 + if 11 <= k <= 13: + suffix = 'th' + elif k % 10 == 1: + suffix = 'st' + elif k % 10 == 2: + suffix = 'nd' + elif k % 10 == 3: + suffix = 'rd' + else: + suffix = 'th' + return str(n) + suffix + + +def as_int(n, strict=True): + """ + Convert the argument to a builtin integer. + + The return value is guaranteed to be equal to the input. ValueError is + raised if the input has a non-integral value. When ``strict`` is True, this + uses `__index__ `_ + and when it is False it uses ``int``. + + + Examples + ======== + + >>> from sympy.utilities.misc import as_int + >>> from sympy import sqrt, S + + The function is primarily concerned with sanitizing input for + functions that need to work with builtin integers, so anything that + is unambiguously an integer should be returned as an int: + + >>> as_int(S(3)) + 3 + + Floats, being of limited precision, are not assumed to be exact and + will raise an error unless the ``strict`` flag is False. This + precision issue becomes apparent for large floating point numbers: + + >>> big = 1e23 + >>> type(big) is float + True + >>> big == int(big) + True + >>> as_int(big) + Traceback (most recent call last): + ... + ValueError: ... is not an integer + >>> as_int(big, strict=False) + 99999999999999991611392 + + Input that might be a complex representation of an integer value is + also rejected by default: + + >>> one = sqrt(3 + 2*sqrt(2)) - sqrt(2) + >>> int(one) == 1 + True + >>> as_int(one) + Traceback (most recent call last): + ... + ValueError: ... is not an integer + """ + if strict: + try: + if isinstance(n, bool): + raise TypeError + return operator.index(n) + except TypeError: + raise ValueError('%s is not an integer' % (n,)) + else: + try: + result = int(n) + except TypeError: + raise ValueError('%s is not an integer' % (n,)) + if n - result: + raise ValueError('%s is not an integer' % (n,)) + return result diff --git a/phivenv/Lib/site-packages/sympy/utilities/pkgdata.py b/phivenv/Lib/site-packages/sympy/utilities/pkgdata.py new file mode 100644 index 0000000000000000000000000000000000000000..8bf2065759362ee09a252d4736bd612b8d271e72 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/pkgdata.py @@ -0,0 +1,33 @@ +# This module is deprecated and will be removed. + +import sys +import os +from io import StringIO + +from sympy.utilities.decorator import deprecated + + +@deprecated( + """ + The sympy.utilities.pkgdata module and its get_resource function are + deprecated. Use the stdlib importlib.resources module instead. + """, + deprecated_since_version="1.12", + active_deprecations_target="pkgdata", +) +def get_resource(identifier, pkgname=__name__): + + mod = sys.modules[pkgname] + fn = getattr(mod, '__file__', None) + if fn is None: + raise OSError("%r has no __file__!") + path = os.path.join(os.path.dirname(fn), identifier) + loader = getattr(mod, '__loader__', None) + if loader is not None: + try: + data = loader.get_data(path) + except (OSError, AttributeError): + pass + else: + return StringIO(data.decode('utf-8')) + return open(os.path.normpath(path), 'rb') diff --git a/phivenv/Lib/site-packages/sympy/utilities/pytest.py b/phivenv/Lib/site-packages/sympy/utilities/pytest.py new file mode 100644 index 0000000000000000000000000000000000000000..75494c8e987c7b89bddf18febe09fdc3df2b194e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/pytest.py @@ -0,0 +1,12 @@ +""" +.. deprecated:: 1.6 + + sympy.utilities.pytest has been renamed to sympy.testing.pytest. +""" +from sympy.utilities.exceptions import sympy_deprecation_warning + +sympy_deprecation_warning("The sympy.utilities.pytest submodule is deprecated. Use sympy.testing.pytest instead.", + deprecated_since_version="1.6", + active_deprecations_target="deprecated-sympy-utilities-submodules") + +from sympy.testing.pytest import * # noqa:F401,F403 diff --git a/phivenv/Lib/site-packages/sympy/utilities/randtest.py b/phivenv/Lib/site-packages/sympy/utilities/randtest.py new file mode 100644 index 0000000000000000000000000000000000000000..1bd3472ed8a89198d9e78e125f16dae1a56f6bad --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/randtest.py @@ -0,0 +1,12 @@ +""" +.. deprecated:: 1.6 + + sympy.utilities.randtest has been renamed to sympy.core.random. +""" +from sympy.utilities.exceptions import sympy_deprecation_warning + +sympy_deprecation_warning("The sympy.utilities.randtest submodule is deprecated. Use sympy.core.random instead.", + deprecated_since_version="1.6", + active_deprecations_target="deprecated-sympy-utilities-submodules") + +from sympy.core.random import * # noqa:F401,F403 diff --git a/phivenv/Lib/site-packages/sympy/utilities/runtests.py b/phivenv/Lib/site-packages/sympy/utilities/runtests.py new file mode 100644 index 0000000000000000000000000000000000000000..bb9f760f070be528e8cd51ab38e00bc944dda9ac --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/runtests.py @@ -0,0 +1,13 @@ +""" +.. deprecated:: 1.6 + + sympy.utilities.runtests has been renamed to sympy.testing.runtests. +""" + +from sympy.utilities.exceptions import sympy_deprecation_warning + +sympy_deprecation_warning("The sympy.utilities.runtests submodule is deprecated. Use sympy.testing.runtests instead.", + deprecated_since_version="1.6", + active_deprecations_target="deprecated-sympy-utilities-submodules") + +from sympy.testing.runtests import * # noqa: F401,F403 diff --git a/phivenv/Lib/site-packages/sympy/utilities/source.py b/phivenv/Lib/site-packages/sympy/utilities/source.py new file mode 100644 index 0000000000000000000000000000000000000000..71692b4aaad07df70b63d7e1eaa86c402ad03c4f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/source.py @@ -0,0 +1,40 @@ +""" +This module adds several functions for interactive source code inspection. +""" + + +def get_class(lookup_view): + """ + Convert a string version of a class name to the object. + + For example, get_class('sympy.core.Basic') will return + class Basic located in module sympy.core + """ + if isinstance(lookup_view, str): + mod_name, func_name = get_mod_func(lookup_view) + if func_name != '': + lookup_view = getattr( + __import__(mod_name, {}, {}, ['*']), func_name) + if not callable(lookup_view): + raise AttributeError( + "'%s.%s' is not a callable." % (mod_name, func_name)) + return lookup_view + + +def get_mod_func(callback): + """ + splits the string path to a class into a string path to the module + and the name of the class. + + Examples + ======== + + >>> from sympy.utilities.source import get_mod_func + >>> get_mod_func('sympy.core.basic.Basic') + ('sympy.core.basic', 'Basic') + + """ + dot = callback.rfind('.') + if dot == -1: + return callback, '' + return callback[:dot], callback[dot + 1:] diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/__init__.py b/phivenv/Lib/site-packages/sympy/utilities/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a08322dd3e8d46c11349ada3aae46567685d53a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_autowrap.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_autowrap.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7916695c58c5515fdb34a505f6b1ffd37b10dd47 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_autowrap.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_codegen.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_codegen.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7670f0fbff29b030397d867e9952ce1a3be6773d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_codegen.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_codegen_julia.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_codegen_julia.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed823a64d71b4d2ae1d2bd0ed3f3c345b3748a40 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_codegen_julia.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_codegen_octave.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_codegen_octave.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce0e8d12aec24baaed1bd11f793d7f3cb4babce4 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_codegen_octave.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_codegen_rust.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_codegen_rust.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ebdcae89c87bae502c29670e458455be1871a8cc Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_codegen_rust.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_decorator.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_decorator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89b8429c709f9a33e1631602bd06dd4db7916576 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_decorator.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_deprecated.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_deprecated.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36c583c90073813a71ee868d07065dcfaa3e4772 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_deprecated.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_enumerative.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_enumerative.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb3f66cb18a9d850287a71ad12a143a897d58b72 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_enumerative.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_exceptions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_exceptions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25e24c3b4f573bc8dd2bb4ebdab45720ae4516cf Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_exceptions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_iterables.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_iterables.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b609b74b6c6cf6a95e5e363cf0fea7e3ce95c27e Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_iterables.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_lambdify.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_lambdify.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7c96b718cc40ea3dc890e66d257fb768db636aa Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_lambdify.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_matchpy_connector.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_matchpy_connector.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80bcb6dbadacce92ada13058d6b83f41a32b4142 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_matchpy_connector.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_mathml.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_mathml.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb2b795cb847866f6ba4ade35bfd7f321f9ff434 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_mathml.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_misc.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_misc.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3f093134e7d057cea4237560aebc28bc72fd37f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_misc.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_pickling.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_pickling.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2ce5206db29886f5c7ae1919571874765c66015 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_pickling.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_source.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_source.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab14ac5aa765410a9eb14d4a6c610b0803e6b9f9 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_source.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_timeutils.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_timeutils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ebc36443e134f61d78c958bed2b3f28d39935292 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_timeutils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_xxe.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_xxe.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1cb3c2a139cd1d9d5670aeed3fb90c38fb53263b Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/utilities/tests/__pycache__/test_xxe.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/test_autowrap.py b/phivenv/Lib/site-packages/sympy/utilities/tests/test_autowrap.py new file mode 100644 index 0000000000000000000000000000000000000000..b5a33d77adb46710cfa3cfeb1ea39402e35c76cf --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/tests/test_autowrap.py @@ -0,0 +1,467 @@ +# Tests that require installed backends go into +# sympy/test_external/test_autowrap + +import os +import tempfile +import shutil +from io import StringIO +from pathlib import Path + +from sympy.core import symbols, Eq +from sympy.utilities.autowrap import (autowrap, binary_function, + CythonCodeWrapper, UfuncifyCodeWrapper, CodeWrapper) +from sympy.utilities.codegen import ( + CCodeGen, C99CodeGen, CodeGenArgumentListError, make_routine +) +from sympy.testing.pytest import raises +from sympy.testing.tmpfiles import TmpFileManager + + +def get_string(dump_fn, routines, prefix="file", **kwargs): + """Wrapper for dump_fn. dump_fn writes its results to a stream object and + this wrapper returns the contents of that stream as a string. This + auxiliary function is used by many tests below. + + The header and the empty lines are not generator to facilitate the + testing of the output. + """ + output = StringIO() + dump_fn(routines, output, prefix, **kwargs) + source = output.getvalue() + output.close() + return source + + +def test_cython_wrapper_scalar_function(): + x, y, z = symbols('x,y,z') + expr = (x + y)*z + routine = make_routine("test", expr) + code_gen = CythonCodeWrapper(CCodeGen()) + source = get_string(code_gen.dump_pyx, [routine]) + + expected = ( + "cdef extern from 'file.h':\n" + " double test(double x, double y, double z)\n" + "\n" + "def test_c(double x, double y, double z):\n" + "\n" + " return test(x, y, z)") + assert source == expected + + +def test_cython_wrapper_outarg(): + from sympy.core.relational import Equality + x, y, z = symbols('x,y,z') + code_gen = CythonCodeWrapper(C99CodeGen()) + + routine = make_routine("test", Equality(z, x + y)) + source = get_string(code_gen.dump_pyx, [routine]) + expected = ( + "cdef extern from 'file.h':\n" + " void test(double x, double y, double *z)\n" + "\n" + "def test_c(double x, double y):\n" + "\n" + " cdef double z = 0\n" + " test(x, y, &z)\n" + " return z") + assert source == expected + + +def test_cython_wrapper_inoutarg(): + from sympy.core.relational import Equality + x, y, z = symbols('x,y,z') + code_gen = CythonCodeWrapper(C99CodeGen()) + routine = make_routine("test", Equality(z, x + y + z)) + source = get_string(code_gen.dump_pyx, [routine]) + expected = ( + "cdef extern from 'file.h':\n" + " void test(double x, double y, double *z)\n" + "\n" + "def test_c(double x, double y, double z):\n" + "\n" + " test(x, y, &z)\n" + " return z") + assert source == expected + + +def test_cython_wrapper_compile_flags(): + from sympy.core.relational import Equality + x, y, z = symbols('x,y,z') + routine = make_routine("test", Equality(z, x + y)) + + code_gen = CythonCodeWrapper(CCodeGen()) + + expected = """\ +from setuptools import setup +from setuptools import Extension +from Cython.Build import cythonize +cy_opts = {'compiler_directives': {'language_level': '3'}} + +ext_mods = [Extension( + 'wrapper_module_%(num)s', ['wrapper_module_%(num)s.pyx', 'wrapped_code_%(num)s.c'], + include_dirs=[], + library_dirs=[], + libraries=[], + extra_compile_args=['-std=c99'], + extra_link_args=[] +)] +setup(ext_modules=cythonize(ext_mods, **cy_opts)) +""" % {'num': CodeWrapper._module_counter} + + temp_dir = tempfile.mkdtemp() + TmpFileManager.tmp_folder(temp_dir) + setup_file_path = os.path.join(temp_dir, 'setup.py') + + code_gen._prepare_files(routine, build_dir=temp_dir) + setup_text = Path(setup_file_path).read_text() + assert setup_text == expected + + code_gen = CythonCodeWrapper(CCodeGen(), + include_dirs=['/usr/local/include', '/opt/booger/include'], + library_dirs=['/user/local/lib'], + libraries=['thelib', 'nilib'], + extra_compile_args=['-slow-math'], + extra_link_args=['-lswamp', '-ltrident'], + cythonize_options={'compiler_directives': {'boundscheck': False}} + ) + expected = """\ +from setuptools import setup +from setuptools import Extension +from Cython.Build import cythonize +cy_opts = {'compiler_directives': {'boundscheck': False}} + +ext_mods = [Extension( + 'wrapper_module_%(num)s', ['wrapper_module_%(num)s.pyx', 'wrapped_code_%(num)s.c'], + include_dirs=['/usr/local/include', '/opt/booger/include'], + library_dirs=['/user/local/lib'], + libraries=['thelib', 'nilib'], + extra_compile_args=['-slow-math', '-std=c99'], + extra_link_args=['-lswamp', '-ltrident'] +)] +setup(ext_modules=cythonize(ext_mods, **cy_opts)) +""" % {'num': CodeWrapper._module_counter} + + code_gen._prepare_files(routine, build_dir=temp_dir) + setup_text = Path(setup_file_path).read_text() + assert setup_text == expected + + expected = """\ +from setuptools import setup +from setuptools import Extension +from Cython.Build import cythonize +cy_opts = {'compiler_directives': {'boundscheck': False}} +import numpy as np + +ext_mods = [Extension( + 'wrapper_module_%(num)s', ['wrapper_module_%(num)s.pyx', 'wrapped_code_%(num)s.c'], + include_dirs=['/usr/local/include', '/opt/booger/include', np.get_include()], + library_dirs=['/user/local/lib'], + libraries=['thelib', 'nilib'], + extra_compile_args=['-slow-math', '-std=c99'], + extra_link_args=['-lswamp', '-ltrident'] +)] +setup(ext_modules=cythonize(ext_mods, **cy_opts)) +""" % {'num': CodeWrapper._module_counter} + + code_gen._need_numpy = True + code_gen._prepare_files(routine, build_dir=temp_dir) + setup_text = Path(setup_file_path).read_text() + assert setup_text == expected + + TmpFileManager.cleanup() + +def test_cython_wrapper_unique_dummyvars(): + from sympy.core.relational import Equality + from sympy.core.symbol import Dummy + x, y, z = Dummy('x'), Dummy('y'), Dummy('z') + x_id, y_id, z_id = [str(d.dummy_index) for d in [x, y, z]] + expr = Equality(z, x + y) + routine = make_routine("test", expr) + code_gen = CythonCodeWrapper(CCodeGen()) + source = get_string(code_gen.dump_pyx, [routine]) + expected_template = ( + "cdef extern from 'file.h':\n" + " void test(double x_{x_id}, double y_{y_id}, double *z_{z_id})\n" + "\n" + "def test_c(double x_{x_id}, double y_{y_id}):\n" + "\n" + " cdef double z_{z_id} = 0\n" + " test(x_{x_id}, y_{y_id}, &z_{z_id})\n" + " return z_{z_id}") + expected = expected_template.format(x_id=x_id, y_id=y_id, z_id=z_id) + assert source == expected + +def test_autowrap_dummy(): + x, y, z = symbols('x y z') + + # Uses DummyWrapper to test that codegen works as expected + + f = autowrap(x + y, backend='dummy') + assert f() == str(x + y) + assert f.args == "x, y" + assert f.returns == "nameless" + f = autowrap(Eq(z, x + y), backend='dummy') + assert f() == str(x + y) + assert f.args == "x, y" + assert f.returns == "z" + f = autowrap(Eq(z, x + y + z), backend='dummy') + assert f() == str(x + y + z) + assert f.args == "x, y, z" + assert f.returns == "z" + + +def test_autowrap_args(): + x, y, z = symbols('x y z') + + raises(CodeGenArgumentListError, lambda: autowrap(Eq(z, x + y), + backend='dummy', args=[x])) + f = autowrap(Eq(z, x + y), backend='dummy', args=[y, x]) + assert f() == str(x + y) + assert f.args == "y, x" + assert f.returns == "z" + + raises(CodeGenArgumentListError, lambda: autowrap(Eq(z, x + y + z), + backend='dummy', args=[x, y])) + f = autowrap(Eq(z, x + y + z), backend='dummy', args=[y, x, z]) + assert f() == str(x + y + z) + assert f.args == "y, x, z" + assert f.returns == "z" + + f = autowrap(Eq(z, x + y + z), backend='dummy', args=(y, x, z)) + assert f() == str(x + y + z) + assert f.args == "y, x, z" + assert f.returns == "z" + +def test_autowrap_store_files(): + x, y = symbols('x y') + tmp = tempfile.mkdtemp() + TmpFileManager.tmp_folder(tmp) + + f = autowrap(x + y, backend='dummy', tempdir=tmp) + assert f() == str(x + y) + assert os.access(tmp, os.F_OK) + + TmpFileManager.cleanup() + +def test_autowrap_store_files_issue_gh12939(): + x, y = symbols('x y') + tmp = './tmp' + saved_cwd = os.getcwd() + temp_cwd = tempfile.mkdtemp() + try: + os.chdir(temp_cwd) + f = autowrap(x + y, backend='dummy', tempdir=tmp) + assert f() == str(x + y) + assert os.access(tmp, os.F_OK) + finally: + os.chdir(saved_cwd) + shutil.rmtree(temp_cwd) + + +def test_binary_function(): + x, y = symbols('x y') + f = binary_function('f', x + y, backend='dummy') + assert f._imp_() == str(x + y) + + +def test_ufuncify_source(): + x, y, z = symbols('x,y,z') + code_wrapper = UfuncifyCodeWrapper(C99CodeGen("ufuncify")) + routine = make_routine("test", x + y + z) + source = get_string(code_wrapper.dump_c, [routine]) + expected = """\ +#include "Python.h" +#include "math.h" +#include "numpy/ndarraytypes.h" +#include "numpy/ufuncobject.h" +#include "numpy/halffloat.h" +#include "file.h" + +static PyMethodDef wrapper_module_%(num)sMethods[] = { + {NULL, NULL, 0, NULL} +}; + +#ifdef NPY_1_19_API_VERSION +static void test_ufunc(char **args, const npy_intp *dimensions, const npy_intp* steps, void* data) +#else +static void test_ufunc(char **args, npy_intp *dimensions, npy_intp* steps, void* data) +#endif +{ + npy_intp i; + npy_intp n = dimensions[0]; + char *in0 = args[0]; + char *in1 = args[1]; + char *in2 = args[2]; + char *out0 = args[3]; + npy_intp in0_step = steps[0]; + npy_intp in1_step = steps[1]; + npy_intp in2_step = steps[2]; + npy_intp out0_step = steps[3]; + for (i = 0; i < n; i++) { + *((double *)out0) = test(*(double *)in0, *(double *)in1, *(double *)in2); + in0 += in0_step; + in1 += in1_step; + in2 += in2_step; + out0 += out0_step; + } +} +PyUFuncGenericFunction test_funcs[1] = {&test_ufunc}; +static char test_types[4] = {NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE}; +static void *test_data[1] = {NULL}; + +#if PY_VERSION_HEX >= 0x03000000 +static struct PyModuleDef moduledef = { + PyModuleDef_HEAD_INIT, + "wrapper_module_%(num)s", + NULL, + -1, + wrapper_module_%(num)sMethods, + NULL, + NULL, + NULL, + NULL +}; + +PyMODINIT_FUNC PyInit_wrapper_module_%(num)s(void) +{ + PyObject *m, *d; + PyObject *ufunc0; + m = PyModule_Create(&moduledef); + if (!m) { + return NULL; + } + import_array(); + import_umath(); + d = PyModule_GetDict(m); + ufunc0 = PyUFunc_FromFuncAndData(test_funcs, test_data, test_types, 1, 3, 1, + PyUFunc_None, "wrapper_module_%(num)s", "Created in SymPy with Ufuncify", 0); + PyDict_SetItemString(d, "test", ufunc0); + Py_DECREF(ufunc0); + return m; +} +#else +PyMODINIT_FUNC initwrapper_module_%(num)s(void) +{ + PyObject *m, *d; + PyObject *ufunc0; + m = Py_InitModule("wrapper_module_%(num)s", wrapper_module_%(num)sMethods); + if (m == NULL) { + return; + } + import_array(); + import_umath(); + d = PyModule_GetDict(m); + ufunc0 = PyUFunc_FromFuncAndData(test_funcs, test_data, test_types, 1, 3, 1, + PyUFunc_None, "wrapper_module_%(num)s", "Created in SymPy with Ufuncify", 0); + PyDict_SetItemString(d, "test", ufunc0); + Py_DECREF(ufunc0); +} +#endif""" % {'num': CodeWrapper._module_counter} + assert source == expected + + +def test_ufuncify_source_multioutput(): + x, y, z = symbols('x,y,z') + var_symbols = (x, y, z) + expr = x + y**3 + 10*z**2 + code_wrapper = UfuncifyCodeWrapper(C99CodeGen("ufuncify")) + routines = [make_routine("func{}".format(i), expr.diff(var_symbols[i]), var_symbols) for i in range(len(var_symbols))] + source = get_string(code_wrapper.dump_c, routines, funcname='multitest') + expected = """\ +#include "Python.h" +#include "math.h" +#include "numpy/ndarraytypes.h" +#include "numpy/ufuncobject.h" +#include "numpy/halffloat.h" +#include "file.h" + +static PyMethodDef wrapper_module_%(num)sMethods[] = { + {NULL, NULL, 0, NULL} +}; + +#ifdef NPY_1_19_API_VERSION +static void multitest_ufunc(char **args, const npy_intp *dimensions, const npy_intp* steps, void* data) +#else +static void multitest_ufunc(char **args, npy_intp *dimensions, npy_intp* steps, void* data) +#endif +{ + npy_intp i; + npy_intp n = dimensions[0]; + char *in0 = args[0]; + char *in1 = args[1]; + char *in2 = args[2]; + char *out0 = args[3]; + char *out1 = args[4]; + char *out2 = args[5]; + npy_intp in0_step = steps[0]; + npy_intp in1_step = steps[1]; + npy_intp in2_step = steps[2]; + npy_intp out0_step = steps[3]; + npy_intp out1_step = steps[4]; + npy_intp out2_step = steps[5]; + for (i = 0; i < n; i++) { + *((double *)out0) = func0(*(double *)in0, *(double *)in1, *(double *)in2); + *((double *)out1) = func1(*(double *)in0, *(double *)in1, *(double *)in2); + *((double *)out2) = func2(*(double *)in0, *(double *)in1, *(double *)in2); + in0 += in0_step; + in1 += in1_step; + in2 += in2_step; + out0 += out0_step; + out1 += out1_step; + out2 += out2_step; + } +} +PyUFuncGenericFunction multitest_funcs[1] = {&multitest_ufunc}; +static char multitest_types[6] = {NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE}; +static void *multitest_data[1] = {NULL}; + +#if PY_VERSION_HEX >= 0x03000000 +static struct PyModuleDef moduledef = { + PyModuleDef_HEAD_INIT, + "wrapper_module_%(num)s", + NULL, + -1, + wrapper_module_%(num)sMethods, + NULL, + NULL, + NULL, + NULL +}; + +PyMODINIT_FUNC PyInit_wrapper_module_%(num)s(void) +{ + PyObject *m, *d; + PyObject *ufunc0; + m = PyModule_Create(&moduledef); + if (!m) { + return NULL; + } + import_array(); + import_umath(); + d = PyModule_GetDict(m); + ufunc0 = PyUFunc_FromFuncAndData(multitest_funcs, multitest_data, multitest_types, 1, 3, 3, + PyUFunc_None, "wrapper_module_%(num)s", "Created in SymPy with Ufuncify", 0); + PyDict_SetItemString(d, "multitest", ufunc0); + Py_DECREF(ufunc0); + return m; +} +#else +PyMODINIT_FUNC initwrapper_module_%(num)s(void) +{ + PyObject *m, *d; + PyObject *ufunc0; + m = Py_InitModule("wrapper_module_%(num)s", wrapper_module_%(num)sMethods); + if (m == NULL) { + return; + } + import_array(); + import_umath(); + d = PyModule_GetDict(m); + ufunc0 = PyUFunc_FromFuncAndData(multitest_funcs, multitest_data, multitest_types, 1, 3, 3, + PyUFunc_None, "wrapper_module_%(num)s", "Created in SymPy with Ufuncify", 0); + PyDict_SetItemString(d, "multitest", ufunc0); + Py_DECREF(ufunc0); +} +#endif""" % {'num': CodeWrapper._module_counter} + assert source == expected diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/test_codegen.py b/phivenv/Lib/site-packages/sympy/utilities/tests/test_codegen.py new file mode 100644 index 0000000000000000000000000000000000000000..4ccc6f9a90fb0a0bec39cea22420da8091ede740 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/tests/test_codegen.py @@ -0,0 +1,1632 @@ +from io import StringIO + +from sympy.core import symbols, Eq, pi, Catalan, Lambda, Dummy +from sympy.core.relational import Equality +from sympy.core.symbol import Symbol +from sympy.functions.special.error_functions import erf +from sympy.integrals.integrals import Integral +from sympy.matrices import Matrix, MatrixSymbol +from sympy.utilities.codegen import ( + codegen, make_routine, CCodeGen, C89CodeGen, C99CodeGen, InputArgument, + CodeGenError, FCodeGen, CodeGenArgumentListError, OutputArgument, + InOutArgument) +from sympy.testing.pytest import raises +from sympy.utilities.lambdify import implemented_function + +#FIXME: Fails due to circular import in with core +# from sympy import codegen + + +def get_string(dump_fn, routines, prefix="file", header=False, empty=False): + """Wrapper for dump_fn. dump_fn writes its results to a stream object and + this wrapper returns the contents of that stream as a string. This + auxiliary function is used by many tests below. + + The header and the empty lines are not generated to facilitate the + testing of the output. + """ + output = StringIO() + dump_fn(routines, output, prefix, header, empty) + source = output.getvalue() + output.close() + return source + + +def test_Routine_argument_order(): + a, x, y, z = symbols('a x y z') + expr = (x + y)*z + raises(CodeGenArgumentListError, lambda: make_routine("test", expr, + argument_sequence=[z, x])) + raises(CodeGenArgumentListError, lambda: make_routine("test", Eq(a, + expr), argument_sequence=[z, x, y])) + r = make_routine('test', Eq(a, expr), argument_sequence=[z, x, a, y]) + assert [ arg.name for arg in r.arguments ] == [z, x, a, y] + assert [ type(arg) for arg in r.arguments ] == [ + InputArgument, InputArgument, OutputArgument, InputArgument ] + r = make_routine('test', Eq(z, expr), argument_sequence=[z, x, y]) + assert [ type(arg) for arg in r.arguments ] == [ + InOutArgument, InputArgument, InputArgument ] + + from sympy.tensor import IndexedBase, Idx + A, B = map(IndexedBase, ['A', 'B']) + m = symbols('m', integer=True) + i = Idx('i', m) + r = make_routine('test', Eq(A[i], B[i]), argument_sequence=[B, A, m]) + assert [ arg.name for arg in r.arguments ] == [B.label, A.label, m] + + expr = Integral(x*y*z, (x, 1, 2), (y, 1, 3)) + r = make_routine('test', Eq(a, expr), argument_sequence=[z, x, a, y]) + assert [ arg.name for arg in r.arguments ] == [z, x, a, y] + + +def test_empty_c_code(): + code_gen = C89CodeGen() + source = get_string(code_gen.dump_c, []) + assert source == "#include \"file.h\"\n#include \n" + + +def test_empty_c_code_with_comment(): + code_gen = C89CodeGen() + source = get_string(code_gen.dump_c, [], header=True) + assert source[:82] == ( + "/******************************************************************************\n *" + ) + # " Code generated with SymPy 0.7.2-git " + assert source[158:] == ( "*\n" + " * *\n" + " * See http://www.sympy.org/ for more information. *\n" + " * *\n" + " * This file is part of 'project' *\n" + " ******************************************************************************/\n" + "#include \"file.h\"\n" + "#include \n" + ) + + +def test_empty_c_header(): + code_gen = C99CodeGen() + source = get_string(code_gen.dump_h, []) + assert source == "#ifndef PROJECT__FILE__H\n#define PROJECT__FILE__H\n#endif\n" + + +def test_simple_c_code(): + x, y, z = symbols('x,y,z') + expr = (x + y)*z + routine = make_routine("test", expr) + code_gen = C89CodeGen() + source = get_string(code_gen.dump_c, [routine]) + expected = ( + "#include \"file.h\"\n" + "#include \n" + "double test(double x, double y, double z) {\n" + " double test_result;\n" + " test_result = z*(x + y);\n" + " return test_result;\n" + "}\n" + ) + assert source == expected + + +def test_c_code_reserved_words(): + x, y, z = symbols('if, typedef, while') + expr = (x + y) * z + routine = make_routine("test", expr) + code_gen = C99CodeGen() + source = get_string(code_gen.dump_c, [routine]) + expected = ( + "#include \"file.h\"\n" + "#include \n" + "double test(double if_, double typedef_, double while_) {\n" + " double test_result;\n" + " test_result = while_*(if_ + typedef_);\n" + " return test_result;\n" + "}\n" + ) + assert source == expected + + +def test_numbersymbol_c_code(): + routine = make_routine("test", pi**Catalan) + code_gen = C89CodeGen() + source = get_string(code_gen.dump_c, [routine]) + expected = ( + "#include \"file.h\"\n" + "#include \n" + "double test() {\n" + " double test_result;\n" + " double const Catalan = %s;\n" + " test_result = pow(M_PI, Catalan);\n" + " return test_result;\n" + "}\n" + ) % Catalan.evalf(17) + assert source == expected + + +def test_c_code_argument_order(): + x, y, z = symbols('x,y,z') + expr = x + y + routine = make_routine("test", expr, argument_sequence=[z, x, y]) + code_gen = C89CodeGen() + source = get_string(code_gen.dump_c, [routine]) + expected = ( + "#include \"file.h\"\n" + "#include \n" + "double test(double z, double x, double y) {\n" + " double test_result;\n" + " test_result = x + y;\n" + " return test_result;\n" + "}\n" + ) + assert source == expected + + +def test_simple_c_header(): + x, y, z = symbols('x,y,z') + expr = (x + y)*z + routine = make_routine("test", expr) + code_gen = C89CodeGen() + source = get_string(code_gen.dump_h, [routine]) + expected = ( + "#ifndef PROJECT__FILE__H\n" + "#define PROJECT__FILE__H\n" + "double test(double x, double y, double z);\n" + "#endif\n" + ) + assert source == expected + + +def test_simple_c_codegen(): + x, y, z = symbols('x,y,z') + expr = (x + y)*z + expected = [ + ("file.c", + "#include \"file.h\"\n" + "#include \n" + "double test(double x, double y, double z) {\n" + " double test_result;\n" + " test_result = z*(x + y);\n" + " return test_result;\n" + "}\n"), + ("file.h", + "#ifndef PROJECT__FILE__H\n" + "#define PROJECT__FILE__H\n" + "double test(double x, double y, double z);\n" + "#endif\n") + ] + result = codegen(("test", expr), "C", "file", header=False, empty=False) + assert result == expected + + +def test_multiple_results_c(): + x, y, z = symbols('x,y,z') + expr1 = (x + y)*z + expr2 = (x - y)*z + routine = make_routine( + "test", + [expr1, expr2] + ) + code_gen = C99CodeGen() + raises(CodeGenError, lambda: get_string(code_gen.dump_h, [routine])) + + +def test_no_results_c(): + raises(ValueError, lambda: make_routine("test", [])) + + +def test_ansi_math1_codegen(): + # not included: log10 + from sympy.functions.elementary.complexes import Abs + from sympy.functions.elementary.exponential import log + from sympy.functions.elementary.hyperbolic import (cosh, sinh, tanh) + from sympy.functions.elementary.integers import (ceiling, floor) + from sympy.functions.elementary.miscellaneous import sqrt + from sympy.functions.elementary.trigonometric import (acos, asin, atan, cos, sin, tan) + x = symbols('x') + name_expr = [ + ("test_fabs", Abs(x)), + ("test_acos", acos(x)), + ("test_asin", asin(x)), + ("test_atan", atan(x)), + ("test_ceil", ceiling(x)), + ("test_cos", cos(x)), + ("test_cosh", cosh(x)), + ("test_floor", floor(x)), + ("test_log", log(x)), + ("test_ln", log(x)), + ("test_sin", sin(x)), + ("test_sinh", sinh(x)), + ("test_sqrt", sqrt(x)), + ("test_tan", tan(x)), + ("test_tanh", tanh(x)), + ] + result = codegen(name_expr, "C89", "file", header=False, empty=False) + assert result[0][0] == "file.c" + assert result[0][1] == ( + '#include "file.h"\n#include \n' + 'double test_fabs(double x) {\n double test_fabs_result;\n test_fabs_result = fabs(x);\n return test_fabs_result;\n}\n' + 'double test_acos(double x) {\n double test_acos_result;\n test_acos_result = acos(x);\n return test_acos_result;\n}\n' + 'double test_asin(double x) {\n double test_asin_result;\n test_asin_result = asin(x);\n return test_asin_result;\n}\n' + 'double test_atan(double x) {\n double test_atan_result;\n test_atan_result = atan(x);\n return test_atan_result;\n}\n' + 'double test_ceil(double x) {\n double test_ceil_result;\n test_ceil_result = ceil(x);\n return test_ceil_result;\n}\n' + 'double test_cos(double x) {\n double test_cos_result;\n test_cos_result = cos(x);\n return test_cos_result;\n}\n' + 'double test_cosh(double x) {\n double test_cosh_result;\n test_cosh_result = cosh(x);\n return test_cosh_result;\n}\n' + 'double test_floor(double x) {\n double test_floor_result;\n test_floor_result = floor(x);\n return test_floor_result;\n}\n' + 'double test_log(double x) {\n double test_log_result;\n test_log_result = log(x);\n return test_log_result;\n}\n' + 'double test_ln(double x) {\n double test_ln_result;\n test_ln_result = log(x);\n return test_ln_result;\n}\n' + 'double test_sin(double x) {\n double test_sin_result;\n test_sin_result = sin(x);\n return test_sin_result;\n}\n' + 'double test_sinh(double x) {\n double test_sinh_result;\n test_sinh_result = sinh(x);\n return test_sinh_result;\n}\n' + 'double test_sqrt(double x) {\n double test_sqrt_result;\n test_sqrt_result = sqrt(x);\n return test_sqrt_result;\n}\n' + 'double test_tan(double x) {\n double test_tan_result;\n test_tan_result = tan(x);\n return test_tan_result;\n}\n' + 'double test_tanh(double x) {\n double test_tanh_result;\n test_tanh_result = tanh(x);\n return test_tanh_result;\n}\n' + ) + assert result[1][0] == "file.h" + assert result[1][1] == ( + '#ifndef PROJECT__FILE__H\n#define PROJECT__FILE__H\n' + 'double test_fabs(double x);\ndouble test_acos(double x);\n' + 'double test_asin(double x);\ndouble test_atan(double x);\n' + 'double test_ceil(double x);\ndouble test_cos(double x);\n' + 'double test_cosh(double x);\ndouble test_floor(double x);\n' + 'double test_log(double x);\ndouble test_ln(double x);\n' + 'double test_sin(double x);\ndouble test_sinh(double x);\n' + 'double test_sqrt(double x);\ndouble test_tan(double x);\n' + 'double test_tanh(double x);\n#endif\n' + ) + + +def test_ansi_math2_codegen(): + # not included: frexp, ldexp, modf, fmod + from sympy.functions.elementary.trigonometric import atan2 + x, y = symbols('x,y') + name_expr = [ + ("test_atan2", atan2(x, y)), + ("test_pow", x**y), + ] + result = codegen(name_expr, "C89", "file", header=False, empty=False) + assert result[0][0] == "file.c" + assert result[0][1] == ( + '#include "file.h"\n#include \n' + 'double test_atan2(double x, double y) {\n double test_atan2_result;\n test_atan2_result = atan2(x, y);\n return test_atan2_result;\n}\n' + 'double test_pow(double x, double y) {\n double test_pow_result;\n test_pow_result = pow(x, y);\n return test_pow_result;\n}\n' + ) + assert result[1][0] == "file.h" + assert result[1][1] == ( + '#ifndef PROJECT__FILE__H\n#define PROJECT__FILE__H\n' + 'double test_atan2(double x, double y);\n' + 'double test_pow(double x, double y);\n' + '#endif\n' + ) + + +def test_complicated_codegen(): + from sympy.functions.elementary.trigonometric import (cos, sin, tan) + x, y, z = symbols('x,y,z') + name_expr = [ + ("test1", ((sin(x) + cos(y) + tan(z))**7).expand()), + ("test2", cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))), + ] + result = codegen(name_expr, "C89", "file", header=False, empty=False) + assert result[0][0] == "file.c" + assert result[0][1] == ( + '#include "file.h"\n#include \n' + 'double test1(double x, double y, double z) {\n' + ' double test1_result;\n' + ' test1_result = ' + 'pow(sin(x), 7) + ' + '7*pow(sin(x), 6)*cos(y) + ' + '7*pow(sin(x), 6)*tan(z) + ' + '21*pow(sin(x), 5)*pow(cos(y), 2) + ' + '42*pow(sin(x), 5)*cos(y)*tan(z) + ' + '21*pow(sin(x), 5)*pow(tan(z), 2) + ' + '35*pow(sin(x), 4)*pow(cos(y), 3) + ' + '105*pow(sin(x), 4)*pow(cos(y), 2)*tan(z) + ' + '105*pow(sin(x), 4)*cos(y)*pow(tan(z), 2) + ' + '35*pow(sin(x), 4)*pow(tan(z), 3) + ' + '35*pow(sin(x), 3)*pow(cos(y), 4) + ' + '140*pow(sin(x), 3)*pow(cos(y), 3)*tan(z) + ' + '210*pow(sin(x), 3)*pow(cos(y), 2)*pow(tan(z), 2) + ' + '140*pow(sin(x), 3)*cos(y)*pow(tan(z), 3) + ' + '35*pow(sin(x), 3)*pow(tan(z), 4) + ' + '21*pow(sin(x), 2)*pow(cos(y), 5) + ' + '105*pow(sin(x), 2)*pow(cos(y), 4)*tan(z) + ' + '210*pow(sin(x), 2)*pow(cos(y), 3)*pow(tan(z), 2) + ' + '210*pow(sin(x), 2)*pow(cos(y), 2)*pow(tan(z), 3) + ' + '105*pow(sin(x), 2)*cos(y)*pow(tan(z), 4) + ' + '21*pow(sin(x), 2)*pow(tan(z), 5) + ' + '7*sin(x)*pow(cos(y), 6) + ' + '42*sin(x)*pow(cos(y), 5)*tan(z) + ' + '105*sin(x)*pow(cos(y), 4)*pow(tan(z), 2) + ' + '140*sin(x)*pow(cos(y), 3)*pow(tan(z), 3) + ' + '105*sin(x)*pow(cos(y), 2)*pow(tan(z), 4) + ' + '42*sin(x)*cos(y)*pow(tan(z), 5) + ' + '7*sin(x)*pow(tan(z), 6) + ' + 'pow(cos(y), 7) + ' + '7*pow(cos(y), 6)*tan(z) + ' + '21*pow(cos(y), 5)*pow(tan(z), 2) + ' + '35*pow(cos(y), 4)*pow(tan(z), 3) + ' + '35*pow(cos(y), 3)*pow(tan(z), 4) + ' + '21*pow(cos(y), 2)*pow(tan(z), 5) + ' + '7*cos(y)*pow(tan(z), 6) + ' + 'pow(tan(z), 7);\n' + ' return test1_result;\n' + '}\n' + 'double test2(double x, double y, double z) {\n' + ' double test2_result;\n' + ' test2_result = cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))));\n' + ' return test2_result;\n' + '}\n' + ) + assert result[1][0] == "file.h" + assert result[1][1] == ( + '#ifndef PROJECT__FILE__H\n' + '#define PROJECT__FILE__H\n' + 'double test1(double x, double y, double z);\n' + 'double test2(double x, double y, double z);\n' + '#endif\n' + ) + + +def test_loops_c(): + from sympy.tensor import IndexedBase, Idx + from sympy.core.symbol import symbols + n, m = symbols('n m', integer=True) + A = IndexedBase('A') + x = IndexedBase('x') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + + (f1, code), (f2, interface) = codegen( + ('matrix_vector', Eq(y[i], A[i, j]*x[j])), "C99", "file", header=False, empty=False) + + assert f1 == 'file.c' + expected = ( + '#include "file.h"\n' + '#include \n' + 'void matrix_vector(double *A, int m, int n, double *x, double *y) {\n' + ' for (int i=0; i\n' + 'void test_dummies(int m_%(mno)i, double *x, double *y) {\n' + ' for (int i_%(ino)i=0; i_%(ino)i\n' + 'void matrix_vector(double *A, int m, int n, int o, int p, double *x, double *y) {\n' + ' for (int i=o; i<%(upperi)s; i++){\n' + ' y[i] = 0;\n' + ' }\n' + ' for (int i=o; i<%(upperi)s; i++){\n' + ' for (int j=0; j\n' + 'double foo(double x, double *y) {\n' + ' (*y) = sin(x);\n' + ' double foo_result;\n' + ' foo_result = cos(x);\n' + ' return foo_result;\n' + '}\n' + ) + assert result[0][1] == expected + + +def test_output_arg_c_reserved_words(): + from sympy.core.relational import Equality + from sympy.functions.elementary.trigonometric import (cos, sin) + x, y, z = symbols("if, while, z") + r = make_routine("foo", [Equality(y, sin(x)), cos(x)]) + c = C89CodeGen() + result = c.write([r], "test", header=False, empty=False) + assert result[0][0] == "test.c" + expected = ( + '#include "test.h"\n' + '#include \n' + 'double foo(double if_, double *while_) {\n' + ' (*while_) = sin(if_);\n' + ' double foo_result;\n' + ' foo_result = cos(if_);\n' + ' return foo_result;\n' + '}\n' + ) + assert result[0][1] == expected + + +def test_multidim_c_argument_cse(): + A_sym = MatrixSymbol('A', 3, 3) + b_sym = MatrixSymbol('b', 3, 1) + A = Matrix(A_sym) + b = Matrix(b_sym) + c = A*b + cgen = CCodeGen(project="test", cse=True) + r = cgen.routine("c", c) + r.arguments[-1].result_var = "out" + r.arguments[-1]._name = "out" + code = get_string(cgen.dump_c, [r], prefix="test") + expected = ( + '#include "test.h"\n' + "#include \n" + "void c(double *A, double *b, double *out) {\n" + " out[0] = A[0]*b[0] + A[1]*b[1] + A[2]*b[2];\n" + " out[1] = A[3]*b[0] + A[4]*b[1] + A[5]*b[2];\n" + " out[2] = A[6]*b[0] + A[7]*b[1] + A[8]*b[2];\n" + "}\n" + ) + assert code == expected + + +def test_ccode_results_named_ordered(): + x, y, z = symbols('x,y,z') + B, C = symbols('B,C') + A = MatrixSymbol('A', 1, 3) + expr1 = Equality(A, Matrix([[1, 2, x]])) + expr2 = Equality(C, (x + y)*z) + expr3 = Equality(B, 2*x) + name_expr = ("test", [expr1, expr2, expr3]) + expected = ( + '#include "test.h"\n' + '#include \n' + 'void test(double x, double *C, double z, double y, double *A, double *B) {\n' + ' (*C) = z*(x + y);\n' + ' A[0] = 1;\n' + ' A[1] = 2;\n' + ' A[2] = x;\n' + ' (*B) = 2*x;\n' + '}\n' + ) + + result = codegen(name_expr, "c", "test", header=False, empty=False, + argument_sequence=(x, C, z, y, A, B)) + source = result[0][1] + assert source == expected + + +def test_ccode_matrixsymbol_slice(): + A = MatrixSymbol('A', 5, 3) + B = MatrixSymbol('B', 1, 3) + C = MatrixSymbol('C', 1, 3) + D = MatrixSymbol('D', 5, 1) + name_expr = ("test", [Equality(B, A[0, :]), + Equality(C, A[1, :]), + Equality(D, A[:, 2])]) + result = codegen(name_expr, "c99", "test", header=False, empty=False) + source = result[0][1] + expected = ( + '#include "test.h"\n' + '#include \n' + 'void test(double *A, double *B, double *C, double *D) {\n' + ' B[0] = A[0];\n' + ' B[1] = A[1];\n' + ' B[2] = A[2];\n' + ' C[0] = A[3];\n' + ' C[1] = A[4];\n' + ' C[2] = A[5];\n' + ' D[0] = A[2];\n' + ' D[1] = A[5];\n' + ' D[2] = A[8];\n' + ' D[3] = A[11];\n' + ' D[4] = A[14];\n' + '}\n' + ) + assert source == expected + +def test_ccode_cse(): + a, b, c, d = symbols('a b c d') + e = MatrixSymbol('e', 3, 1) + name_expr = ("test", [Equality(e, Matrix([[a*b], [a*b + c*d], [a*b*c*d]]))]) + generator = CCodeGen(cse=True) + result = codegen(name_expr, code_gen=generator, header=False, empty=False) + source = result[0][1] + expected = ( + '#include "test.h"\n' + '#include \n' + 'void test(double a, double b, double c, double d, double *e) {\n' + ' const double x0 = a*b;\n' + ' const double x1 = c*d;\n' + ' e[0] = x0;\n' + ' e[1] = x0 + x1;\n' + ' e[2] = x0*x1;\n' + '}\n' + ) + assert source == expected + +def test_ccode_unused_array_arg(): + x = MatrixSymbol('x', 2, 1) + # x does not appear in output + name_expr = ("test", 1.0) + generator = CCodeGen() + result = codegen(name_expr, code_gen=generator, header=False, empty=False, argument_sequence=(x,)) + source = result[0][1] + # note: x should appear as (double *) + expected = ( + '#include "test.h"\n' + '#include \n' + 'double test(double *x) {\n' + ' double test_result;\n' + ' test_result = 1.0;\n' + ' return test_result;\n' + '}\n' + ) + assert source == expected + +def test_ccode_unused_array_arg_func(): + # issue 16689 + X = MatrixSymbol('X',3,1) + Y = MatrixSymbol('Y',3,1) + z = symbols('z',integer = True) + name_expr = ('testBug', X[0] + X[1]) + result = codegen(name_expr, language='C', header=False, empty=False, argument_sequence=(X, Y, z)) + source = result[0][1] + expected = ( + '#include "testBug.h"\n' + '#include \n' + 'double testBug(double *X, double *Y, int z) {\n' + ' double testBug_result;\n' + ' testBug_result = X[0] + X[1];\n' + ' return testBug_result;\n' + '}\n' + ) + assert source == expected + +def test_empty_f_code(): + code_gen = FCodeGen() + source = get_string(code_gen.dump_f95, []) + assert source == "" + + +def test_empty_f_code_with_header(): + code_gen = FCodeGen() + source = get_string(code_gen.dump_f95, [], header=True) + assert source[:82] == ( + "!******************************************************************************\n!*" + ) + # " Code generated with SymPy 0.7.2-git " + assert source[158:] == ( "*\n" + "!* *\n" + "!* See http://www.sympy.org/ for more information. *\n" + "!* *\n" + "!* This file is part of 'project' *\n" + "!******************************************************************************\n" + ) + + +def test_empty_f_header(): + code_gen = FCodeGen() + source = get_string(code_gen.dump_h, []) + assert source == "" + + +def test_simple_f_code(): + x, y, z = symbols('x,y,z') + expr = (x + y)*z + routine = make_routine("test", expr) + code_gen = FCodeGen() + source = get_string(code_gen.dump_f95, [routine]) + expected = ( + "REAL*8 function test(x, y, z)\n" + "implicit none\n" + "REAL*8, intent(in) :: x\n" + "REAL*8, intent(in) :: y\n" + "REAL*8, intent(in) :: z\n" + "test = z*(x + y)\n" + "end function\n" + ) + assert source == expected + + +def test_numbersymbol_f_code(): + routine = make_routine("test", pi**Catalan) + code_gen = FCodeGen() + source = get_string(code_gen.dump_f95, [routine]) + expected = ( + "REAL*8 function test()\n" + "implicit none\n" + "REAL*8, parameter :: Catalan = %sd0\n" + "REAL*8, parameter :: pi = %sd0\n" + "test = pi**Catalan\n" + "end function\n" + ) % (Catalan.evalf(17), pi.evalf(17)) + assert source == expected + +def test_erf_f_code(): + x = symbols('x') + routine = make_routine("test", erf(x) - erf(-2 * x)) + code_gen = FCodeGen() + source = get_string(code_gen.dump_f95, [routine]) + expected = ( + "REAL*8 function test(x)\n" + "implicit none\n" + "REAL*8, intent(in) :: x\n" + "test = erf(x) + erf(2.0d0*x)\n" + "end function\n" + ) + assert source == expected, source + +def test_f_code_argument_order(): + x, y, z = symbols('x,y,z') + expr = x + y + routine = make_routine("test", expr, argument_sequence=[z, x, y]) + code_gen = FCodeGen() + source = get_string(code_gen.dump_f95, [routine]) + expected = ( + "REAL*8 function test(z, x, y)\n" + "implicit none\n" + "REAL*8, intent(in) :: z\n" + "REAL*8, intent(in) :: x\n" + "REAL*8, intent(in) :: y\n" + "test = x + y\n" + "end function\n" + ) + assert source == expected + + +def test_simple_f_header(): + x, y, z = symbols('x,y,z') + expr = (x + y)*z + routine = make_routine("test", expr) + code_gen = FCodeGen() + source = get_string(code_gen.dump_h, [routine]) + expected = ( + "interface\n" + "REAL*8 function test(x, y, z)\n" + "implicit none\n" + "REAL*8, intent(in) :: x\n" + "REAL*8, intent(in) :: y\n" + "REAL*8, intent(in) :: z\n" + "end function\n" + "end interface\n" + ) + assert source == expected + + +def test_simple_f_codegen(): + x, y, z = symbols('x,y,z') + expr = (x + y)*z + result = codegen( + ("test", expr), "F95", "file", header=False, empty=False) + expected = [ + ("file.f90", + "REAL*8 function test(x, y, z)\n" + "implicit none\n" + "REAL*8, intent(in) :: x\n" + "REAL*8, intent(in) :: y\n" + "REAL*8, intent(in) :: z\n" + "test = z*(x + y)\n" + "end function\n"), + ("file.h", + "interface\n" + "REAL*8 function test(x, y, z)\n" + "implicit none\n" + "REAL*8, intent(in) :: x\n" + "REAL*8, intent(in) :: y\n" + "REAL*8, intent(in) :: z\n" + "end function\n" + "end interface\n") + ] + assert result == expected + + +def test_multiple_results_f(): + x, y, z = symbols('x,y,z') + expr1 = (x + y)*z + expr2 = (x - y)*z + routine = make_routine( + "test", + [expr1, expr2] + ) + code_gen = FCodeGen() + raises(CodeGenError, lambda: get_string(code_gen.dump_h, [routine])) + + +def test_no_results_f(): + raises(ValueError, lambda: make_routine("test", [])) + + +def test_intrinsic_math_codegen(): + # not included: log10 + from sympy.functions.elementary.complexes import Abs + from sympy.functions.elementary.exponential import log + from sympy.functions.elementary.hyperbolic import (cosh, sinh, tanh) + from sympy.functions.elementary.miscellaneous import sqrt + from sympy.functions.elementary.trigonometric import (acos, asin, atan, cos, sin, tan) + x = symbols('x') + name_expr = [ + ("test_abs", Abs(x)), + ("test_acos", acos(x)), + ("test_asin", asin(x)), + ("test_atan", atan(x)), + ("test_cos", cos(x)), + ("test_cosh", cosh(x)), + ("test_log", log(x)), + ("test_ln", log(x)), + ("test_sin", sin(x)), + ("test_sinh", sinh(x)), + ("test_sqrt", sqrt(x)), + ("test_tan", tan(x)), + ("test_tanh", tanh(x)), + ] + result = codegen(name_expr, "F95", "file", header=False, empty=False) + assert result[0][0] == "file.f90" + expected = ( + 'REAL*8 function test_abs(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'test_abs = abs(x)\n' + 'end function\n' + 'REAL*8 function test_acos(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'test_acos = acos(x)\n' + 'end function\n' + 'REAL*8 function test_asin(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'test_asin = asin(x)\n' + 'end function\n' + 'REAL*8 function test_atan(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'test_atan = atan(x)\n' + 'end function\n' + 'REAL*8 function test_cos(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'test_cos = cos(x)\n' + 'end function\n' + 'REAL*8 function test_cosh(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'test_cosh = cosh(x)\n' + 'end function\n' + 'REAL*8 function test_log(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'test_log = log(x)\n' + 'end function\n' + 'REAL*8 function test_ln(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'test_ln = log(x)\n' + 'end function\n' + 'REAL*8 function test_sin(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'test_sin = sin(x)\n' + 'end function\n' + 'REAL*8 function test_sinh(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'test_sinh = sinh(x)\n' + 'end function\n' + 'REAL*8 function test_sqrt(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'test_sqrt = sqrt(x)\n' + 'end function\n' + 'REAL*8 function test_tan(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'test_tan = tan(x)\n' + 'end function\n' + 'REAL*8 function test_tanh(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'test_tanh = tanh(x)\n' + 'end function\n' + ) + assert result[0][1] == expected + + assert result[1][0] == "file.h" + expected = ( + 'interface\n' + 'REAL*8 function test_abs(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'end function\n' + 'end interface\n' + 'interface\n' + 'REAL*8 function test_acos(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'end function\n' + 'end interface\n' + 'interface\n' + 'REAL*8 function test_asin(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'end function\n' + 'end interface\n' + 'interface\n' + 'REAL*8 function test_atan(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'end function\n' + 'end interface\n' + 'interface\n' + 'REAL*8 function test_cos(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'end function\n' + 'end interface\n' + 'interface\n' + 'REAL*8 function test_cosh(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'end function\n' + 'end interface\n' + 'interface\n' + 'REAL*8 function test_log(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'end function\n' + 'end interface\n' + 'interface\n' + 'REAL*8 function test_ln(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'end function\n' + 'end interface\n' + 'interface\n' + 'REAL*8 function test_sin(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'end function\n' + 'end interface\n' + 'interface\n' + 'REAL*8 function test_sinh(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'end function\n' + 'end interface\n' + 'interface\n' + 'REAL*8 function test_sqrt(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'end function\n' + 'end interface\n' + 'interface\n' + 'REAL*8 function test_tan(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'end function\n' + 'end interface\n' + 'interface\n' + 'REAL*8 function test_tanh(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'end function\n' + 'end interface\n' + ) + assert result[1][1] == expected + + +def test_intrinsic_math2_codegen(): + # not included: frexp, ldexp, modf, fmod + from sympy.functions.elementary.trigonometric import atan2 + x, y = symbols('x,y') + name_expr = [ + ("test_atan2", atan2(x, y)), + ("test_pow", x**y), + ] + result = codegen(name_expr, "F95", "file", header=False, empty=False) + assert result[0][0] == "file.f90" + expected = ( + 'REAL*8 function test_atan2(x, y)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'REAL*8, intent(in) :: y\n' + 'test_atan2 = atan2(x, y)\n' + 'end function\n' + 'REAL*8 function test_pow(x, y)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'REAL*8, intent(in) :: y\n' + 'test_pow = x**y\n' + 'end function\n' + ) + assert result[0][1] == expected + + assert result[1][0] == "file.h" + expected = ( + 'interface\n' + 'REAL*8 function test_atan2(x, y)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'REAL*8, intent(in) :: y\n' + 'end function\n' + 'end interface\n' + 'interface\n' + 'REAL*8 function test_pow(x, y)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'REAL*8, intent(in) :: y\n' + 'end function\n' + 'end interface\n' + ) + assert result[1][1] == expected + + +def test_complicated_codegen_f95(): + from sympy.functions.elementary.trigonometric import (cos, sin, tan) + x, y, z = symbols('x,y,z') + name_expr = [ + ("test1", ((sin(x) + cos(y) + tan(z))**7).expand()), + ("test2", cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))), + ] + result = codegen(name_expr, "F95", "file", header=False, empty=False) + assert result[0][0] == "file.f90" + expected = ( + 'REAL*8 function test1(x, y, z)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'REAL*8, intent(in) :: y\n' + 'REAL*8, intent(in) :: z\n' + 'test1 = sin(x)**7 + 7*sin(x)**6*cos(y) + 7*sin(x)**6*tan(z) + 21*sin(x) &\n' + ' **5*cos(y)**2 + 42*sin(x)**5*cos(y)*tan(z) + 21*sin(x)**5*tan(z) &\n' + ' **2 + 35*sin(x)**4*cos(y)**3 + 105*sin(x)**4*cos(y)**2*tan(z) + &\n' + ' 105*sin(x)**4*cos(y)*tan(z)**2 + 35*sin(x)**4*tan(z)**3 + 35*sin( &\n' + ' x)**3*cos(y)**4 + 140*sin(x)**3*cos(y)**3*tan(z) + 210*sin(x)**3* &\n' + ' cos(y)**2*tan(z)**2 + 140*sin(x)**3*cos(y)*tan(z)**3 + 35*sin(x) &\n' + ' **3*tan(z)**4 + 21*sin(x)**2*cos(y)**5 + 105*sin(x)**2*cos(y)**4* &\n' + ' tan(z) + 210*sin(x)**2*cos(y)**3*tan(z)**2 + 210*sin(x)**2*cos(y) &\n' + ' **2*tan(z)**3 + 105*sin(x)**2*cos(y)*tan(z)**4 + 21*sin(x)**2*tan &\n' + ' (z)**5 + 7*sin(x)*cos(y)**6 + 42*sin(x)*cos(y)**5*tan(z) + 105* &\n' + ' sin(x)*cos(y)**4*tan(z)**2 + 140*sin(x)*cos(y)**3*tan(z)**3 + 105 &\n' + ' *sin(x)*cos(y)**2*tan(z)**4 + 42*sin(x)*cos(y)*tan(z)**5 + 7*sin( &\n' + ' x)*tan(z)**6 + cos(y)**7 + 7*cos(y)**6*tan(z) + 21*cos(y)**5*tan( &\n' + ' z)**2 + 35*cos(y)**4*tan(z)**3 + 35*cos(y)**3*tan(z)**4 + 21*cos( &\n' + ' y)**2*tan(z)**5 + 7*cos(y)*tan(z)**6 + tan(z)**7\n' + 'end function\n' + 'REAL*8 function test2(x, y, z)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'REAL*8, intent(in) :: y\n' + 'REAL*8, intent(in) :: z\n' + 'test2 = cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))\n' + 'end function\n' + ) + assert result[0][1] == expected + assert result[1][0] == "file.h" + expected = ( + 'interface\n' + 'REAL*8 function test1(x, y, z)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'REAL*8, intent(in) :: y\n' + 'REAL*8, intent(in) :: z\n' + 'end function\n' + 'end interface\n' + 'interface\n' + 'REAL*8 function test2(x, y, z)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'REAL*8, intent(in) :: y\n' + 'REAL*8, intent(in) :: z\n' + 'end function\n' + 'end interface\n' + ) + assert result[1][1] == expected + + +def test_loops(): + from sympy.tensor import IndexedBase, Idx + from sympy.core.symbol import symbols + + n, m = symbols('n,m', integer=True) + A, x, y = map(IndexedBase, 'Axy') + i = Idx('i', m) + j = Idx('j', n) + + (f1, code), (f2, interface) = codegen( + ('matrix_vector', Eq(y[i], A[i, j]*x[j])), "F95", "file", header=False, empty=False) + + assert f1 == 'file.f90' + expected = ( + 'subroutine matrix_vector(A, m, n, x, y)\n' + 'implicit none\n' + 'INTEGER*4, intent(in) :: m\n' + 'INTEGER*4, intent(in) :: n\n' + 'REAL*8, intent(in), dimension(1:m, 1:n) :: A\n' + 'REAL*8, intent(in), dimension(1:n) :: x\n' + 'REAL*8, intent(out), dimension(1:m) :: y\n' + 'INTEGER*4 :: i\n' + 'INTEGER*4 :: j\n' + 'do i = 1, m\n' + ' y(i) = 0\n' + 'end do\n' + 'do i = 1, m\n' + ' do j = 1, n\n' + ' y(i) = %(rhs)s + y(i)\n' + ' end do\n' + 'end do\n' + 'end subroutine\n' + ) + + assert code == expected % {'rhs': 'A(i, j)*x(j)'} or\ + code == expected % {'rhs': 'x(j)*A(i, j)'} + assert f2 == 'file.h' + assert interface == ( + 'interface\n' + 'subroutine matrix_vector(A, m, n, x, y)\n' + 'implicit none\n' + 'INTEGER*4, intent(in) :: m\n' + 'INTEGER*4, intent(in) :: n\n' + 'REAL*8, intent(in), dimension(1:m, 1:n) :: A\n' + 'REAL*8, intent(in), dimension(1:n) :: x\n' + 'REAL*8, intent(out), dimension(1:m) :: y\n' + 'end subroutine\n' + 'end interface\n' + ) + + +def test_dummy_loops_f95(): + from sympy.tensor import IndexedBase, Idx + i, m = symbols('i m', integer=True, cls=Dummy) + x = IndexedBase('x') + y = IndexedBase('y') + i = Idx(i, m) + expected = ( + 'subroutine test_dummies(m_%(mcount)i, x, y)\n' + 'implicit none\n' + 'INTEGER*4, intent(in) :: m_%(mcount)i\n' + 'REAL*8, intent(in), dimension(1:m_%(mcount)i) :: x\n' + 'REAL*8, intent(out), dimension(1:m_%(mcount)i) :: y\n' + 'INTEGER*4 :: i_%(icount)i\n' + 'do i_%(icount)i = 1, m_%(mcount)i\n' + ' y(i_%(icount)i) = x(i_%(icount)i)\n' + 'end do\n' + 'end subroutine\n' + ) % {'icount': i.label.dummy_index, 'mcount': m.dummy_index} + r = make_routine('test_dummies', Eq(y[i], x[i])) + c = FCodeGen() + code = get_string(c.dump_f95, [r]) + assert code == expected + + +def test_loops_InOut(): + from sympy.tensor import IndexedBase, Idx + from sympy.core.symbol import symbols + + i, j, n, m = symbols('i,j,n,m', integer=True) + A, x, y = symbols('A,x,y') + A = IndexedBase(A)[Idx(i, m), Idx(j, n)] + x = IndexedBase(x)[Idx(j, n)] + y = IndexedBase(y)[Idx(i, m)] + + (f1, code), (f2, interface) = codegen( + ('matrix_vector', Eq(y, y + A*x)), "F95", "file", header=False, empty=False) + + assert f1 == 'file.f90' + expected = ( + 'subroutine matrix_vector(A, m, n, x, y)\n' + 'implicit none\n' + 'INTEGER*4, intent(in) :: m\n' + 'INTEGER*4, intent(in) :: n\n' + 'REAL*8, intent(in), dimension(1:m, 1:n) :: A\n' + 'REAL*8, intent(in), dimension(1:n) :: x\n' + 'REAL*8, intent(inout), dimension(1:m) :: y\n' + 'INTEGER*4 :: i\n' + 'INTEGER*4 :: j\n' + 'do i = 1, m\n' + ' do j = 1, n\n' + ' y(i) = %(rhs)s + y(i)\n' + ' end do\n' + 'end do\n' + 'end subroutine\n' + ) + + assert (code == expected % {'rhs': 'A(i, j)*x(j)'} or + code == expected % {'rhs': 'x(j)*A(i, j)'}) + assert f2 == 'file.h' + assert interface == ( + 'interface\n' + 'subroutine matrix_vector(A, m, n, x, y)\n' + 'implicit none\n' + 'INTEGER*4, intent(in) :: m\n' + 'INTEGER*4, intent(in) :: n\n' + 'REAL*8, intent(in), dimension(1:m, 1:n) :: A\n' + 'REAL*8, intent(in), dimension(1:n) :: x\n' + 'REAL*8, intent(inout), dimension(1:m) :: y\n' + 'end subroutine\n' + 'end interface\n' + ) + + +def test_partial_loops_f(): + # check that loop boundaries are determined by Idx, and array strides + # determined by shape of IndexedBase object. + from sympy.tensor import IndexedBase, Idx + from sympy.core.symbol import symbols + n, m, o, p = symbols('n m o p', integer=True) + A = IndexedBase('A', shape=(m, p)) + x = IndexedBase('x') + y = IndexedBase('y') + i = Idx('i', (o, m - 5)) # Note: bounds are inclusive + j = Idx('j', n) # dimension n corresponds to bounds (0, n - 1) + + (f1, code), (f2, interface) = codegen( + ('matrix_vector', Eq(y[i], A[i, j]*x[j])), "F95", "file", header=False, empty=False) + + expected = ( + 'subroutine matrix_vector(A, m, n, o, p, x, y)\n' + 'implicit none\n' + 'INTEGER*4, intent(in) :: m\n' + 'INTEGER*4, intent(in) :: n\n' + 'INTEGER*4, intent(in) :: o\n' + 'INTEGER*4, intent(in) :: p\n' + 'REAL*8, intent(in), dimension(1:m, 1:p) :: A\n' + 'REAL*8, intent(in), dimension(1:n) :: x\n' + 'REAL*8, intent(out), dimension(1:%(iup-ilow)s) :: y\n' + 'INTEGER*4 :: i\n' + 'INTEGER*4 :: j\n' + 'do i = %(ilow)s, %(iup)s\n' + ' y(i) = 0\n' + 'end do\n' + 'do i = %(ilow)s, %(iup)s\n' + ' do j = 1, n\n' + ' y(i) = %(rhs)s + y(i)\n' + ' end do\n' + 'end do\n' + 'end subroutine\n' + ) % { + 'rhs': '%(rhs)s', + 'iup': str(m - 4), + 'ilow': str(1 + o), + 'iup-ilow': str(m - 4 - o) + } + + assert code == expected % {'rhs': 'A(i, j)*x(j)'} or\ + code == expected % {'rhs': 'x(j)*A(i, j)'} + + +def test_output_arg_f(): + from sympy.core.relational import Equality + from sympy.functions.elementary.trigonometric import (cos, sin) + x, y, z = symbols("x,y,z") + r = make_routine("foo", [Equality(y, sin(x)), cos(x)]) + c = FCodeGen() + result = c.write([r], "test", header=False, empty=False) + assert result[0][0] == "test.f90" + assert result[0][1] == ( + 'REAL*8 function foo(x, y)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'REAL*8, intent(out) :: y\n' + 'y = sin(x)\n' + 'foo = cos(x)\n' + 'end function\n' + ) + + +def test_inline_function(): + from sympy.tensor import IndexedBase, Idx + from sympy.core.symbol import symbols + n, m = symbols('n m', integer=True) + A, x, y = map(IndexedBase, 'Axy') + i = Idx('i', m) + p = FCodeGen() + func = implemented_function('func', Lambda(n, n*(n + 1))) + routine = make_routine('test_inline', Eq(y[i], func(x[i]))) + code = get_string(p.dump_f95, [routine]) + expected = ( + 'subroutine test_inline(m, x, y)\n' + 'implicit none\n' + 'INTEGER*4, intent(in) :: m\n' + 'REAL*8, intent(in), dimension(1:m) :: x\n' + 'REAL*8, intent(out), dimension(1:m) :: y\n' + 'INTEGER*4 :: i\n' + 'do i = 1, m\n' + ' y(i) = %s*%s\n' + 'end do\n' + 'end subroutine\n' + ) + args = ('x(i)', '(x(i) + 1)') + assert code == expected % args or\ + code == expected % args[::-1] + + +def test_f_code_call_signature_wrap(): + # Issue #7934 + x = symbols('x:20') + expr = 0 + for sym in x: + expr += sym + routine = make_routine("test", expr) + code_gen = FCodeGen() + source = get_string(code_gen.dump_f95, [routine]) + expected = """\ +REAL*8 function test(x0, x1, x10, x11, x12, x13, x14, x15, x16, x17, x18, & + x19, x2, x3, x4, x5, x6, x7, x8, x9) +implicit none +REAL*8, intent(in) :: x0 +REAL*8, intent(in) :: x1 +REAL*8, intent(in) :: x10 +REAL*8, intent(in) :: x11 +REAL*8, intent(in) :: x12 +REAL*8, intent(in) :: x13 +REAL*8, intent(in) :: x14 +REAL*8, intent(in) :: x15 +REAL*8, intent(in) :: x16 +REAL*8, intent(in) :: x17 +REAL*8, intent(in) :: x18 +REAL*8, intent(in) :: x19 +REAL*8, intent(in) :: x2 +REAL*8, intent(in) :: x3 +REAL*8, intent(in) :: x4 +REAL*8, intent(in) :: x5 +REAL*8, intent(in) :: x6 +REAL*8, intent(in) :: x7 +REAL*8, intent(in) :: x8 +REAL*8, intent(in) :: x9 +test = x0 + x1 + x10 + x11 + x12 + x13 + x14 + x15 + x16 + x17 + x18 + & + x19 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9 +end function +""" + assert source == expected + + +def test_check_case(): + x, X = symbols('x,X') + raises(CodeGenError, lambda: codegen(('test', x*X), 'f95', 'prefix')) + + +def test_check_case_false_positive(): + # The upper case/lower case exception should not be triggered by SymPy + # objects that differ only because of assumptions. (It may be useful to + # have a check for that as well, but here we only want to test against + # false positives with respect to case checking.) + x1 = symbols('x') + x2 = symbols('x', my_assumption=True) + try: + codegen(('test', x1*x2), 'f95', 'prefix') + except CodeGenError as e: + if e.args[0].startswith("Fortran ignores case."): + raise AssertionError("This exception should not be raised!") + + +def test_c_fortran_omit_routine_name(): + x, y = symbols("x,y") + name_expr = [("foo", 2*x)] + result = codegen(name_expr, "F95", header=False, empty=False) + expresult = codegen(name_expr, "F95", "foo", header=False, empty=False) + assert result[0][1] == expresult[0][1] + + name_expr = ("foo", x*y) + result = codegen(name_expr, "F95", header=False, empty=False) + expresult = codegen(name_expr, "F95", "foo", header=False, empty=False) + assert result[0][1] == expresult[0][1] + + name_expr = ("foo", Matrix([[x, y], [x+y, x-y]])) + result = codegen(name_expr, "C89", header=False, empty=False) + expresult = codegen(name_expr, "C89", "foo", header=False, empty=False) + assert result[0][1] == expresult[0][1] + + +def test_fcode_matrix_output(): + x, y, z = symbols('x,y,z') + e1 = x + y + e2 = Matrix([[x, y], [z, 16]]) + name_expr = ("test", (e1, e2)) + result = codegen(name_expr, "f95", "test", header=False, empty=False) + source = result[0][1] + expected = ( + "REAL*8 function test(x, y, z, out_%(hash)s)\n" + "implicit none\n" + "REAL*8, intent(in) :: x\n" + "REAL*8, intent(in) :: y\n" + "REAL*8, intent(in) :: z\n" + "REAL*8, intent(out), dimension(1:2, 1:2) :: out_%(hash)s\n" + "out_%(hash)s(1, 1) = x\n" + "out_%(hash)s(2, 1) = z\n" + "out_%(hash)s(1, 2) = y\n" + "out_%(hash)s(2, 2) = 16\n" + "test = x + y\n" + "end function\n" + ) + # look for the magic number + a = source.splitlines()[5] + b = a.split('_') + out = b[1] + expected = expected % {'hash': out} + assert source == expected + + +def test_fcode_results_named_ordered(): + x, y, z = symbols('x,y,z') + B, C = symbols('B,C') + A = MatrixSymbol('A', 1, 3) + expr1 = Equality(A, Matrix([[1, 2, x]])) + expr2 = Equality(C, (x + y)*z) + expr3 = Equality(B, 2*x) + name_expr = ("test", [expr1, expr2, expr3]) + result = codegen(name_expr, "f95", "test", header=False, empty=False, + argument_sequence=(x, z, y, C, A, B)) + source = result[0][1] + expected = ( + "subroutine test(x, z, y, C, A, B)\n" + "implicit none\n" + "REAL*8, intent(in) :: x\n" + "REAL*8, intent(in) :: z\n" + "REAL*8, intent(in) :: y\n" + "REAL*8, intent(out) :: C\n" + "REAL*8, intent(out) :: B\n" + "REAL*8, intent(out), dimension(1:1, 1:3) :: A\n" + "C = z*(x + y)\n" + "A(1, 1) = 1\n" + "A(1, 2) = 2\n" + "A(1, 3) = x\n" + "B = 2*x\n" + "end subroutine\n" + ) + assert source == expected + + +def test_fcode_matrixsymbol_slice(): + A = MatrixSymbol('A', 2, 3) + B = MatrixSymbol('B', 1, 3) + C = MatrixSymbol('C', 1, 3) + D = MatrixSymbol('D', 2, 1) + name_expr = ("test", [Equality(B, A[0, :]), + Equality(C, A[1, :]), + Equality(D, A[:, 2])]) + result = codegen(name_expr, "f95", "test", header=False, empty=False) + source = result[0][1] + expected = ( + "subroutine test(A, B, C, D)\n" + "implicit none\n" + "REAL*8, intent(in), dimension(1:2, 1:3) :: A\n" + "REAL*8, intent(out), dimension(1:1, 1:3) :: B\n" + "REAL*8, intent(out), dimension(1:1, 1:3) :: C\n" + "REAL*8, intent(out), dimension(1:2, 1:1) :: D\n" + "B(1, 1) = A(1, 1)\n" + "B(1, 2) = A(1, 2)\n" + "B(1, 3) = A(1, 3)\n" + "C(1, 1) = A(2, 1)\n" + "C(1, 2) = A(2, 2)\n" + "C(1, 3) = A(2, 3)\n" + "D(1, 1) = A(1, 3)\n" + "D(2, 1) = A(2, 3)\n" + "end subroutine\n" + ) + assert source == expected + + +def test_fcode_matrixsymbol_slice_autoname(): + # see issue #8093 + A = MatrixSymbol('A', 2, 3) + name_expr = ("test", A[:, 1]) + result = codegen(name_expr, "f95", "test", header=False, empty=False) + source = result[0][1] + expected = ( + "subroutine test(A, out_%(hash)s)\n" + "implicit none\n" + "REAL*8, intent(in), dimension(1:2, 1:3) :: A\n" + "REAL*8, intent(out), dimension(1:2, 1:1) :: out_%(hash)s\n" + "out_%(hash)s(1, 1) = A(1, 2)\n" + "out_%(hash)s(2, 1) = A(2, 2)\n" + "end subroutine\n" + ) + # look for the magic number + a = source.splitlines()[3] + b = a.split('_') + out = b[1] + expected = expected % {'hash': out} + assert source == expected + + +def test_global_vars(): + x, y, z, t = symbols("x y z t") + result = codegen(('f', x*y), "F95", header=False, empty=False, + global_vars=(y,)) + source = result[0][1] + expected = ( + "REAL*8 function f(x)\n" + "implicit none\n" + "REAL*8, intent(in) :: x\n" + "f = x*y\n" + "end function\n" + ) + assert source == expected + + expected = ( + '#include "f.h"\n' + '#include \n' + 'double f(double x, double y) {\n' + ' double f_result;\n' + ' f_result = x*y + z;\n' + ' return f_result;\n' + '}\n' + ) + result = codegen(('f', x*y+z), "C", header=False, empty=False, + global_vars=(z, t)) + source = result[0][1] + assert source == expected + +def test_custom_codegen(): + from sympy.printing.c import C99CodePrinter + from sympy.functions.elementary.exponential import exp + + printer = C99CodePrinter(settings={'user_functions': {'exp': 'fastexp'}}) + + x, y = symbols('x y') + expr = exp(x + y) + + # replace math.h with a different header + gen = C99CodeGen(printer=printer, + preprocessor_statements=['#include "fastexp.h"']) + + expected = ( + '#include "expr.h"\n' + '#include "fastexp.h"\n' + 'double expr(double x, double y) {\n' + ' double expr_result;\n' + ' expr_result = fastexp(x + y);\n' + ' return expr_result;\n' + '}\n' + ) + + result = codegen(('expr', expr), header=False, empty=False, code_gen=gen) + source = result[0][1] + assert source == expected + + # use both math.h and an external header + gen = C99CodeGen(printer=printer) + gen.preprocessor_statements.append('#include "fastexp.h"') + + expected = ( + '#include "expr.h"\n' + '#include \n' + '#include "fastexp.h"\n' + 'double expr(double x, double y) {\n' + ' double expr_result;\n' + ' expr_result = fastexp(x + y);\n' + ' return expr_result;\n' + '}\n' + ) + + result = codegen(('expr', expr), header=False, empty=False, code_gen=gen) + source = result[0][1] + assert source == expected + +def test_c_with_printer(): + # issue 13586 + from sympy.printing.c import C99CodePrinter + class CustomPrinter(C99CodePrinter): + def _print_Pow(self, expr): + return "fastpow({}, {})".format(self._print(expr.base), + self._print(expr.exp)) + + x = symbols('x') + expr = x**3 + expected =[ + ("file.c", + "#include \"file.h\"\n" + "#include \n" + "double test(double x) {\n" + " double test_result;\n" + " test_result = fastpow(x, 3);\n" + " return test_result;\n" + "}\n"), + ("file.h", + "#ifndef PROJECT__FILE__H\n" + "#define PROJECT__FILE__H\n" + "double test(double x);\n" + "#endif\n") + ] + result = codegen(("test", expr), "C","file", header=False, empty=False, printer = CustomPrinter()) + assert result == expected + + +def test_fcode_complex(): + import sympy.utilities.codegen + sympy.utilities.codegen.COMPLEX_ALLOWED = True + x = Symbol('x', real=True) + y = Symbol('y',real=True) + result = codegen(('test',x+y), 'f95', 'test', header=False, empty=False) + source = (result[0][1]) + expected = ( + "REAL*8 function test(x, y)\n" + "implicit none\n" + "REAL*8, intent(in) :: x\n" + "REAL*8, intent(in) :: y\n" + "test = x + y\n" + "end function\n") + assert source == expected + x = Symbol('x') + y = Symbol('y',real=True) + result = codegen(('test',x+y), 'f95', 'test', header=False, empty=False) + source = (result[0][1]) + expected = ( + "COMPLEX*16 function test(x, y)\n" + "implicit none\n" + "COMPLEX*16, intent(in) :: x\n" + "REAL*8, intent(in) :: y\n" + "test = x + y\n" + "end function\n" + ) + assert source==expected + sympy.utilities.codegen.COMPLEX_ALLOWED = False diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/test_codegen_julia.py b/phivenv/Lib/site-packages/sympy/utilities/tests/test_codegen_julia.py new file mode 100644 index 0000000000000000000000000000000000000000..12841cb7d476107e3866d91b998bed1f997f3901 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/tests/test_codegen_julia.py @@ -0,0 +1,620 @@ +from io import StringIO + +from sympy.core import S, symbols, Eq, pi, Catalan, EulerGamma, Function +from sympy.core.relational import Equality +from sympy.functions.elementary.piecewise import Piecewise +from sympy.matrices import Matrix, MatrixSymbol +from sympy.utilities.codegen import JuliaCodeGen, codegen, make_routine +from sympy.testing.pytest import XFAIL +import sympy + + +x, y, z = symbols('x,y,z') + + +def test_empty_jl_code(): + code_gen = JuliaCodeGen() + output = StringIO() + code_gen.dump_jl([], output, "file", header=False, empty=False) + source = output.getvalue() + assert source == "" + + +def test_jl_simple_code(): + name_expr = ("test", (x + y)*z) + result, = codegen(name_expr, "Julia", header=False, empty=False) + assert result[0] == "test.jl" + source = result[1] + expected = ( + "function test(x, y, z)\n" + " out1 = z .* (x + y)\n" + " return out1\n" + "end\n" + ) + assert source == expected + + +def test_jl_simple_code_with_header(): + name_expr = ("test", (x + y)*z) + result, = codegen(name_expr, "Julia", header=True, empty=False) + assert result[0] == "test.jl" + source = result[1] + expected = ( + "# Code generated with SymPy " + sympy.__version__ + "\n" + "#\n" + "# See http://www.sympy.org/ for more information.\n" + "#\n" + "# This file is part of 'project'\n" + "function test(x, y, z)\n" + " out1 = z .* (x + y)\n" + " return out1\n" + "end\n" + ) + assert source == expected + + +def test_jl_simple_code_nameout(): + expr = Equality(z, (x + y)) + name_expr = ("test", expr) + result, = codegen(name_expr, "Julia", header=False, empty=False) + source = result[1] + expected = ( + "function test(x, y)\n" + " z = x + y\n" + " return z\n" + "end\n" + ) + assert source == expected + + +def test_jl_numbersymbol(): + name_expr = ("test", pi**Catalan) + result, = codegen(name_expr, "Julia", header=False, empty=False) + source = result[1] + expected = ( + "function test()\n" + " out1 = pi ^ catalan\n" + " return out1\n" + "end\n" + ) + assert source == expected + + +@XFAIL +def test_jl_numbersymbol_no_inline(): + # FIXME: how to pass inline=False to the JuliaCodePrinter? + name_expr = ("test", [pi**Catalan, EulerGamma]) + result, = codegen(name_expr, "Julia", header=False, + empty=False, inline=False) + source = result[1] + expected = ( + "function test()\n" + " Catalan = 0.915965594177219\n" + " EulerGamma = 0.5772156649015329\n" + " out1 = pi ^ Catalan\n" + " out2 = EulerGamma\n" + " return out1, out2\n" + "end\n" + ) + assert source == expected + + +def test_jl_code_argument_order(): + expr = x + y + routine = make_routine("test", expr, argument_sequence=[z, x, y], language="julia") + code_gen = JuliaCodeGen() + output = StringIO() + code_gen.dump_jl([routine], output, "test", header=False, empty=False) + source = output.getvalue() + expected = ( + "function test(z, x, y)\n" + " out1 = x + y\n" + " return out1\n" + "end\n" + ) + assert source == expected + + +def test_multiple_results_m(): + # Here the output order is the input order + expr1 = (x + y)*z + expr2 = (x - y)*z + name_expr = ("test", [expr1, expr2]) + result, = codegen(name_expr, "Julia", header=False, empty=False) + source = result[1] + expected = ( + "function test(x, y, z)\n" + " out1 = z .* (x + y)\n" + " out2 = z .* (x - y)\n" + " return out1, out2\n" + "end\n" + ) + assert source == expected + + +def test_results_named_unordered(): + # Here output order is based on name_expr + A, B, C = symbols('A,B,C') + expr1 = Equality(C, (x + y)*z) + expr2 = Equality(A, (x - y)*z) + expr3 = Equality(B, 2*x) + name_expr = ("test", [expr1, expr2, expr3]) + result, = codegen(name_expr, "Julia", header=False, empty=False) + source = result[1] + expected = ( + "function test(x, y, z)\n" + " C = z .* (x + y)\n" + " A = z .* (x - y)\n" + " B = 2 * x\n" + " return C, A, B\n" + "end\n" + ) + assert source == expected + + +def test_results_named_ordered(): + A, B, C = symbols('A,B,C') + expr1 = Equality(C, (x + y)*z) + expr2 = Equality(A, (x - y)*z) + expr3 = Equality(B, 2*x) + name_expr = ("test", [expr1, expr2, expr3]) + result = codegen(name_expr, "Julia", header=False, empty=False, + argument_sequence=(x, z, y)) + assert result[0][0] == "test.jl" + source = result[0][1] + expected = ( + "function test(x, z, y)\n" + " C = z .* (x + y)\n" + " A = z .* (x - y)\n" + " B = 2 * x\n" + " return C, A, B\n" + "end\n" + ) + assert source == expected + + +def test_complicated_jl_codegen(): + from sympy.functions.elementary.trigonometric import (cos, sin, tan) + name_expr = ("testlong", + [ ((sin(x) + cos(y) + tan(z))**3).expand(), + cos(cos(cos(cos(cos(cos(cos(cos(x + y + z)))))))) + ]) + result = codegen(name_expr, "Julia", header=False, empty=False) + assert result[0][0] == "testlong.jl" + source = result[0][1] + expected = ( + "function testlong(x, y, z)\n" + " out1 = sin(x) .^ 3 + 3 * sin(x) .^ 2 .* cos(y) + 3 * sin(x) .^ 2 .* tan(z)" + " + 3 * sin(x) .* cos(y) .^ 2 + 6 * sin(x) .* cos(y) .* tan(z) + 3 * sin(x) .* tan(z) .^ 2" + " + cos(y) .^ 3 + 3 * cos(y) .^ 2 .* tan(z) + 3 * cos(y) .* tan(z) .^ 2 + tan(z) .^ 3\n" + " out2 = cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))\n" + " return out1, out2\n" + "end\n" + ) + assert source == expected + + +def test_jl_output_arg_mixed_unordered(): + # named outputs are alphabetical, unnamed output appear in the given order + from sympy.functions.elementary.trigonometric import (cos, sin) + a = symbols("a") + name_expr = ("foo", [cos(2*x), Equality(y, sin(x)), cos(x), Equality(a, sin(2*x))]) + result, = codegen(name_expr, "Julia", header=False, empty=False) + assert result[0] == "foo.jl" + source = result[1] + expected = ( + 'function foo(x)\n' + ' out1 = cos(2 * x)\n' + ' y = sin(x)\n' + ' out3 = cos(x)\n' + ' a = sin(2 * x)\n' + ' return out1, y, out3, a\n' + 'end\n' + ) + assert source == expected + + +def test_jl_piecewise_(): + pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True), evaluate=False) + name_expr = ("pwtest", pw) + result, = codegen(name_expr, "Julia", header=False, empty=False) + source = result[1] + expected = ( + "function pwtest(x)\n" + " out1 = ((x < -1) ? (0) :\n" + " (x <= 1) ? (x .^ 2) :\n" + " (x > 1) ? (2 - x) : (1))\n" + " return out1\n" + "end\n" + ) + assert source == expected + + +@XFAIL +def test_jl_piecewise_no_inline(): + # FIXME: how to pass inline=False to the JuliaCodePrinter? + pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True)) + name_expr = ("pwtest", pw) + result, = codegen(name_expr, "Julia", header=False, empty=False, + inline=False) + source = result[1] + expected = ( + "function pwtest(x)\n" + " if (x < -1)\n" + " out1 = 0\n" + " elseif (x <= 1)\n" + " out1 = x .^ 2\n" + " elseif (x > 1)\n" + " out1 = -x + 2\n" + " else\n" + " out1 = 1\n" + " end\n" + " return out1\n" + "end\n" + ) + assert source == expected + + +def test_jl_multifcns_per_file(): + name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ] + result = codegen(name_expr, "Julia", header=False, empty=False) + assert result[0][0] == "foo.jl" + source = result[0][1] + expected = ( + "function foo(x, y)\n" + " out1 = 2 * x\n" + " out2 = 3 * y\n" + " return out1, out2\n" + "end\n" + "function bar(y)\n" + " out1 = y .^ 2\n" + " out2 = 4 * y\n" + " return out1, out2\n" + "end\n" + ) + assert source == expected + + +def test_jl_multifcns_per_file_w_header(): + name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ] + result = codegen(name_expr, "Julia", header=True, empty=False) + assert result[0][0] == "foo.jl" + source = result[0][1] + expected = ( + "# Code generated with SymPy " + sympy.__version__ + "\n" + "#\n" + "# See http://www.sympy.org/ for more information.\n" + "#\n" + "# This file is part of 'project'\n" + "function foo(x, y)\n" + " out1 = 2 * x\n" + " out2 = 3 * y\n" + " return out1, out2\n" + "end\n" + "function bar(y)\n" + " out1 = y .^ 2\n" + " out2 = 4 * y\n" + " return out1, out2\n" + "end\n" + ) + assert source == expected + + +def test_jl_filename_match_prefix(): + name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ] + result, = codegen(name_expr, "Julia", prefix="baz", header=False, + empty=False) + assert result[0] == "baz.jl" + + +def test_jl_matrix_named(): + e2 = Matrix([[x, 2*y, pi*z]]) + name_expr = ("test", Equality(MatrixSymbol('myout1', 1, 3), e2)) + result = codegen(name_expr, "Julia", header=False, empty=False) + assert result[0][0] == "test.jl" + source = result[0][1] + expected = ( + "function test(x, y, z)\n" + " myout1 = [x 2 * y pi * z]\n" + " return myout1\n" + "end\n" + ) + assert source == expected + + +def test_jl_matrix_named_matsym(): + myout1 = MatrixSymbol('myout1', 1, 3) + e2 = Matrix([[x, 2*y, pi*z]]) + name_expr = ("test", Equality(myout1, e2, evaluate=False)) + result, = codegen(name_expr, "Julia", header=False, empty=False) + source = result[1] + expected = ( + "function test(x, y, z)\n" + " myout1 = [x 2 * y pi * z]\n" + " return myout1\n" + "end\n" + ) + assert source == expected + + +def test_jl_matrix_output_autoname(): + expr = Matrix([[x, x+y, 3]]) + name_expr = ("test", expr) + result, = codegen(name_expr, "Julia", header=False, empty=False) + source = result[1] + expected = ( + "function test(x, y)\n" + " out1 = [x x + y 3]\n" + " return out1\n" + "end\n" + ) + assert source == expected + + +def test_jl_matrix_output_autoname_2(): + e1 = (x + y) + e2 = Matrix([[2*x, 2*y, 2*z]]) + e3 = Matrix([[x], [y], [z]]) + e4 = Matrix([[x, y], [z, 16]]) + name_expr = ("test", (e1, e2, e3, e4)) + result, = codegen(name_expr, "Julia", header=False, empty=False) + source = result[1] + expected = ( + "function test(x, y, z)\n" + " out1 = x + y\n" + " out2 = [2 * x 2 * y 2 * z]\n" + " out3 = [x, y, z]\n" + " out4 = [x y;\n" + " z 16]\n" + " return out1, out2, out3, out4\n" + "end\n" + ) + assert source == expected + + +def test_jl_results_matrix_named_ordered(): + B, C = symbols('B,C') + A = MatrixSymbol('A', 1, 3) + expr1 = Equality(C, (x + y)*z) + expr2 = Equality(A, Matrix([[1, 2, x]])) + expr3 = Equality(B, 2*x) + name_expr = ("test", [expr1, expr2, expr3]) + result, = codegen(name_expr, "Julia", header=False, empty=False, + argument_sequence=(x, z, y)) + source = result[1] + expected = ( + "function test(x, z, y)\n" + " C = z .* (x + y)\n" + " A = [1 2 x]\n" + " B = 2 * x\n" + " return C, A, B\n" + "end\n" + ) + assert source == expected + + +def test_jl_matrixsymbol_slice(): + A = MatrixSymbol('A', 2, 3) + B = MatrixSymbol('B', 1, 3) + C = MatrixSymbol('C', 1, 3) + D = MatrixSymbol('D', 2, 1) + name_expr = ("test", [Equality(B, A[0, :]), + Equality(C, A[1, :]), + Equality(D, A[:, 2])]) + result, = codegen(name_expr, "Julia", header=False, empty=False) + source = result[1] + expected = ( + "function test(A)\n" + " B = A[1,:]\n" + " C = A[2,:]\n" + " D = A[:,3]\n" + " return B, C, D\n" + "end\n" + ) + assert source == expected + + +def test_jl_matrixsymbol_slice2(): + A = MatrixSymbol('A', 3, 4) + B = MatrixSymbol('B', 2, 2) + C = MatrixSymbol('C', 2, 2) + name_expr = ("test", [Equality(B, A[0:2, 0:2]), + Equality(C, A[0:2, 1:3])]) + result, = codegen(name_expr, "Julia", header=False, empty=False) + source = result[1] + expected = ( + "function test(A)\n" + " B = A[1:2,1:2]\n" + " C = A[1:2,2:3]\n" + " return B, C\n" + "end\n" + ) + assert source == expected + + +def test_jl_matrixsymbol_slice3(): + A = MatrixSymbol('A', 8, 7) + B = MatrixSymbol('B', 2, 2) + C = MatrixSymbol('C', 4, 2) + name_expr = ("test", [Equality(B, A[6:, 1::3]), + Equality(C, A[::2, ::3])]) + result, = codegen(name_expr, "Julia", header=False, empty=False) + source = result[1] + expected = ( + "function test(A)\n" + " B = A[7:end,2:3:end]\n" + " C = A[1:2:end,1:3:end]\n" + " return B, C\n" + "end\n" + ) + assert source == expected + + +def test_jl_matrixsymbol_slice_autoname(): + A = MatrixSymbol('A', 2, 3) + B = MatrixSymbol('B', 1, 3) + name_expr = ("test", [Equality(B, A[0,:]), A[1,:], A[:,0], A[:,1]]) + result, = codegen(name_expr, "Julia", header=False, empty=False) + source = result[1] + expected = ( + "function test(A)\n" + " B = A[1,:]\n" + " out2 = A[2,:]\n" + " out3 = A[:,1]\n" + " out4 = A[:,2]\n" + " return B, out2, out3, out4\n" + "end\n" + ) + assert source == expected + + +def test_jl_loops(): + # Note: an Julia programmer would probably vectorize this across one or + # more dimensions. Also, size(A) would be used rather than passing in m + # and n. Perhaps users would expect us to vectorize automatically here? + # Or is it possible to represent such things using IndexedBase? + from sympy.tensor import IndexedBase, Idx + from sympy.core.symbol import symbols + n, m = symbols('n m', integer=True) + A = IndexedBase('A') + x = IndexedBase('x') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + result, = codegen(('mat_vec_mult', Eq(y[i], A[i, j]*x[j])), "Julia", + header=False, empty=False) + source = result[1] + expected = ( + 'function mat_vec_mult(y, A, m, n, x)\n' + ' for i = 1:m\n' + ' y[i] = 0\n' + ' end\n' + ' for i = 1:m\n' + ' for j = 1:n\n' + ' y[i] = %(rhs)s + y[i]\n' + ' end\n' + ' end\n' + ' return y\n' + 'end\n' + ) + assert (source == expected % {'rhs': 'A[%s,%s] .* x[j]' % (i, j)} or + source == expected % {'rhs': 'x[j] .* A[%s,%s]' % (i, j)}) + + +def test_jl_tensor_loops_multiple_contractions(): + # see comments in previous test about vectorizing + from sympy.tensor import IndexedBase, Idx + from sympy.core.symbol import symbols + n, m, o, p = symbols('n m o p', integer=True) + A = IndexedBase('A') + B = IndexedBase('B') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + k = Idx('k', o) + l = Idx('l', p) + result, = codegen(('tensorthing', Eq(y[i], B[j, k, l]*A[i, j, k, l])), + "Julia", header=False, empty=False) + source = result[1] + expected = ( + 'function tensorthing(y, A, B, m, n, o, p)\n' + ' for i = 1:m\n' + ' y[i] = 0\n' + ' end\n' + ' for i = 1:m\n' + ' for j = 1:n\n' + ' for k = 1:o\n' + ' for l = 1:p\n' + ' y[i] = A[i,j,k,l] .* B[j,k,l] + y[i]\n' + ' end\n' + ' end\n' + ' end\n' + ' end\n' + ' return y\n' + 'end\n' + ) + assert source == expected + + +def test_jl_InOutArgument(): + expr = Equality(x, x**2) + name_expr = ("mysqr", expr) + result, = codegen(name_expr, "Julia", header=False, empty=False) + source = result[1] + expected = ( + "function mysqr(x)\n" + " x = x .^ 2\n" + " return x\n" + "end\n" + ) + assert source == expected + + +def test_jl_InOutArgument_order(): + # can specify the order as (x, y) + expr = Equality(x, x**2 + y) + name_expr = ("test", expr) + result, = codegen(name_expr, "Julia", header=False, + empty=False, argument_sequence=(x,y)) + source = result[1] + expected = ( + "function test(x, y)\n" + " x = x .^ 2 + y\n" + " return x\n" + "end\n" + ) + assert source == expected + # make sure it gives (x, y) not (y, x) + expr = Equality(x, x**2 + y) + name_expr = ("test", expr) + result, = codegen(name_expr, "Julia", header=False, empty=False) + source = result[1] + expected = ( + "function test(x, y)\n" + " x = x .^ 2 + y\n" + " return x\n" + "end\n" + ) + assert source == expected + + +def test_jl_not_supported(): + f = Function('f') + name_expr = ("test", [f(x).diff(x), S.ComplexInfinity]) + result, = codegen(name_expr, "Julia", header=False, empty=False) + source = result[1] + expected = ( + "function test(x)\n" + " # unsupported: Derivative(f(x), x)\n" + " # unsupported: zoo\n" + " out1 = Derivative(f(x), x)\n" + " out2 = zoo\n" + " return out1, out2\n" + "end\n" + ) + assert source == expected + + +def test_global_vars_octave(): + x, y, z, t = symbols("x y z t") + result = codegen(('f', x*y), "Julia", header=False, empty=False, + global_vars=(y,)) + source = result[0][1] + expected = ( + "function f(x)\n" + " out1 = x .* y\n" + " return out1\n" + "end\n" + ) + assert source == expected + + result = codegen(('f', x*y+z), "Julia", header=False, empty=False, + argument_sequence=(x, y), global_vars=(z, t)) + source = result[0][1] + expected = ( + "function f(x, y)\n" + " out1 = x .* y + z\n" + " return out1\n" + "end\n" + ) + assert source == expected diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/test_codegen_octave.py b/phivenv/Lib/site-packages/sympy/utilities/tests/test_codegen_octave.py new file mode 100644 index 0000000000000000000000000000000000000000..77aaef7dd0d81d6855024e49fbb3d6d4c09f3384 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/tests/test_codegen_octave.py @@ -0,0 +1,589 @@ +from io import StringIO + +from sympy.core import S, symbols, Eq, pi, Catalan, EulerGamma, Function +from sympy.core.relational import Equality +from sympy.functions.elementary.piecewise import Piecewise +from sympy.matrices import Matrix, MatrixSymbol +from sympy.utilities.codegen import OctaveCodeGen, codegen, make_routine +from sympy.testing.pytest import raises +from sympy.testing.pytest import XFAIL +import sympy + + +x, y, z = symbols('x,y,z') + + +def test_empty_m_code(): + code_gen = OctaveCodeGen() + output = StringIO() + code_gen.dump_m([], output, "file", header=False, empty=False) + source = output.getvalue() + assert source == "" + + +def test_m_simple_code(): + name_expr = ("test", (x + y)*z) + result, = codegen(name_expr, "Octave", header=False, empty=False) + assert result[0] == "test.m" + source = result[1] + expected = ( + "function out1 = test(x, y, z)\n" + " out1 = z.*(x + y);\n" + "end\n" + ) + assert source == expected + + +def test_m_simple_code_with_header(): + name_expr = ("test", (x + y)*z) + result, = codegen(name_expr, "Octave", header=True, empty=False) + assert result[0] == "test.m" + source = result[1] + expected = ( + "function out1 = test(x, y, z)\n" + " %TEST Autogenerated by SymPy\n" + " % Code generated with SymPy " + sympy.__version__ + "\n" + " %\n" + " % See http://www.sympy.org/ for more information.\n" + " %\n" + " % This file is part of 'project'\n" + " out1 = z.*(x + y);\n" + "end\n" + ) + assert source == expected + + +def test_m_simple_code_nameout(): + expr = Equality(z, (x + y)) + name_expr = ("test", expr) + result, = codegen(name_expr, "Octave", header=False, empty=False) + source = result[1] + expected = ( + "function z = test(x, y)\n" + " z = x + y;\n" + "end\n" + ) + assert source == expected + + +def test_m_numbersymbol(): + name_expr = ("test", pi**Catalan) + result, = codegen(name_expr, "Octave", header=False, empty=False) + source = result[1] + expected = ( + "function out1 = test()\n" + " out1 = pi^%s;\n" + "end\n" + ) % Catalan.evalf(17) + assert source == expected + + +@XFAIL +def test_m_numbersymbol_no_inline(): + # FIXME: how to pass inline=False to the OctaveCodePrinter? + name_expr = ("test", [pi**Catalan, EulerGamma]) + result, = codegen(name_expr, "Octave", header=False, + empty=False, inline=False) + source = result[1] + expected = ( + "function [out1, out2] = test()\n" + " Catalan = 0.915965594177219; % constant\n" + " EulerGamma = 0.5772156649015329; % constant\n" + " out1 = pi^Catalan;\n" + " out2 = EulerGamma;\n" + "end\n" + ) + assert source == expected + + +def test_m_code_argument_order(): + expr = x + y + routine = make_routine("test", expr, argument_sequence=[z, x, y], language="octave") + code_gen = OctaveCodeGen() + output = StringIO() + code_gen.dump_m([routine], output, "test", header=False, empty=False) + source = output.getvalue() + expected = ( + "function out1 = test(z, x, y)\n" + " out1 = x + y;\n" + "end\n" + ) + assert source == expected + + +def test_multiple_results_m(): + # Here the output order is the input order + expr1 = (x + y)*z + expr2 = (x - y)*z + name_expr = ("test", [expr1, expr2]) + result, = codegen(name_expr, "Octave", header=False, empty=False) + source = result[1] + expected = ( + "function [out1, out2] = test(x, y, z)\n" + " out1 = z.*(x + y);\n" + " out2 = z.*(x - y);\n" + "end\n" + ) + assert source == expected + + +def test_results_named_unordered(): + # Here output order is based on name_expr + A, B, C = symbols('A,B,C') + expr1 = Equality(C, (x + y)*z) + expr2 = Equality(A, (x - y)*z) + expr3 = Equality(B, 2*x) + name_expr = ("test", [expr1, expr2, expr3]) + result, = codegen(name_expr, "Octave", header=False, empty=False) + source = result[1] + expected = ( + "function [C, A, B] = test(x, y, z)\n" + " C = z.*(x + y);\n" + " A = z.*(x - y);\n" + " B = 2*x;\n" + "end\n" + ) + assert source == expected + + +def test_results_named_ordered(): + A, B, C = symbols('A,B,C') + expr1 = Equality(C, (x + y)*z) + expr2 = Equality(A, (x - y)*z) + expr3 = Equality(B, 2*x) + name_expr = ("test", [expr1, expr2, expr3]) + result = codegen(name_expr, "Octave", header=False, empty=False, + argument_sequence=(x, z, y)) + assert result[0][0] == "test.m" + source = result[0][1] + expected = ( + "function [C, A, B] = test(x, z, y)\n" + " C = z.*(x + y);\n" + " A = z.*(x - y);\n" + " B = 2*x;\n" + "end\n" + ) + assert source == expected + + +def test_complicated_m_codegen(): + from sympy.functions.elementary.trigonometric import (cos, sin, tan) + name_expr = ("testlong", + [ ((sin(x) + cos(y) + tan(z))**3).expand(), + cos(cos(cos(cos(cos(cos(cos(cos(x + y + z)))))))) + ]) + result = codegen(name_expr, "Octave", header=False, empty=False) + assert result[0][0] == "testlong.m" + source = result[0][1] + expected = ( + "function [out1, out2] = testlong(x, y, z)\n" + " out1 = sin(x).^3 + 3*sin(x).^2.*cos(y) + 3*sin(x).^2.*tan(z)" + " + 3*sin(x).*cos(y).^2 + 6*sin(x).*cos(y).*tan(z) + 3*sin(x).*tan(z).^2" + " + cos(y).^3 + 3*cos(y).^2.*tan(z) + 3*cos(y).*tan(z).^2 + tan(z).^3;\n" + " out2 = cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))));\n" + "end\n" + ) + assert source == expected + + +def test_m_output_arg_mixed_unordered(): + # named outputs are alphabetical, unnamed output appear in the given order + from sympy.functions.elementary.trigonometric import (cos, sin) + a = symbols("a") + name_expr = ("foo", [cos(2*x), Equality(y, sin(x)), cos(x), Equality(a, sin(2*x))]) + result, = codegen(name_expr, "Octave", header=False, empty=False) + assert result[0] == "foo.m" + source = result[1] + expected = ( + 'function [out1, y, out3, a] = foo(x)\n' + ' out1 = cos(2*x);\n' + ' y = sin(x);\n' + ' out3 = cos(x);\n' + ' a = sin(2*x);\n' + 'end\n' + ) + assert source == expected + + +def test_m_piecewise_(): + pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True), evaluate=False) + name_expr = ("pwtest", pw) + result, = codegen(name_expr, "Octave", header=False, empty=False) + source = result[1] + expected = ( + "function out1 = pwtest(x)\n" + " out1 = ((x < -1).*(0) + (~(x < -1)).*( ...\n" + " (x <= 1).*(x.^2) + (~(x <= 1)).*( ...\n" + " (x > 1).*(2 - x) + (~(x > 1)).*(1))));\n" + "end\n" + ) + assert source == expected + + +@XFAIL +def test_m_piecewise_no_inline(): + # FIXME: how to pass inline=False to the OctaveCodePrinter? + pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True)) + name_expr = ("pwtest", pw) + result, = codegen(name_expr, "Octave", header=False, empty=False, + inline=False) + source = result[1] + expected = ( + "function out1 = pwtest(x)\n" + " if (x < -1)\n" + " out1 = 0;\n" + " elseif (x <= 1)\n" + " out1 = x.^2;\n" + " elseif (x > 1)\n" + " out1 = -x + 2;\n" + " else\n" + " out1 = 1;\n" + " end\n" + "end\n" + ) + assert source == expected + + +def test_m_multifcns_per_file(): + name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ] + result = codegen(name_expr, "Octave", header=False, empty=False) + assert result[0][0] == "foo.m" + source = result[0][1] + expected = ( + "function [out1, out2] = foo(x, y)\n" + " out1 = 2*x;\n" + " out2 = 3*y;\n" + "end\n" + "function [out1, out2] = bar(y)\n" + " out1 = y.^2;\n" + " out2 = 4*y;\n" + "end\n" + ) + assert source == expected + + +def test_m_multifcns_per_file_w_header(): + name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ] + result = codegen(name_expr, "Octave", header=True, empty=False) + assert result[0][0] == "foo.m" + source = result[0][1] + expected = ( + "function [out1, out2] = foo(x, y)\n" + " %FOO Autogenerated by SymPy\n" + " % Code generated with SymPy " + sympy.__version__ + "\n" + " %\n" + " % See http://www.sympy.org/ for more information.\n" + " %\n" + " % This file is part of 'project'\n" + " out1 = 2*x;\n" + " out2 = 3*y;\n" + "end\n" + "function [out1, out2] = bar(y)\n" + " out1 = y.^2;\n" + " out2 = 4*y;\n" + "end\n" + ) + assert source == expected + + +def test_m_filename_match_first_fcn(): + name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ] + raises(ValueError, lambda: codegen(name_expr, + "Octave", prefix="bar", header=False, empty=False)) + + +def test_m_matrix_named(): + e2 = Matrix([[x, 2*y, pi*z]]) + name_expr = ("test", Equality(MatrixSymbol('myout1', 1, 3), e2)) + result = codegen(name_expr, "Octave", header=False, empty=False) + assert result[0][0] == "test.m" + source = result[0][1] + expected = ( + "function myout1 = test(x, y, z)\n" + " myout1 = [x 2*y pi*z];\n" + "end\n" + ) + assert source == expected + + +def test_m_matrix_named_matsym(): + myout1 = MatrixSymbol('myout1', 1, 3) + e2 = Matrix([[x, 2*y, pi*z]]) + name_expr = ("test", Equality(myout1, e2, evaluate=False)) + result, = codegen(name_expr, "Octave", header=False, empty=False) + source = result[1] + expected = ( + "function myout1 = test(x, y, z)\n" + " myout1 = [x 2*y pi*z];\n" + "end\n" + ) + assert source == expected + + +def test_m_matrix_output_autoname(): + expr = Matrix([[x, x+y, 3]]) + name_expr = ("test", expr) + result, = codegen(name_expr, "Octave", header=False, empty=False) + source = result[1] + expected = ( + "function out1 = test(x, y)\n" + " out1 = [x x + y 3];\n" + "end\n" + ) + assert source == expected + + +def test_m_matrix_output_autoname_2(): + e1 = (x + y) + e2 = Matrix([[2*x, 2*y, 2*z]]) + e3 = Matrix([[x], [y], [z]]) + e4 = Matrix([[x, y], [z, 16]]) + name_expr = ("test", (e1, e2, e3, e4)) + result, = codegen(name_expr, "Octave", header=False, empty=False) + source = result[1] + expected = ( + "function [out1, out2, out3, out4] = test(x, y, z)\n" + " out1 = x + y;\n" + " out2 = [2*x 2*y 2*z];\n" + " out3 = [x; y; z];\n" + " out4 = [x y; z 16];\n" + "end\n" + ) + assert source == expected + + +def test_m_results_matrix_named_ordered(): + B, C = symbols('B,C') + A = MatrixSymbol('A', 1, 3) + expr1 = Equality(C, (x + y)*z) + expr2 = Equality(A, Matrix([[1, 2, x]])) + expr3 = Equality(B, 2*x) + name_expr = ("test", [expr1, expr2, expr3]) + result, = codegen(name_expr, "Octave", header=False, empty=False, + argument_sequence=(x, z, y)) + source = result[1] + expected = ( + "function [C, A, B] = test(x, z, y)\n" + " C = z.*(x + y);\n" + " A = [1 2 x];\n" + " B = 2*x;\n" + "end\n" + ) + assert source == expected + + +def test_m_matrixsymbol_slice(): + A = MatrixSymbol('A', 2, 3) + B = MatrixSymbol('B', 1, 3) + C = MatrixSymbol('C', 1, 3) + D = MatrixSymbol('D', 2, 1) + name_expr = ("test", [Equality(B, A[0, :]), + Equality(C, A[1, :]), + Equality(D, A[:, 2])]) + result, = codegen(name_expr, "Octave", header=False, empty=False) + source = result[1] + expected = ( + "function [B, C, D] = test(A)\n" + " B = A(1, :);\n" + " C = A(2, :);\n" + " D = A(:, 3);\n" + "end\n" + ) + assert source == expected + + +def test_m_matrixsymbol_slice2(): + A = MatrixSymbol('A', 3, 4) + B = MatrixSymbol('B', 2, 2) + C = MatrixSymbol('C', 2, 2) + name_expr = ("test", [Equality(B, A[0:2, 0:2]), + Equality(C, A[0:2, 1:3])]) + result, = codegen(name_expr, "Octave", header=False, empty=False) + source = result[1] + expected = ( + "function [B, C] = test(A)\n" + " B = A(1:2, 1:2);\n" + " C = A(1:2, 2:3);\n" + "end\n" + ) + assert source == expected + + +def test_m_matrixsymbol_slice3(): + A = MatrixSymbol('A', 8, 7) + B = MatrixSymbol('B', 2, 2) + C = MatrixSymbol('C', 4, 2) + name_expr = ("test", [Equality(B, A[6:, 1::3]), + Equality(C, A[::2, ::3])]) + result, = codegen(name_expr, "Octave", header=False, empty=False) + source = result[1] + expected = ( + "function [B, C] = test(A)\n" + " B = A(7:end, 2:3:end);\n" + " C = A(1:2:end, 1:3:end);\n" + "end\n" + ) + assert source == expected + + +def test_m_matrixsymbol_slice_autoname(): + A = MatrixSymbol('A', 2, 3) + B = MatrixSymbol('B', 1, 3) + name_expr = ("test", [Equality(B, A[0,:]), A[1,:], A[:,0], A[:,1]]) + result, = codegen(name_expr, "Octave", header=False, empty=False) + source = result[1] + expected = ( + "function [B, out2, out3, out4] = test(A)\n" + " B = A(1, :);\n" + " out2 = A(2, :);\n" + " out3 = A(:, 1);\n" + " out4 = A(:, 2);\n" + "end\n" + ) + assert source == expected + + +def test_m_loops(): + # Note: an Octave programmer would probably vectorize this across one or + # more dimensions. Also, size(A) would be used rather than passing in m + # and n. Perhaps users would expect us to vectorize automatically here? + # Or is it possible to represent such things using IndexedBase? + from sympy.tensor import IndexedBase, Idx + from sympy.core.symbol import symbols + n, m = symbols('n m', integer=True) + A = IndexedBase('A') + x = IndexedBase('x') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + result, = codegen(('mat_vec_mult', Eq(y[i], A[i, j]*x[j])), "Octave", + header=False, empty=False) + source = result[1] + expected = ( + 'function y = mat_vec_mult(A, m, n, x)\n' + ' for i = 1:m\n' + ' y(i) = 0;\n' + ' end\n' + ' for i = 1:m\n' + ' for j = 1:n\n' + ' y(i) = %(rhs)s + y(i);\n' + ' end\n' + ' end\n' + 'end\n' + ) + assert (source == expected % {'rhs': 'A(%s, %s).*x(j)' % (i, j)} or + source == expected % {'rhs': 'x(j).*A(%s, %s)' % (i, j)}) + + +def test_m_tensor_loops_multiple_contractions(): + # see comments in previous test about vectorizing + from sympy.tensor import IndexedBase, Idx + from sympy.core.symbol import symbols + n, m, o, p = symbols('n m o p', integer=True) + A = IndexedBase('A') + B = IndexedBase('B') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + k = Idx('k', o) + l = Idx('l', p) + result, = codegen(('tensorthing', Eq(y[i], B[j, k, l]*A[i, j, k, l])), + "Octave", header=False, empty=False) + source = result[1] + expected = ( + 'function y = tensorthing(A, B, m, n, o, p)\n' + ' for i = 1:m\n' + ' y(i) = 0;\n' + ' end\n' + ' for i = 1:m\n' + ' for j = 1:n\n' + ' for k = 1:o\n' + ' for l = 1:p\n' + ' y(i) = A(i, j, k, l).*B(j, k, l) + y(i);\n' + ' end\n' + ' end\n' + ' end\n' + ' end\n' + 'end\n' + ) + assert source == expected + + +def test_m_InOutArgument(): + expr = Equality(x, x**2) + name_expr = ("mysqr", expr) + result, = codegen(name_expr, "Octave", header=False, empty=False) + source = result[1] + expected = ( + "function x = mysqr(x)\n" + " x = x.^2;\n" + "end\n" + ) + assert source == expected + + +def test_m_InOutArgument_order(): + # can specify the order as (x, y) + expr = Equality(x, x**2 + y) + name_expr = ("test", expr) + result, = codegen(name_expr, "Octave", header=False, + empty=False, argument_sequence=(x,y)) + source = result[1] + expected = ( + "function x = test(x, y)\n" + " x = x.^2 + y;\n" + "end\n" + ) + assert source == expected + # make sure it gives (x, y) not (y, x) + expr = Equality(x, x**2 + y) + name_expr = ("test", expr) + result, = codegen(name_expr, "Octave", header=False, empty=False) + source = result[1] + expected = ( + "function x = test(x, y)\n" + " x = x.^2 + y;\n" + "end\n" + ) + assert source == expected + + +def test_m_not_supported(): + f = Function('f') + name_expr = ("test", [f(x).diff(x), S.ComplexInfinity]) + result, = codegen(name_expr, "Octave", header=False, empty=False) + source = result[1] + expected = ( + "function [out1, out2] = test(x)\n" + " % unsupported: Derivative(f(x), x)\n" + " % unsupported: zoo\n" + " out1 = Derivative(f(x), x);\n" + " out2 = zoo;\n" + "end\n" + ) + assert source == expected + + +def test_global_vars_octave(): + x, y, z, t = symbols("x y z t") + result = codegen(('f', x*y), "Octave", header=False, empty=False, + global_vars=(y,)) + source = result[0][1] + expected = ( + "function out1 = f(x)\n" + " global y\n" + " out1 = x.*y;\n" + "end\n" + ) + assert source == expected + + result = codegen(('f', x*y+z), "Octave", header=False, empty=False, + argument_sequence=(x, y), global_vars=(z, t)) + source = result[0][1] + expected = ( + "function out1 = f(x, y)\n" + " global t z\n" + " out1 = x.*y + z;\n" + "end\n" + ) + assert source == expected diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/test_codegen_rust.py b/phivenv/Lib/site-packages/sympy/utilities/tests/test_codegen_rust.py new file mode 100644 index 0000000000000000000000000000000000000000..bc7f82158ae8fa7dfe34bf909aa695b119fb9526 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/tests/test_codegen_rust.py @@ -0,0 +1,401 @@ +from io import StringIO + +from sympy.core import S, symbols, pi, Catalan, EulerGamma, Function +from sympy.core.relational import Equality +from sympy.functions.elementary.piecewise import Piecewise +from sympy.utilities.codegen import RustCodeGen, codegen, make_routine +from sympy.testing.pytest import XFAIL +import sympy + + +x, y, z = symbols('x,y,z') + + +def test_empty_rust_code(): + code_gen = RustCodeGen() + output = StringIO() + code_gen.dump_rs([], output, "file", header=False, empty=False) + source = output.getvalue() + assert source == "" + + +def test_simple_rust_code(): + name_expr = ("test", (x + y)*z) + result, = codegen(name_expr, "Rust", header=False, empty=False) + assert result[0] == "test.rs" + source = result[1] + expected = ( + "fn test(x: f64, y: f64, z: f64) -> f64 {\n" + " let out1 = z*(x + y);\n" + " out1\n" + "}\n" + ) + assert source == expected + + +def test_simple_code_with_header(): + name_expr = ("test", (x + y)*z) + result, = codegen(name_expr, "Rust", header=True, empty=False) + assert result[0] == "test.rs" + source = result[1] + version_str = "Code generated with SymPy %s" % sympy.__version__ + version_line = version_str.center(76).rstrip() + expected = ( + "/*\n" + " *%(version_line)s\n" + " *\n" + " * See http://www.sympy.org/ for more information.\n" + " *\n" + " * This file is part of 'project'\n" + " */\n" + "fn test(x: f64, y: f64, z: f64) -> f64 {\n" + " let out1 = z*(x + y);\n" + " out1\n" + "}\n" + ) % {'version_line': version_line} + assert source == expected + + +def test_simple_code_nameout(): + expr = Equality(z, (x + y)) + name_expr = ("test", expr) + result, = codegen(name_expr, "Rust", header=False, empty=False) + source = result[1] + expected = ( + "fn test(x: f64, y: f64) -> f64 {\n" + " let z = x + y;\n" + " z\n" + "}\n" + ) + assert source == expected + + +def test_numbersymbol(): + name_expr = ("test", pi**Catalan) + result, = codegen(name_expr, "Rust", header=False, empty=False) + source = result[1] + expected = ( + "fn test() -> f64 {\n" + " const Catalan: f64 = %s;\n" + " let out1 = PI.powf(Catalan);\n" + " out1\n" + "}\n" + ) % Catalan.evalf(17) + assert source == expected + + +@XFAIL +def test_numbersymbol_inline(): + # FIXME: how to pass inline to the RustCodePrinter? + name_expr = ("test", [pi**Catalan, EulerGamma]) + result, = codegen(name_expr, "Rust", header=False, + empty=False, inline=True) + source = result[1] + expected = ( + "fn test() -> (f64, f64) {\n" + " const Catalan: f64 = %s;\n" + " const EulerGamma: f64 = %s;\n" + " let out1 = PI.powf(Catalan);\n" + " let out2 = EulerGamma);\n" + " (out1, out2)\n" + "}\n" + ) % (Catalan.evalf(17), EulerGamma.evalf(17)) + assert source == expected + + +def test_argument_order(): + expr = x + y + routine = make_routine("test", expr, argument_sequence=[z, x, y], language="rust") + code_gen = RustCodeGen() + output = StringIO() + code_gen.dump_rs([routine], output, "test", header=False, empty=False) + source = output.getvalue() + expected = ( + "fn test(z: f64, x: f64, y: f64) -> f64 {\n" + " let out1 = x + y;\n" + " out1\n" + "}\n" + ) + assert source == expected + + +def test_multiple_results_rust(): + # Here the output order is the input order + expr1 = (x + y)*z + expr2 = (x - y)*z + name_expr = ("test", [expr1, expr2]) + result, = codegen(name_expr, "Rust", header=False, empty=False) + source = result[1] + expected = ( + "fn test(x: f64, y: f64, z: f64) -> (f64, f64) {\n" + " let out1 = z*(x + y);\n" + " let out2 = z*(x - y);\n" + " (out1, out2)\n" + "}\n" + ) + assert source == expected + + +def test_results_named_unordered(): + # Here output order is based on name_expr + A, B, C = symbols('A,B,C') + expr1 = Equality(C, (x + y)*z) + expr2 = Equality(A, (x - y)*z) + expr3 = Equality(B, 2*x) + name_expr = ("test", [expr1, expr2, expr3]) + result, = codegen(name_expr, "Rust", header=False, empty=False) + source = result[1] + expected = ( + "fn test(x: f64, y: f64, z: f64) -> (f64, f64, f64) {\n" + " let C = z*(x + y);\n" + " let A = z*(x - y);\n" + " let B = 2*x;\n" + " (C, A, B)\n" + "}\n" + ) + assert source == expected + + +def test_results_named_ordered(): + A, B, C = symbols('A,B,C') + expr1 = Equality(C, (x + y)*z) + expr2 = Equality(A, (x - y)*z) + expr3 = Equality(B, 2*x) + name_expr = ("test", [expr1, expr2, expr3]) + result = codegen(name_expr, "Rust", header=False, empty=False, + argument_sequence=(x, z, y)) + assert result[0][0] == "test.rs" + source = result[0][1] + expected = ( + "fn test(x: f64, z: f64, y: f64) -> (f64, f64, f64) {\n" + " let C = z*(x + y);\n" + " let A = z*(x - y);\n" + " let B = 2*x;\n" + " (C, A, B)\n" + "}\n" + ) + assert source == expected + + +def test_complicated_rs_codegen(): + from sympy.functions.elementary.trigonometric import (cos, sin, tan) + name_expr = ("testlong", + [ ((sin(x) + cos(y) + tan(z))**3).expand(), + cos(cos(cos(cos(cos(cos(cos(cos(x + y + z)))))))) + ]) + result = codegen(name_expr, "Rust", header=False, empty=False) + assert result[0][0] == "testlong.rs" + source = result[0][1] + expected = ( + "fn testlong(x: f64, y: f64, z: f64) -> (f64, f64) {\n" + " let out1 = x.sin().powi(3) + 3*x.sin().powi(2)*y.cos()" + " + 3*x.sin().powi(2)*z.tan() + 3*x.sin()*y.cos().powi(2)" + " + 6*x.sin()*y.cos()*z.tan() + 3*x.sin()*z.tan().powi(2)" + " + y.cos().powi(3) + 3*y.cos().powi(2)*z.tan()" + " + 3*y.cos()*z.tan().powi(2) + z.tan().powi(3);\n" + " let out2 = (x + y + z).cos().cos().cos().cos()" + ".cos().cos().cos().cos();\n" + " (out1, out2)\n" + "}\n" + ) + assert source == expected + + +def test_output_arg_mixed_unordered(): + # named outputs are alphabetical, unnamed output appear in the given order + from sympy.functions.elementary.trigonometric import (cos, sin) + a = symbols("a") + name_expr = ("foo", [cos(2*x), Equality(y, sin(x)), cos(x), Equality(a, sin(2*x))]) + result, = codegen(name_expr, "Rust", header=False, empty=False) + assert result[0] == "foo.rs" + source = result[1] + expected = ( + "fn foo(x: f64) -> (f64, f64, f64, f64) {\n" + " let out1 = (2*x).cos();\n" + " let y = x.sin();\n" + " let out3 = x.cos();\n" + " let a = (2*x).sin();\n" + " (out1, y, out3, a)\n" + "}\n" + ) + assert source == expected + + +def test_piecewise_(): + pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True), evaluate=False) + name_expr = ("pwtest", pw) + result, = codegen(name_expr, "Rust", header=False, empty=False) + source = result[1] + expected = ( + "fn pwtest(x: f64) -> f64 {\n" + " let out1 = if (x < -1.0) {\n" + " 0\n" + " } else if (x <= 1.0) {\n" + " x.powi(2)\n" + " } else if (x > 1.0) {\n" + " 2 - x\n" + " } else {\n" + " 1\n" + " };\n" + " out1\n" + "}\n" + ) + assert source == expected + + +@XFAIL +def test_piecewise_inline(): + # FIXME: how to pass inline to the RustCodePrinter? + pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True)) + name_expr = ("pwtest", pw) + result, = codegen(name_expr, "Rust", header=False, empty=False, + inline=True) + source = result[1] + expected = ( + "fn pwtest(x: f64) -> f64 {\n" + " let out1 = if (x < -1) { 0 } else if (x <= 1) { x.powi(2) }" + " else if (x > 1) { -x + 2 } else { 1 };\n" + " out1\n" + "}\n" + ) + assert source == expected + + +def test_multifcns_per_file(): + name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ] + result = codegen(name_expr, "Rust", header=False, empty=False) + assert result[0][0] == "foo.rs" + source = result[0][1] + expected = ( + "fn foo(x: f64, y: f64) -> (f64, f64) {\n" + " let out1 = 2*x;\n" + " let out2 = 3*y;\n" + " (out1, out2)\n" + "}\n" + "fn bar(y: f64) -> (f64, f64) {\n" + " let out1 = y.powi(2);\n" + " let out2 = 4*y;\n" + " (out1, out2)\n" + "}\n" + ) + assert source == expected + + +def test_multifcns_per_file_w_header(): + name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ] + result = codegen(name_expr, "Rust", header=True, empty=False) + assert result[0][0] == "foo.rs" + source = result[0][1] + version_str = "Code generated with SymPy %s" % sympy.__version__ + version_line = version_str.center(76).rstrip() + expected = ( + "/*\n" + " *%(version_line)s\n" + " *\n" + " * See http://www.sympy.org/ for more information.\n" + " *\n" + " * This file is part of 'project'\n" + " */\n" + "fn foo(x: f64, y: f64) -> (f64, f64) {\n" + " let out1 = 2*x;\n" + " let out2 = 3*y;\n" + " (out1, out2)\n" + "}\n" + "fn bar(y: f64) -> (f64, f64) {\n" + " let out1 = y.powi(2);\n" + " let out2 = 4*y;\n" + " (out1, out2)\n" + "}\n" + ) % {'version_line': version_line} + assert source == expected + + +def test_filename_match_prefix(): + name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ] + result, = codegen(name_expr, "Rust", prefix="baz", header=False, + empty=False) + assert result[0] == "baz.rs" + + +def test_InOutArgument(): + expr = Equality(x, x**2) + name_expr = ("mysqr", expr) + result, = codegen(name_expr, "Rust", header=False, empty=False) + source = result[1] + expected = ( + "fn mysqr(x: f64) -> f64 {\n" + " let x = x.powi(2);\n" + " x\n" + "}\n" + ) + assert source == expected + + +def test_InOutArgument_order(): + # can specify the order as (x, y) + expr = Equality(x, x**2 + y) + name_expr = ("test", expr) + result, = codegen(name_expr, "Rust", header=False, + empty=False, argument_sequence=(x,y)) + source = result[1] + expected = ( + "fn test(x: f64, y: f64) -> f64 {\n" + " let x = x.powi(2) + y;\n" + " x\n" + "}\n" + ) + assert source == expected + # make sure it gives (x, y) not (y, x) + expr = Equality(x, x**2 + y) + name_expr = ("test", expr) + result, = codegen(name_expr, "Rust", header=False, empty=False) + source = result[1] + expected = ( + "fn test(x: f64, y: f64) -> f64 {\n" + " let x = x.powi(2) + y;\n" + " x\n" + "}\n" + ) + assert source == expected + + +def test_not_supported(): + f = Function('f') + name_expr = ("test", [f(x).diff(x), S.ComplexInfinity]) + result, = codegen(name_expr, "Rust", header=False, empty=False) + source = result[1] + expected = ( + "fn test(x: f64) -> (f64, f64) {\n" + " // unsupported: Derivative(f(x), x)\n" + " // unsupported: zoo\n" + " let out1 = Derivative(f(x), x);\n" + " let out2 = zoo;\n" + " (out1, out2)\n" + "}\n" + ) + assert source == expected + + +def test_global_vars_rust(): + x, y, z, t = symbols("x y z t") + result = codegen(('f', x*y), "Rust", header=False, empty=False, + global_vars=(y,)) + source = result[0][1] + expected = ( + "fn f(x: f64) -> f64 {\n" + " let out1 = x*y;\n" + " out1\n" + "}\n" + ) + assert source == expected + + result = codegen(('f', x*y+z), "Rust", header=False, empty=False, + argument_sequence=(x, y), global_vars=(z, t)) + source = result[0][1] + expected = ( + "fn f(x: f64, y: f64) -> f64 {\n" + " let out1 = x*y + z;\n" + " out1\n" + "}\n" + ) + assert source == expected diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/test_decorator.py b/phivenv/Lib/site-packages/sympy/utilities/tests/test_decorator.py new file mode 100644 index 0000000000000000000000000000000000000000..b1870d4db8f719fdabfeab14120bfb3ce10131a9 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/tests/test_decorator.py @@ -0,0 +1,129 @@ +from functools import wraps + +from sympy.utilities.decorator import threaded, xthreaded, memoize_property, deprecated +from sympy.testing.pytest import warns_deprecated_sympy + +from sympy.core.basic import Basic +from sympy.core.relational import Eq +from sympy.matrices.dense import Matrix + +from sympy.abc import x, y + + +def test_threaded(): + @threaded + def function(expr, *args): + return 2*expr + sum(args) + + assert function(Matrix([[x, y], [1, x]]), 1, 2) == \ + Matrix([[2*x + 3, 2*y + 3], [5, 2*x + 3]]) + + assert function(Eq(x, y), 1, 2) == Eq(2*x + 3, 2*y + 3) + + assert function([x, y], 1, 2) == [2*x + 3, 2*y + 3] + assert function((x, y), 1, 2) == (2*x + 3, 2*y + 3) + + assert function({x, y}, 1, 2) == {2*x + 3, 2*y + 3} + + @threaded + def function(expr, n): + return expr**n + + assert function(x + y, 2) == x**2 + y**2 + assert function(x, 2) == x**2 + + +def test_xthreaded(): + @xthreaded + def function(expr, n): + return expr**n + + assert function(x + y, 2) == (x + y)**2 + + +def test_wraps(): + def my_func(x): + """My function. """ + + my_func.is_my_func = True + + new_my_func = threaded(my_func) + new_my_func = wraps(my_func)(new_my_func) + + assert new_my_func.__name__ == 'my_func' + assert new_my_func.__doc__ == 'My function. ' + assert hasattr(new_my_func, 'is_my_func') + assert new_my_func.is_my_func is True + + +def test_memoize_property(): + class TestMemoize(Basic): + @memoize_property + def prop(self): + return Basic() + + member = TestMemoize() + obj1 = member.prop + obj2 = member.prop + assert obj1 is obj2 + +def test_deprecated(): + @deprecated('deprecated_function is deprecated', + deprecated_since_version='1.10', + # This is the target at the top of the file, which will never + # go away. + active_deprecations_target='active-deprecations') + def deprecated_function(x): + return x + + with warns_deprecated_sympy(): + assert deprecated_function(1) == 1 + + @deprecated('deprecated_class is deprecated', + deprecated_since_version='1.10', + active_deprecations_target='active-deprecations') + class deprecated_class: + pass + + with warns_deprecated_sympy(): + assert isinstance(deprecated_class(), deprecated_class) + + # Ensure the class decorator works even when the class never returns + # itself + @deprecated('deprecated_class_new is deprecated', + deprecated_since_version='1.10', + active_deprecations_target='active-deprecations') + class deprecated_class_new: + def __new__(cls, arg): + return arg + + with warns_deprecated_sympy(): + assert deprecated_class_new(1) == 1 + + @deprecated('deprecated_class_init is deprecated', + deprecated_since_version='1.10', + active_deprecations_target='active-deprecations') + class deprecated_class_init: + def __init__(self, arg): + self.arg = 1 + + with warns_deprecated_sympy(): + assert deprecated_class_init(1).arg == 1 + + @deprecated('deprecated_class_new_init is deprecated', + deprecated_since_version='1.10', + active_deprecations_target='active-deprecations') + class deprecated_class_new_init: + def __new__(cls, arg): + if arg == 0: + return arg + return object.__new__(cls) + + def __init__(self, arg): + self.arg = 1 + + with warns_deprecated_sympy(): + assert deprecated_class_new_init(0) == 0 + + with warns_deprecated_sympy(): + assert deprecated_class_new_init(1).arg == 1 diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/test_deprecated.py b/phivenv/Lib/site-packages/sympy/utilities/tests/test_deprecated.py new file mode 100644 index 0000000000000000000000000000000000000000..dd4534ef1abc38ff368011b3ef9d11c497f3675b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/tests/test_deprecated.py @@ -0,0 +1,13 @@ +from sympy.testing.pytest import warns_deprecated_sympy + +# See https://github.com/sympy/sympy/pull/18095 + +def test_deprecated_utilities(): + with warns_deprecated_sympy(): + import sympy.utilities.pytest # noqa:F401 + with warns_deprecated_sympy(): + import sympy.utilities.runtests # noqa:F401 + with warns_deprecated_sympy(): + import sympy.utilities.randtest # noqa:F401 + with warns_deprecated_sympy(): + import sympy.utilities.tmpfiles # noqa:F401 diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/test_enumerative.py b/phivenv/Lib/site-packages/sympy/utilities/tests/test_enumerative.py new file mode 100644 index 0000000000000000000000000000000000000000..357499d5fd400b14e2bcad067f3015b74b0e9003 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/tests/test_enumerative.py @@ -0,0 +1,179 @@ +import string +from itertools import zip_longest + +from sympy.utilities.enumerative import ( + list_visitor, + MultisetPartitionTraverser, + multiset_partitions_taocp + ) +from sympy.utilities.iterables import _set_partitions + +# first some functions only useful as test scaffolding - these provide +# straightforward, but slow reference implementations against which to +# compare the real versions, and also a comparison to verify that +# different versions are giving identical results. + +def part_range_filter(partition_iterator, lb, ub): + """ + Filters (on the number of parts) a multiset partition enumeration + + Arguments + ========= + + lb, and ub are a range (in the Python slice sense) on the lpart + variable returned from a multiset partition enumeration. Recall + that lpart is 0-based (it points to the topmost part on the part + stack), so if you want to return parts of sizes 2,3,4,5 you would + use lb=1 and ub=5. + """ + for state in partition_iterator: + f, lpart, pstack = state + if lpart >= lb and lpart < ub: + yield state + +def multiset_partitions_baseline(multiplicities, components): + """Enumerates partitions of a multiset + + Parameters + ========== + + multiplicities + list of integer multiplicities of the components of the multiset. + + components + the components (elements) themselves + + Returns + ======= + + Set of partitions. Each partition is tuple of parts, and each + part is a tuple of components (with repeats to indicate + multiplicity) + + Notes + ===== + + Multiset partitions can be created as equivalence classes of set + partitions, and this function does just that. This approach is + slow and memory intensive compared to the more advanced algorithms + available, but the code is simple and easy to understand. Hence + this routine is strictly for testing -- to provide a + straightforward baseline against which to regress the production + versions. (This code is a simplified version of an earlier + production implementation.) + """ + + canon = [] # list of components with repeats + for ct, elem in zip(multiplicities, components): + canon.extend([elem]*ct) + + # accumulate the multiset partitions in a set to eliminate dups + cache = set() + n = len(canon) + for nc, q in _set_partitions(n): + rv = [[] for i in range(nc)] + for i in range(n): + rv[q[i]].append(canon[i]) + canonical = tuple( + sorted([tuple(p) for p in rv])) + cache.add(canonical) + return cache + + +def compare_multiset_w_baseline(multiplicities): + """ + Enumerates the partitions of multiset with AOCP algorithm and + baseline implementation, and compare the results. + + """ + letters = string.ascii_lowercase + bl_partitions = multiset_partitions_baseline(multiplicities, letters) + + # The partitions returned by the different algorithms may have + # their parts in different orders. Also, they generate partitions + # in different orders. Hence the sorting, and set comparison. + + aocp_partitions = set() + for state in multiset_partitions_taocp(multiplicities): + p1 = tuple(sorted( + [tuple(p) for p in list_visitor(state, letters)])) + aocp_partitions.add(p1) + + assert bl_partitions == aocp_partitions + +def compare_multiset_states(s1, s2): + """compare for equality two instances of multiset partition states + + This is useful for comparing different versions of the algorithm + to verify correctness.""" + # Comparison is physical, the only use of semantics is to ignore + # trash off the top of the stack. + f1, lpart1, pstack1 = s1 + f2, lpart2, pstack2 = s2 + + if (lpart1 == lpart2) and (f1[0:lpart1+1] == f2[0:lpart2+1]): + if pstack1[0:f1[lpart1+1]] == pstack2[0:f2[lpart2+1]]: + return True + return False + +def test_multiset_partitions_taocp(): + """Compares the output of multiset_partitions_taocp with a baseline + (set partition based) implementation.""" + + # Test cases should not be too large, since the baseline + # implementation is fairly slow. + multiplicities = [2,2] + compare_multiset_w_baseline(multiplicities) + + multiplicities = [4,3,1] + compare_multiset_w_baseline(multiplicities) + +def test_multiset_partitions_versions(): + """Compares Knuth-based versions of multiset_partitions""" + multiplicities = [5,2,2,1] + m = MultisetPartitionTraverser() + for s1, s2 in zip_longest(m.enum_all(multiplicities), + multiset_partitions_taocp(multiplicities)): + assert compare_multiset_states(s1, s2) + +def subrange_exercise(mult, lb, ub): + """Compare filter-based and more optimized subrange implementations + + Helper for tests, called with both small and larger multisets. + """ + m = MultisetPartitionTraverser() + assert m.count_partitions(mult) == \ + m.count_partitions_slow(mult) + + # Note - multiple traversals from the same + # MultisetPartitionTraverser object cannot execute at the same + # time, hence make several instances here. + ma = MultisetPartitionTraverser() + mc = MultisetPartitionTraverser() + md = MultisetPartitionTraverser() + + # Several paths to compute just the size two partitions + a_it = ma.enum_range(mult, lb, ub) + b_it = part_range_filter(multiset_partitions_taocp(mult), lb, ub) + c_it = part_range_filter(mc.enum_small(mult, ub), lb, sum(mult)) + d_it = part_range_filter(md.enum_large(mult, lb), 0, ub) + + for sa, sb, sc, sd in zip_longest(a_it, b_it, c_it, d_it): + assert compare_multiset_states(sa, sb) + assert compare_multiset_states(sa, sc) + assert compare_multiset_states(sa, sd) + +def test_subrange(): + # Quick, but doesn't hit some of the corner cases + mult = [4,4,2,1] # mississippi + lb = 1 + ub = 2 + subrange_exercise(mult, lb, ub) + + +def test_subrange_large(): + # takes a second or so, depending on cpu, Python version, etc. + mult = [6,3,2,1] + lb = 4 + ub = 7 + subrange_exercise(mult, lb, ub) diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/test_exceptions.py b/phivenv/Lib/site-packages/sympy/utilities/tests/test_exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..d91e55e95d0ae4ac57cdd1989e0b3d39a55cd07d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/tests/test_exceptions.py @@ -0,0 +1,12 @@ +from sympy.testing.pytest import raises +from sympy.utilities.exceptions import sympy_deprecation_warning + +# Only test exceptions here because the other cases are tested in the +# warns_deprecated_sympy tests +def test_sympy_deprecation_warning(): + raises(TypeError, lambda: sympy_deprecation_warning('test', + deprecated_since_version=1.10, + active_deprecations_target='active-deprecations')) + + raises(ValueError, lambda: sympy_deprecation_warning('test', + deprecated_since_version="1.10", active_deprecations_target='(active-deprecations)=')) diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/test_iterables.py b/phivenv/Lib/site-packages/sympy/utilities/tests/test_iterables.py new file mode 100644 index 0000000000000000000000000000000000000000..1003522bcd556c6f63e04de7da57b43498575fee --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/tests/test_iterables.py @@ -0,0 +1,945 @@ +from textwrap import dedent +from itertools import islice, product + +from sympy.core.basic import Basic +from sympy.core.numbers import Integer +from sympy.core.sorting import ordered +from sympy.core.symbol import (Dummy, symbols) +from sympy.functions.combinatorial.factorials import factorial +from sympy.matrices.dense import Matrix +from sympy.combinatorics import RGS_enum, RGS_unrank, Permutation +from sympy.utilities.iterables import ( + _partition, _set_partitions, binary_partitions, bracelets, capture, + cartes, common_prefix, common_suffix, connected_components, dict_merge, + filter_symbols, flatten, generate_bell, generate_derangements, + generate_involutions, generate_oriented_forest, group, has_dups, ibin, + iproduct, kbins, minlex, multiset, multiset_combinations, + multiset_partitions, multiset_permutations, necklaces, numbered_symbols, + partitions, permutations, postfixes, + prefixes, reshape, rotate_left, rotate_right, runs, sift, + strongly_connected_components, subsets, take, topological_sort, unflatten, + uniq, variations, ordered_partitions, rotations, is_palindromic, iterable, + NotIterable, multiset_derangements, signed_permutations, + sequence_partitions, sequence_partitions_empty) +from sympy.utilities.enumerative import ( + factoring_visitor, multiset_partitions_taocp ) + +from sympy.core.singleton import S +from sympy.testing.pytest import raises, warns_deprecated_sympy + +w, x, y, z = symbols('w,x,y,z') + + +def test_deprecated_iterables(): + from sympy.utilities.iterables import default_sort_key, ordered + with warns_deprecated_sympy(): + assert list(ordered([y, x])) == [x, y] + with warns_deprecated_sympy(): + assert sorted([y, x], key=default_sort_key) == [x, y] + + +def test_is_palindromic(): + assert is_palindromic('') + assert is_palindromic('x') + assert is_palindromic('xx') + assert is_palindromic('xyx') + assert not is_palindromic('xy') + assert not is_palindromic('xyzx') + assert is_palindromic('xxyzzyx', 1) + assert not is_palindromic('xxyzzyx', 2) + assert is_palindromic('xxyzzyx', 2, -1) + assert is_palindromic('xxyzzyx', 2, 6) + assert is_palindromic('xxyzyx', 1) + assert not is_palindromic('xxyzyx', 2) + assert is_palindromic('xxyzyx', 2, 2 + 3) + + +def test_flatten(): + assert flatten((1, (1,))) == [1, 1] + assert flatten((x, (x,))) == [x, x] + + ls = [[(-2, -1), (1, 2)], [(0, 0)]] + + assert flatten(ls, levels=0) == ls + assert flatten(ls, levels=1) == [(-2, -1), (1, 2), (0, 0)] + assert flatten(ls, levels=2) == [-2, -1, 1, 2, 0, 0] + assert flatten(ls, levels=3) == [-2, -1, 1, 2, 0, 0] + + raises(ValueError, lambda: flatten(ls, levels=-1)) + + class MyOp(Basic): + pass + + assert flatten([MyOp(x, y), z]) == [MyOp(x, y), z] + assert flatten([MyOp(x, y), z], cls=MyOp) == [x, y, z] + + assert flatten({1, 11, 2}) == list({1, 11, 2}) + + +def test_iproduct(): + assert list(iproduct()) == [()] + assert list(iproduct([])) == [] + assert list(iproduct([1,2,3])) == [(1,),(2,),(3,)] + assert sorted(iproduct([1, 2], [3, 4, 5])) == [ + (1,3),(1,4),(1,5),(2,3),(2,4),(2,5)] + assert sorted(iproduct([0,1],[0,1],[0,1])) == [ + (0,0,0),(0,0,1),(0,1,0),(0,1,1),(1,0,0),(1,0,1),(1,1,0),(1,1,1)] + assert iterable(iproduct(S.Integers)) is True + assert iterable(iproduct(S.Integers, S.Integers)) is True + assert (3,) in iproduct(S.Integers) + assert (4, 5) in iproduct(S.Integers, S.Integers) + assert (1, 2, 3) in iproduct(S.Integers, S.Integers, S.Integers) + triples = set(islice(iproduct(S.Integers, S.Integers, S.Integers), 1000)) + for n1, n2, n3 in triples: + assert isinstance(n1, Integer) + assert isinstance(n2, Integer) + assert isinstance(n3, Integer) + for t in set(product(*([range(-2, 3)]*3))): + assert t in iproduct(S.Integers, S.Integers, S.Integers) + + +def test_group(): + assert group([]) == [] + assert group([], multiple=False) == [] + + assert group([1]) == [[1]] + assert group([1], multiple=False) == [(1, 1)] + + assert group([1, 1]) == [[1, 1]] + assert group([1, 1], multiple=False) == [(1, 2)] + + assert group([1, 1, 1]) == [[1, 1, 1]] + assert group([1, 1, 1], multiple=False) == [(1, 3)] + + assert group([1, 2, 1]) == [[1], [2], [1]] + assert group([1, 2, 1], multiple=False) == [(1, 1), (2, 1), (1, 1)] + + assert group([1, 1, 2, 2, 2, 1, 3, 3]) == [[1, 1], [2, 2, 2], [1], [3, 3]] + assert group([1, 1, 2, 2, 2, 1, 3, 3], multiple=False) == [(1, 2), + (2, 3), (1, 1), (3, 2)] + + +def test_subsets(): + # combinations + assert list(subsets([1, 2, 3], 0)) == [()] + assert list(subsets([1, 2, 3], 1)) == [(1,), (2,), (3,)] + assert list(subsets([1, 2, 3], 2)) == [(1, 2), (1, 3), (2, 3)] + assert list(subsets([1, 2, 3], 3)) == [(1, 2, 3)] + l = list(range(4)) + assert list(subsets(l, 0, repetition=True)) == [()] + assert list(subsets(l, 1, repetition=True)) == [(0,), (1,), (2,), (3,)] + assert list(subsets(l, 2, repetition=True)) == [(0, 0), (0, 1), (0, 2), + (0, 3), (1, 1), (1, 2), + (1, 3), (2, 2), (2, 3), + (3, 3)] + assert list(subsets(l, 3, repetition=True)) == [(0, 0, 0), (0, 0, 1), + (0, 0, 2), (0, 0, 3), + (0, 1, 1), (0, 1, 2), + (0, 1, 3), (0, 2, 2), + (0, 2, 3), (0, 3, 3), + (1, 1, 1), (1, 1, 2), + (1, 1, 3), (1, 2, 2), + (1, 2, 3), (1, 3, 3), + (2, 2, 2), (2, 2, 3), + (2, 3, 3), (3, 3, 3)] + assert len(list(subsets(l, 4, repetition=True))) == 35 + + assert list(subsets(l[:2], 3, repetition=False)) == [] + assert list(subsets(l[:2], 3, repetition=True)) == [(0, 0, 0), + (0, 0, 1), + (0, 1, 1), + (1, 1, 1)] + assert list(subsets([1, 2], repetition=True)) == \ + [(), (1,), (2,), (1, 1), (1, 2), (2, 2)] + assert list(subsets([1, 2], repetition=False)) == \ + [(), (1,), (2,), (1, 2)] + assert list(subsets([1, 2, 3], 2)) == \ + [(1, 2), (1, 3), (2, 3)] + assert list(subsets([1, 2, 3], 2, repetition=True)) == \ + [(1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (3, 3)] + + +def test_variations(): + # permutations + l = list(range(4)) + assert list(variations(l, 0, repetition=False)) == [()] + assert list(variations(l, 1, repetition=False)) == [(0,), (1,), (2,), (3,)] + assert list(variations(l, 2, repetition=False)) == [(0, 1), (0, 2), (0, 3), (1, 0), (1, 2), (1, 3), (2, 0), (2, 1), (2, 3), (3, 0), (3, 1), (3, 2)] + assert list(variations(l, 3, repetition=False)) == [(0, 1, 2), (0, 1, 3), (0, 2, 1), (0, 2, 3), (0, 3, 1), (0, 3, 2), (1, 0, 2), (1, 0, 3), (1, 2, 0), (1, 2, 3), (1, 3, 0), (1, 3, 2), (2, 0, 1), (2, 0, 3), (2, 1, 0), (2, 1, 3), (2, 3, 0), (2, 3, 1), (3, 0, 1), (3, 0, 2), (3, 1, 0), (3, 1, 2), (3, 2, 0), (3, 2, 1)] + assert list(variations(l, 0, repetition=True)) == [()] + assert list(variations(l, 1, repetition=True)) == [(0,), (1,), (2,), (3,)] + assert list(variations(l, 2, repetition=True)) == [(0, 0), (0, 1), (0, 2), + (0, 3), (1, 0), (1, 1), + (1, 2), (1, 3), (2, 0), + (2, 1), (2, 2), (2, 3), + (3, 0), (3, 1), (3, 2), + (3, 3)] + assert len(list(variations(l, 3, repetition=True))) == 64 + assert len(list(variations(l, 4, repetition=True))) == 256 + assert list(variations(l[:2], 3, repetition=False)) == [] + assert list(variations(l[:2], 3, repetition=True)) == [ + (0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 1, 1), + (1, 0, 0), (1, 0, 1), (1, 1, 0), (1, 1, 1) + ] + + +def test_cartes(): + assert list(cartes([1, 2], [3, 4, 5])) == \ + [(1, 3), (1, 4), (1, 5), (2, 3), (2, 4), (2, 5)] + assert list(cartes()) == [()] + assert list(cartes('a')) == [('a',)] + assert list(cartes('a', repeat=2)) == [('a', 'a')] + assert list(cartes(list(range(2)))) == [(0,), (1,)] + + +def test_filter_symbols(): + s = numbered_symbols() + filtered = filter_symbols(s, symbols("x0 x2 x3")) + assert take(filtered, 3) == list(symbols("x1 x4 x5")) + + +def test_numbered_symbols(): + s = numbered_symbols(cls=Dummy) + assert isinstance(next(s), Dummy) + assert next(numbered_symbols('C', start=1, exclude=[symbols('C1')])) == \ + symbols('C2') + + +def test_sift(): + assert sift(list(range(5)), lambda _: _ % 2) == {1: [1, 3], 0: [0, 2, 4]} + assert sift([x, y], lambda _: _.has(x)) == {False: [y], True: [x]} + assert sift([S.One], lambda _: _.has(x)) == {False: [1]} + assert sift([0, 1, 2, 3], lambda x: x % 2, binary=True) == ( + [1, 3], [0, 2]) + assert sift([0, 1, 2, 3], lambda x: x % 3 == 1, binary=True) == ( + [1], [0, 2, 3]) + raises(ValueError, lambda: + sift([0, 1, 2, 3], lambda x: x % 3, binary=True)) + + +def test_take(): + X = numbered_symbols() + + assert take(X, 5) == list(symbols('x0:5')) + assert take(X, 5) == list(symbols('x5:10')) + + assert take([1, 2, 3, 4, 5], 5) == [1, 2, 3, 4, 5] + + +def test_dict_merge(): + assert dict_merge({}, {1: x, y: z}) == {1: x, y: z} + assert dict_merge({1: x, y: z}, {}) == {1: x, y: z} + + assert dict_merge({2: z}, {1: x, y: z}) == {1: x, 2: z, y: z} + assert dict_merge({1: x, y: z}, {2: z}) == {1: x, 2: z, y: z} + + assert dict_merge({1: y, 2: z}, {1: x, y: z}) == {1: x, 2: z, y: z} + assert dict_merge({1: x, y: z}, {1: y, 2: z}) == {1: y, 2: z, y: z} + + +def test_prefixes(): + assert list(prefixes([])) == [] + assert list(prefixes([1])) == [[1]] + assert list(prefixes([1, 2])) == [[1], [1, 2]] + + assert list(prefixes([1, 2, 3, 4, 5])) == \ + [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4], [1, 2, 3, 4, 5]] + + +def test_postfixes(): + assert list(postfixes([])) == [] + assert list(postfixes([1])) == [[1]] + assert list(postfixes([1, 2])) == [[2], [1, 2]] + + assert list(postfixes([1, 2, 3, 4, 5])) == \ + [[5], [4, 5], [3, 4, 5], [2, 3, 4, 5], [1, 2, 3, 4, 5]] + + +def test_topological_sort(): + V = [2, 3, 5, 7, 8, 9, 10, 11] + E = [(7, 11), (7, 8), (5, 11), + (3, 8), (3, 10), (11, 2), + (11, 9), (11, 10), (8, 9)] + + assert topological_sort((V, E)) == [3, 5, 7, 8, 11, 2, 9, 10] + assert topological_sort((V, E), key=lambda v: -v) == \ + [7, 5, 11, 3, 10, 8, 9, 2] + + raises(ValueError, lambda: topological_sort((V, E + [(10, 7)]))) + + +def test_strongly_connected_components(): + assert strongly_connected_components(([], [])) == [] + assert strongly_connected_components(([1, 2, 3], [])) == [[1], [2], [3]] + + V = [1, 2, 3] + E = [(1, 2), (1, 3), (2, 1), (2, 3), (3, 1)] + assert strongly_connected_components((V, E)) == [[1, 2, 3]] + + V = [1, 2, 3, 4] + E = [(1, 2), (2, 3), (3, 2), (3, 4)] + assert strongly_connected_components((V, E)) == [[4], [2, 3], [1]] + + V = [1, 2, 3, 4] + E = [(1, 2), (2, 1), (3, 4), (4, 3)] + assert strongly_connected_components((V, E)) == [[1, 2], [3, 4]] + + +def test_connected_components(): + assert connected_components(([], [])) == [] + assert connected_components(([1, 2, 3], [])) == [[1], [2], [3]] + + V = [1, 2, 3] + E = [(1, 2), (1, 3), (2, 1), (2, 3), (3, 1)] + assert connected_components((V, E)) == [[1, 2, 3]] + + V = [1, 2, 3, 4] + E = [(1, 2), (2, 3), (3, 2), (3, 4)] + assert connected_components((V, E)) == [[1, 2, 3, 4]] + + V = [1, 2, 3, 4] + E = [(1, 2), (3, 4)] + assert connected_components((V, E)) == [[1, 2], [3, 4]] + + +def test_rotate(): + A = [0, 1, 2, 3, 4] + + assert rotate_left(A, 2) == [2, 3, 4, 0, 1] + assert rotate_right(A, 1) == [4, 0, 1, 2, 3] + A = [] + B = rotate_right(A, 1) + assert B == [] + B.append(1) + assert A == [] + B = rotate_left(A, 1) + assert B == [] + B.append(1) + assert A == [] + + +def test_multiset_partitions(): + A = [0, 1, 2, 3, 4] + + assert list(multiset_partitions(A, 5)) == [[[0], [1], [2], [3], [4]]] + assert len(list(multiset_partitions(A, 4))) == 10 + assert len(list(multiset_partitions(A, 3))) == 25 + + assert list(multiset_partitions([1, 1, 1, 2, 2], 2)) == [ + [[1, 1, 1, 2], [2]], [[1, 1, 1], [2, 2]], [[1, 1, 2, 2], [1]], + [[1, 1, 2], [1, 2]], [[1, 1], [1, 2, 2]]] + + assert list(multiset_partitions([1, 1, 2, 2], 2)) == [ + [[1, 1, 2], [2]], [[1, 1], [2, 2]], [[1, 2, 2], [1]], + [[1, 2], [1, 2]]] + + assert list(multiset_partitions([1, 2, 3, 4], 2)) == [ + [[1, 2, 3], [4]], [[1, 2, 4], [3]], [[1, 2], [3, 4]], + [[1, 3, 4], [2]], [[1, 3], [2, 4]], [[1, 4], [2, 3]], + [[1], [2, 3, 4]]] + + assert list(multiset_partitions([1, 2, 2], 2)) == [ + [[1, 2], [2]], [[1], [2, 2]]] + + assert list(multiset_partitions(3)) == [ + [[0, 1, 2]], [[0, 1], [2]], [[0, 2], [1]], [[0], [1, 2]], + [[0], [1], [2]]] + assert list(multiset_partitions(3, 2)) == [ + [[0, 1], [2]], [[0, 2], [1]], [[0], [1, 2]]] + assert list(multiset_partitions([1] * 3, 2)) == [[[1], [1, 1]]] + assert list(multiset_partitions([1] * 3)) == [ + [[1, 1, 1]], [[1], [1, 1]], [[1], [1], [1]]] + a = [3, 2, 1] + assert list(multiset_partitions(a)) == \ + list(multiset_partitions(sorted(a))) + assert list(multiset_partitions(a, 5)) == [] + assert list(multiset_partitions(a, 1)) == [[[1, 2, 3]]] + assert list(multiset_partitions(a + [4], 5)) == [] + assert list(multiset_partitions(a + [4], 1)) == [[[1, 2, 3, 4]]] + assert list(multiset_partitions(2, 5)) == [] + assert list(multiset_partitions(2, 1)) == [[[0, 1]]] + assert list(multiset_partitions('a')) == [[['a']]] + assert list(multiset_partitions('a', 2)) == [] + assert list(multiset_partitions('ab')) == [[['a', 'b']], [['a'], ['b']]] + assert list(multiset_partitions('ab', 1)) == [[['a', 'b']]] + assert list(multiset_partitions('aaa', 1)) == [['aaa']] + assert list(multiset_partitions([1, 1], 1)) == [[[1, 1]]] + ans = [('mpsyy',), ('mpsy', 'y'), ('mps', 'yy'), ('mps', 'y', 'y'), + ('mpyy', 's'), ('mpy', 'sy'), ('mpy', 's', 'y'), ('mp', 'syy'), + ('mp', 'sy', 'y'), ('mp', 's', 'yy'), ('mp', 's', 'y', 'y'), + ('msyy', 'p'), ('msy', 'py'), ('msy', 'p', 'y'), ('ms', 'pyy'), + ('ms', 'py', 'y'), ('ms', 'p', 'yy'), ('ms', 'p', 'y', 'y'), + ('myy', 'ps'), ('myy', 'p', 's'), ('my', 'psy'), ('my', 'ps', 'y'), + ('my', 'py', 's'), ('my', 'p', 'sy'), ('my', 'p', 's', 'y'), + ('m', 'psyy'), ('m', 'psy', 'y'), ('m', 'ps', 'yy'), + ('m', 'ps', 'y', 'y'), ('m', 'pyy', 's'), ('m', 'py', 'sy'), + ('m', 'py', 's', 'y'), ('m', 'p', 'syy'), + ('m', 'p', 'sy', 'y'), ('m', 'p', 's', 'yy'), + ('m', 'p', 's', 'y', 'y')] + assert [tuple("".join(part) for part in p) + for p in multiset_partitions('sympy')] == ans + factorings = [[24], [8, 3], [12, 2], [4, 6], [4, 2, 3], + [6, 2, 2], [2, 2, 2, 3]] + assert [factoring_visitor(p, [2,3]) for + p in multiset_partitions_taocp([3, 1])] == factorings + + +def test_multiset_combinations(): + ans = ['iii', 'iim', 'iip', 'iis', 'imp', 'ims', 'ipp', 'ips', + 'iss', 'mpp', 'mps', 'mss', 'pps', 'pss', 'sss'] + assert [''.join(i) for i in + list(multiset_combinations('mississippi', 3))] == ans + M = multiset('mississippi') + assert [''.join(i) for i in + list(multiset_combinations(M, 3))] == ans + assert [''.join(i) for i in multiset_combinations(M, 30)] == [] + assert list(multiset_combinations([[1], [2, 3]], 2)) == [[[1], [2, 3]]] + assert len(list(multiset_combinations('a', 3))) == 0 + assert len(list(multiset_combinations('a', 0))) == 1 + assert list(multiset_combinations('abc', 1)) == [['a'], ['b'], ['c']] + raises(ValueError, lambda: list(multiset_combinations({0: 3, 1: -1}, 2))) + + +def test_multiset_permutations(): + ans = ['abby', 'abyb', 'aybb', 'baby', 'bayb', 'bbay', 'bbya', 'byab', + 'byba', 'yabb', 'ybab', 'ybba'] + assert [''.join(i) for i in multiset_permutations('baby')] == ans + assert [''.join(i) for i in multiset_permutations(multiset('baby'))] == ans + assert list(multiset_permutations([0, 0, 0], 2)) == [[0, 0]] + assert list(multiset_permutations([0, 2, 1], 2)) == [ + [0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]] + assert len(list(multiset_permutations('a', 0))) == 1 + assert len(list(multiset_permutations('a', 3))) == 0 + for nul in ([], {}, ''): + assert list(multiset_permutations(nul)) == [[]] + assert list(multiset_permutations(nul, 0)) == [[]] + # impossible requests give no result + assert list(multiset_permutations(nul, 1)) == [] + assert list(multiset_permutations(nul, -1)) == [] + + def test(): + for i in range(1, 7): + print(i) + for p in multiset_permutations([0, 0, 1, 0, 1], i): + print(p) + assert capture(lambda: test()) == dedent('''\ + 1 + [0] + [1] + 2 + [0, 0] + [0, 1] + [1, 0] + [1, 1] + 3 + [0, 0, 0] + [0, 0, 1] + [0, 1, 0] + [0, 1, 1] + [1, 0, 0] + [1, 0, 1] + [1, 1, 0] + 4 + [0, 0, 0, 1] + [0, 0, 1, 0] + [0, 0, 1, 1] + [0, 1, 0, 0] + [0, 1, 0, 1] + [0, 1, 1, 0] + [1, 0, 0, 0] + [1, 0, 0, 1] + [1, 0, 1, 0] + [1, 1, 0, 0] + 5 + [0, 0, 0, 1, 1] + [0, 0, 1, 0, 1] + [0, 0, 1, 1, 0] + [0, 1, 0, 0, 1] + [0, 1, 0, 1, 0] + [0, 1, 1, 0, 0] + [1, 0, 0, 0, 1] + [1, 0, 0, 1, 0] + [1, 0, 1, 0, 0] + [1, 1, 0, 0, 0] + 6\n''') + raises(ValueError, lambda: list(multiset_permutations({0: 3, 1: -1}))) + + +def test_partitions(): + ans = [[{}], [(0, {})]] + for i in range(2): + assert list(partitions(0, size=i)) == ans[i] + assert list(partitions(1, 0, size=i)) == ans[i] + assert list(partitions(6, 2, 2, size=i)) == ans[i] + assert list(partitions(6, 2, None, size=i)) != ans[i] + assert list(partitions(6, None, 2, size=i)) != ans[i] + assert list(partitions(6, 2, 0, size=i)) == ans[i] + + assert list(partitions(6, k=2)) == [ + {2: 3}, {1: 2, 2: 2}, {1: 4, 2: 1}, {1: 6}] + + assert list(partitions(6, k=3)) == [ + {3: 2}, {1: 1, 2: 1, 3: 1}, {1: 3, 3: 1}, {2: 3}, {1: 2, 2: 2}, + {1: 4, 2: 1}, {1: 6}] + + assert list(partitions(8, k=4, m=3)) == [ + {4: 2}, {1: 1, 3: 1, 4: 1}, {2: 2, 4: 1}, {2: 1, 3: 2}] == [ + i for i in partitions(8, k=4, m=3) if all(k <= 4 for k in i) + and sum(i.values()) <=3] + + assert list(partitions(S(3), m=2)) == [ + {3: 1}, {1: 1, 2: 1}] + + assert list(partitions(4, k=3)) == [ + {1: 1, 3: 1}, {2: 2}, {1: 2, 2: 1}, {1: 4}] == [ + i for i in partitions(4) if all(k <= 3 for k in i)] + + + # Consistency check on output of _partitions and RGS_unrank. + # This provides a sanity test on both routines. Also verifies that + # the total number of partitions is the same in each case. + # (from pkrathmann2) + + for n in range(2, 6): + i = 0 + for m, q in _set_partitions(n): + assert q == RGS_unrank(i, n) + i += 1 + assert i == RGS_enum(n) + + +def test_binary_partitions(): + assert [i[:] for i in binary_partitions(10)] == [[8, 2], [8, 1, 1], + [4, 4, 2], [4, 4, 1, 1], [4, 2, 2, 2], [4, 2, 2, 1, 1], + [4, 2, 1, 1, 1, 1], [4, 1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2], + [2, 2, 2, 2, 1, 1], [2, 2, 2, 1, 1, 1, 1], [2, 2, 1, 1, 1, 1, 1, 1], + [2, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]] + + assert len([j[:] for j in binary_partitions(16)]) == 36 + + +def test_bell_perm(): + assert [len(set(generate_bell(i))) for i in range(1, 7)] == [ + factorial(i) for i in range(1, 7)] + assert list(generate_bell(3)) == [ + (0, 1, 2), (0, 2, 1), (2, 0, 1), (2, 1, 0), (1, 2, 0), (1, 0, 2)] + # generate_bell and trotterjohnson are advertised to return the same + # permutations; this is not technically necessary so this test could + # be removed + for n in range(1, 5): + p = Permutation(range(n)) + b = generate_bell(n) + for bi in b: + assert bi == tuple(p.array_form) + p = p.next_trotterjohnson() + raises(ValueError, lambda: list(generate_bell(0))) # XXX is this consistent with other permutation algorithms? + + +def test_involutions(): + lengths = [1, 2, 4, 10, 26, 76] + for n, N in enumerate(lengths): + i = list(generate_involutions(n + 1)) + assert len(i) == N + assert len({Permutation(j)**2 for j in i}) == 1 + + +def test_derangements(): + assert len(list(generate_derangements(list(range(6))))) == 265 + assert ''.join(''.join(i) for i in generate_derangements('abcde')) == ( + 'badecbaecdbcaedbcdeabceadbdaecbdeacbdecabeacdbedacbedcacabedcadebcaebd' + 'cdaebcdbeacdeabcdebaceabdcebadcedabcedbadabecdaebcdaecbdcaebdcbeadceab' + 'dcebadeabcdeacbdebacdebcaeabcdeadbceadcbecabdecbadecdabecdbaedabcedacb' + 'edbacedbca') + assert list(generate_derangements([0, 1, 2, 3])) == [ + [1, 0, 3, 2], [1, 2, 3, 0], [1, 3, 0, 2], [2, 0, 3, 1], + [2, 3, 0, 1], [2, 3, 1, 0], [3, 0, 1, 2], [3, 2, 0, 1], [3, 2, 1, 0]] + assert list(generate_derangements([0, 1, 2, 2])) == [ + [2, 2, 0, 1], [2, 2, 1, 0]] + assert list(generate_derangements('ba')) == [list('ab')] + # multiset_derangements + D = multiset_derangements + assert list(D('abb')) == [] + assert [''.join(i) for i in D('ab')] == ['ba'] + assert [''.join(i) for i in D('abc')] == ['bca', 'cab'] + assert [''.join(i) for i in D('aabb')] == ['bbaa'] + assert [''.join(i) for i in D('aabbcccc')] == [ + 'ccccaabb', 'ccccabab', 'ccccabba', 'ccccbaab', 'ccccbaba', + 'ccccbbaa'] + assert [''.join(i) for i in D('aabbccc')] == [ + 'cccabba', 'cccabab', 'cccaabb', 'ccacbba', 'ccacbab', + 'ccacabb', 'cbccbaa', 'cbccaba', 'cbccaab', 'bcccbaa', + 'bcccaba', 'bcccaab'] + assert [''.join(i) for i in D('books')] == ['kbsoo', 'ksboo', + 'sbkoo', 'skboo', 'oksbo', 'oskbo', 'okbso', 'obkso', 'oskob', + 'oksob', 'osbok', 'obsok'] + assert list(generate_derangements([[3], [2], [2], [1]])) == [ + [[2], [1], [3], [2]], [[2], [3], [1], [2]]] + + +def test_necklaces(): + def count(n, k, f): + return len(list(necklaces(n, k, f))) + m = [] + for i in range(1, 8): + m.append(( + i, count(i, 2, 0), count(i, 2, 1), count(i, 3, 1))) + assert Matrix(m) == Matrix([ + [1, 2, 2, 3], + [2, 3, 3, 6], + [3, 4, 4, 10], + [4, 6, 6, 21], + [5, 8, 8, 39], + [6, 14, 13, 92], + [7, 20, 18, 198]]) + + +def test_bracelets(): + bc = list(bracelets(2, 4)) + assert Matrix(bc) == Matrix([ + [0, 0], + [0, 1], + [0, 2], + [0, 3], + [1, 1], + [1, 2], + [1, 3], + [2, 2], + [2, 3], + [3, 3] + ]) + bc = list(bracelets(4, 2)) + assert Matrix(bc) == Matrix([ + [0, 0, 0, 0], + [0, 0, 0, 1], + [0, 0, 1, 1], + [0, 1, 0, 1], + [0, 1, 1, 1], + [1, 1, 1, 1] + ]) + + +def test_generate_oriented_forest(): + assert list(generate_oriented_forest(5)) == [[0, 1, 2, 3, 4], + [0, 1, 2, 3, 3], [0, 1, 2, 3, 2], [0, 1, 2, 3, 1], [0, 1, 2, 3, 0], + [0, 1, 2, 2, 2], [0, 1, 2, 2, 1], [0, 1, 2, 2, 0], [0, 1, 2, 1, 2], + [0, 1, 2, 1, 1], [0, 1, 2, 1, 0], [0, 1, 2, 0, 1], [0, 1, 2, 0, 0], + [0, 1, 1, 1, 1], [0, 1, 1, 1, 0], [0, 1, 1, 0, 1], [0, 1, 1, 0, 0], + [0, 1, 0, 1, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 0]] + assert len(list(generate_oriented_forest(10))) == 1842 + + +def test_unflatten(): + r = list(range(10)) + assert unflatten(r) == list(zip(r[::2], r[1::2])) + assert unflatten(r, 5) == [tuple(r[:5]), tuple(r[5:])] + raises(ValueError, lambda: unflatten(list(range(10)), 3)) + raises(ValueError, lambda: unflatten(list(range(10)), -2)) + + +def test_common_prefix_suffix(): + assert common_prefix([], [1]) == [] + assert common_prefix(list(range(3))) == [0, 1, 2] + assert common_prefix(list(range(3)), list(range(4))) == [0, 1, 2] + assert common_prefix([1, 2, 3], [1, 2, 5]) == [1, 2] + assert common_prefix([1, 2, 3], [1, 3, 5]) == [1] + + assert common_suffix([], [1]) == [] + assert common_suffix(list(range(3))) == [0, 1, 2] + assert common_suffix(list(range(3)), list(range(3))) == [0, 1, 2] + assert common_suffix(list(range(3)), list(range(4))) == [] + assert common_suffix([1, 2, 3], [9, 2, 3]) == [2, 3] + assert common_suffix([1, 2, 3], [9, 7, 3]) == [3] + + +def test_minlex(): + assert minlex([1, 2, 0]) == (0, 1, 2) + assert minlex((1, 2, 0)) == (0, 1, 2) + assert minlex((1, 0, 2)) == (0, 2, 1) + assert minlex((1, 0, 2), directed=False) == (0, 1, 2) + assert minlex('aba') == 'aab' + assert minlex(('bb', 'aaa', 'c', 'a'), key=len) == ('c', 'a', 'bb', 'aaa') + + +def test_ordered(): + assert list(ordered((x, y), hash, default=False)) in [[x, y], [y, x]] + assert list(ordered((x, y), hash, default=False)) == \ + list(ordered((y, x), hash, default=False)) + assert list(ordered((x, y))) == [x, y] + + seq, keys = [[[1, 2, 1], [0, 3, 1], [1, 1, 3], [2], [1]], + (lambda x: len(x), lambda x: sum(x))] + assert list(ordered(seq, keys, default=False, warn=False)) == \ + [[1], [2], [1, 2, 1], [0, 3, 1], [1, 1, 3]] + raises(ValueError, lambda: + list(ordered(seq, keys, default=False, warn=True))) + + +def test_runs(): + assert runs([]) == [] + assert runs([1]) == [[1]] + assert runs([1, 1]) == [[1], [1]] + assert runs([1, 1, 2]) == [[1], [1, 2]] + assert runs([1, 2, 1]) == [[1, 2], [1]] + assert runs([2, 1, 1]) == [[2], [1], [1]] + from operator import lt + assert runs([2, 1, 1], lt) == [[2, 1], [1]] + + +def test_reshape(): + seq = list(range(1, 9)) + assert reshape(seq, [4]) == \ + [[1, 2, 3, 4], [5, 6, 7, 8]] + assert reshape(seq, (4,)) == \ + [(1, 2, 3, 4), (5, 6, 7, 8)] + assert reshape(seq, (2, 2)) == \ + [(1, 2, 3, 4), (5, 6, 7, 8)] + assert reshape(seq, (2, [2])) == \ + [(1, 2, [3, 4]), (5, 6, [7, 8])] + assert reshape(seq, ((2,), [2])) == \ + [((1, 2), [3, 4]), ((5, 6), [7, 8])] + assert reshape(seq, (1, [2], 1)) == \ + [(1, [2, 3], 4), (5, [6, 7], 8)] + assert reshape(tuple(seq), ([[1], 1, (2,)],)) == \ + (([[1], 2, (3, 4)],), ([[5], 6, (7, 8)],)) + assert reshape(tuple(seq), ([1], 1, (2,))) == \ + (([1], 2, (3, 4)), ([5], 6, (7, 8))) + assert reshape(list(range(12)), [2, [3], {2}, (1, (3,), 1)]) == \ + [[0, 1, [2, 3, 4], {5, 6}, (7, (8, 9, 10), 11)]] + raises(ValueError, lambda: reshape([0, 1], [-1])) + raises(ValueError, lambda: reshape([0, 1], [3])) + + +def test_uniq(): + assert list(uniq(p for p in partitions(4))) == \ + [{4: 1}, {1: 1, 3: 1}, {2: 2}, {1: 2, 2: 1}, {1: 4}] + assert list(uniq(x % 2 for x in range(5))) == [0, 1] + assert list(uniq('a')) == ['a'] + assert list(uniq('ababc')) == list('abc') + assert list(uniq([[1], [2, 1], [1]])) == [[1], [2, 1]] + assert list(uniq(permutations(i for i in [[1], 2, 2]))) == \ + [([1], 2, 2), (2, [1], 2), (2, 2, [1])] + assert list(uniq([2, 3, 2, 4, [2], [1], [2], [3], [1]])) == \ + [2, 3, 4, [2], [1], [3]] + f = [1] + raises(RuntimeError, lambda: [f.remove(i) for i in uniq(f)]) + f = [[1]] + raises(RuntimeError, lambda: [f.remove(i) for i in uniq(f)]) + + +def test_kbins(): + assert len(list(kbins('1123', 2, ordered=1))) == 24 + assert len(list(kbins('1123', 2, ordered=11))) == 36 + assert len(list(kbins('1123', 2, ordered=10))) == 10 + assert len(list(kbins('1123', 2, ordered=0))) == 5 + assert len(list(kbins('1123', 2, ordered=None))) == 3 + + def test1(): + for orderedval in [None, 0, 1, 10, 11]: + print('ordered =', orderedval) + for p in kbins([0, 0, 1], 2, ordered=orderedval): + print(' ', p) + assert capture(lambda : test1()) == dedent('''\ + ordered = None + [[0], [0, 1]] + [[0, 0], [1]] + ordered = 0 + [[0, 0], [1]] + [[0, 1], [0]] + ordered = 1 + [[0], [0, 1]] + [[0], [1, 0]] + [[1], [0, 0]] + ordered = 10 + [[0, 0], [1]] + [[1], [0, 0]] + [[0, 1], [0]] + [[0], [0, 1]] + ordered = 11 + [[0], [0, 1]] + [[0, 0], [1]] + [[0], [1, 0]] + [[0, 1], [0]] + [[1], [0, 0]] + [[1, 0], [0]]\n''') + + def test2(): + for orderedval in [None, 0, 1, 10, 11]: + print('ordered =', orderedval) + for p in kbins(list(range(3)), 2, ordered=orderedval): + print(' ', p) + assert capture(lambda : test2()) == dedent('''\ + ordered = None + [[0], [1, 2]] + [[0, 1], [2]] + ordered = 0 + [[0, 1], [2]] + [[0, 2], [1]] + [[0], [1, 2]] + ordered = 1 + [[0], [1, 2]] + [[0], [2, 1]] + [[1], [0, 2]] + [[1], [2, 0]] + [[2], [0, 1]] + [[2], [1, 0]] + ordered = 10 + [[0, 1], [2]] + [[2], [0, 1]] + [[0, 2], [1]] + [[1], [0, 2]] + [[0], [1, 2]] + [[1, 2], [0]] + ordered = 11 + [[0], [1, 2]] + [[0, 1], [2]] + [[0], [2, 1]] + [[0, 2], [1]] + [[1], [0, 2]] + [[1, 0], [2]] + [[1], [2, 0]] + [[1, 2], [0]] + [[2], [0, 1]] + [[2, 0], [1]] + [[2], [1, 0]] + [[2, 1], [0]]\n''') + + +def test_has_dups(): + assert has_dups(set()) is False + assert has_dups(list(range(3))) is False + assert has_dups([1, 2, 1]) is True + assert has_dups([[1], [1]]) is True + assert has_dups([[1], [2]]) is False + + +def test__partition(): + assert _partition('abcde', [1, 0, 1, 2, 0]) == [ + ['b', 'e'], ['a', 'c'], ['d']] + assert _partition('abcde', [1, 0, 1, 2, 0], 3) == [ + ['b', 'e'], ['a', 'c'], ['d']] + output = (3, [1, 0, 1, 2, 0]) + assert _partition('abcde', *output) == [['b', 'e'], ['a', 'c'], ['d']] + + +def test_ordered_partitions(): + from sympy.functions.combinatorial.numbers import nT + f = ordered_partitions + assert list(f(0, 1)) == [[]] + assert list(f(1, 0)) == [[]] + for i in range(1, 7): + for j in [None] + list(range(1, i)): + assert ( + sum(1 for p in f(i, j, 1)) == + sum(1 for p in f(i, j, 0)) == + nT(i, j)) + + +def test_rotations(): + assert list(rotations('ab')) == [['a', 'b'], ['b', 'a']] + assert list(rotations(range(3))) == [[0, 1, 2], [1, 2, 0], [2, 0, 1]] + assert list(rotations(range(3), dir=-1)) == [[0, 1, 2], [2, 0, 1], [1, 2, 0]] + + +def test_ibin(): + assert ibin(3) == [1, 1] + assert ibin(3, 3) == [0, 1, 1] + assert ibin(3, str=True) == '11' + assert ibin(3, 3, str=True) == '011' + assert list(ibin(2, 'all')) == [(0, 0), (0, 1), (1, 0), (1, 1)] + assert list(ibin(2, '', str=True)) == ['00', '01', '10', '11'] + raises(ValueError, lambda: ibin(-.5)) + raises(ValueError, lambda: ibin(2, 1)) + + +def test_iterable(): + assert iterable(0) is False + assert iterable(1) is False + assert iterable(None) is False + + class Test1(NotIterable): + pass + + assert iterable(Test1()) is False + + class Test2(NotIterable): + _iterable = True + + assert iterable(Test2()) is True + + class Test3: + pass + + assert iterable(Test3()) is False + + class Test4: + _iterable = True + + assert iterable(Test4()) is True + + class Test5: + def __iter__(self): + yield 1 + + assert iterable(Test5()) is True + + class Test6(Test5): + _iterable = False + + assert iterable(Test6()) is False + + +def test_sequence_partitions(): + assert list(sequence_partitions([1], 1)) == [[[1]]] + assert list(sequence_partitions([1, 2], 1)) == [[[1, 2]]] + assert list(sequence_partitions([1, 2], 2)) == [[[1], [2]]] + assert list(sequence_partitions([1, 2, 3], 1)) == [[[1, 2, 3]]] + assert list(sequence_partitions([1, 2, 3], 2)) == \ + [[[1], [2, 3]], [[1, 2], [3]]] + assert list(sequence_partitions([1, 2, 3], 3)) == [[[1], [2], [3]]] + + # Exceptional cases + assert list(sequence_partitions([], 0)) == [] + assert list(sequence_partitions([], 1)) == [] + assert list(sequence_partitions([1, 2], 0)) == [] + assert list(sequence_partitions([1, 2], 3)) == [] + + +def test_sequence_partitions_empty(): + assert list(sequence_partitions_empty([], 1)) == [[[]]] + assert list(sequence_partitions_empty([], 2)) == [[[], []]] + assert list(sequence_partitions_empty([], 3)) == [[[], [], []]] + assert list(sequence_partitions_empty([1], 1)) == [[[1]]] + assert list(sequence_partitions_empty([1], 2)) == [[[], [1]], [[1], []]] + assert list(sequence_partitions_empty([1], 3)) == \ + [[[], [], [1]], [[], [1], []], [[1], [], []]] + assert list(sequence_partitions_empty([1, 2], 1)) == [[[1, 2]]] + assert list(sequence_partitions_empty([1, 2], 2)) == \ + [[[], [1, 2]], [[1], [2]], [[1, 2], []]] + assert list(sequence_partitions_empty([1, 2], 3)) == [ + [[], [], [1, 2]], [[], [1], [2]], [[], [1, 2], []], + [[1], [], [2]], [[1], [2], []], [[1, 2], [], []] + ] + assert list(sequence_partitions_empty([1, 2, 3], 1)) == [[[1, 2, 3]]] + assert list(sequence_partitions_empty([1, 2, 3], 2)) == \ + [[[], [1, 2, 3]], [[1], [2, 3]], [[1, 2], [3]], [[1, 2, 3], []]] + assert list(sequence_partitions_empty([1, 2, 3], 3)) == [ + [[], [], [1, 2, 3]], [[], [1], [2, 3]], + [[], [1, 2], [3]], [[], [1, 2, 3], []], + [[1], [], [2, 3]], [[1], [2], [3]], + [[1], [2, 3], []], [[1, 2], [], [3]], + [[1, 2], [3], []], [[1, 2, 3], [], []] + ] + + # Exceptional cases + assert list(sequence_partitions([], 0)) == [] + assert list(sequence_partitions([1], 0)) == [] + assert list(sequence_partitions([1, 2], 0)) == [] + + +def test_signed_permutations(): + ans = [(0, 1, 1), (0, -1, 1), (0, 1, -1), (0, -1, -1), + (1, 0, 1), (-1, 0, 1), (1, 0, -1), (-1, 0, -1), + (1, 1, 0), (-1, 1, 0), (1, -1, 0), (-1, -1, 0)] + assert list(signed_permutations((0, 1, 1))) == ans + assert list(signed_permutations((1, 0, 1))) == ans + assert list(signed_permutations((1, 1, 0))) == ans diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/test_lambdify.py b/phivenv/Lib/site-packages/sympy/utilities/tests/test_lambdify.py new file mode 100644 index 0000000000000000000000000000000000000000..b094c67d39f09c24bcb9abc1e755cb5328e143e7 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/tests/test_lambdify.py @@ -0,0 +1,2263 @@ +from itertools import product +import math +import inspect +import linecache +import gc + +import mpmath +import cmath + +from sympy.testing.pytest import raises, warns_deprecated_sympy +from sympy.concrete.summations import Sum +from sympy.core.function import (Function, Lambda, diff) +from sympy.core.numbers import (E, Float, I, Rational, all_close, oo, pi) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, symbols) +from sympy.functions.combinatorial.factorials import (RisingFactorial, factorial) +from sympy.functions.combinatorial.numbers import bernoulli, harmonic +from sympy.functions.elementary.complexes import Abs, sign +from sympy.functions.elementary.exponential import exp, log +from sympy.functions.elementary.hyperbolic import asinh,acosh,atanh +from sympy.functions.elementary.integers import floor +from sympy.functions.elementary.miscellaneous import (Max, Min, sqrt) +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (asin, acos, atan, cos, cot, sin, + sinc, tan) +from sympy.functions import sinh,cosh,tanh +from sympy.functions.special.bessel import (besseli, besselj, besselk, bessely, jn, yn) +from sympy.functions.special.beta_functions import (beta, betainc, betainc_regularized) +from sympy.functions.special.delta_functions import (Heaviside) +from sympy.functions.special.error_functions import (Ei, erf, erfc, fresnelc, fresnels, Si, Ci) +from sympy.functions.special.gamma_functions import (digamma, gamma, loggamma, polygamma) +from sympy.functions.special.zeta_functions import zeta +from sympy.integrals.integrals import Integral +from sympy.logic.boolalg import (And, false, ITE, Not, Or, true) +from sympy.matrices.expressions.dotproduct import DotProduct +from sympy.simplify.cse_main import cse +from sympy.tensor.array import derive_by_array, Array +from sympy.tensor.array.expressions import ArraySymbol +from sympy.tensor.indexed import IndexedBase, Idx +from sympy.utilities.lambdify import lambdify +from sympy.utilities.iterables import numbered_symbols +from sympy.vector import CoordSys3D +from sympy.core.expr import UnevaluatedExpr +from sympy.codegen.cfunctions import expm1, log1p, exp2, log2, log10, hypot, isnan, isinf +from sympy.codegen.numpy_nodes import logaddexp, logaddexp2, amin, amax, minimum, maximum +from sympy.codegen.scipy_nodes import cosm1, powm1 +from sympy.functions.elementary.complexes import re, im, arg +from sympy.functions.special.polynomials import \ + chebyshevt, chebyshevu, legendre, hermite, laguerre, gegenbauer, \ + assoc_legendre, assoc_laguerre, jacobi +from sympy.matrices import Matrix, MatrixSymbol, SparseMatrix +from sympy.printing.codeprinter import PrintMethodNotImplementedError +from sympy.printing.lambdarepr import LambdaPrinter +from sympy.printing.numpy import NumPyPrinter +from sympy.utilities.lambdify import implemented_function, lambdastr +from sympy.testing.pytest import skip +from sympy.utilities.decorator import conserve_mpmath_dps +from sympy.utilities.exceptions import ignore_warnings +from sympy.external import import_module +from sympy.functions.special.gamma_functions import uppergamma, lowergamma + + +import sympy + + +MutableDenseMatrix = Matrix + +numpy = import_module('numpy') +scipy = import_module('scipy', import_kwargs={'fromlist': ['sparse']}) +numexpr = import_module('numexpr') +tensorflow = import_module('tensorflow') +cupy = import_module('cupy') +jax = import_module('jax') +numba = import_module('numba') + +if tensorflow: + # Hide Tensorflow warnings + import os + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' + +w, x, y, z = symbols('w,x,y,z') + +#================== Test different arguments ======================= + + +def test_no_args(): + f = lambdify([], 1) + raises(TypeError, lambda: f(-1)) + assert f() == 1 + + +def test_single_arg(): + f = lambdify(x, 2*x) + assert f(1) == 2 + + +def test_list_args(): + f = lambdify([x, y], x + y) + assert f(1, 2) == 3 + + +def test_nested_args(): + f1 = lambdify([[w, x]], [w, x]) + assert f1([91, 2]) == [91, 2] + raises(TypeError, lambda: f1(1, 2)) + + f2 = lambdify([(w, x), (y, z)], [w, x, y, z]) + assert f2((18, 12), (73, 4)) == [18, 12, 73, 4] + raises(TypeError, lambda: f2(3, 4)) + + f3 = lambdify([w, [[[x]], y], z], [w, x, y, z]) + assert f3(10, [[[52]], 31], 44) == [10, 52, 31, 44] + + +def test_str_args(): + f = lambdify('x,y,z', 'z,y,x') + assert f(3, 2, 1) == (1, 2, 3) + assert f(1.0, 2.0, 3.0) == (3.0, 2.0, 1.0) + # make sure correct number of args required + raises(TypeError, lambda: f(0)) + + +def test_own_namespace_1(): + myfunc = lambda x: 1 + f = lambdify(x, sin(x), {"sin": myfunc}) + assert f(0.1) == 1 + assert f(100) == 1 + + +def test_own_namespace_2(): + def myfunc(x): + return 1 + f = lambdify(x, sin(x), {'sin': myfunc}) + assert f(0.1) == 1 + assert f(100) == 1 + + +def test_own_module(): + f = lambdify(x, sin(x), math) + assert f(0) == 0.0 + + p, q, r = symbols("p q r", real=True) + ae = abs(exp(p+UnevaluatedExpr(q+r))) + f = lambdify([p, q, r], [ae, ae], modules=math) + results = f(1.0, 1e18, -1e18) + refvals = [math.exp(1.0)]*2 + for res, ref in zip(results, refvals): + assert abs((res-ref)/ref) < 1e-15 + + +def test_bad_args(): + # no vargs given + raises(TypeError, lambda: lambdify(1)) + # same with vector exprs + raises(TypeError, lambda: lambdify([1, 2])) + + +def test_atoms(): + # Non-Symbol atoms should not be pulled out from the expression namespace + f = lambdify(x, pi + x, {"pi": 3.14}) + assert f(0) == 3.14 + f = lambdify(x, I + x, {"I": 1j}) + assert f(1) == 1 + 1j + +#================== Test different modules ========================= + +# high precision output of sin(0.2*pi) is used to detect if precision is lost unwanted + + +@conserve_mpmath_dps +def test_sympy_lambda(): + mpmath.mp.dps = 50 + sin02 = mpmath.mpf("0.19866933079506121545941262711838975037020672954020") + f = lambdify(x, sin(x), "sympy") + assert f(x) == sin(x) + prec = 1e-15 + assert -prec < f(Rational(1, 5)).evalf() - Float(str(sin02)) < prec + # arctan is in numpy module and should not be available + # The arctan below gives NameError. What is this supposed to test? + # raises(NameError, lambda: lambdify(x, arctan(x), "sympy")) + + +@conserve_mpmath_dps +def test_math_lambda(): + mpmath.mp.dps = 50 + sin02 = mpmath.mpf("0.19866933079506121545941262711838975037020672954020") + f = lambdify(x, sin(x), "math") + prec = 1e-15 + assert -prec < f(0.2) - sin02 < prec + raises(TypeError, lambda: f(x)) + # if this succeeds, it can't be a Python math function + + +@conserve_mpmath_dps +def test_mpmath_lambda(): + mpmath.mp.dps = 50 + sin02 = mpmath.mpf("0.19866933079506121545941262711838975037020672954020") + f = lambdify(x, sin(x), "mpmath") + prec = 1e-49 # mpmath precision is around 50 decimal places + assert -prec < f(mpmath.mpf("0.2")) - sin02 < prec + raises(TypeError, lambda: f(x)) + # if this succeeds, it can't be a mpmath function + + ref2 = (mpmath.mpf("1e-30") + - mpmath.mpf("1e-45")/2 + + 5*mpmath.mpf("1e-60")/6 + - 3*mpmath.mpf("1e-75")/4 + + 33*mpmath.mpf("1e-90")/40 + ) + f2a = lambdify((x, y), x**y - 1, "mpmath") + f2b = lambdify((x, y), powm1(x, y), "mpmath") + f2c = lambdify((x,), expm1(x*log1p(x)), "mpmath") + ans2a = f2a(mpmath.mpf("1")+mpmath.mpf("1e-15"), mpmath.mpf("1e-15")) + ans2b = f2b(mpmath.mpf("1")+mpmath.mpf("1e-15"), mpmath.mpf("1e-15")) + ans2c = f2c(mpmath.mpf("1e-15")) + assert abs(ans2a - ref2) < 1e-51 + assert abs(ans2b - ref2) < 1e-67 + assert abs(ans2c - ref2) < 1e-80 + + +@conserve_mpmath_dps +def test_number_precision(): + mpmath.mp.dps = 50 + sin02 = mpmath.mpf("0.19866933079506121545941262711838975037020672954020") + f = lambdify(x, sin02, "mpmath") + prec = 1e-49 # mpmath precision is around 50 decimal places + assert -prec < f(0) - sin02 < prec + +@conserve_mpmath_dps +def test_mpmath_precision(): + mpmath.mp.dps = 100 + assert str(lambdify((), pi.evalf(100), 'mpmath')()) == str(pi.evalf(100)) + +#================== Test Translations ============================== +# We can only check if all translated functions are valid. It has to be checked +# by hand if they are complete. + + +def test_math_transl(): + from sympy.utilities.lambdify import MATH_TRANSLATIONS + for sym, mat in MATH_TRANSLATIONS.items(): + assert sym in sympy.__dict__ + assert mat in math.__dict__ + + +def test_mpmath_transl(): + from sympy.utilities.lambdify import MPMATH_TRANSLATIONS + for sym, mat in MPMATH_TRANSLATIONS.items(): + assert sym in sympy.__dict__ or sym == 'Matrix' + assert mat in mpmath.__dict__ + + +def test_numpy_transl(): + if not numpy: + skip("numpy not installed.") + + from sympy.utilities.lambdify import NUMPY_TRANSLATIONS + for sym, nump in NUMPY_TRANSLATIONS.items(): + assert sym in sympy.__dict__ + assert nump in numpy.__dict__ + + +def test_scipy_transl(): + if not scipy: + skip("scipy not installed.") + + from sympy.utilities.lambdify import SCIPY_TRANSLATIONS + for sym, scip in SCIPY_TRANSLATIONS.items(): + assert sym in sympy.__dict__ + assert scip in scipy.__dict__ or scip in scipy.special.__dict__ + + +def test_numpy_translation_abs(): + if not numpy: + skip("numpy not installed.") + + f = lambdify(x, Abs(x), "numpy") + assert f(-1) == 1 + assert f(1) == 1 + + +def test_numexpr_printer(): + if not numexpr: + skip("numexpr not installed.") + + # if translation/printing is done incorrectly then evaluating + # a lambdified numexpr expression will throw an exception + from sympy.printing.lambdarepr import NumExprPrinter + + blacklist = ('where', 'complex', 'contains') + arg_tuple = (x, y, z) # some functions take more than one argument + for sym in NumExprPrinter._numexpr_functions.keys(): + if sym in blacklist: + continue + ssym = S(sym) + if hasattr(ssym, '_nargs'): + nargs = ssym._nargs[0] + else: + nargs = 1 + args = arg_tuple[:nargs] + f = lambdify(args, ssym(*args), modules='numexpr') + assert f(*(1, )*nargs) is not None + + +def test_cmath_sqrt(): + f = lambdify(x, sqrt(x), "cmath") + assert f(0) == 0 + assert f(1) == 1 + assert f(4) == 2 + assert abs(f(2) - 1.414) < 0.001 + assert f(-1) == 1j + assert f(-4) == 2j + + +def test_cmath_log(): + f = lambdify(x, log(x), "cmath") + assert abs(f(1) - 0) < 1e-15 + assert abs(f(cmath.e) - 1) < 1e-15 + assert abs(f(-1) - cmath.log(-1)) < 1e-15 + + +def test_cmath_sinh(): + f = lambdify(x, sinh(x), "cmath") + assert abs(f(0) - cmath.sinh(0)) < 1e-15 + assert abs(f(pi) - cmath.sinh(pi)) < 1e-15 + assert abs(f(-pi) - cmath.sinh(-pi)) < 1e-15 + assert abs(f(1j) - cmath.sinh(1j)) < 1e-15 + + +def test_cmath_cosh(): + f = lambdify(x, cosh(x), "cmath") + assert abs(f(0) - cmath.cosh(0)) < 1e-15 + assert abs(f(pi) - cmath.cosh(pi)) < 1e-15 + assert abs(f(-pi) - cmath.cosh(-pi)) < 1e-15 + assert abs(f(1j) - cmath.cosh(1j)) < 1e-15 + + +def test_cmath_tanh(): + f = lambdify(x, tanh(x), "cmath") + assert abs(f(0) - cmath.tanh(0)) < 1e-15 + assert abs(f(pi) - cmath.tanh(pi)) < 1e-15 + assert abs(f(-pi) - cmath.tanh(-pi)) < 1e-15 + assert abs(f(1j) - cmath.tanh(1j)) < 1e-15 + + +def test_cmath_sin(): + f = lambdify(x, sin(x), "cmath") + assert abs(f(0) - cmath.sin(0)) < 1e-15 + assert abs(f(pi) - cmath.sin(pi)) < 1e-15 + assert abs(f(-pi) - cmath.sin(-pi)) < 1e-15 + assert abs(f(1j) - cmath.sin(1j)) < 1e-15 + + +def test_cmath_cos(): + f = lambdify(x, cos(x), "cmath") + assert abs(f(0) - cmath.cos(0)) < 1e-15 + assert abs(f(pi) - cmath.cos(pi)) < 1e-15 + assert abs(f(-pi) - cmath.cos(-pi)) < 1e-15 + assert abs(f(1j) - cmath.cos(1j)) < 1e-15 + + +def test_cmath_tan(): + f = lambdify(x, tan(x), "cmath") + assert abs(f(0) - cmath.tan(0)) < 1e-15 + assert abs(f(1j) - cmath.tan(1j)) < 1e-15 + + +def test_cmath_asin(): + f = lambdify(x, asin(x), "cmath") + assert abs(f(0) - cmath.asin(0)) < 1e-15 + assert abs(f(1) - cmath.asin(1)) < 1e-15 + assert abs(f(-1) - cmath.asin(-1)) < 1e-15 + assert abs(f(2) - cmath.asin(2)) < 1e-15 + assert abs(f(1j) - cmath.asin(1j)) < 1e-15 + + +def test_cmath_acos(): + f = lambdify(x, acos(x), "cmath") + assert abs(f(1) - cmath.acos(1)) < 1e-15 + assert abs(f(-1) - cmath.acos(-1)) < 1e-15 + assert abs(f(2) - cmath.acos(2)) < 1e-15 + assert abs(f(1j) - cmath.acos(1j)) < 1e-15 + + +def test_cmath_atan(): + f = lambdify(x, atan(x), "cmath") + assert abs(f(0) - cmath.atan(0)) < 1e-15 + assert abs(f(1) - cmath.atan(1)) < 1e-15 + assert abs(f(-1) - cmath.atan(-1)) < 1e-15 + assert abs(f(2) - cmath.atan(2)) < 1e-15 + assert abs(f(2j) - cmath.atan(2j)) < 1e-15 + + +def test_cmath_asinh(): + f = lambdify(x, asinh(x), "cmath") + assert abs(f(0) - cmath.asinh(0)) < 1e-15 + assert abs(f(1) - cmath.asinh(1)) < 1e-15 + assert abs(f(-1) - cmath.asinh(-1)) < 1e-15 + assert abs(f(2) - cmath.asinh(2)) < 1e-15 + assert abs(f(2j) - cmath.asinh(2j)) < 1e-15 + + +def test_cmath_acosh(): + f = lambdify(x, acosh(x), "cmath") + assert abs(f(1) - cmath.acosh(1)) < 1e-15 + assert abs(f(2) - cmath.acosh(2)) < 1e-15 + assert abs(f(-1) - cmath.acosh(-1)) < 1e-15 + assert abs(f(2j) - cmath.acosh(2j)) < 1e-15 + + +def test_cmath_atanh(): + f = lambdify(x, atanh(x), "cmath") + assert abs(f(0) - cmath.atanh(0)) < 1e-15 + assert abs(f(0.5) - cmath.atanh(0.5)) < 1e-15 + assert abs(f(-0.5) - cmath.atanh(-0.5)) < 1e-15 + assert abs(f(2) - cmath.atanh(2)) < 1e-15 + assert abs(f(-2) - cmath.atanh(-2)) < 1e-15 + assert abs(f(2j) - cmath.atanh(2j)) < 1e-15 + + +def test_cmath_complex_identities(): + # Define symbol + z = symbols('z') + + # Trigonometric identity using re(z) and im(z) + expr = cos(z) - cos(re(z)) * cosh(im(z)) + I * sin(re(z)) * sinh(im(z)) + func = lambdify([z], expr, modules=["cmath", "math"]) + hpi = math.pi / 2 + assert abs(func(hpi + 1j * hpi)) < 4e-16 + + # Euler's Formula: e^(i*z) = cos(z) + i*sin(z) + func = lambdify([z], exp(I * z) - (cos(z) + I * sin(z)), modules=["cmath", "math"]) + assert abs(func(hpi)) < 4e-16 + + # Exponential Identity: e^z = e^(Re(z)) * (cos(Im(z)) + i*sin(Im(z))) + func_exp = lambdify([z], exp(z) - exp(re(z)) * (cos(im(z)) + I * sin(im(z))), + modules=["cmath", "math"]) + assert abs(func_exp(hpi + 1j * hpi)) < 4e-16 + + # Complex Cosine Identity: cos(z) = cos(Re(z)) * cosh(Im(z)) - i*sin(Re(z)) * sinh(Im(z)) + func_cos = lambdify([z], cos(z) - (cos(re(z)) * cosh(im(z)) - I * sin(re(z)) * sinh(im(z))), + modules=["cmath", "math"]) + assert abs(func_cos(hpi + 1j * hpi)) < 4e-16 + + # Complex Sine Identity: sin(z) = sin(Re(z)) * cosh(Im(z)) + i*cos(Re(z)) * sinh(Im(z)) + func_sin = lambdify([z], sin(z) - (sin(re(z)) * cosh(im(z)) + I * cos(re(z)) * sinh(im(z))), + modules=["cmath", "math"]) + assert abs(func_sin(hpi + 1j * hpi)) < 4e-16 + + # Complex Hyperbolic Cosine Identity: cosh(z) = cosh(Re(z)) * cos(Im(z)) + i*sinh(Re(z)) * sin(Im(z)) + func_cosh_1 = lambdify([z], cosh(z) - (cosh(re(z)) * cos(im(z)) + I * sinh(re(z)) * sin(im(z))), + modules=["cmath", "math"]) + assert abs(func_cosh_1(hpi + 1j * hpi)) < 4e-16 + + # Complex Hyperbolic Sine Identity: sinh(z) = sinh(Re(z)) * cos(Im(z)) + i*cosh(Re(z)) * sin(Im(z)) + func_sinh = lambdify([z], sinh(z) - (sinh(re(z)) * cos(im(z)) + I * cosh(re(z)) * sin(im(z))), + modules=["cmath", "math"]) + assert abs(func_sinh(hpi + 1j * hpi)) < 4e-16 + + # cosh(z) = (e^z + e^(-z)) / 2 + func_cosh_2 = lambdify([z], cosh(z) - (exp(z) + exp(-z)) / 2, modules=["cmath", "math"]) + assert abs(func_cosh_2(hpi)) < 4e-16 + + # Additional expressions testing log and exp with real and imaginary parts + expr1 = log(re(z)) + log(im(z)) - log(re(z) * im(z)) + expr2 = exp(re(z)) * exp(im(z) * I) - exp(z) + expr3 = log(exp(re(z))) - re(z) + expr4 = exp(log(re(z))) - re(z) + expr5 = log(exp(re(z) + im(z))) - (re(z) + im(z)) + expr6 = exp(log(re(z) + im(z))) - (re(z) + im(z)) + func1 = lambdify([z], expr1, modules=["cmath", "math"]) + func2 = lambdify([z], expr2, modules=["cmath", "math"]) + func3 = lambdify([z], expr3, modules=["cmath", "math"]) + func4 = lambdify([z], expr4, modules=["cmath", "math"]) + func5 = lambdify([z], expr5, modules=["cmath", "math"]) + func6 = lambdify([z], expr6, modules=["cmath", "math"]) + test_value = 3 + 4j + assert abs(func1(test_value)) < 4e-16 + assert abs(func2(test_value)) < 4e-16 + assert abs(func3(test_value)) < 4e-16 + assert abs(func4(test_value)) < 4e-16 + assert abs(func5(test_value)) < 4e-16 + assert abs(func6(test_value)) < 4e-16 + + +def test_issue_9334(): + if not numexpr: + skip("numexpr not installed.") + if not numpy: + skip("numpy not installed.") + expr = S('b*a - sqrt(a**2)') + a, b = sorted(expr.free_symbols, key=lambda s: s.name) + func_numexpr = lambdify((a,b), expr, modules=[numexpr], dummify=False) + foo, bar = numpy.random.random((2, 4)) + func_numexpr(foo, bar) + + +def test_issue_12984(): + if not numexpr: + skip("numexpr not installed.") + func_numexpr = lambdify((x,y,z), Piecewise((y, x >= 0), (z, x > -1)), numexpr) + with ignore_warnings(RuntimeWarning): + assert func_numexpr(1, 24, 42) == 24 + assert str(func_numexpr(-1, 24, 42)) == 'nan' + + +def test_empty_modules(): + x, y = symbols('x y') + expr = -(x % y) + + no_modules = lambdify([x, y], expr) + empty_modules = lambdify([x, y], expr, modules=[]) + assert no_modules(3, 7) == empty_modules(3, 7) + assert no_modules(3, 7) == -3 + + +def test_exponentiation(): + f = lambdify(x, x**2) + assert f(-1) == 1 + assert f(0) == 0 + assert f(1) == 1 + assert f(-2) == 4 + assert f(2) == 4 + assert f(2.5) == 6.25 + + +def test_sqrt(): + f = lambdify(x, sqrt(x)) + assert f(0) == 0.0 + assert f(1) == 1.0 + assert f(4) == 2.0 + assert abs(f(2) - 1.414) < 0.001 + assert f(6.25) == 2.5 + + +def test_trig(): + f = lambdify([x], [cos(x), sin(x)], 'math') + d = f(pi) + prec = 1e-11 + assert -prec < d[0] + 1 < prec + assert -prec < d[1] < prec + d = f(3.14159) + prec = 1e-5 + assert -prec < d[0] + 1 < prec + assert -prec < d[1] < prec + + +def test_integral(): + if numpy and not scipy: + skip("scipy not installed.") + f = Lambda(x, exp(-x**2)) + l = lambdify(y, Integral(f(x), (x, y, oo))) + d = l(-oo) + assert 1.77245385 < d < 1.772453851 + + +def test_double_integral(): + if numpy and not scipy: + skip("scipy not installed.") + # example from http://mpmath.org/doc/current/calculus/integration.html + i = Integral(1/(1 - x**2*y**2), (x, 0, 1), (y, 0, z)) + l = lambdify([z], i) + d = l(1) + assert 1.23370055 < d < 1.233700551 + +def test_spherical_bessel(): + if numpy and not scipy: + skip("scipy not installed.") + test_point = 4.2 #randomly selected + x = symbols("x") + jtest = jn(2, x) + assert abs(lambdify(x,jtest)(test_point) - + jtest.subs(x,test_point).evalf()) < 1e-8 + ytest = yn(2, x) + assert abs(lambdify(x,ytest)(test_point) - + ytest.subs(x,test_point).evalf()) < 1e-8 + + +#================== Test vectors =================================== + + +def test_vector_simple(): + f = lambdify((x, y, z), (z, y, x)) + assert f(3, 2, 1) == (1, 2, 3) + assert f(1.0, 2.0, 3.0) == (3.0, 2.0, 1.0) + # make sure correct number of args required + raises(TypeError, lambda: f(0)) + + +def test_vector_discontinuous(): + f = lambdify(x, (-1/x, 1/x)) + raises(ZeroDivisionError, lambda: f(0)) + assert f(1) == (-1.0, 1.0) + assert f(2) == (-0.5, 0.5) + assert f(-2) == (0.5, -0.5) + + +def test_trig_symbolic(): + f = lambdify([x], [cos(x), sin(x)], 'math') + d = f(pi) + assert abs(d[0] + 1) < 0.0001 + assert abs(d[1] - 0) < 0.0001 + + +def test_trig_float(): + f = lambdify([x], [cos(x), sin(x)]) + d = f(3.14159) + assert abs(d[0] + 1) < 0.0001 + assert abs(d[1] - 0) < 0.0001 + + +def test_docs(): + f = lambdify(x, x**2) + assert f(2) == 4 + f = lambdify([x, y, z], [z, y, x]) + assert f(1, 2, 3) == [3, 2, 1] + f = lambdify(x, sqrt(x)) + assert f(4) == 2.0 + f = lambdify((x, y), sin(x*y)**2) + assert f(0, 5) == 0 + + +def test_math(): + f = lambdify((x, y), sin(x), modules="math") + assert f(0, 5) == 0 + + +def test_sin(): + f = lambdify(x, sin(x)**2) + assert isinstance(f(2), float) + f = lambdify(x, sin(x)**2, modules="math") + assert isinstance(f(2), float) + + +def test_matrix(): + A = Matrix([[x, x*y], [sin(z) + 4, x**z]]) + sol = Matrix([[1, 2], [sin(3) + 4, 1]]) + f = lambdify((x, y, z), A, modules="sympy") + assert f(1, 2, 3) == sol + f = lambdify((x, y, z), (A, [A]), modules="sympy") + assert f(1, 2, 3) == (sol, [sol]) + J = Matrix((x, x + y)).jacobian((x, y)) + v = Matrix((x, y)) + sol = Matrix([[1, 0], [1, 1]]) + assert lambdify(v, J, modules='sympy')(1, 2) == sol + assert lambdify(v.T, J, modules='sympy')(1, 2) == sol + + +def test_numpy_matrix(): + if not numpy: + skip("numpy not installed.") + A = Matrix([[x, x*y], [sin(z) + 4, x**z]]) + sol_arr = numpy.array([[1, 2], [numpy.sin(3) + 4, 1]]) + #Lambdify array first, to ensure return to array as default + f = lambdify((x, y, z), A, ['numpy']) + numpy.testing.assert_allclose(f(1, 2, 3), sol_arr) + #Check that the types are arrays and matrices + assert isinstance(f(1, 2, 3), numpy.ndarray) + + # gh-15071 + class dot(Function): + pass + x_dot_mtx = dot(x, Matrix([[2], [1], [0]])) + f_dot1 = lambdify(x, x_dot_mtx) + inp = numpy.zeros((17, 3)) + assert numpy.all(f_dot1(inp) == 0) + + strict_kw = {"allow_unknown_functions": False, "inline": True, "fully_qualified_modules": False} + p2 = NumPyPrinter(dict(user_functions={'dot': 'dot'}, **strict_kw)) + f_dot2 = lambdify(x, x_dot_mtx, printer=p2) + assert numpy.all(f_dot2(inp) == 0) + + p3 = NumPyPrinter(strict_kw) + # The line below should probably fail upon construction (before calling with "(inp)"): + raises(Exception, lambda: lambdify(x, x_dot_mtx, printer=p3)(inp)) + + +def test_numpy_transpose(): + if not numpy: + skip("numpy not installed.") + A = Matrix([[1, x], [0, 1]]) + f = lambdify((x), A.T, modules="numpy") + numpy.testing.assert_array_equal(f(2), numpy.array([[1, 0], [2, 1]])) + + +def test_numpy_dotproduct(): + if not numpy: + skip("numpy not installed") + A = Matrix([x, y, z]) + f1 = lambdify([x, y, z], DotProduct(A, A), modules='numpy') + f2 = lambdify([x, y, z], DotProduct(A, A.T), modules='numpy') + f3 = lambdify([x, y, z], DotProduct(A.T, A), modules='numpy') + f4 = lambdify([x, y, z], DotProduct(A, A.T), modules='numpy') + + assert f1(1, 2, 3) == \ + f2(1, 2, 3) == \ + f3(1, 2, 3) == \ + f4(1, 2, 3) == \ + numpy.array([14]) + + +def test_numpy_inverse(): + if not numpy: + skip("numpy not installed.") + A = Matrix([[1, x], [0, 1]]) + f = lambdify((x), A**-1, modules="numpy") + numpy.testing.assert_array_equal(f(2), numpy.array([[1, -2], [0, 1]])) + + +def test_numpy_old_matrix(): + if not numpy: + skip("numpy not installed.") + A = Matrix([[x, x*y], [sin(z) + 4, x**z]]) + sol_arr = numpy.array([[1, 2], [numpy.sin(3) + 4, 1]]) + f = lambdify((x, y, z), A, [{'ImmutableDenseMatrix': numpy.matrix}, 'numpy']) + with ignore_warnings(PendingDeprecationWarning): + numpy.testing.assert_allclose(f(1, 2, 3), sol_arr) + assert isinstance(f(1, 2, 3), numpy.matrix) + + +def test_scipy_sparse_matrix(): + if not scipy: + skip("scipy not installed.") + A = SparseMatrix([[x, 0], [0, y]]) + f = lambdify((x, y), A, modules="scipy") + B = f(1, 2) + assert isinstance(B, scipy.sparse.coo_matrix) + + +def test_python_div_zero_issue_11306(): + if not numpy: + skip("numpy not installed.") + p = Piecewise((1 / x, y < -1), (x, y < 1), (1 / x, True)) + f = lambdify([x, y], p, modules='numpy') + with numpy.errstate(divide='ignore'): + assert float(f(numpy.array(0), numpy.array(0.5))) == 0 + assert float(f(numpy.array(0), numpy.array(1))) == float('inf') + + +def test_issue9474(): + mods = [None, 'math'] + if numpy: + mods.append('numpy') + if mpmath: + mods.append('mpmath') + for mod in mods: + f = lambdify(x, S.One/x, modules=mod) + assert f(2) == 0.5 + f = lambdify(x, floor(S.One/x), modules=mod) + assert f(2) == 0 + + for absfunc, modules in product([Abs, abs], mods): + f = lambdify(x, absfunc(x), modules=modules) + assert f(-1) == 1 + assert f(1) == 1 + assert f(3+4j) == 5 + + +def test_issue_9871(): + if not numexpr: + skip("numexpr not installed.") + if not numpy: + skip("numpy not installed.") + + r = sqrt(x**2 + y**2) + expr = diff(1/r, x) + + xn = yn = numpy.linspace(1, 10, 16) + # expr(xn, xn) = -xn/(sqrt(2)*xn)^3 + fv_exact = -numpy.sqrt(2.)**-3 * xn**-2 + + fv_numpy = lambdify((x, y), expr, modules='numpy')(xn, yn) + fv_numexpr = lambdify((x, y), expr, modules='numexpr')(xn, yn) + numpy.testing.assert_allclose(fv_numpy, fv_exact, rtol=1e-10) + numpy.testing.assert_allclose(fv_numexpr, fv_exact, rtol=1e-10) + + +def test_numpy_piecewise(): + if not numpy: + skip("numpy not installed.") + pieces = Piecewise((x, x < 3), (x**2, x > 5), (0, True)) + f = lambdify(x, pieces, modules="numpy") + numpy.testing.assert_array_equal(f(numpy.arange(10)), + numpy.array([0, 1, 2, 0, 0, 0, 36, 49, 64, 81])) + # If we evaluate somewhere all conditions are False, we should get back NaN + nodef_func = lambdify(x, Piecewise((x, x > 0), (-x, x < 0))) + numpy.testing.assert_array_equal(nodef_func(numpy.array([-1, 0, 1])), + numpy.array([1, numpy.nan, 1])) + + +def test_numpy_logical_ops(): + if not numpy: + skip("numpy not installed.") + and_func = lambdify((x, y), And(x, y), modules="numpy") + and_func_3 = lambdify((x, y, z), And(x, y, z), modules="numpy") + or_func = lambdify((x, y), Or(x, y), modules="numpy") + or_func_3 = lambdify((x, y, z), Or(x, y, z), modules="numpy") + not_func = lambdify((x), Not(x), modules="numpy") + arr1 = numpy.array([True, True]) + arr2 = numpy.array([False, True]) + arr3 = numpy.array([True, False]) + numpy.testing.assert_array_equal(and_func(arr1, arr2), numpy.array([False, True])) + numpy.testing.assert_array_equal(and_func_3(arr1, arr2, arr3), numpy.array([False, False])) + numpy.testing.assert_array_equal(or_func(arr1, arr2), numpy.array([True, True])) + numpy.testing.assert_array_equal(or_func_3(arr1, arr2, arr3), numpy.array([True, True])) + numpy.testing.assert_array_equal(not_func(arr2), numpy.array([True, False])) + + +def test_numpy_matmul(): + if not numpy: + skip("numpy not installed.") + xmat = Matrix([[x, y], [z, 1+z]]) + ymat = Matrix([[x**2], [Abs(x)]]) + mat_func = lambdify((x, y, z), xmat*ymat, modules="numpy") + numpy.testing.assert_array_equal(mat_func(0.5, 3, 4), numpy.array([[1.625], [3.5]])) + numpy.testing.assert_array_equal(mat_func(-0.5, 3, 4), numpy.array([[1.375], [3.5]])) + # Multiple matrices chained together in multiplication + f = lambdify((x, y, z), xmat*xmat*xmat, modules="numpy") + numpy.testing.assert_array_equal(f(0.5, 3, 4), numpy.array([[72.125, 119.25], + [159, 251]])) + + +def test_numpy_numexpr(): + if not numpy: + skip("numpy not installed.") + if not numexpr: + skip("numexpr not installed.") + a, b, c = numpy.random.randn(3, 128, 128) + # ensure that numpy and numexpr return same value for complicated expression + expr = sin(x) + cos(y) + tan(z)**2 + Abs(z-y)*acos(sin(y*z)) + \ + Abs(y-z)*acosh(2+exp(y-x))- sqrt(x**2+I*y**2) + npfunc = lambdify((x, y, z), expr, modules='numpy') + nefunc = lambdify((x, y, z), expr, modules='numexpr') + assert numpy.allclose(npfunc(a, b, c), nefunc(a, b, c)) + + +def test_numexpr_userfunctions(): + if not numpy: + skip("numpy not installed.") + if not numexpr: + skip("numexpr not installed.") + a, b = numpy.random.randn(2, 10) + uf = type('uf', (Function, ), + {'eval' : classmethod(lambda x, y : y**2+1)}) + func = lambdify(x, 1-uf(x), modules='numexpr') + assert numpy.allclose(func(a), -(a**2)) + + uf = implemented_function(Function('uf'), lambda x, y : 2*x*y+1) + func = lambdify((x, y), uf(x, y), modules='numexpr') + assert numpy.allclose(func(a, b), 2*a*b+1) + + +def test_tensorflow_basic_math(): + if not tensorflow: + skip("tensorflow not installed.") + expr = Max(sin(x), Abs(1/(x+2))) + func = lambdify(x, expr, modules="tensorflow") + + with tensorflow.compat.v1.Session() as s: + a = tensorflow.constant(0, dtype=tensorflow.float32) + assert func(a).eval(session=s) == 0.5 + + +def test_tensorflow_placeholders(): + if not tensorflow: + skip("tensorflow not installed.") + expr = Max(sin(x), Abs(1/(x+2))) + func = lambdify(x, expr, modules="tensorflow") + + with tensorflow.compat.v1.Session() as s: + a = tensorflow.compat.v1.placeholder(dtype=tensorflow.float32) + assert func(a).eval(session=s, feed_dict={a: 0}) == 0.5 + + +def test_tensorflow_variables(): + if not tensorflow: + skip("tensorflow not installed.") + expr = Max(sin(x), Abs(1/(x+2))) + func = lambdify(x, expr, modules="tensorflow") + + with tensorflow.compat.v1.Session() as s: + a = tensorflow.Variable(0, dtype=tensorflow.float32) + s.run(a.initializer) + assert func(a).eval(session=s, feed_dict={a: 0}) == 0.5 + + +def test_tensorflow_logical_operations(): + if not tensorflow: + skip("tensorflow not installed.") + expr = Not(And(Or(x, y), y)) + func = lambdify([x, y], expr, modules="tensorflow") + + with tensorflow.compat.v1.Session() as s: + assert func(False, True).eval(session=s) == False + + +def test_tensorflow_piecewise(): + if not tensorflow: + skip("tensorflow not installed.") + expr = Piecewise((0, Eq(x,0)), (-1, x < 0), (1, x > 0)) + func = lambdify(x, expr, modules="tensorflow") + + with tensorflow.compat.v1.Session() as s: + assert func(-1).eval(session=s) == -1 + assert func(0).eval(session=s) == 0 + assert func(1).eval(session=s) == 1 + + +def test_tensorflow_multi_max(): + if not tensorflow: + skip("tensorflow not installed.") + expr = Max(x, -x, x**2) + func = lambdify(x, expr, modules="tensorflow") + + with tensorflow.compat.v1.Session() as s: + assert func(-2).eval(session=s) == 4 + + +def test_tensorflow_multi_min(): + if not tensorflow: + skip("tensorflow not installed.") + expr = Min(x, -x, x**2) + func = lambdify(x, expr, modules="tensorflow") + + with tensorflow.compat.v1.Session() as s: + assert func(-2).eval(session=s) == -2 + + +def test_tensorflow_relational(): + if not tensorflow: + skip("tensorflow not installed.") + expr = x >= 0 + func = lambdify(x, expr, modules="tensorflow") + + with tensorflow.compat.v1.Session() as s: + assert func(1).eval(session=s) == True + + +def test_tensorflow_complexes(): + if not tensorflow: + skip("tensorflow not installed") + + func1 = lambdify(x, re(x), modules="tensorflow") + func2 = lambdify(x, im(x), modules="tensorflow") + func3 = lambdify(x, Abs(x), modules="tensorflow") + func4 = lambdify(x, arg(x), modules="tensorflow") + + with tensorflow.compat.v1.Session() as s: + # For versions before + # https://github.com/tensorflow/tensorflow/issues/30029 + # resolved, using Python numeric types may not work + a = tensorflow.constant(1+2j) + assert func1(a).eval(session=s) == 1 + assert func2(a).eval(session=s) == 2 + + tensorflow_result = func3(a).eval(session=s) + sympy_result = Abs(1 + 2j).evalf() + assert abs(tensorflow_result-sympy_result) < 10**-6 + + tensorflow_result = func4(a).eval(session=s) + sympy_result = arg(1 + 2j).evalf() + assert abs(tensorflow_result-sympy_result) < 10**-6 + + +def test_tensorflow_array_arg(): + # Test for issue 14655 (tensorflow part) + if not tensorflow: + skip("tensorflow not installed.") + + f = lambdify([[x, y]], x*x + y, 'tensorflow') + + with tensorflow.compat.v1.Session() as s: + fcall = f(tensorflow.constant([2.0, 1.0])) + assert fcall.eval(session=s) == 5.0 + + +#================== Test symbolic ================================== + + +def test_sym_single_arg(): + f = lambdify(x, x * y) + assert f(z) == z * y + + +def test_sym_list_args(): + f = lambdify([x, y], x + y + z) + assert f(1, 2) == 3 + z + + +def test_sym_integral(): + f = Lambda(x, exp(-x**2)) + l = lambdify(x, Integral(f(x), (x, -oo, oo)), modules="sympy") + assert l(y) == Integral(exp(-y**2), (y, -oo, oo)) + assert l(y).doit() == sqrt(pi) + + +def test_namespace_order(): + # lambdify had a bug, such that module dictionaries or cached module + # dictionaries would pull earlier namespaces into themselves. + # Because the module dictionaries form the namespace of the + # generated lambda, this meant that the behavior of a previously + # generated lambda function could change as a result of later calls + # to lambdify. + n1 = {'f': lambda x: 'first f'} + n2 = {'f': lambda x: 'second f', + 'g': lambda x: 'function g'} + f = sympy.Function('f') + g = sympy.Function('g') + if1 = lambdify(x, f(x), modules=(n1, "sympy")) + assert if1(1) == 'first f' + if2 = lambdify(x, g(x), modules=(n2, "sympy")) + # previously gave 'second f' + assert if1(1) == 'first f' + + assert if2(1) == 'function g' + + +def test_imps(): + # Here we check if the default returned functions are anonymous - in + # the sense that we can have more than one function with the same name + f = implemented_function('f', lambda x: 2*x) + g = implemented_function('f', lambda x: math.sqrt(x)) + l1 = lambdify(x, f(x)) + l2 = lambdify(x, g(x)) + assert str(f(x)) == str(g(x)) + assert l1(3) == 6 + assert l2(3) == math.sqrt(3) + # check that we can pass in a Function as input + func = sympy.Function('myfunc') + assert not hasattr(func, '_imp_') + my_f = implemented_function(func, lambda x: 2*x) + assert hasattr(my_f, '_imp_') + # Error for functions with same name and different implementation + f2 = implemented_function("f", lambda x: x + 101) + raises(ValueError, lambda: lambdify(x, f(f2(x)))) + + +def test_imps_errors(): + # Test errors that implemented functions can return, and still be able to + # form expressions. + # See: https://github.com/sympy/sympy/issues/10810 + # + # XXX: Removed AttributeError here. This test was added due to issue 10810 + # but that issue was about ValueError. It doesn't seem reasonable to + # "support" catching AttributeError in the same context... + for val, error_class in product((0, 0., 2, 2.0), (TypeError, ValueError)): + + def myfunc(a): + if a == 0: + raise error_class + return 1 + + f = implemented_function('f', myfunc) + expr = f(val) + assert expr == f(val) + + +def test_imps_wrong_args(): + raises(ValueError, lambda: implemented_function(sin, lambda x: x)) + + +def test_lambdify_imps(): + # Test lambdify with implemented functions + # first test basic (sympy) lambdify + f = sympy.cos + assert lambdify(x, f(x))(0) == 1 + assert lambdify(x, 1 + f(x))(0) == 2 + assert lambdify((x, y), y + f(x))(0, 1) == 2 + # make an implemented function and test + f = implemented_function("f", lambda x: x + 100) + assert lambdify(x, f(x))(0) == 100 + assert lambdify(x, 1 + f(x))(0) == 101 + assert lambdify((x, y), y + f(x))(0, 1) == 101 + # Can also handle tuples, lists, dicts as expressions + lam = lambdify(x, (f(x), x)) + assert lam(3) == (103, 3) + lam = lambdify(x, [f(x), x]) + assert lam(3) == [103, 3] + lam = lambdify(x, [f(x), (f(x), x)]) + assert lam(3) == [103, (103, 3)] + lam = lambdify(x, {f(x): x}) + assert lam(3) == {103: 3} + lam = lambdify(x, {f(x): x}) + assert lam(3) == {103: 3} + lam = lambdify(x, {x: f(x)}) + assert lam(3) == {3: 103} + # Check that imp preferred to other namespaces by default + d = {'f': lambda x: x + 99} + lam = lambdify(x, f(x), d) + assert lam(3) == 103 + # Unless flag passed + lam = lambdify(x, f(x), d, use_imps=False) + assert lam(3) == 102 + + +def test_dummification(): + t = symbols('t') + F = Function('F') + G = Function('G') + #"\alpha" is not a valid Python variable name + #lambdify should sub in a dummy for it, and return + #without a syntax error + alpha = symbols(r'\alpha') + some_expr = 2 * F(t)**2 / G(t) + lam = lambdify((F(t), G(t)), some_expr) + assert lam(3, 9) == 2 + lam = lambdify(sin(t), 2 * sin(t)**2) + assert lam(F(t)) == 2 * F(t)**2 + #Test that \alpha was properly dummified + lam = lambdify((alpha, t), 2*alpha + t) + assert lam(2, 1) == 5 + raises(SyntaxError, lambda: lambdify(F(t) * G(t), F(t) * G(t) + 5)) + raises(SyntaxError, lambda: lambdify(2 * F(t), 2 * F(t) + 5)) + raises(SyntaxError, lambda: lambdify(2 * F(t), 4 * F(t) + 5)) + + +def test_lambdify__arguments_with_invalid_python_identifiers(): + # see sympy/sympy#26690 + N = CoordSys3D('N') + xn, yn, zn = N.base_scalars() + expr = xn + yn + f = lambdify([xn, yn], expr) + res = f(0.2, 0.3) + ref = 0.2 + 0.3 + assert abs(res-ref) < 1e-15 + + +def test_curly_matrix_symbol(): + # Issue #15009 + curlyv = sympy.MatrixSymbol("{v}", 2, 1) + lam = lambdify(curlyv, curlyv) + assert lam(1)==1 + lam = lambdify(curlyv, curlyv, dummify=True) + assert lam(1)==1 + + +def test_python_keywords(): + # Test for issue 7452. The automatic dummification should ensure use of + # Python reserved keywords as symbol names will create valid lambda + # functions. This is an additional regression test. + python_if = symbols('if') + expr = python_if / 2 + f = lambdify(python_if, expr) + assert f(4.0) == 2.0 + + +def test_lambdify_docstring(): + func = lambdify((w, x, y, z), w + x + y + z) + ref = ( + "Created with lambdify. Signature:\n\n" + "func(w, x, y, z)\n\n" + "Expression:\n\n" + "w + x + y + z" + ).splitlines() + assert func.__doc__.splitlines()[:len(ref)] == ref + syms = symbols('a1:26') + func = lambdify(syms, sum(syms)) + ref = ( + "Created with lambdify. Signature:\n\n" + "func(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15,\n" + " a16, a17, a18, a19, a20, a21, a22, a23, a24, a25)\n\n" + "Expression:\n\n" + "a1 + a10 + a11 + a12 + a13 + a14 + a15 + a16 + a17 + a18 + a19 + a2 + a20 +..." + ).splitlines() + assert func.__doc__.splitlines()[:len(ref)] == ref + + +def test_lambdify_linecache(): + func = lambdify(x, x + 1) + source = 'def _lambdifygenerated(x):\n return x + 1\n' + assert inspect.getsource(func) == source + filename = inspect.getsourcefile(func) + assert filename.startswith('= (1, 10) + + if have_scipy_1_10plus: + cm2 = lambdify((x, y), powm1(x, y), modules='scipy') + assert abs(cm2(1.2, 1e-9) - 1.82321557e-10) < 1e-17 + + +def test_scipy_bernoulli(): + if not scipy: + skip("scipy not installed") + + bern = lambdify((x,), bernoulli(x), modules='scipy') + assert bern(1) == 0.5 + + +def test_scipy_harmonic(): + if not scipy: + skip("scipy not installed") + + hn = lambdify((x,), harmonic(x), modules='scipy') + assert hn(2) == 1.5 + hnm = lambdify((x, y), harmonic(x, y), modules='scipy') + assert hnm(2, 2) == 1.25 + + +def test_cupy_array_arg(): + if not cupy: + skip("CuPy not installed") + + f = lambdify([[x, y]], x*x + y, 'cupy') + result = f(cupy.array([2.0, 1.0])) + assert result == 5 + assert "cupy" in str(type(result)) + + +def test_cupy_array_arg_using_numpy(): + # numpy functions can be run on cupy arrays + # unclear if we can "officially" support this, + # depends on numpy __array_function__ support + if not cupy: + skip("CuPy not installed") + + f = lambdify([[x, y]], x*x + y, 'numpy') + result = f(cupy.array([2.0, 1.0])) + assert result == 5 + assert "cupy" in str(type(result)) + +def test_cupy_dotproduct(): + if not cupy: + skip("CuPy not installed") + + A = Matrix([x, y, z]) + f1 = lambdify([x, y, z], DotProduct(A, A), modules='cupy') + f2 = lambdify([x, y, z], DotProduct(A, A.T), modules='cupy') + f3 = lambdify([x, y, z], DotProduct(A.T, A), modules='cupy') + f4 = lambdify([x, y, z], DotProduct(A, A.T), modules='cupy') + + assert f1(1, 2, 3) == \ + f2(1, 2, 3) == \ + f3(1, 2, 3) == \ + f4(1, 2, 3) == \ + cupy.array([14]) + + +def test_jax_array_arg(): + if not jax: + skip("JAX not installed") + + f = lambdify([[x, y]], x*x + y, 'jax') + result = f(jax.numpy.array([2.0, 1.0])) + assert result == 5 + assert "jax" in str(type(result)) + + +def test_jax_array_arg_using_numpy(): + if not jax: + skip("JAX not installed") + + f = lambdify([[x, y]], x*x + y, 'numpy') + result = f(jax.numpy.array([2.0, 1.0])) + assert result == 5 + assert "jax" in str(type(result)) + + +def test_jax_dotproduct(): + if not jax: + skip("JAX not installed") + + A = Matrix([x, y, z]) + f1 = lambdify([x, y, z], DotProduct(A, A), modules='jax') + f2 = lambdify([x, y, z], DotProduct(A, A.T), modules='jax') + f3 = lambdify([x, y, z], DotProduct(A.T, A), modules='jax') + f4 = lambdify([x, y, z], DotProduct(A, A.T), modules='jax') + + assert f1(1, 2, 3) == \ + f2(1, 2, 3) == \ + f3(1, 2, 3) == \ + f4(1, 2, 3) == \ + jax.numpy.array([14]) + + +def test_lambdify_cse(): + def no_op_cse(exprs): + return (), exprs + + def dummy_cse(exprs): + from sympy.simplify.cse_main import cse + return cse(exprs, symbols=numbered_symbols(cls=Dummy)) + + def minmem(exprs): + from sympy.simplify.cse_main import cse_release_variables, cse + return cse(exprs, postprocess=cse_release_variables) + + class Case: + def __init__(self, *, args, exprs, num_args, requires_numpy=False): + self.args = args + self.exprs = exprs + self.num_args = num_args + subs_dict = dict(zip(self.args, self.num_args)) + self.ref = [e.subs(subs_dict).evalf() for e in exprs] + self.requires_numpy = requires_numpy + + def lambdify(self, *, cse): + return lambdify(self.args, self.exprs, cse=cse) + + def assertAllClose(self, result, *, abstol=1e-15, reltol=1e-15): + if self.requires_numpy: + assert all(numpy.allclose(result[i], numpy.asarray(r, dtype=float), + rtol=reltol, atol=abstol) + for i, r in enumerate(self.ref)) + return + + for i, r in enumerate(self.ref): + abs_err = abs(result[i] - r) + if r == 0: + assert abs_err < abstol + else: + assert abs_err/abs(r) < reltol + + cases = [ + Case( + args=(x, y, z), + exprs=[ + x + y + z, + x + y - z, + 2*x + 2*y - z, + (x+y)**2 + (y+z)**2, + ], + num_args=(2., 3., 4.) + ), + Case( + args=(x, y, z), + exprs=[ + x + sympy.Heaviside(x), + y + sympy.Heaviside(x), + z + sympy.Heaviside(x, 1), + z/sympy.Heaviside(x, 1) + ], + num_args=(0., 3., 4.) + ), + Case( + args=(x, y, z), + exprs=[ + x + sinc(y), + y + sinc(y), + z - sinc(y) + ], + num_args=(0.1, 0.2, 0.3) + ), + Case( + args=(x, y, z), + exprs=[ + Matrix([[x, x*y], [sin(z) + 4, x**z]]), + x*y+sin(z)-x**z, + Matrix([x*x, sin(z), x**z]) + ], + num_args=(1.,2.,3.), + requires_numpy=True + ), + Case( + args=(x, y), + exprs=[(x + y - 1)**2, x, x + y, + (x + y)/(2*x + 1) + (x + y - 1)**2, (2*x + 1)**(x + y)], + num_args=(1,2) + ) + ] + for case in cases: + if not numpy and case.requires_numpy: + continue + for _cse in [False, True, minmem, no_op_cse, dummy_cse]: + f = case.lambdify(cse=_cse) + result = f(*case.num_args) + case.assertAllClose(result) + +def test_issue_25288(): + syms = numbered_symbols(cls=Dummy) + ok = lambdify(x, [x**2, sin(x**2)], cse=lambda e: cse(e, symbols=syms))(2) + assert ok + + +def test_deprecated_set(): + with warns_deprecated_sympy(): + lambdify({x, y}, x + y) + +def test_issue_13881(): + if not numpy: + skip("numpy not installed.") + + X = MatrixSymbol('X', 3, 1) + + f = lambdify(X, X.T*X, 'numpy') + assert f(numpy.array([1, 2, 3])) == 14 + assert f(numpy.array([3, 2, 1])) == 14 + + f = lambdify(X, X*X.T, 'numpy') + assert f(numpy.array([1, 2, 3])) == 14 + assert f(numpy.array([3, 2, 1])) == 14 + + f = lambdify(X, (X*X.T)*X, 'numpy') + arr1 = numpy.array([[1], [2], [3]]) + arr2 = numpy.array([[14],[28],[42]]) + + assert numpy.array_equal(f(arr1), arr2) + + +def test_23536_lambdify_cse_dummy(): + + f = Function('x')(y) + g = Function('w')(y) + expr = z + (f**4 + g**5)*(f**3 + (g*f)**3) + expr = expr.expand() + eval_expr = lambdify(((f, g), z), expr, cse=True) + ans = eval_expr((1.0, 2.0), 3.0) # shouldn't raise NameError + assert ans == 300.0 # not a list and value is 300 + + +class LambdifyDocstringTestCase: + SIGNATURE = None + EXPR = None + SRC = None + + def __init__(self, docstring_limit, expected_redacted): + self.docstring_limit = docstring_limit + self.expected_redacted = expected_redacted + + @property + def expected_expr(self): + expr_redacted_msg = "EXPRESSION REDACTED DUE TO LENGTH, (see lambdify's `docstring_limit`)" + return self.EXPR if not self.expected_redacted else expr_redacted_msg + + @property + def expected_src(self): + src_redacted_msg = "SOURCE CODE REDACTED DUE TO LENGTH, (see lambdify's `docstring_limit`)" + return self.SRC if not self.expected_redacted else src_redacted_msg + + @property + def expected_docstring(self): + expected_docstring = ( + f'Created with lambdify. Signature:\n\n' + f'func({self.SIGNATURE})\n\n' + f'Expression:\n\n' + f'{self.expected_expr}\n\n' + f'Source code:\n\n' + f'{self.expected_src}\n\n' + f'Imported modules:\n\n' + ) + return expected_docstring + + def __len__(self): + return len(self.expected_docstring) + + def __repr__(self): + return ( + f'{self.__class__.__name__}(' + f'docstring_limit={self.docstring_limit}, ' + f'expected_redacted={self.expected_redacted})' + ) + + +def test_lambdify_docstring_size_limit_simple_symbol(): + + class SimpleSymbolTestCase(LambdifyDocstringTestCase): + SIGNATURE = 'x' + EXPR = 'x' + SRC = ( + 'def _lambdifygenerated(x):\n' + ' return x\n' + ) + + x = symbols('x') + + test_cases = ( + SimpleSymbolTestCase(docstring_limit=None, expected_redacted=False), + SimpleSymbolTestCase(docstring_limit=100, expected_redacted=False), + SimpleSymbolTestCase(docstring_limit=1, expected_redacted=False), + SimpleSymbolTestCase(docstring_limit=0, expected_redacted=True), + SimpleSymbolTestCase(docstring_limit=-1, expected_redacted=True), + ) + for test_case in test_cases: + lambdified_expr = lambdify( + [x], + x, + 'sympy', + docstring_limit=test_case.docstring_limit, + ) + assert lambdified_expr.__doc__ == test_case.expected_docstring + + +def test_lambdify_docstring_size_limit_nested_expr(): + + class ExprListTestCase(LambdifyDocstringTestCase): + SIGNATURE = 'x, y, z' + EXPR = ( + '[x, [y], z, x**3 + 3*x**2*y + 3*x**2*z + 3*x*y**2 + 6*x*y*z ' + '+ 3*x*z**2 +...' + ) + SRC = ( + 'def _lambdifygenerated(x, y, z):\n' + ' return [x, [y], z, x**3 + 3*x**2*y + 3*x**2*z + 3*x*y**2 ' + '+ 6*x*y*z + 3*x*z**2 + y**3 + 3*y**2*z + 3*y*z**2 + z**3]\n' + ) + + x, y, z = symbols('x, y, z') + expr = [x, [y], z, ((x + y + z)**3).expand()] + + test_cases = ( + ExprListTestCase(docstring_limit=None, expected_redacted=False), + ExprListTestCase(docstring_limit=200, expected_redacted=False), + ExprListTestCase(docstring_limit=50, expected_redacted=True), + ExprListTestCase(docstring_limit=0, expected_redacted=True), + ExprListTestCase(docstring_limit=-1, expected_redacted=True), + ) + for test_case in test_cases: + lambdified_expr = lambdify( + [x, y, z], + expr, + 'sympy', + docstring_limit=test_case.docstring_limit, + ) + assert lambdified_expr.__doc__ == test_case.expected_docstring + + +def test_lambdify_docstring_size_limit_matrix(): + + class MatrixTestCase(LambdifyDocstringTestCase): + SIGNATURE = 'x, y, z' + EXPR = ( + 'Matrix([[0, x], [x + y + z, x**3 + 3*x**2*y + 3*x**2*z + 3*x*y**2 ' + '+ 6*x*y*z...' + ) + SRC = ( + 'def _lambdifygenerated(x, y, z):\n' + ' return ImmutableDenseMatrix([[0, x], [x + y + z, x**3 ' + '+ 3*x**2*y + 3*x**2*z + 3*x*y**2 + 6*x*y*z + 3*x*z**2 + y**3 ' + '+ 3*y**2*z + 3*y*z**2 + z**3]])\n' + ) + + x, y, z = symbols('x, y, z') + expr = Matrix([[S.Zero, x], [x + y + z, ((x + y + z)**3).expand()]]) + + test_cases = ( + MatrixTestCase(docstring_limit=None, expected_redacted=False), + MatrixTestCase(docstring_limit=200, expected_redacted=False), + MatrixTestCase(docstring_limit=50, expected_redacted=True), + MatrixTestCase(docstring_limit=0, expected_redacted=True), + MatrixTestCase(docstring_limit=-1, expected_redacted=True), + ) + for test_case in test_cases: + lambdified_expr = lambdify( + [x, y, z], + expr, + 'sympy', + docstring_limit=test_case.docstring_limit, + ) + assert lambdified_expr.__doc__ == test_case.expected_docstring + + +def test_lambdify_empty_tuple(): + a = symbols("a") + expr = ((), (a,)) + f = lambdify(a, expr) + result = f(1) + assert result == ((), (1,)), "Lambdify did not handle the empty tuple correctly." + + +def test_assoc_legendre_numerical_evaluation(): + + tol = 1e-10 + + sympy_result_integer = assoc_legendre(1, 1/2, 0.1).evalf() + sympy_result_complex = assoc_legendre(2, 1, 3).evalf() + mpmath_result_integer = -0.474572528387641 + mpmath_result_complex = -25.45584412271571*I + + assert all_close(sympy_result_integer, mpmath_result_integer, tol) + assert all_close(sympy_result_complex, mpmath_result_complex, tol) + + +def test_Piecewise(): + + modules = [math] + if numpy: + modules.append('numpy') + + for mod in modules: + # test isinf + f = lambdify(x, Piecewise((7.0, isinf(x)), (3.0, True)), mod) + assert f(+float('inf')) == +7.0 + assert f(-float('inf')) == +7.0 + assert f(42.) == 3.0 + + f2 = lambdify(x, Piecewise((7.0*sign(x), isinf(x)), (3.0, True)), mod) + assert f2(+float('inf')) == +7.0 + assert f2(-float('inf')) == -7.0 + assert f2(42.) == 3.0 + + # test isnan (gh-26784) + g = lambdify(x, Piecewise((7.0, isnan(x)), (3.0, True)), mod) + assert g(float('nan')) == 7.0 + assert g(42.) == 3.0 + + +def test_array_symbol(): + if not numpy: + skip("numpy not installed.") + a = ArraySymbol('a', (3,)) + f = lambdify((a), a) + assert numpy.all(f(numpy.array([1,2,3])) == numpy.array([1,2,3])) diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/test_matchpy_connector.py b/phivenv/Lib/site-packages/sympy/utilities/tests/test_matchpy_connector.py new file mode 100644 index 0000000000000000000000000000000000000000..3648bd49f9e56ca20fbf428ed46c01429dbe8b15 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/tests/test_matchpy_connector.py @@ -0,0 +1,164 @@ +import pickle + +from sympy.core.relational import (Eq, Ne) +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.external import import_module +from sympy.testing.pytest import skip +from sympy.utilities.matchpy_connector import WildDot, WildPlus, WildStar, Replacer + +matchpy = import_module("matchpy") + +x, y, z = symbols("x y z") + + +def _get_first_match(expr, pattern): + from matchpy import ManyToOneMatcher, Pattern + + matcher = ManyToOneMatcher() + matcher.add(Pattern(pattern)) + return next(iter(matcher.match(expr))) + + +def test_matchpy_connector(): + if matchpy is None: + skip("matchpy not installed") + + from multiset import Multiset + from matchpy import Pattern, Substitution + + w_ = WildDot("w_") + w__ = WildPlus("w__") + w___ = WildStar("w___") + + expr = x + y + pattern = x + w_ + p, subst = _get_first_match(expr, pattern) + assert p == Pattern(pattern) + assert subst == Substitution({'w_': y}) + + expr = x + y + z + pattern = x + w__ + p, subst = _get_first_match(expr, pattern) + assert p == Pattern(pattern) + assert subst == Substitution({'w__': Multiset([y, z])}) + + expr = x + y + z + pattern = x + y + z + w___ + p, subst = _get_first_match(expr, pattern) + assert p == Pattern(pattern) + assert subst == Substitution({'w___': Multiset()}) + + +def test_matchpy_optional(): + if matchpy is None: + skip("matchpy not installed") + + from matchpy import Pattern, Substitution + from matchpy import ManyToOneReplacer, ReplacementRule + + p = WildDot("p", optional=1) + q = WildDot("q", optional=0) + + pattern = p*x + q + + expr1 = 2*x + pa, subst = _get_first_match(expr1, pattern) + assert pa == Pattern(pattern) + assert subst == Substitution({'p': 2, 'q': 0}) + + expr2 = x + 3 + pa, subst = _get_first_match(expr2, pattern) + assert pa == Pattern(pattern) + assert subst == Substitution({'p': 1, 'q': 3}) + + expr3 = x + pa, subst = _get_first_match(expr3, pattern) + assert pa == Pattern(pattern) + assert subst == Substitution({'p': 1, 'q': 0}) + + expr4 = x*y + z + pa, subst = _get_first_match(expr4, pattern) + assert pa == Pattern(pattern) + assert subst == Substitution({'p': y, 'q': z}) + + replacer = ManyToOneReplacer() + replacer.add(ReplacementRule(Pattern(pattern), lambda p, q: sin(p)*cos(q))) + assert replacer.replace(expr1) == sin(2)*cos(0) + assert replacer.replace(expr2) == sin(1)*cos(3) + assert replacer.replace(expr3) == sin(1)*cos(0) + assert replacer.replace(expr4) == sin(y)*cos(z) + + +def test_replacer(): + if matchpy is None: + skip("matchpy not installed") + + for info in [True, False]: + for lambdify in [True, False]: + _perform_test_replacer(info, lambdify) + + +def _perform_test_replacer(info, lambdify): + + x1_ = WildDot("x1_") + x2_ = WildDot("x2_") + + a_ = WildDot("a_", optional=S.One) + b_ = WildDot("b_", optional=S.One) + c_ = WildDot("c_", optional=S.Zero) + + replacer = Replacer(common_constraints=[ + matchpy.CustomConstraint(lambda a_: not a_.has(x)), + matchpy.CustomConstraint(lambda b_: not b_.has(x)), + matchpy.CustomConstraint(lambda c_: not c_.has(x)), + ], lambdify=lambdify, info=info) + + # Rewrite the equation into implicit form, unless it's already solved: + replacer.add(Eq(x1_, x2_), Eq(x1_ - x2_, 0), conditions_nonfalse=[Ne(x2_, 0), Ne(x1_, 0), Ne(x1_, x), Ne(x2_, x)], info=1) + + # Simple equation solver for real numbers: + replacer.add(Eq(a_*x + b_, 0), Eq(x, -b_/a_), info=2) + disc = b_**2 - 4*a_*c_ + replacer.add( + Eq(a_*x**2 + b_*x + c_, 0), + Eq(x, (-b_ - sqrt(disc))/(2*a_)) | Eq(x, (-b_ + sqrt(disc))/(2*a_)), + conditions_nonfalse=[disc >= 0], + info=3 + ) + replacer.add( + Eq(a_*x**2 + c_, 0), + Eq(x, sqrt(-c_/a_)) | Eq(x, -sqrt(-c_/a_)), + conditions_nonfalse=[-c_*a_ > 0], + info=4 + ) + + g = lambda expr, infos: (expr, infos) if info else expr + + assert replacer.replace(Eq(3*x, y)) == g(Eq(x, y/3), [1, 2]) + assert replacer.replace(Eq(x**2 + 1, 0)) == g(Eq(x**2 + 1, 0), []) + assert replacer.replace(Eq(x**2, 4)) == g((Eq(x, 2) | Eq(x, -2)), [1, 4]) + assert replacer.replace(Eq(x**2 + 4*y*x + 4*y**2, 0)) == g(Eq(x, -2*y), [3]) + + +def test_matchpy_object_pickle(): + if matchpy is None: + return + + a1 = WildDot("a") + a2 = pickle.loads(pickle.dumps(a1)) + assert a1 == a2 + + a1 = WildDot("a", S(1)) + a2 = pickle.loads(pickle.dumps(a1)) + assert a1 == a2 + + a1 = WildPlus("a", S(1)) + a2 = pickle.loads(pickle.dumps(a1)) + assert a1 == a2 + + a1 = WildStar("a", S(1)) + a2 = pickle.loads(pickle.dumps(a1)) + assert a1 == a2 diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/test_mathml.py b/phivenv/Lib/site-packages/sympy/utilities/tests/test_mathml.py new file mode 100644 index 0000000000000000000000000000000000000000..e4a7598f175be34f8bb34c0bd9c003d1c0238c7b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/tests/test_mathml.py @@ -0,0 +1,33 @@ +import os +from textwrap import dedent +from sympy.external import import_module +from sympy.testing.pytest import skip +from sympy.utilities.mathml import apply_xsl + + + +lxml = import_module('lxml') + +path = os.path.abspath(os.path.join(os.path.dirname(__file__), "test_xxe.py")) + + +def test_xxe(): + assert os.path.isfile(path) + if not lxml: + skip("lxml not installed.") + + mml = dedent( + rf""" + + ]> + + John + &ent; + + """ + ) + xsl = 'mathml/data/simple_mmlctop.xsl' + + res = apply_xsl(mml, xsl) + assert res == \ + '\n\nJohn\n\n\n' diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/test_misc.py b/phivenv/Lib/site-packages/sympy/utilities/tests/test_misc.py new file mode 100644 index 0000000000000000000000000000000000000000..0a3e059419303c33cbd7b770679b5efc1b03486d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/tests/test_misc.py @@ -0,0 +1,148 @@ +from textwrap import dedent +import sys +from subprocess import Popen, PIPE +import os + +from sympy.core.singleton import S +from sympy.testing.pytest import (raises, warns_deprecated_sympy, + skip_under_pyodide) +from sympy.utilities.misc import (translate, replace, ordinal, rawlines, + strlines, as_int, find_executable) + + +def test_translate(): + abc = 'abc' + assert translate(abc, None, 'a') == 'bc' + assert translate(abc, None, '') == 'abc' + assert translate(abc, {'a': 'x'}, 'c') == 'xb' + assert translate(abc, {'a': 'bc'}, 'c') == 'bcb' + assert translate(abc, {'ab': 'x'}, 'c') == 'x' + assert translate(abc, {'ab': ''}, 'c') == '' + assert translate(abc, {'bc': 'x'}, 'c') == 'ab' + assert translate(abc, {'abc': 'x', 'a': 'y'}) == 'x' + u = chr(4096) + assert translate(abc, 'a', 'x', u) == 'xbc' + assert (u in translate(abc, 'a', u, u)) is True + + +def test_replace(): + assert replace('abc', ('a', 'b')) == 'bbc' + assert replace('abc', {'a': 'Aa'}) == 'Aabc' + assert replace('abc', ('a', 'b'), ('c', 'C')) == 'bbC' + + +def test_ordinal(): + assert ordinal(-1) == '-1st' + assert ordinal(0) == '0th' + assert ordinal(1) == '1st' + assert ordinal(2) == '2nd' + assert ordinal(3) == '3rd' + assert all(ordinal(i).endswith('th') for i in range(4, 21)) + assert ordinal(100) == '100th' + assert ordinal(101) == '101st' + assert ordinal(102) == '102nd' + assert ordinal(103) == '103rd' + assert ordinal(104) == '104th' + assert ordinal(200) == '200th' + assert all(ordinal(i) == str(i) + 'th' for i in range(-220, -203)) + + +def test_rawlines(): + assert rawlines('a a\na') == "dedent('''\\\n a a\n a''')" + assert rawlines('a a') == "'a a'" + assert rawlines(strlines('\\le"ft')) == ( + '(\n' + " '(\\n'\n" + ' \'r\\\'\\\\le"ft\\\'\\n\'\n' + " ')'\n" + ')') + + +def test_strlines(): + q = 'this quote (") is in the middle' + # the following assert rhs was prepared with + # print(rawlines(strlines(q, 10))) + assert strlines(q, 10) == dedent('''\ + ( + 'this quo' + 'te (") i' + 's in the' + ' middle' + )''') + assert q == ( + 'this quo' + 'te (") i' + 's in the' + ' middle' + ) + q = "this quote (') is in the middle" + assert strlines(q, 20) == dedent('''\ + ( + "this quote (') is " + "in the middle" + )''') + assert strlines('\\left') == ( + '(\n' + "r'\\left'\n" + ')') + assert strlines('\\left', short=True) == r"r'\left'" + assert strlines('\\le"ft') == ( + '(\n' + 'r\'\\le"ft\'\n' + ')') + q = 'this\nother line' + assert strlines(q) == rawlines(q) + + +def test_translate_args(): + try: + translate(None, None, None, 'not_none') + except ValueError: + pass # Exception raised successfully + else: + assert False + + assert translate('s', None, None, None) == 's' + + try: + translate('s', 'a', 'bc') + except ValueError: + pass # Exception raised successfully + else: + assert False + + +@skip_under_pyodide("Cannot create subprocess under pyodide.") +def test_debug_output(): + env = os.environ.copy() + env['SYMPY_DEBUG'] = 'True' + cmd = 'from sympy import *; x = Symbol("x"); print(integrate((1-cos(x))/x, x))' + cmdline = [sys.executable, '-c', cmd] + proc = Popen(cmdline, env=env, stdout=PIPE, stderr=PIPE) + out, err = proc.communicate() + out = out.decode('ascii') # utf-8? + err = err.decode('ascii') + expected = 'substituted: -x*(1 - cos(x)), u: 1/x, u_var: _u' + assert expected in err, err + + +def test_as_int(): + raises(ValueError, lambda : as_int(True)) + raises(ValueError, lambda : as_int(1.1)) + raises(ValueError, lambda : as_int([])) + raises(ValueError, lambda : as_int(S.NaN)) + raises(ValueError, lambda : as_int(S.Infinity)) + raises(ValueError, lambda : as_int(S.NegativeInfinity)) + raises(ValueError, lambda : as_int(S.ComplexInfinity)) + # for the following, limited precision makes int(arg) == arg + # but the int value is not necessarily what a user might have + # expected; Q.prime is more nuanced in its response for + # expressions which might be complex representations of an + # integer. This is not -- by design -- as_ints role. + raises(ValueError, lambda : as_int(1e23)) + raises(ValueError, lambda : as_int(S('1.'+'0'*20+'1'))) + assert as_int(True, strict=False) == 1 + +def test_deprecated_find_executable(): + with warns_deprecated_sympy(): + find_executable('python') diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/test_pickling.py b/phivenv/Lib/site-packages/sympy/utilities/tests/test_pickling.py new file mode 100644 index 0000000000000000000000000000000000000000..7ea2d062286cbcf802eb0f42f2d7d130123599af --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/tests/test_pickling.py @@ -0,0 +1,723 @@ +import inspect +import copy +import pickle + +from sympy.physics.units import meter + +from sympy.testing.pytest import XFAIL, raises, ignore_warnings + +from sympy.core.basic import Atom, Basic +from sympy.core.singleton import SingletonRegistry +from sympy.core.symbol import Str, Dummy, Symbol, Wild +from sympy.core.numbers import (E, I, pi, oo, zoo, nan, Integer, + Rational, Float, AlgebraicNumber) +from sympy.core.relational import (Equality, GreaterThan, LessThan, Relational, + StrictGreaterThan, StrictLessThan, Unequality) +from sympy.core.add import Add +from sympy.core.mul import Mul +from sympy.core.power import Pow +from sympy.core.function import Derivative, Function, FunctionClass, Lambda, \ + WildFunction +from sympy.sets.sets import Interval +from sympy.core.multidimensional import vectorize + +from sympy.external.gmpy import gmpy as _gmpy +from sympy.utilities.exceptions import SymPyDeprecationWarning + +from sympy.core.singleton import S +from sympy.core.symbol import symbols + +from sympy.external import import_module +cloudpickle = import_module('cloudpickle') + + +not_equal_attrs = { + '_assumptions', # This is a local cache that isn't automatically filled on creation + '_mhash', # Cached after __hash__ is called but set to None after creation +} + + +deprecated_attrs = { + 'is_EmptySet', # Deprecated from SymPy 1.5. This can be removed when is_EmptySet is removed. + 'expr_free_symbols', # Deprecated from SymPy 1.9. This can be removed when exr_free_symbols is removed. +} + +dont_check_attrs = { + '_sage_', # Fails because Sage is not installed +} + + +def check(a, exclude=[], check_attr=True, deprecated=()): + """ Check that pickling and copying round-trips. + """ + # Pickling with protocols 0 and 1 is disabled for Basic instances: + if isinstance(a, Basic): + for protocol in [0, 1]: + raises(NotImplementedError, lambda: pickle.dumps(a, protocol)) + + protocols = [2, copy.copy, copy.deepcopy, 3, 4] + if cloudpickle: + protocols.extend([cloudpickle]) + + for protocol in protocols: + if protocol in exclude: + continue + + if callable(protocol): + if isinstance(a, type): + # Classes can't be copied, but that's okay. + continue + b = protocol(a) + elif inspect.ismodule(protocol): + b = protocol.loads(protocol.dumps(a)) + else: + b = pickle.loads(pickle.dumps(a, protocol)) + + d1 = dir(a) + d2 = dir(b) + assert set(d1) == set(d2) + + if not check_attr: + continue + + def c(a, b, d): + for i in d: + if i in dont_check_attrs: + continue + elif i in not_equal_attrs: + if hasattr(a, i): + assert hasattr(b, i), i + elif i in deprecated_attrs or i in deprecated: + with ignore_warnings(SymPyDeprecationWarning): + assert getattr(a, i) == getattr(b, i), i + elif not hasattr(a, i): + continue + else: + attr = getattr(a, i) + if not hasattr(attr, "__call__"): + assert hasattr(b, i), i + assert getattr(b, i) == attr, "%s != %s, protocol: %s" % (getattr(b, i), attr, protocol) + + c(a, b, d1) + c(b, a, d2) + + + +#================== core ========================= + + +def test_core_basic(): + for c in (Atom, Atom(), Basic, Basic(), SingletonRegistry, S): + check(c) + +def test_core_Str(): + check(Str('x')) + +def test_core_symbol(): + # make the Symbol a unique name that doesn't class with any other + # testing variable in this file since after this test the symbol + # having the same name will be cached as noncommutative + for c in (Dummy, Dummy("x", commutative=False), Symbol, + Symbol("_issue_3130", commutative=False), Wild, Wild("x")): + check(c) + + +def test_core_numbers(): + for c in (Integer(2), Rational(2, 3), Float("1.2")): + check(c) + for c in (AlgebraicNumber, AlgebraicNumber(sqrt(3))): + check(c, check_attr=False) + + +def test_core_float_copy(): + # See gh-7457 + y = Symbol("x") + 1.0 + check(y) # does not raise TypeError ("argument is not an mpz") + + +def test_core_relational(): + x = Symbol("x") + y = Symbol("y") + for c in (Equality, Equality(x, y), GreaterThan, GreaterThan(x, y), + LessThan, LessThan(x, y), Relational, Relational(x, y), + StrictGreaterThan, StrictGreaterThan(x, y), StrictLessThan, + StrictLessThan(x, y), Unequality, Unequality(x, y)): + check(c) + + +def test_core_add(): + x = Symbol("x") + for c in (Add, Add(x, 4)): + check(c) + + +def test_core_mul(): + x = Symbol("x") + for c in (Mul, Mul(x, 4)): + check(c) + + +def test_core_power(): + x = Symbol("x") + for c in (Pow, Pow(x, 4)): + check(c) + + +def test_core_function(): + x = Symbol("x") + for f in (Derivative, Derivative(x), Function, FunctionClass, Lambda, + WildFunction): + check(f) + + +def test_core_undefinedfunctions(): + f = Function("f") + check(f) + + +def test_core_appliedundef(): + x = Symbol("_long_unique_name_1") + f = Function("_long_unique_name_2") + check(f(x)) + + +def test_core_interval(): + for c in (Interval, Interval(0, 2)): + check(c) + + +def test_core_multidimensional(): + for c in (vectorize, vectorize(0)): + check(c) + + +def test_Singletons(): + protocols = [0, 1, 2, 3, 4] + copiers = [copy.copy, copy.deepcopy] + copiers += [lambda x: pickle.loads(pickle.dumps(x, proto)) + for proto in protocols] + if cloudpickle: + copiers += [lambda x: cloudpickle.loads(cloudpickle.dumps(x))] + + for obj in (Integer(-1), Integer(0), Integer(1), Rational(1, 2), pi, E, I, + oo, -oo, zoo, nan, S.GoldenRatio, S.TribonacciConstant, + S.EulerGamma, S.Catalan, S.EmptySet, S.IdentityFunction): + for func in copiers: + assert func(obj) is obj + +#================== combinatorics =================== +from sympy.combinatorics.free_groups import FreeGroup + +def test_free_group(): + check(FreeGroup("x, y, z"), check_attr=False) + +#================== functions =================== +from sympy.functions import (Piecewise, lowergamma, acosh, chebyshevu, + chebyshevt, ln, chebyshevt_root, legendre, Heaviside, bernoulli, coth, + tanh, assoc_legendre, sign, arg, asin, DiracDelta, re, rf, Abs, + uppergamma, binomial, sinh, cos, cot, acos, acot, gamma, bell, + hermite, harmonic, LambertW, zeta, log, factorial, asinh, acoth, cosh, + dirichlet_eta, Eijk, loggamma, erf, ceiling, im, fibonacci, + tribonacci, conjugate, tan, chebyshevu_root, floor, atanh, sqrt, sin, + atan, ff, lucas, atan2, polygamma, exp) + + +def test_functions(): + one_var = (acosh, ln, Heaviside, factorial, bernoulli, coth, tanh, + sign, arg, asin, DiracDelta, re, Abs, sinh, cos, cot, acos, acot, + gamma, bell, harmonic, LambertW, zeta, log, factorial, asinh, + acoth, cosh, dirichlet_eta, loggamma, erf, ceiling, im, fibonacci, + tribonacci, conjugate, tan, floor, atanh, sin, atan, lucas, exp) + two_var = (rf, ff, lowergamma, chebyshevu, chebyshevt, binomial, + atan2, polygamma, hermite, legendre, uppergamma) + x, y, z = symbols("x,y,z") + others = (chebyshevt_root, chebyshevu_root, Eijk(x, y, z), + Piecewise( (0, x < -1), (x**2, x <= 1), (x**3, True)), + assoc_legendre) + for cls in one_var: + check(cls) + c = cls(x) + check(c) + for cls in two_var: + check(cls) + c = cls(x, y) + check(c) + for cls in others: + check(cls) + +#================== geometry ==================== +from sympy.geometry.entity import GeometryEntity +from sympy.geometry.point import Point +from sympy.geometry.ellipse import Circle, Ellipse +from sympy.geometry.line import Line, LinearEntity, Ray, Segment +from sympy.geometry.polygon import Polygon, RegularPolygon, Triangle + + +def test_geometry(): + p1 = Point(1, 2) + p2 = Point(2, 3) + p3 = Point(0, 0) + p4 = Point(0, 1) + for c in ( + GeometryEntity, GeometryEntity(), Point, p1, Circle, Circle(p1, 2), + Ellipse, Ellipse(p1, 3, 4), Line, Line(p1, p2), LinearEntity, + LinearEntity(p1, p2), Ray, Ray(p1, p2), Segment, Segment(p1, p2), + Polygon, Polygon(p1, p2, p3, p4), RegularPolygon, + RegularPolygon(p1, 4, 5), Triangle, Triangle(p1, p2, p3)): + check(c, check_attr=False) + +#================== integrals ==================== +from sympy.integrals.integrals import Integral + + +def test_integrals(): + x = Symbol("x") + for c in (Integral, Integral(x)): + check(c) + +#==================== logic ===================== +from sympy.core.logic import Logic + + +def test_logic(): + for c in (Logic, Logic(1)): + check(c) + +#================== matrices ==================== +from sympy.matrices import Matrix, SparseMatrix + + +def test_matrices(): + for c in (Matrix, Matrix([1, 2, 3]), SparseMatrix, SparseMatrix([[1, 2], [3, 4]])): + check(c, deprecated=['_smat', '_mat']) + +#================== ntheory ===================== +from sympy.ntheory.generate import Sieve + + +def test_ntheory(): + for c in (Sieve, Sieve()): + check(c) + +#================== physics ===================== +from sympy.physics.paulialgebra import Pauli +from sympy.physics.units import Unit + + +def test_physics(): + for c in (Unit, meter, Pauli, Pauli(1)): + check(c) + +#================== plotting ==================== +# XXX: These tests are not complete, so XFAIL them + + +@XFAIL +def test_plotting(): + from sympy.plotting.pygletplot.color_scheme import ColorGradient, ColorScheme + from sympy.plotting.pygletplot.managed_window import ManagedWindow + from sympy.plotting.plot import Plot, ScreenShot + from sympy.plotting.pygletplot.plot_axes import PlotAxes, PlotAxesBase, PlotAxesFrame, PlotAxesOrdinate + from sympy.plotting.pygletplot.plot_camera import PlotCamera + from sympy.plotting.pygletplot.plot_controller import PlotController + from sympy.plotting.pygletplot.plot_curve import PlotCurve + from sympy.plotting.pygletplot.plot_interval import PlotInterval + from sympy.plotting.pygletplot.plot_mode import PlotMode + from sympy.plotting.pygletplot.plot_modes import Cartesian2D, Cartesian3D, Cylindrical, \ + ParametricCurve2D, ParametricCurve3D, ParametricSurface, Polar, Spherical + from sympy.plotting.pygletplot.plot_object import PlotObject + from sympy.plotting.pygletplot.plot_surface import PlotSurface + from sympy.plotting.pygletplot.plot_window import PlotWindow + for c in ( + ColorGradient, ColorGradient(0.2, 0.4), ColorScheme, ManagedWindow, + ManagedWindow, Plot, ScreenShot, PlotAxes, PlotAxesBase, + PlotAxesFrame, PlotAxesOrdinate, PlotCamera, PlotController, + PlotCurve, PlotInterval, PlotMode, Cartesian2D, Cartesian3D, + Cylindrical, ParametricCurve2D, ParametricCurve3D, + ParametricSurface, Polar, Spherical, PlotObject, PlotSurface, + PlotWindow): + check(c) + + +@XFAIL +def test_plotting2(): + #from sympy.plotting.color_scheme import ColorGradient + from sympy.plotting.pygletplot.color_scheme import ColorScheme + #from sympy.plotting.managed_window import ManagedWindow + from sympy.plotting.plot import Plot + #from sympy.plotting.plot import ScreenShot + from sympy.plotting.pygletplot.plot_axes import PlotAxes + #from sympy.plotting.plot_axes import PlotAxesBase, PlotAxesFrame, PlotAxesOrdinate + #from sympy.plotting.plot_camera import PlotCamera + #from sympy.plotting.plot_controller import PlotController + #from sympy.plotting.plot_curve import PlotCurve + #from sympy.plotting.plot_interval import PlotInterval + #from sympy.plotting.plot_mode import PlotMode + #from sympy.plotting.plot_modes import Cartesian2D, Cartesian3D, Cylindrical, \ + # ParametricCurve2D, ParametricCurve3D, ParametricSurface, Polar, Spherical + #from sympy.plotting.plot_object import PlotObject + #from sympy.plotting.plot_surface import PlotSurface + # from sympy.plotting.plot_window import PlotWindow + check(ColorScheme("rainbow")) + check(Plot(1, visible=False)) + check(PlotAxes()) + +#================== polys ======================= +from sympy.polys.domains.integerring import ZZ +from sympy.polys.domains.rationalfield import QQ +from sympy.polys.orderings import lex +from sympy.polys.polytools import Poly + +def test_pickling_polys_polytools(): + from sympy.polys.polytools import PurePoly + # from sympy.polys.polytools import GroebnerBasis + x = Symbol('x') + + for c in (Poly, Poly(x, x)): + check(c) + + for c in (PurePoly, PurePoly(x)): + check(c) + + # TODO: fix pickling of Options class (see GroebnerBasis._options) + # for c in (GroebnerBasis, GroebnerBasis([x**2 - 1], x, order=lex)): + # check(c) + +def test_pickling_polys_polyclasses(): + from sympy.polys.polyclasses import DMP, DMF, ANP + + for c in (DMP, DMP([[ZZ(1)], [ZZ(2)], [ZZ(3)]], ZZ)): + check(c, deprecated=['rep']) + for c in (DMF, DMF(([ZZ(1), ZZ(2)], [ZZ(1), ZZ(3)]), ZZ)): + check(c) + for c in (ANP, ANP([QQ(1), QQ(2)], [QQ(1), QQ(2), QQ(3)], QQ)): + check(c) + +@XFAIL +def test_pickling_polys_rings(): + # NOTE: can't use protocols < 2 because we have to execute __new__ to + # make sure caching of rings works properly. + + from sympy.polys.rings import PolyRing + + ring = PolyRing("x,y,z", ZZ, lex) + + for c in (PolyRing, ring): + check(c, exclude=[0, 1]) + + for c in (ring.dtype, ring.one): + check(c, exclude=[0, 1], check_attr=False) # TODO: Py3k + +def test_pickling_polys_fields(): + pass + # NOTE: can't use protocols < 2 because we have to execute __new__ to + # make sure caching of fields works properly. + + # from sympy.polys.fields import FracField + + # field = FracField("x,y,z", ZZ, lex) + + # TODO: AssertionError: assert id(obj) not in self.memo + # for c in (FracField, field): + # check(c, exclude=[0, 1]) + + # TODO: AssertionError: assert id(obj) not in self.memo + # for c in (field.dtype, field.one): + # check(c, exclude=[0, 1]) + +def test_pickling_polys_elements(): + from sympy.polys.domains.pythonrational import PythonRational + #from sympy.polys.domains.pythonfinitefield import PythonFiniteField + #from sympy.polys.domains.mpelements import MPContext + + for c in (PythonRational, PythonRational(1, 7)): + check(c) + + #gf = PythonFiniteField(17) + + # TODO: fix pickling of ModularInteger + # for c in (gf.dtype, gf(5)): + # check(c) + + #mp = MPContext() + + # TODO: fix pickling of RealElement + # for c in (mp.mpf, mp.mpf(1.0)): + # check(c) + + # TODO: fix pickling of ComplexElement + # for c in (mp.mpc, mp.mpc(1.0, -1.5)): + # check(c) + +def test_pickling_polys_domains(): + # from sympy.polys.domains.pythonfinitefield import PythonFiniteField + from sympy.polys.domains.pythonintegerring import PythonIntegerRing + from sympy.polys.domains.pythonrationalfield import PythonRationalField + + # TODO: fix pickling of ModularInteger + # for c in (PythonFiniteField, PythonFiniteField(17)): + # check(c) + + for c in (PythonIntegerRing, PythonIntegerRing()): + check(c, check_attr=False) + + for c in (PythonRationalField, PythonRationalField()): + check(c, check_attr=False) + + if _gmpy is not None: + # from sympy.polys.domains.gmpyfinitefield import GMPYFiniteField + from sympy.polys.domains.gmpyintegerring import GMPYIntegerRing + from sympy.polys.domains.gmpyrationalfield import GMPYRationalField + + # TODO: fix pickling of ModularInteger + # for c in (GMPYFiniteField, GMPYFiniteField(17)): + # check(c) + + for c in (GMPYIntegerRing, GMPYIntegerRing()): + check(c, check_attr=False) + + for c in (GMPYRationalField, GMPYRationalField()): + check(c, check_attr=False) + + #from sympy.polys.domains.realfield import RealField + #from sympy.polys.domains.complexfield import ComplexField + from sympy.polys.domains.algebraicfield import AlgebraicField + #from sympy.polys.domains.polynomialring import PolynomialRing + #from sympy.polys.domains.fractionfield import FractionField + from sympy.polys.domains.expressiondomain import ExpressionDomain + + # TODO: fix pickling of RealElement + # for c in (RealField, RealField(100)): + # check(c) + + # TODO: fix pickling of ComplexElement + # for c in (ComplexField, ComplexField(100)): + # check(c) + + for c in (AlgebraicField, AlgebraicField(QQ, sqrt(3))): + check(c, check_attr=False) + + # TODO: AssertionError + # for c in (PolynomialRing, PolynomialRing(ZZ, "x,y,z")): + # check(c) + + # TODO: AttributeError: 'PolyElement' object has no attribute 'ring' + # for c in (FractionField, FractionField(ZZ, "x,y,z")): + # check(c) + + for c in (ExpressionDomain, ExpressionDomain()): + check(c, check_attr=False) + + +def test_pickling_polys_orderings(): + from sympy.polys.orderings import (LexOrder, GradedLexOrder, + ReversedGradedLexOrder, InverseOrder) + # from sympy.polys.orderings import ProductOrder + + for c in (LexOrder, LexOrder()): + check(c) + + for c in (GradedLexOrder, GradedLexOrder()): + check(c) + + for c in (ReversedGradedLexOrder, ReversedGradedLexOrder()): + check(c) + + # TODO: Argh, Python is so naive. No lambdas nor inner function support in + # pickling module. Maybe someone could figure out what to do with this. + # + # for c in (ProductOrder, ProductOrder((LexOrder(), lambda m: m[:2]), + # (GradedLexOrder(), lambda m: m[2:]))): + # check(c) + + for c in (InverseOrder, InverseOrder(LexOrder())): + check(c) + +def test_pickling_polys_monomials(): + from sympy.polys.monomials import MonomialOps, Monomial + x, y, z = symbols("x,y,z") + + for c in (MonomialOps, MonomialOps(3)): + check(c) + + for c in (Monomial, Monomial((1, 2, 3), (x, y, z))): + check(c) + +def test_pickling_polys_errors(): + from sympy.polys.polyerrors import (HeuristicGCDFailed, + HomomorphismFailed, IsomorphismFailed, ExtraneousFactors, + EvaluationFailed, RefinementFailed, CoercionFailed, NotInvertible, + NotReversible, NotAlgebraic, DomainError, PolynomialError, + UnificationFailed, GeneratorsError, GeneratorsNeeded, + UnivariatePolynomialError, MultivariatePolynomialError, OptionError, + FlagError) + # from sympy.polys.polyerrors import (ExactQuotientFailed, + # OperationNotSupported, ComputationFailed, PolificationFailed) + + # x = Symbol('x') + + # TODO: TypeError: __init__() takes at least 3 arguments (1 given) + # for c in (ExactQuotientFailed, ExactQuotientFailed(x, 3*x, ZZ)): + # check(c) + + # TODO: TypeError: can't pickle instancemethod objects + # for c in (OperationNotSupported, OperationNotSupported(Poly(x), Poly.gcd)): + # check(c) + + for c in (HeuristicGCDFailed, HeuristicGCDFailed()): + check(c) + + for c in (HomomorphismFailed, HomomorphismFailed()): + check(c) + + for c in (IsomorphismFailed, IsomorphismFailed()): + check(c) + + for c in (ExtraneousFactors, ExtraneousFactors()): + check(c) + + for c in (EvaluationFailed, EvaluationFailed()): + check(c) + + for c in (RefinementFailed, RefinementFailed()): + check(c) + + for c in (CoercionFailed, CoercionFailed()): + check(c) + + for c in (NotInvertible, NotInvertible()): + check(c) + + for c in (NotReversible, NotReversible()): + check(c) + + for c in (NotAlgebraic, NotAlgebraic()): + check(c) + + for c in (DomainError, DomainError()): + check(c) + + for c in (PolynomialError, PolynomialError()): + check(c) + + for c in (UnificationFailed, UnificationFailed()): + check(c) + + for c in (GeneratorsError, GeneratorsError()): + check(c) + + for c in (GeneratorsNeeded, GeneratorsNeeded()): + check(c) + + # TODO: PicklingError: Can't pickle at 0x38578c0>: it's not found as __main__. + # for c in (ComputationFailed, ComputationFailed(lambda t: t, 3, None)): + # check(c) + + for c in (UnivariatePolynomialError, UnivariatePolynomialError()): + check(c) + + for c in (MultivariatePolynomialError, MultivariatePolynomialError()): + check(c) + + # TODO: TypeError: __init__() takes at least 3 arguments (1 given) + # for c in (PolificationFailed, PolificationFailed({}, x, x, False)): + # check(c) + + for c in (OptionError, OptionError()): + check(c) + + for c in (FlagError, FlagError()): + check(c) + +#def test_pickling_polys_options(): + #from sympy.polys.polyoptions import Options + + # TODO: fix pickling of `symbols' flag + # for c in (Options, Options((), dict(domain='ZZ', polys=False))): + # check(c) + +# TODO: def test_pickling_polys_rootisolation(): +# RealInterval +# ComplexInterval + +def test_pickling_polys_rootoftools(): + from sympy.polys.rootoftools import CRootOf, RootSum + + x = Symbol('x') + f = x**3 + x + 3 + + for c in (CRootOf, CRootOf(f, 0)): + check(c) + + for c in (RootSum, RootSum(f, exp)): + check(c) + +#================== printing ==================== +from sympy.printing.latex import LatexPrinter +from sympy.printing.mathml import MathMLContentPrinter, MathMLPresentationPrinter +from sympy.printing.pretty.pretty import PrettyPrinter +from sympy.printing.pretty.stringpict import prettyForm, stringPict +from sympy.printing.printer import Printer +from sympy.printing.python import PythonPrinter + + +def test_printing(): + for c in (LatexPrinter, LatexPrinter(), MathMLContentPrinter, + MathMLPresentationPrinter, PrettyPrinter, prettyForm, stringPict, + stringPict("a"), Printer, Printer(), PythonPrinter, + PythonPrinter()): + check(c) + + +@XFAIL +def test_printing1(): + check(MathMLContentPrinter()) + + +@XFAIL +def test_printing2(): + check(MathMLPresentationPrinter()) + + +@XFAIL +def test_printing3(): + check(PrettyPrinter()) + +#================== series ====================== +from sympy.series.limits import Limit +from sympy.series.order import Order + + +def test_series(): + e = Symbol("e") + x = Symbol("x") + for c in (Limit, Limit(e, x, 1), Order, Order(e)): + check(c) + +#================== concrete ================== +from sympy.concrete.products import Product +from sympy.concrete.summations import Sum + + +def test_concrete(): + x = Symbol("x") + for c in (Product, Product(x, (x, 2, 4)), Sum, Sum(x, (x, 2, 4))): + check(c) + +def test_deprecation_warning(): + w = SymPyDeprecationWarning("message", deprecated_since_version='1.0', active_deprecations_target="active-deprecations") + check(w) + +def test_issue_18438(): + assert pickle.loads(pickle.dumps(S.Half)) == S.Half + + +#================= old pickles ================= +def test_unpickle_from_older_versions(): + data = ( + b'\x80\x04\x95^\x00\x00\x00\x00\x00\x00\x00\x8c\x10sympy.core.power' + b'\x94\x8c\x03Pow\x94\x93\x94\x8c\x12sympy.core.numbers\x94\x8c' + b'\x07Integer\x94\x93\x94K\x02\x85\x94R\x94}\x94bh\x03\x8c\x04Half' + b'\x94\x93\x94)R\x94}\x94b\x86\x94R\x94}\x94b.' + ) + assert pickle.loads(data) == sqrt(2) diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/test_source.py b/phivenv/Lib/site-packages/sympy/utilities/tests/test_source.py new file mode 100644 index 0000000000000000000000000000000000000000..468185bc579fc325aee21024dfa15ebf14287b5f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/tests/test_source.py @@ -0,0 +1,11 @@ +from sympy.utilities.source import get_mod_func, get_class + + +def test_get_mod_func(): + assert get_mod_func( + 'sympy.core.basic.Basic') == ('sympy.core.basic', 'Basic') + + +def test_get_class(): + _basic = get_class('sympy.core.basic.Basic') + assert _basic.__name__ == 'Basic' diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/test_timeutils.py b/phivenv/Lib/site-packages/sympy/utilities/tests/test_timeutils.py new file mode 100644 index 0000000000000000000000000000000000000000..14edfd089c7315ee9f39a4298af0289f8919da6b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/tests/test_timeutils.py @@ -0,0 +1,10 @@ +"""Tests for simple tools for timing functions' execution. """ + +from sympy.utilities.timeutils import timed + +def test_timed(): + result = timed(lambda: 1 + 1, limit=100000) + assert result[0] == 100000 and result[3] == "ns", str(result) + + result = timed("1 + 1", limit=100000) + assert result[0] == 100000 and result[3] == "ns" diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/test_wester.py b/phivenv/Lib/site-packages/sympy/utilities/tests/test_wester.py new file mode 100644 index 0000000000000000000000000000000000000000..c5699a4eb0824e507967ecd4ff8f7a5f32cc9a54 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/tests/test_wester.py @@ -0,0 +1,3104 @@ +""" Tests from Michael Wester's 1999 paper "Review of CAS mathematical +capabilities". + +http://www.math.unm.edu/~wester/cas/book/Wester.pdf +See also http://math.unm.edu/~wester/cas_review.html for detailed output of +each tested system. +""" + +from sympy.assumptions.ask import Q, ask +from sympy.assumptions.refine import refine +from sympy.concrete.products import product +from sympy.core import EulerGamma +from sympy.core.evalf import N +from sympy.core.function import (Derivative, Function, Lambda, Subs, + diff, expand, expand_func) +from sympy.core.mul import Mul +from sympy.core.intfunc import igcd +from sympy.core.numbers import (AlgebraicNumber, E, I, Rational, + nan, oo, pi, zoo) +from sympy.core.relational import Eq, Lt +from sympy.core.singleton import S +from sympy.core.symbol import Dummy, Symbol, symbols +from sympy.functions.combinatorial.factorials import (rf, binomial, + factorial, factorial2) +from sympy.functions.combinatorial.numbers import bernoulli, fibonacci, totient, partition +from sympy.functions.elementary.complexes import (conjugate, im, re, + sign) +from sympy.functions.elementary.exponential import LambertW, exp, log +from sympy.functions.elementary.hyperbolic import (asinh, cosh, sinh, + tanh) +from sympy.functions.elementary.integers import ceiling, floor +from sympy.functions.elementary.miscellaneous import Max, Min, sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (acos, acot, asin, + atan, cos, cot, csc, sec, sin, tan) +from sympy.functions.special.bessel import besselj +from sympy.functions.special.delta_functions import DiracDelta +from sympy.functions.special.elliptic_integrals import (elliptic_e, + elliptic_f) +from sympy.functions.special.gamma_functions import gamma, polygamma +from sympy.functions.special.hyper import hyper +from sympy.functions.special.polynomials import (assoc_legendre, + chebyshevt) +from sympy.functions.special.zeta_functions import polylog +from sympy.geometry.util import idiff +from sympy.logic.boolalg import And +from sympy.matrices.dense import hessian, wronskian +from sympy.matrices.expressions.matmul import MatMul +from sympy.ntheory.continued_fraction import ( + continued_fraction_convergents as cf_c, + continued_fraction_iterator as cf_i, continued_fraction_periodic as + cf_p, continued_fraction_reduce as cf_r) +from sympy.ntheory.factor_ import factorint +from sympy.ntheory.generate import primerange +from sympy.polys.domains.integerring import ZZ +from sympy.polys.orthopolys import legendre_poly +from sympy.polys.partfrac import apart +from sympy.polys.polytools import Poly, factor, gcd, resultant +from sympy.series.limits import limit +from sympy.series.order import O +from sympy.series.residues import residue +from sympy.series.series import series +from sympy.sets.fancysets import ImageSet +from sympy.sets.sets import FiniteSet, Intersection, Interval, Union +from sympy.simplify.combsimp import combsimp +from sympy.simplify.hyperexpand import hyperexpand +from sympy.simplify.powsimp import powdenest, powsimp +from sympy.simplify.radsimp import radsimp +from sympy.simplify.simplify import logcombine, simplify +from sympy.simplify.sqrtdenest import sqrtdenest +from sympy.simplify.trigsimp import trigsimp +from sympy.solvers.solvers import solve + +import mpmath +from sympy.functions.combinatorial.numbers import stirling +from sympy.functions.special.delta_functions import Heaviside +from sympy.functions.special.error_functions import Ci, Si, erf +from sympy.functions.special.zeta_functions import zeta +from sympy.testing.pytest import (XFAIL, slow, SKIP, tooslow, raises) +from sympy.utilities.iterables import partitions +from mpmath import mpi, mpc +from sympy.matrices import Matrix, GramSchmidt, eye +from sympy.matrices.expressions.blockmatrix import BlockMatrix, block_collapse +from sympy.matrices.expressions import MatrixSymbol, ZeroMatrix +from sympy.physics.quantum import Commutator +from sympy.polys.rings import PolyRing +from sympy.polys.fields import FracField +from sympy.polys.solvers import solve_lin_sys +from sympy.concrete import Sum +from sympy.concrete.products import Product +from sympy.integrals import integrate +from sympy.integrals.transforms import laplace_transform,\ + inverse_laplace_transform, LaplaceTransform, fourier_transform,\ + mellin_transform, laplace_correspondence, laplace_initial_conds +from sympy.solvers.recurr import rsolve +from sympy.solvers.solveset import solveset, solveset_real, linsolve +from sympy.solvers.ode import dsolve +from sympy.core.relational import Equality +from itertools import islice, takewhile +from sympy.series.formal import fps +from sympy.series.fourier import fourier_series +from sympy.calculus.util import minimum + + +EmptySet = S.EmptySet +R = Rational +x, y, z = symbols('x y z') +i, j, k, l, m, n = symbols('i j k l m n', integer=True) +f = Function('f') +g = Function('g') + +# A. Boolean Logic and Quantifier Elimination +# Not implemented. + +# B. Set Theory + + +def test_B1(): + assert (FiniteSet(i, j, j, k, k, k) | FiniteSet(l, k, j) | + FiniteSet(j, m, j)) == FiniteSet(i, j, k, l, m) + + +def test_B2(): + assert (FiniteSet(i, j, j, k, k, k) & FiniteSet(l, k, j) & + FiniteSet(j, m, j)) == Intersection({j, m}, {i, j, k}, {j, k, l}) + # Previous output below. Not sure why that should be the expected output. + # There should probably be a way to rewrite Intersections that way but I + # don't see why an Intersection should evaluate like that: + # + # == Union({j}, Intersection({m}, Union({j, k}, Intersection({i}, {l})))) + + +def test_B3(): + assert (FiniteSet(i, j, k, l, m) - FiniteSet(j) == + FiniteSet(i, k, l, m)) + + +def test_B4(): + assert (FiniteSet(*(FiniteSet(i, j)*FiniteSet(k, l))) == + FiniteSet((i, k), (i, l), (j, k), (j, l))) + + +# C. Numbers + + +def test_C1(): + assert (factorial(50) == + 30414093201713378043612608166064768844377641568960512000000000000) + + +def test_C2(): + assert (factorint(factorial(50)) == {2: 47, 3: 22, 5: 12, 7: 8, + 11: 4, 13: 3, 17: 2, 19: 2, 23: 2, 29: 1, 31: 1, 37: 1, + 41: 1, 43: 1, 47: 1}) + + +def test_C3(): + assert (factorial2(10), factorial2(9)) == (3840, 945) + + +# Base conversions; not really implemented by SymPy +# Whatever. Take credit! +def test_C4(): + assert 0xABC == 2748 + + +def test_C5(): + assert 123 == int('234', 7) + + +def test_C6(): + assert int('677', 8) == int('1BF', 16) == 447 + + +def test_C7(): + assert log(32768, 8) == 5 + + +def test_C8(): + # Modular multiplicative inverse. Would be nice if divmod could do this. + assert ZZ.invert(5, 7) == 3 + assert ZZ.invert(5, 6) == 5 + + +def test_C9(): + assert igcd(igcd(1776, 1554), 5698) == 74 + + +def test_C10(): + x = 0 + for n in range(2, 11): + x += R(1, n) + assert x == R(4861, 2520) + + +def test_C11(): + assert R(1, 7) == S('0.[142857]') + + +def test_C12(): + assert R(7, 11) * R(22, 7) == 2 + + +def test_C13(): + test = R(10, 7) * (1 + R(29, 1000)) ** R(1, 3) + good = 3 ** R(1, 3) + assert test == good + + +def test_C14(): + assert sqrtdenest(sqrt(2*sqrt(3) + 4)) == 1 + sqrt(3) + + +def test_C15(): + test = sqrtdenest(sqrt(14 + 3*sqrt(3 + 2*sqrt(5 - 12*sqrt(3 - 2*sqrt(2)))))) + good = sqrt(2) + 3 + assert test == good + + +def test_C16(): + test = sqrtdenest(sqrt(10 + 2*sqrt(6) + 2*sqrt(10) + 2*sqrt(15))) + good = sqrt(2) + sqrt(3) + sqrt(5) + assert test == good + + +def test_C17(): + test = radsimp((sqrt(3) + sqrt(2)) / (sqrt(3) - sqrt(2))) + good = 5 + 2*sqrt(6) + assert test == good + + +def test_C18(): + assert simplify((sqrt(-2 + sqrt(-5)) * sqrt(-2 - sqrt(-5))).expand(complex=True)) == 3 + + +@XFAIL +def test_C19(): + assert radsimp(simplify((90 + 34*sqrt(7)) ** R(1, 3))) == 3 + sqrt(7) + + +def test_C20(): + inside = (135 + 78*sqrt(3)) + test = AlgebraicNumber((inside**R(2, 3) + 3) * sqrt(3) / inside**R(1, 3)) + assert simplify(test) == AlgebraicNumber(12) + + +def test_C21(): + assert simplify(AlgebraicNumber((41 + 29*sqrt(2)) ** R(1, 5))) == \ + AlgebraicNumber(1 + sqrt(2)) + + +@XFAIL +def test_C22(): + test = simplify(((6 - 4*sqrt(2))*log(3 - 2*sqrt(2)) + (3 - 2*sqrt(2))*log(17 + - 12*sqrt(2)) + 32 - 24*sqrt(2)) / (48*sqrt(2) - 72)) + good = sqrt(2)/3 - log(sqrt(2) - 1)/3 + assert test == good + + +def test_C23(): + assert 2 * oo - 3 is oo + + +@XFAIL +def test_C24(): + raise NotImplementedError("2**aleph_null == aleph_1") + +# D. Numerical Analysis + + +def test_D1(): + assert 0.0 / sqrt(2) == 0 + + +def test_D2(): + assert str(exp(-1000000).evalf()) == '3.29683147808856e-434295' + + +def test_D3(): + assert exp(pi*sqrt(163)).evalf(50).num.ae(262537412640768744) + + +def test_D4(): + assert floor(R(-5, 3)) == -2 + assert ceiling(R(-5, 3)) == -1 + + +@XFAIL +def test_D5(): + raise NotImplementedError("cubic_spline([1, 2, 4, 5], [1, 4, 2, 3], x)(3) == 27/8") + + +@XFAIL +def test_D6(): + raise NotImplementedError("translate sum(a[i]*x**i, (i,1,n)) to FORTRAN") + + +@XFAIL +def test_D7(): + raise NotImplementedError("translate sum(a[i]*x**i, (i,1,n)) to C") + + +@XFAIL +def test_D8(): + # One way is to cheat by converting the sum to a string, + # and replacing the '[' and ']' with ''. + # E.g., horner(S(str(_).replace('[','').replace(']',''))) + raise NotImplementedError("apply Horner's rule to sum(a[i]*x**i, (i,1,5))") + + +@XFAIL +def test_D9(): + raise NotImplementedError("translate D8 to FORTRAN") + + +@XFAIL +def test_D10(): + raise NotImplementedError("translate D8 to C") + + +@XFAIL +def test_D11(): + #Is there a way to use count_ops? + raise NotImplementedError("flops(sum(product(f[i][k], (i,1,k)), (k,1,n)))") + + +@XFAIL +def test_D12(): + assert (mpi(-4, 2) * x + mpi(1, 3)) ** 2 == mpi(-8, 16)*x**2 + mpi(-24, 12)*x + mpi(1, 9) + + +@XFAIL +def test_D13(): + raise NotImplementedError("discretize a PDE: diff(f(x,t),t) == diff(diff(f(x,t),x),x)") + +# E. Statistics +# See scipy; all of this is numerical. + +# F. Combinatorial Theory. + + +def test_F1(): + assert rf(x, 3) == x*(1 + x)*(2 + x) + + +def test_F2(): + assert expand_func(binomial(n, 3)) == n*(n - 1)*(n - 2)/6 + + +@XFAIL +def test_F3(): + assert combsimp(2**n * factorial(n) * factorial2(2*n - 1)) == factorial(2*n) + + +@XFAIL +def test_F4(): + assert combsimp(2**n * factorial(n) * product(2*k - 1, (k, 1, n))) == factorial(2*n) + + +@XFAIL +def test_F5(): + assert gamma(n + R(1, 2)) / sqrt(pi) / factorial(n) == factorial(2*n)/2**(2*n)/factorial(n)**2 + + +def test_F6(): + partTest = [p.copy() for p in partitions(4)] + partDesired = [{4: 1}, {1: 1, 3: 1}, {2: 2}, {1: 2, 2:1}, {1: 4}] + assert partTest == partDesired + + +def test_F7(): + assert partition(4) == 5 + + +def test_F8(): + assert stirling(5, 2, signed=True) == -50 # if signed, then kind=1 + + +def test_F9(): + assert totient(1776) == 576 + +# G. Number Theory + + +def test_G1(): + assert list(primerange(999983, 1000004)) == [999983, 1000003] + + +@XFAIL +def test_G2(): + raise NotImplementedError("find the primitive root of 191 == 19") + + +@XFAIL +def test_G3(): + raise NotImplementedError("(a+b)**p mod p == a**p + b**p mod p; p prime") + +# ... G14 Modular equations are not implemented. + +def test_G15(): + assert Rational(sqrt(3).evalf()).limit_denominator(15) == R(26, 15) + assert list(takewhile(lambda x: x.q <= 15, cf_c(cf_i(sqrt(3)))))[-1] == \ + R(26, 15) + + +def test_G16(): + assert list(islice(cf_i(pi),10)) == [3, 7, 15, 1, 292, 1, 1, 1, 2, 1] + + +def test_G17(): + assert cf_p(0, 1, 23) == [4, [1, 3, 1, 8]] + + +def test_G18(): + assert cf_p(1, 2, 5) == [[1]] + assert cf_r([[1]]).expand() == S.Half + sqrt(5)/2 + + +@XFAIL +def test_G19(): + s = symbols('s', integer=True, positive=True) + it = cf_i((exp(1/s) - 1)/(exp(1/s) + 1)) + assert list(islice(it, 5)) == [0, 2*s, 6*s, 10*s, 14*s] + + +def test_G20(): + s = symbols('s', integer=True, positive=True) + # Wester erroneously has this as -s + sqrt(s**2 + 1) + assert cf_r([[2*s]]) == s + sqrt(s**2 + 1) + + +@XFAIL +def test_G20b(): + s = symbols('s', integer=True, positive=True) + assert cf_p(s, 1, s**2 + 1) == [[2*s]] + + +# H. Algebra + + +def test_H1(): + assert simplify(2*2**n) == simplify(2**(n + 1)) + assert powdenest(2*2**n) == simplify(2**(n + 1)) + + +def test_H2(): + assert powsimp(4 * 2**n) == 2**(n + 2) + + +def test_H3(): + assert (-1)**(n*(n + 1)) == 1 + + +def test_H4(): + expr = factor(6*x - 10) + assert type(expr) is Mul + assert expr.args[0] == 2 + assert expr.args[1] == 3*x - 5 + +p1 = 64*x**34 - 21*x**47 - 126*x**8 - 46*x**5 - 16*x**60 - 81 +p2 = 72*x**60 - 25*x**25 - 19*x**23 - 22*x**39 - 83*x**52 + 54*x**10 + 81 +q = 34*x**19 - 25*x**16 + 70*x**7 + 20*x**3 - 91*x - 86 + + +def test_H5(): + assert gcd(p1, p2, x) == 1 + + +def test_H6(): + assert gcd(expand(p1 * q), expand(p2 * q)) == q + + +def test_H7(): + p1 = 24*x*y**19*z**8 - 47*x**17*y**5*z**8 + 6*x**15*y**9*z**2 - 3*x**22 + 5 + p2 = 34*x**5*y**8*z**13 + 20*x**7*y**7*z**7 + 12*x**9*y**16*z**4 + 80*y**14*z + assert gcd(p1, p2, x, y, z) == 1 + + +def test_H8(): + p1 = 24*x*y**19*z**8 - 47*x**17*y**5*z**8 + 6*x**15*y**9*z**2 - 3*x**22 + 5 + p2 = 34*x**5*y**8*z**13 + 20*x**7*y**7*z**7 + 12*x**9*y**16*z**4 + 80*y**14*z + q = 11*x**12*y**7*z**13 - 23*x**2*y**8*z**10 + 47*x**17*y**5*z**8 + assert gcd(p1 * q, p2 * q, x, y, z) == q + + +def test_H9(): + x = Symbol('x', zero=False) + p1 = 2*x**(n + 4) - x**(n + 2) + p2 = 4*x**(n + 1) + 3*x**n + assert gcd(p1, p2) == x**n + + +def test_H10(): + p1 = 3*x**4 + 3*x**3 + x**2 - x - 2 + p2 = x**3 - 3*x**2 + x + 5 + assert resultant(p1, p2, x) == 0 + + +def test_H11(): + assert resultant(p1 * q, p2 * q, x) == 0 + + +def test_H12(): + num = x**2 - 4 + den = x**2 + 4*x + 4 + assert simplify(num/den) == (x - 2)/(x + 2) + + +@XFAIL +def test_H13(): + assert simplify((exp(x) - 1) / (exp(x/2) + 1)) == exp(x/2) - 1 + + +def test_H14(): + p = (x + 1) ** 20 + ep = expand(p) + assert ep == (1 + 20*x + 190*x**2 + 1140*x**3 + 4845*x**4 + 15504*x**5 + + 38760*x**6 + 77520*x**7 + 125970*x**8 + 167960*x**9 + 184756*x**10 + + 167960*x**11 + 125970*x**12 + 77520*x**13 + 38760*x**14 + 15504*x**15 + + 4845*x**16 + 1140*x**17 + 190*x**18 + 20*x**19 + x**20) + dep = diff(ep, x) + assert dep == (20 + 380*x + 3420*x**2 + 19380*x**3 + 77520*x**4 + + 232560*x**5 + 542640*x**6 + 1007760*x**7 + 1511640*x**8 + 1847560*x**9 + + 1847560*x**10 + 1511640*x**11 + 1007760*x**12 + 542640*x**13 + + 232560*x**14 + 77520*x**15 + 19380*x**16 + 3420*x**17 + 380*x**18 + + 20*x**19) + assert factor(dep) == 20*(1 + x)**19 + + +def test_H15(): + assert simplify(Mul(*[x - r for r in solveset(x**3 + x**2 - 7)])) == x**3 + x**2 - 7 + + +def test_H16(): + assert factor(x**100 - 1) == ((x - 1)*(x + 1)*(x**2 + 1)*(x**4 - x**3 + + x**2 - x + 1)*(x**4 + x**3 + x**2 + x + 1)*(x**8 - x**6 + x**4 + - x**2 + 1)*(x**20 - x**15 + x**10 - x**5 + 1)*(x**20 + x**15 + x**10 + + x**5 + 1)*(x**40 - x**30 + x**20 - x**10 + 1)) + + +def test_H17(): + assert simplify(factor(expand(p1 * p2)) - p1*p2) == 0 + + +@XFAIL +def test_H18(): + # Factor over complex rationals. + test = factor(4*x**4 + 8*x**3 + 77*x**2 + 18*x + 153) + good = (2*x + 3*I)*(2*x - 3*I)*(x + 1 - 4*I)*(x + 1 + 4*I) + assert test == good + + +def test_H19(): + a = symbols('a') + # The idea is to let a**2 == 2, then solve 1/(a-1). Answer is a+1") + assert Poly(a - 1).invert(Poly(a**2 - 2)) == a + 1 + + +@XFAIL +def test_H20(): + raise NotImplementedError("let a**2==2; (x**3 + (a-2)*x**2 - " + + "(2*a+3)*x - 3*a) / (x**2-2) = (x**2 - 2*x - 3) / (x-a)") + + +@XFAIL +def test_H21(): + raise NotImplementedError("evaluate (b+c)**4 assuming b**3==2, c**2==3. \ + Answer is 2*b + 8*c + 18*b**2 + 12*b*c + 9") + + +def test_H22(): + assert factor(x**4 - 3*x**2 + 1, modulus=5) == (x - 2)**2 * (x + 2)**2 + + +def test_H23(): + f = x**11 + x + 1 + g = (x**2 + x + 1) * (x**9 - x**8 + x**6 - x**5 + x**3 - x**2 + 1) + assert factor(f, modulus=65537) == g + + +def test_H24(): + phi = AlgebraicNumber(S.GoldenRatio.expand(func=True), alias='phi') + assert factor(x**4 - 3*x**2 + 1, extension=phi) == \ + (x - phi)*(x + 1 - phi)*(x - 1 + phi)*(x + phi) + + +def test_H25(): + e = (x - 2*y**2 + 3*z**3) ** 20 + assert factor(expand(e)) == e + + +def test_H26(): + g = expand((sin(x) - 2*cos(y)**2 + 3*tan(z)**3)**20) + assert factor(g, expand=False) == (-sin(x) + 2*cos(y)**2 - 3*tan(z)**3)**20 + + +def test_H27(): + f = 24*x*y**19*z**8 - 47*x**17*y**5*z**8 + 6*x**15*y**9*z**2 - 3*x**22 + 5 + g = 34*x**5*y**8*z**13 + 20*x**7*y**7*z**7 + 12*x**9*y**16*z**4 + 80*y**14*z + h = -2*z*y**7 \ + *(6*x**9*y**9*z**3 + 10*x**7*z**6 + 17*y*x**5*z**12 + 40*y**7) \ + *(3*x**22 + 47*x**17*y**5*z**8 - 6*x**15*y**9*z**2 - 24*x*y**19*z**8 - 5) + assert factor(expand(f*g)) == h + + +@XFAIL +def test_H28(): + raise NotImplementedError("expand ((1 - c**2)**5 * (1 - s**2)**5 * " + + "(c**2 + s**2)**10) with c**2 + s**2 = 1. Answer is c**10*s**10.") + + +@XFAIL +def test_H29(): + assert factor(4*x**2 - 21*x*y + 20*y**2, modulus=3) == (x + y)*(x - y) + + +def test_H30(): + test = factor(x**3 + y**3, extension=sqrt(-3)) + answer = (x + y)*(x + y*(-R(1, 2) - sqrt(3)/2*I))*(x + y*(-R(1, 2) + sqrt(3)/2*I)) + assert answer == test + + +def test_H31(): + f = (x**2 + 2*x + 3)/(x**3 + 4*x**2 + 5*x + 2) + g = 2 / (x + 1)**2 - 2 / (x + 1) + 3 / (x + 2) + assert apart(f) == g + + +@XFAIL +def test_H32(): # issue 6558 + raise NotImplementedError("[A*B*C - (A*B*C)**(-1)]*A*C*B (product \ + of a non-commuting product and its inverse)") + + +def test_H33(): + A, B, C = symbols('A, B, C', commutative=False) + assert (Commutator(A, Commutator(B, C)) + + Commutator(B, Commutator(C, A)) + + Commutator(C, Commutator(A, B))).doit().expand() == 0 + + +# I. Trigonometry + +def test_I1(): + assert tan(pi*R(7, 10)) == -sqrt(1 + 2/sqrt(5)) + + +@XFAIL +def test_I2(): + assert sqrt((1 + cos(6))/2) == -cos(3) + + +def test_I3(): + assert cos(n*pi) + sin((4*n - 1)*pi/2) == (-1)**n - 1 + + +def test_I4(): + assert refine(cos(pi*cos(n*pi)) + sin(pi/2*cos(n*pi)), Q.integer(n)) == (-1)**n - 1 + + +@XFAIL +def test_I5(): + assert sin((n**5/5 + n**4/2 + n**3/3 - n/30) * pi) == 0 + + +@XFAIL +def test_I6(): + raise NotImplementedError("assuming -3*pi pi**E) + + +@XFAIL +def test_N2(): + x = symbols('x', real=True) + assert ask(x**4 - x + 1 > 0) is True + assert ask(x**4 - x + 1 > 1) is False + + +@XFAIL +def test_N3(): + x = symbols('x', real=True) + assert ask(And(Lt(-1, x), Lt(x, 1)), abs(x) < 1 ) + +@XFAIL +def test_N4(): + x, y = symbols('x y', real=True) + assert ask(2*x**2 > 2*y**2, (x > y) & (y > 0)) is True + + +@XFAIL +def test_N5(): + x, y, k = symbols('x y k', real=True) + assert ask(k*x**2 > k*y**2, (x > y) & (y > 0) & (k > 0)) is True + + +@slow +@XFAIL +def test_N6(): + x, y, k, n = symbols('x y k n', real=True) + assert ask(k*x**n > k*y**n, (x > y) & (y > 0) & (k > 0) & (n > 0)) is True + + +@XFAIL +def test_N7(): + x, y = symbols('x y', real=True) + assert ask(y > 0, (x > 1) & (y >= x - 1)) is True + + +@XFAIL +@slow +def test_N8(): + x, y, z = symbols('x y z', real=True) + assert ask(Eq(x, y) & Eq(y, z), + (x >= y) & (y >= z) & (z >= x)) + + +def test_N9(): + x = Symbol('x') + assert solveset(abs(x - 1) > 2, domain=S.Reals) == Union(Interval(-oo, -1, False, True), + Interval(3, oo, True)) + + +def test_N10(): + x = Symbol('x') + p = (x - 1)*(x - 2)*(x - 3)*(x - 4)*(x - 5) + assert solveset(expand(p) < 0, domain=S.Reals) == Union(Interval(-oo, 1, True, True), + Interval(2, 3, True, True), + Interval(4, 5, True, True)) + + +def test_N11(): + x = Symbol('x') + assert solveset(6/(x - 3) <= 3, domain=S.Reals) == Union(Interval(-oo, 3, True, True), Interval(5, oo)) + + +def test_N12(): + x = Symbol('x') + assert solveset(sqrt(x) < 2, domain=S.Reals) == Interval(0, 4, False, True) + + +def test_N13(): + x = Symbol('x') + assert solveset(sin(x) < 2, domain=S.Reals) == S.Reals + + +@XFAIL +def test_N14(): + x = Symbol('x') + # Gives 'Union(Interval(Integer(0), Mul(Rational(1, 2), pi), false, true), + # Interval(Mul(Rational(1, 2), pi), Mul(Integer(2), pi), true, false))' + # which is not the correct answer, but the provided also seems wrong. + assert solveset(sin(x) < 1, x, domain=S.Reals) == Union(Interval(-oo, pi/2, True, True), + Interval(pi/2, oo, True, True)) + + +def test_N15(): + r, t = symbols('r t') + # raises NotImplementedError: only univariate inequalities are supported + solveset(abs(2*r*(cos(t) - 1) + 1) <= 1, r, S.Reals) + + +def test_N16(): + r, t = symbols('r t') + solveset((r**2)*((cos(t) - 4)**2)*sin(t)**2 < 9, r, S.Reals) + + +@XFAIL +def test_N17(): + # currently only univariate inequalities are supported + assert solveset((x + y > 0, x - y < 0), (x, y)) == (abs(x) < y) + + +def test_O1(): + M = Matrix((1 + I, -2, 3*I)) + assert sqrt(expand(M.dot(M.H))) == sqrt(15) + + +def test_O2(): + assert Matrix((2, 2, -3)).cross(Matrix((1, 3, 1))) == Matrix([[11], + [-5], + [4]]) + +# The vector module has no way of representing vectors symbolically (without +# respect to a basis) +@XFAIL +def test_O3(): + # assert (va ^ vb) | (vc ^ vd) == -(va | vc)*(vb | vd) + (va | vd)*(vb | vc) + raise NotImplementedError("""The vector module has no way of representing + vectors symbolically (without respect to a basis)""") + +def test_O4(): + from sympy.vector import CoordSys3D, Del + N = CoordSys3D("N") + delop = Del() + i, j, k = N.base_vectors() + x, y, z = N.base_scalars() + F = i*(x*y*z) + j*((x*y*z)**2) + k*((y**2)*(z**3)) + assert delop.cross(F).doit() == (-2*x**2*y**2*z + 2*y*z**3)*i + x*y*j + (2*x*y**2*z**2 - x*z)*k + +@XFAIL +def test_O5(): + #assert grad|(f^g)-g|(grad^f)+f|(grad^g) == 0 + raise NotImplementedError("""The vector module has no way of representing + vectors symbolically (without respect to a basis)""") + +#testO8-O9 MISSING!! + + +def test_O10(): + L = [Matrix([2, 3, 5]), Matrix([3, 6, 2]), Matrix([8, 3, 6])] + assert GramSchmidt(L) == [Matrix([ + [2], + [3], + [5]]), + Matrix([ + [R(23, 19)], + [R(63, 19)], + [R(-47, 19)]]), + Matrix([ + [R(1692, 353)], + [R(-1551, 706)], + [R(-423, 706)]])] + + +def test_P1(): + assert Matrix(3, 3, lambda i, j: j - i).diagonal(-1) == Matrix( + 1, 2, [-1, -1]) + + +def test_P2(): + M = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + M.row_del(1) + M.col_del(2) + assert M == Matrix([[1, 2], + [7, 8]]) + + +def test_P3(): + A = Matrix([ + [11, 12, 13, 14], + [21, 22, 23, 24], + [31, 32, 33, 34], + [41, 42, 43, 44]]) + + A11 = A[0:3, 1:4] + A12 = A[(0, 1, 3), (2, 0, 3)] + A21 = A + A221 = -A[0:2, 2:4] + A222 = -A[(3, 0), (2, 1)] + A22 = BlockMatrix([[A221, A222]]).T + rows = [[-A11, A12], [A21, A22]] + raises(ValueError, lambda: BlockMatrix(rows)) + B = Matrix(rows) + assert B == Matrix([ + [-12, -13, -14, 13, 11, 14], + [-22, -23, -24, 23, 21, 24], + [-32, -33, -34, 43, 41, 44], + [11, 12, 13, 14, -13, -23], + [21, 22, 23, 24, -14, -24], + [31, 32, 33, 34, -43, -13], + [41, 42, 43, 44, -42, -12]]) + + +@XFAIL +def test_P4(): + raise NotImplementedError("Block matrix diagonalization not supported") + + +def test_P5(): + M = Matrix([[7, 11], + [3, 8]]) + assert M % 2 == Matrix([[1, 1], + [1, 0]]) + + +def test_P6(): + M = Matrix([[cos(x), sin(x)], + [-sin(x), cos(x)]]) + assert M.diff(x, 2) == Matrix([[-cos(x), -sin(x)], + [sin(x), -cos(x)]]) + + +def test_P7(): + M = Matrix([[x, y]])*( + z*Matrix([[1, 3, 5], + [2, 4, 6]]) + Matrix([[7, -9, 11], + [-8, 10, -12]])) + assert M == Matrix([[x*(z + 7) + y*(2*z - 8), x*(3*z - 9) + y*(4*z + 10), + x*(5*z + 11) + y*(6*z - 12)]]) + + +def test_P8(): + M = Matrix([[1, -2*I], + [-3*I, 4]]) + assert M.norm(ord=S.Infinity) == 7 + + +def test_P9(): + a, b, c = symbols('a b c', nonzero=True) + M = Matrix([[a/(b*c), 1/c, 1/b], + [1/c, b/(a*c), 1/a], + [1/b, 1/a, c/(a*b)]]) + assert factor(M.norm('fro')) == (a**2 + b**2 + c**2)/(abs(a)*abs(b)*abs(c)) + + +@XFAIL +def test_P10(): + M = Matrix([[1, 2 + 3*I], + [f(4 - 5*I), 6]]) + # conjugate(f(4 - 5*i)) is not simplified to f(4+5*I) + assert M.H == Matrix([[1, f(4 + 5*I)], + [2 + 3*I, 6]]) + + +@XFAIL +def test_P11(): + # raises NotImplementedError("Matrix([[x,y],[1,x*y]]).inv() + # not simplifying to extract common factor") + assert Matrix([[x, y], + [1, x*y]]).inv() == (1/(x**2 - 1))*Matrix([[x, -1], + [-1/y, x/y]]) + + +def test_P11_workaround(): + # This test was changed to inverse method ADJ because it depended on the + # specific form of inverse returned from the 'GE' method which has changed. + M = Matrix([[x, y], [1, x*y]]).inv('ADJ') + c = gcd(tuple(M)) + assert MatMul(c, M/c, evaluate=False) == MatMul(c, Matrix([ + [x*y, -y], + [ -1, x]]), evaluate=False) + + +def test_P12(): + A11 = MatrixSymbol('A11', n, n) + A12 = MatrixSymbol('A12', n, n) + A22 = MatrixSymbol('A22', n, n) + B = BlockMatrix([[A11, A12], + [ZeroMatrix(n, n), A22]]) + assert block_collapse(B.I) == BlockMatrix([[A11.I, (-1)*A11.I*A12*A22.I], + [ZeroMatrix(n, n), A22.I]]) + + +def test_P13(): + M = Matrix([[1, x - 2, x - 3], + [x - 1, x**2 - 3*x + 6, x**2 - 3*x - 2], + [x - 2, x**2 - 8, 2*(x**2) - 12*x + 14]]) + L, U, _ = M.LUdecomposition() + assert simplify(L) == Matrix([[1, 0, 0], + [x - 1, 1, 0], + [x - 2, x - 3, 1]]) + assert simplify(U) == Matrix([[1, x - 2, x - 3], + [0, 4, x - 5], + [0, 0, x - 7]]) + + +def test_P14(): + M = Matrix([[1, 2, 3, 1, 3], + [3, 2, 1, 1, 7], + [0, 2, 4, 1, 1], + [1, 1, 1, 1, 4]]) + R, _ = M.rref() + assert R == Matrix([[1, 0, -1, 0, 2], + [0, 1, 2, 0, -1], + [0, 0, 0, 1, 3], + [0, 0, 0, 0, 0]]) + + +def test_P15(): + M = Matrix([[-1, 3, 7, -5], + [4, -2, 1, 3], + [2, 4, 15, -7]]) + assert M.rank() == 2 + + +def test_P16(): + M = Matrix([[2*sqrt(2), 8], + [6*sqrt(6), 24*sqrt(3)]]) + assert M.rank() == 1 + + +def test_P17(): + t = symbols('t', real=True) + M=Matrix([ + [sin(2*t), cos(2*t)], + [2*(1 - (cos(t)**2))*cos(t), (1 - 2*(sin(t)**2))*sin(t)]]) + assert M.rank() == 1 + + +def test_P18(): + M = Matrix([[1, 0, -2, 0], + [-2, 1, 0, 3], + [-1, 2, -6, 6]]) + assert M.nullspace() == [Matrix([[2], + [4], + [1], + [0]]), + Matrix([[0], + [-3], + [0], + [1]])] + + +def test_P19(): + w = symbols('w') + M = Matrix([[1, 1, 1, 1], + [w, x, y, z], + [w**2, x**2, y**2, z**2], + [w**3, x**3, y**3, z**3]]) + assert M.det() == (w**3*x**2*y - w**3*x**2*z - w**3*x*y**2 + w**3*x*z**2 + + w**3*y**2*z - w**3*y*z**2 - w**2*x**3*y + w**2*x**3*z + + w**2*x*y**3 - w**2*x*z**3 - w**2*y**3*z + w**2*y*z**3 + + w*x**3*y**2 - w*x**3*z**2 - w*x**2*y**3 + w*x**2*z**3 + + w*y**3*z**2 - w*y**2*z**3 - x**3*y**2*z + x**3*y*z**2 + + x**2*y**3*z - x**2*y*z**3 - x*y**3*z**2 + x*y**2*z**3 + ) + + +@XFAIL +def test_P20(): + raise NotImplementedError("Matrix minimal polynomial not supported") + + +def test_P21(): + M = Matrix([[5, -3, -7], + [-2, 1, 2], + [2, -3, -4]]) + assert M.charpoly(x).as_expr() == x**3 - 2*x**2 - 5*x + 6 + + +def test_P22(): + d = 100 + M = (2 - x)*eye(d) + assert M.eigenvals() == {-x + 2: d} + + +def test_P23(): + M = Matrix([ + [2, 1, 0, 0, 0], + [1, 2, 1, 0, 0], + [0, 1, 2, 1, 0], + [0, 0, 1, 2, 1], + [0, 0, 0, 1, 2]]) + assert M.eigenvals() == { + S('1'): 1, + S('2'): 1, + S('3'): 1, + S('sqrt(3) + 2'): 1, + S('-sqrt(3) + 2'): 1} + + +def test_P24(): + M = Matrix([[611, 196, -192, 407, -8, -52, -49, 29], + [196, 899, 113, -192, -71, -43, -8, -44], + [-192, 113, 899, 196, 61, 49, 8, 52], + [ 407, -192, 196, 611, 8, 44, 59, -23], + [ -8, -71, 61, 8, 411, -599, 208, 208], + [ -52, -43, 49, 44, -599, 411, 208, 208], + [ -49, -8, 8, 59, 208, 208, 99, -911], + [ 29, -44, 52, -23, 208, 208, -911, 99]]) + assert M.eigenvals() == { + S('0'): 1, + S('10*sqrt(10405)'): 1, + S('100*sqrt(26) + 510'): 1, + S('1000'): 2, + S('-100*sqrt(26) + 510'): 1, + S('-10*sqrt(10405)'): 1, + S('1020'): 1} + + +def test_P25(): + MF = N(Matrix([[ 611, 196, -192, 407, -8, -52, -49, 29], + [ 196, 899, 113, -192, -71, -43, -8, -44], + [-192, 113, 899, 196, 61, 49, 8, 52], + [ 407, -192, 196, 611, 8, 44, 59, -23], + [ -8, -71, 61, 8, 411, -599, 208, 208], + [ -52, -43, 49, 44, -599, 411, 208, 208], + [ -49, -8, 8, 59, 208, 208, 99, -911], + [ 29, -44, 52, -23, 208, 208, -911, 99]])) + + ev_1 = sorted(MF.eigenvals(multiple=True)) + ev_2 = sorted( + [-1020.0490184299969, 0.0, 0.09804864072151699, 1000.0, 1000.0, + 1019.9019513592784, 1020.0, 1020.0490184299969]) + + for x, y in zip(ev_1, ev_2): + assert abs(x - y) < 1e-12 + + +def test_P26(): + a0, a1, a2, a3, a4 = symbols('a0 a1 a2 a3 a4') + M = Matrix([[-a4, -a3, -a2, -a1, -a0, 0, 0, 0, 0], + [ 1, 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 1, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 1, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 1, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, -1, -1, 0, 0], + [ 0, 0, 0, 0, 0, 1, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 1, -1, -1], + [ 0, 0, 0, 0, 0, 0, 0, 1, 0]]) + assert M.eigenvals(error_when_incomplete=False) == { + S('-1/2 - sqrt(3)*I/2'): 2, + S('-1/2 + sqrt(3)*I/2'): 2} + + +def test_P27(): + a = symbols('a') + M = Matrix([[a, 0, 0, 0, 0], + [0, 0, 0, 0, 1], + [0, 0, a, 0, 0], + [0, 0, 0, a, 0], + [0, -2, 0, 0, 2]]) + + assert M.eigenvects() == [ + (a, 3, [ + Matrix([1, 0, 0, 0, 0]), + Matrix([0, 0, 1, 0, 0]), + Matrix([0, 0, 0, 1, 0]) + ]), + (1 - I, 1, [ + Matrix([0, (1 + I)/2, 0, 0, 1]) + ]), + (1 + I, 1, [ + Matrix([0, (1 - I)/2, 0, 0, 1]) + ]), + ] + + +@XFAIL +def test_P28(): + raise NotImplementedError("Generalized eigenvectors not supported \ +https://github.com/sympy/sympy/issues/5293") + + +@XFAIL +def test_P29(): + raise NotImplementedError("Generalized eigenvectors not supported \ +https://github.com/sympy/sympy/issues/5293") + + +def test_P30(): + M = Matrix([[1, 0, 0, 1, -1], + [0, 1, -2, 3, -3], + [0, 0, -1, 2, -2], + [1, -1, 1, 0, 1], + [1, -1, 1, -1, 2]]) + _, J = M.jordan_form() + assert J == Matrix([[-1, 0, 0, 0, 0], + [0, 1, 1, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 1], + [0, 0, 0, 0, 1]]) + + +@XFAIL +def test_P31(): + raise NotImplementedError("Smith normal form not implemented") + + +def test_P32(): + M = Matrix([[1, -2], + [2, 1]]) + assert exp(M).rewrite(cos).simplify() == Matrix([[E*cos(2), -E*sin(2)], + [E*sin(2), E*cos(2)]]) + + +def test_P33(): + w, t = symbols('w t') + M = Matrix([[0, 1, 0, 0], + [0, 0, 0, 2*w], + [0, 0, 0, 1], + [0, -2*w, 3*w**2, 0]]) + assert exp(M*t).rewrite(cos).expand() == Matrix([ + [1, -3*t + 4*sin(t*w)/w, 6*t*w - 6*sin(t*w), -2*cos(t*w)/w + 2/w], + [0, 4*cos(t*w) - 3, -6*w*cos(t*w) + 6*w, 2*sin(t*w)], + [0, 2*cos(t*w)/w - 2/w, -3*cos(t*w) + 4, sin(t*w)/w], + [0, -2*sin(t*w), 3*w*sin(t*w), cos(t*w)]]) + + +@XFAIL +def test_P34(): + a, b, c = symbols('a b c', real=True) + M = Matrix([[a, 1, 0, 0, 0, 0], + [0, a, 0, 0, 0, 0], + [0, 0, b, 0, 0, 0], + [0, 0, 0, c, 1, 0], + [0, 0, 0, 0, c, 1], + [0, 0, 0, 0, 0, c]]) + # raises exception, sin(M) not supported. exp(M*I) also not supported + # https://github.com/sympy/sympy/issues/6218 + assert sin(M) == Matrix([[sin(a), cos(a), 0, 0, 0, 0], + [0, sin(a), 0, 0, 0, 0], + [0, 0, sin(b), 0, 0, 0], + [0, 0, 0, sin(c), cos(c), -sin(c)/2], + [0, 0, 0, 0, sin(c), cos(c)], + [0, 0, 0, 0, 0, sin(c)]]) + + +@XFAIL +def test_P35(): + M = pi/2*Matrix([[2, 1, 1], + [2, 3, 2], + [1, 1, 2]]) + # raises exception, sin(M) not supported. exp(M*I) also not supported + # https://github.com/sympy/sympy/issues/6218 + assert sin(M) == eye(3) + + +@XFAIL +def test_P36(): + M = Matrix([[10, 7], + [7, 17]]) + assert sqrt(M) == Matrix([[3, 1], + [1, 4]]) + + +def test_P37(): + M = Matrix([[1, 1, 0], + [0, 1, 0], + [0, 0, 1]]) + assert M**S.Half == Matrix([[1, R(1, 2), 0], + [0, 1, 0], + [0, 0, 1]]) + + +@XFAIL +def test_P38(): + M=Matrix([[0, 1, 0], + [0, 0, 0], + [0, 0, 0]]) + + with raises(AssertionError): + # raises ValueError: Matrix det == 0; not invertible + M**S.Half + # if it doesn't raise then this assertion will be + # raised and the test will be flagged as not XFAILing + assert None + +@XFAIL +def test_P39(): + """ + M=Matrix([ + [1, 1], + [2, 2], + [3, 3]]) + M.SVD() + """ + raise NotImplementedError("Singular value decomposition not implemented") + + +def test_P40(): + r, t = symbols('r t', real=True) + M = Matrix([r*cos(t), r*sin(t)]) + assert M.jacobian(Matrix([r, t])) == Matrix([[cos(t), -r*sin(t)], + [sin(t), r*cos(t)]]) + + +def test_P41(): + r, t = symbols('r t', real=True) + assert hessian(r**2*sin(t),(r,t)) == Matrix([[ 2*sin(t), 2*r*cos(t)], + [2*r*cos(t), -r**2*sin(t)]]) + + +def test_P42(): + assert wronskian([cos(x), sin(x)], x).simplify() == 1 + + +def test_P43(): + def __my_jacobian(M, Y): + return Matrix([M.diff(v).T for v in Y]).T + r, t = symbols('r t', real=True) + M = Matrix([r*cos(t), r*sin(t)]) + assert __my_jacobian(M,[r,t]) == Matrix([[cos(t), -r*sin(t)], + [sin(t), r*cos(t)]]) + + +def test_P44(): + def __my_hessian(f, Y): + V = Matrix([diff(f, v) for v in Y]) + return Matrix([V.T.diff(v) for v in Y]) + r, t = symbols('r t', real=True) + assert __my_hessian(r**2*sin(t), (r, t)) == Matrix([ + [ 2*sin(t), 2*r*cos(t)], + [2*r*cos(t), -r**2*sin(t)]]) + + +def test_P45(): + def __my_wronskian(Y, v): + M = Matrix([Matrix(Y).T.diff(x, n) for n in range(0, len(Y))]) + return M.det() + assert __my_wronskian([cos(x), sin(x)], x).simplify() == 1 + +# Q1-Q6 Tensor tests missing + + +@XFAIL +def test_R1(): + i, j, n = symbols('i j n', integer=True, positive=True) + xn = MatrixSymbol('xn', n, 1) + Sm = Sum((xn[i, 0] - Sum(xn[j, 0], (j, 0, n - 1))/n)**2, (i, 0, n - 1)) + # sum does not calculate + # Unknown result + Sm.doit() + raise NotImplementedError('Unknown result') + +@XFAIL +def test_R2(): + m, b = symbols('m b') + i, n = symbols('i n', integer=True, positive=True) + xn = MatrixSymbol('xn', n, 1) + yn = MatrixSymbol('yn', n, 1) + f = Sum((yn[i, 0] - m*xn[i, 0] - b)**2, (i, 0, n - 1)) + f1 = diff(f, m) + f2 = diff(f, b) + # raises TypeError: solveset() takes at most 2 arguments (3 given) + solveset((f1, f2), (m, b), domain=S.Reals) + + +@XFAIL +def test_R3(): + n, k = symbols('n k', integer=True, positive=True) + sk = ((-1)**k) * (binomial(2*n, k))**2 + Sm = Sum(sk, (k, 1, oo)) + T = Sm.doit() + T2 = T.combsimp() + # returns -((-1)**n*factorial(2*n) + # - (factorial(n))**2)*exp_polar(-I*pi)/(factorial(n))**2 + assert T2 == (-1)**n*binomial(2*n, n) + + +@XFAIL +def test_R4(): +# Macsyma indefinite sum test case: +#(c15) /* Check whether the full Gosper algorithm is implemented +# => 1/2^(n + 1) binomial(n, k - 1) */ +#closedform(indefsum(binomial(n, k)/2^n - binomial(n + 1, k)/2^(n + 1), k)); +#Time= 2690 msecs +# (- n + k - 1) binomial(n + 1, k) +#(d15) - -------------------------------- +# n +# 2 2 (n + 1) +# +#(c16) factcomb(makefact(%)); +#Time= 220 msecs +# n! +#(d16) ---------------- +# n +# 2 k! 2 (n - k)! +# Might be possible after fixing https://github.com/sympy/sympy/pull/1879 + raise NotImplementedError("Indefinite sum not supported") + + +@XFAIL +def test_R5(): + a, b, c, n, k = symbols('a b c n k', integer=True, positive=True) + sk = ((-1)**k)*(binomial(a + b, a + k) + *binomial(b + c, b + k)*binomial(c + a, c + k)) + Sm = Sum(sk, (k, 1, oo)) + T = Sm.doit() # hypergeometric series not calculated + assert T == factorial(a+b+c)/(factorial(a)*factorial(b)*factorial(c)) + + +def test_R6(): + n, k = symbols('n k', integer=True, positive=True) + gn = MatrixSymbol('gn', n + 2, 1) + Sm = Sum(gn[k, 0] - gn[k - 1, 0], (k, 1, n + 1)) + assert Sm.doit() == -gn[0, 0] + gn[n + 1, 0] + + +def test_R7(): + n, k = symbols('n k', integer=True, positive=True) + T = Sum(k**3,(k,1,n)).doit() + assert T.factor() == n**2*(n + 1)**2/4 + +@XFAIL +def test_R8(): + n, k = symbols('n k', integer=True, positive=True) + Sm = Sum(k**2*binomial(n, k), (k, 1, n)) + T = Sm.doit() #returns Piecewise function + assert T.combsimp() == n*(n + 1)*2**(n - 2) + + +def test_R9(): + n, k = symbols('n k', integer=True, positive=True) + Sm = Sum(binomial(n, k - 1)/k, (k, 1, n + 1)) + assert Sm.doit().simplify() == (2**(n + 1) - 1)/(n + 1) + + +@XFAIL +def test_R10(): + n, m, r, k = symbols('n m r k', integer=True, positive=True) + Sm = Sum(binomial(n, k)*binomial(m, r - k), (k, 0, r)) + T = Sm.doit() + T2 = T.combsimp().rewrite(factorial) + assert T2 == factorial(m + n)/(factorial(r)*factorial(m + n - r)) + assert T2 == binomial(m + n, r).rewrite(factorial) + # rewrite(binomial) is not working. + # https://github.com/sympy/sympy/issues/7135 + T3 = T2.rewrite(binomial) + assert T3 == binomial(m + n, r) + + +@XFAIL +def test_R11(): + n, k = symbols('n k', integer=True, positive=True) + sk = binomial(n, k)*fibonacci(k) + Sm = Sum(sk, (k, 0, n)) + T = Sm.doit() + # Fibonacci simplification not implemented + # https://github.com/sympy/sympy/issues/7134 + assert T == fibonacci(2*n) + + +@XFAIL +def test_R12(): + n, k = symbols('n k', integer=True, positive=True) + Sm = Sum(fibonacci(k)**2, (k, 0, n)) + T = Sm.doit() + assert T == fibonacci(n)*fibonacci(n + 1) + + +@XFAIL +def test_R13(): + n, k = symbols('n k', integer=True, positive=True) + Sm = Sum(sin(k*x), (k, 1, n)) + T = Sm.doit() # Sum is not calculated + assert T.simplify() == cot(x/2)/2 - cos(x*(2*n + 1)/2)/(2*sin(x/2)) + + +@XFAIL +def test_R14(): + n, k = symbols('n k', integer=True, positive=True) + Sm = Sum(sin((2*k - 1)*x), (k, 1, n)) + T = Sm.doit() # Sum is not calculated + assert T.simplify() == sin(n*x)**2/sin(x) + + +@XFAIL +def test_R15(): + n, k = symbols('n k', integer=True, positive=True) + Sm = Sum(binomial(n - k, k), (k, 0, floor(n/2))) + T = Sm.doit() # Sum is not calculated + assert T.simplify() == fibonacci(n + 1) + + +def test_R16(): + k = symbols('k', integer=True, positive=True) + Sm = Sum(1/k**2 + 1/k**3, (k, 1, oo)) + assert Sm.doit() == zeta(3) + pi**2/6 + + +def test_R17(): + k = symbols('k', integer=True, positive=True) + assert abs(float(Sum(1/k**2 + 1/k**3, (k, 1, oo))) + - 2.8469909700078206) < 1e-15 + + +def test_R18(): + k = symbols('k', integer=True, positive=True) + Sm = Sum(1/(2**k*k**2), (k, 1, oo)) + T = Sm.doit() + assert T.simplify() == -log(2)**2/2 + pi**2/12 + + +@slow +@XFAIL +def test_R19(): + k = symbols('k', integer=True, positive=True) + Sm = Sum(1/((3*k + 1)*(3*k + 2)*(3*k + 3)), (k, 0, oo)) + T = Sm.doit() + # assert fails, T not simplified + assert T.simplify() == -log(3)/4 + sqrt(3)*pi/12 + + +@XFAIL +def test_R20(): + n, k = symbols('n k', integer=True, positive=True) + Sm = Sum(binomial(n, 4*k), (k, 0, oo)) + T = Sm.doit() + # assert fails, T not simplified + assert T.simplify() == 2**(n/2)*cos(pi*n/4)/2 + 2**(n - 1)/2 + + +@XFAIL +def test_R21(): + k = symbols('k', integer=True, positive=True) + Sm = Sum(1/(sqrt(k*(k + 1)) * (sqrt(k) + sqrt(k + 1))), (k, 1, oo)) + T = Sm.doit() # Sum not calculated + assert T.simplify() == 1 + + +# test_R22 answer not available in Wester samples +# Sum(Sum(binomial(n, k)*binomial(n - k, n - 2*k)*x**n*y**(n - 2*k), +# (k, 0, floor(n/2))), (n, 0, oo)) with abs(x*y)<1? + + +@XFAIL +def test_R23(): + n, k = symbols('n k', integer=True, positive=True) + Sm = Sum(Sum((factorial(n)/(factorial(k)**2*factorial(n - 2*k)))* + (x/y)**k*(x*y)**(n - k), (n, 2*k, oo)), (k, 0, oo)) + # Missing how to express constraint abs(x*y)<1? + T = Sm.doit() # Sum not calculated + assert T == -1/sqrt(x**2*y**2 - 4*x**2 - 2*x*y + 1) + + +def test_R24(): + m, k = symbols('m k', integer=True, positive=True) + Sm = Sum(Product(k/(2*k - 1), (k, 1, m)), (m, 2, oo)) + assert Sm.doit() == pi/2 + + +def test_S1(): + k = symbols('k', integer=True, positive=True) + Pr = Product(gamma(k/3), (k, 1, 8)) + assert Pr.doit().simplify() == 640*sqrt(3)*pi**3/6561 + + +def test_S2(): + n, k = symbols('n k', integer=True, positive=True) + assert Product(k, (k, 1, n)).doit() == factorial(n) + + +def test_S3(): + n, k = symbols('n k', integer=True, positive=True) + assert Product(x**k, (k, 1, n)).doit().simplify() == x**(n*(n + 1)/2) + + +def test_S4(): + n, k = symbols('n k', integer=True, positive=True) + assert Product(1 + 1/k, (k, 1, n -1)).doit().simplify() == n + + +def test_S5(): + n, k = symbols('n k', integer=True, positive=True) + assert (Product((2*k - 1)/(2*k), (k, 1, n)).doit().gammasimp() == + gamma(n + S.Half)/(sqrt(pi)*gamma(n + 1))) + + +@XFAIL +def test_S6(): + n, k = symbols('n k', integer=True, positive=True) + # Product does not evaluate + assert (Product(x**2 -2*x*cos(k*pi/n) + 1, (k, 1, n - 1)).doit().simplify() + == (x**(2*n) - 1)/(x**2 - 1)) + + +@XFAIL +def test_S7(): + k = symbols('k', integer=True, positive=True) + Pr = Product((k**3 - 1)/(k**3 + 1), (k, 2, oo)) + T = Pr.doit() # Product does not evaluate + assert T.simplify() == R(2, 3) + + +@XFAIL +def test_S8(): + k = symbols('k', integer=True, positive=True) + Pr = Product(1 - 1/(2*k)**2, (k, 1, oo)) + T = Pr.doit() + # Product does not evaluate + assert T.simplify() == 2/pi + + +@XFAIL +def test_S9(): + k = symbols('k', integer=True, positive=True) + Pr = Product(1 + (-1)**(k + 1)/(2*k - 1), (k, 1, oo)) + T = Pr.doit() + # Product produces 0 + # https://github.com/sympy/sympy/issues/7133 + assert T.simplify() == sqrt(2) + + +@XFAIL +def test_S10(): + k = symbols('k', integer=True, positive=True) + Pr = Product((k*(k + 1) + 1 + I)/(k*(k + 1) + 1 - I), (k, 0, oo)) + T = Pr.doit() + # Product does not evaluate + assert T.simplify() == -1 + + +def test_T1(): + assert limit((1 + 1/n)**n, n, oo) == E + assert limit((1 - cos(x))/x**2, x, 0) == S.Half + + +def test_T2(): + assert limit((3**x + 5**x)**(1/x), x, oo) == 5 + + +def test_T3(): + assert limit(log(x)/(log(x) + sin(x)), x, oo) == 1 + + +def test_T4(): + assert limit((exp(x*exp(-x)/(exp(-x) + exp(-2*x**2/(x + 1)))) + - exp(x))/x, x, oo) == -exp(2) + + +def test_T5(): + assert limit(x*log(x)*log(x*exp(x) - x**2)**2/log(log(x**2 + + 2*exp(exp(3*x**3*log(x))))), x, oo) == R(1, 3) + + +def test_T6(): + assert limit(1/n * factorial(n)**(1/n), n, oo) == exp(-1) + + +def test_T7(): + limit(1/n * gamma(n + 1)**(1/n), n, oo) + + +def test_T8(): + a, z = symbols('a z', positive=True) + assert limit(gamma(z + a)/gamma(z)*exp(-a*log(z)), z, oo) == 1 + + +@XFAIL +def test_T9(): + z, k = symbols('z k', positive=True) + # raises NotImplementedError: + # Don't know how to calculate the mrv of '(1, k)' + assert limit(hyper((1, k), (1,), z/k), k, oo) == exp(z) + + +@XFAIL +def test_T10(): + # No longer raises PoleError, but should return euler-mascheroni constant + assert limit(zeta(x) - 1/(x - 1), x, 1) == integrate(-1/x + 1/floor(x), (x, 1, oo)) + +@XFAIL +def test_T11(): + n, k = symbols('n k', integer=True, positive=True) + # evaluates to 0 + assert limit(n**x/(x*product((1 + x/k), (k, 1, n))), n, oo) == gamma(x) + + +def test_T12(): + x, t = symbols('x t', real=True) + # Does not evaluate the limit but returns an expression with erf + assert limit(x * integrate(exp(-t**2), (t, 0, x))/(1 - exp(-x**2)), + x, 0) == 1 + + +def test_T13(): + x = symbols('x', real=True) + assert [limit(x/abs(x), x, 0, dir='-'), + limit(x/abs(x), x, 0, dir='+')] == [-1, 1] + + +def test_T14(): + x = symbols('x', real=True) + assert limit(atan(-log(x)), x, 0, dir='+') == pi/2 + + +def test_U1(): + x = symbols('x', real=True) + assert diff(abs(x), x) == sign(x) + + +def test_U2(): + f = Lambda(x, Piecewise((-x, x < 0), (x, x >= 0))) + assert diff(f(x), x) == Piecewise((-1, x < 0), (1, x >= 0)) + + +def test_U3(): + f = Lambda(x, Piecewise((x**2 - 1, x == 1), (x**3, x != 1))) + f1 = Lambda(x, diff(f(x), x)) + assert f1(x) == 3*x**2 + assert f1(1) == 3 + + +@XFAIL +def test_U4(): + n = symbols('n', integer=True, positive=True) + x = symbols('x', real=True) + d = diff(x**n, x, n) + assert d.rewrite(factorial) == factorial(n) + + +def test_U5(): + # issue 6681 + t = symbols('t') + ans = ( + Derivative(f(g(t)), g(t))*Derivative(g(t), (t, 2)) + + Derivative(f(g(t)), (g(t), 2))*Derivative(g(t), t)**2) + assert f(g(t)).diff(t, 2) == ans + assert ans.doit() == ans + + +def test_U6(): + h = Function('h') + T = integrate(f(y), (y, h(x), g(x))) + assert T.diff(x) == ( + f(g(x))*Derivative(g(x), x) - f(h(x))*Derivative(h(x), x)) + + +@XFAIL +def test_U7(): + p, t = symbols('p t', real=True) + # Exact differential => d(V(P, T)) => dV/dP DP + dV/dT DT + # raises ValueError: Since there is more than one variable in the + # expression, the variable(s) of differentiation must be supplied to + # differentiate f(p,t) + diff(f(p, t)) + + +def test_U8(): + x, y = symbols('x y', real=True) + eq = cos(x*y) + x + # If SymPy had implicit_diff() function this hack could be avoided + # TODO: Replace solve with solveset, current test fails for solveset + assert idiff(y - eq, y, x) == (-y*sin(x*y) + 1)/(x*sin(x*y) + 1) + + +def test_U9(): + # Wester sample case for Maple: + # O29 := diff(f(x, y), x) + diff(f(x, y), y); + # /d \ /d \ + # |-- f(x, y)| + |-- f(x, y)| + # \dx / \dy / + # + # O30 := factor(subs(f(x, y) = g(x^2 + y^2), %)); + # 2 2 + # 2 D(g)(x + y ) (x + y) + x, y = symbols('x y', real=True) + su = diff(f(x, y), x) + diff(f(x, y), y) + s2 = su.subs(f(x, y), g(x**2 + y**2)) + s3 = s2.doit().factor() + # Subs not performed, s3 = 2*(x + y)*Subs(Derivative( + # g(_xi_1), _xi_1), _xi_1, x**2 + y**2) + # Derivative(g(x*2 + y**2), x**2 + y**2) is not valid in SymPy, + # and probably will remain that way. You can take derivatives with respect + # to other expressions only if they are atomic, like a symbol or a + # function. + # D operator should be added to SymPy + # See https://github.com/sympy/sympy/issues/4719. + assert s3 == (x + y)*Subs(Derivative(g(x), x), x, x**2 + y**2)*2 + + +def test_U10(): + # see issue 2519: + assert residue((z**3 + 5)/((z**4 - 1)*(z + 1)), z, -1) == R(-9, 4) + +@XFAIL +def test_U11(): + # assert (2*dx + dz) ^ (3*dx + dy + dz) ^ (dx + dy + 4*dz) == 8*dx ^ dy ^dz + raise NotImplementedError + + +@XFAIL +def test_U12(): + # Wester sample case: + # (c41) /* d(3 x^5 dy /\ dz + 5 x y^2 dz /\ dx + 8 z dx /\ dy) + # => (15 x^4 + 10 x y + 8) dx /\ dy /\ dz */ + # factor(ext_diff(3*x^5 * dy ~ dz + 5*x*y^2 * dz ~ dx + 8*z * dx ~ dy)); + # 4 + # (d41) (10 x y + 15 x + 8) dx dy dz + raise NotImplementedError( + "External diff of differential form not supported") + + +def test_U13(): + assert minimum(x**4 - x + 1, x) == -3*2**R(1,3)/8 + 1 + + +@XFAIL +def test_U14(): + #f = 1/(x**2 + y**2 + 1) + #assert [minimize(f), maximize(f)] == [0,1] + raise NotImplementedError("minimize(), maximize() not supported") + + +@XFAIL +def test_U15(): + raise NotImplementedError("minimize() not supported and also solve does \ +not support multivariate inequalities") + + +@XFAIL +def test_U16(): + raise NotImplementedError("minimize() not supported in SymPy and also \ +solve does not support multivariate inequalities") + + +@XFAIL +def test_U17(): + raise NotImplementedError("Linear programming, symbolic simplex not \ +supported in SymPy") + + +def test_V1(): + x = symbols('x', real=True) + assert integrate(abs(x), x) == Piecewise((-x**2/2, x <= 0), (x**2/2, True)) + + +def test_V2(): + assert integrate(Piecewise((-x, x < 0), (x, x >= 0)), x + ) == Piecewise((-x**2/2, x < 0), (x**2/2, True)) + + +def test_V3(): + assert integrate(1/(x**3 + 2),x).diff().simplify() == 1/(x**3 + 2) + + +def test_V4(): + assert integrate(2**x/sqrt(1 + 4**x), x) == asinh(2**x)/log(2) + + +@XFAIL +def test_V5(): + # Returns (-45*x**2 + 80*x - 41)/(5*sqrt(2*x - 1)*(4*x**2 - 4*x + 1)) + assert (integrate((3*x - 5)**2/(2*x - 1)**R(7, 2), x).simplify() == + (-41 + 80*x - 45*x**2)/(5*(2*x - 1)**R(5, 2))) + + +@XFAIL +def test_V6(): + # returns RootSum(40*_z**2 - 1, Lambda(_i, _i*log(-4*_i + exp(-m*x))))/m + assert (integrate(1/(2*exp(m*x) - 5*exp(-m*x)), x) == sqrt(10)*( + log(2*exp(m*x) - sqrt(10)) - log(2*exp(m*x) + sqrt(10)))/(20*m)) + + +def test_V7(): + r1 = integrate(sinh(x)**4/cosh(x)**2) + assert r1.simplify() == x*R(-3, 2) + sinh(x)**3/(2*cosh(x)) + 3*tanh(x)/2 + + +@XFAIL +def test_V8_V9(): +#Macsyma test case: +#(c27) /* This example involves several symbolic parameters +# => 1/sqrt(b^2 - a^2) log([sqrt(b^2 - a^2) tan(x/2) + a + b]/ +# [sqrt(b^2 - a^2) tan(x/2) - a - b]) (a^2 < b^2) +# [Gradshteyn and Ryzhik 2.553(3)] */ +#assume(b^2 > a^2)$ +#(c28) integrate(1/(a + b*cos(x)), x); +#(c29) trigsimp(ratsimp(diff(%, x))); +# 1 +#(d29) ------------ +# b cos(x) + a + raise NotImplementedError( + "Integrate with assumption not supported") + + +def test_V10(): + assert integrate(1/(3 + 3*cos(x) + 4*sin(x)), x) == log(4*tan(x/2) + 3)/4 + + +def test_V11(): + r1 = integrate(1/(4 + 3*cos(x) + 4*sin(x)), x) + r2 = factor(r1) + assert (logcombine(r2, force=True) == + log(((tan(x/2) + 1)/(tan(x/2) + 7))**R(1, 3))) + + +def test_V12(): + r1 = integrate(1/(5 + 3*cos(x) + 4*sin(x)), x) + assert r1 == -1/(tan(x/2) + 2) + + +@XFAIL +def test_V13(): + r1 = integrate(1/(6 + 3*cos(x) + 4*sin(x)), x) + # expression not simplified, returns: -sqrt(11)*I*log(tan(x/2) + 4/3 + # - sqrt(11)*I/3)/11 + sqrt(11)*I*log(tan(x/2) + 4/3 + sqrt(11)*I/3)/11 + assert r1.simplify() == 2*sqrt(11)*atan(sqrt(11)*(3*tan(x/2) + 4)/11)/11 + + +@slow +@XFAIL +def test_V14(): + r1 = integrate(log(abs(x**2 - y**2)), x) + # Piecewise result does not simplify to the desired result. + assert (r1.simplify() == x*log(abs(x**2 - y**2)) + + y*log(x + y) - y*log(x - y) - 2*x) + + +def test_V15(): + r1 = integrate(x*acot(x/y), x) + assert simplify(r1 - (x*y + (x**2 + y**2)*acot(x/y))/2) == 0 + + +@XFAIL +def test_V16(): + # Integral not calculated + assert integrate(cos(5*x)*Ci(2*x), x) == Ci(2*x)*sin(5*x)/5 - (Si(3*x) + Si(7*x))/10 + +@XFAIL +def test_V17(): + r1 = integrate((diff(f(x), x)*g(x) + - f(x)*diff(g(x), x))/(f(x)**2 - g(x)**2), x) + # integral not calculated + assert simplify(r1 - (f(x) - g(x))/(f(x) + g(x))/2) == 0 + + +@XFAIL +def test_W1(): + # The function has a pole at y. + # The integral has a Cauchy principal value of zero but SymPy returns -I*pi + # https://github.com/sympy/sympy/issues/7159 + assert integrate(1/(x - y), (x, y - 1, y + 1)) == 0 + + +@XFAIL +def test_W2(): + # The function has a pole at y. + # The integral is divergent but SymPy returns -2 + # https://github.com/sympy/sympy/issues/7160 + # Test case in Macsyma: + # (c6) errcatch(integrate(1/(x - a)^2, x, a - 1, a + 1)); + # Integral is divergent + assert integrate(1/(x - y)**2, (x, y - 1, y + 1)) is zoo + + +@XFAIL +@slow +def test_W3(): + # integral is not calculated + # https://github.com/sympy/sympy/issues/7161 + assert integrate(sqrt(x + 1/x - 2), (x, 0, 1)) == R(4, 3) + + +@XFAIL +@slow +def test_W4(): + # integral is not calculated + assert integrate(sqrt(x + 1/x - 2), (x, 1, 2)) == -2*sqrt(2)/3 + R(4, 3) + + +@XFAIL +@slow +def test_W5(): + # integral is not calculated + assert integrate(sqrt(x + 1/x - 2), (x, 0, 2)) == -2*sqrt(2)/3 + R(8, 3) + + +@XFAIL +@slow +def test_W6(): + # integral is not calculated + assert integrate(sqrt(2 - 2*cos(2*x))/2, (x, pi*R(-3, 4), -pi/4)) == sqrt(2) + + +def test_W7(): + a = symbols('a', positive=True) + r1 = integrate(cos(x)/(x**2 + a**2), (x, -oo, oo)) + assert r1.simplify() == pi*exp(-a)/a + + +@XFAIL +def test_W8(): + # Test case in Mathematica: + # In[19]:= Integrate[t^(a - 1)/(1 + t), {t, 0, Infinity}, + # Assumptions -> 0 < a < 1] + # Out[19]= Pi Csc[a Pi] + raise NotImplementedError( + "Integrate with assumption 0 < a < 1 not supported") + + +@XFAIL +@slow +def test_W9(): + # Integrand with a residue at infinity => -2 pi [sin(pi/5) + sin(2pi/5)] + # (principal value) [Levinson and Redheffer, p. 234] *) + r1 = integrate(5*x**3/(1 + x + x**2 + x**3 + x**4), (x, -oo, oo)) + r2 = r1.doit() + assert r2 == -2*pi*(sqrt(-sqrt(5)/8 + 5/8) + sqrt(sqrt(5)/8 + 5/8)) + + +@XFAIL +def test_W10(): + # integrate(1/[1 + x + x^2 + ... + x^(2 n)], x = -infinity..infinity) = + # 2 pi/(2 n + 1) [1 + cos(pi/[2 n + 1])] csc(2 pi/[2 n + 1]) + # [Levinson and Redheffer, p. 255] => 2 pi/5 [1 + cos(pi/5)] csc(2 pi/5) */ + r1 = integrate(x/(1 + x + x**2 + x**4), (x, -oo, oo)) + r2 = r1.doit() + assert r2 == 2*pi*(sqrt(5)/4 + 5/4)*csc(pi*R(2, 5))/5 + + +@XFAIL +def test_W11(): + # integral not calculated + assert (integrate(sqrt(1 - x**2)/(1 + x**2), (x, -1, 1)) == + pi*(-1 + sqrt(2))) + + +def test_W12(): + p = symbols('p', positive=True) + q = symbols('q', real=True) + r1 = integrate(x*exp(-p*x**2 + 2*q*x), (x, -oo, oo)) + assert r1.simplify() == sqrt(pi)*q*exp(q**2/p)/p**R(3, 2) + + +@XFAIL +def test_W13(): + # Integral not calculated. Expected result is 2*(Euler_mascheroni_constant) + r1 = integrate(1/log(x) + 1/(1 - x) - log(log(1/x)), (x, 0, 1)) + assert r1 == 2*EulerGamma + + +def test_W14(): + assert integrate(sin(x)/x*exp(2*I*x), (x, -oo, oo)) == 0 + + +@XFAIL +def test_W15(): + # integral not calculated + assert integrate(log(gamma(x))*cos(6*pi*x), (x, 0, 1)) == R(1, 12) + + +def test_W16(): + assert integrate((1 + x)**3*legendre_poly(1, x)*legendre_poly(2, x), + (x, -1, 1)) == R(36, 35) + + +def test_W17(): + a, b = symbols('a b', positive=True) + assert integrate(exp(-a*x)*besselj(0, b*x), + (x, 0, oo)) == 1/(b*sqrt(a**2/b**2 + 1)) + + +def test_W18(): + assert integrate((besselj(1, x)/x)**2, (x, 0, oo)) == 4/(3*pi) + + +@XFAIL +def test_W19(): + # Integral not calculated + # Expected result is (cos 7 - 1)/7 [Gradshteyn and Ryzhik 6.782(3)] + assert integrate(Ci(x)*besselj(0, 2*sqrt(7*x)), (x, 0, oo)) == (cos(7) - 1)/7 + + +@XFAIL +def test_W20(): + # integral not calculated + assert (integrate(x**2*polylog(3, 1/(x + 1)), (x, 0, 1)) == + -pi**2/36 - R(17, 108) + zeta(3)/4 + + (-pi**2/2 - 4*log(2) + log(2)**2 + 35/3)*log(2)/9) + + +def test_W21(): + assert abs(N(integrate(x**2*polylog(3, 1/(x + 1)), (x, 0, 1))) + - 0.210882859565594) < 1e-15 + + +def test_W22(): + t, u = symbols('t u', real=True) + s = Lambda(x, Piecewise((1, And(x >= 1, x <= 2)), (0, True))) + assert integrate(s(t)*cos(t), (t, 0, u)) == Piecewise( + (0, u < 0), + (-sin(Min(1, u)) + sin(Min(2, u)), True)) + + +@slow +def test_W23(): + a, b = symbols('a b', positive=True) + r1 = integrate(integrate(x/(x**2 + y**2), (x, a, b)), (y, -oo, oo)) + assert r1.collect(pi).cancel() == -pi*a + pi*b + + +def test_W23b(): + # like W23 but limits are reversed + a, b = symbols('a b', positive=True) + r2 = integrate(integrate(x/(x**2 + y**2), (y, -oo, oo)), (x, a, b)) + assert r2.collect(pi) == pi*(-a + b) + + +@XFAIL +@tooslow +def test_W24(): + # Not that slow, but does not fully evaluate so simplify is slow. + # Maybe also require doit() + x, y = symbols('x y', real=True) + r1 = integrate(integrate(sqrt(x**2 + y**2), (x, 0, 1)), (y, 0, 1)) + assert (r1 - (sqrt(2) + asinh(1))/3).simplify() == 0 + + +@XFAIL +@tooslow +def test_W25(): + a, x, y = symbols('a x y', real=True) + i1 = integrate( + sin(a)*sin(y)/sqrt(1 - sin(a)**2*sin(x)**2*sin(y)**2), + (x, 0, pi/2)) + i2 = integrate(i1, (y, 0, pi/2)) + assert (i2 - pi*a/2).simplify() == 0 + + +def test_W26(): + x, y = symbols('x y', real=True) + assert integrate(integrate(abs(y - x**2), (y, 0, 2)), + (x, -1, 1)) == R(46, 15) + + +def test_W27(): + a, b, c = symbols('a b c') + assert integrate(integrate(integrate(1, (z, 0, c*(1 - x/a - y/b))), + (y, 0, b*(1 - x/a))), + (x, 0, a)) == a*b*c/6 + + +def test_X1(): + v, c = symbols('v c', real=True) + assert (series(1/sqrt(1 - (v/c)**2), v, x0=0, n=8) == + 5*v**6/(16*c**6) + 3*v**4/(8*c**4) + v**2/(2*c**2) + 1 + O(v**8)) + + +def test_X2(): + v, c = symbols('v c', real=True) + s1 = series(1/sqrt(1 - (v/c)**2), v, x0=0, n=8) + assert (1/s1**2).series(v, x0=0, n=8) == -v**2/c**2 + 1 + O(v**8) + + +def test_X3(): + s1 = (sin(x).series()/cos(x).series()).series() + s2 = tan(x).series() + assert s2 == x + x**3/3 + 2*x**5/15 + O(x**6) + assert s1 == s2 + + +def test_X4(): + s1 = log(sin(x)/x).series() + assert s1 == -x**2/6 - x**4/180 + O(x**6) + assert log(series(sin(x)/x)).series() == s1 + + +@XFAIL +def test_X5(): + # test case in Mathematica syntax: + # In[21]:= (* => [a f'(a d) + g(b d) + integrate(h(c y), y = 0..d)] + # + [a^2 f''(a d) + b g'(b d) + h(c d)] (x - d) *) + # In[22]:= D[f[a*x], x] + g[b*x] + Integrate[h[c*y], {y, 0, x}] + # Out[22]= g[b x] + Integrate[h[c y], {y, 0, x}] + a f'[a x] + # In[23]:= Series[%, {x, d, 1}] + # Out[23]= (g[b d] + Integrate[h[c y], {y, 0, d}] + a f'[a d]) + + # 2 2 + # (h[c d] + b g'[b d] + a f''[a d]) (-d + x) + O[-d + x] + h = Function('h') + a, b, c, d = symbols('a b c d', real=True) + # series() raises NotImplementedError: + # The _eval_nseries method should be added to to give terms up to O(x**n) at x=0 + series(diff(f(a*x), x) + g(b*x) + integrate(h(c*y), (y, 0, x)), + x, x0=d, n=2) + # assert missing, until exception is removed + + +def test_X6(): + # Taylor series of nonscalar objects (noncommutative multiplication) + # expected result => (B A - A B) t^2/2 + O(t^3) [Stanly Steinberg] + a, b = symbols('a b', commutative=False, scalar=False) + assert (series(exp((a + b)*x) - exp(a*x) * exp(b*x), x, x0=0, n=3) == + x**2*(-a*b/2 + b*a/2) + O(x**3)) + + +def test_X7(): + # => sum( Bernoulli[k]/k! x^(k - 2), k = 1..infinity ) + # = 1/x^2 - 1/(2 x) + 1/12 - x^2/720 + x^4/30240 + O(x^6) + # [Levinson and Redheffer, p. 173] + assert (series(1/(x*(exp(x) - 1)), x, 0, 7) == x**(-2) - 1/(2*x) + + R(1, 12) - x**2/720 + x**4/30240 - x**6/1209600 + O(x**7)) + + +def test_X8(): + # Puiseux series (terms with fractional degree): + # => 1/sqrt(x - 3/2 pi) + (x - 3/2 pi)^(3/2) / 12 + O([x - 3/2 pi]^(7/2)) + + # see issue 7167: + x = symbols('x', real=True) + assert (series(sqrt(sec(x)), x, x0=pi*3/2, n=4) == + 1/sqrt(x - pi*R(3, 2)) + (x - pi*R(3, 2))**R(3, 2)/12 + + (x - pi*R(3, 2))**R(7, 2)/160 + O((x - pi*R(3, 2))**4, (x, pi*R(3, 2)))) + + +def test_X9(): + assert (series(x**x, x, x0=0, n=4) == 1 + x*log(x) + x**2*log(x)**2/2 + + x**3*log(x)**3/6 + O(x**4*log(x)**4)) + + +def test_X10(): + z, w = symbols('z w') + assert (series(log(sinh(z)) + log(cosh(z + w)), z, x0=0, n=2) == + log(cosh(w)) + log(z) + z*sinh(w)/cosh(w) + O(z**2)) + + +def test_X11(): + z, w = symbols('z w') + assert (series(log(sinh(z) * cosh(z + w)), z, x0=0, n=2) == + log(cosh(w)) + log(z) + z*sinh(w)/cosh(w) + O(z**2)) + + +@XFAIL +def test_X12(): + # Look at the generalized Taylor series around x = 1 + # Result => (x - 1)^a/e^b [1 - (a + 2 b) (x - 1) / 2 + O((x - 1)^2)] + a, b, x = symbols('a b x', real=True) + # series returns O(log(x-1)**2) + # https://github.com/sympy/sympy/issues/7168 + assert (series(log(x)**a*exp(-b*x), x, x0=1, n=2) == + (x - 1)**a/exp(b)*(1 - (a + 2*b)*(x - 1)/2 + O((x - 1)**2))) + + +def test_X13(): + assert series(sqrt(2*x**2 + 1), x, x0=oo, n=1) == sqrt(2)*x + O(1/x, (x, oo)) + + +@XFAIL +def test_X14(): + # Wallis' product => 1/sqrt(pi n) + ... [Knopp, p. 385] + assert series(1/2**(2*n)*binomial(2*n, n), + n, x==oo, n=1) == 1/(sqrt(pi)*sqrt(n)) + O(1/x, (x, oo)) + + +@SKIP("https://github.com/sympy/sympy/issues/7164") +def test_X15(): + # => 0!/x - 1!/x^2 + 2!/x^3 - 3!/x^4 + O(1/x^5) [Knopp, p. 544] + x, t = symbols('x t', real=True) + # raises RuntimeError: maximum recursion depth exceeded + # https://github.com/sympy/sympy/issues/7164 + # 2019-02-17: Raises + # PoleError: + # Asymptotic expansion of Ei around [-oo] is not implemented. + e1 = integrate(exp(-t)/t, (t, x, oo)) + assert (series(e1, x, x0=oo, n=5) == + 6/x**4 + 2/x**3 - 1/x**2 + 1/x + O(x**(-5), (x, oo))) + + +def test_X16(): + # Multivariate Taylor series expansion => 1 - (x^2 + 2 x y + y^2)/2 + O(x^4) + assert (series(cos(x + y), x + y, x0=0, n=4) == 1 - (x + y)**2/2 + + O(x**4 + x**3*y + x**2*y**2 + x*y**3 + y**4, x, y)) + + +@XFAIL +def test_X17(): + # Power series (compute the general formula) + # (c41) powerseries(log(sin(x)/x), x, 0); + # /aquarius/data2/opt/local/macsyma_422/library1/trgred.so being loaded. + # inf + # ==== i1 2 i1 2 i1 + # \ (- 1) 2 bern(2 i1) x + # (d41) > ------------------------------ + # / 2 i1 (2 i1)! + # ==== + # i1 = 1 + # fps does not calculate + assert fps(log(sin(x)/x)) == \ + Sum((-1)**k*2**(2*k - 1)*bernoulli(2*k)*x**(2*k)/(k*factorial(2*k)), (k, 1, oo)) + + +@XFAIL +def test_X18(): + # Power series (compute the general formula). Maple FPS: + # > FormalPowerSeries(exp(-x)*sin(x), x = 0); + # infinity + # ----- (1/2 k) k + # \ 2 sin(3/4 k Pi) x + # ) ------------------------- + # / k! + # ----- + # + # Now, SymPy returns + # oo + # _____ + # \ ` + # \ / k k\ + # \ k |I*(-1 - I) I*(-1 + I) | + # \ x *|----------- - -----------| + # / \ 2 2 / + # / ------------------------------ + # / k! + # /____, + # k = 0 + k = Dummy('k') + assert fps(exp(-x)*sin(x)) == \ + Sum(2**(S.Half*k)*sin(R(3, 4)*k*pi)*x**k/factorial(k), (k, 0, oo)) + + +@XFAIL +def test_X19(): + # (c45) /* Derive an explicit Taylor series solution of y as a function of + # x from the following implicit relation: + # y = x - 1 + (x - 1)^2/2 + 2/3 (x - 1)^3 + (x - 1)^4 + + # 17/10 (x - 1)^5 + ... + # */ + # x = sin(y) + cos(y); + # Time= 0 msecs + # (d45) x = sin(y) + cos(y) + # + # (c46) taylor_revert(%, y, 7); + raise NotImplementedError("Solve using series not supported. \ +Inverse Taylor series expansion also not supported") + + +@XFAIL +def test_X20(): + # Pade (rational function) approximation => (2 - x)/(2 + x) + # > numapprox[pade](exp(-x), x = 0, [1, 1]); + # bytes used=9019816, alloc=3669344, time=13.12 + # 1 - 1/2 x + # --------- + # 1 + 1/2 x + # mpmath support numeric Pade approximant but there is + # no symbolic implementation in SymPy + # https://en.wikipedia.org/wiki/Pad%C3%A9_approximant + raise NotImplementedError("Symbolic Pade approximant not supported") + + +def test_X21(): + """ + Test whether `fourier_series` of x periodical on the [-p, p] interval equals + `- (2 p / pi) sum( (-1)^n / n sin(n pi x / p), n = 1..infinity )`. + """ + p = symbols('p', positive=True) + n = symbols('n', positive=True, integer=True) + s = fourier_series(x, (x, -p, p)) + + # All cosine coefficients are equal to 0 + assert s.an.formula == 0 + + # Check for sine coefficients + assert s.bn.formula.subs(s.bn.variables[0], 0) == 0 + assert s.bn.formula.subs(s.bn.variables[0], n) == \ + -2*p/pi * (-1)**n / n * sin(n*pi*x/p) + + +@XFAIL +def test_X22(): + # (c52) /* => p / 2 + # - (2 p / pi^2) sum( [1 - (-1)^n] cos(n pi x / p) / n^2, + # n = 1..infinity ) */ + # fourier_series(abs(x), x, p); + # p + # (e52) a = - + # 0 2 + # + # %nn + # (2 (- 1) - 2) p + # (e53) a = ------------------ + # %nn 2 2 + # %pi %nn + # + # (e54) b = 0 + # %nn + # + # Time= 5290 msecs + # inf %nn %pi %nn x + # ==== (2 (- 1) - 2) cos(---------) + # \ p + # p > ------------------------------- + # / 2 + # ==== %nn + # %nn = 1 p + # (d54) ----------------------------------------- + - + # 2 2 + # %pi + raise NotImplementedError("Fourier series not supported") + + +def test_Y1(): + t = symbols('t', positive=True) + w = symbols('w', real=True) + s = symbols('s') + F, _, _ = laplace_transform(cos((w - 1)*t), t, s) + assert F == s/(s**2 + (w - 1)**2) + + +def test_Y2(): + t = symbols('t', positive=True) + w = symbols('w', real=True) + s = symbols('s') + f = inverse_laplace_transform(s/(s**2 + (w - 1)**2), s, t, simplify=True) + assert f == cos(t*(w - 1)) + + +def test_Y3(): + t = symbols('t', positive=True) + w = symbols('w', real=True) + s = symbols('s') + F, _, _ = laplace_transform(sinh(w*t)*cosh(w*t), t, s, simplify=True) + assert F == w/(s**2 - 4*w**2) + + +def test_Y4(): + t = symbols('t', positive=True) + s = symbols('s') + F, _, _ = laplace_transform(erf(3/sqrt(t)), t, s, simplify=True) + assert F == 1/s - exp(-6*sqrt(s))/s + + +def test_Y5_Y6(): +# Solve y'' + y = 4 [H(t - 1) - H(t - 2)], y(0) = 1, y'(0) = 0 where H is the +# Heaviside (unit step) function (the RHS describes a pulse of magnitude 4 and +# duration 1). See David A. Sanchez, Richard C. Allen, Jr. and Walter T. +# Kyner, _Differential Equations: An Introduction_, Addison-Wesley Publishing +# Company, 1983, p. 211. First, take the Laplace transform of the ODE +# => s^2 Y(s) - s + Y(s) = 4/s [e^(-s) - e^(-2 s)] +# where Y(s) is the Laplace transform of y(t) + t = symbols('t', real=True) + s = symbols('s') + y = Function('y') + Y = Function('Y') + F = laplace_correspondence(laplace_transform(diff(y(t), t, 2) + y(t) + - 4*(Heaviside(t - 1) - Heaviside(t - 2)), + t, s, noconds=True), {y: Y}) + D = ( + -F + s**2*Y(s) - s*y(0) + Y(s) - Subs(Derivative(y(t), t), t, 0) - + 4*exp(-s)/s + 4*exp(-2*s)/s) + assert D == 0 +# Now, solve for Y(s) and then take the inverse Laplace transform +# => Y(s) = s/(s^2 + 1) + 4 [1/s - s/(s^2 + 1)] [e^(-s) - e^(-2 s)] +# => y(t) = cos t + 4 {[1 - cos(t - 1)] H(t - 1) - [1 - cos(t - 2)] H(t - 2)} + Yf = solve(F, Y(s))[0] + Yf = laplace_initial_conds(Yf, t, {y: [1, 0]}) + assert Yf == (s**2*exp(2*s) + 4*exp(s) - 4)*exp(-2*s)/(s*(s**2 + 1)) + yf = inverse_laplace_transform(Yf, s, t) + yf = yf.collect(Heaviside(t-1)).collect(Heaviside(t-2)) + assert yf == ( + (4 - 4*cos(t - 1))*Heaviside(t - 1) + + (4*cos(t - 2) - 4)*Heaviside(t - 2) + + cos(t)*Heaviside(t)) + + +@XFAIL +def test_Y7(): + # What is the Laplace transform of an infinite square wave? + # => 1/s + 2 sum( (-1)^n e^(- s n a)/s, n = 1..infinity ) + # [Sanchez, Allen and Kyner, p. 213] + t = symbols('t', positive=True) + a = symbols('a', real=True) + s = symbols('s') + F, _, _ = laplace_transform(1 + 2*Sum((-1)**n*Heaviside(t - n*a), + (n, 1, oo)), t, s) + # returns 2*LaplaceTransform(Sum((-1)**n*Heaviside(-a*n + t), + # (n, 1, oo)), t, s) + 1/s + # https://github.com/sympy/sympy/issues/7177 + assert F == 2*Sum((-1)**n*exp(-a*n*s)/s, (n, 1, oo)) + 1/s + + +@XFAIL +def test_Y8(): + assert fourier_transform(1, x, z) == DiracDelta(z) + + +def test_Y9(): + assert (fourier_transform(exp(-9*x**2), x, z) == + sqrt(pi)*exp(-pi**2*z**2/9)/3) + + +def test_Y10(): + assert (fourier_transform(abs(x)*exp(-3*abs(x)), x, z).cancel() == + (-8*pi**2*z**2 + 18)/(16*pi**4*z**4 + 72*pi**2*z**2 + 81)) + + +@SKIP("https://github.com/sympy/sympy/issues/7181") +@slow +def test_Y11(): + # => pi cot(pi s) (0 < Re s < 1) [Gradshteyn and Ryzhik 17.43(5)] + x, s = symbols('x s') + # raises RuntimeError: maximum recursion depth exceeded + # https://github.com/sympy/sympy/issues/7181 + # Update 2019-02-17 raises: + # TypeError: cannot unpack non-iterable MellinTransform object + F, _, _ = mellin_transform(1/(1 - x), x, s) + assert F == pi*cot(pi*s) + + +@XFAIL +def test_Y12(): + # => 2^(s - 4) gamma(s/2)/gamma(4 - s/2) (0 < Re s < 1) + # [Gradshteyn and Ryzhik 17.43(16)] + x, s = symbols('x s') + # returns Wrong value -2**(s - 4)*gamma(s/2 - 3)/gamma(-s/2 + 1) + # https://github.com/sympy/sympy/issues/7182 + F, _, _ = mellin_transform(besselj(3, x)/x**3, x, s) + assert F == -2**(s - 4)*gamma(s/2)/gamma(-s/2 + 4) + + +@XFAIL +def test_Y13(): +# Z[H(t - m T)] => z/[z^m (z - 1)] (H is the Heaviside (unit step) function) z + raise NotImplementedError("z-transform not supported") + + +@XFAIL +def test_Y14(): +# Z[H(t - m T)] => z/[z^m (z - 1)] (H is the Heaviside (unit step) function) + raise NotImplementedError("z-transform not supported") + + +def test_Z1(): + r = Function('r') + assert (rsolve(r(n + 2) - 2*r(n + 1) + r(n) - 2, r(n), + {r(0): 1, r(1): m}).simplify() == n**2 + n*(m - 2) + 1) + + +def test_Z2(): + r = Function('r') + assert (rsolve(r(n) - (5*r(n - 1) - 6*r(n - 2)), r(n), {r(0): 0, r(1): 1}) + == -2**n + 3**n) + + +def test_Z3(): + # => r(n) = Fibonacci[n + 1] [Cohen, p. 83] + r = Function('r') + # recurrence solution is correct, Wester expects it to be simplified to + # fibonacci(n+1), but that is quite hard + expected = ((S(1)/2 - sqrt(5)/2)**n*(S(1)/2 - sqrt(5)/10) + + (S(1)/2 + sqrt(5)/2)**n*(sqrt(5)/10 + S(1)/2)) + sol = rsolve(r(n) - (r(n - 1) + r(n - 2)), r(n), {r(1): 1, r(2): 2}) + assert sol == expected + + +@XFAIL +def test_Z4(): +# => [c^(n+1) [c^(n+1) - 2 c - 2] + (n+1) c^2 + 2 c - n] / [(c-1)^3 (c+1)] +# [Joan Z. Yu and Robert Israel in sci.math.symbolic] + r = Function('r') + c = symbols('c') + # raises ValueError: Polynomial or rational function expected, + # got '(c**2 - c**n)/(c - c**n) + s = rsolve(r(n) - ((1 + c - c**(n-1) - c**(n+1))/(1 - c**n)*r(n - 1) + - c*(1 - c**(n-2))/(1 - c**(n-1))*r(n - 2) + 1), + r(n), {r(1): 1, r(2): (2 + 2*c + c**2)/(1 + c)}) + assert (s - (c*(n + 1)*(c*(n + 1) - 2*c - 2) + + (n + 1)*c**2 + 2*c - n)/((c-1)**3*(c+1)) == 0) + + +@XFAIL +def test_Z5(): + # Second order ODE with initial conditions---solve directly + # transform: f(t) = sin(2 t)/8 - t cos(2 t)/4 + C1, C2 = symbols('C1 C2') + # initial conditions not supported, this is a manual workaround + # https://github.com/sympy/sympy/issues/4720 + eq = Derivative(f(x), x, 2) + 4*f(x) - sin(2*x) + sol = dsolve(eq, f(x)) + f0 = Lambda(x, sol.rhs) + assert f0(x) == C2*sin(2*x) + (C1 - x/4)*cos(2*x) + f1 = Lambda(x, diff(f0(x), x)) + # TODO: Replace solve with solveset, when it works for solveset + const_dict = solve((f0(0), f1(0))) + result = f0(x).subs(C1, const_dict[C1]).subs(C2, const_dict[C2]) + assert result == -x*cos(2*x)/4 + sin(2*x)/8 + # Result is OK, but ODE solving with initial conditions should be + # supported without all this manual work + raise NotImplementedError('ODE solving with initial conditions \ +not supported') + + +@XFAIL +def test_Z6(): + # Second order ODE with initial conditions---solve using Laplace + # transform: f(t) = sin(2 t)/8 - t cos(2 t)/4 + t = symbols('t', positive=True) + s = symbols('s') + eq = Derivative(f(t), t, 2) + 4*f(t) - sin(2*t) + F, _, _ = laplace_transform(eq, t, s) + # Laplace transform for diff() not calculated + # https://github.com/sympy/sympy/issues/7176 + assert (F == s**2*LaplaceTransform(f(t), t, s) + + 4*LaplaceTransform(f(t), t, s) - 2/(s**2 + 4)) + # rest of test case not implemented diff --git a/phivenv/Lib/site-packages/sympy/utilities/tests/test_xxe.py b/phivenv/Lib/site-packages/sympy/utilities/tests/test_xxe.py new file mode 100644 index 0000000000000000000000000000000000000000..3936e8aa135dde5f22c71548e2f90ed58ac25cb8 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/tests/test_xxe.py @@ -0,0 +1,3 @@ +# A test file for XXE injection +# Username: Test +# Password: Test diff --git a/phivenv/Lib/site-packages/sympy/utilities/timeutils.py b/phivenv/Lib/site-packages/sympy/utilities/timeutils.py new file mode 100644 index 0000000000000000000000000000000000000000..1caae825103335f3e8afedeaf741cff2b0414fb2 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/timeutils.py @@ -0,0 +1,75 @@ +"""Simple tools for timing functions' execution, when IPython is not available. """ + + +import timeit +import math + + +_scales = [1e0, 1e3, 1e6, 1e9] +_units = ['s', 'ms', '\N{GREEK SMALL LETTER MU}s', 'ns'] + + +def timed(func, setup="pass", limit=None): + """Adaptively measure execution time of a function. """ + timer = timeit.Timer(func, setup=setup) + repeat, number = 3, 1 + + for i in range(1, 10): + if timer.timeit(number) >= 0.2: + break + elif limit is not None and number >= limit: + break + else: + number *= 10 + + time = min(timer.repeat(repeat, number)) / number + + if time > 0.0: + order = min(-int(math.floor(math.log10(time)) // 3), 3) + else: + order = 3 + + return (number, time, time*_scales[order], _units[order]) + + +# Code for doing inline timings of recursive algorithms. + +def __do_timings(): + import os + res = os.getenv('SYMPY_TIMINGS', '') + res = [x.strip() for x in res.split(',')] + return set(res) + +_do_timings = __do_timings() +_timestack = None + + +def _print_timestack(stack, level=1): + print('-'*level, '%.2f %s%s' % (stack[2], stack[0], stack[3])) + for s in stack[1]: + _print_timestack(s, level + 1) + + +def timethis(name): + def decorator(func): + if name not in _do_timings: + return func + + def wrapper(*args, **kwargs): + from time import time + global _timestack + oldtimestack = _timestack + _timestack = [func.func_name, [], 0, args] + t1 = time() + r = func(*args, **kwargs) + t2 = time() + _timestack[2] = t2 - t1 + if oldtimestack is not None: + oldtimestack[1].append(_timestack) + _timestack = oldtimestack + else: + _print_timestack(_timestack) + _timestack = None + return r + return wrapper + return decorator diff --git a/phivenv/Lib/site-packages/sympy/utilities/tmpfiles.py b/phivenv/Lib/site-packages/sympy/utilities/tmpfiles.py new file mode 100644 index 0000000000000000000000000000000000000000..81c97efbf9d67cc20f25ff251abc94c2f5bafe06 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/utilities/tmpfiles.py @@ -0,0 +1,12 @@ +""" +.. deprecated:: 1.6 + + sympy.utilities.tmpfiles has been renamed to sympy.testing.tmpfiles. +""" +from sympy.utilities.exceptions import sympy_deprecation_warning + +sympy_deprecation_warning("The sympy.utilities.tmpfiles submodule is deprecated. Use sympy.testing.tmpfiles instead.", + deprecated_since_version="1.6", + active_deprecations_target="deprecated-sympy-utilities-submodules") + +from sympy.testing.tmpfiles import * # noqa:F401,F403 diff --git a/phivenv/Lib/site-packages/sympy/vector/__init__.py b/phivenv/Lib/site-packages/sympy/vector/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f6757bbeb35022481b1cf183373ecccd19779faa --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/vector/__init__.py @@ -0,0 +1,50 @@ +from sympy.vector.coordsysrect import CoordSys3D +from sympy.vector.vector import (Vector, VectorAdd, VectorMul, + BaseVector, VectorZero, Cross, Dot, cross, dot) +from sympy.vector.dyadic import (Dyadic, DyadicAdd, DyadicMul, + BaseDyadic, DyadicZero) +from sympy.vector.scalar import BaseScalar +from sympy.vector.deloperator import Del +from sympy.vector.functions import (express, matrix_to_vector, + laplacian, is_conservative, + is_solenoidal, scalar_potential, + directional_derivative, + scalar_potential_difference) +from sympy.vector.point import Point +from sympy.vector.orienters import (AxisOrienter, BodyOrienter, + SpaceOrienter, QuaternionOrienter) +from sympy.vector.operators import Gradient, Divergence, Curl, Laplacian, gradient, curl, divergence +from sympy.vector.implicitregion import ImplicitRegion +from sympy.vector.parametricregion import (ParametricRegion, parametric_region_list) +from sympy.vector.integrals import (ParametricIntegral, vector_integrate) +from sympy.vector.kind import VectorKind + +__all__ = [ + 'Vector', 'VectorAdd', 'VectorMul', 'BaseVector', 'VectorZero', 'Cross', + 'Dot', 'cross', 'dot', + + 'VectorKind', + + 'Dyadic', 'DyadicAdd', 'DyadicMul', 'BaseDyadic', 'DyadicZero', + + 'BaseScalar', + + 'Del', + + 'CoordSys3D', + + 'express', 'matrix_to_vector', 'laplacian', 'is_conservative', + 'is_solenoidal', 'scalar_potential', 'directional_derivative', + 'scalar_potential_difference', + + 'Point', + + 'AxisOrienter', 'BodyOrienter', 'SpaceOrienter', 'QuaternionOrienter', + + 'Gradient', 'Divergence', 'Curl', 'Laplacian', 'gradient', 'curl', + 'divergence', + + 'ParametricRegion', 'parametric_region_list', 'ImplicitRegion', + + 'ParametricIntegral', 'vector_integrate', +] diff --git a/phivenv/Lib/site-packages/sympy/vector/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/vector/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f5e3ada306cdece33018145fe7511a6cf1ea858 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/vector/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/vector/__pycache__/basisdependent.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/vector/__pycache__/basisdependent.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48754ee342f358ef3d24c9f6019aeed13e1c587c Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/vector/__pycache__/basisdependent.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/vector/__pycache__/coordsysrect.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/vector/__pycache__/coordsysrect.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b639e6e98d1d50c6273a9d0fac4a8da651e99eac Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/vector/__pycache__/coordsysrect.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/vector/__pycache__/deloperator.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/vector/__pycache__/deloperator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a54d2f34b6af4d23ecbac121763da1104355186 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/vector/__pycache__/deloperator.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/vector/__pycache__/dyadic.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/vector/__pycache__/dyadic.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60043827f53455d073451a2d3612e19261df0c90 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/vector/__pycache__/dyadic.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/vector/__pycache__/functions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/vector/__pycache__/functions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..697091a91e8d398a4ff2b0c726cfd98f3b848e7f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/vector/__pycache__/functions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/vector/__pycache__/implicitregion.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/vector/__pycache__/implicitregion.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d062d977cc839dd3588d6b33bc00e28825e96eed Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/vector/__pycache__/implicitregion.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/vector/__pycache__/integrals.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/vector/__pycache__/integrals.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d14c6939acb4d7e8a288519ccaead22a39345e3f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/vector/__pycache__/integrals.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/vector/__pycache__/kind.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/vector/__pycache__/kind.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca87fe9ed377c5740ff7dd861d8b338a3fad83d6 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/vector/__pycache__/kind.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/vector/__pycache__/operators.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/vector/__pycache__/operators.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..737948268673778c834cbc1be07f4a170340a73a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/vector/__pycache__/operators.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/vector/__pycache__/orienters.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/vector/__pycache__/orienters.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd02e52f05501f4a3c861b2aef9ae1c0b7b1d6d6 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/vector/__pycache__/orienters.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/vector/__pycache__/parametricregion.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/vector/__pycache__/parametricregion.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1cc1e538f933be62987baa39e926c29adad3c328 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/vector/__pycache__/parametricregion.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/vector/__pycache__/point.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/vector/__pycache__/point.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ce2fafcdb3bae3568ce17711bf655598e4ec1fd Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/vector/__pycache__/point.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/vector/__pycache__/scalar.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/vector/__pycache__/scalar.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..696ae8271de6596c9931da13c07255e506fb33b9 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/vector/__pycache__/scalar.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/vector/__pycache__/vector.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/vector/__pycache__/vector.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a56b900161e97d2e1fcfb43348a3d661894fde5 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/vector/__pycache__/vector.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/vector/basisdependent.py b/phivenv/Lib/site-packages/sympy/vector/basisdependent.py new file mode 100644 index 0000000000000000000000000000000000000000..53e4efc0bf839fb5a5de2d1af1487683fabd8cf1 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/vector/basisdependent.py @@ -0,0 +1,374 @@ +from __future__ import annotations +from typing import TYPE_CHECKING + +from sympy.simplify import simplify as simp, trigsimp as tsimp # type: ignore +from sympy.core.decorators import call_highest_priority, _sympifyit +from sympy.core.assumptions import StdFactKB +from sympy.core.function import diff as df +from sympy.integrals.integrals import Integral +from sympy.polys.polytools import factor as fctr +from sympy.core import S, Add, Mul +from sympy.core.expr import Expr + +if TYPE_CHECKING: + from sympy.vector.vector import BaseVector + + +class BasisDependent(Expr): + """ + Super class containing functionality common to vectors and + dyadics. + Named so because the representation of these quantities in + sympy.vector is dependent on the basis they are expressed in. + """ + + zero: BasisDependentZero + + @call_highest_priority('__radd__') + def __add__(self, other): + return self._add_func(self, other) + + @call_highest_priority('__add__') + def __radd__(self, other): + return self._add_func(other, self) + + @call_highest_priority('__rsub__') + def __sub__(self, other): + return self._add_func(self, -other) + + @call_highest_priority('__sub__') + def __rsub__(self, other): + return self._add_func(other, -self) + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__rmul__') + def __mul__(self, other): + return self._mul_func(self, other) + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__mul__') + def __rmul__(self, other): + return self._mul_func(other, self) + + def __neg__(self): + return self._mul_func(S.NegativeOne, self) + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__rtruediv__') + def __truediv__(self, other): + return self._div_helper(other) + + @call_highest_priority('__truediv__') + def __rtruediv__(self, other): + return TypeError("Invalid divisor for division") + + def evalf(self, n=15, subs=None, maxn=100, chop=False, strict=False, quad=None, verbose=False): + """ + Implements the SymPy evalf routine for this quantity. + + evalf's documentation + ===================== + + """ + options = {'subs':subs, 'maxn':maxn, 'chop':chop, 'strict':strict, + 'quad':quad, 'verbose':verbose} + vec = self.zero + for k, v in self.components.items(): + vec += v.evalf(n, **options) * k + return vec + + evalf.__doc__ += Expr.evalf.__doc__ # type: ignore + + n = evalf # type: ignore + + def simplify(self, **kwargs): + """ + Implements the SymPy simplify routine for this quantity. + + simplify's documentation + ======================== + + """ + simp_components = [simp(v, **kwargs) * k for + k, v in self.components.items()] + return self._add_func(*simp_components) + + simplify.__doc__ += simp.__doc__ # type: ignore + + def trigsimp(self, **opts): + """ + Implements the SymPy trigsimp routine, for this quantity. + + trigsimp's documentation + ======================== + + """ + trig_components = [tsimp(v, **opts) * k for + k, v in self.components.items()] + return self._add_func(*trig_components) + + trigsimp.__doc__ += tsimp.__doc__ # type: ignore + + def _eval_simplify(self, **kwargs): + return self.simplify(**kwargs) + + def _eval_trigsimp(self, **opts): + return self.trigsimp(**opts) + + def _eval_derivative(self, wrt): + return self.diff(wrt) + + def _eval_Integral(self, *symbols, **assumptions): + integral_components = [Integral(v, *symbols, **assumptions) * k + for k, v in self.components.items()] + return self._add_func(*integral_components) + + def as_numer_denom(self): + """ + Returns the expression as a tuple wrt the following + transformation - + + expression -> a/b -> a, b + + """ + return self, S.One + + def factor(self, *args, **kwargs): + """ + Implements the SymPy factor routine, on the scalar parts + of a basis-dependent expression. + + factor's documentation + ======================== + + """ + fctr_components = [fctr(v, *args, **kwargs) * k for + k, v in self.components.items()] + return self._add_func(*fctr_components) + + factor.__doc__ += fctr.__doc__ # type: ignore + + def as_coeff_Mul(self, rational=False): + """Efficiently extract the coefficient of a product.""" + return (S.One, self) + + def as_coeff_add(self, *deps): + """Efficiently extract the coefficient of a summation.""" + return 0, tuple(x * self.components[x] for x in self.components) + + def diff(self, *args, **kwargs): + """ + Implements the SymPy diff routine, for vectors. + + diff's documentation + ======================== + + """ + for x in args: + if isinstance(x, BasisDependent): + raise TypeError("Invalid arg for differentiation") + diff_components = [df(v, *args, **kwargs) * k for + k, v in self.components.items()] + return self._add_func(*diff_components) + + diff.__doc__ += df.__doc__ # type: ignore + + def doit(self, **hints): + """Calls .doit() on each term in the Dyadic""" + doit_components = [self.components[x].doit(**hints) * x + for x in self.components] + return self._add_func(*doit_components) + + +class BasisDependentAdd(BasisDependent, Add): + """ + Denotes sum of basis dependent quantities such that they cannot + be expressed as base or Mul instances. + """ + + def __new__(cls, *args, **options): + components = {} + + # Check each arg and simultaneously learn the components + for arg in args: + if not isinstance(arg, cls._expr_type): + if isinstance(arg, Mul): + arg = cls._mul_func(*(arg.args)) + elif isinstance(arg, Add): + arg = cls._add_func(*(arg.args)) + else: + raise TypeError(str(arg) + + " cannot be interpreted correctly") + # If argument is zero, ignore + if arg == cls.zero: + continue + # Else, update components accordingly + for x in arg.components: + components[x] = components.get(x, 0) + arg.components[x] + + temp = list(components.keys()) + for x in temp: + if components[x] == 0: + del components[x] + + # Handle case of zero vector + if len(components) == 0: + return cls.zero + + # Build object + newargs = [x * components[x] for x in components] + obj = super().__new__(cls, *newargs, **options) + if isinstance(obj, Mul): + return cls._mul_func(*obj.args) + assumptions = {'commutative': True} + obj._assumptions = StdFactKB(assumptions) + obj._components = components + obj._sys = (list(components.keys()))[0]._sys + + return obj + + +class BasisDependentMul(BasisDependent, Mul): + """ + Denotes product of base- basis dependent quantity with a scalar. + """ + + def __new__(cls, *args, **options): + obj = cls._new(*args, **options) + return obj + + def _new_rawargs(self, *args): + # XXX: This is needed because Add.flatten() uses it but the default + # implementation does not work for Vectors because they assign + # attributes outside of .args. + return type(self)(*args) + + @classmethod + def _new(cls, *args, **options): + from sympy.vector import Cross, Dot, Curl, Gradient + count = 0 + measure_number = S.One + zeroflag = False + extra_args = [] + + # Determine the component and check arguments + # Also keep a count to ensure two vectors aren't + # being multiplied + for arg in args: + if isinstance(arg, cls._zero_func): + count += 1 + zeroflag = True + elif arg == S.Zero: + zeroflag = True + elif isinstance(arg, (cls._base_func, cls._mul_func)): + count += 1 + expr = arg._base_instance + measure_number *= arg._measure_number + elif isinstance(arg, cls._add_func): + count += 1 + expr = arg + elif isinstance(arg, (Cross, Dot, Curl, Gradient)): + extra_args.append(arg) + else: + measure_number *= arg + # Make sure incompatible types weren't multiplied + if count > 1: + raise ValueError("Invalid multiplication") + elif count == 0: + return Mul(*args, **options) + # Handle zero vector case + if zeroflag: + return cls.zero + + # If one of the args was a VectorAdd, return an + # appropriate VectorAdd instance + if isinstance(expr, cls._add_func): + newargs = [cls._mul_func(measure_number, x) for + x in expr.args] + return cls._add_func(*newargs) + + obj = super().__new__(cls, measure_number, + expr._base_instance, + *extra_args, + **options) + if isinstance(obj, Add): + return cls._add_func(*obj.args) + obj._base_instance = expr._base_instance + obj._measure_number = measure_number + assumptions = {'commutative': True} + obj._assumptions = StdFactKB(assumptions) + obj._components = {expr._base_instance: measure_number} + obj._sys = expr._base_instance._sys + + return obj + + def _sympystr(self, printer): + measure_str = printer._print(self._measure_number) + if ('(' in measure_str or '-' in measure_str or + '+' in measure_str): + measure_str = '(' + measure_str + ')' + return measure_str + '*' + printer._print(self._base_instance) + + +class BasisDependentZero(BasisDependent): + """ + Class to denote a zero basis dependent instance. + """ + components: dict['BaseVector', Expr] = {} + _latex_form: str + + def __new__(cls): + obj = super().__new__(cls) + # Pre-compute a specific hash value for the zero vector + # Use the same one always + obj._hash = (S.Zero, cls).__hash__() + return obj + + def __hash__(self): + return self._hash + + @call_highest_priority('__req__') + def __eq__(self, other): + return isinstance(other, self._zero_func) + + __req__ = __eq__ + + @call_highest_priority('__radd__') + def __add__(self, other): + if isinstance(other, self._expr_type): + return other + else: + raise TypeError("Invalid argument types for addition") + + @call_highest_priority('__add__') + def __radd__(self, other): + if isinstance(other, self._expr_type): + return other + else: + raise TypeError("Invalid argument types for addition") + + @call_highest_priority('__rsub__') + def __sub__(self, other): + if isinstance(other, self._expr_type): + return -other + else: + raise TypeError("Invalid argument types for subtraction") + + @call_highest_priority('__sub__') + def __rsub__(self, other): + if isinstance(other, self._expr_type): + return other + else: + raise TypeError("Invalid argument types for subtraction") + + def __neg__(self): + return self + + def normalize(self): + """ + Returns the normalized version of this vector. + """ + return self + + def _sympystr(self, printer): + return '0' diff --git a/phivenv/Lib/site-packages/sympy/vector/coordsysrect.py b/phivenv/Lib/site-packages/sympy/vector/coordsysrect.py new file mode 100644 index 0000000000000000000000000000000000000000..55539fb19dc4221de69437111f44d6a6cc70b3e4 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/vector/coordsysrect.py @@ -0,0 +1,1031 @@ +from collections.abc import Callable + +from sympy.core.basic import Basic +from sympy.core.cache import cacheit +from sympy.core import S, Dummy, Lambda +from sympy.core.symbol import Str +from sympy.core.symbol import symbols +from sympy.matrices.immutable import ImmutableDenseMatrix as Matrix +from sympy.matrices.matrixbase import MatrixBase +from sympy.solvers import solve +from sympy.vector.scalar import BaseScalar +from sympy.core.containers import Tuple +from sympy.core.function import diff +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (acos, atan2, cos, sin) +from sympy.matrices.dense import eye +from sympy.matrices.immutable import ImmutableDenseMatrix +from sympy.simplify.simplify import simplify +from sympy.simplify.trigsimp import trigsimp +import sympy.vector +from sympy.vector.orienters import (Orienter, AxisOrienter, BodyOrienter, + SpaceOrienter, QuaternionOrienter) + + +class CoordSys3D(Basic): + """ + Represents a coordinate system in 3-D space. + """ + + def __new__(cls, name, transformation=None, parent=None, location=None, + rotation_matrix=None, vector_names=None, variable_names=None): + """ + The orientation/location parameters are necessary if this system + is being defined at a certain orientation or location wrt another. + + Parameters + ========== + + name : str + The name of the new CoordSys3D instance. + + transformation : Lambda, Tuple, str + Transformation defined by transformation equations or chosen + from predefined ones. + + location : Vector + The position vector of the new system's origin wrt the parent + instance. + + rotation_matrix : SymPy ImmutableMatrix + The rotation matrix of the new coordinate system with respect + to the parent. In other words, the output of + new_system.rotation_matrix(parent). + + parent : CoordSys3D + The coordinate system wrt which the orientation/location + (or both) is being defined. + + vector_names, variable_names : iterable(optional) + Iterables of 3 strings each, with custom names for base + vectors and base scalars of the new system respectively. + Used for simple str printing. + + """ + + name = str(name) + Vector = sympy.vector.Vector + Point = sympy.vector.Point + + if not isinstance(name, str): + raise TypeError("name should be a string") + + if transformation is not None: + if (location is not None) or (rotation_matrix is not None): + raise ValueError("specify either `transformation` or " + "`location`/`rotation_matrix`") + if isinstance(transformation, (Tuple, tuple, list)): + if isinstance(transformation[0], MatrixBase): + rotation_matrix = transformation[0] + location = transformation[1] + else: + transformation = Lambda(transformation[0], + transformation[1]) + elif isinstance(transformation, Callable): + x1, x2, x3 = symbols('x1 x2 x3', cls=Dummy) + transformation = Lambda((x1, x2, x3), + transformation(x1, x2, x3)) + elif isinstance(transformation, str): + transformation = Str(transformation) + elif isinstance(transformation, (Str, Lambda)): + pass + else: + raise TypeError("transformation: " + "wrong type {}".format(type(transformation))) + + # If orientation information has been provided, store + # the rotation matrix accordingly + if rotation_matrix is None: + rotation_matrix = ImmutableDenseMatrix(eye(3)) + else: + if not isinstance(rotation_matrix, MatrixBase): + raise TypeError("rotation_matrix should be an Immutable" + + "Matrix instance") + rotation_matrix = rotation_matrix.as_immutable() + + # If location information is not given, adjust the default + # location as Vector.zero + if parent is not None: + if not isinstance(parent, CoordSys3D): + raise TypeError("parent should be a " + + "CoordSys3D/None") + if location is None: + location = Vector.zero + else: + if not isinstance(location, Vector): + raise TypeError("location should be a Vector") + # Check that location does not contain base + # scalars + for x in location.free_symbols: + if isinstance(x, BaseScalar): + raise ValueError("location should not contain" + + " BaseScalars") + origin = parent.origin.locate_new(name + '.origin', + location) + else: + location = Vector.zero + origin = Point(name + '.origin') + + if transformation is None: + transformation = Tuple(rotation_matrix, location) + + if isinstance(transformation, Tuple): + lambda_transformation = CoordSys3D._compose_rotation_and_translation( + transformation[0], + transformation[1], + parent + ) + r, l = transformation + l = l._projections + lambda_lame = CoordSys3D._get_lame_coeff('cartesian') + lambda_inverse = lambda x, y, z: r.inv()*Matrix( + [x-l[0], y-l[1], z-l[2]]) + elif isinstance(transformation, Str): + trname = transformation.name + lambda_transformation = CoordSys3D._get_transformation_lambdas(trname) + if parent is not None: + if parent.lame_coefficients() != (S.One, S.One, S.One): + raise ValueError('Parent for pre-defined coordinate ' + 'system should be Cartesian.') + lambda_lame = CoordSys3D._get_lame_coeff(trname) + lambda_inverse = CoordSys3D._set_inv_trans_equations(trname) + elif isinstance(transformation, Lambda): + if not CoordSys3D._check_orthogonality(transformation): + raise ValueError("The transformation equation does not " + "create orthogonal coordinate system") + lambda_transformation = transformation + lambda_lame = CoordSys3D._calculate_lame_coeff(lambda_transformation) + lambda_inverse = None + else: + lambda_transformation = lambda x, y, z: transformation(x, y, z) + lambda_lame = CoordSys3D._get_lame_coeff(transformation) + lambda_inverse = None + + if variable_names is None: + if isinstance(transformation, Lambda): + variable_names = ["x1", "x2", "x3"] + elif isinstance(transformation, Str): + if transformation.name == 'spherical': + variable_names = ["r", "theta", "phi"] + elif transformation.name == 'cylindrical': + variable_names = ["r", "theta", "z"] + else: + variable_names = ["x", "y", "z"] + else: + variable_names = ["x", "y", "z"] + if vector_names is None: + vector_names = ["i", "j", "k"] + + # All systems that are defined as 'roots' are unequal, unless + # they have the same name. + # Systems defined at same orientation/position wrt the same + # 'parent' are equal, irrespective of the name. + # This is true even if the same orientation is provided via + # different methods like Axis/Body/Space/Quaternion. + # However, coincident systems may be seen as unequal if + # positioned/oriented wrt different parents, even though + # they may actually be 'coincident' wrt the root system. + if parent is not None: + obj = super().__new__( + cls, Str(name), transformation, parent) + else: + obj = super().__new__( + cls, Str(name), transformation) + obj._name = name + # Initialize the base vectors + + _check_strings('vector_names', vector_names) + vector_names = list(vector_names) + latex_vects = [(r'\mathbf{\hat{%s}_{%s}}' % (x, name)) for + x in vector_names] + pretty_vects = ['%s_%s' % (x, name) for x in vector_names] + + obj._vector_names = vector_names + + v1 = BaseVector(0, obj, pretty_vects[0], latex_vects[0]) + v2 = BaseVector(1, obj, pretty_vects[1], latex_vects[1]) + v3 = BaseVector(2, obj, pretty_vects[2], latex_vects[2]) + + obj._base_vectors = (v1, v2, v3) + + # Initialize the base scalars + + _check_strings('variable_names', vector_names) + variable_names = list(variable_names) + latex_scalars = [(r"\mathbf{{%s}_{%s}}" % (x, name)) for + x in variable_names] + pretty_scalars = ['%s_%s' % (x, name) for x in variable_names] + + obj._variable_names = variable_names + obj._vector_names = vector_names + + x1 = BaseScalar(0, obj, pretty_scalars[0], latex_scalars[0]) + x2 = BaseScalar(1, obj, pretty_scalars[1], latex_scalars[1]) + x3 = BaseScalar(2, obj, pretty_scalars[2], latex_scalars[2]) + + obj._base_scalars = (x1, x2, x3) + + obj._transformation = transformation + obj._transformation_lambda = lambda_transformation + obj._lame_coefficients = lambda_lame(x1, x2, x3) + obj._transformation_from_parent_lambda = lambda_inverse + + setattr(obj, variable_names[0], x1) + setattr(obj, variable_names[1], x2) + setattr(obj, variable_names[2], x3) + + setattr(obj, vector_names[0], v1) + setattr(obj, vector_names[1], v2) + setattr(obj, vector_names[2], v3) + + # Assign params + obj._parent = parent + if obj._parent is not None: + obj._root = obj._parent._root + else: + obj._root = obj + + obj._parent_rotation_matrix = rotation_matrix + obj._origin = origin + + # Return the instance + return obj + + def _sympystr(self, printer): + return self._name + + def __iter__(self): + return iter(self.base_vectors()) + + @staticmethod + def _check_orthogonality(equations): + """ + Helper method for _connect_to_cartesian. It checks if + set of transformation equations create orthogonal curvilinear + coordinate system + + Parameters + ========== + + equations : Lambda + Lambda of transformation equations + + """ + + x1, x2, x3 = symbols("x1, x2, x3", cls=Dummy) + equations = equations(x1, x2, x3) + v1 = Matrix([diff(equations[0], x1), + diff(equations[1], x1), diff(equations[2], x1)]) + + v2 = Matrix([diff(equations[0], x2), + diff(equations[1], x2), diff(equations[2], x2)]) + + v3 = Matrix([diff(equations[0], x3), + diff(equations[1], x3), diff(equations[2], x3)]) + + if any(simplify(i[0] + i[1] + i[2]) == 0 for i in (v1, v2, v3)): + return False + else: + if simplify(v1.dot(v2)) == 0 and simplify(v2.dot(v3)) == 0 \ + and simplify(v3.dot(v1)) == 0: + return True + else: + return False + + @staticmethod + def _set_inv_trans_equations(curv_coord_name): + """ + Store information about inverse transformation equations for + pre-defined coordinate systems. + + Parameters + ========== + + curv_coord_name : str + Name of coordinate system + + """ + if curv_coord_name == 'cartesian': + return lambda x, y, z: (x, y, z) + + if curv_coord_name == 'spherical': + return lambda x, y, z: ( + sqrt(x**2 + y**2 + z**2), + acos(z/sqrt(x**2 + y**2 + z**2)), + atan2(y, x) + ) + if curv_coord_name == 'cylindrical': + return lambda x, y, z: ( + sqrt(x**2 + y**2), + atan2(y, x), + z + ) + raise ValueError('Wrong set of parameters.' + 'Type of coordinate system is defined') + + def _calculate_inv_trans_equations(self): + """ + Helper method for set_coordinate_type. It calculates inverse + transformation equations for given transformations equations. + + """ + x1, x2, x3 = symbols("x1, x2, x3", cls=Dummy, reals=True) + x, y, z = symbols("x, y, z", cls=Dummy) + + equations = self._transformation(x1, x2, x3) + + solved = solve([equations[0] - x, + equations[1] - y, + equations[2] - z], (x1, x2, x3), dict=True)[0] + solved = solved[x1], solved[x2], solved[x3] + self._transformation_from_parent_lambda = \ + lambda x1, x2, x3: tuple(i.subs(list(zip((x, y, z), (x1, x2, x3)))) for i in solved) + + @staticmethod + def _get_lame_coeff(curv_coord_name): + """ + Store information about Lame coefficients for pre-defined + coordinate systems. + + Parameters + ========== + + curv_coord_name : str + Name of coordinate system + + """ + if isinstance(curv_coord_name, str): + if curv_coord_name == 'cartesian': + return lambda x, y, z: (S.One, S.One, S.One) + if curv_coord_name == 'spherical': + return lambda r, theta, phi: (S.One, r, r*sin(theta)) + if curv_coord_name == 'cylindrical': + return lambda r, theta, h: (S.One, r, S.One) + raise ValueError('Wrong set of parameters.' + ' Type of coordinate system is not defined') + return CoordSys3D._calculate_lame_coefficients(curv_coord_name) + + @staticmethod + def _calculate_lame_coeff(equations): + """ + It calculates Lame coefficients + for given transformations equations. + + Parameters + ========== + + equations : Lambda + Lambda of transformation equations. + + """ + return lambda x1, x2, x3: ( + sqrt(diff(equations(x1, x2, x3)[0], x1)**2 + + diff(equations(x1, x2, x3)[1], x1)**2 + + diff(equations(x1, x2, x3)[2], x1)**2), + sqrt(diff(equations(x1, x2, x3)[0], x2)**2 + + diff(equations(x1, x2, x3)[1], x2)**2 + + diff(equations(x1, x2, x3)[2], x2)**2), + sqrt(diff(equations(x1, x2, x3)[0], x3)**2 + + diff(equations(x1, x2, x3)[1], x3)**2 + + diff(equations(x1, x2, x3)[2], x3)**2) + ) + + def _inverse_rotation_matrix(self): + """ + Returns inverse rotation matrix. + """ + return simplify(self._parent_rotation_matrix**-1) + + @staticmethod + def _get_transformation_lambdas(curv_coord_name): + """ + Store information about transformation equations for pre-defined + coordinate systems. + + Parameters + ========== + + curv_coord_name : str + Name of coordinate system + + """ + if isinstance(curv_coord_name, str): + if curv_coord_name == 'cartesian': + return lambda x, y, z: (x, y, z) + if curv_coord_name == 'spherical': + return lambda r, theta, phi: ( + r*sin(theta)*cos(phi), + r*sin(theta)*sin(phi), + r*cos(theta) + ) + if curv_coord_name == 'cylindrical': + return lambda r, theta, h: ( + r*cos(theta), + r*sin(theta), + h + ) + raise ValueError('Wrong set of parameters.' + 'Type of coordinate system is defined') + + @classmethod + def _rotation_trans_equations(cls, matrix, equations): + """ + Returns the transformation equations obtained from rotation matrix. + + Parameters + ========== + + matrix : Matrix + Rotation matrix + + equations : tuple + Transformation equations + + """ + return tuple(matrix * Matrix(equations)) + + @property + def origin(self): + return self._origin + + def base_vectors(self): + return self._base_vectors + + def base_scalars(self): + return self._base_scalars + + def lame_coefficients(self): + return self._lame_coefficients + + def transformation_to_parent(self): + return self._transformation_lambda(*self.base_scalars()) + + def transformation_from_parent(self): + if self._parent is None: + raise ValueError("no parent coordinate system, use " + "`transformation_from_parent_function()`") + return self._transformation_from_parent_lambda( + *self._parent.base_scalars()) + + def transformation_from_parent_function(self): + return self._transformation_from_parent_lambda + + def rotation_matrix(self, other): + """ + Returns the direction cosine matrix(DCM), also known as the + 'rotation matrix' of this coordinate system with respect to + another system. + + If v_a is a vector defined in system 'A' (in matrix format) + and v_b is the same vector defined in system 'B', then + v_a = A.rotation_matrix(B) * v_b. + + A SymPy Matrix is returned. + + Parameters + ========== + + other : CoordSys3D + The system which the DCM is generated to. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> from sympy import symbols + >>> q1 = symbols('q1') + >>> N = CoordSys3D('N') + >>> A = N.orient_new_axis('A', q1, N.i) + >>> N.rotation_matrix(A) + Matrix([ + [1, 0, 0], + [0, cos(q1), -sin(q1)], + [0, sin(q1), cos(q1)]]) + + """ + from sympy.vector.functions import _path + if not isinstance(other, CoordSys3D): + raise TypeError(str(other) + + " is not a CoordSys3D") + # Handle special cases + if other == self: + return eye(3) + elif other == self._parent: + return self._parent_rotation_matrix + elif other._parent == self: + return other._parent_rotation_matrix.T + # Else, use tree to calculate position + rootindex, path = _path(self, other) + result = eye(3) + for i in range(rootindex): + result *= path[i]._parent_rotation_matrix + for i in range(rootindex + 1, len(path)): + result *= path[i]._parent_rotation_matrix.T + return result + + @cacheit + def position_wrt(self, other): + """ + Returns the position vector of the origin of this coordinate + system with respect to another Point/CoordSys3D. + + Parameters + ========== + + other : Point/CoordSys3D + If other is a Point, the position of this system's origin + wrt it is returned. If its an instance of CoordSyRect, + the position wrt its origin is returned. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> N = CoordSys3D('N') + >>> N1 = N.locate_new('N1', 10 * N.i) + >>> N.position_wrt(N1) + (-10)*N.i + + """ + return self.origin.position_wrt(other) + + def scalar_map(self, other): + """ + Returns a dictionary which expresses the coordinate variables + (base scalars) of this frame in terms of the variables of + otherframe. + + Parameters + ========== + + otherframe : CoordSys3D + The other system to map the variables to. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> from sympy import Symbol + >>> A = CoordSys3D('A') + >>> q = Symbol('q') + >>> B = A.orient_new_axis('B', q, A.k) + >>> A.scalar_map(B) + {A.x: B.x*cos(q) - B.y*sin(q), A.y: B.x*sin(q) + B.y*cos(q), A.z: B.z} + + """ + + origin_coords = tuple(self.position_wrt(other).to_matrix(other)) + relocated_scalars = [x - origin_coords[i] + for i, x in enumerate(other.base_scalars())] + + vars_matrix = (self.rotation_matrix(other) * + Matrix(relocated_scalars)) + return {x: trigsimp(vars_matrix[i]) + for i, x in enumerate(self.base_scalars())} + + def locate_new(self, name, position, vector_names=None, + variable_names=None): + """ + Returns a CoordSys3D with its origin located at the given + position wrt this coordinate system's origin. + + Parameters + ========== + + name : str + The name of the new CoordSys3D instance. + + position : Vector + The position vector of the new system's origin wrt this + one. + + vector_names, variable_names : iterable(optional) + Iterables of 3 strings each, with custom names for base + vectors and base scalars of the new system respectively. + Used for simple str printing. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> A = CoordSys3D('A') + >>> B = A.locate_new('B', 10 * A.i) + >>> B.origin.position_wrt(A.origin) + 10*A.i + + """ + if variable_names is None: + variable_names = self._variable_names + if vector_names is None: + vector_names = self._vector_names + + return CoordSys3D(name, location=position, + vector_names=vector_names, + variable_names=variable_names, + parent=self) + + def orient_new(self, name, orienters, location=None, + vector_names=None, variable_names=None): + """ + Creates a new CoordSys3D oriented in the user-specified way + with respect to this system. + + Please refer to the documentation of the orienter classes + for more information about the orientation procedure. + + Parameters + ========== + + name : str + The name of the new CoordSys3D instance. + + orienters : iterable/Orienter + An Orienter or an iterable of Orienters for orienting the + new coordinate system. + If an Orienter is provided, it is applied to get the new + system. + If an iterable is provided, the orienters will be applied + in the order in which they appear in the iterable. + + location : Vector(optional) + The location of the new coordinate system's origin wrt this + system's origin. If not specified, the origins are taken to + be coincident. + + vector_names, variable_names : iterable(optional) + Iterables of 3 strings each, with custom names for base + vectors and base scalars of the new system respectively. + Used for simple str printing. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> from sympy import symbols + >>> q0, q1, q2, q3 = symbols('q0 q1 q2 q3') + >>> N = CoordSys3D('N') + + Using an AxisOrienter + + >>> from sympy.vector import AxisOrienter + >>> axis_orienter = AxisOrienter(q1, N.i + 2 * N.j) + >>> A = N.orient_new('A', (axis_orienter, )) + + Using a BodyOrienter + + >>> from sympy.vector import BodyOrienter + >>> body_orienter = BodyOrienter(q1, q2, q3, '123') + >>> B = N.orient_new('B', (body_orienter, )) + + Using a SpaceOrienter + + >>> from sympy.vector import SpaceOrienter + >>> space_orienter = SpaceOrienter(q1, q2, q3, '312') + >>> C = N.orient_new('C', (space_orienter, )) + + Using a QuaternionOrienter + + >>> from sympy.vector import QuaternionOrienter + >>> q_orienter = QuaternionOrienter(q0, q1, q2, q3) + >>> D = N.orient_new('D', (q_orienter, )) + """ + if variable_names is None: + variable_names = self._variable_names + if vector_names is None: + vector_names = self._vector_names + + if isinstance(orienters, Orienter): + if isinstance(orienters, AxisOrienter): + final_matrix = orienters.rotation_matrix(self) + else: + final_matrix = orienters.rotation_matrix() + # TODO: trigsimp is needed here so that the matrix becomes + # canonical (scalar_map also calls trigsimp; without this, you can + # end up with the same CoordinateSystem that compares differently + # due to a differently formatted matrix). However, this is + # probably not so good for performance. + final_matrix = trigsimp(final_matrix) + else: + final_matrix = Matrix(eye(3)) + for orienter in orienters: + if isinstance(orienter, AxisOrienter): + final_matrix *= orienter.rotation_matrix(self) + else: + final_matrix *= orienter.rotation_matrix() + + return CoordSys3D(name, rotation_matrix=final_matrix, + vector_names=vector_names, + variable_names=variable_names, + location=location, + parent=self) + + def orient_new_axis(self, name, angle, axis, location=None, + vector_names=None, variable_names=None): + """ + Axis rotation is a rotation about an arbitrary axis by + some angle. The angle is supplied as a SymPy expr scalar, and + the axis is supplied as a Vector. + + Parameters + ========== + + name : string + The name of the new coordinate system + + angle : Expr + The angle by which the new system is to be rotated + + axis : Vector + The axis around which the rotation has to be performed + + location : Vector(optional) + The location of the new coordinate system's origin wrt this + system's origin. If not specified, the origins are taken to + be coincident. + + vector_names, variable_names : iterable(optional) + Iterables of 3 strings each, with custom names for base + vectors and base scalars of the new system respectively. + Used for simple str printing. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> from sympy import symbols + >>> q1 = symbols('q1') + >>> N = CoordSys3D('N') + >>> B = N.orient_new_axis('B', q1, N.i + 2 * N.j) + + """ + if variable_names is None: + variable_names = self._variable_names + if vector_names is None: + vector_names = self._vector_names + + orienter = AxisOrienter(angle, axis) + return self.orient_new(name, orienter, + location=location, + vector_names=vector_names, + variable_names=variable_names) + + def orient_new_body(self, name, angle1, angle2, angle3, + rotation_order, location=None, + vector_names=None, variable_names=None): + """ + Body orientation takes this coordinate system through three + successive simple rotations. + + Body fixed rotations include both Euler Angles and + Tait-Bryan Angles, see https://en.wikipedia.org/wiki/Euler_angles. + + Parameters + ========== + + name : string + The name of the new coordinate system + + angle1, angle2, angle3 : Expr + Three successive angles to rotate the coordinate system by + + rotation_order : string + String defining the order of axes for rotation + + location : Vector(optional) + The location of the new coordinate system's origin wrt this + system's origin. If not specified, the origins are taken to + be coincident. + + vector_names, variable_names : iterable(optional) + Iterables of 3 strings each, with custom names for base + vectors and base scalars of the new system respectively. + Used for simple str printing. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> from sympy import symbols + >>> q1, q2, q3 = symbols('q1 q2 q3') + >>> N = CoordSys3D('N') + + A 'Body' fixed rotation is described by three angles and + three body-fixed rotation axes. To orient a coordinate system D + with respect to N, each sequential rotation is always about + the orthogonal unit vectors fixed to D. For example, a '123' + rotation will specify rotations about N.i, then D.j, then + D.k. (Initially, D.i is same as N.i) + Therefore, + + >>> D = N.orient_new_body('D', q1, q2, q3, '123') + + is same as + + >>> D = N.orient_new_axis('D', q1, N.i) + >>> D = D.orient_new_axis('D', q2, D.j) + >>> D = D.orient_new_axis('D', q3, D.k) + + Acceptable rotation orders are of length 3, expressed in XYZ or + 123, and cannot have a rotation about about an axis twice in a row. + + >>> B = N.orient_new_body('B', q1, q2, q3, '123') + >>> B = N.orient_new_body('B', q1, q2, 0, 'ZXZ') + >>> B = N.orient_new_body('B', 0, 0, 0, 'XYX') + + """ + + orienter = BodyOrienter(angle1, angle2, angle3, rotation_order) + return self.orient_new(name, orienter, + location=location, + vector_names=vector_names, + variable_names=variable_names) + + def orient_new_space(self, name, angle1, angle2, angle3, + rotation_order, location=None, + vector_names=None, variable_names=None): + """ + Space rotation is similar to Body rotation, but the rotations + are applied in the opposite order. + + Parameters + ========== + + name : string + The name of the new coordinate system + + angle1, angle2, angle3 : Expr + Three successive angles to rotate the coordinate system by + + rotation_order : string + String defining the order of axes for rotation + + location : Vector(optional) + The location of the new coordinate system's origin wrt this + system's origin. If not specified, the origins are taken to + be coincident. + + vector_names, variable_names : iterable(optional) + Iterables of 3 strings each, with custom names for base + vectors and base scalars of the new system respectively. + Used for simple str printing. + + See Also + ======== + + CoordSys3D.orient_new_body : method to orient via Euler + angles + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> from sympy import symbols + >>> q1, q2, q3 = symbols('q1 q2 q3') + >>> N = CoordSys3D('N') + + To orient a coordinate system D with respect to N, each + sequential rotation is always about N's orthogonal unit vectors. + For example, a '123' rotation will specify rotations about + N.i, then N.j, then N.k. + Therefore, + + >>> D = N.orient_new_space('D', q1, q2, q3, '312') + + is same as + + >>> B = N.orient_new_axis('B', q1, N.i) + >>> C = B.orient_new_axis('C', q2, N.j) + >>> D = C.orient_new_axis('D', q3, N.k) + + """ + + orienter = SpaceOrienter(angle1, angle2, angle3, rotation_order) + return self.orient_new(name, orienter, + location=location, + vector_names=vector_names, + variable_names=variable_names) + + def orient_new_quaternion(self, name, q0, q1, q2, q3, location=None, + vector_names=None, variable_names=None): + """ + Quaternion orientation orients the new CoordSys3D with + Quaternions, defined as a finite rotation about lambda, a unit + vector, by some amount theta. + + This orientation is described by four parameters: + + q0 = cos(theta/2) + + q1 = lambda_x sin(theta/2) + + q2 = lambda_y sin(theta/2) + + q3 = lambda_z sin(theta/2) + + Quaternion does not take in a rotation order. + + Parameters + ========== + + name : string + The name of the new coordinate system + + q0, q1, q2, q3 : Expr + The quaternions to rotate the coordinate system by + + location : Vector(optional) + The location of the new coordinate system's origin wrt this + system's origin. If not specified, the origins are taken to + be coincident. + + vector_names, variable_names : iterable(optional) + Iterables of 3 strings each, with custom names for base + vectors and base scalars of the new system respectively. + Used for simple str printing. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> from sympy import symbols + >>> q0, q1, q2, q3 = symbols('q0 q1 q2 q3') + >>> N = CoordSys3D('N') + >>> B = N.orient_new_quaternion('B', q0, q1, q2, q3) + + """ + + orienter = QuaternionOrienter(q0, q1, q2, q3) + return self.orient_new(name, orienter, + location=location, + vector_names=vector_names, + variable_names=variable_names) + + def create_new(self, name, transformation, variable_names=None, vector_names=None): + """ + Returns a CoordSys3D which is connected to self by transformation. + + Parameters + ========== + + name : str + The name of the new CoordSys3D instance. + + transformation : Lambda, Tuple, str + Transformation defined by transformation equations or chosen + from predefined ones. + + vector_names, variable_names : iterable(optional) + Iterables of 3 strings each, with custom names for base + vectors and base scalars of the new system respectively. + Used for simple str printing. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> a = CoordSys3D('a') + >>> b = a.create_new('b', transformation='spherical') + >>> b.transformation_to_parent() + (b.r*sin(b.theta)*cos(b.phi), b.r*sin(b.phi)*sin(b.theta), b.r*cos(b.theta)) + >>> b.transformation_from_parent() + (sqrt(a.x**2 + a.y**2 + a.z**2), acos(a.z/sqrt(a.x**2 + a.y**2 + a.z**2)), atan2(a.y, a.x)) + + """ + return CoordSys3D(name, parent=self, transformation=transformation, + variable_names=variable_names, vector_names=vector_names) + + def __init__(self, name, location=None, rotation_matrix=None, + parent=None, vector_names=None, variable_names=None, + latex_vects=None, pretty_vects=None, latex_scalars=None, + pretty_scalars=None, transformation=None): + # Dummy initializer for setting docstring + pass + + __init__.__doc__ = __new__.__doc__ + + @staticmethod + def _compose_rotation_and_translation(rot, translation, parent): + r = lambda x, y, z: CoordSys3D._rotation_trans_equations(rot, (x, y, z)) + if parent is None: + return r + + dx, dy, dz = [translation.dot(i) for i in parent.base_vectors()] + t = lambda x, y, z: ( + x + dx, + y + dy, + z + dz, + ) + return lambda x, y, z: t(*r(x, y, z)) + + +def _check_strings(arg_name, arg): + errorstr = arg_name + " must be an iterable of 3 string-types" + if len(arg) != 3: + raise ValueError(errorstr) + for s in arg: + if not isinstance(s, str): + raise TypeError(errorstr) + + +# Delayed import to avoid cyclic import problems: +from sympy.vector.vector import BaseVector diff --git a/phivenv/Lib/site-packages/sympy/vector/deloperator.py b/phivenv/Lib/site-packages/sympy/vector/deloperator.py new file mode 100644 index 0000000000000000000000000000000000000000..51c3c0caf42b5e5d372bd65907d8bae2bd563562 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/vector/deloperator.py @@ -0,0 +1,121 @@ +from sympy.core import Basic +from sympy.vector.operators import gradient, divergence, curl + + +class Del(Basic): + """ + Represents the vector differential operator, usually represented in + mathematical expressions as the 'nabla' symbol. + """ + + def __new__(cls): + obj = super().__new__(cls) + obj._name = "delop" + return obj + + def gradient(self, scalar_field, doit=False): + """ + Returns the gradient of the given scalar field, as a + Vector instance. + + Parameters + ========== + + scalar_field : SymPy expression + The scalar field to calculate the gradient of. + + doit : bool + If True, the result is returned after calling .doit() on + each component. Else, the returned expression contains + Derivative instances + + Examples + ======== + + >>> from sympy.vector import CoordSys3D, Del + >>> C = CoordSys3D('C') + >>> delop = Del() + >>> delop.gradient(9) + 0 + >>> delop(C.x*C.y*C.z).doit() + C.y*C.z*C.i + C.x*C.z*C.j + C.x*C.y*C.k + + """ + + return gradient(scalar_field, doit=doit) + + __call__ = gradient + __call__.__doc__ = gradient.__doc__ + + def dot(self, vect, doit=False): + """ + Represents the dot product between this operator and a given + vector - equal to the divergence of the vector field. + + Parameters + ========== + + vect : Vector + The vector whose divergence is to be calculated. + + doit : bool + If True, the result is returned after calling .doit() on + each component. Else, the returned expression contains + Derivative instances + + Examples + ======== + + >>> from sympy.vector import CoordSys3D, Del + >>> delop = Del() + >>> C = CoordSys3D('C') + >>> delop.dot(C.x*C.i) + Derivative(C.x, C.x) + >>> v = C.x*C.y*C.z * (C.i + C.j + C.k) + >>> (delop & v).doit() + C.x*C.y + C.x*C.z + C.y*C.z + + """ + return divergence(vect, doit=doit) + + __and__ = dot + __and__.__doc__ = dot.__doc__ + + def cross(self, vect, doit=False): + """ + Represents the cross product between this operator and a given + vector - equal to the curl of the vector field. + + Parameters + ========== + + vect : Vector + The vector whose curl is to be calculated. + + doit : bool + If True, the result is returned after calling .doit() on + each component. Else, the returned expression contains + Derivative instances + + Examples + ======== + + >>> from sympy.vector import CoordSys3D, Del + >>> C = CoordSys3D('C') + >>> delop = Del() + >>> v = C.x*C.y*C.z * (C.i + C.j + C.k) + >>> delop.cross(v, doit = True) + (-C.x*C.y + C.x*C.z)*C.i + (C.x*C.y - C.y*C.z)*C.j + + (-C.x*C.z + C.y*C.z)*C.k + >>> (delop ^ C.i).doit() + 0 + + """ + + return curl(vect, doit=doit) + + __xor__ = cross + __xor__.__doc__ = cross.__doc__ + + def _sympystr(self, printer): + return self._name diff --git a/phivenv/Lib/site-packages/sympy/vector/dyadic.py b/phivenv/Lib/site-packages/sympy/vector/dyadic.py new file mode 100644 index 0000000000000000000000000000000000000000..980c6e6dad90ac095b7bd6d4228f507a7831b39f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/vector/dyadic.py @@ -0,0 +1,285 @@ +from __future__ import annotations + +from sympy.vector.basisdependent import (BasisDependent, BasisDependentAdd, + BasisDependentMul, BasisDependentZero) +from sympy.core import S, Pow +from sympy.core.expr import AtomicExpr +from sympy.matrices.immutable import ImmutableDenseMatrix as Matrix +import sympy.vector + + +class Dyadic(BasisDependent): + """ + Super class for all Dyadic-classes. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Dyadic_tensor + .. [2] Kane, T., Levinson, D. Dynamics Theory and Applications. 1985 + McGraw-Hill + + """ + + _op_priority = 13.0 + + _expr_type: type[Dyadic] + _mul_func: type[Dyadic] + _add_func: type[Dyadic] + _zero_func: type[Dyadic] + _base_func: type[Dyadic] + zero: DyadicZero + + @property + def components(self): + """ + Returns the components of this dyadic in the form of a + Python dictionary mapping BaseDyadic instances to the + corresponding measure numbers. + + """ + # The '_components' attribute is defined according to the + # subclass of Dyadic the instance belongs to. + return self._components + + def dot(self, other): + """ + Returns the dot product(also called inner product) of this + Dyadic, with another Dyadic or Vector. + If 'other' is a Dyadic, this returns a Dyadic. Else, it returns + a Vector (unless an error is encountered). + + Parameters + ========== + + other : Dyadic/Vector + The other Dyadic or Vector to take the inner product with + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> N = CoordSys3D('N') + >>> D1 = N.i.outer(N.j) + >>> D2 = N.j.outer(N.j) + >>> D1.dot(D2) + (N.i|N.j) + >>> D1.dot(N.j) + N.i + + """ + + Vector = sympy.vector.Vector + if isinstance(other, BasisDependentZero): + return Vector.zero + elif isinstance(other, Vector): + outvec = Vector.zero + for k, v in self.components.items(): + vect_dot = k.args[1].dot(other) + outvec += vect_dot * v * k.args[0] + return outvec + elif isinstance(other, Dyadic): + outdyad = Dyadic.zero + for k1, v1 in self.components.items(): + for k2, v2 in other.components.items(): + vect_dot = k1.args[1].dot(k2.args[0]) + outer_product = k1.args[0].outer(k2.args[1]) + outdyad += vect_dot * v1 * v2 * outer_product + return outdyad + else: + raise TypeError("Inner product is not defined for " + + str(type(other)) + " and Dyadics.") + + def __and__(self, other): + return self.dot(other) + + __and__.__doc__ = dot.__doc__ + + def cross(self, other): + """ + Returns the cross product between this Dyadic, and a Vector, as a + Vector instance. + + Parameters + ========== + + other : Vector + The Vector that we are crossing this Dyadic with + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> N = CoordSys3D('N') + >>> d = N.i.outer(N.i) + >>> d.cross(N.j) + (N.i|N.k) + + """ + + Vector = sympy.vector.Vector + if other == Vector.zero: + return Dyadic.zero + elif isinstance(other, Vector): + outdyad = Dyadic.zero + for k, v in self.components.items(): + cross_product = k.args[1].cross(other) + outer = k.args[0].outer(cross_product) + outdyad += v * outer + return outdyad + else: + raise TypeError(str(type(other)) + " not supported for " + + "cross with dyadics") + + def __xor__(self, other): + return self.cross(other) + + __xor__.__doc__ = cross.__doc__ + + def to_matrix(self, system, second_system=None): + """ + Returns the matrix form of the dyadic with respect to one or two + coordinate systems. + + Parameters + ========== + + system : CoordSys3D + The coordinate system that the rows and columns of the matrix + correspond to. If a second system is provided, this + only corresponds to the rows of the matrix. + second_system : CoordSys3D, optional, default=None + The coordinate system that the columns of the matrix correspond + to. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> N = CoordSys3D('N') + >>> v = N.i + 2*N.j + >>> d = v.outer(N.i) + >>> d.to_matrix(N) + Matrix([ + [1, 0, 0], + [2, 0, 0], + [0, 0, 0]]) + >>> from sympy import Symbol + >>> q = Symbol('q') + >>> P = N.orient_new_axis('P', q, N.k) + >>> d.to_matrix(N, P) + Matrix([ + [ cos(q), -sin(q), 0], + [2*cos(q), -2*sin(q), 0], + [ 0, 0, 0]]) + + """ + + if second_system is None: + second_system = system + + return Matrix([i.dot(self).dot(j) for i in system for j in + second_system]).reshape(3, 3) + + def _div_helper(one, other): + """ Helper for division involving dyadics """ + if isinstance(one, Dyadic) and isinstance(other, Dyadic): + raise TypeError("Cannot divide two dyadics") + elif isinstance(one, Dyadic): + return DyadicMul(one, Pow(other, S.NegativeOne)) + else: + raise TypeError("Cannot divide by a dyadic") + + +class BaseDyadic(Dyadic, AtomicExpr): + """ + Class to denote a base dyadic tensor component. + """ + + def __new__(cls, vector1, vector2): + Vector = sympy.vector.Vector + BaseVector = sympy.vector.BaseVector + VectorZero = sympy.vector.VectorZero + # Verify arguments + if not isinstance(vector1, (BaseVector, VectorZero)) or \ + not isinstance(vector2, (BaseVector, VectorZero)): + raise TypeError("BaseDyadic cannot be composed of non-base " + + "vectors") + # Handle special case of zero vector + elif vector1 == Vector.zero or vector2 == Vector.zero: + return Dyadic.zero + # Initialize instance + obj = super().__new__(cls, vector1, vector2) + obj._base_instance = obj + obj._measure_number = 1 + obj._components = {obj: S.One} + obj._sys = vector1._sys + obj._pretty_form = ('(' + vector1._pretty_form + '|' + + vector2._pretty_form + ')') + obj._latex_form = (r'\left(' + vector1._latex_form + r"{\middle|}" + + vector2._latex_form + r'\right)') + + return obj + + def _sympystr(self, printer): + return "({}|{})".format( + printer._print(self.args[0]), printer._print(self.args[1])) + + def _sympyrepr(self, printer): + return "BaseDyadic({}, {})".format( + printer._print(self.args[0]), printer._print(self.args[1])) + + +class DyadicMul(BasisDependentMul, Dyadic): + """ Products of scalars and BaseDyadics """ + + def __new__(cls, *args, **options): + obj = BasisDependentMul.__new__(cls, *args, **options) + return obj + + @property + def base_dyadic(self): + """ The BaseDyadic involved in the product. """ + return self._base_instance + + @property + def measure_number(self): + """ The scalar expression involved in the definition of + this DyadicMul. + """ + return self._measure_number + + +class DyadicAdd(BasisDependentAdd, Dyadic): + """ Class to hold dyadic sums """ + + def __new__(cls, *args, **options): + obj = BasisDependentAdd.__new__(cls, *args, **options) + return obj + + def _sympystr(self, printer): + items = list(self.components.items()) + items.sort(key=lambda x: x[0].__str__()) + return " + ".join(printer._print(k * v) for k, v in items) + + +class DyadicZero(BasisDependentZero, Dyadic): + """ + Class to denote a zero dyadic + """ + + _op_priority = 13.1 + _pretty_form = '(0|0)' + _latex_form = r'(\mathbf{\hat{0}}|\mathbf{\hat{0}})' + + def __new__(cls): + obj = BasisDependentZero.__new__(cls) + return obj + + +Dyadic._expr_type = Dyadic +Dyadic._mul_func = DyadicMul +Dyadic._add_func = DyadicAdd +Dyadic._zero_func = DyadicZero +Dyadic._base_func = BaseDyadic +Dyadic.zero = DyadicZero() diff --git a/phivenv/Lib/site-packages/sympy/vector/functions.py b/phivenv/Lib/site-packages/sympy/vector/functions.py new file mode 100644 index 0000000000000000000000000000000000000000..b78df8ae2e182f3e571ca7fa8bfabd39bf99d26e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/vector/functions.py @@ -0,0 +1,513 @@ +from sympy.vector.coordsysrect import CoordSys3D +from sympy.vector.deloperator import Del +from sympy.vector.scalar import BaseScalar +from sympy.vector.vector import Vector, BaseVector +from sympy.vector.operators import gradient, curl, divergence +from sympy.core.function import diff +from sympy.core.singleton import S +from sympy.integrals.integrals import integrate +from sympy.core import sympify +from sympy.vector.dyadic import Dyadic + + +def express(expr, system, system2=None, variables=False): + """ + Global function for 'express' functionality. + + Re-expresses a Vector, Dyadic or scalar(sympyfiable) in the given + coordinate system. + + If 'variables' is True, then the coordinate variables (base scalars) + of other coordinate systems present in the vector/scalar field or + dyadic are also substituted in terms of the base scalars of the + given system. + + Parameters + ========== + + expr : Vector/Dyadic/scalar(sympyfiable) + The expression to re-express in CoordSys3D 'system' + + system: CoordSys3D + The coordinate system the expr is to be expressed in + + system2: CoordSys3D + The other coordinate system required for re-expression + (only for a Dyadic Expr) + + variables : boolean + Specifies whether to substitute the coordinate variables present + in expr, in terms of those of parameter system + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> from sympy import Symbol, cos, sin + >>> N = CoordSys3D('N') + >>> q = Symbol('q') + >>> B = N.orient_new_axis('B', q, N.k) + >>> from sympy.vector import express + >>> express(B.i, N) + (cos(q))*N.i + (sin(q))*N.j + >>> express(N.x, B, variables=True) + B.x*cos(q) - B.y*sin(q) + >>> d = N.i.outer(N.i) + >>> express(d, B, N) == (cos(q))*(B.i|N.i) + (-sin(q))*(B.j|N.i) + True + + """ + + if expr in (0, Vector.zero): + return expr + + if not isinstance(system, CoordSys3D): + raise TypeError("system should be a CoordSys3D \ + instance") + + if isinstance(expr, Vector): + if system2 is not None: + raise ValueError("system2 should not be provided for \ + Vectors") + # Given expr is a Vector + if variables: + # If variables attribute is True, substitute + # the coordinate variables in the Vector + system_list = {x.system for x in expr.atoms(BaseScalar, BaseVector)} - {system} + subs_dict = {} + for f in system_list: + subs_dict.update(f.scalar_map(system)) + expr = expr.subs(subs_dict) + # Re-express in this coordinate system + outvec = Vector.zero + parts = expr.separate() + for x in parts: + if x != system: + temp = system.rotation_matrix(x) * parts[x].to_matrix(x) + outvec += matrix_to_vector(temp, system) + else: + outvec += parts[x] + return outvec + + elif isinstance(expr, Dyadic): + if system2 is None: + system2 = system + if not isinstance(system2, CoordSys3D): + raise TypeError("system2 should be a CoordSys3D \ + instance") + outdyad = Dyadic.zero + var = variables + for k, v in expr.components.items(): + outdyad += (express(v, system, variables=var) * + (express(k.args[0], system, variables=var) | + express(k.args[1], system2, variables=var))) + + return outdyad + + else: + if system2 is not None: + raise ValueError("system2 should not be provided for \ + Vectors") + if variables: + # Given expr is a scalar field + system_set = set() + expr = sympify(expr) + # Substitute all the coordinate variables + for x in expr.atoms(BaseScalar): + if x.system != system: + system_set.add(x.system) + subs_dict = {} + for f in system_set: + subs_dict.update(f.scalar_map(system)) + return expr.subs(subs_dict) + return expr + + +def directional_derivative(field, direction_vector): + """ + Returns the directional derivative of a scalar or vector field computed + along a given vector in coordinate system which parameters are expressed. + + Parameters + ========== + + field : Vector or Scalar + The scalar or vector field to compute the directional derivative of + + direction_vector : Vector + The vector to calculated directional derivative along them. + + + Examples + ======== + + >>> from sympy.vector import CoordSys3D, directional_derivative + >>> R = CoordSys3D('R') + >>> f1 = R.x*R.y*R.z + >>> v1 = 3*R.i + 4*R.j + R.k + >>> directional_derivative(f1, v1) + R.x*R.y + 4*R.x*R.z + 3*R.y*R.z + >>> f2 = 5*R.x**2*R.z + >>> directional_derivative(f2, v1) + 5*R.x**2 + 30*R.x*R.z + + """ + from sympy.vector.operators import _get_coord_systems + coord_sys = _get_coord_systems(field) + if len(coord_sys) > 0: + # TODO: This gets a random coordinate system in case of multiple ones: + coord_sys = next(iter(coord_sys)) + field = express(field, coord_sys, variables=True) + i, j, k = coord_sys.base_vectors() + x, y, z = coord_sys.base_scalars() + out = Vector.dot(direction_vector, i) * diff(field, x) + out += Vector.dot(direction_vector, j) * diff(field, y) + out += Vector.dot(direction_vector, k) * diff(field, z) + if out == 0 and isinstance(field, Vector): + out = Vector.zero + return out + elif isinstance(field, Vector): + return Vector.zero + else: + return S.Zero + + +def laplacian(expr): + """ + Return the laplacian of the given field computed in terms of + the base scalars of the given coordinate system. + + Parameters + ========== + + expr : SymPy Expr or Vector + expr denotes a scalar or vector field. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D, laplacian + >>> R = CoordSys3D('R') + >>> f = R.x**2*R.y**5*R.z + >>> laplacian(f) + 20*R.x**2*R.y**3*R.z + 2*R.y**5*R.z + >>> f = R.x**2*R.i + R.y**3*R.j + R.z**4*R.k + >>> laplacian(f) + 2*R.i + 6*R.y*R.j + 12*R.z**2*R.k + + """ + + delop = Del() + if expr.is_Vector: + return (gradient(divergence(expr)) - curl(curl(expr))).doit() + return delop.dot(delop(expr)).doit() + + +def is_conservative(field): + """ + Checks if a field is conservative. + + Parameters + ========== + + field : Vector + The field to check for conservative property + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> from sympy.vector import is_conservative + >>> R = CoordSys3D('R') + >>> is_conservative(R.y*R.z*R.i + R.x*R.z*R.j + R.x*R.y*R.k) + True + >>> is_conservative(R.z*R.j) + False + + """ + + # Field is conservative irrespective of system + # Take the first coordinate system in the result of the + # separate method of Vector + if not isinstance(field, Vector): + raise TypeError("field should be a Vector") + if field == Vector.zero: + return True + return curl(field).simplify() == Vector.zero + + +def is_solenoidal(field): + """ + Checks if a field is solenoidal. + + Parameters + ========== + + field : Vector + The field to check for solenoidal property + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> from sympy.vector import is_solenoidal + >>> R = CoordSys3D('R') + >>> is_solenoidal(R.y*R.z*R.i + R.x*R.z*R.j + R.x*R.y*R.k) + True + >>> is_solenoidal(R.y * R.j) + False + + """ + + # Field is solenoidal irrespective of system + # Take the first coordinate system in the result of the + # separate method in Vector + if not isinstance(field, Vector): + raise TypeError("field should be a Vector") + if field == Vector.zero: + return True + return divergence(field).simplify() is S.Zero + + +def scalar_potential(field, coord_sys): + """ + Returns the scalar potential function of a field in a given + coordinate system (without the added integration constant). + + Parameters + ========== + + field : Vector + The vector field whose scalar potential function is to be + calculated + + coord_sys : CoordSys3D + The coordinate system to do the calculation in + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> from sympy.vector import scalar_potential, gradient + >>> R = CoordSys3D('R') + >>> scalar_potential(R.k, R) == R.z + True + >>> scalar_field = 2*R.x**2*R.y*R.z + >>> grad_field = gradient(scalar_field) + >>> scalar_potential(grad_field, R) + 2*R.x**2*R.y*R.z + + """ + + # Check whether field is conservative + if not is_conservative(field): + raise ValueError("Field is not conservative") + if field == Vector.zero: + return S.Zero + # Express the field exntirely in coord_sys + # Substitute coordinate variables also + if not isinstance(coord_sys, CoordSys3D): + raise TypeError("coord_sys must be a CoordSys3D") + field = express(field, coord_sys, variables=True) + dimensions = coord_sys.base_vectors() + scalars = coord_sys.base_scalars() + # Calculate scalar potential function + temp_function = integrate(field.dot(dimensions[0]), scalars[0]) + for i, dim in enumerate(dimensions[1:]): + partial_diff = diff(temp_function, scalars[i + 1]) + partial_diff = field.dot(dim) - partial_diff + temp_function += integrate(partial_diff, scalars[i + 1]) + return temp_function + + +def scalar_potential_difference(field, coord_sys, point1, point2): + """ + Returns the scalar potential difference between two points in a + certain coordinate system, wrt a given field. + + If a scalar field is provided, its values at the two points are + considered. If a conservative vector field is provided, the values + of its scalar potential function at the two points are used. + + Returns (potential at point2) - (potential at point1) + + The position vectors of the two Points are calculated wrt the + origin of the coordinate system provided. + + Parameters + ========== + + field : Vector/Expr + The field to calculate wrt + + coord_sys : CoordSys3D + The coordinate system to do the calculations in + + point1 : Point + The initial Point in given coordinate system + + position2 : Point + The second Point in the given coordinate system + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> from sympy.vector import scalar_potential_difference + >>> R = CoordSys3D('R') + >>> P = R.origin.locate_new('P', R.x*R.i + R.y*R.j + R.z*R.k) + >>> vectfield = 4*R.x*R.y*R.i + 2*R.x**2*R.j + >>> scalar_potential_difference(vectfield, R, R.origin, P) + 2*R.x**2*R.y + >>> Q = R.origin.locate_new('O', 3*R.i + R.j + 2*R.k) + >>> scalar_potential_difference(vectfield, R, P, Q) + -2*R.x**2*R.y + 18 + + """ + + if not isinstance(coord_sys, CoordSys3D): + raise TypeError("coord_sys must be a CoordSys3D") + if isinstance(field, Vector): + # Get the scalar potential function + scalar_fn = scalar_potential(field, coord_sys) + else: + # Field is a scalar + scalar_fn = field + # Express positions in required coordinate system + origin = coord_sys.origin + position1 = express(point1.position_wrt(origin), coord_sys, + variables=True) + position2 = express(point2.position_wrt(origin), coord_sys, + variables=True) + # Get the two positions as substitution dicts for coordinate variables + subs_dict1 = {} + subs_dict2 = {} + scalars = coord_sys.base_scalars() + for i, x in enumerate(coord_sys.base_vectors()): + subs_dict1[scalars[i]] = x.dot(position1) + subs_dict2[scalars[i]] = x.dot(position2) + return scalar_fn.subs(subs_dict2) - scalar_fn.subs(subs_dict1) + + +def matrix_to_vector(matrix, system): + """ + Converts a vector in matrix form to a Vector instance. + + It is assumed that the elements of the Matrix represent the + measure numbers of the components of the vector along basis + vectors of 'system'. + + Parameters + ========== + + matrix : SymPy Matrix, Dimensions: (3, 1) + The matrix to be converted to a vector + + system : CoordSys3D + The coordinate system the vector is to be defined in + + Examples + ======== + + >>> from sympy import ImmutableMatrix as Matrix + >>> m = Matrix([1, 2, 3]) + >>> from sympy.vector import CoordSys3D, matrix_to_vector + >>> C = CoordSys3D('C') + >>> v = matrix_to_vector(m, C) + >>> v + C.i + 2*C.j + 3*C.k + >>> v.to_matrix(C) == m + True + + """ + + outvec = Vector.zero + vects = system.base_vectors() + for i, x in enumerate(matrix): + outvec += x * vects[i] + return outvec + + +def _path(from_object, to_object): + """ + Calculates the 'path' of objects starting from 'from_object' + to 'to_object', along with the index of the first common + ancestor in the tree. + + Returns (index, list) tuple. + """ + + if from_object._root != to_object._root: + raise ValueError("No connecting path found between " + + str(from_object) + " and " + str(to_object)) + + other_path = [] + obj = to_object + while obj._parent is not None: + other_path.append(obj) + obj = obj._parent + other_path.append(obj) + object_set = set(other_path) + from_path = [] + obj = from_object + while obj not in object_set: + from_path.append(obj) + obj = obj._parent + index = len(from_path) + from_path.extend(other_path[other_path.index(obj)::-1]) + return index, from_path + + +def orthogonalize(*vlist, orthonormal=False): + """ + Takes a sequence of independent vectors and orthogonalizes them + using the Gram - Schmidt process. Returns a list of + orthogonal or orthonormal vectors. + + Parameters + ========== + + vlist : sequence of independent vectors to be made orthogonal. + + orthonormal : Optional parameter + Set to True if the vectors returned should be + orthonormal. + Default: False + + Examples + ======== + + >>> from sympy.vector.coordsysrect import CoordSys3D + >>> from sympy.vector.functions import orthogonalize + >>> C = CoordSys3D('C') + >>> i, j, k = C.base_vectors() + >>> v1 = i + 2*j + >>> v2 = 2*i + 3*j + >>> orthogonalize(v1, v2) + [C.i + 2*C.j, 2/5*C.i + (-1/5)*C.j] + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Gram-Schmidt_process + + """ + + if not all(isinstance(vec, Vector) for vec in vlist): + raise TypeError('Each element must be of Type Vector') + + ortho_vlist = [] + for i, term in enumerate(vlist): + for j in range(i): + term -= ortho_vlist[j].projection(vlist[i]) + # TODO : The following line introduces a performance issue + # and needs to be changed once a good solution for issue #10279 is + # found. + if term.equals(Vector.zero): + raise ValueError("Vector set not linearly independent") + ortho_vlist.append(term) + + if orthonormal: + ortho_vlist = [vec.normalize() for vec in ortho_vlist] + + return ortho_vlist diff --git a/phivenv/Lib/site-packages/sympy/vector/implicitregion.py b/phivenv/Lib/site-packages/sympy/vector/implicitregion.py new file mode 100644 index 0000000000000000000000000000000000000000..ed2d55a1be8b1eaca71d08b632a94886a2b0269c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/vector/implicitregion.py @@ -0,0 +1,506 @@ +from sympy.core.numbers import Rational +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.complexes import sign +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.polys.polytools import gcd +from sympy.sets.sets import Complement +from sympy.core import Basic, Tuple, diff, expand, Eq, Integer +from sympy.core.sorting import ordered +from sympy.core.symbol import _symbol +from sympy.solvers import solveset, nonlinsolve, diophantine +from sympy.polys import total_degree +from sympy.geometry import Point +from sympy.ntheory.factor_ import core + + +class ImplicitRegion(Basic): + """ + Represents an implicit region in space. + + Examples + ======== + + >>> from sympy import Eq + >>> from sympy.abc import x, y, z, t + >>> from sympy.vector import ImplicitRegion + + >>> ImplicitRegion((x, y), x**2 + y**2 - 4) + ImplicitRegion((x, y), x**2 + y**2 - 4) + >>> ImplicitRegion((x, y), Eq(y*x, 1)) + ImplicitRegion((x, y), x*y - 1) + + >>> parabola = ImplicitRegion((x, y), y**2 - 4*x) + >>> parabola.degree + 2 + >>> parabola.equation + -4*x + y**2 + >>> parabola.rational_parametrization(t) + (4/t**2, 4/t) + + >>> r = ImplicitRegion((x, y, z), Eq(z, x**2 + y**2)) + >>> r.variables + (x, y, z) + >>> r.singular_points() + EmptySet + >>> r.regular_point() + (-10, -10, 200) + + Parameters + ========== + + variables : tuple to map variables in implicit equation to base scalars. + + equation : An expression or Eq denoting the implicit equation of the region. + + """ + def __new__(cls, variables, equation): + if not isinstance(variables, Tuple): + variables = Tuple(*variables) + + if isinstance(equation, Eq): + equation = equation.lhs - equation.rhs + + return super().__new__(cls, variables, equation) + + @property + def variables(self): + return self.args[0] + + @property + def equation(self): + return self.args[1] + + @property + def degree(self): + return total_degree(self.equation) + + def regular_point(self): + """ + Returns a point on the implicit region. + + Examples + ======== + + >>> from sympy.abc import x, y, z + >>> from sympy.vector import ImplicitRegion + >>> circle = ImplicitRegion((x, y), (x + 2)**2 + (y - 3)**2 - 16) + >>> circle.regular_point() + (-2, -1) + >>> parabola = ImplicitRegion((x, y), x**2 - 4*y) + >>> parabola.regular_point() + (0, 0) + >>> r = ImplicitRegion((x, y, z), (x + y + z)**4) + >>> r.regular_point() + (-10, -10, 20) + + References + ========== + + - Erik Hillgarter, "Rational Points on Conics", Diploma Thesis, RISC-Linz, + J. Kepler Universitat Linz, 1996. Available: + https://www3.risc.jku.at/publications/download/risc_1355/Rational%20Points%20on%20Conics.pdf + + """ + equation = self.equation + + if len(self.variables) == 1: + return (list(solveset(equation, self.variables[0], domain=S.Reals))[0],) + elif len(self.variables) == 2: + + if self.degree == 2: + coeffs = a, b, c, d, e, f = conic_coeff(self.variables, equation) + + if b**2 == 4*a*c: + x_reg, y_reg = self._regular_point_parabola(*coeffs) + else: + x_reg, y_reg = self._regular_point_ellipse(*coeffs) + return x_reg, y_reg + + if len(self.variables) == 3: + x, y, z = self.variables + + for x_reg in range(-10, 10): + for y_reg in range(-10, 10): + if not solveset(equation.subs({x: x_reg, y: y_reg}), self.variables[2], domain=S.Reals).is_empty: + return (x_reg, y_reg, list(solveset(equation.subs({x: x_reg, y: y_reg})))[0]) + + if len(self.singular_points()) != 0: + return list[self.singular_points()][0] + + raise NotImplementedError() + + def _regular_point_parabola(self, a, b, c, d, e, f): + ok = (a, d) != (0, 0) and (c, e) != (0, 0) and b**2 == 4*a*c and (a, c) != (0, 0) + + if not ok: + raise ValueError("Rational Point on the conic does not exist") + + if a != 0: + d_dash, f_dash = (4*a*e - 2*b*d, 4*a*f - d**2) + if d_dash != 0: + y_reg = -f_dash/d_dash + x_reg = -(d + b*y_reg)/(2*a) + else: + ok = False + elif c != 0: + d_dash, f_dash = (4*c*d - 2*b*e, 4*c*f - e**2) + if d_dash != 0: + x_reg = -f_dash/d_dash + y_reg = -(e + b*x_reg)/(2*c) + else: + ok = False + + if ok: + return x_reg, y_reg + else: + raise ValueError("Rational Point on the conic does not exist") + + def _regular_point_ellipse(self, a, b, c, d, e, f): + D = 4*a*c - b**2 + ok = D + + if not ok: + raise ValueError("Rational Point on the conic does not exist") + + if a == 0 and c == 0: + K = -1 + L = 4*(d*e - b*f) + elif c != 0: + K = D + L = 4*c**2*d**2 - 4*b*c*d*e + 4*a*c*e**2 + 4*b**2*c*f - 16*a*c**2*f + else: + K = D + L = 4*a**2*e**2 - 4*b*a*d*e + 4*b**2*a*f + + ok = L != 0 and not(K > 0 and L < 0) + if not ok: + raise ValueError("Rational Point on the conic does not exist") + + K = Rational(K).limit_denominator(10**12) + L = Rational(L).limit_denominator(10**12) + + k1, k2 = K.p, K.q + l1, l2 = L.p, L.q + g = gcd(k2, l2) + + a1 = (l2*k2)/g + b1 = (k1*l2)/g + c1 = -(l1*k2)/g + a2 = sign(a1)*core(abs(a1), 2) + r1 = sqrt(a1/a2) + b2 = sign(b1)*core(abs(b1), 2) + r2 = sqrt(b1/b2) + c2 = sign(c1)*core(abs(c1), 2) + r3 = sqrt(c1/c2) + + g = gcd(gcd(a2, b2), c2) + a2 = a2/g + b2 = b2/g + c2 = c2/g + + g1 = gcd(a2, b2) + a2 = a2/g1 + b2 = b2/g1 + c2 = c2*g1 + + g2 = gcd(a2,c2) + a2 = a2/g2 + b2 = b2*g2 + c2 = c2/g2 + + g3 = gcd(b2, c2) + a2 = a2*g3 + b2 = b2/g3 + c2 = c2/g3 + + x, y, z = symbols("x y z") + eq = a2*x**2 + b2*y**2 + c2*z**2 + + solutions = diophantine(eq) + + if len(solutions) == 0: + raise ValueError("Rational Point on the conic does not exist") + + flag = False + for sol in solutions: + syms = Tuple(*sol).free_symbols + rep = dict.fromkeys(syms, 3) + sol_z = sol[2] + + if sol_z == 0: + flag = True + continue + + if not isinstance(sol_z, (int, Integer)): + syms_z = sol_z.free_symbols + + if len(syms_z) == 1: + p = next(iter(syms_z)) + p_values = Complement(S.Integers, solveset(Eq(sol_z, 0), p, S.Integers)) + rep[p] = next(iter(p_values)) + + if len(syms_z) == 2: + p, q = list(ordered(syms_z)) + + for i in S.Integers: + subs_sol_z = sol_z.subs(p, i) + q_values = Complement(S.Integers, solveset(Eq(subs_sol_z, 0), q, S.Integers)) + + if not q_values.is_empty: + rep[p] = i + rep[q] = next(iter(q_values)) + break + + if len(syms) != 0: + x, y, z = tuple(s.subs(rep) for s in sol) + else: + x, y, z = sol + flag = False + break + + if flag: + raise ValueError("Rational Point on the conic does not exist") + + x = (x*g3)/r1 + y = (y*g2)/r2 + z = (z*g1)/r3 + x = x/z + y = y/z + + if a == 0 and c == 0: + x_reg = (x + y - 2*e)/(2*b) + y_reg = (x - y - 2*d)/(2*b) + elif c != 0: + x_reg = (x - 2*d*c + b*e)/K + y_reg = (y - b*x_reg - e)/(2*c) + else: + y_reg = (x - 2*e*a + b*d)/K + x_reg = (y - b*y_reg - d)/(2*a) + + return x_reg, y_reg + + def singular_points(self): + """ + Returns a set of singular points of the region. + + The singular points are those points on the region + where all partial derivatives vanish. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy.vector import ImplicitRegion + >>> I = ImplicitRegion((x, y), (y-1)**2 -x**3 + 2*x**2 -x) + >>> I.singular_points() + {(1, 1)} + + """ + eq_list = [self.equation] + for var in self.variables: + eq_list += [diff(self.equation, var)] + + return nonlinsolve(eq_list, list(self.variables)) + + def multiplicity(self, point): + """ + Returns the multiplicity of a singular point on the region. + + A singular point (x,y) of region is said to be of multiplicity m + if all the partial derivatives off to order m - 1 vanish there. + + Examples + ======== + + >>> from sympy.abc import x, y, z + >>> from sympy.vector import ImplicitRegion + >>> I = ImplicitRegion((x, y, z), x**2 + y**3 - z**4) + >>> I.singular_points() + {(0, 0, 0)} + >>> I.multiplicity((0, 0, 0)) + 2 + + """ + if isinstance(point, Point): + point = point.args + + modified_eq = self.equation + + for i, var in enumerate(self.variables): + modified_eq = modified_eq.subs(var, var + point[i]) + modified_eq = expand(modified_eq) + + if len(modified_eq.args) != 0: + terms = modified_eq.args + m = min(total_degree(term) for term in terms) + else: + terms = modified_eq + m = total_degree(terms) + + return m + + def rational_parametrization(self, parameters=('t', 's'), reg_point=None): + """ + Returns the rational parametrization of implicit region. + + Examples + ======== + + >>> from sympy import Eq + >>> from sympy.abc import x, y, z, s, t + >>> from sympy.vector import ImplicitRegion + + >>> parabola = ImplicitRegion((x, y), y**2 - 4*x) + >>> parabola.rational_parametrization() + (4/t**2, 4/t) + + >>> circle = ImplicitRegion((x, y), Eq(x**2 + y**2, 4)) + >>> circle.rational_parametrization() + (4*t/(t**2 + 1), 4*t**2/(t**2 + 1) - 2) + + >>> I = ImplicitRegion((x, y), x**3 + x**2 - y**2) + >>> I.rational_parametrization() + (t**2 - 1, t*(t**2 - 1)) + + >>> cubic_curve = ImplicitRegion((x, y), x**3 + x**2 - y**2) + >>> cubic_curve.rational_parametrization(parameters=(t)) + (t**2 - 1, t*(t**2 - 1)) + + >>> sphere = ImplicitRegion((x, y, z), x**2 + y**2 + z**2 - 4) + >>> sphere.rational_parametrization(parameters=(t, s)) + (-2 + 4/(s**2 + t**2 + 1), 4*s/(s**2 + t**2 + 1), 4*t/(s**2 + t**2 + 1)) + + For some conics, regular_points() is unable to find a point on curve. + To calulcate the parametric representation in such cases, user need + to determine a point on the region and pass it using reg_point. + + >>> c = ImplicitRegion((x, y), (x - 1/2)**2 + (y)**2 - (1/4)**2) + >>> c.rational_parametrization(reg_point=(3/4, 0)) + (0.75 - 0.5/(t**2 + 1), -0.5*t/(t**2 + 1)) + + References + ========== + + - Christoph M. Hoffmann, "Conversion Methods between Parametric and + Implicit Curves and Surfaces", Purdue e-Pubs, 1990. Available: + https://docs.lib.purdue.edu/cgi/viewcontent.cgi?article=1827&context=cstech + + """ + equation = self.equation + degree = self.degree + + if degree == 1: + if len(self.variables) == 1: + return (equation,) + elif len(self.variables) == 2: + x, y = self.variables + y_par = list(solveset(equation, y))[0] + return x, y_par + else: + raise NotImplementedError() + + point = () + + # Finding the (n - 1) fold point of the monoid of degree + if degree == 2: + # For degree 2 curves, either a regular point or a singular point can be used. + if reg_point is not None: + # Using point provided by the user as regular point + point = reg_point + else: + if len(self.singular_points()) != 0: + point = list(self.singular_points())[0] + else: + point = self.regular_point() + + if len(self.singular_points()) != 0: + singular_points = self.singular_points() + for spoint in singular_points: + syms = Tuple(*spoint).free_symbols + rep = dict.fromkeys(syms, 2) + + if len(syms) != 0: + spoint = tuple(s.subs(rep) for s in spoint) + + if self.multiplicity(spoint) == degree - 1: + point = spoint + break + + if len(point) == 0: + # The region in not a monoid + raise NotImplementedError() + + modified_eq = equation + + # Shifting the region such that fold point moves to origin + for i, var in enumerate(self.variables): + modified_eq = modified_eq.subs(var, var + point[i]) + modified_eq = expand(modified_eq) + + hn = hn_1 = 0 + for term in modified_eq.args: + if total_degree(term) == degree: + hn += term + else: + hn_1 += term + + hn_1 = -1*hn_1 + + if not isinstance(parameters, tuple): + parameters = (parameters,) + + if len(self.variables) == 2: + + parameter1 = parameters[0] + if parameter1 == 's': + # To avoid name conflict between parameters + s = _symbol('s_', real=True) + else: + s = _symbol('s', real=True) + t = _symbol(parameter1, real=True) + + hn = hn.subs({self.variables[0]: s, self.variables[1]: t}) + hn_1 = hn_1.subs({self.variables[0]: s, self.variables[1]: t}) + + x_par = (s*(hn_1/hn)).subs(s, 1) + point[0] + y_par = (t*(hn_1/hn)).subs(s, 1) + point[1] + + return x_par, y_par + + elif len(self.variables) == 3: + + parameter1, parameter2 = parameters + if 'r' in parameters: + # To avoid name conflict between parameters + r = _symbol('r_', real=True) + else: + r = _symbol('r', real=True) + s = _symbol(parameter2, real=True) + t = _symbol(parameter1, real=True) + + hn = hn.subs({self.variables[0]: r, self.variables[1]: s, self.variables[2]: t}) + hn_1 = hn_1.subs({self.variables[0]: r, self.variables[1]: s, self.variables[2]: t}) + + x_par = (r*(hn_1/hn)).subs(r, 1) + point[0] + y_par = (s*(hn_1/hn)).subs(r, 1) + point[1] + z_par = (t*(hn_1/hn)).subs(r, 1) + point[2] + + return x_par, y_par, z_par + + raise NotImplementedError() + +def conic_coeff(variables, equation): + if total_degree(equation) != 2: + raise ValueError() + x = variables[0] + y = variables[1] + + equation = expand(equation) + a = equation.coeff(x**2) + b = equation.coeff(x*y) + c = equation.coeff(y**2) + d = equation.coeff(x, 1).coeff(y, 0) + e = equation.coeff(y, 1).coeff(x, 0) + f = equation.coeff(x, 0).coeff(y, 0) + return a, b, c, d, e, f diff --git a/phivenv/Lib/site-packages/sympy/vector/integrals.py b/phivenv/Lib/site-packages/sympy/vector/integrals.py new file mode 100644 index 0000000000000000000000000000000000000000..a6451c182f214b20b1105eb0a4dc243455c9d126 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/vector/integrals.py @@ -0,0 +1,206 @@ +from sympy.core import Basic, diff +from sympy.core.singleton import S +from sympy.core.sorting import default_sort_key +from sympy.matrices import Matrix +from sympy.integrals import Integral, integrate +from sympy.geometry.entity import GeometryEntity +from sympy.simplify.simplify import simplify +from sympy.utilities.iterables import topological_sort +from sympy.vector import (CoordSys3D, Vector, ParametricRegion, + parametric_region_list, ImplicitRegion) +from sympy.vector.operators import _get_coord_systems + + +class ParametricIntegral(Basic): + """ + Represents integral of a scalar or vector field + over a Parametric Region + + Examples + ======== + + >>> from sympy import cos, sin, pi + >>> from sympy.vector import CoordSys3D, ParametricRegion, ParametricIntegral + >>> from sympy.abc import r, t, theta, phi + + >>> C = CoordSys3D('C') + >>> curve = ParametricRegion((3*t - 2, t + 1), (t, 1, 2)) + >>> ParametricIntegral(C.x, curve) + 5*sqrt(10)/2 + >>> length = ParametricIntegral(1, curve) + >>> length + sqrt(10) + >>> semisphere = ParametricRegion((2*sin(phi)*cos(theta), 2*sin(phi)*sin(theta), 2*cos(phi)),\ + (theta, 0, 2*pi), (phi, 0, pi/2)) + >>> ParametricIntegral(C.z, semisphere) + 8*pi + + >>> ParametricIntegral(C.j + C.k, ParametricRegion((r*cos(theta), r*sin(theta)), r, theta)) + 0 + + """ + + def __new__(cls, field, parametricregion): + + coord_set = _get_coord_systems(field) + + if len(coord_set) == 0: + coord_sys = CoordSys3D('C') + elif len(coord_set) > 1: + raise ValueError + else: + coord_sys = next(iter(coord_set)) + + if parametricregion.dimensions == 0: + return S.Zero + + base_vectors = coord_sys.base_vectors() + base_scalars = coord_sys.base_scalars() + + parametricfield = field + + r = Vector.zero + for i in range(len(parametricregion.definition)): + r += base_vectors[i]*parametricregion.definition[i] + + if len(coord_set) != 0: + for i in range(len(parametricregion.definition)): + parametricfield = parametricfield.subs(base_scalars[i], parametricregion.definition[i]) + + if parametricregion.dimensions == 1: + parameter = parametricregion.parameters[0] + + r_diff = diff(r, parameter) + lower, upper = parametricregion.limits[parameter][0], parametricregion.limits[parameter][1] + + if isinstance(parametricfield, Vector): + integrand = simplify(r_diff.dot(parametricfield)) + else: + integrand = simplify(r_diff.magnitude()*parametricfield) + + result = integrate(integrand, (parameter, lower, upper)) + + elif parametricregion.dimensions == 2: + u, v = cls._bounds_case(parametricregion.parameters, parametricregion.limits) + + r_u = diff(r, u) + r_v = diff(r, v) + normal_vector = simplify(r_u.cross(r_v)) + + if isinstance(parametricfield, Vector): + integrand = parametricfield.dot(normal_vector) + else: + integrand = parametricfield*normal_vector.magnitude() + + integrand = simplify(integrand) + + lower_u, upper_u = parametricregion.limits[u][0], parametricregion.limits[u][1] + lower_v, upper_v = parametricregion.limits[v][0], parametricregion.limits[v][1] + + result = integrate(integrand, (u, lower_u, upper_u), (v, lower_v, upper_v)) + + else: + variables = cls._bounds_case(parametricregion.parameters, parametricregion.limits) + coeff = Matrix(parametricregion.definition).jacobian(variables).det() + integrand = simplify(parametricfield*coeff) + + l = [(var, parametricregion.limits[var][0], parametricregion.limits[var][1]) for var in variables] + result = integrate(integrand, *l) + + if not isinstance(result, Integral): + return result + else: + return super().__new__(cls, field, parametricregion) + + @classmethod + def _bounds_case(cls, parameters, limits): + + V = list(limits.keys()) + E = [] + + for p in V: + lower_p = limits[p][0] + upper_p = limits[p][1] + + lower_p = lower_p.atoms() + upper_p = upper_p.atoms() + E.extend((p, q) for q in V if p != q and + (lower_p.issuperset({q}) or upper_p.issuperset({q}))) + + if not E: + return parameters + else: + return topological_sort((V, E), key=default_sort_key) + + @property + def field(self): + return self.args[0] + + @property + def parametricregion(self): + return self.args[1] + + +def vector_integrate(field, *region): + """ + Compute the integral of a vector/scalar field + over a a region or a set of parameters. + + Examples + ======== + >>> from sympy.vector import CoordSys3D, ParametricRegion, vector_integrate + >>> from sympy.abc import x, y, t + >>> C = CoordSys3D('C') + + >>> region = ParametricRegion((t, t**2), (t, 1, 5)) + >>> vector_integrate(C.x*C.i, region) + 12 + + Integrals over some objects of geometry module can also be calculated. + + >>> from sympy.geometry import Point, Circle, Triangle + >>> c = Circle(Point(0, 2), 5) + >>> vector_integrate(C.x**2 + C.y**2, c) + 290*pi + >>> triangle = Triangle(Point(-2, 3), Point(2, 3), Point(0, 5)) + >>> vector_integrate(3*C.x**2*C.y*C.i + C.j, triangle) + -8 + + Integrals over some simple implicit regions can be computed. But in most cases, + it takes too long to compute over them. This is due to the expressions of parametric + representation becoming large. + + >>> from sympy.vector import ImplicitRegion + >>> c2 = ImplicitRegion((x, y), (x - 2)**2 + (y - 1)**2 - 9) + >>> vector_integrate(1, c2) + 6*pi + + Integral of fields with respect to base scalars: + + >>> vector_integrate(12*C.y**3, (C.y, 1, 3)) + 240 + >>> vector_integrate(C.x**2*C.z, C.x) + C.x**3*C.z/3 + >>> vector_integrate(C.x*C.i - C.y*C.k, C.x) + (Integral(C.x, C.x))*C.i + (Integral(-C.y, C.x))*C.k + >>> _.doit() + C.x**2/2*C.i + (-C.x*C.y)*C.k + + """ + if len(region) == 1: + if isinstance(region[0], ParametricRegion): + return ParametricIntegral(field, region[0]) + + if isinstance(region[0], ImplicitRegion): + region = parametric_region_list(region[0])[0] + return vector_integrate(field, region) + + if isinstance(region[0], GeometryEntity): + regions_list = parametric_region_list(region[0]) + + result = 0 + for reg in regions_list: + result += vector_integrate(field, reg) + return result + + return integrate(field, *region) diff --git a/phivenv/Lib/site-packages/sympy/vector/kind.py b/phivenv/Lib/site-packages/sympy/vector/kind.py new file mode 100644 index 0000000000000000000000000000000000000000..c6c04896b34c9c92c3fb340d94985df859e5877d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/vector/kind.py @@ -0,0 +1,67 @@ +#sympy.vector.kind + +from sympy.core.kind import Kind, _NumberKind, NumberKind +from sympy.core.mul import Mul + +class VectorKind(Kind): + """ + Kind for all vector objects in SymPy. + + Parameters + ========== + + element_kind : Kind + Kind of the element. Default is + :class:`sympy.core.kind.NumberKind`, + which means that the vector contains only numbers. + + Examples + ======== + + Any instance of Vector class has kind ``VectorKind``: + + >>> from sympy.vector.coordsysrect import CoordSys3D + >>> Sys = CoordSys3D('Sys') + >>> Sys.i.kind + VectorKind(NumberKind) + + Operations between instances of Vector keep also have the kind ``VectorKind``: + + >>> from sympy.core.add import Add + >>> v1 = Sys.i * 2 + Sys.j * 3 + Sys.k * 4 + >>> v2 = Sys.i * Sys.x + Sys.j * Sys.y + Sys.k * Sys.z + >>> v1.kind + VectorKind(NumberKind) + >>> v2.kind + VectorKind(NumberKind) + >>> Add(v1, v2).kind + VectorKind(NumberKind) + + Subclasses of Vector also have the kind ``VectorKind``, such as + Cross, VectorAdd, VectorMul or VectorZero. + + See Also + ======== + + sympy.core.kind.Kind + sympy.matrices.kind.MatrixKind + + """ + def __new__(cls, element_kind=NumberKind): + obj = super().__new__(cls, element_kind) + obj.element_kind = element_kind + return obj + + def __repr__(self): + return "VectorKind(%s)" % self.element_kind + +@Mul._kind_dispatcher.register(_NumberKind, VectorKind) +def num_vec_mul(k1, k2): + """ + The result of a multiplication between a number and a Vector should be of VectorKind. + The element kind is selected by recursive dispatching. + """ + if not isinstance(k2, VectorKind): + k1, k2 = k2, k1 + elemk = Mul._kind_dispatcher(k1, k2.element_kind) + return VectorKind(elemk) diff --git a/phivenv/Lib/site-packages/sympy/vector/operators.py b/phivenv/Lib/site-packages/sympy/vector/operators.py new file mode 100644 index 0000000000000000000000000000000000000000..3ca42d20302f972cef66e0b4a35ac75606b2da94 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/vector/operators.py @@ -0,0 +1,335 @@ +import collections +from sympy.core.expr import Expr +from sympy.core import sympify, S, preorder_traversal +from sympy.vector.coordsysrect import CoordSys3D +from sympy.vector.vector import Vector, VectorMul, VectorAdd, Cross, Dot +from sympy.core.function import Derivative +from sympy.core.add import Add +from sympy.core.mul import Mul + + +def _get_coord_systems(expr): + g = preorder_traversal(expr) + ret = set() + for i in g: + if isinstance(i, CoordSys3D): + ret.add(i) + g.skip() + return frozenset(ret) + + +def _split_mul_args_wrt_coordsys(expr): + d = collections.defaultdict(lambda: S.One) + for i in expr.args: + d[_get_coord_systems(i)] *= i + return list(d.values()) + + +class Gradient(Expr): + """ + Represents unevaluated Gradient. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D, Gradient + >>> R = CoordSys3D('R') + >>> s = R.x*R.y*R.z + >>> Gradient(s) + Gradient(R.x*R.y*R.z) + + """ + + def __new__(cls, expr): + expr = sympify(expr) + obj = Expr.__new__(cls, expr) + obj._expr = expr + return obj + + def doit(self, **hints): + return gradient(self._expr, doit=True) + + +class Divergence(Expr): + """ + Represents unevaluated Divergence. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D, Divergence + >>> R = CoordSys3D('R') + >>> v = R.y*R.z*R.i + R.x*R.z*R.j + R.x*R.y*R.k + >>> Divergence(v) + Divergence(R.y*R.z*R.i + R.x*R.z*R.j + R.x*R.y*R.k) + + """ + + def __new__(cls, expr): + expr = sympify(expr) + obj = Expr.__new__(cls, expr) + obj._expr = expr + return obj + + def doit(self, **hints): + return divergence(self._expr, doit=True) + + +class Curl(Expr): + """ + Represents unevaluated Curl. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D, Curl + >>> R = CoordSys3D('R') + >>> v = R.y*R.z*R.i + R.x*R.z*R.j + R.x*R.y*R.k + >>> Curl(v) + Curl(R.y*R.z*R.i + R.x*R.z*R.j + R.x*R.y*R.k) + + """ + + def __new__(cls, expr): + expr = sympify(expr) + obj = Expr.__new__(cls, expr) + obj._expr = expr + return obj + + def doit(self, **hints): + return curl(self._expr, doit=True) + + +def curl(vect, doit=True): + """ + Returns the curl of a vector field computed wrt the base scalars + of the given coordinate system. + + Parameters + ========== + + vect : Vector + The vector operand + + doit : bool + If True, the result is returned after calling .doit() on + each component. Else, the returned expression contains + Derivative instances + + Examples + ======== + + >>> from sympy.vector import CoordSys3D, curl + >>> R = CoordSys3D('R') + >>> v1 = R.y*R.z*R.i + R.x*R.z*R.j + R.x*R.y*R.k + >>> curl(v1) + 0 + >>> v2 = R.x*R.y*R.z*R.i + >>> curl(v2) + R.x*R.y*R.j + (-R.x*R.z)*R.k + + """ + + coord_sys = _get_coord_systems(vect) + + if len(coord_sys) == 0: + return Vector.zero + elif len(coord_sys) == 1: + coord_sys = next(iter(coord_sys)) + i, j, k = coord_sys.base_vectors() + x, y, z = coord_sys.base_scalars() + h1, h2, h3 = coord_sys.lame_coefficients() + vectx = vect.dot(i) + vecty = vect.dot(j) + vectz = vect.dot(k) + outvec = Vector.zero + outvec += (Derivative(vectz * h3, y) - + Derivative(vecty * h2, z)) * i / (h2 * h3) + outvec += (Derivative(vectx * h1, z) - + Derivative(vectz * h3, x)) * j / (h1 * h3) + outvec += (Derivative(vecty * h2, x) - + Derivative(vectx * h1, y)) * k / (h2 * h1) + + if doit: + return outvec.doit() + return outvec + else: + if isinstance(vect, (Add, VectorAdd)): + from sympy.vector import express + try: + cs = next(iter(coord_sys)) + args = [express(i, cs, variables=True) for i in vect.args] + except ValueError: + args = vect.args + return VectorAdd.fromiter(curl(i, doit=doit) for i in args) + elif isinstance(vect, (Mul, VectorMul)): + vector = [i for i in vect.args if isinstance(i, (Vector, Cross, Gradient))][0] + scalar = Mul.fromiter(i for i in vect.args if not isinstance(i, (Vector, Cross, Gradient))) + res = Cross(gradient(scalar), vector).doit() + scalar*curl(vector, doit=doit) + if doit: + return res.doit() + return res + elif isinstance(vect, (Cross, Curl, Gradient)): + return Curl(vect) + else: + raise ValueError("Invalid argument for curl") + + +def divergence(vect, doit=True): + """ + Returns the divergence of a vector field computed wrt the base + scalars of the given coordinate system. + + Parameters + ========== + + vector : Vector + The vector operand + + doit : bool + If True, the result is returned after calling .doit() on + each component. Else, the returned expression contains + Derivative instances + + Examples + ======== + + >>> from sympy.vector import CoordSys3D, divergence + >>> R = CoordSys3D('R') + >>> v1 = R.x*R.y*R.z * (R.i+R.j+R.k) + + >>> divergence(v1) + R.x*R.y + R.x*R.z + R.y*R.z + >>> v2 = 2*R.y*R.z*R.j + >>> divergence(v2) + 2*R.z + + """ + coord_sys = _get_coord_systems(vect) + if len(coord_sys) == 0: + return S.Zero + elif len(coord_sys) == 1: + if isinstance(vect, (Cross, Curl, Gradient)): + return Divergence(vect) + # TODO: is case of many coord systems, this gets a random one: + coord_sys = next(iter(coord_sys)) + i, j, k = coord_sys.base_vectors() + x, y, z = coord_sys.base_scalars() + h1, h2, h3 = coord_sys.lame_coefficients() + vx = _diff_conditional(vect.dot(i), x, h2, h3) \ + / (h1 * h2 * h3) + vy = _diff_conditional(vect.dot(j), y, h3, h1) \ + / (h1 * h2 * h3) + vz = _diff_conditional(vect.dot(k), z, h1, h2) \ + / (h1 * h2 * h3) + res = vx + vy + vz + if doit: + return res.doit() + return res + else: + if isinstance(vect, (Add, VectorAdd)): + return Add.fromiter(divergence(i, doit=doit) for i in vect.args) + elif isinstance(vect, (Mul, VectorMul)): + vector = [i for i in vect.args if isinstance(i, (Vector, Cross, Gradient))][0] + scalar = Mul.fromiter(i for i in vect.args if not isinstance(i, (Vector, Cross, Gradient))) + res = Dot(vector, gradient(scalar)) + scalar*divergence(vector, doit=doit) + if doit: + return res.doit() + return res + elif isinstance(vect, (Cross, Curl, Gradient)): + return Divergence(vect) + else: + raise ValueError("Invalid argument for divergence") + + +def gradient(scalar_field, doit=True): + """ + Returns the vector gradient of a scalar field computed wrt the + base scalars of the given coordinate system. + + Parameters + ========== + + scalar_field : SymPy Expr + The scalar field to compute the gradient of + + doit : bool + If True, the result is returned after calling .doit() on + each component. Else, the returned expression contains + Derivative instances + + Examples + ======== + + >>> from sympy.vector import CoordSys3D, gradient + >>> R = CoordSys3D('R') + >>> s1 = R.x*R.y*R.z + >>> gradient(s1) + R.y*R.z*R.i + R.x*R.z*R.j + R.x*R.y*R.k + >>> s2 = 5*R.x**2*R.z + >>> gradient(s2) + 10*R.x*R.z*R.i + 5*R.x**2*R.k + + """ + coord_sys = _get_coord_systems(scalar_field) + + if len(coord_sys) == 0: + return Vector.zero + elif len(coord_sys) == 1: + coord_sys = next(iter(coord_sys)) + h1, h2, h3 = coord_sys.lame_coefficients() + i, j, k = coord_sys.base_vectors() + x, y, z = coord_sys.base_scalars() + vx = Derivative(scalar_field, x) / h1 + vy = Derivative(scalar_field, y) / h2 + vz = Derivative(scalar_field, z) / h3 + + if doit: + return (vx * i + vy * j + vz * k).doit() + return vx * i + vy * j + vz * k + else: + if isinstance(scalar_field, (Add, VectorAdd)): + return VectorAdd.fromiter(gradient(i) for i in scalar_field.args) + if isinstance(scalar_field, (Mul, VectorMul)): + s = _split_mul_args_wrt_coordsys(scalar_field) + return VectorAdd.fromiter(scalar_field / i * gradient(i) for i in s) + return Gradient(scalar_field) + + +class Laplacian(Expr): + """ + Represents unevaluated Laplacian. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D, Laplacian + >>> R = CoordSys3D('R') + >>> v = 3*R.x**3*R.y**2*R.z**3 + >>> Laplacian(v) + Laplacian(3*R.x**3*R.y**2*R.z**3) + + """ + + def __new__(cls, expr): + expr = sympify(expr) + obj = Expr.__new__(cls, expr) + obj._expr = expr + return obj + + def doit(self, **hints): + from sympy.vector.functions import laplacian + return laplacian(self._expr) + + +def _diff_conditional(expr, base_scalar, coeff_1, coeff_2): + """ + First re-expresses expr in the system that base_scalar belongs to. + If base_scalar appears in the re-expressed form, differentiates + it wrt base_scalar. + Else, returns 0 + """ + from sympy.vector.functions import express + new_expr = express(expr, base_scalar.system, variables=True) + arg = coeff_1 * coeff_2 * new_expr + return Derivative(arg, base_scalar) if arg else S.Zero diff --git a/phivenv/Lib/site-packages/sympy/vector/orienters.py b/phivenv/Lib/site-packages/sympy/vector/orienters.py new file mode 100644 index 0000000000000000000000000000000000000000..0c22089e568bc817c943c1beecebde0fea46b6ae --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/vector/orienters.py @@ -0,0 +1,398 @@ +from sympy.core.basic import Basic +from sympy.core.sympify import sympify +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.matrices.dense import (eye, rot_axis1, rot_axis2, rot_axis3) +from sympy.matrices.immutable import ImmutableDenseMatrix as Matrix +from sympy.core.cache import cacheit +from sympy.core.symbol import Str +import sympy.vector + + +class Orienter(Basic): + """ + Super-class for all orienter classes. + """ + + def rotation_matrix(self): + """ + The rotation matrix corresponding to this orienter + instance. + """ + return self._parent_orient + + +class AxisOrienter(Orienter): + """ + Class to denote an axis orienter. + """ + + def __new__(cls, angle, axis): + if not isinstance(axis, sympy.vector.Vector): + raise TypeError("axis should be a Vector") + angle = sympify(angle) + + obj = super().__new__(cls, angle, axis) + obj._angle = angle + obj._axis = axis + + return obj + + def __init__(self, angle, axis): + """ + Axis rotation is a rotation about an arbitrary axis by + some angle. The angle is supplied as a SymPy expr scalar, and + the axis is supplied as a Vector. + + Parameters + ========== + + angle : Expr + The angle by which the new system is to be rotated + + axis : Vector + The axis around which the rotation has to be performed + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> from sympy import symbols + >>> q1 = symbols('q1') + >>> N = CoordSys3D('N') + >>> from sympy.vector import AxisOrienter + >>> orienter = AxisOrienter(q1, N.i + 2 * N.j) + >>> B = N.orient_new('B', (orienter, )) + + """ + # Dummy initializer for docstrings + pass + + @cacheit + def rotation_matrix(self, system): + """ + The rotation matrix corresponding to this orienter + instance. + + Parameters + ========== + + system : CoordSys3D + The coordinate system wrt which the rotation matrix + is to be computed + """ + + axis = sympy.vector.express(self.axis, system).normalize() + axis = axis.to_matrix(system) + theta = self.angle + parent_orient = ((eye(3) - axis * axis.T) * cos(theta) + + Matrix([[0, -axis[2], axis[1]], + [axis[2], 0, -axis[0]], + [-axis[1], axis[0], 0]]) * sin(theta) + + axis * axis.T) + parent_orient = parent_orient.T + return parent_orient + + @property + def angle(self): + return self._angle + + @property + def axis(self): + return self._axis + + +class ThreeAngleOrienter(Orienter): + """ + Super-class for Body and Space orienters. + """ + + def __new__(cls, angle1, angle2, angle3, rot_order): + if isinstance(rot_order, Str): + rot_order = rot_order.name + + approved_orders = ('123', '231', '312', '132', '213', + '321', '121', '131', '212', '232', + '313', '323', '') + original_rot_order = rot_order + rot_order = str(rot_order).upper() + if not (len(rot_order) == 3): + raise TypeError('rot_order should be a str of length 3') + rot_order = [i.replace('X', '1') for i in rot_order] + rot_order = [i.replace('Y', '2') for i in rot_order] + rot_order = [i.replace('Z', '3') for i in rot_order] + rot_order = ''.join(rot_order) + if rot_order not in approved_orders: + raise TypeError('Invalid rot_type parameter') + a1 = int(rot_order[0]) + a2 = int(rot_order[1]) + a3 = int(rot_order[2]) + angle1 = sympify(angle1) + angle2 = sympify(angle2) + angle3 = sympify(angle3) + if cls._in_order: + parent_orient = (_rot(a1, angle1) * + _rot(a2, angle2) * + _rot(a3, angle3)) + else: + parent_orient = (_rot(a3, angle3) * + _rot(a2, angle2) * + _rot(a1, angle1)) + parent_orient = parent_orient.T + + obj = super().__new__( + cls, angle1, angle2, angle3, Str(rot_order)) + obj._angle1 = angle1 + obj._angle2 = angle2 + obj._angle3 = angle3 + obj._rot_order = original_rot_order + obj._parent_orient = parent_orient + + return obj + + @property + def angle1(self): + return self._angle1 + + @property + def angle2(self): + return self._angle2 + + @property + def angle3(self): + return self._angle3 + + @property + def rot_order(self): + return self._rot_order + + +class BodyOrienter(ThreeAngleOrienter): + """ + Class to denote a body-orienter. + """ + + _in_order = True + + def __new__(cls, angle1, angle2, angle3, rot_order): + obj = ThreeAngleOrienter.__new__(cls, angle1, angle2, angle3, + rot_order) + return obj + + def __init__(self, angle1, angle2, angle3, rot_order): + """ + Body orientation takes this coordinate system through three + successive simple rotations. + + Body fixed rotations include both Euler Angles and + Tait-Bryan Angles, see https://en.wikipedia.org/wiki/Euler_angles. + + Parameters + ========== + + angle1, angle2, angle3 : Expr + Three successive angles to rotate the coordinate system by + + rotation_order : string + String defining the order of axes for rotation + + Examples + ======== + + >>> from sympy.vector import CoordSys3D, BodyOrienter + >>> from sympy import symbols + >>> q1, q2, q3 = symbols('q1 q2 q3') + >>> N = CoordSys3D('N') + + A 'Body' fixed rotation is described by three angles and + three body-fixed rotation axes. To orient a coordinate system D + with respect to N, each sequential rotation is always about + the orthogonal unit vectors fixed to D. For example, a '123' + rotation will specify rotations about N.i, then D.j, then + D.k. (Initially, D.i is same as N.i) + Therefore, + + >>> body_orienter = BodyOrienter(q1, q2, q3, '123') + >>> D = N.orient_new('D', (body_orienter, )) + + is same as + + >>> from sympy.vector import AxisOrienter + >>> axis_orienter1 = AxisOrienter(q1, N.i) + >>> D = N.orient_new('D', (axis_orienter1, )) + >>> axis_orienter2 = AxisOrienter(q2, D.j) + >>> D = D.orient_new('D', (axis_orienter2, )) + >>> axis_orienter3 = AxisOrienter(q3, D.k) + >>> D = D.orient_new('D', (axis_orienter3, )) + + Acceptable rotation orders are of length 3, expressed in XYZ or + 123, and cannot have a rotation about about an axis twice in a row. + + >>> body_orienter1 = BodyOrienter(q1, q2, q3, '123') + >>> body_orienter2 = BodyOrienter(q1, q2, 0, 'ZXZ') + >>> body_orienter3 = BodyOrienter(0, 0, 0, 'XYX') + + """ + # Dummy initializer for docstrings + pass + + +class SpaceOrienter(ThreeAngleOrienter): + """ + Class to denote a space-orienter. + """ + + _in_order = False + + def __new__(cls, angle1, angle2, angle3, rot_order): + obj = ThreeAngleOrienter.__new__(cls, angle1, angle2, angle3, + rot_order) + return obj + + def __init__(self, angle1, angle2, angle3, rot_order): + """ + Space rotation is similar to Body rotation, but the rotations + are applied in the opposite order. + + Parameters + ========== + + angle1, angle2, angle3 : Expr + Three successive angles to rotate the coordinate system by + + rotation_order : string + String defining the order of axes for rotation + + See Also + ======== + + BodyOrienter : Orienter to orient systems wrt Euler angles. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D, SpaceOrienter + >>> from sympy import symbols + >>> q1, q2, q3 = symbols('q1 q2 q3') + >>> N = CoordSys3D('N') + + To orient a coordinate system D with respect to N, each + sequential rotation is always about N's orthogonal unit vectors. + For example, a '123' rotation will specify rotations about + N.i, then N.j, then N.k. + Therefore, + + >>> space_orienter = SpaceOrienter(q1, q2, q3, '312') + >>> D = N.orient_new('D', (space_orienter, )) + + is same as + + >>> from sympy.vector import AxisOrienter + >>> axis_orienter1 = AxisOrienter(q1, N.i) + >>> B = N.orient_new('B', (axis_orienter1, )) + >>> axis_orienter2 = AxisOrienter(q2, N.j) + >>> C = B.orient_new('C', (axis_orienter2, )) + >>> axis_orienter3 = AxisOrienter(q3, N.k) + >>> D = C.orient_new('C', (axis_orienter3, )) + + """ + # Dummy initializer for docstrings + pass + + +class QuaternionOrienter(Orienter): + """ + Class to denote a quaternion-orienter. + """ + + def __new__(cls, q0, q1, q2, q3): + q0 = sympify(q0) + q1 = sympify(q1) + q2 = sympify(q2) + q3 = sympify(q3) + parent_orient = (Matrix([[q0 ** 2 + q1 ** 2 - q2 ** 2 - + q3 ** 2, + 2 * (q1 * q2 - q0 * q3), + 2 * (q0 * q2 + q1 * q3)], + [2 * (q1 * q2 + q0 * q3), + q0 ** 2 - q1 ** 2 + + q2 ** 2 - q3 ** 2, + 2 * (q2 * q3 - q0 * q1)], + [2 * (q1 * q3 - q0 * q2), + 2 * (q0 * q1 + q2 * q3), + q0 ** 2 - q1 ** 2 - + q2 ** 2 + q3 ** 2]])) + parent_orient = parent_orient.T + + obj = super().__new__(cls, q0, q1, q2, q3) + obj._q0 = q0 + obj._q1 = q1 + obj._q2 = q2 + obj._q3 = q3 + obj._parent_orient = parent_orient + + return obj + + def __init__(self, angle1, angle2, angle3, rot_order): + """ + Quaternion orientation orients the new CoordSys3D with + Quaternions, defined as a finite rotation about lambda, a unit + vector, by some amount theta. + + This orientation is described by four parameters: + + q0 = cos(theta/2) + + q1 = lambda_x sin(theta/2) + + q2 = lambda_y sin(theta/2) + + q3 = lambda_z sin(theta/2) + + Quaternion does not take in a rotation order. + + Parameters + ========== + + q0, q1, q2, q3 : Expr + The quaternions to rotate the coordinate system by + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> from sympy import symbols + >>> q0, q1, q2, q3 = symbols('q0 q1 q2 q3') + >>> N = CoordSys3D('N') + >>> from sympy.vector import QuaternionOrienter + >>> q_orienter = QuaternionOrienter(q0, q1, q2, q3) + >>> B = N.orient_new('B', (q_orienter, )) + + """ + # Dummy initializer for docstrings + pass + + @property + def q0(self): + return self._q0 + + @property + def q1(self): + return self._q1 + + @property + def q2(self): + return self._q2 + + @property + def q3(self): + return self._q3 + + +def _rot(axis, angle): + """DCM for simple axis 1, 2 or 3 rotations. """ + if axis == 1: + return Matrix(rot_axis1(angle).T) + elif axis == 2: + return Matrix(rot_axis2(angle).T) + elif axis == 3: + return Matrix(rot_axis3(angle).T) diff --git a/phivenv/Lib/site-packages/sympy/vector/parametricregion.py b/phivenv/Lib/site-packages/sympy/vector/parametricregion.py new file mode 100644 index 0000000000000000000000000000000000000000..5246769dabe208fe630f7c33b2da3ef4e11b3f67 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/vector/parametricregion.py @@ -0,0 +1,189 @@ +from functools import singledispatch +from sympy.core.numbers import pi +from sympy.functions.elementary.trigonometric import tan +from sympy.simplify import trigsimp +from sympy.core import Basic, Tuple +from sympy.core.symbol import _symbol +from sympy.solvers import solve +from sympy.geometry import Point, Segment, Curve, Ellipse, Polygon +from sympy.vector import ImplicitRegion + + +class ParametricRegion(Basic): + """ + Represents a parametric region in space. + + Examples + ======== + + >>> from sympy import cos, sin, pi + >>> from sympy.abc import r, theta, t, a, b, x, y + >>> from sympy.vector import ParametricRegion + + >>> ParametricRegion((t, t**2), (t, -1, 2)) + ParametricRegion((t, t**2), (t, -1, 2)) + >>> ParametricRegion((x, y), (x, 3, 4), (y, 5, 6)) + ParametricRegion((x, y), (x, 3, 4), (y, 5, 6)) + >>> ParametricRegion((r*cos(theta), r*sin(theta)), (r, -2, 2), (theta, 0, pi)) + ParametricRegion((r*cos(theta), r*sin(theta)), (r, -2, 2), (theta, 0, pi)) + >>> ParametricRegion((a*cos(t), b*sin(t)), t) + ParametricRegion((a*cos(t), b*sin(t)), t) + + >>> circle = ParametricRegion((r*cos(theta), r*sin(theta)), r, (theta, 0, pi)) + >>> circle.parameters + (r, theta) + >>> circle.definition + (r*cos(theta), r*sin(theta)) + >>> circle.limits + {theta: (0, pi)} + + Dimension of a parametric region determines whether a region is a curve, surface + or volume region. It does not represent its dimensions in space. + + >>> circle.dimensions + 1 + + Parameters + ========== + + definition : tuple to define base scalars in terms of parameters. + + bounds : Parameter or a tuple of length 3 to define parameter and corresponding lower and upper bound. + + """ + def __new__(cls, definition, *bounds): + parameters = () + limits = {} + + if not isinstance(bounds, Tuple): + bounds = Tuple(*bounds) + + for bound in bounds: + if isinstance(bound, (tuple, Tuple)): + if len(bound) != 3: + raise ValueError("Tuple should be in the form (parameter, lowerbound, upperbound)") + parameters += (bound[0],) + limits[bound[0]] = (bound[1], bound[2]) + else: + parameters += (bound,) + + if not isinstance(definition, (tuple, Tuple)): + definition = (definition,) + + obj = super().__new__(cls, Tuple(*definition), *bounds) + obj._parameters = parameters + obj._limits = limits + + return obj + + @property + def definition(self): + return self.args[0] + + @property + def limits(self): + return self._limits + + @property + def parameters(self): + return self._parameters + + @property + def dimensions(self): + return len(self.limits) + + +@singledispatch +def parametric_region_list(reg): + """ + Returns a list of ParametricRegion objects representing the geometric region. + + Examples + ======== + + >>> from sympy.abc import t + >>> from sympy.vector import parametric_region_list + >>> from sympy.geometry import Point, Curve, Ellipse, Segment, Polygon + + >>> p = Point(2, 5) + >>> parametric_region_list(p) + [ParametricRegion((2, 5))] + + >>> c = Curve((t**3, 4*t), (t, -3, 4)) + >>> parametric_region_list(c) + [ParametricRegion((t**3, 4*t), (t, -3, 4))] + + >>> e = Ellipse(Point(1, 3), 2, 3) + >>> parametric_region_list(e) + [ParametricRegion((2*cos(t) + 1, 3*sin(t) + 3), (t, 0, 2*pi))] + + >>> s = Segment(Point(1, 3), Point(2, 6)) + >>> parametric_region_list(s) + [ParametricRegion((t + 1, 3*t + 3), (t, 0, 1))] + + >>> p1, p2, p3, p4 = [(0, 1), (2, -3), (5, 3), (-2, 3)] + >>> poly = Polygon(p1, p2, p3, p4) + >>> parametric_region_list(poly) + [ParametricRegion((2*t, 1 - 4*t), (t, 0, 1)), ParametricRegion((3*t + 2, 6*t - 3), (t, 0, 1)),\ + ParametricRegion((5 - 7*t, 3), (t, 0, 1)), ParametricRegion((2*t - 2, 3 - 2*t), (t, 0, 1))] + + """ + raise ValueError("SymPy cannot determine parametric representation of the region.") + + +@parametric_region_list.register(Point) +def _(obj): + return [ParametricRegion(obj.args)] + + +@parametric_region_list.register(Curve) # type: ignore +def _(obj): + definition = obj.arbitrary_point(obj.parameter).args + bounds = obj.limits + return [ParametricRegion(definition, bounds)] + + +@parametric_region_list.register(Ellipse) # type: ignore +def _(obj, parameter='t'): + definition = obj.arbitrary_point(parameter).args + t = _symbol(parameter, real=True) + bounds = (t, 0, 2*pi) + return [ParametricRegion(definition, bounds)] + + +@parametric_region_list.register(Segment) # type: ignore +def _(obj, parameter='t'): + t = _symbol(parameter, real=True) + definition = obj.arbitrary_point(t).args + + for i in range(0, 3): + lower_bound = solve(definition[i] - obj.points[0].args[i], t) + upper_bound = solve(definition[i] - obj.points[1].args[i], t) + + if len(lower_bound) == 1 and len(upper_bound) == 1: + bounds = t, lower_bound[0], upper_bound[0] + break + + definition_tuple = obj.arbitrary_point(parameter).args + return [ParametricRegion(definition_tuple, bounds)] + + +@parametric_region_list.register(Polygon) # type: ignore +def _(obj, parameter='t'): + l = [parametric_region_list(side, parameter)[0] for side in obj.sides] + return l + + +@parametric_region_list.register(ImplicitRegion) # type: ignore +def _(obj, parameters=('t', 's')): + definition = obj.rational_parametrization(parameters) + bounds = [] + + for i in range(len(obj.variables) - 1): + # Each parameter is replaced by its tangent to simplify integration + parameter = _symbol(parameters[i], real=True) + definition = [trigsimp(elem.subs(parameter, tan(parameter/2))) for elem in definition] + bounds.append((parameter, 0, 2*pi),) + + definition = Tuple(*definition) + return [ParametricRegion(definition, *bounds)] diff --git a/phivenv/Lib/site-packages/sympy/vector/point.py b/phivenv/Lib/site-packages/sympy/vector/point.py new file mode 100644 index 0000000000000000000000000000000000000000..442ea4e8edc0e33c4f83f774ea3f11a01725ac3a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/vector/point.py @@ -0,0 +1,148 @@ +from sympy.core.basic import Basic +from sympy.core.symbol import Str +from sympy.vector.vector import Vector +from sympy.vector.coordsysrect import CoordSys3D +from sympy.vector.functions import _path +from sympy.core.cache import cacheit + + +class Point(Basic): + """ + Represents a point in 3-D space. + """ + + def __new__(cls, name, position=Vector.zero, parent_point=None): + name = str(name) + # Check the args first + if not isinstance(position, Vector): + raise TypeError( + "position should be an instance of Vector, not %s" % type( + position)) + if (not isinstance(parent_point, Point) and + parent_point is not None): + raise TypeError( + "parent_point should be an instance of Point, not %s" % type( + parent_point)) + # Super class construction + if parent_point is None: + obj = super().__new__(cls, Str(name), position) + else: + obj = super().__new__(cls, Str(name), position, parent_point) + # Decide the object parameters + obj._name = name + obj._pos = position + if parent_point is None: + obj._parent = None + obj._root = obj + else: + obj._parent = parent_point + obj._root = parent_point._root + # Return object + return obj + + @cacheit + def position_wrt(self, other): + """ + Returns the position vector of this Point with respect to + another Point/CoordSys3D. + + Parameters + ========== + + other : Point/CoordSys3D + If other is a Point, the position of this Point wrt it is + returned. If its an instance of CoordSyRect, the position + wrt its origin is returned. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> N = CoordSys3D('N') + >>> p1 = N.origin.locate_new('p1', 10 * N.i) + >>> N.origin.position_wrt(p1) + (-10)*N.i + + """ + + if (not isinstance(other, Point) and + not isinstance(other, CoordSys3D)): + raise TypeError(str(other) + + "is not a Point or CoordSys3D") + if isinstance(other, CoordSys3D): + other = other.origin + # Handle special cases + if other == self: + return Vector.zero + elif other == self._parent: + return self._pos + elif other._parent == self: + return -1 * other._pos + # Else, use point tree to calculate position + rootindex, path = _path(self, other) + result = Vector.zero + for i in range(rootindex): + result += path[i]._pos + for i in range(rootindex + 1, len(path)): + result -= path[i]._pos + return result + + def locate_new(self, name, position): + """ + Returns a new Point located at the given position wrt this + Point. + Thus, the position vector of the new Point wrt this one will + be equal to the given 'position' parameter. + + Parameters + ========== + + name : str + Name of the new point + + position : Vector + The position vector of the new Point wrt this one + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> N = CoordSys3D('N') + >>> p1 = N.origin.locate_new('p1', 10 * N.i) + >>> p1.position_wrt(N.origin) + 10*N.i + + """ + return Point(name, position, self) + + def express_coordinates(self, coordinate_system): + """ + Returns the Cartesian/rectangular coordinates of this point + wrt the origin of the given CoordSys3D instance. + + Parameters + ========== + + coordinate_system : CoordSys3D + The coordinate system to express the coordinates of this + Point in. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> N = CoordSys3D('N') + >>> p1 = N.origin.locate_new('p1', 10 * N.i) + >>> p2 = p1.locate_new('p2', 5 * N.j) + >>> p2.express_coordinates(N) + (10, 5, 0) + + """ + + # Determine the position vector + pos_vect = self.position_wrt(coordinate_system.origin) + # Express it in the given coordinate system + return tuple(pos_vect.to_matrix(coordinate_system)) + + def _sympystr(self, printer): + return self._name diff --git a/phivenv/Lib/site-packages/sympy/vector/scalar.py b/phivenv/Lib/site-packages/sympy/vector/scalar.py new file mode 100644 index 0000000000000000000000000000000000000000..bcfb56cf177b9378a24f81ca1e6524fe048a5f94 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/vector/scalar.py @@ -0,0 +1,72 @@ +from sympy.core import AtomicExpr, Symbol, S +from sympy.core.sympify import _sympify +from sympy.printing.pretty.stringpict import prettyForm +from sympy.printing.precedence import PRECEDENCE +from sympy.core.kind import NumberKind + + +class BaseScalar(AtomicExpr): + """ + A coordinate symbol/base scalar. + + Ideally, users should not instantiate this class. + + """ + + kind = NumberKind + + def __new__(cls, index, system, pretty_str=None, latex_str=None): + from sympy.vector.coordsysrect import CoordSys3D + if pretty_str is None: + pretty_str = "x{}".format(index) + elif isinstance(pretty_str, Symbol): + pretty_str = pretty_str.name + if latex_str is None: + latex_str = "x_{}".format(index) + elif isinstance(latex_str, Symbol): + latex_str = latex_str.name + + index = _sympify(index) + system = _sympify(system) + obj = super().__new__(cls, index, system) + if not isinstance(system, CoordSys3D): + raise TypeError("system should be a CoordSys3D") + if index not in range(0, 3): + raise ValueError("Invalid index specified.") + # The _id is used for equating purposes, and for hashing + obj._id = (index, system) + obj._name = obj.name = system._name + '.' + system._variable_names[index] + obj._pretty_form = '' + pretty_str + obj._latex_form = latex_str + obj._system = system + + return obj + + is_commutative = True + is_symbol = True + + @property + def free_symbols(self): + return {self} + + _diff_wrt = True + + def _eval_derivative(self, s): + if self == s: + return S.One + return S.Zero + + def _latex(self, printer=None): + return self._latex_form + + def _pretty(self, printer=None): + return prettyForm(self._pretty_form) + + precedence = PRECEDENCE['Atom'] + + @property + def system(self): + return self._system + + def _sympystr(self, printer): + return self._name diff --git a/phivenv/Lib/site-packages/sympy/vector/tests/__init__.py b/phivenv/Lib/site-packages/sympy/vector/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/vector/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/vector/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..187cc654028996bae583054a4f367a59a8f38bc7 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/vector/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/vector/tests/__pycache__/test_coordsysrect.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/vector/tests/__pycache__/test_coordsysrect.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fcc80179eb7565a5a599ea26f9793ce9a6cba0e8 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/vector/tests/__pycache__/test_coordsysrect.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/vector/tests/__pycache__/test_dyadic.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/vector/tests/__pycache__/test_dyadic.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6bd956624b1ae9bb0124d5da66bcab3edcaa5f59 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/vector/tests/__pycache__/test_dyadic.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/vector/tests/__pycache__/test_field_functions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/vector/tests/__pycache__/test_field_functions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b412114f6bec936842fe5939086b6d0ba3bfbcec Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/vector/tests/__pycache__/test_field_functions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/vector/tests/__pycache__/test_functions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/vector/tests/__pycache__/test_functions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..deb393558fac18224f07eb84bd3afd9e18e1a6c7 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/vector/tests/__pycache__/test_functions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/vector/tests/__pycache__/test_implicitregion.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/vector/tests/__pycache__/test_implicitregion.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8dee15e7154b2bac7f8bfe95f90c0ae9da00523f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/vector/tests/__pycache__/test_implicitregion.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/vector/tests/__pycache__/test_integrals.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/vector/tests/__pycache__/test_integrals.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5df5ff0eecf8c451019f7ea5e928cd30ca068e93 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/vector/tests/__pycache__/test_integrals.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/vector/tests/__pycache__/test_operators.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/vector/tests/__pycache__/test_operators.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f48d66e9b1f380b2f81a810dd1d55c24a9e1f43d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/vector/tests/__pycache__/test_operators.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/vector/tests/__pycache__/test_parametricregion.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/vector/tests/__pycache__/test_parametricregion.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8139c4bc0514e07c17eb8b81691363c237c40c2 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/vector/tests/__pycache__/test_parametricregion.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/vector/tests/__pycache__/test_printing.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/vector/tests/__pycache__/test_printing.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be3d1e70b8d0998839bcecf9d91e5bf066accb0d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/vector/tests/__pycache__/test_printing.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/vector/tests/__pycache__/test_vector.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/vector/tests/__pycache__/test_vector.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3522f614d5455c792d530633455cfe9f32600cb6 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/vector/tests/__pycache__/test_vector.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/vector/tests/test_coordsysrect.py b/phivenv/Lib/site-packages/sympy/vector/tests/test_coordsysrect.py new file mode 100644 index 0000000000000000000000000000000000000000..53eb8c89ec1643a71800efe3e370acff3cb6f9c0 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/vector/tests/test_coordsysrect.py @@ -0,0 +1,464 @@ +from sympy.testing.pytest import raises +from sympy.vector.coordsysrect import CoordSys3D +from sympy.vector.scalar import BaseScalar +from sympy.core.function import expand +from sympy.core.numbers import pi +from sympy.core.symbol import symbols +from sympy.functions.elementary.hyperbolic import (cosh, sinh) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (acos, atan2, cos, sin) +from sympy.matrices.dense import zeros +from sympy.matrices.immutable import ImmutableDenseMatrix as Matrix +from sympy.simplify.simplify import simplify +from sympy.vector.functions import express +from sympy.vector.point import Point +from sympy.vector.vector import Vector +from sympy.vector.orienters import (AxisOrienter, BodyOrienter, + SpaceOrienter, QuaternionOrienter) + + +x, y, z = symbols('x y z') +a, b, c, q = symbols('a b c q') +q1, q2, q3, q4 = symbols('q1 q2 q3 q4') + + +def test_func_args(): + A = CoordSys3D('A') + assert A.x.func(*A.x.args) == A.x + expr = 3*A.x + 4*A.y + assert expr.func(*expr.args) == expr + assert A.i.func(*A.i.args) == A.i + v = A.x*A.i + A.y*A.j + A.z*A.k + assert v.func(*v.args) == v + assert A.origin.func(*A.origin.args) == A.origin + + +def test_coordsys3d_equivalence(): + A = CoordSys3D('A') + A1 = CoordSys3D('A') + assert A1 == A + B = CoordSys3D('B') + assert A != B + + +def test_orienters(): + A = CoordSys3D('A') + axis_orienter = AxisOrienter(a, A.k) + body_orienter = BodyOrienter(a, b, c, '123') + space_orienter = SpaceOrienter(a, b, c, '123') + q_orienter = QuaternionOrienter(q1, q2, q3, q4) + assert axis_orienter.rotation_matrix(A) == Matrix([ + [ cos(a), sin(a), 0], + [-sin(a), cos(a), 0], + [ 0, 0, 1]]) + assert body_orienter.rotation_matrix() == Matrix([ + [ cos(b)*cos(c), sin(a)*sin(b)*cos(c) + sin(c)*cos(a), + sin(a)*sin(c) - sin(b)*cos(a)*cos(c)], + [-sin(c)*cos(b), -sin(a)*sin(b)*sin(c) + cos(a)*cos(c), + sin(a)*cos(c) + sin(b)*sin(c)*cos(a)], + [ sin(b), -sin(a)*cos(b), + cos(a)*cos(b)]]) + assert space_orienter.rotation_matrix() == Matrix([ + [cos(b)*cos(c), sin(c)*cos(b), -sin(b)], + [sin(a)*sin(b)*cos(c) - sin(c)*cos(a), + sin(a)*sin(b)*sin(c) + cos(a)*cos(c), sin(a)*cos(b)], + [sin(a)*sin(c) + sin(b)*cos(a)*cos(c), -sin(a)*cos(c) + + sin(b)*sin(c)*cos(a), cos(a)*cos(b)]]) + assert q_orienter.rotation_matrix() == Matrix([ + [q1**2 + q2**2 - q3**2 - q4**2, 2*q1*q4 + 2*q2*q3, + -2*q1*q3 + 2*q2*q4], + [-2*q1*q4 + 2*q2*q3, q1**2 - q2**2 + q3**2 - q4**2, + 2*q1*q2 + 2*q3*q4], + [2*q1*q3 + 2*q2*q4, + -2*q1*q2 + 2*q3*q4, q1**2 - q2**2 - q3**2 + q4**2]]) + + +def test_coordinate_vars(): + """ + Tests the coordinate variables functionality with respect to + reorientation of coordinate systems. + """ + A = CoordSys3D('A') + # Note that the name given on the lhs is different from A.x._name + assert BaseScalar(0, A, 'A_x', r'\mathbf{{x}_{A}}') == A.x + assert BaseScalar(1, A, 'A_y', r'\mathbf{{y}_{A}}') == A.y + assert BaseScalar(2, A, 'A_z', r'\mathbf{{z}_{A}}') == A.z + assert BaseScalar(0, A, 'A_x', r'\mathbf{{x}_{A}}').__hash__() == A.x.__hash__() + assert isinstance(A.x, BaseScalar) and \ + isinstance(A.y, BaseScalar) and \ + isinstance(A.z, BaseScalar) + assert A.x*A.y == A.y*A.x + assert A.scalar_map(A) == {A.x: A.x, A.y: A.y, A.z: A.z} + assert A.x.system == A + assert A.x.diff(A.x) == 1 + B = A.orient_new_axis('B', q, A.k) + assert B.scalar_map(A) == {B.z: A.z, B.y: -A.x*sin(q) + A.y*cos(q), + B.x: A.x*cos(q) + A.y*sin(q)} + assert A.scalar_map(B) == {A.x: B.x*cos(q) - B.y*sin(q), + A.y: B.x*sin(q) + B.y*cos(q), A.z: B.z} + assert express(B.x, A, variables=True) == A.x*cos(q) + A.y*sin(q) + assert express(B.y, A, variables=True) == -A.x*sin(q) + A.y*cos(q) + assert express(B.z, A, variables=True) == A.z + assert expand(express(B.x*B.y*B.z, A, variables=True)) == \ + expand(A.z*(-A.x*sin(q) + A.y*cos(q))*(A.x*cos(q) + A.y*sin(q))) + assert express(B.x*B.i + B.y*B.j + B.z*B.k, A) == \ + (B.x*cos(q) - B.y*sin(q))*A.i + (B.x*sin(q) + \ + B.y*cos(q))*A.j + B.z*A.k + assert simplify(express(B.x*B.i + B.y*B.j + B.z*B.k, A, \ + variables=True)) == \ + A.x*A.i + A.y*A.j + A.z*A.k + assert express(A.x*A.i + A.y*A.j + A.z*A.k, B) == \ + (A.x*cos(q) + A.y*sin(q))*B.i + \ + (-A.x*sin(q) + A.y*cos(q))*B.j + A.z*B.k + assert simplify(express(A.x*A.i + A.y*A.j + A.z*A.k, B, \ + variables=True)) == \ + B.x*B.i + B.y*B.j + B.z*B.k + N = B.orient_new_axis('N', -q, B.k) + assert N.scalar_map(A) == \ + {N.x: A.x, N.z: A.z, N.y: A.y} + C = A.orient_new_axis('C', q, A.i + A.j + A.k) + mapping = A.scalar_map(C) + assert mapping[A.x].equals(C.x*(2*cos(q) + 1)/3 + + C.y*(-2*sin(q + pi/6) + 1)/3 + + C.z*(-2*cos(q + pi/3) + 1)/3) + assert mapping[A.y].equals(C.x*(-2*cos(q + pi/3) + 1)/3 + + C.y*(2*cos(q) + 1)/3 + + C.z*(-2*sin(q + pi/6) + 1)/3) + assert mapping[A.z].equals(C.x*(-2*sin(q + pi/6) + 1)/3 + + C.y*(-2*cos(q + pi/3) + 1)/3 + + C.z*(2*cos(q) + 1)/3) + D = A.locate_new('D', a*A.i + b*A.j + c*A.k) + assert D.scalar_map(A) == {D.z: A.z - c, D.x: A.x - a, D.y: A.y - b} + E = A.orient_new_axis('E', a, A.k, a*A.i + b*A.j + c*A.k) + assert A.scalar_map(E) == {A.z: E.z + c, + A.x: E.x*cos(a) - E.y*sin(a) + a, + A.y: E.x*sin(a) + E.y*cos(a) + b} + assert E.scalar_map(A) == {E.x: (A.x - a)*cos(a) + (A.y - b)*sin(a), + E.y: (-A.x + a)*sin(a) + (A.y - b)*cos(a), + E.z: A.z - c} + F = A.locate_new('F', Vector.zero) + assert A.scalar_map(F) == {A.z: F.z, A.x: F.x, A.y: F.y} + + +def test_rotation_matrix(): + N = CoordSys3D('N') + A = N.orient_new_axis('A', q1, N.k) + B = A.orient_new_axis('B', q2, A.i) + C = B.orient_new_axis('C', q3, B.j) + D = N.orient_new_axis('D', q4, N.j) + E = N.orient_new_space('E', q1, q2, q3, '123') + F = N.orient_new_quaternion('F', q1, q2, q3, q4) + G = N.orient_new_body('G', q1, q2, q3, '123') + assert N.rotation_matrix(C) == Matrix([ + [- sin(q1) * sin(q2) * sin(q3) + cos(q1) * cos(q3), - sin(q1) * + cos(q2), sin(q1) * sin(q2) * cos(q3) + sin(q3) * cos(q1)], \ + [sin(q1) * cos(q3) + sin(q2) * sin(q3) * cos(q1), \ + cos(q1) * cos(q2), sin(q1) * sin(q3) - sin(q2) * cos(q1) * \ + cos(q3)], [- sin(q3) * cos(q2), sin(q2), cos(q2) * cos(q3)]]) + test_mat = D.rotation_matrix(C) - Matrix( + [[cos(q1) * cos(q3) * cos(q4) - sin(q3) * (- sin(q4) * cos(q2) + + sin(q1) * sin(q2) * cos(q4)), - sin(q2) * sin(q4) - sin(q1) * + cos(q2) * cos(q4), sin(q3) * cos(q1) * cos(q4) + cos(q3) * \ + (- sin(q4) * cos(q2) + sin(q1) * sin(q2) * cos(q4))], \ + [sin(q1) * cos(q3) + sin(q2) * sin(q3) * cos(q1), cos(q1) * \ + cos(q2), sin(q1) * sin(q3) - sin(q2) * cos(q1) * cos(q3)], \ + [sin(q4) * cos(q1) * cos(q3) - sin(q3) * (cos(q2) * cos(q4) + \ + sin(q1) * sin(q2) * \ + sin(q4)), sin(q2) * + cos(q4) - sin(q1) * sin(q4) * cos(q2), sin(q3) * \ + sin(q4) * cos(q1) + cos(q3) * (cos(q2) * cos(q4) + \ + sin(q1) * sin(q2) * sin(q4))]]) + assert test_mat.expand() == zeros(3, 3) + assert E.rotation_matrix(N) == Matrix( + [[cos(q2)*cos(q3), sin(q3)*cos(q2), -sin(q2)], + [sin(q1)*sin(q2)*cos(q3) - sin(q3)*cos(q1), \ + sin(q1)*sin(q2)*sin(q3) + cos(q1)*cos(q3), sin(q1)*cos(q2)], \ + [sin(q1)*sin(q3) + sin(q2)*cos(q1)*cos(q3), - \ + sin(q1)*cos(q3) + sin(q2)*sin(q3)*cos(q1), cos(q1)*cos(q2)]]) + assert F.rotation_matrix(N) == Matrix([[ + q1**2 + q2**2 - q3**2 - q4**2, + 2*q1*q4 + 2*q2*q3, -2*q1*q3 + 2*q2*q4],[ -2*q1*q4 + 2*q2*q3, + q1**2 - q2**2 + q3**2 - q4**2, 2*q1*q2 + 2*q3*q4], + [2*q1*q3 + 2*q2*q4, + -2*q1*q2 + 2*q3*q4, + q1**2 - q2**2 - q3**2 + q4**2]]) + assert G.rotation_matrix(N) == Matrix([[ + cos(q2)*cos(q3), sin(q1)*sin(q2)*cos(q3) + sin(q3)*cos(q1), + sin(q1)*sin(q3) - sin(q2)*cos(q1)*cos(q3)], [ + -sin(q3)*cos(q2), -sin(q1)*sin(q2)*sin(q3) + cos(q1)*cos(q3), + sin(q1)*cos(q3) + sin(q2)*sin(q3)*cos(q1)],[ + sin(q2), -sin(q1)*cos(q2), cos(q1)*cos(q2)]]) + + +def test_vector_with_orientation(): + """ + Tests the effects of orientation of coordinate systems on + basic vector operations. + """ + N = CoordSys3D('N') + A = N.orient_new_axis('A', q1, N.k) + B = A.orient_new_axis('B', q2, A.i) + C = B.orient_new_axis('C', q3, B.j) + + # Test to_matrix + v1 = a*N.i + b*N.j + c*N.k + assert v1.to_matrix(A) == Matrix([[ a*cos(q1) + b*sin(q1)], + [-a*sin(q1) + b*cos(q1)], + [ c]]) + + # Test dot + assert N.i.dot(A.i) == cos(q1) + assert N.i.dot(A.j) == -sin(q1) + assert N.i.dot(A.k) == 0 + assert N.j.dot(A.i) == sin(q1) + assert N.j.dot(A.j) == cos(q1) + assert N.j.dot(A.k) == 0 + assert N.k.dot(A.i) == 0 + assert N.k.dot(A.j) == 0 + assert N.k.dot(A.k) == 1 + + assert N.i.dot(A.i + A.j) == -sin(q1) + cos(q1) == \ + (A.i + A.j).dot(N.i) + + assert A.i.dot(C.i) == cos(q3) + assert A.i.dot(C.j) == 0 + assert A.i.dot(C.k) == sin(q3) + assert A.j.dot(C.i) == sin(q2)*sin(q3) + assert A.j.dot(C.j) == cos(q2) + assert A.j.dot(C.k) == -sin(q2)*cos(q3) + assert A.k.dot(C.i) == -cos(q2)*sin(q3) + assert A.k.dot(C.j) == sin(q2) + assert A.k.dot(C.k) == cos(q2)*cos(q3) + + # Test cross + assert N.i.cross(A.i) == sin(q1)*A.k + assert N.i.cross(A.j) == cos(q1)*A.k + assert N.i.cross(A.k) == -sin(q1)*A.i - cos(q1)*A.j + assert N.j.cross(A.i) == -cos(q1)*A.k + assert N.j.cross(A.j) == sin(q1)*A.k + assert N.j.cross(A.k) == cos(q1)*A.i - sin(q1)*A.j + assert N.k.cross(A.i) == A.j + assert N.k.cross(A.j) == -A.i + assert N.k.cross(A.k) == Vector.zero + + assert N.i.cross(A.i) == sin(q1)*A.k + assert N.i.cross(A.j) == cos(q1)*A.k + assert N.i.cross(A.i + A.j) == sin(q1)*A.k + cos(q1)*A.k + assert (A.i + A.j).cross(N.i) == (-sin(q1) - cos(q1))*N.k + + assert A.i.cross(C.i) == sin(q3)*C.j + assert A.i.cross(C.j) == -sin(q3)*C.i + cos(q3)*C.k + assert A.i.cross(C.k) == -cos(q3)*C.j + assert C.i.cross(A.i) == (-sin(q3)*cos(q2))*A.j + \ + (-sin(q2)*sin(q3))*A.k + assert C.j.cross(A.i) == (sin(q2))*A.j + (-cos(q2))*A.k + assert express(C.k.cross(A.i), C).trigsimp() == cos(q3)*C.j + + +def test_orient_new_methods(): + N = CoordSys3D('N') + orienter1 = AxisOrienter(q4, N.j) + orienter2 = SpaceOrienter(q1, q2, q3, '123') + orienter3 = QuaternionOrienter(q1, q2, q3, q4) + orienter4 = BodyOrienter(q1, q2, q3, '123') + D = N.orient_new('D', (orienter1, )) + E = N.orient_new('E', (orienter2, )) + F = N.orient_new('F', (orienter3, )) + G = N.orient_new('G', (orienter4, )) + assert D == N.orient_new_axis('D', q4, N.j) + assert E == N.orient_new_space('E', q1, q2, q3, '123') + assert F == N.orient_new_quaternion('F', q1, q2, q3, q4) + assert G == N.orient_new_body('G', q1, q2, q3, '123') + + +def test_locatenew_point(): + """ + Tests Point class, and locate_new method in CoordSys3D. + """ + A = CoordSys3D('A') + assert isinstance(A.origin, Point) + v = a*A.i + b*A.j + c*A.k + C = A.locate_new('C', v) + assert C.origin.position_wrt(A) == \ + C.position_wrt(A) == \ + C.origin.position_wrt(A.origin) == v + assert A.origin.position_wrt(C) == \ + A.position_wrt(C) == \ + A.origin.position_wrt(C.origin) == -v + assert A.origin.express_coordinates(C) == (-a, -b, -c) + p = A.origin.locate_new('p', -v) + assert p.express_coordinates(A) == (-a, -b, -c) + assert p.position_wrt(C.origin) == p.position_wrt(C) == \ + -2 * v + p1 = p.locate_new('p1', 2*v) + assert p1.position_wrt(C.origin) == Vector.zero + assert p1.express_coordinates(C) == (0, 0, 0) + p2 = p.locate_new('p2', A.i) + assert p1.position_wrt(p2) == 2*v - A.i + assert p2.express_coordinates(C) == (-2*a + 1, -2*b, -2*c) + + +def test_create_new(): + a = CoordSys3D('a') + c = a.create_new('c', transformation='spherical') + assert c._parent == a + assert c.transformation_to_parent() == \ + (c.r*sin(c.theta)*cos(c.phi), c.r*sin(c.theta)*sin(c.phi), c.r*cos(c.theta)) + assert c.transformation_from_parent() == \ + (sqrt(a.x**2 + a.y**2 + a.z**2), acos(a.z/sqrt(a.x**2 + a.y**2 + a.z**2)), atan2(a.y, a.x)) + + +def test_evalf(): + A = CoordSys3D('A') + v = 3*A.i + 4*A.j + a*A.k + assert v.n() == v.evalf() + assert v.evalf(subs={a:1}) == v.subs(a, 1).evalf() + + +def test_lame_coefficients(): + a = CoordSys3D('a', 'spherical') + assert a.lame_coefficients() == (1, a.r, sin(a.theta)*a.r) + a = CoordSys3D('a') + assert a.lame_coefficients() == (1, 1, 1) + a = CoordSys3D('a', 'cartesian') + assert a.lame_coefficients() == (1, 1, 1) + a = CoordSys3D('a', 'cylindrical') + assert a.lame_coefficients() == (1, a.r, 1) + + +def test_transformation_equations(): + + x, y, z = symbols('x y z') + # Str + a = CoordSys3D('a', transformation='spherical', + variable_names=["r", "theta", "phi"]) + r, theta, phi = a.base_scalars() + + assert r == a.r + assert theta == a.theta + assert phi == a.phi + + raises(AttributeError, lambda: a.x) + raises(AttributeError, lambda: a.y) + raises(AttributeError, lambda: a.z) + + assert a.transformation_to_parent() == ( + r*sin(theta)*cos(phi), + r*sin(theta)*sin(phi), + r*cos(theta) + ) + assert a.lame_coefficients() == (1, r, r*sin(theta)) + assert a.transformation_from_parent_function()(x, y, z) == ( + sqrt(x ** 2 + y ** 2 + z ** 2), + acos((z) / sqrt(x**2 + y**2 + z**2)), + atan2(y, x) + ) + a = CoordSys3D('a', transformation='cylindrical', + variable_names=["r", "theta", "z"]) + r, theta, z = a.base_scalars() + assert a.transformation_to_parent() == ( + r*cos(theta), + r*sin(theta), + z + ) + assert a.lame_coefficients() == (1, a.r, 1) + assert a.transformation_from_parent_function()(x, y, z) == (sqrt(x**2 + y**2), + atan2(y, x), z) + + a = CoordSys3D('a', 'cartesian') + assert a.transformation_to_parent() == (a.x, a.y, a.z) + assert a.lame_coefficients() == (1, 1, 1) + assert a.transformation_from_parent_function()(x, y, z) == (x, y, z) + + # Variables and expressions + + # Cartesian with equation tuple: + x, y, z = symbols('x y z') + a = CoordSys3D('a', ((x, y, z), (x, y, z))) + a._calculate_inv_trans_equations() + assert a.transformation_to_parent() == (a.x1, a.x2, a.x3) + assert a.lame_coefficients() == (1, 1, 1) + assert a.transformation_from_parent_function()(x, y, z) == (x, y, z) + r, theta, z = symbols("r theta z") + + # Cylindrical with equation tuple: + a = CoordSys3D('a', [(r, theta, z), (r*cos(theta), r*sin(theta), z)], + variable_names=["r", "theta", "z"]) + r, theta, z = a.base_scalars() + assert a.transformation_to_parent() == ( + r*cos(theta), r*sin(theta), z + ) + assert a.lame_coefficients() == ( + sqrt(sin(theta)**2 + cos(theta)**2), + sqrt(r**2*sin(theta)**2 + r**2*cos(theta)**2), + 1 + ) # ==> this should simplify to (1, r, 1), tests are too slow with `simplify`. + + # Definitions with `lambda`: + + # Cartesian with `lambda` + a = CoordSys3D('a', lambda x, y, z: (x, y, z)) + assert a.transformation_to_parent() == (a.x1, a.x2, a.x3) + assert a.lame_coefficients() == (1, 1, 1) + a._calculate_inv_trans_equations() + assert a.transformation_from_parent_function()(x, y, z) == (x, y, z) + + # Spherical with `lambda` + a = CoordSys3D('a', lambda r, theta, phi: (r*sin(theta)*cos(phi), r*sin(theta)*sin(phi), r*cos(theta)), + variable_names=["r", "theta", "phi"]) + r, theta, phi = a.base_scalars() + assert a.transformation_to_parent() == ( + r*sin(theta)*cos(phi), r*sin(phi)*sin(theta), r*cos(theta) + ) + assert a.lame_coefficients() == ( + sqrt(sin(phi)**2*sin(theta)**2 + sin(theta)**2*cos(phi)**2 + cos(theta)**2), + sqrt(r**2*sin(phi)**2*cos(theta)**2 + r**2*sin(theta)**2 + r**2*cos(phi)**2*cos(theta)**2), + sqrt(r**2*sin(phi)**2*sin(theta)**2 + r**2*sin(theta)**2*cos(phi)**2) + ) # ==> this should simplify to (1, r, sin(theta)*r), `simplify` is too slow. + + # Cylindrical with `lambda` + a = CoordSys3D('a', lambda r, theta, z: + (r*cos(theta), r*sin(theta), z), + variable_names=["r", "theta", "z"] + ) + r, theta, z = a.base_scalars() + assert a.transformation_to_parent() == (r*cos(theta), r*sin(theta), z) + assert a.lame_coefficients() == ( + sqrt(sin(theta)**2 + cos(theta)**2), + sqrt(r**2*sin(theta)**2 + r**2*cos(theta)**2), + 1 + ) # ==> this should simplify to (1, a.x, 1) + + raises(TypeError, lambda: CoordSys3D('a', transformation={ + x: x*sin(y)*cos(z), y:x*sin(y)*sin(z), z: x*cos(y)})) + + +def test_check_orthogonality(): + x, y, z = symbols('x y z') + u,v = symbols('u, v') + a = CoordSys3D('a', transformation=((x, y, z), (x*sin(y)*cos(z), x*sin(y)*sin(z), x*cos(y)))) + assert a._check_orthogonality(a._transformation) is True + a = CoordSys3D('a', transformation=((x, y, z), (x * cos(y), x * sin(y), z))) + assert a._check_orthogonality(a._transformation) is True + a = CoordSys3D('a', transformation=((u, v, z), (cosh(u) * cos(v), sinh(u) * sin(v), z))) + assert a._check_orthogonality(a._transformation) is True + + raises(ValueError, lambda: CoordSys3D('a', transformation=((x, y, z), (x, x, z)))) + raises(ValueError, lambda: CoordSys3D('a', transformation=( + (x, y, z), (x*sin(y/2)*cos(z), x*sin(y)*sin(z), x*cos(y))))) + + +def test_rotation_trans_equations(): + a = CoordSys3D('a') + from sympy.core.symbol import symbols + q0 = symbols('q0') + assert a._rotation_trans_equations(a._parent_rotation_matrix, a.base_scalars()) == (a.x, a.y, a.z) + assert a._rotation_trans_equations(a._inverse_rotation_matrix(), a.base_scalars()) == (a.x, a.y, a.z) + b = a.orient_new_axis('b', 0, -a.k) + assert b._rotation_trans_equations(b._parent_rotation_matrix, b.base_scalars()) == (b.x, b.y, b.z) + assert b._rotation_trans_equations(b._inverse_rotation_matrix(), b.base_scalars()) == (b.x, b.y, b.z) + c = a.orient_new_axis('c', q0, -a.k) + assert c._rotation_trans_equations(c._parent_rotation_matrix, c.base_scalars()) == \ + (-sin(q0) * c.y + cos(q0) * c.x, sin(q0) * c.x + cos(q0) * c.y, c.z) + assert c._rotation_trans_equations(c._inverse_rotation_matrix(), c.base_scalars()) == \ + (sin(q0) * c.y + cos(q0) * c.x, -sin(q0) * c.x + cos(q0) * c.y, c.z) diff --git a/phivenv/Lib/site-packages/sympy/vector/tests/test_dyadic.py b/phivenv/Lib/site-packages/sympy/vector/tests/test_dyadic.py new file mode 100644 index 0000000000000000000000000000000000000000..2e396fcf2a81af897b59c0065f6b15f5c6933222 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/vector/tests/test_dyadic.py @@ -0,0 +1,134 @@ +from sympy.core.numbers import pi +from sympy.core.symbol import symbols +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.matrices.immutable import ImmutableDenseMatrix as Matrix +from sympy.simplify.simplify import simplify +from sympy.vector import (CoordSys3D, Vector, Dyadic, + DyadicAdd, DyadicMul, DyadicZero, + BaseDyadic, express) + + +A = CoordSys3D('A') + + +def test_dyadic(): + a, b = symbols('a, b') + assert Dyadic.zero != 0 + assert isinstance(Dyadic.zero, DyadicZero) + assert BaseDyadic(A.i, A.j) != BaseDyadic(A.j, A.i) + assert (BaseDyadic(Vector.zero, A.i) == + BaseDyadic(A.i, Vector.zero) == Dyadic.zero) + + d1 = A.i | A.i + d2 = A.j | A.j + d3 = A.i | A.j + + assert isinstance(d1, BaseDyadic) + d_mul = a*d1 + assert isinstance(d_mul, DyadicMul) + assert d_mul.base_dyadic == d1 + assert d_mul.measure_number == a + assert isinstance(a*d1 + b*d3, DyadicAdd) + assert d1 == A.i.outer(A.i) + assert d3 == A.i.outer(A.j) + v1 = a*A.i - A.k + v2 = A.i + b*A.j + assert v1 | v2 == v1.outer(v2) == a * (A.i|A.i) + (a*b) * (A.i|A.j) +\ + - (A.k|A.i) - b * (A.k|A.j) + assert d1 * 0 == Dyadic.zero + assert d1 != Dyadic.zero + assert d1 * 2 == 2 * (A.i | A.i) + assert d1 / 2. == 0.5 * d1 + + assert d1.dot(0 * d1) == Vector.zero + assert d1 & d2 == Dyadic.zero + assert d1.dot(A.i) == A.i == d1 & A.i + + assert d1.cross(Vector.zero) == Dyadic.zero + assert d1.cross(A.i) == Dyadic.zero + assert d1 ^ A.j == d1.cross(A.j) + assert d1.cross(A.k) == - A.i | A.j + assert d2.cross(A.i) == - A.j | A.k == d2 ^ A.i + + assert A.i ^ d1 == Dyadic.zero + assert A.j.cross(d1) == - A.k | A.i == A.j ^ d1 + assert Vector.zero.cross(d1) == Dyadic.zero + assert A.k ^ d1 == A.j | A.i + assert A.i.dot(d1) == A.i & d1 == A.i + assert A.j.dot(d1) == Vector.zero + assert Vector.zero.dot(d1) == Vector.zero + assert A.j & d2 == A.j + + assert d1.dot(d3) == d1 & d3 == A.i | A.j == d3 + assert d3 & d1 == Dyadic.zero + + q = symbols('q') + B = A.orient_new_axis('B', q, A.k) + assert express(d1, B) == express(d1, B, B) + + expr1 = ((cos(q)**2) * (B.i | B.i) + (-sin(q) * cos(q)) * + (B.i | B.j) + (-sin(q) * cos(q)) * (B.j | B.i) + (sin(q)**2) * + (B.j | B.j)) + assert (express(d1, B) - expr1).simplify() == Dyadic.zero + + expr2 = (cos(q)) * (B.i | A.i) + (-sin(q)) * (B.j | A.i) + assert (express(d1, B, A) - expr2).simplify() == Dyadic.zero + + expr3 = (cos(q)) * (A.i | B.i) + (-sin(q)) * (A.i | B.j) + assert (express(d1, A, B) - expr3).simplify() == Dyadic.zero + + assert d1.to_matrix(A) == Matrix([[1, 0, 0], [0, 0, 0], [0, 0, 0]]) + assert d1.to_matrix(A, B) == Matrix([[cos(q), -sin(q), 0], + [0, 0, 0], + [0, 0, 0]]) + assert d3.to_matrix(A) == Matrix([[0, 1, 0], [0, 0, 0], [0, 0, 0]]) + a, b, c, d, e, f = symbols('a, b, c, d, e, f') + v1 = a * A.i + b * A.j + c * A.k + v2 = d * A.i + e * A.j + f * A.k + d4 = v1.outer(v2) + assert d4.to_matrix(A) == Matrix([[a * d, a * e, a * f], + [b * d, b * e, b * f], + [c * d, c * e, c * f]]) + d5 = v1.outer(v1) + C = A.orient_new_axis('C', q, A.i) + for expected, actual in zip(C.rotation_matrix(A) * d5.to_matrix(A) * \ + C.rotation_matrix(A).T, d5.to_matrix(C)): + assert (expected - actual).simplify() == 0 + + +def test_dyadic_simplify(): + x, y, z, k, n, m, w, f, s, A = symbols('x, y, z, k, n, m, w, f, s, A') + N = CoordSys3D('N') + + dy = N.i | N.i + test1 = (1 / x + 1 / y) * dy + assert (N.i & test1 & N.i) != (x + y) / (x * y) + test1 = test1.simplify() + assert test1.simplify() == simplify(test1) + assert (N.i & test1 & N.i) == (x + y) / (x * y) + + test2 = (A**2 * s**4 / (4 * pi * k * m**3)) * dy + test2 = test2.simplify() + assert (N.i & test2 & N.i) == (A**2 * s**4 / (4 * pi * k * m**3)) + + test3 = ((4 + 4 * x - 2 * (2 + 2 * x)) / (2 + 2 * x)) * dy + test3 = test3.simplify() + assert (N.i & test3 & N.i) == 0 + + test4 = ((-4 * x * y**2 - 2 * y**3 - 2 * x**2 * y) / (x + y)**2) * dy + test4 = test4.simplify() + assert (N.i & test4 & N.i) == -2 * y + + +def test_dyadic_srepr(): + from sympy.printing.repr import srepr + N = CoordSys3D('N') + + dy = N.i | N.j + res = "BaseDyadic(CoordSys3D(Str('N'), Tuple(ImmutableDenseMatrix([["\ + "Integer(1), Integer(0), Integer(0)], [Integer(0), Integer(1), "\ + "Integer(0)], [Integer(0), Integer(0), Integer(1)]]), "\ + "VectorZero())).i, CoordSys3D(Str('N'), Tuple(ImmutableDenseMatrix("\ + "[[Integer(1), Integer(0), Integer(0)], [Integer(0), Integer(1), "\ + "Integer(0)], [Integer(0), Integer(0), Integer(1)]]), VectorZero())).j)" + assert srepr(dy) == res diff --git a/phivenv/Lib/site-packages/sympy/vector/tests/test_field_functions.py b/phivenv/Lib/site-packages/sympy/vector/tests/test_field_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..035c2ce0234b81069c5ad8dcb1c74f4de0164a8f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/vector/tests/test_field_functions.py @@ -0,0 +1,321 @@ +from sympy.core.function import Derivative +from sympy.vector.vector import Vector +from sympy.vector.coordsysrect import CoordSys3D +from sympy.simplify import simplify +from sympy.core.symbol import symbols +from sympy.core import S +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.vector.vector import Dot +from sympy.vector.operators import curl, divergence, gradient, Gradient, Divergence, Cross +from sympy.vector.deloperator import Del +from sympy.vector.functions import (is_conservative, is_solenoidal, + scalar_potential, directional_derivative, + laplacian, scalar_potential_difference) +from sympy.testing.pytest import raises + +C = CoordSys3D('C') +i, j, k = C.base_vectors() +x, y, z = C.base_scalars() +delop = Del() +a, b, c, q = symbols('a b c q') + + +def test_del_operator(): + # Tests for curl + + assert delop ^ Vector.zero == Vector.zero + assert ((delop ^ Vector.zero).doit() == Vector.zero == + curl(Vector.zero)) + assert delop.cross(Vector.zero) == delop ^ Vector.zero + assert (delop ^ i).doit() == Vector.zero + assert delop.cross(2*y**2*j, doit=True) == Vector.zero + assert delop.cross(2*y**2*j) == delop ^ 2*y**2*j + v = x*y*z * (i + j + k) + assert ((delop ^ v).doit() == + (-x*y + x*z)*i + (x*y - y*z)*j + (-x*z + y*z)*k == + curl(v)) + assert delop ^ v == delop.cross(v) + assert (delop.cross(2*x**2*j) == + (Derivative(0, C.y) - Derivative(2*C.x**2, C.z))*C.i + + (-Derivative(0, C.x) + Derivative(0, C.z))*C.j + + (-Derivative(0, C.y) + Derivative(2*C.x**2, C.x))*C.k) + assert (delop.cross(2*x**2*j, doit=True) == 4*x*k == + curl(2*x**2*j)) + + #Tests for divergence + assert delop & Vector.zero is S.Zero == divergence(Vector.zero) + assert (delop & Vector.zero).doit() is S.Zero + assert delop.dot(Vector.zero) == delop & Vector.zero + assert (delop & i).doit() is S.Zero + assert (delop & x**2*i).doit() == 2*x == divergence(x**2*i) + assert (delop.dot(v, doit=True) == x*y + y*z + z*x == + divergence(v)) + assert delop & v == delop.dot(v) + assert delop.dot(1/(x*y*z) * (i + j + k), doit=True) == \ + - 1 / (x*y*z**2) - 1 / (x*y**2*z) - 1 / (x**2*y*z) + v = x*i + y*j + z*k + assert (delop & v == Derivative(C.x, C.x) + + Derivative(C.y, C.y) + Derivative(C.z, C.z)) + assert delop.dot(v, doit=True) == 3 == divergence(v) + assert delop & v == delop.dot(v) + assert simplify((delop & v).doit()) == 3 + + #Tests for gradient + assert (delop.gradient(0, doit=True) == Vector.zero == + gradient(0)) + assert delop.gradient(0) == delop(0) + assert (delop(S.Zero)).doit() == Vector.zero + assert (delop(x) == (Derivative(C.x, C.x))*C.i + + (Derivative(C.x, C.y))*C.j + (Derivative(C.x, C.z))*C.k) + assert (delop(x)).doit() == i == gradient(x) + assert (delop(x*y*z) == + (Derivative(C.x*C.y*C.z, C.x))*C.i + + (Derivative(C.x*C.y*C.z, C.y))*C.j + + (Derivative(C.x*C.y*C.z, C.z))*C.k) + assert (delop.gradient(x*y*z, doit=True) == + y*z*i + z*x*j + x*y*k == + gradient(x*y*z)) + assert delop(x*y*z) == delop.gradient(x*y*z) + assert (delop(2*x**2)).doit() == 4*x*i + assert ((delop(a*sin(y) / x)).doit() == + -a*sin(y)/x**2 * i + a*cos(y)/x * j) + + #Tests for directional derivative + assert (Vector.zero & delop)(a) is S.Zero + assert ((Vector.zero & delop)(a)).doit() is S.Zero + assert ((v & delop)(Vector.zero)).doit() == Vector.zero + assert ((v & delop)(S.Zero)).doit() is S.Zero + assert ((i & delop)(x)).doit() == 1 + assert ((j & delop)(y)).doit() == 1 + assert ((k & delop)(z)).doit() == 1 + assert ((i & delop)(x*y*z)).doit() == y*z + assert ((v & delop)(x)).doit() == x + assert ((v & delop)(x*y*z)).doit() == 3*x*y*z + assert (v & delop)(x + y + z) == C.x + C.y + C.z + assert ((v & delop)(x + y + z)).doit() == x + y + z + assert ((v & delop)(v)).doit() == v + assert ((i & delop)(v)).doit() == i + assert ((j & delop)(v)).doit() == j + assert ((k & delop)(v)).doit() == k + assert ((v & delop)(Vector.zero)).doit() == Vector.zero + + # Tests for laplacian on scalar fields + assert laplacian(x*y*z) is S.Zero + assert laplacian(x**2) == S(2) + assert laplacian(x**2*y**2*z**2) == \ + 2*y**2*z**2 + 2*x**2*z**2 + 2*x**2*y**2 + A = CoordSys3D('A', transformation="spherical", variable_names=["r", "theta", "phi"]) + B = CoordSys3D('B', transformation='cylindrical', variable_names=["r", "theta", "z"]) + assert laplacian(A.r + A.theta + A.phi) == 2/A.r + cos(A.theta)/(A.r**2*sin(A.theta)) + assert laplacian(B.r + B.theta + B.z) == 1/B.r + + # Tests for laplacian on vector fields + assert laplacian(x*y*z*(i + j + k)) == Vector.zero + assert laplacian(x*y**2*z*(i + j + k)) == \ + 2*x*z*i + 2*x*z*j + 2*x*z*k + + +def test_product_rules(): + """ + Tests the six product rules defined with respect to the Del + operator + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Del + + """ + + #Define the scalar and vector functions + f = 2*x*y*z + g = x*y + y*z + z*x + u = x**2*i + 4*j - y**2*z*k + v = 4*i + x*y*z*k + + # First product rule + lhs = delop(f * g, doit=True) + rhs = (f * delop(g) + g * delop(f)).doit() + assert simplify(lhs) == simplify(rhs) + + # Second product rule + lhs = delop(u & v).doit() + rhs = ((u ^ (delop ^ v)) + (v ^ (delop ^ u)) + \ + ((u & delop)(v)) + ((v & delop)(u))).doit() + assert simplify(lhs) == simplify(rhs) + + # Third product rule + lhs = (delop & (f*v)).doit() + rhs = ((f * (delop & v)) + (v & (delop(f)))).doit() + assert simplify(lhs) == simplify(rhs) + + # Fourth product rule + lhs = (delop & (u ^ v)).doit() + rhs = ((v & (delop ^ u)) - (u & (delop ^ v))).doit() + assert simplify(lhs) == simplify(rhs) + + # Fifth product rule + lhs = (delop ^ (f * v)).doit() + rhs = (((delop(f)) ^ v) + (f * (delop ^ v))).doit() + assert simplify(lhs) == simplify(rhs) + + # Sixth product rule + lhs = (delop ^ (u ^ v)).doit() + rhs = (u * (delop & v) - v * (delop & u) + + (v & delop)(u) - (u & delop)(v)).doit() + assert simplify(lhs) == simplify(rhs) + + +P = C.orient_new_axis('P', q, C.k) # type: ignore +scalar_field = 2*x**2*y*z +grad_field = gradient(scalar_field) +vector_field = y**2*i + 3*x*j + 5*y*z*k +curl_field = curl(vector_field) + + +def test_conservative(): + assert is_conservative(Vector.zero) is True + assert is_conservative(i) is True + assert is_conservative(2 * i + 3 * j + 4 * k) is True + assert (is_conservative(y*z*i + x*z*j + x*y*k) is + True) + assert is_conservative(x * j) is False + assert is_conservative(grad_field) is True + assert is_conservative(curl_field) is False + assert (is_conservative(4*x*y*z*i + 2*x**2*z*j) is + False) + assert is_conservative(z*P.i + P.x*k) is True + + +def test_solenoidal(): + assert is_solenoidal(Vector.zero) is True + assert is_solenoidal(i) is True + assert is_solenoidal(2 * i + 3 * j + 4 * k) is True + assert (is_solenoidal(y*z*i + x*z*j + x*y*k) is + True) + assert is_solenoidal(y * j) is False + assert is_solenoidal(grad_field) is False + assert is_solenoidal(curl_field) is True + assert is_solenoidal((-2*y + 3)*k) is True + assert is_solenoidal(cos(q)*i + sin(q)*j + cos(q)*P.k) is True + assert is_solenoidal(z*P.i + P.x*k) is True + + +def test_directional_derivative(): + assert directional_derivative(C.x*C.y*C.z, 3*C.i + 4*C.j + C.k) == C.x*C.y + 4*C.x*C.z + 3*C.y*C.z + assert directional_derivative(5*C.x**2*C.z, 3*C.i + 4*C.j + C.k) == 5*C.x**2 + 30*C.x*C.z + assert directional_derivative(5*C.x**2*C.z, 4*C.j) is S.Zero + + D = CoordSys3D("D", "spherical", variable_names=["r", "theta", "phi"], + vector_names=["e_r", "e_theta", "e_phi"]) + r, theta, phi = D.base_scalars() + e_r, e_theta, e_phi = D.base_vectors() + assert directional_derivative(r**2*e_r, e_r) == 2*r*e_r + assert directional_derivative(5*r**2*phi, 3*e_r + 4*e_theta + e_phi) == 5*r**2 + 30*r*phi + + +def test_scalar_potential(): + assert scalar_potential(Vector.zero, C) == 0 + assert scalar_potential(i, C) == x + assert scalar_potential(j, C) == y + assert scalar_potential(k, C) == z + assert scalar_potential(y*z*i + x*z*j + x*y*k, C) == x*y*z + assert scalar_potential(grad_field, C) == scalar_field + assert scalar_potential(z*P.i + P.x*k, C) == x*z*cos(q) + y*z*sin(q) + assert scalar_potential(z*P.i + P.x*k, P) == P.x*P.z + raises(ValueError, lambda: scalar_potential(x*j, C)) + + +def test_scalar_potential_difference(): + point1 = C.origin.locate_new('P1', 1*i + 2*j + 3*k) + point2 = C.origin.locate_new('P2', 4*i + 5*j + 6*k) + genericpointC = C.origin.locate_new('RP', x*i + y*j + z*k) + genericpointP = P.origin.locate_new('PP', P.x*P.i + P.y*P.j + P.z*P.k) + assert scalar_potential_difference(S.Zero, C, point1, point2) == 0 + assert (scalar_potential_difference(scalar_field, C, C.origin, + genericpointC) == + scalar_field) + assert (scalar_potential_difference(grad_field, C, C.origin, + genericpointC) == + scalar_field) + assert scalar_potential_difference(grad_field, C, point1, point2) == 948 + assert (scalar_potential_difference(y*z*i + x*z*j + + x*y*k, C, point1, + genericpointC) == + x*y*z - 6) + potential_diff_P = (2*P.z*(P.x*sin(q) + P.y*cos(q))* + (P.x*cos(q) - P.y*sin(q))**2) + assert (scalar_potential_difference(grad_field, P, P.origin, + genericpointP).simplify() == + potential_diff_P.simplify()) + + +def test_differential_operators_curvilinear_system(): + A = CoordSys3D('A', transformation="spherical", variable_names=["r", "theta", "phi"]) + B = CoordSys3D('B', transformation='cylindrical', variable_names=["r", "theta", "z"]) + # Test for spherical coordinate system and gradient + assert gradient(3*A.r + 4*A.theta) == 3*A.i + 4/A.r*A.j + assert gradient(3*A.r*A.phi + 4*A.theta) == 3*A.phi*A.i + 4/A.r*A.j + (3/sin(A.theta))*A.k + assert gradient(0*A.r + 0*A.theta+0*A.phi) == Vector.zero + assert gradient(A.r*A.theta*A.phi) == A.theta*A.phi*A.i + A.phi*A.j + (A.theta/sin(A.theta))*A.k + # Test for spherical coordinate system and divergence + assert divergence(A.r * A.i + A.theta * A.j + A.phi * A.k) == \ + (sin(A.theta)*A.r + cos(A.theta)*A.r*A.theta)/(sin(A.theta)*A.r**2) + 3 + 1/(sin(A.theta)*A.r) + assert divergence(3*A.r*A.phi*A.i + A.theta*A.j + A.r*A.theta*A.phi*A.k) == \ + (sin(A.theta)*A.r + cos(A.theta)*A.r*A.theta)/(sin(A.theta)*A.r**2) + 9*A.phi + A.theta/sin(A.theta) + assert divergence(Vector.zero) == 0 + assert divergence(0*A.i + 0*A.j + 0*A.k) == 0 + # Test for spherical coordinate system and curl + assert curl(A.r*A.i + A.theta*A.j + A.phi*A.k) == \ + (cos(A.theta)*A.phi/(sin(A.theta)*A.r))*A.i + (-A.phi/A.r)*A.j + A.theta/A.r*A.k + assert curl(A.r*A.j + A.phi*A.k) == (cos(A.theta)*A.phi/(sin(A.theta)*A.r))*A.i + (-A.phi/A.r)*A.j + 2*A.k + + # Test for cylindrical coordinate system and gradient + assert gradient(0*B.r + 0*B.theta+0*B.z) == Vector.zero + assert gradient(B.r*B.theta*B.z) == B.theta*B.z*B.i + B.z*B.j + B.r*B.theta*B.k + assert gradient(3*B.r) == 3*B.i + assert gradient(2*B.theta) == 2/B.r * B.j + assert gradient(4*B.z) == 4*B.k + # Test for cylindrical coordinate system and divergence + assert divergence(B.r*B.i + B.theta*B.j + B.z*B.k) == 3 + 1/B.r + assert divergence(B.r*B.j + B.z*B.k) == 1 + # Test for cylindrical coordinate system and curl + assert curl(B.r*B.j + B.z*B.k) == 2*B.k + assert curl(3*B.i + 2/B.r*B.j + 4*B.k) == Vector.zero + +def test_mixed_coordinates(): + # gradient + a = CoordSys3D('a') + b = CoordSys3D('b') + c = CoordSys3D('c') + assert gradient(a.x*b.y) == b.y*a.i + a.x*b.j + assert gradient(3*cos(q)*a.x*b.x+a.y*(a.x+(cos(q)+b.x))) ==\ + (a.y + 3*b.x*cos(q))*a.i + (a.x + b.x + cos(q))*a.j + (3*a.x*cos(q) + a.y)*b.i + # Some tests need further work: + # assert gradient(a.x*(cos(a.x+b.x))) == (cos(a.x + b.x))*a.i + a.x*Gradient(cos(a.x + b.x)) + # assert gradient(cos(a.x + b.x)*cos(a.x + b.z)) == Gradient(cos(a.x + b.x)*cos(a.x + b.z)) + assert gradient(a.x**b.y) == Gradient(a.x**b.y) + # assert gradient(cos(a.x+b.y)*a.z) == None + assert gradient(cos(a.x*b.y)) == Gradient(cos(a.x*b.y)) + assert gradient(3*cos(q)*a.x*b.x*a.z*a.y+ b.y*b.z + cos(a.x+a.y)*b.z) == \ + (3*a.y*a.z*b.x*cos(q) - b.z*sin(a.x + a.y))*a.i + \ + (3*a.x*a.z*b.x*cos(q) - b.z*sin(a.x + a.y))*a.j + (3*a.x*a.y*b.x*cos(q))*a.k + \ + (3*a.x*a.y*a.z*cos(q))*b.i + b.z*b.j + (b.y + cos(a.x + a.y))*b.k + # divergence + assert divergence(a.i*a.x+a.j*a.y+a.z*a.k + b.i*b.x+b.j*b.y+b.z*b.k + c.i*c.x+c.j*c.y+c.z*c.k) == S(9) + # assert divergence(3*a.i*a.x*cos(a.x+b.z) + a.j*b.x*c.z) == None + assert divergence(3*a.i*a.x*a.z + b.j*b.x*c.z + 3*a.j*a.z*a.y) == \ + 6*a.z + b.x*Dot(b.j, c.k) + assert divergence(3*cos(q)*a.x*b.x*b.i*c.x) == \ + 3*a.x*b.x*cos(q)*Dot(b.i, c.i) + 3*a.x*c.x*cos(q) + 3*b.x*c.x*cos(q)*Dot(b.i, a.i) + assert divergence(a.x*b.x*c.x*Cross(a.x*a.i, a.y*b.j)) ==\ + a.x*b.x*c.x*Divergence(Cross(a.x*a.i, a.y*b.j)) + \ + b.x*c.x*Dot(Cross(a.x*a.i, a.y*b.j), a.i) + \ + a.x*c.x*Dot(Cross(a.x*a.i, a.y*b.j), b.i) + \ + a.x*b.x*Dot(Cross(a.x*a.i, a.y*b.j), c.i) + assert divergence(a.x*b.x*c.x*(a.x*a.i + b.x*b.i)) == \ + 4*a.x*b.x*c.x +\ + a.x**2*c.x*Dot(a.i, b.i) +\ + a.x**2*b.x*Dot(a.i, c.i) +\ + b.x**2*c.x*Dot(b.i, a.i) +\ + a.x*b.x**2*Dot(b.i, c.i) diff --git a/phivenv/Lib/site-packages/sympy/vector/tests/test_functions.py b/phivenv/Lib/site-packages/sympy/vector/tests/test_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..dfdf9821b6c853755ce12d0cbdfa599bd4f312e4 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/vector/tests/test_functions.py @@ -0,0 +1,184 @@ +from sympy.vector.vector import Vector +from sympy.vector.coordsysrect import CoordSys3D +from sympy.vector.functions import express, matrix_to_vector, orthogonalize +from sympy.core.numbers import Rational +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.matrices.immutable import ImmutableDenseMatrix as Matrix +from sympy.testing.pytest import raises + +N = CoordSys3D('N') +q1, q2, q3, q4, q5 = symbols('q1 q2 q3 q4 q5') +A = N.orient_new_axis('A', q1, N.k) # type: ignore +B = A.orient_new_axis('B', q2, A.i) +C = B.orient_new_axis('C', q3, B.j) + + +def test_express(): + assert express(Vector.zero, N) == Vector.zero + assert express(S.Zero, N) is S.Zero + assert express(A.i, C) == cos(q3)*C.i + sin(q3)*C.k + assert express(A.j, C) == sin(q2)*sin(q3)*C.i + cos(q2)*C.j - \ + sin(q2)*cos(q3)*C.k + assert express(A.k, C) == -sin(q3)*cos(q2)*C.i + sin(q2)*C.j + \ + cos(q2)*cos(q3)*C.k + assert express(A.i, N) == cos(q1)*N.i + sin(q1)*N.j + assert express(A.j, N) == -sin(q1)*N.i + cos(q1)*N.j + assert express(A.k, N) == N.k + assert express(A.i, A) == A.i + assert express(A.j, A) == A.j + assert express(A.k, A) == A.k + assert express(A.i, B) == B.i + assert express(A.j, B) == cos(q2)*B.j - sin(q2)*B.k + assert express(A.k, B) == sin(q2)*B.j + cos(q2)*B.k + assert express(A.i, C) == cos(q3)*C.i + sin(q3)*C.k + assert express(A.j, C) == sin(q2)*sin(q3)*C.i + cos(q2)*C.j - \ + sin(q2)*cos(q3)*C.k + assert express(A.k, C) == -sin(q3)*cos(q2)*C.i + sin(q2)*C.j + \ + cos(q2)*cos(q3)*C.k + # Check to make sure UnitVectors get converted properly + assert express(N.i, N) == N.i + assert express(N.j, N) == N.j + assert express(N.k, N) == N.k + assert express(N.i, A) == (cos(q1)*A.i - sin(q1)*A.j) + assert express(N.j, A) == (sin(q1)*A.i + cos(q1)*A.j) + assert express(N.k, A) == A.k + assert express(N.i, B) == (cos(q1)*B.i - sin(q1)*cos(q2)*B.j + + sin(q1)*sin(q2)*B.k) + assert express(N.j, B) == (sin(q1)*B.i + cos(q1)*cos(q2)*B.j - + sin(q2)*cos(q1)*B.k) + assert express(N.k, B) == (sin(q2)*B.j + cos(q2)*B.k) + assert express(N.i, C) == ( + (cos(q1)*cos(q3) - sin(q1)*sin(q2)*sin(q3))*C.i - + sin(q1)*cos(q2)*C.j + + (sin(q3)*cos(q1) + sin(q1)*sin(q2)*cos(q3))*C.k) + assert express(N.j, C) == ( + (sin(q1)*cos(q3) + sin(q2)*sin(q3)*cos(q1))*C.i + + cos(q1)*cos(q2)*C.j + + (sin(q1)*sin(q3) - sin(q2)*cos(q1)*cos(q3))*C.k) + assert express(N.k, C) == (-sin(q3)*cos(q2)*C.i + sin(q2)*C.j + + cos(q2)*cos(q3)*C.k) + + assert express(A.i, N) == (cos(q1)*N.i + sin(q1)*N.j) + assert express(A.j, N) == (-sin(q1)*N.i + cos(q1)*N.j) + assert express(A.k, N) == N.k + assert express(A.i, A) == A.i + assert express(A.j, A) == A.j + assert express(A.k, A) == A.k + assert express(A.i, B) == B.i + assert express(A.j, B) == (cos(q2)*B.j - sin(q2)*B.k) + assert express(A.k, B) == (sin(q2)*B.j + cos(q2)*B.k) + assert express(A.i, C) == (cos(q3)*C.i + sin(q3)*C.k) + assert express(A.j, C) == (sin(q2)*sin(q3)*C.i + cos(q2)*C.j - + sin(q2)*cos(q3)*C.k) + assert express(A.k, C) == (-sin(q3)*cos(q2)*C.i + sin(q2)*C.j + + cos(q2)*cos(q3)*C.k) + + assert express(B.i, N) == (cos(q1)*N.i + sin(q1)*N.j) + assert express(B.j, N) == (-sin(q1)*cos(q2)*N.i + + cos(q1)*cos(q2)*N.j + sin(q2)*N.k) + assert express(B.k, N) == (sin(q1)*sin(q2)*N.i - + sin(q2)*cos(q1)*N.j + cos(q2)*N.k) + assert express(B.i, A) == A.i + assert express(B.j, A) == (cos(q2)*A.j + sin(q2)*A.k) + assert express(B.k, A) == (-sin(q2)*A.j + cos(q2)*A.k) + assert express(B.i, B) == B.i + assert express(B.j, B) == B.j + assert express(B.k, B) == B.k + assert express(B.i, C) == (cos(q3)*C.i + sin(q3)*C.k) + assert express(B.j, C) == C.j + assert express(B.k, C) == (-sin(q3)*C.i + cos(q3)*C.k) + + assert express(C.i, N) == ( + (cos(q1)*cos(q3) - sin(q1)*sin(q2)*sin(q3))*N.i + + (sin(q1)*cos(q3) + sin(q2)*sin(q3)*cos(q1))*N.j - + sin(q3)*cos(q2)*N.k) + assert express(C.j, N) == ( + -sin(q1)*cos(q2)*N.i + cos(q1)*cos(q2)*N.j + sin(q2)*N.k) + assert express(C.k, N) == ( + (sin(q3)*cos(q1) + sin(q1)*sin(q2)*cos(q3))*N.i + + (sin(q1)*sin(q3) - sin(q2)*cos(q1)*cos(q3))*N.j + + cos(q2)*cos(q3)*N.k) + assert express(C.i, A) == (cos(q3)*A.i + sin(q2)*sin(q3)*A.j - + sin(q3)*cos(q2)*A.k) + assert express(C.j, A) == (cos(q2)*A.j + sin(q2)*A.k) + assert express(C.k, A) == (sin(q3)*A.i - sin(q2)*cos(q3)*A.j + + cos(q2)*cos(q3)*A.k) + assert express(C.i, B) == (cos(q3)*B.i - sin(q3)*B.k) + assert express(C.j, B) == B.j + assert express(C.k, B) == (sin(q3)*B.i + cos(q3)*B.k) + assert express(C.i, C) == C.i + assert express(C.j, C) == C.j + assert express(C.k, C) == C.k == (C.k) + + # Check to make sure Vectors get converted back to UnitVectors + assert N.i == express((cos(q1)*A.i - sin(q1)*A.j), N).simplify() + assert N.j == express((sin(q1)*A.i + cos(q1)*A.j), N).simplify() + assert N.i == express((cos(q1)*B.i - sin(q1)*cos(q2)*B.j + + sin(q1)*sin(q2)*B.k), N).simplify() + assert N.j == express((sin(q1)*B.i + cos(q1)*cos(q2)*B.j - + sin(q2)*cos(q1)*B.k), N).simplify() + assert N.k == express((sin(q2)*B.j + cos(q2)*B.k), N).simplify() + + + assert A.i == express((cos(q1)*N.i + sin(q1)*N.j), A).simplify() + assert A.j == express((-sin(q1)*N.i + cos(q1)*N.j), A).simplify() + + assert A.j == express((cos(q2)*B.j - sin(q2)*B.k), A).simplify() + assert A.k == express((sin(q2)*B.j + cos(q2)*B.k), A).simplify() + + assert A.i == express((cos(q3)*C.i + sin(q3)*C.k), A).simplify() + assert A.j == express((sin(q2)*sin(q3)*C.i + cos(q2)*C.j - + sin(q2)*cos(q3)*C.k), A).simplify() + + assert A.k == express((-sin(q3)*cos(q2)*C.i + sin(q2)*C.j + + cos(q2)*cos(q3)*C.k), A).simplify() + assert B.i == express((cos(q1)*N.i + sin(q1)*N.j), B).simplify() + assert B.j == express((-sin(q1)*cos(q2)*N.i + + cos(q1)*cos(q2)*N.j + sin(q2)*N.k), B).simplify() + + assert B.k == express((sin(q1)*sin(q2)*N.i - + sin(q2)*cos(q1)*N.j + cos(q2)*N.k), B).simplify() + + assert B.j == express((cos(q2)*A.j + sin(q2)*A.k), B).simplify() + assert B.k == express((-sin(q2)*A.j + cos(q2)*A.k), B).simplify() + assert B.i == express((cos(q3)*C.i + sin(q3)*C.k), B).simplify() + assert B.k == express((-sin(q3)*C.i + cos(q3)*C.k), B).simplify() + assert C.i == express((cos(q3)*A.i + sin(q2)*sin(q3)*A.j - + sin(q3)*cos(q2)*A.k), C).simplify() + assert C.j == express((cos(q2)*A.j + sin(q2)*A.k), C).simplify() + assert C.k == express((sin(q3)*A.i - sin(q2)*cos(q3)*A.j + + cos(q2)*cos(q3)*A.k), C).simplify() + assert C.i == express((cos(q3)*B.i - sin(q3)*B.k), C).simplify() + assert C.k == express((sin(q3)*B.i + cos(q3)*B.k), C).simplify() + + +def test_matrix_to_vector(): + m = Matrix([[1], [2], [3]]) + assert matrix_to_vector(m, C) == C.i + 2*C.j + 3*C.k + m = Matrix([[0], [0], [0]]) + assert matrix_to_vector(m, N) == matrix_to_vector(m, C) == \ + Vector.zero + m = Matrix([[q1], [q2], [q3]]) + assert matrix_to_vector(m, N) == q1*N.i + q2*N.j + q3*N.k + + +def test_orthogonalize(): + C = CoordSys3D('C') + a, b = symbols('a b', integer=True) + i, j, k = C.base_vectors() + v1 = i + 2*j + v2 = 2*i + 3*j + v3 = 3*i + 5*j + v4 = 3*i + j + v5 = 2*i + 2*j + v6 = a*i + b*j + v7 = 4*a*i + 4*b*j + assert orthogonalize(v1, v2) == [C.i + 2*C.j, C.i*Rational(2, 5) + -C.j/5] + # from wikipedia + assert orthogonalize(v4, v5, orthonormal=True) == \ + [(3*sqrt(10))*C.i/10 + (sqrt(10))*C.j/10, (-sqrt(10))*C.i/10 + (3*sqrt(10))*C.j/10] + raises(ValueError, lambda: orthogonalize(v1, v2, v3)) + raises(ValueError, lambda: orthogonalize(v6, v7)) diff --git a/phivenv/Lib/site-packages/sympy/vector/tests/test_implicitregion.py b/phivenv/Lib/site-packages/sympy/vector/tests/test_implicitregion.py new file mode 100644 index 0000000000000000000000000000000000000000..3686d847a7f165cb5ba9aeb813e5922aaa17e1e0 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/vector/tests/test_implicitregion.py @@ -0,0 +1,90 @@ +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.abc import x, y, z, s, t +from sympy.sets import FiniteSet, EmptySet +from sympy.geometry import Point +from sympy.vector import ImplicitRegion +from sympy.testing.pytest import raises + + +def test_ImplicitRegion(): + ellipse = ImplicitRegion((x, y), (x**2/4 + y**2/16 - 1)) + assert ellipse.equation == x**2/4 + y**2/16 - 1 + assert ellipse.variables == (x, y) + assert ellipse.degree == 2 + r = ImplicitRegion((x, y, z), Eq(x**4 + y**2 - x*y, 6)) + assert r.equation == x**4 + y**2 - x*y - 6 + assert r.variables == (x, y, z) + assert r.degree == 4 + + +def test_regular_point(): + r1 = ImplicitRegion((x,), x**2 - 16) + assert r1.regular_point() == (-4,) + c1 = ImplicitRegion((x, y), x**2 + y**2 - 4) + assert c1.regular_point() == (0, -2) + c2 = ImplicitRegion((x, y), (x - S(5)/2)**2 + y**2 - (S(1)/4)**2) + assert c2.regular_point() == (S(5)/2, -S(1)/4) + c3 = ImplicitRegion((x, y), (y - 5)**2 - 16*(x - 5)) + assert c3.regular_point() == (5, 5) + r2 = ImplicitRegion((x, y), x**2 - 4*x*y - 3*y**2 + 4*x + 8*y - 5) + assert r2.regular_point() == (S(4)/7, S(9)/7) + r3 = ImplicitRegion((x, y), x**2 - 2*x*y + 3*y**2 - 2*x - 5*y + 3/2) + raises(ValueError, lambda: r3.regular_point()) + + +def test_singular_points_and_multiplicty(): + r1 = ImplicitRegion((x, y, z), Eq(x + y + z, 0)) + assert r1.singular_points() == EmptySet + r2 = ImplicitRegion((x, y, z), x*y*z + y**4 -x**2*z**2) + assert r2.singular_points() == FiniteSet((0, 0, z), (x, 0, 0)) + assert r2.multiplicity((0, 0, 0)) == 3 + assert r2.multiplicity((0, 0, 6)) == 2 + r3 = ImplicitRegion((x, y, z), z**2 - x**2 - y**2) + assert r3.singular_points() == FiniteSet((0, 0, 0)) + assert r3.multiplicity((0, 0, 0)) == 2 + r4 = ImplicitRegion((x, y), x**2 + y**2 - 2*x) + assert r4.singular_points() == EmptySet + assert r4.multiplicity(Point(1, 3)) == 0 + + +def test_rational_parametrization(): + p = ImplicitRegion((x,), x - 2) + assert p.rational_parametrization() == (x - 2,) + + line = ImplicitRegion((x, y), Eq(y, 3*x + 2)) + assert line.rational_parametrization() == (x, 3*x + 2) + + circle1 = ImplicitRegion((x, y), (x-2)**2 + (y+3)**2 - 4) + assert circle1.rational_parametrization(parameters=t) == (4*t/(t**2 + 1) + 2, 4*t**2/(t**2 + 1) - 5) + circle2 = ImplicitRegion((x, y), (x - S.Half)**2 + y**2 - (S(1)/2)**2) + + assert circle2.rational_parametrization(parameters=t) == (t/(t**2 + 1) + S(1)/2, t**2/(t**2 + 1) - S(1)/2) + circle3 = ImplicitRegion((x, y), Eq(x**2 + y**2, 2*x)) + assert circle3.rational_parametrization(parameters=(t,)) == (2*t/(t**2 + 1) + 1, 2*t**2/(t**2 + 1) - 1) + + parabola = ImplicitRegion((x, y), (y - 3)**2 - 4*(x + 6)) + assert parabola.rational_parametrization(t) == (-6 + 4/t**2, 3 + 4/t) + + rect_hyperbola = ImplicitRegion((x, y), x*y - 1) + assert rect_hyperbola.rational_parametrization(t) == (-1 + (t + 1)/t, t) + + cubic_curve = ImplicitRegion((x, y), x**3 + x**2 - y**2) + assert cubic_curve.rational_parametrization(parameters=(t)) == (t**2 - 1, t*(t**2 - 1)) + cuspidal = ImplicitRegion((x, y), (x**3 - y**2)) + assert cuspidal.rational_parametrization(t) == (t**2, t**3) + + I = ImplicitRegion((x, y), x**3 + x**2 - y**2) + assert I.rational_parametrization(t) == (t**2 - 1, t*(t**2 - 1)) + + sphere = ImplicitRegion((x, y, z), Eq(x**2 + y**2 + z**2, 2*x)) + assert sphere.rational_parametrization(parameters=(s, t)) == (2/(s**2 + t**2 + 1), 2*t/(s**2 + t**2 + 1), 2*s/(s**2 + t**2 + 1)) + + conic = ImplicitRegion((x, y), Eq(x**2 + 4*x*y + 3*y**2 + x - y + 10, 0)) + assert conic.rational_parametrization(t) == ( + S(17)/2 + 4/(3*t**2 + 4*t + 1), 4*t/(3*t**2 + 4*t + 1) - S(11)/2) + + r1 = ImplicitRegion((x, y), y**2 - x**3 + x) + raises(NotImplementedError, lambda: r1.rational_parametrization()) + r2 = ImplicitRegion((x, y), y**2 - x**3 - x**2 + 1) + raises(NotImplementedError, lambda: r2.rational_parametrization()) diff --git a/phivenv/Lib/site-packages/sympy/vector/tests/test_integrals.py b/phivenv/Lib/site-packages/sympy/vector/tests/test_integrals.py new file mode 100644 index 0000000000000000000000000000000000000000..84c900d038e214df1ea59a8cd8fb2929005c3674 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/vector/tests/test_integrals.py @@ -0,0 +1,106 @@ +from sympy.core.numbers import pi +from sympy.core.singleton import S +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.testing.pytest import raises +from sympy.vector.coordsysrect import CoordSys3D +from sympy.vector.integrals import ParametricIntegral, vector_integrate +from sympy.vector.parametricregion import ParametricRegion +from sympy.vector.implicitregion import ImplicitRegion +from sympy.abc import x, y, z, u, v, r, t, theta, phi +from sympy.geometry import Point, Segment, Curve, Circle, Polygon, Plane + +C = CoordSys3D('C') + +def test_parametric_lineintegrals(): + halfcircle = ParametricRegion((4*cos(theta), 4*sin(theta)), (theta, -pi/2, pi/2)) + assert ParametricIntegral(C.x*C.y**4, halfcircle) == S(8192)/5 + + curve = ParametricRegion((t, t**2, t**3), (t, 0, 1)) + field1 = 8*C.x**2*C.y*C.z*C.i + 5*C.z*C.j - 4*C.x*C.y*C.k + assert ParametricIntegral(field1, curve) == 1 + line = ParametricRegion((4*t - 1, 2 - 2*t, t), (t, 0, 1)) + assert ParametricIntegral(C.x*C.z*C.i - C.y*C.z*C.k, line) == 3 + + assert ParametricIntegral(4*C.x**3, ParametricRegion((1, t), (t, 0, 2))) == 8 + + helix = ParametricRegion((cos(t), sin(t), 3*t), (t, 0, 4*pi)) + assert ParametricIntegral(C.x*C.y*C.z, helix) == -3*sqrt(10)*pi + + field2 = C.y*C.i + C.z*C.j + C.z*C.k + assert ParametricIntegral(field2, ParametricRegion((cos(t), sin(t), t**2), (t, 0, pi))) == -5*pi/2 + pi**4/2 + +def test_parametric_surfaceintegrals(): + + semisphere = ParametricRegion((2*sin(phi)*cos(theta), 2*sin(phi)*sin(theta), 2*cos(phi)),\ + (theta, 0, 2*pi), (phi, 0, pi/2)) + assert ParametricIntegral(C.z, semisphere) == 8*pi + + cylinder = ParametricRegion((sqrt(3)*cos(theta), sqrt(3)*sin(theta), z), (z, 0, 6), (theta, 0, 2*pi)) + assert ParametricIntegral(C.y, cylinder) == 0 + + cone = ParametricRegion((v*cos(u), v*sin(u), v), (u, 0, 2*pi), (v, 0, 1)) + assert ParametricIntegral(C.x*C.i + C.y*C.j + C.z**4*C.k, cone) == pi/3 + + triangle1 = ParametricRegion((x, y), (x, 0, 2), (y, 0, 10 - 5*x)) + triangle2 = ParametricRegion((x, y), (y, 0, 10 - 5*x), (x, 0, 2)) + assert ParametricIntegral(-15.6*C.y*C.k, triangle1) == ParametricIntegral(-15.6*C.y*C.k, triangle2) + assert ParametricIntegral(C.z, triangle1) == 10*C.z + +def test_parametric_volumeintegrals(): + + cube = ParametricRegion((x, y, z), (x, 0, 1), (y, 0, 1), (z, 0, 1)) + assert ParametricIntegral(1, cube) == 1 + + solidsphere1 = ParametricRegion((r*sin(phi)*cos(theta), r*sin(phi)*sin(theta), r*cos(phi)),\ + (r, 0, 2), (theta, 0, 2*pi), (phi, 0, pi)) + solidsphere2 = ParametricRegion((r*sin(phi)*cos(theta), r*sin(phi)*sin(theta), r*cos(phi)),\ + (r, 0, 2), (phi, 0, pi), (theta, 0, 2*pi)) + assert ParametricIntegral(C.x**2 + C.y**2, solidsphere1) == -256*pi/15 + assert ParametricIntegral(C.x**2 + C.y**2, solidsphere2) == 256*pi/15 + + region_under_plane1 = ParametricRegion((x, y, z), (x, 0, 3), (y, 0, -2*x/3 + 2),\ + (z, 0, 6 - 2*x - 3*y)) + region_under_plane2 = ParametricRegion((x, y, z), (x, 0, 3), (z, 0, 6 - 2*x - 3*y),\ + (y, 0, -2*x/3 + 2)) + + assert ParametricIntegral(C.x*C.i + C.j - 100*C.k, region_under_plane1) == \ + ParametricIntegral(C.x*C.i + C.j - 100*C.k, region_under_plane2) + assert ParametricIntegral(2*C.x, region_under_plane2) == -9 + +def test_vector_integrate(): + halfdisc = ParametricRegion((r*cos(theta), r* sin(theta)), (r, -2, 2), (theta, 0, pi)) + assert vector_integrate(C.x**2, halfdisc) == 4*pi + assert vector_integrate(C.x, ParametricRegion((t, t**2), (t, 2, 3))) == -17*sqrt(17)/12 + 37*sqrt(37)/12 + + assert vector_integrate(C.y**3*C.z, (C.x, 0, 3), (C.y, -1, 4)) == 765*C.z/4 + + s1 = Segment(Point(0, 0), Point(0, 1)) + assert vector_integrate(-15*C.y, s1) == S(-15)/2 + s2 = Segment(Point(4, 3, 9), Point(1, 1, 7)) + assert vector_integrate(C.y*C.i, s2) == -6 + + curve = Curve((sin(t), cos(t)), (t, 0, 2)) + assert vector_integrate(5*C.z, curve) == 10*C.z + + c1 = Circle(Point(2, 3), 6) + assert vector_integrate(C.x*C.y, c1) == 72*pi + c2 = Circle(Point(0, 0), Point(1, 1), Point(1, 0)) + assert vector_integrate(1, c2) == c2.circumference + + triangle = Polygon((0, 0), (1, 0), (1, 1)) + assert vector_integrate(C.x*C.i - 14*C.y*C.j, triangle) == 0 + p1, p2, p3, p4 = [(0, 0), (1, 0), (5, 1), (0, 1)] + poly = Polygon(p1, p2, p3, p4) + assert vector_integrate(-23*C.z, poly) == -161*C.z - 23*sqrt(17)*C.z + + point = Point(2, 3) + assert vector_integrate(C.i*C.y, point) == ParametricIntegral(C.y*C.i, ParametricRegion((2, 3))) + + c3 = ImplicitRegion((x, y), x**2 + y**2 - 4) + assert vector_integrate(45, c3) == 180*pi + c4 = ImplicitRegion((x, y), (x - 3)**2 + (y - 4)**2 - 9) + assert vector_integrate(1, c4) == 6*pi + + pl = Plane(Point(1, 1, 1), Point(2, 3, 4), Point(2, 2, 2)) + raises(ValueError, lambda: vector_integrate(C.x*C.z*C.i + C.k, pl)) diff --git a/phivenv/Lib/site-packages/sympy/vector/tests/test_operators.py b/phivenv/Lib/site-packages/sympy/vector/tests/test_operators.py new file mode 100644 index 0000000000000000000000000000000000000000..5734edadd00547c67d6f864b50afd966ad8392a6 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/vector/tests/test_operators.py @@ -0,0 +1,43 @@ +from sympy.vector import CoordSys3D, Gradient, Divergence, Curl, VectorZero, Laplacian +from sympy.printing.repr import srepr + +R = CoordSys3D('R') +s1 = R.x*R.y*R.z # type: ignore +s2 = R.x + 3*R.y**2 # type: ignore +s3 = R.x**2 + R.y**2 + R.z**2 # type: ignore +v1 = R.x*R.i + R.z*R.z*R.j # type: ignore +v2 = R.x*R.i + R.y*R.j + R.z*R.k # type: ignore +v3 = R.x**2*R.i + R.y**2*R.j + R.z**2*R.k # type: ignore + + +def test_Gradient(): + assert Gradient(s1) == Gradient(R.x*R.y*R.z) + assert Gradient(s2) == Gradient(R.x + 3*R.y**2) + assert Gradient(s1).doit() == R.y*R.z*R.i + R.x*R.z*R.j + R.x*R.y*R.k + assert Gradient(s2).doit() == R.i + 6*R.y*R.j + + +def test_Divergence(): + assert Divergence(v1) == Divergence(R.x*R.i + R.z*R.z*R.j) + assert Divergence(v2) == Divergence(R.x*R.i + R.y*R.j + R.z*R.k) + assert Divergence(v1).doit() == 1 + assert Divergence(v2).doit() == 3 + # issue 22384 + Rc = CoordSys3D('R', transformation='cylindrical') + assert Divergence(Rc.i).doit() == 1/Rc.r + + +def test_Curl(): + assert Curl(v1) == Curl(R.x*R.i + R.z*R.z*R.j) + assert Curl(v2) == Curl(R.x*R.i + R.y*R.j + R.z*R.k) + assert Curl(v1).doit() == (-2*R.z)*R.i + assert Curl(v2).doit() == VectorZero() + + +def test_Laplacian(): + assert Laplacian(s3) == Laplacian(R.x**2 + R.y**2 + R.z**2) + assert Laplacian(v3) == Laplacian(R.x**2*R.i + R.y**2*R.j + R.z**2*R.k) + assert Laplacian(s3).doit() == 6 + assert Laplacian(v3).doit() == 2*R.i + 2*R.j + 2*R.k + assert srepr(Laplacian(s3)) == \ + 'Laplacian(Add(Pow(R.x, Integer(2)), Pow(R.y, Integer(2)), Pow(R.z, Integer(2))))' diff --git a/phivenv/Lib/site-packages/sympy/vector/tests/test_parametricregion.py b/phivenv/Lib/site-packages/sympy/vector/tests/test_parametricregion.py new file mode 100644 index 0000000000000000000000000000000000000000..e785b96744f9e2c39e91b997fcb70f8a921256bd --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/vector/tests/test_parametricregion.py @@ -0,0 +1,97 @@ +from sympy.core.numbers import pi +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.vector.coordsysrect import CoordSys3D +from sympy.vector.parametricregion import ParametricRegion, parametric_region_list +from sympy.geometry import Point, Segment, Curve, Ellipse, Line, Parabola, Polygon +from sympy.testing.pytest import raises +from sympy.abc import a, b, r, t, x, y, z, theta, phi + + +C = CoordSys3D('C') + +def test_ParametricRegion(): + + point = ParametricRegion((3, 4)) + assert point.definition == (3, 4) + assert point.parameters == () + assert point.limits == {} + assert point.dimensions == 0 + + # line x = y + line_xy = ParametricRegion((y, y), (y, 1, 5)) + assert line_xy .definition == (y, y) + assert line_xy.parameters == (y,) + assert line_xy.dimensions == 1 + + # line y = z + line_yz = ParametricRegion((x,t,t), x, (t, 1, 2)) + assert line_yz.definition == (x,t,t) + assert line_yz.parameters == (x, t) + assert line_yz.limits == {t: (1, 2)} + assert line_yz.dimensions == 1 + + p1 = ParametricRegion((9*a, -16*b), (a, 0, 2), (b, -1, 5)) + assert p1.definition == (9*a, -16*b) + assert p1.parameters == (a, b) + assert p1.limits == {a: (0, 2), b: (-1, 5)} + assert p1.dimensions == 2 + + p2 = ParametricRegion((t, t**3), t) + assert p2.parameters == (t,) + assert p2.limits == {} + assert p2.dimensions == 0 + + circle = ParametricRegion((r*cos(theta), r*sin(theta)), r, (theta, 0, 2*pi)) + assert circle.definition == (r*cos(theta), r*sin(theta)) + assert circle.dimensions == 1 + + halfdisc = ParametricRegion((r*cos(theta), r*sin(theta)), (r, -2, 2), (theta, 0, pi)) + assert halfdisc.definition == (r*cos(theta), r*sin(theta)) + assert halfdisc.parameters == (r, theta) + assert halfdisc.limits == {r: (-2, 2), theta: (0, pi)} + assert halfdisc.dimensions == 2 + + ellipse = ParametricRegion((a*cos(t), b*sin(t)), (t, 0, 8)) + assert ellipse.parameters == (t,) + assert ellipse.limits == {t: (0, 8)} + assert ellipse.dimensions == 1 + + cylinder = ParametricRegion((r*cos(theta), r*sin(theta), z), (r, 0, 1), (theta, 0, 2*pi), (z, 0, 4)) + assert cylinder.parameters == (r, theta, z) + assert cylinder.dimensions == 3 + + sphere = ParametricRegion((r*sin(phi)*cos(theta),r*sin(phi)*sin(theta), r*cos(phi)), + r, (theta, 0, 2*pi), (phi, 0, pi)) + assert sphere.definition == (r*sin(phi)*cos(theta),r*sin(phi)*sin(theta), r*cos(phi)) + assert sphere.parameters == (r, theta, phi) + assert sphere.dimensions == 2 + + raises(ValueError, lambda: ParametricRegion((a*t**2, 2*a*t), (a, -2))) + raises(ValueError, lambda: ParametricRegion((a, b), (a**2, sin(b)), (a, 2, 4, 6))) + + +def test_parametric_region_list(): + + point = Point(-5, 12) + assert parametric_region_list(point) == [ParametricRegion((-5, 12))] + + e = Ellipse(Point(2, 8), 2, 6) + assert parametric_region_list(e, t) == [ParametricRegion((2*cos(t) + 2, 6*sin(t) + 8), (t, 0, 2*pi))] + + c = Curve((t, t**3), (t, 5, 3)) + assert parametric_region_list(c) == [ParametricRegion((t, t**3), (t, 5, 3))] + + s = Segment(Point(2, 11, -6), Point(0, 2, 5)) + assert parametric_region_list(s, t) == [ParametricRegion((2 - 2*t, 11 - 9*t, 11*t - 6), (t, 0, 1))] + s1 = Segment(Point(0, 0), (1, 0)) + assert parametric_region_list(s1, t) == [ParametricRegion((t, 0), (t, 0, 1))] + s2 = Segment(Point(1, 2, 3), Point(1, 2, 5)) + assert parametric_region_list(s2, t) == [ParametricRegion((1, 2, 2*t + 3), (t, 0, 1))] + s3 = Segment(Point(12, 56), Point(12, 56)) + assert parametric_region_list(s3) == [ParametricRegion((12, 56))] + + poly = Polygon((1,3), (-3, 8), (2, 4)) + assert parametric_region_list(poly, t) == [ParametricRegion((1 - 4*t, 5*t + 3), (t, 0, 1)), ParametricRegion((5*t - 3, 8 - 4*t), (t, 0, 1)), ParametricRegion((2 - t, 4 - t), (t, 0, 1))] + + p1 = Parabola(Point(0, 0), Line(Point(5, 8), Point(7,8))) + raises(ValueError, lambda: parametric_region_list(p1)) diff --git a/phivenv/Lib/site-packages/sympy/vector/tests/test_printing.py b/phivenv/Lib/site-packages/sympy/vector/tests/test_printing.py new file mode 100644 index 0000000000000000000000000000000000000000..ae76905e967bdf93485f135c6a69f968e1208986 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/vector/tests/test_printing.py @@ -0,0 +1,221 @@ +# -*- coding: utf-8 -*- +from sympy.core.function import Function +from sympy.integrals.integrals import Integral +from sympy.printing.latex import latex +from sympy.printing.pretty import pretty as xpretty +from sympy.vector import CoordSys3D, Del, Vector, express +from sympy.abc import a, b, c +from sympy.testing.pytest import XFAIL + + +def pretty(expr): + """ASCII pretty-printing""" + return xpretty(expr, use_unicode=False, wrap_line=False) + + +def upretty(expr): + """Unicode pretty-printing""" + return xpretty(expr, use_unicode=True, wrap_line=False) + + +# Initialize the basic and tedious vector/dyadic expressions +# needed for testing. +# Some of the pretty forms shown denote how the expressions just +# above them should look with pretty printing. +N = CoordSys3D('N') +C = N.orient_new_axis('C', a, N.k) # type: ignore +v = [] +d = [] +v.append(Vector.zero) +v.append(N.i) # type: ignore +v.append(-N.i) # type: ignore +v.append(N.i + N.j) # type: ignore +v.append(a*N.i) # type: ignore +v.append(a*N.i - b*N.j) # type: ignore +v.append((a**2 + N.x)*N.i + N.k) # type: ignore +v.append((a**2 + b)*N.i + 3*(C.y - c)*N.k) # type: ignore +f = Function('f') +v.append(N.j - (Integral(f(b)) - C.x**2)*N.k) # type: ignore +upretty_v_8 = """\ + ⎛ 2 ⌠ ⎞ \n\ +j_N + ⎜x_C - ⎮ f(b) db⎟ k_N\n\ + ⎝ ⌡ ⎠ \ +""" +pretty_v_8 = """\ +j_N + / / \\\n\ + | 2 | |\n\ + |x_C - | f(b) db|\n\ + | | |\n\ + \\ / / \ +""" + +v.append(N.i + C.k) # type: ignore +v.append(express(N.i, C)) # type: ignore +v.append((a**2 + b)*N.i + (Integral(f(b)))*N.k) # type: ignore +upretty_v_11 = """\ +⎛ 2 ⎞ ⎛⌠ ⎞ \n\ +⎝a + b⎠ i_N + ⎜⎮ f(b) db⎟ k_N\n\ + ⎝⌡ ⎠ \ +""" +pretty_v_11 = """\ +/ 2 \\ + / / \\\n\ +\\a + b/ i_N| | |\n\ + | | f(b) db|\n\ + | | |\n\ + \\/ / \ +""" + +for x in v: + d.append(x | N.k) # type: ignore +s = 3*N.x**2*C.y # type: ignore +upretty_s = """\ + 2\n\ +3⋅y_C⋅x_N \ +""" +pretty_s = """\ + 2\n\ +3*y_C*x_N \ +""" + +# This is the pretty form for ((a**2 + b)*N.i + 3*(C.y - c)*N.k) | N.k +upretty_d_7 = """\ +⎛ 2 ⎞ \n\ +⎝a + b⎠ (i_N|k_N) + (3⋅y_C - 3⋅c) (k_N|k_N)\ +""" +pretty_d_7 = """\ +/ 2 \\ (i_N|k_N) + (3*y_C - 3*c) (k_N|k_N)\n\ +\\a + b/ \ +""" + + +def test_str_printing(): + assert str(v[0]) == '0' + assert str(v[1]) == 'N.i' + assert str(v[2]) == '(-1)*N.i' + assert str(v[3]) == 'N.i + N.j' + assert str(v[8]) == 'N.j + (C.x**2 - Integral(f(b), b))*N.k' + assert str(v[9]) == 'C.k + N.i' + assert str(s) == '3*C.y*N.x**2' + assert str(d[0]) == '0' + assert str(d[1]) == '(N.i|N.k)' + assert str(d[4]) == 'a*(N.i|N.k)' + assert str(d[5]) == 'a*(N.i|N.k) + (-b)*(N.j|N.k)' + assert str(d[8]) == ('(N.j|N.k) + (C.x**2 - ' + + 'Integral(f(b), b))*(N.k|N.k)') + + +@XFAIL +def test_pretty_printing_ascii(): + assert pretty(v[0]) == '0' + assert pretty(v[1]) == 'i_N' + assert pretty(v[5]) == '(a) i_N + (-b) j_N' + assert pretty(v[8]) == pretty_v_8 + assert pretty(v[2]) == '(-1) i_N' + assert pretty(v[11]) == pretty_v_11 + assert pretty(s) == pretty_s + assert pretty(d[0]) == '(0|0)' + assert pretty(d[5]) == '(a) (i_N|k_N) + (-b) (j_N|k_N)' + assert pretty(d[7]) == pretty_d_7 + assert pretty(d[10]) == '(cos(a)) (i_C|k_N) + (-sin(a)) (j_C|k_N)' + + +def test_pretty_print_unicode_v(): + assert upretty(v[0]) == '0' + assert upretty(v[1]) == 'i_N' + assert upretty(v[5]) == '(a) i_N + (-b) j_N' + # Make sure the printing works in other objects + assert upretty(v[5].args) == '((a) i_N, (-b) j_N)' + assert upretty(v[8]) == upretty_v_8 + assert upretty(v[2]) == '(-1) i_N' + assert upretty(v[11]) == upretty_v_11 + assert upretty(s) == upretty_s + assert upretty(d[0]) == '(0|0)' + assert upretty(d[5]) == '(a) (i_N|k_N) + (-b) (j_N|k_N)' + assert upretty(d[7]) == upretty_d_7 + assert upretty(d[10]) == '(cos(a)) (i_C|k_N) + (-sin(a)) (j_C|k_N)' + + +def test_latex_printing(): + assert latex(v[0]) == '\\mathbf{\\hat{0}}' + assert latex(v[1]) == '\\mathbf{\\hat{i}_{N}}' + assert latex(v[2]) == '- \\mathbf{\\hat{i}_{N}}' + assert latex(v[5]) == ('\\left(a\\right)\\mathbf{\\hat{i}_{N}} + ' + + '\\left(- b\\right)\\mathbf{\\hat{j}_{N}}') + assert latex(v[6]) == ('\\left(\\mathbf{{x}_{N}} + a^{2}\\right)\\mathbf{\\hat{i}_' + + '{N}} + \\mathbf{\\hat{k}_{N}}') + assert latex(v[8]) == ('\\mathbf{\\hat{j}_{N}} + \\left(\\mathbf{{x}_' + + '{C}}^{2} - \\int f{\\left(b \\right)}\\,' + + ' db\\right)\\mathbf{\\hat{k}_{N}}') + assert latex(s) == '3 \\mathbf{{y}_{C}} \\mathbf{{x}_{N}}^{2}' + assert latex(d[0]) == '(\\mathbf{\\hat{0}}|\\mathbf{\\hat{0}})' + assert latex(d[4]) == ('\\left(a\\right)\\left(\\mathbf{\\hat{i}_{N}}{\\middle|}' + + '\\mathbf{\\hat{k}_{N}}\\right)') + assert latex(d[9]) == ('\\left(\\mathbf{\\hat{k}_{C}}{\\middle|}' + + '\\mathbf{\\hat{k}_{N}}\\right) + \\left(' + + '\\mathbf{\\hat{i}_{N}}{\\middle|}\\mathbf{' + + '\\hat{k}_{N}}\\right)') + assert latex(d[11]) == ('\\left(a^{2} + b\\right)\\left(\\mathbf{\\hat{i}_{N}}' + + '{\\middle|}\\mathbf{\\hat{k}_{N}}\\right) + ' + + '\\left(\\int f{\\left(b \\right)}\\, db\\right)\\left(' + + '\\mathbf{\\hat{k}_{N}}{\\middle|}\\mathbf{' + + '\\hat{k}_{N}}\\right)') + +def test_issue_23058(): + from sympy import symbols, sin, cos, pi, UnevaluatedExpr + + delop = Del() + CC_ = CoordSys3D("C") + y = CC_.y + xhat = CC_.i + + t = symbols("t") + ten = symbols("10", positive=True) + eps, mu = 4*pi*ten**(-11), ten**(-5) + + Bx = 2 * ten**(-4) * cos(ten**5 * t) * sin(ten**(-3) * y) + vecB = Bx * xhat + vecE = (1/eps) * Integral(delop.cross(vecB/mu).doit(), t) + vecE = vecE.doit() + + vecB_str = """\ +⎛ ⎛y_C⎞ ⎛ 5 ⎞⎞ \n\ +⎜2⋅sin⎜───⎟⋅cos⎝10 ⋅t⎠⎟ i_C\n\ +⎜ ⎜ 3⎟ ⎟ \n\ +⎜ ⎝10 ⎠ ⎟ \n\ +⎜─────────────────────⎟ \n\ +⎜ 4 ⎟ \n\ +⎝ 10 ⎠ \ +""" + vecE_str = """\ +⎛ 4 ⎛ 5 ⎞ ⎛y_C⎞ ⎞ \n\ +⎜-10 ⋅sin⎝10 ⋅t⎠⋅cos⎜───⎟ ⎟ k_C\n\ +⎜ ⎜ 3⎟ ⎟ \n\ +⎜ ⎝10 ⎠ ⎟ \n\ +⎜─────────────────────────⎟ \n\ +⎝ 2⋅π ⎠ \ +""" + + assert upretty(vecB) == vecB_str + assert upretty(vecE) == vecE_str + + ten = UnevaluatedExpr(10) + eps, mu = 4*pi*ten**(-11), ten**(-5) + + Bx = 2 * ten**(-4) * cos(ten**5 * t) * sin(ten**(-3) * y) + vecB = Bx * xhat + + vecB_str = """\ +⎛ -4 ⎛ 5⎞ ⎛ -3⎞⎞ \n\ +⎝2⋅10 ⋅cos⎝t⋅10 ⎠⋅sin⎝y_C⋅10 ⎠⎠ i_C \ +""" + assert upretty(vecB) == vecB_str + +def test_custom_names(): + A = CoordSys3D('A', vector_names=['x', 'y', 'z'], + variable_names=['i', 'j', 'k']) + assert A.i.__str__() == 'A.i' + assert A.x.__str__() == 'A.x' + assert A.i._pretty_form == 'i_A' + assert A.x._pretty_form == 'x_A' + assert A.i._latex_form == r'\mathbf{{i}_{A}}' + assert A.x._latex_form == r"\mathbf{\hat{x}_{A}}" diff --git a/phivenv/Lib/site-packages/sympy/vector/tests/test_vector.py b/phivenv/Lib/site-packages/sympy/vector/tests/test_vector.py new file mode 100644 index 0000000000000000000000000000000000000000..daba6d6a02c87b41a8bf801eee9b9045897d0003 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/vector/tests/test_vector.py @@ -0,0 +1,342 @@ +from sympy.core import Rational, S, Add, Mul, I +from sympy.simplify import simplify, trigsimp +from sympy.core.function import (Derivative, Function, diff) +from sympy.core.numbers import pi +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.integrals.integrals import Integral +from sympy.matrices.immutable import ImmutableDenseMatrix as Matrix +from sympy.vector.vector import Vector, BaseVector, VectorAdd, \ + VectorMul, VectorZero +from sympy.vector.coordsysrect import CoordSys3D +from sympy.vector.vector import Cross, Dot, cross +from sympy.testing.pytest import raises +from sympy.vector.kind import VectorKind +from sympy.core.kind import NumberKind +from sympy.testing.pytest import XFAIL + + +C = CoordSys3D('C') + +i, j, k = C.base_vectors() +a, b, c = symbols('a b c') + + +def test_cross(): + v1 = C.x * i + C.z * C.z * j + v2 = C.x * i + C.y * j + C.z * k + assert Cross(v1, v2) == Cross(C.x*C.i + C.z**2*C.j, C.x*C.i + C.y*C.j + C.z*C.k) + assert Cross(v1, v2).doit() == C.z**3*C.i + (-C.x*C.z)*C.j + (C.x*C.y - C.x*C.z**2)*C.k + assert cross(v1, v2) == C.z**3*C.i + (-C.x*C.z)*C.j + (C.x*C.y - C.x*C.z**2)*C.k + assert Cross(v1, v2) == -Cross(v2, v1) + # XXX: Cannot use Cross here. See XFAIL test below: + assert cross(v1, v2) + cross(v2, v1) == Vector.zero + + +@XFAIL +def test_cross_xfail(): + v1 = C.x * i + C.z * C.z * j + v2 = C.x * i + C.y * j + C.z * k + assert Cross(v1, v2) + Cross(v2, v1) == Vector.zero + + +def test_dot(): + v1 = C.x * i + C.z * C.z * j + v2 = C.x * i + C.y * j + C.z * k + assert Dot(v1, v2) == Dot(C.x*C.i + C.z**2*C.j, C.x*C.i + C.y*C.j + C.z*C.k) + assert Dot(v1, v2).doit() == C.x**2 + C.y*C.z**2 + assert Dot(v2, v1).doit() == C.x**2 + C.y*C.z**2 + assert Dot(v1, v2) == Dot(v2, v1) + + +def test_vector_sympy(): + """ + Test whether the Vector framework confirms to the hashing + and equality testing properties of SymPy. + """ + v1 = 3*j + assert v1 == j*3 + assert v1.components == {j: 3} + v2 = 3*i + 4*j + 5*k + v3 = 2*i + 4*j + i + 4*k + k + assert v3 == v2 + assert v3.__hash__() == v2.__hash__() + + +def test_kind(): + assert C.i.kind is VectorKind(NumberKind) + assert C.j.kind is VectorKind(NumberKind) + assert C.k.kind is VectorKind(NumberKind) + + assert C.x.kind is NumberKind + assert C.y.kind is NumberKind + assert C.z.kind is NumberKind + + assert Mul._kind_dispatcher(NumberKind, VectorKind(NumberKind)) is VectorKind(NumberKind) + assert Mul(2, C.i).kind is VectorKind(NumberKind) + + v1 = C.x * i + C.z * C.z * j + v2 = C.x * i + C.y * j + C.z * k + assert v1.kind is VectorKind(NumberKind) + assert v2.kind is VectorKind(NumberKind) + + assert (v1 + v2).kind is VectorKind(NumberKind) + assert Add(v1, v2).kind is VectorKind(NumberKind) + assert Cross(v1, v2).doit().kind is VectorKind(NumberKind) + assert VectorAdd(v1, v2).kind is VectorKind(NumberKind) + assert VectorMul(2, v1).kind is VectorKind(NumberKind) + assert VectorZero().kind is VectorKind(NumberKind) + + assert v1.projection(v2).kind is VectorKind(NumberKind) + assert v2.projection(v1).kind is VectorKind(NumberKind) + + +def test_vectoradd(): + assert isinstance(Add(C.i, C.j), VectorAdd) + v1 = C.x * i + C.z * C.z * j + v2 = C.x * i + C.y * j + C.z * k + assert isinstance(Add(v1, v2), VectorAdd) + + # https://github.com/sympy/sympy/issues/26121 + + E = Matrix([C.i, C.j, C.k]).T + a = Matrix([1, 2, 3]) + av = E*a + + assert av[0].kind == VectorKind() + assert isinstance(av[0], VectorAdd) + + +def test_vector(): + assert isinstance(i, BaseVector) + assert i != j + assert j != k + assert k != i + assert i - i == Vector.zero + assert i + Vector.zero == i + assert i - Vector.zero == i + assert Vector.zero != 0 + assert -Vector.zero == Vector.zero + + v1 = a*i + b*j + c*k + v2 = a**2*i + b**2*j + c**2*k + v3 = v1 + v2 + v4 = 2 * v1 + v5 = a * i + + assert isinstance(v1, VectorAdd) + assert v1 - v1 == Vector.zero + assert v1 + Vector.zero == v1 + assert v1.dot(i) == a + assert v1.dot(j) == b + assert v1.dot(k) == c + assert i.dot(v2) == a**2 + assert j.dot(v2) == b**2 + assert k.dot(v2) == c**2 + assert v3.dot(i) == a**2 + a + assert v3.dot(j) == b**2 + b + assert v3.dot(k) == c**2 + c + + assert v1 + v2 == v2 + v1 + assert v1 - v2 == -1 * (v2 - v1) + assert a * v1 == v1 * a + + assert isinstance(v5, VectorMul) + assert v5.base_vector == i + assert v5.measure_number == a + assert isinstance(v4, Vector) + assert isinstance(v4, VectorAdd) + assert isinstance(v4, Vector) + assert isinstance(Vector.zero, VectorZero) + assert isinstance(Vector.zero, Vector) + assert isinstance(v1 * 0, VectorZero) + + assert v1.to_matrix(C) == Matrix([[a], [b], [c]]) + + assert i.components == {i: 1} + assert v5.components == {i: a} + assert v1.components == {i: a, j: b, k: c} + + assert VectorAdd(v1, Vector.zero) == v1 + assert VectorMul(a, v1) == v1*a + assert VectorMul(1, i) == i + assert VectorAdd(v1, Vector.zero) == v1 + assert VectorMul(0, Vector.zero) == Vector.zero + raises(TypeError, lambda: v1.outer(1)) + raises(TypeError, lambda: v1.dot(1)) + + +def test_vector_magnitude_normalize(): + assert Vector.zero.magnitude() == 0 + assert Vector.zero.normalize() == Vector.zero + + assert i.magnitude() == 1 + assert j.magnitude() == 1 + assert k.magnitude() == 1 + assert i.normalize() == i + assert j.normalize() == j + assert k.normalize() == k + + v1 = a * i + assert v1.normalize() == (a/sqrt(a**2))*i + assert v1.magnitude() == sqrt(a**2) + + v2 = a*i + b*j + c*k + assert v2.magnitude() == sqrt(a**2 + b**2 + c**2) + assert v2.normalize() == v2 / v2.magnitude() + + v3 = i + j + assert v3.normalize() == (sqrt(2)/2)*C.i + (sqrt(2)/2)*C.j + + +def test_vector_simplify(): + A, s, k, m = symbols('A, s, k, m') + + test1 = (1 / a + 1 / b) * i + assert (test1 & i) != (a + b) / (a * b) + test1 = simplify(test1) + assert (test1 & i) == (a + b) / (a * b) + assert test1.simplify() == simplify(test1) + + test2 = (A**2 * s**4 / (4 * pi * k * m**3)) * i + test2 = simplify(test2) + assert (test2 & i) == (A**2 * s**4 / (4 * pi * k * m**3)) + + test3 = ((4 + 4 * a - 2 * (2 + 2 * a)) / (2 + 2 * a)) * i + test3 = simplify(test3) + assert (test3 & i) == 0 + + test4 = ((-4 * a * b**2 - 2 * b**3 - 2 * a**2 * b) / (a + b)**2) * i + test4 = simplify(test4) + assert (test4 & i) == -2 * b + + v = (sin(a)+cos(a))**2*i - j + assert trigsimp(v) == (2*sin(a + pi/4)**2)*i + (-1)*j + assert trigsimp(v) == v.trigsimp() + + assert simplify(Vector.zero) == Vector.zero + + +def test_vector_equals(): + assert (2*i).equals(j) is False + assert i.equals(i) is True + + # https://github.com/sympy/sympy/issues/25915 + A = (sqrt(2) + sqrt(6)) / sqrt(sqrt(3) + 2) + assert (A*i).equals(2*i) is True + assert (A*i).equals(3*i) is False + + # Test comparing vectors in different coordinate systems + D = C.orient_new_axis('D', pi/2, C.k) + assert (D.i).equals(C.j) is True + assert (D.i).equals(C.i) is False + + +def test_vector_conjugate(): + # https://github.com/sympy/sympy/issues/27094 + assert (I*i + (1 + I)*j + 2*k).conjugate() == -I*i + (1 - I)*j + 2*k + + +def test_vector_dot(): + assert i.dot(Vector.zero) == 0 + assert Vector.zero.dot(i) == 0 + assert i & Vector.zero == 0 + + assert i.dot(i) == 1 + assert i.dot(j) == 0 + assert i.dot(k) == 0 + assert i & i == 1 + assert i & j == 0 + assert i & k == 0 + + assert j.dot(i) == 0 + assert j.dot(j) == 1 + assert j.dot(k) == 0 + assert j & i == 0 + assert j & j == 1 + assert j & k == 0 + + assert k.dot(i) == 0 + assert k.dot(j) == 0 + assert k.dot(k) == 1 + assert k & i == 0 + assert k & j == 0 + assert k & k == 1 + + raises(TypeError, lambda: k.dot(1)) + + +def test_vector_cross(): + assert i.cross(Vector.zero) == Vector.zero + assert Vector.zero.cross(i) == Vector.zero + + assert i.cross(i) == Vector.zero + assert i.cross(j) == k + assert i.cross(k) == -j + assert i ^ i == Vector.zero + assert i ^ j == k + assert i ^ k == -j + + assert j.cross(i) == -k + assert j.cross(j) == Vector.zero + assert j.cross(k) == i + assert j ^ i == -k + assert j ^ j == Vector.zero + assert j ^ k == i + + assert k.cross(i) == j + assert k.cross(j) == -i + assert k.cross(k) == Vector.zero + assert k ^ i == j + assert k ^ j == -i + assert k ^ k == Vector.zero + + assert k.cross(1) == Cross(k, 1) + + +def test_projection(): + v1 = i + j + k + v2 = 3*i + 4*j + v3 = 0*i + 0*j + assert v1.projection(v1) == i + j + k + assert v1.projection(v2) == Rational(7, 3)*C.i + Rational(7, 3)*C.j + Rational(7, 3)*C.k + assert v1.projection(v1, scalar=True) == S.One + assert v1.projection(v2, scalar=True) == Rational(7, 3) + assert v3.projection(v1) == Vector.zero + assert v3.projection(v1, scalar=True) == S.Zero + + +def test_vector_diff_integrate(): + f = Function('f') + v = f(a)*C.i + a**2*C.j - C.k + assert Derivative(v, a) == Derivative((f(a))*C.i + + a**2*C.j + (-1)*C.k, a) + assert (diff(v, a) == v.diff(a) == Derivative(v, a).doit() == + (Derivative(f(a), a))*C.i + 2*a*C.j) + assert (Integral(v, a) == (Integral(f(a), a))*C.i + + (Integral(a**2, a))*C.j + (Integral(-1, a))*C.k) + + +def test_vector_args(): + raises(ValueError, lambda: BaseVector(3, C)) + raises(TypeError, lambda: BaseVector(0, Vector.zero)) + + +def test_srepr(): + from sympy.printing.repr import srepr + res = "CoordSys3D(Str('C'), Tuple(ImmutableDenseMatrix([[Integer(1), "\ + "Integer(0), Integer(0)], [Integer(0), Integer(1), Integer(0)], "\ + "[Integer(0), Integer(0), Integer(1)]]), VectorZero())).i" + assert srepr(C.i) == res + + +def test_scalar(): + from sympy.vector import CoordSys3D + C = CoordSys3D('C') + v1 = 3*C.i + 4*C.j + 5*C.k + v2 = 3*C.i - 4*C.j + 5*C.k + assert v1.is_Vector is True + assert v1.is_scalar is False + assert (v1.dot(v2)).is_scalar is True + assert (v1.cross(v2)).is_scalar is False diff --git a/phivenv/Lib/site-packages/sympy/vector/vector.py b/phivenv/Lib/site-packages/sympy/vector/vector.py new file mode 100644 index 0000000000000000000000000000000000000000..c035ef48d2edd511f6cdbca19e965d99a2c8c66e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/vector/vector.py @@ -0,0 +1,714 @@ +from __future__ import annotations +from itertools import product + +from sympy.core import Add, Basic +from sympy.core.assumptions import StdFactKB +from sympy.core.expr import AtomicExpr, Expr +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.sorting import default_sort_key +from sympy.core.sympify import sympify +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.matrices.immutable import ImmutableDenseMatrix as Matrix +from sympy.vector.basisdependent import (BasisDependentZero, + BasisDependent, BasisDependentMul, BasisDependentAdd) +from sympy.vector.coordsysrect import CoordSys3D +from sympy.vector.dyadic import Dyadic, BaseDyadic, DyadicAdd +from sympy.vector.kind import VectorKind + + +class Vector(BasisDependent): + """ + Super class for all Vector classes. + Ideally, neither this class nor any of its subclasses should be + instantiated by the user. + """ + + is_scalar = False + is_Vector = True + _op_priority = 12.0 + + _expr_type: type[Vector] + _mul_func: type[Vector] + _add_func: type[Vector] + _zero_func: type[Vector] + _base_func: type[Vector] + zero: VectorZero + + kind: VectorKind = VectorKind() + + @property + def components(self): + """ + Returns the components of this vector in the form of a + Python dictionary mapping BaseVector instances to the + corresponding measure numbers. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> C = CoordSys3D('C') + >>> v = 3*C.i + 4*C.j + 5*C.k + >>> v.components + {C.i: 3, C.j: 4, C.k: 5} + + """ + # The '_components' attribute is defined according to the + # subclass of Vector the instance belongs to. + return self._components + + def magnitude(self): + """ + Returns the magnitude of this vector. + """ + return sqrt(self & self) + + def normalize(self): + """ + Returns the normalized version of this vector. + """ + return self / self.magnitude() + + def equals(self, other): + """ + Check if ``self`` and ``other`` are identically equal vectors. + + Explanation + =========== + + Checks if two vector expressions are equal for all possible values of + the symbols present in the expressions. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> from sympy.abc import x, y + >>> from sympy import pi + >>> C = CoordSys3D('C') + + Compare vectors that are equal or not: + + >>> C.i.equals(C.j) + False + >>> C.i.equals(C.i) + True + + These two vectors are equal if `x = y` but are not identically equal + as expressions since for some values of `x` and `y` they are unequal: + + >>> v1 = x*C.i + C.j + >>> v2 = y*C.i + C.j + >>> v1.equals(v1) + True + >>> v1.equals(v2) + False + + Vectors from different coordinate systems can be compared: + + >>> D = C.orient_new_axis('D', pi/2, C.i) + >>> D.j.equals(C.j) + False + >>> D.j.equals(C.k) + True + + Parameters + ========== + + other: Vector + The other vector expression to compare with. + + Returns + ======= + + ``True``, ``False`` or ``None``. A return value of ``True`` indicates + that the two vectors are identically equal. A return value of ``False`` + indicates that they are not. In some cases it is not possible to + determine if the two vectors are identically equal and ``None`` is + returned. + + See Also + ======== + + sympy.core.expr.Expr.equals + """ + diff = self - other + diff_mag2 = diff.dot(diff) + return diff_mag2.equals(0) + + def dot(self, other): + """ + Returns the dot product of this Vector, either with another + Vector, or a Dyadic, or a Del operator. + If 'other' is a Vector, returns the dot product scalar (SymPy + expression). + If 'other' is a Dyadic, the dot product is returned as a Vector. + If 'other' is an instance of Del, returns the directional + derivative operator as a Python function. If this function is + applied to a scalar expression, it returns the directional + derivative of the scalar field wrt this Vector. + + Parameters + ========== + + other: Vector/Dyadic/Del + The Vector or Dyadic we are dotting with, or a Del operator . + + Examples + ======== + + >>> from sympy.vector import CoordSys3D, Del + >>> C = CoordSys3D('C') + >>> delop = Del() + >>> C.i.dot(C.j) + 0 + >>> C.i & C.i + 1 + >>> v = 3*C.i + 4*C.j + 5*C.k + >>> v.dot(C.k) + 5 + >>> (C.i & delop)(C.x*C.y*C.z) + C.y*C.z + >>> d = C.i.outer(C.i) + >>> C.i.dot(d) + C.i + + """ + + # Check special cases + if isinstance(other, Dyadic): + if isinstance(self, VectorZero): + return Vector.zero + outvec = Vector.zero + for k, v in other.components.items(): + vect_dot = k.args[0].dot(self) + outvec += vect_dot * v * k.args[1] + return outvec + from sympy.vector.deloperator import Del + if not isinstance(other, (Del, Vector)): + raise TypeError(str(other) + " is not a vector, dyadic or " + + "del operator") + + # Check if the other is a del operator + if isinstance(other, Del): + def directional_derivative(field): + from sympy.vector.functions import directional_derivative + return directional_derivative(field, self) + return directional_derivative + + return dot(self, other) + + def __and__(self, other): + return self.dot(other) + + __and__.__doc__ = dot.__doc__ + + def cross(self, other): + """ + Returns the cross product of this Vector with another Vector or + Dyadic instance. + The cross product is a Vector, if 'other' is a Vector. If 'other' + is a Dyadic, this returns a Dyadic instance. + + Parameters + ========== + + other: Vector/Dyadic + The Vector or Dyadic we are crossing with. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> C = CoordSys3D('C') + >>> C.i.cross(C.j) + C.k + >>> C.i ^ C.i + 0 + >>> v = 3*C.i + 4*C.j + 5*C.k + >>> v ^ C.i + 5*C.j + (-4)*C.k + >>> d = C.i.outer(C.i) + >>> C.j.cross(d) + (-1)*(C.k|C.i) + + """ + + # Check special cases + if isinstance(other, Dyadic): + if isinstance(self, VectorZero): + return Dyadic.zero + outdyad = Dyadic.zero + for k, v in other.components.items(): + cross_product = self.cross(k.args[0]) + outer = cross_product.outer(k.args[1]) + outdyad += v * outer + return outdyad + + return cross(self, other) + + def __xor__(self, other): + return self.cross(other) + + __xor__.__doc__ = cross.__doc__ + + def outer(self, other): + """ + Returns the outer product of this vector with another, in the + form of a Dyadic instance. + + Parameters + ========== + + other : Vector + The Vector with respect to which the outer product is to + be computed. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> N = CoordSys3D('N') + >>> N.i.outer(N.j) + (N.i|N.j) + + """ + + # Handle the special cases + if not isinstance(other, Vector): + raise TypeError("Invalid operand for outer product") + elif (isinstance(self, VectorZero) or + isinstance(other, VectorZero)): + return Dyadic.zero + + # Iterate over components of both the vectors to generate + # the required Dyadic instance + args = [(v1 * v2) * BaseDyadic(k1, k2) for (k1, v1), (k2, v2) + in product(self.components.items(), other.components.items())] + + return DyadicAdd(*args) + + def projection(self, other, scalar=False): + """ + Returns the vector or scalar projection of the 'other' on 'self'. + + Examples + ======== + + >>> from sympy.vector.coordsysrect import CoordSys3D + >>> C = CoordSys3D('C') + >>> i, j, k = C.base_vectors() + >>> v1 = i + j + k + >>> v2 = 3*i + 4*j + >>> v1.projection(v2) + 7/3*C.i + 7/3*C.j + 7/3*C.k + >>> v1.projection(v2, scalar=True) + 7/3 + + """ + if self.equals(Vector.zero): + return S.Zero if scalar else Vector.zero + + if scalar: + return self.dot(other) / self.dot(self) + else: + return self.dot(other) / self.dot(self) * self + + @property + def _projections(self): + """ + Returns the components of this vector but the output includes + also zero values components. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D, Vector + >>> C = CoordSys3D('C') + >>> v1 = 3*C.i + 4*C.j + 5*C.k + >>> v1._projections + (3, 4, 5) + >>> v2 = C.x*C.y*C.z*C.i + >>> v2._projections + (C.x*C.y*C.z, 0, 0) + >>> v3 = Vector.zero + >>> v3._projections + (0, 0, 0) + """ + + from sympy.vector.operators import _get_coord_systems + if isinstance(self, VectorZero): + return (S.Zero, S.Zero, S.Zero) + base_vec = next(iter(_get_coord_systems(self))).base_vectors() + return tuple([self.dot(i) for i in base_vec]) + + def __or__(self, other): + return self.outer(other) + + __or__.__doc__ = outer.__doc__ + + def to_matrix(self, system): + """ + Returns the matrix form of this vector with respect to the + specified coordinate system. + + Parameters + ========== + + system : CoordSys3D + The system wrt which the matrix form is to be computed + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> C = CoordSys3D('C') + >>> from sympy.abc import a, b, c + >>> v = a*C.i + b*C.j + c*C.k + >>> v.to_matrix(C) + Matrix([ + [a], + [b], + [c]]) + + """ + + return Matrix([self.dot(unit_vec) for unit_vec in + system.base_vectors()]) + + def separate(self): + """ + The constituents of this vector in different coordinate systems, + as per its definition. + + Returns a dict mapping each CoordSys3D to the corresponding + constituent Vector. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> R1 = CoordSys3D('R1') + >>> R2 = CoordSys3D('R2') + >>> v = R1.i + R2.i + >>> v.separate() == {R1: R1.i, R2: R2.i} + True + + """ + + parts = {} + for vect, measure in self.components.items(): + parts[vect.system] = (parts.get(vect.system, Vector.zero) + + vect * measure) + return parts + + def _div_helper(one, other): + """ Helper for division involving vectors. """ + if isinstance(one, Vector) and isinstance(other, Vector): + raise TypeError("Cannot divide two vectors") + elif isinstance(one, Vector): + if other == S.Zero: + raise ValueError("Cannot divide a vector by zero") + return VectorMul(one, Pow(other, S.NegativeOne)) + else: + raise TypeError("Invalid division involving a vector") + +# The following is adapted from the matrices.expressions.matexpr file + +def get_postprocessor(cls): + def _postprocessor(expr): + vec_class = {Add: VectorAdd}[cls] + vectors = [] + for term in expr.args: + if isinstance(term.kind, VectorKind): + vectors.append(term) + + if vec_class == VectorAdd: + return VectorAdd(*vectors).doit(deep=False) + return _postprocessor + + +Basic._constructor_postprocessor_mapping[Vector] = { + "Add": [get_postprocessor(Add)], +} + +class BaseVector(Vector, AtomicExpr): + """ + Class to denote a base vector. + + """ + + def __new__(cls, index, system, pretty_str=None, latex_str=None): + if pretty_str is None: + pretty_str = "x{}".format(index) + if latex_str is None: + latex_str = "x_{}".format(index) + pretty_str = str(pretty_str) + latex_str = str(latex_str) + # Verify arguments + if index not in range(0, 3): + raise ValueError("index must be 0, 1 or 2") + if not isinstance(system, CoordSys3D): + raise TypeError("system should be a CoordSys3D") + name = system._vector_names[index] + # Initialize an object + obj = super().__new__(cls, S(index), system) + # Assign important attributes + obj._base_instance = obj + obj._components = {obj: S.One} + obj._measure_number = S.One + obj._name = system._name + '.' + name + obj._pretty_form = '' + pretty_str + obj._latex_form = latex_str + obj._system = system + # The _id is used for printing purposes + obj._id = (index, system) + assumptions = {'commutative': True} + obj._assumptions = StdFactKB(assumptions) + + # This attr is used for re-expression to one of the systems + # involved in the definition of the Vector. Applies to + # VectorMul and VectorAdd too. + obj._sys = system + + return obj + + @property + def system(self): + return self._system + + def _sympystr(self, printer): + return self._name + + def _sympyrepr(self, printer): + index, system = self._id + return printer._print(system) + '.' + system._vector_names[index] + + @property + def free_symbols(self): + return {self} + + def _eval_conjugate(self): + return self + + +class VectorAdd(BasisDependentAdd, Vector): + """ + Class to denote sum of Vector instances. + """ + + def __new__(cls, *args, **options): + obj = BasisDependentAdd.__new__(cls, *args, **options) + return obj + + def _sympystr(self, printer): + ret_str = '' + items = list(self.separate().items()) + items.sort(key=lambda x: x[0].__str__()) + for system, vect in items: + base_vects = system.base_vectors() + for x in base_vects: + if x in vect.components: + temp_vect = self.components[x] * x + ret_str += printer._print(temp_vect) + " + " + return ret_str[:-3] + + +class VectorMul(BasisDependentMul, Vector): + """ + Class to denote products of scalars and BaseVectors. + """ + + def __new__(cls, *args, **options): + obj = BasisDependentMul.__new__(cls, *args, **options) + return obj + + @property + def base_vector(self): + """ The BaseVector involved in the product. """ + return self._base_instance + + @property + def measure_number(self): + """ The scalar expression involved in the definition of + this VectorMul. + """ + return self._measure_number + + +class VectorZero(BasisDependentZero, Vector): + """ + Class to denote a zero vector + """ + + _op_priority = 12.1 + _pretty_form = '0' + _latex_form = r'\mathbf{\hat{0}}' + + def __new__(cls): + obj = BasisDependentZero.__new__(cls) + return obj + + +class Cross(Vector): + """ + Represents unevaluated Cross product. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D, Cross + >>> R = CoordSys3D('R') + >>> v1 = R.i + R.j + R.k + >>> v2 = R.x * R.i + R.y * R.j + R.z * R.k + >>> Cross(v1, v2) + Cross(R.i + R.j + R.k, R.x*R.i + R.y*R.j + R.z*R.k) + >>> Cross(v1, v2).doit() + (-R.y + R.z)*R.i + (R.x - R.z)*R.j + (-R.x + R.y)*R.k + + """ + + def __new__(cls, expr1, expr2): + expr1 = sympify(expr1) + expr2 = sympify(expr2) + if default_sort_key(expr1) > default_sort_key(expr2): + return -Cross(expr2, expr1) + obj = Expr.__new__(cls, expr1, expr2) + obj._expr1 = expr1 + obj._expr2 = expr2 + return obj + + def doit(self, **hints): + return cross(self._expr1, self._expr2) + + +class Dot(Expr): + """ + Represents unevaluated Dot product. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D, Dot + >>> from sympy import symbols + >>> R = CoordSys3D('R') + >>> a, b, c = symbols('a b c') + >>> v1 = R.i + R.j + R.k + >>> v2 = a * R.i + b * R.j + c * R.k + >>> Dot(v1, v2) + Dot(R.i + R.j + R.k, a*R.i + b*R.j + c*R.k) + >>> Dot(v1, v2).doit() + a + b + c + + """ + + def __new__(cls, expr1, expr2): + expr1 = sympify(expr1) + expr2 = sympify(expr2) + expr1, expr2 = sorted([expr1, expr2], key=default_sort_key) + obj = Expr.__new__(cls, expr1, expr2) + obj._expr1 = expr1 + obj._expr2 = expr2 + return obj + + def doit(self, **hints): + return dot(self._expr1, self._expr2) + + +def cross(vect1, vect2): + """ + Returns cross product of two vectors. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> from sympy.vector.vector import cross + >>> R = CoordSys3D('R') + >>> v1 = R.i + R.j + R.k + >>> v2 = R.x * R.i + R.y * R.j + R.z * R.k + >>> cross(v1, v2) + (-R.y + R.z)*R.i + (R.x - R.z)*R.j + (-R.x + R.y)*R.k + + """ + if isinstance(vect1, Add): + return VectorAdd.fromiter(cross(i, vect2) for i in vect1.args) + if isinstance(vect2, Add): + return VectorAdd.fromiter(cross(vect1, i) for i in vect2.args) + if isinstance(vect1, BaseVector) and isinstance(vect2, BaseVector): + if vect1._sys == vect2._sys: + n1 = vect1.args[0] + n2 = vect2.args[0] + if n1 == n2: + return Vector.zero + n3 = ({0,1,2}.difference({n1, n2})).pop() + sign = 1 if ((n1 + 1) % 3 == n2) else -1 + return sign*vect1._sys.base_vectors()[n3] + from .functions import express + try: + v = express(vect1, vect2._sys) + except ValueError: + return Cross(vect1, vect2) + else: + return cross(v, vect2) + if isinstance(vect1, VectorZero) or isinstance(vect2, VectorZero): + return Vector.zero + if isinstance(vect1, VectorMul): + v1, m1 = next(iter(vect1.components.items())) + return m1*cross(v1, vect2) + if isinstance(vect2, VectorMul): + v2, m2 = next(iter(vect2.components.items())) + return m2*cross(vect1, v2) + + return Cross(vect1, vect2) + + +def dot(vect1, vect2): + """ + Returns dot product of two vectors. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> from sympy.vector.vector import dot + >>> R = CoordSys3D('R') + >>> v1 = R.i + R.j + R.k + >>> v2 = R.x * R.i + R.y * R.j + R.z * R.k + >>> dot(v1, v2) + R.x + R.y + R.z + + """ + if isinstance(vect1, Add): + return Add.fromiter(dot(i, vect2) for i in vect1.args) + if isinstance(vect2, Add): + return Add.fromiter(dot(vect1, i) for i in vect2.args) + if isinstance(vect1, BaseVector) and isinstance(vect2, BaseVector): + if vect1._sys == vect2._sys: + return S.One if vect1 == vect2 else S.Zero + from .functions import express + try: + v = express(vect2, vect1._sys) + except ValueError: + return Dot(vect1, vect2) + else: + return dot(vect1, v) + if isinstance(vect1, VectorZero) or isinstance(vect2, VectorZero): + return S.Zero + if isinstance(vect1, VectorMul): + v1, m1 = next(iter(vect1.components.items())) + return m1*dot(v1, vect2) + if isinstance(vect2, VectorMul): + v2, m2 = next(iter(vect2.components.items())) + return m2*dot(vect1, v2) + + return Dot(vect1, vect2) + + +Vector._expr_type = Vector +Vector._mul_func = VectorMul +Vector._add_func = VectorAdd +Vector._zero_func = VectorZero +Vector._base_func = BaseVector +Vector.zero = VectorZero() diff --git a/phivenv/Lib/site-packages/tokenizers-0.22.0.dist-info/INSTALLER b/phivenv/Lib/site-packages/tokenizers-0.22.0.dist-info/INSTALLER new file mode 100644 index 0000000000000000000000000000000000000000..a1b589e38a32041e49332e5e81c2d363dc418d68 --- /dev/null +++ b/phivenv/Lib/site-packages/tokenizers-0.22.0.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/phivenv/Lib/site-packages/tokenizers-0.22.0.dist-info/METADATA b/phivenv/Lib/site-packages/tokenizers-0.22.0.dist-info/METADATA new file mode 100644 index 0000000000000000000000000000000000000000..8167151c0615b9efa6800e1afd1b8c1071aa4a78 --- /dev/null +++ b/phivenv/Lib/site-packages/tokenizers-0.22.0.dist-info/METADATA @@ -0,0 +1,210 @@ +Metadata-Version: 2.4 +Name: tokenizers +Version: 0.22.0 +Classifier: Development Status :: 5 - Production/Stable +Classifier: Intended Audience :: Developers +Classifier: Intended Audience :: Education +Classifier: Intended Audience :: Science/Research +Classifier: License :: OSI Approved :: Apache Software License +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: 3.12 +Classifier: Programming Language :: Python :: 3.13 +Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence +Requires-Dist: huggingface-hub>=0.16.4,<1.0 +Requires-Dist: pytest ; extra == 'testing' +Requires-Dist: pytest-asyncio ; extra == 'testing' +Requires-Dist: requests ; extra == 'testing' +Requires-Dist: numpy ; extra == 'testing' +Requires-Dist: datasets ; extra == 'testing' +Requires-Dist: black==22.3 ; extra == 'testing' +Requires-Dist: ruff ; extra == 'testing' +Requires-Dist: sphinx ; extra == 'docs' +Requires-Dist: sphinx-rtd-theme ; extra == 'docs' +Requires-Dist: setuptools-rust ; extra == 'docs' +Requires-Dist: tokenizers[testing] ; extra == 'dev' +Provides-Extra: testing +Provides-Extra: docs +Provides-Extra: dev +Keywords: NLP,tokenizer,BPE,transformer,deep learning +Author-email: Nicolas Patry , Anthony Moi +Requires-Python: >=3.9 +Description-Content-Type: text/markdown; charset=UTF-8; variant=GFM +Project-URL: Homepage, https://github.com/huggingface/tokenizers +Project-URL: Source, https://github.com/huggingface/tokenizers + +

+
+ +
+

+

+ + Build + + + GitHub + +

+
+ +# Tokenizers + +Provides an implementation of today's most used tokenizers, with a focus on performance and +versatility. + +Bindings over the [Rust](https://github.com/huggingface/tokenizers/tree/master/tokenizers) implementation. +If you are interested in the High-level design, you can go check it there. + +Otherwise, let's dive in! + +## Main features: + + - Train new vocabularies and tokenize using 4 pre-made tokenizers (Bert WordPiece and the 3 + most common BPE versions). + - Extremely fast (both training and tokenization), thanks to the Rust implementation. Takes + less than 20 seconds to tokenize a GB of text on a server's CPU. + - Easy to use, but also extremely versatile. + - Designed for research and production. + - Normalization comes with alignments tracking. It's always possible to get the part of the + original sentence that corresponds to a given token. + - Does all the pre-processing: Truncate, Pad, add the special tokens your model needs. + +### Installation + +#### With pip: + +```bash +pip install tokenizers +``` + +#### From sources: + +To use this method, you need to have the Rust installed: + +```bash +# Install with: +curl https://sh.rustup.rs -sSf | sh -s -- -y +export PATH="$HOME/.cargo/bin:$PATH" +``` + +Once Rust is installed, you can compile doing the following + +```bash +git clone https://github.com/huggingface/tokenizers +cd tokenizers/bindings/python + +# Create a virtual env (you can use yours as well) +python -m venv .env +source .env/bin/activate + +# Install `tokenizers` in the current virtual env +pip install -e . +``` + +### Load a pretrained tokenizer from the Hub + +```python +from tokenizers import Tokenizer + +tokenizer = Tokenizer.from_pretrained("bert-base-cased") +``` + +### Using the provided Tokenizers + +We provide some pre-build tokenizers to cover the most common cases. You can easily load one of +these using some `vocab.json` and `merges.txt` files: + +```python +from tokenizers import CharBPETokenizer + +# Initialize a tokenizer +vocab = "./path/to/vocab.json" +merges = "./path/to/merges.txt" +tokenizer = CharBPETokenizer(vocab, merges) + +# And then encode: +encoded = tokenizer.encode("I can feel the magic, can you?") +print(encoded.ids) +print(encoded.tokens) +``` + +And you can train them just as simply: + +```python +from tokenizers import CharBPETokenizer + +# Initialize a tokenizer +tokenizer = CharBPETokenizer() + +# Then train it! +tokenizer.train([ "./path/to/files/1.txt", "./path/to/files/2.txt" ]) + +# Now, let's use it: +encoded = tokenizer.encode("I can feel the magic, can you?") + +# And finally save it somewhere +tokenizer.save("./path/to/directory/my-bpe.tokenizer.json") +``` + +#### Provided Tokenizers + + - `CharBPETokenizer`: The original BPE + - `ByteLevelBPETokenizer`: The byte level version of the BPE + - `SentencePieceBPETokenizer`: A BPE implementation compatible with the one used by SentencePiece + - `BertWordPieceTokenizer`: The famous Bert tokenizer, using WordPiece + +All of these can be used and trained as explained above! + +### Build your own + +Whenever these provided tokenizers don't give you enough freedom, you can build your own tokenizer, +by putting all the different parts you need together. +You can check how we implemented the [provided tokenizers](https://github.com/huggingface/tokenizers/tree/master/bindings/python/py_src/tokenizers/implementations) and adapt them easily to your own needs. + +#### Building a byte-level BPE + +Here is an example showing how to build your own byte-level BPE by putting all the different pieces +together, and then saving it to a single file: + +```python +from tokenizers import Tokenizer, models, pre_tokenizers, decoders, trainers, processors + +# Initialize a tokenizer +tokenizer = Tokenizer(models.BPE()) + +# Customize pre-tokenization and decoding +tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=True) +tokenizer.decoder = decoders.ByteLevel() +tokenizer.post_processor = processors.ByteLevel(trim_offsets=True) + +# And then train +trainer = trainers.BpeTrainer( + vocab_size=20000, + min_frequency=2, + initial_alphabet=pre_tokenizers.ByteLevel.alphabet() +) +tokenizer.train([ + "./path/to/dataset/1.txt", + "./path/to/dataset/2.txt", + "./path/to/dataset/3.txt" +], trainer=trainer) + +# And Save it +tokenizer.save("byte-level-bpe.tokenizer.json", pretty=True) +``` + +Now, when you want to use this tokenizer, this is as simple as: + +```python +from tokenizers import Tokenizer + +tokenizer = Tokenizer.from_file("byte-level-bpe.tokenizer.json") + +encoded = tokenizer.encode("I can feel the magic, can you?") +``` + diff --git a/phivenv/Lib/site-packages/tokenizers-0.22.0.dist-info/RECORD b/phivenv/Lib/site-packages/tokenizers-0.22.0.dist-info/RECORD new file mode 100644 index 0000000000000000000000000000000000000000..03b92fdebbe44de2766a8380ea208ab7f16096fc --- /dev/null +++ b/phivenv/Lib/site-packages/tokenizers-0.22.0.dist-info/RECORD @@ -0,0 +1,45 @@ +tokenizers-0.22.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +tokenizers-0.22.0.dist-info/METADATA,sha256=EKVl7mxUTvTaBDxLE6hS0QkicW2pXHT-5yqSBmpoqDY,6949 +tokenizers-0.22.0.dist-info/RECORD,, +tokenizers-0.22.0.dist-info/WHEEL,sha256=-M5O7l5EczTA8VFaBQsg2Fpg0dKz0WOuvpt3nEh86bo,94 +tokenizers/__init__.py,sha256=AJ5NdKOV_dZ_SOw9oFF1M4ln0fh1hod-voqzeYHS71U,2715 +tokenizers/__init__.pyi,sha256=Ae9ppleBCxKXzptN_UyQeTNdx_iUSCSgZxfNJ8OTBAI,46516 +tokenizers/__pycache__/__init__.cpython-39.pyc,, +tokenizers/decoders/__init__.py,sha256=Qr1jbA2fvqwjgBBYOEyaDjLBwDwkBftP5jN3tRvT1jg,387 +tokenizers/decoders/__init__.pyi,sha256=s5AWC5hQiQYhND41lGwrP3lRWSlZPV_POUa9R1eM_OI,7674 +tokenizers/decoders/__pycache__/__init__.cpython-39.pyc,, +tokenizers/implementations/__init__.py,sha256=9ZI2cJHPCUCZXP37GNmRHtBzIYiKWzyqky1rtcYdrPw,316 +tokenizers/implementations/__pycache__/__init__.cpython-39.pyc,, +tokenizers/implementations/__pycache__/base_tokenizer.cpython-39.pyc,, +tokenizers/implementations/__pycache__/bert_wordpiece.cpython-39.pyc,, +tokenizers/implementations/__pycache__/byte_level_bpe.cpython-39.pyc,, +tokenizers/implementations/__pycache__/char_level_bpe.cpython-39.pyc,, +tokenizers/implementations/__pycache__/sentencepiece_bpe.cpython-39.pyc,, +tokenizers/implementations/__pycache__/sentencepiece_unigram.cpython-39.pyc,, +tokenizers/implementations/base_tokenizer.py,sha256=crpmURcXXhlQRqeqFRY5kdAwms5qtDJfRDMATsUEVuw,16250 +tokenizers/implementations/bert_wordpiece.py,sha256=jRUMSYU1C7nVEDZMKdyG3Vp7-CUongE1kGJhvOP_EHM,5671 +tokenizers/implementations/byte_level_bpe.py,sha256=pWfbSeaew9atWKX8Nf2zIRia6qhAwXMfZgaeINInzo8,4394 +tokenizers/implementations/char_level_bpe.py,sha256=Bt15rEcZRSzJ2RfEeq2YvOXZ8w_UYGq9i6O3y-41Aj4,5599 +tokenizers/implementations/sentencepiece_bpe.py,sha256=4_dJFcNsyvC_F1wBnWDBquWssbUMV4kP-BEXeuHcxtI,3824 +tokenizers/implementations/sentencepiece_unigram.py,sha256=5SliPQsU_cAB1Pqt7ufL9w8fBrNoLixVC-rqF4MRcm0,7776 +tokenizers/models/__init__.py,sha256=qE73qycPAKcey_pS8ov1FIz10L2Ybzy1E-1KGee27Qg,184 +tokenizers/models/__init__.pyi,sha256=vEJHu-C5LgbqBdtmSgGBIjkHh4bvDx520V0xY2gntIo,17520 +tokenizers/models/__pycache__/__init__.cpython-39.pyc,, +tokenizers/normalizers/__init__.py,sha256=Pgb_wE1QQov3t_SW4vcsXY6lVBDYyGNfYGkJMk3wI7c,870 +tokenizers/normalizers/__init__.pyi,sha256=cpDAEDNbn2HZdCmLMX27mHLQFVU1dAvCZwRqKITYUxY,21534 +tokenizers/normalizers/__pycache__/__init__.cpython-39.pyc,, +tokenizers/pre_tokenizers/__init__.py,sha256=QvNmMTtV0RGkaKSN9hTwEPimD1OLHRU6PbzVxXwM4zA,614 +tokenizers/pre_tokenizers/__init__.pyi,sha256=Nu81riq3C9_DfikxxgvJ9leJ7MU3vfbw_V4xHETN40s,27245 +tokenizers/pre_tokenizers/__pycache__/__init__.cpython-39.pyc,, +tokenizers/processors/__init__.py,sha256=X1OKKtr3IUeede85mI1cLF0MqmDRlr7LNFMYdEy5KMU,316 +tokenizers/processors/__init__.pyi,sha256=qMoapc_ufxbrL-ZVAkpYugLzxxuz1UBBNyoEiySjt-4,11699 +tokenizers/processors/__pycache__/__init__.cpython-39.pyc,, +tokenizers/tokenizers.pyd,sha256=TjT1QHF1WkmpA0V4C5Fw3cOZaymkw506H9xIZzJ_DvA,7669248 +tokenizers/tools/__init__.py,sha256=mJLVrHN1FP23Cf4_EJWoYUdh8yTb5BbJFNfa8Ec5vHM,56 +tokenizers/tools/__pycache__/__init__.cpython-39.pyc,, +tokenizers/tools/__pycache__/visualizer.cpython-39.pyc,, +tokenizers/tools/visualizer-styles.css,sha256=g98RJucUCjYnPLwXycfcV-dAbHMzqJpnZ7NoP_cCUBk,5020 +tokenizers/tools/visualizer.py,sha256=Q2S8r3mP6Xnri7DcWbmWpgzp2aDd1qlisBSV4R1LxJE,15028 +tokenizers/trainers/__init__.py,sha256=v9CdtoVauwD6b1wAA6R5goSLZslBbPE5vZHM2d8sgKM,256 +tokenizers/trainers/__init__.pyi,sha256=13m2jGMVhn4dJyxkfgzR-JKvuBGvuj7fvgwJHlbnQUs,5538 +tokenizers/trainers/__pycache__/__init__.cpython-39.pyc,, diff --git a/phivenv/Lib/site-packages/tokenizers-0.22.0.dist-info/WHEEL b/phivenv/Lib/site-packages/tokenizers-0.22.0.dist-info/WHEEL new file mode 100644 index 0000000000000000000000000000000000000000..3cf4a47cfedeb31afa1186ccc41a537b8dd3bbf4 --- /dev/null +++ b/phivenv/Lib/site-packages/tokenizers-0.22.0.dist-info/WHEEL @@ -0,0 +1,4 @@ +Wheel-Version: 1.0 +Generator: maturin (1.9.4) +Root-Is-Purelib: false +Tag: cp39-abi3-win_amd64 diff --git a/phivenv/Lib/site-packages/tokenizers/__init__.py b/phivenv/Lib/site-packages/tokenizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d22afda0e73fee769224a24e80abc07cc51f9a97 --- /dev/null +++ b/phivenv/Lib/site-packages/tokenizers/__init__.py @@ -0,0 +1,100 @@ +from enum import Enum +from typing import List, Tuple, Union + + +Offsets = Tuple[int, int] + +TextInputSequence = str +"""A :obj:`str` that represents an input sequence """ + +PreTokenizedInputSequence = Union[List[str], Tuple[str]] +"""A pre-tokenized input sequence. Can be one of: + + - A :obj:`List` of :obj:`str` + - A :obj:`Tuple` of :obj:`str` +""" + +TextEncodeInput = Union[ + TextInputSequence, + Tuple[TextInputSequence, TextInputSequence], + List[TextInputSequence], +] +"""Represents a textual input for encoding. Can be either: + + - A single sequence: :data:`~tokenizers.TextInputSequence` + - A pair of sequences: + + - A :obj:`Tuple` of :data:`~tokenizers.TextInputSequence` + - Or a :obj:`List` of :data:`~tokenizers.TextInputSequence` of size 2 +""" + +PreTokenizedEncodeInput = Union[ + PreTokenizedInputSequence, + Tuple[PreTokenizedInputSequence, PreTokenizedInputSequence], + List[PreTokenizedInputSequence], +] +"""Represents a pre-tokenized input for encoding. Can be either: + + - A single sequence: :data:`~tokenizers.PreTokenizedInputSequence` + - A pair of sequences: + + - A :obj:`Tuple` of :data:`~tokenizers.PreTokenizedInputSequence` + - Or a :obj:`List` of :data:`~tokenizers.PreTokenizedInputSequence` of size 2 +""" + +InputSequence = Union[TextInputSequence, PreTokenizedInputSequence] +"""Represents all the possible types of input sequences for encoding. Can be: + + - When ``is_pretokenized=False``: :data:`~TextInputSequence` + - When ``is_pretokenized=True``: :data:`~PreTokenizedInputSequence` +""" + +EncodeInput = Union[TextEncodeInput, PreTokenizedEncodeInput] +"""Represents all the possible types of input for encoding. Can be: + + - When ``is_pretokenized=False``: :data:`~TextEncodeInput` + - When ``is_pretokenized=True``: :data:`~PreTokenizedEncodeInput` +""" + + +class OffsetReferential(Enum): + ORIGINAL = "original" + NORMALIZED = "normalized" + + +class OffsetType(Enum): + BYTE = "byte" + CHAR = "char" + + +class SplitDelimiterBehavior(Enum): + REMOVED = "removed" + ISOLATED = "isolated" + MERGED_WITH_PREVIOUS = "merged_with_previous" + MERGED_WITH_NEXT = "merged_with_next" + CONTIGUOUS = "contiguous" + + +from .tokenizers import ( + AddedToken, + Encoding, + NormalizedString, + PreTokenizedString, + Regex, + Token, + Tokenizer, + decoders, + models, + normalizers, + pre_tokenizers, + processors, + trainers, + __version__, +) +from .implementations import ( + BertWordPieceTokenizer, + ByteLevelBPETokenizer, + CharBPETokenizer, + SentencePieceBPETokenizer, + SentencePieceUnigramTokenizer, +) diff --git a/phivenv/Lib/site-packages/tokenizers/__init__.pyi b/phivenv/Lib/site-packages/tokenizers/__init__.pyi new file mode 100644 index 0000000000000000000000000000000000000000..c3bfd62a40e60237d4a71fbda2121b1af49d59c6 --- /dev/null +++ b/phivenv/Lib/site-packages/tokenizers/__init__.pyi @@ -0,0 +1,1362 @@ +# Generated content DO NOT EDIT +class AddedToken: + """ + Represents a token that can be be added to a :class:`~tokenizers.Tokenizer`. + It can have special options that defines the way it should behave. + + Args: + content (:obj:`str`): The content of the token + + single_word (:obj:`bool`, defaults to :obj:`False`): + Defines whether this token should only match single words. If :obj:`True`, this + token will never match inside of a word. For example the token ``ing`` would match + on ``tokenizing`` if this option is :obj:`False`, but not if it is :obj:`True`. + The notion of "`inside of a word`" is defined by the word boundaries pattern in + regular expressions (ie. the token should start and end with word boundaries). + + lstrip (:obj:`bool`, defaults to :obj:`False`): + Defines whether this token should strip all potential whitespaces on its left side. + If :obj:`True`, this token will greedily match any whitespace on its left. For + example if we try to match the token ``[MASK]`` with ``lstrip=True``, in the text + ``"I saw a [MASK]"``, we would match on ``" [MASK]"``. (Note the space on the left). + + rstrip (:obj:`bool`, defaults to :obj:`False`): + Defines whether this token should strip all potential whitespaces on its right + side. If :obj:`True`, this token will greedily match any whitespace on its right. + It works just like :obj:`lstrip` but on the right. + + normalized (:obj:`bool`, defaults to :obj:`True` with :meth:`~tokenizers.Tokenizer.add_tokens` and :obj:`False` with :meth:`~tokenizers.Tokenizer.add_special_tokens`): + Defines whether this token should match against the normalized version of the input + text. For example, with the added token ``"yesterday"``, and a normalizer in charge of + lowercasing the text, the token could be extract from the input ``"I saw a lion + Yesterday"``. + special (:obj:`bool`, defaults to :obj:`False` with :meth:`~tokenizers.Tokenizer.add_tokens` and :obj:`False` with :meth:`~tokenizers.Tokenizer.add_special_tokens`): + Defines whether this token should be skipped when decoding. + + """ + def __init__(self, content, single_word=False, lstrip=False, rstrip=False, normalized=True, special=False): + pass + + @property + def content(self): + """ + Get the content of this :obj:`AddedToken` + """ + pass + + @property + def lstrip(self): + """ + Get the value of the :obj:`lstrip` option + """ + pass + + @property + def normalized(self): + """ + Get the value of the :obj:`normalized` option + """ + pass + + @property + def rstrip(self): + """ + Get the value of the :obj:`rstrip` option + """ + pass + + @property + def single_word(self): + """ + Get the value of the :obj:`single_word` option + """ + pass + + @property + def special(self): + """ + Get the value of the :obj:`special` option + """ + pass + +class Encoding: + """ + The :class:`~tokenizers.Encoding` represents the output of a :class:`~tokenizers.Tokenizer`. + """ + @property + def attention_mask(self): + """ + The attention mask + + This indicates to the LM which tokens should be attended to, and which should not. + This is especially important when batching sequences, where we need to applying + padding. + + Returns: + :obj:`List[int]`: The attention mask + """ + pass + + def char_to_token(self, char_pos, sequence_index=0): + """ + Get the token that contains the char at the given position in the input sequence. + + Args: + char_pos (:obj:`int`): + The position of a char in the input string + sequence_index (:obj:`int`, defaults to :obj:`0`): + The index of the sequence that contains the target char + + Returns: + :obj:`int`: The index of the token that contains this char in the encoded sequence + """ + pass + + def char_to_word(self, char_pos, sequence_index=0): + """ + Get the word that contains the char at the given position in the input sequence. + + Args: + char_pos (:obj:`int`): + The position of a char in the input string + sequence_index (:obj:`int`, defaults to :obj:`0`): + The index of the sequence that contains the target char + + Returns: + :obj:`int`: The index of the word that contains this char in the input sequence + """ + pass + + @property + def ids(self): + """ + The generated IDs + + The IDs are the main input to a Language Model. They are the token indices, + the numerical representations that a LM understands. + + Returns: + :obj:`List[int]`: The list of IDs + """ + pass + + @staticmethod + def merge(encodings, growing_offsets=True): + """ + Merge the list of encodings into one final :class:`~tokenizers.Encoding` + + Args: + encodings (A :obj:`List` of :class:`~tokenizers.Encoding`): + The list of encodings that should be merged in one + + growing_offsets (:obj:`bool`, defaults to :obj:`True`): + Whether the offsets should accumulate while merging + + Returns: + :class:`~tokenizers.Encoding`: The resulting Encoding + """ + pass + + @property + def n_sequences(self): + """ + The number of sequences represented + + Returns: + :obj:`int`: The number of sequences in this :class:`~tokenizers.Encoding` + """ + pass + + @property + def offsets(self): + """ + The offsets associated to each token + + These offsets let's you slice the input string, and thus retrieve the original + part that led to producing the corresponding token. + + Returns: + A :obj:`List` of :obj:`Tuple[int, int]`: The list of offsets + """ + pass + + @property + def overflowing(self): + """ + A :obj:`List` of overflowing :class:`~tokenizers.Encoding` + + When using truncation, the :class:`~tokenizers.Tokenizer` takes care of splitting + the output into as many pieces as required to match the specified maximum length. + This field lets you retrieve all the subsequent pieces. + + When you use pairs of sequences, the overflowing pieces will contain enough + variations to cover all the possible combinations, while respecting the provided + maximum length. + """ + pass + + def pad(self, length, direction="right", pad_id=0, pad_type_id=0, pad_token="[PAD]"): + """ + Pad the :class:`~tokenizers.Encoding` at the given length + + Args: + length (:obj:`int`): + The desired length + + direction: (:obj:`str`, defaults to :obj:`right`): + The expected padding direction. Can be either :obj:`right` or :obj:`left` + + pad_id (:obj:`int`, defaults to :obj:`0`): + The ID corresponding to the padding token + + pad_type_id (:obj:`int`, defaults to :obj:`0`): + The type ID corresponding to the padding token + + pad_token (:obj:`str`, defaults to `[PAD]`): + The pad token to use + """ + pass + + @property + def sequence_ids(self): + """ + The generated sequence indices. + + They represent the index of the input sequence associated to each token. + The sequence id can be None if the token is not related to any input sequence, + like for example with special tokens. + + Returns: + A :obj:`List` of :obj:`Optional[int]`: A list of optional sequence index. + """ + pass + + def set_sequence_id(self, sequence_id): + """ + Set the given sequence index + + Set the given sequence index for the whole range of tokens contained in this + :class:`~tokenizers.Encoding`. + """ + pass + + @property + def special_tokens_mask(self): + """ + The special token mask + + This indicates which tokens are special tokens, and which are not. + + Returns: + :obj:`List[int]`: The special tokens mask + """ + pass + + def token_to_chars(self, token_index): + """ + Get the offsets of the token at the given index. + + The returned offsets are related to the input sequence that contains the + token. In order to determine in which input sequence it belongs, you + must call :meth:`~tokenizers.Encoding.token_to_sequence()`. + + Args: + token_index (:obj:`int`): + The index of a token in the encoded sequence. + + Returns: + :obj:`Tuple[int, int]`: The token offsets :obj:`(first, last + 1)` + """ + pass + + def token_to_sequence(self, token_index): + """ + Get the index of the sequence represented by the given token. + + In the general use case, this method returns :obj:`0` for a single sequence or + the first sequence of a pair, and :obj:`1` for the second sequence of a pair + + Args: + token_index (:obj:`int`): + The index of a token in the encoded sequence. + + Returns: + :obj:`int`: The sequence id of the given token + """ + pass + + def token_to_word(self, token_index): + """ + Get the index of the word that contains the token in one of the input sequences. + + The returned word index is related to the input sequence that contains + the token. In order to determine in which input sequence it belongs, you + must call :meth:`~tokenizers.Encoding.token_to_sequence()`. + + Args: + token_index (:obj:`int`): + The index of a token in the encoded sequence. + + Returns: + :obj:`int`: The index of the word in the relevant input sequence. + """ + pass + + @property + def tokens(self): + """ + The generated tokens + + They are the string representation of the IDs. + + Returns: + :obj:`List[str]`: The list of tokens + """ + pass + + def truncate(self, max_length, stride=0, direction="right"): + """ + Truncate the :class:`~tokenizers.Encoding` at the given length + + If this :class:`~tokenizers.Encoding` represents multiple sequences, when truncating + this information is lost. It will be considered as representing a single sequence. + + Args: + max_length (:obj:`int`): + The desired length + + stride (:obj:`int`, defaults to :obj:`0`): + The length of previous content to be included in each overflowing piece + + direction (:obj:`str`, defaults to :obj:`right`): + Truncate direction + """ + pass + + @property + def type_ids(self): + """ + The generated type IDs + + Generally used for tasks like sequence classification or question answering, + these tokens let the LM know which input sequence corresponds to each tokens. + + Returns: + :obj:`List[int]`: The list of type ids + """ + pass + + @property + def word_ids(self): + """ + The generated word indices. + + They represent the index of the word associated to each token. + When the input is pre-tokenized, they correspond to the ID of the given input label, + otherwise they correspond to the words indices as defined by the + :class:`~tokenizers.pre_tokenizers.PreTokenizer` that was used. + + For special tokens and such (any token that was generated from something that was + not part of the input), the output is :obj:`None` + + Returns: + A :obj:`List` of :obj:`Optional[int]`: A list of optional word index. + """ + pass + + def word_to_chars(self, word_index, sequence_index=0): + """ + Get the offsets of the word at the given index in one of the input sequences. + + Args: + word_index (:obj:`int`): + The index of a word in one of the input sequences. + sequence_index (:obj:`int`, defaults to :obj:`0`): + The index of the sequence that contains the target word + + Returns: + :obj:`Tuple[int, int]`: The range of characters (span) :obj:`(first, last + 1)` + """ + pass + + def word_to_tokens(self, word_index, sequence_index=0): + """ + Get the encoded tokens corresponding to the word at the given index + in one of the input sequences. + + Args: + word_index (:obj:`int`): + The index of a word in one of the input sequences. + sequence_index (:obj:`int`, defaults to :obj:`0`): + The index of the sequence that contains the target word + + Returns: + :obj:`Tuple[int, int]`: The range of tokens: :obj:`(first, last + 1)` + """ + pass + + @property + def words(self): + """ + The generated word indices. + + .. warning:: + This is deprecated and will be removed in a future version. + Please use :obj:`~tokenizers.Encoding.word_ids` instead. + + They represent the index of the word associated to each token. + When the input is pre-tokenized, they correspond to the ID of the given input label, + otherwise they correspond to the words indices as defined by the + :class:`~tokenizers.pre_tokenizers.PreTokenizer` that was used. + + For special tokens and such (any token that was generated from something that was + not part of the input), the output is :obj:`None` + + Returns: + A :obj:`List` of :obj:`Optional[int]`: A list of optional word index. + """ + pass + +class NormalizedString: + """ + NormalizedString + + A NormalizedString takes care of modifying an "original" string, to obtain a "normalized" one. + While making all the requested modifications, it keeps track of the alignment information + between the two versions of the string. + + Args: + sequence: str: + The string sequence used to initialize this NormalizedString + """ + def append(self, s): + """ + Append the given sequence to the string + """ + pass + + def clear(self): + """ + Clears the string + """ + pass + + def filter(self, func): + """ + Filter each character of the string using the given func + """ + pass + + def for_each(self, func): + """ + Calls the given function for each character of the string + """ + pass + + def lowercase(self): + """ + Lowercase the string + """ + pass + + def lstrip(self): + """ + Strip the left of the string + """ + pass + + def map(self, func): + """ + Calls the given function for each character of the string + + Replaces each character of the string using the returned value. Each + returned value **must** be a str of length 1 (ie a character). + """ + pass + + def nfc(self): + """ + Runs the NFC normalization + """ + pass + + def nfd(self): + """ + Runs the NFD normalization + """ + pass + + def nfkc(self): + """ + Runs the NFKC normalization + """ + pass + + def nfkd(self): + """ + Runs the NFKD normalization + """ + pass + + @property + def normalized(self): + """ + The normalized part of the string + """ + pass + + def prepend(self, s): + """ + Prepend the given sequence to the string + """ + pass + + def replace(self, pattern, content): + """ + Replace the content of the given pattern with the provided content + + Args: + pattern: Pattern: + A pattern used to match the string. Usually a string or a Regex + + content: str: + The content to be used as replacement + """ + pass + + def rstrip(self): + """ + Strip the right of the string + """ + pass + + def slice(self, range): + """ + Slice the string using the given range + """ + pass + + def split(self, pattern, behavior): + """ + Split the NormalizedString using the given pattern and the specified behavior + + Args: + pattern: Pattern: + A pattern used to split the string. Usually a string or a regex built with `tokenizers.Regex` + + behavior: SplitDelimiterBehavior: + The behavior to use when splitting. + Choices: "removed", "isolated", "merged_with_previous", "merged_with_next", + "contiguous" + + Returns: + A list of NormalizedString, representing each split + """ + pass + + def strip(self): + """ + Strip both ends of the string + """ + pass + + def uppercase(self): + """ + Uppercase the string + """ + pass + +class PreTokenizedString: + """ + PreTokenizedString + + Wrapper over a string, that provides a way to normalize, pre-tokenize, tokenize the + underlying string, while keeping track of the alignment information (offsets). + + The PreTokenizedString manages what we call `splits`. Each split represents a substring + which is a subpart of the original string, with the relevant offsets and tokens. + + When calling one of the methods used to modify the PreTokenizedString (namely one of + `split`, `normalize` or `tokenize), only the `splits` that don't have any associated + tokens will get modified. + + Args: + sequence: str: + The string sequence used to initialize this PreTokenizedString + """ + def __init__(self, sequence): + pass + + def get_splits(self, offset_referential="original", offset_type="char"): + """ + Get the splits currently managed by the PreTokenizedString + + Args: + offset_referential: :obj:`str` + Whether the returned splits should have offsets expressed relative + to the original string, or the normalized one. choices: "original", "normalized". + + offset_type: :obj:`str` + Whether the returned splits should have offsets expressed in bytes or chars. + When slicing an str, we usually want to use chars, which is the default value. + Now in some cases it might be interesting to get these offsets expressed in bytes, + so it is possible to change this here. + choices: "char", "bytes" + + Returns + A list of splits + """ + pass + + def normalize(self, func): + """ + Normalize each split of the `PreTokenizedString` using the given `func` + + Args: + func: Callable[[NormalizedString], None]: + The function used to normalize each underlying split. This function + does not need to return anything, just calling the methods on the provided + NormalizedString allow its modification. + """ + pass + + def split(self, func): + """ + Split the PreTokenizedString using the given `func` + + Args: + func: Callable[[index, NormalizedString], List[NormalizedString]]: + The function used to split each underlying split. + It is expected to return a list of `NormalizedString`, that represent the new + splits. If the given `NormalizedString` does not need any splitting, we can + just return it directly. + In order for the offsets to be tracked accurately, any returned `NormalizedString` + should come from calling either `.split` or `.slice` on the received one. + """ + pass + + def to_encoding(self, type_id=0, word_idx=None): + """ + Return an Encoding generated from this PreTokenizedString + + Args: + type_id: int = 0: + The type_id to be used on the generated Encoding. + + word_idx: Optional[int] = None: + An optional word index to be used for each token of this Encoding. If provided, + all the word indices in the generated Encoding will use this value, instead + of the one automatically tracked during pre-tokenization. + + Returns: + An Encoding + """ + pass + + def tokenize(self, func): + """ + Tokenize each split of the `PreTokenizedString` using the given `func` + + Args: + func: Callable[[str], List[Token]]: + The function used to tokenize each underlying split. This function must return + a list of Token generated from the input str. + """ + pass + +class Regex: + """ + Instantiate a new Regex with the given pattern + """ + def __init__(self, pattern): + pass + +class Token: + pass + +class Tokenizer: + """ + A :obj:`Tokenizer` works as a pipeline. It processes some raw text as input + and outputs an :class:`~tokenizers.Encoding`. + + Args: + model (:class:`~tokenizers.models.Model`): + The core algorithm that this :obj:`Tokenizer` should be using. + + """ + def __init__(self, model): + pass + + def add_special_tokens(self, tokens): + """ + Add the given special tokens to the Tokenizer. + + If these tokens are already part of the vocabulary, it just let the Tokenizer know about + them. If they don't exist, the Tokenizer creates them, giving them a new id. + + These special tokens will never be processed by the model (ie won't be split into + multiple tokens), and they can be removed from the output when decoding. + + Args: + tokens (A :obj:`List` of :class:`~tokenizers.AddedToken` or :obj:`str`): + The list of special tokens we want to add to the vocabulary. Each token can either + be a string or an instance of :class:`~tokenizers.AddedToken` for more + customization. + + Returns: + :obj:`int`: The number of tokens that were created in the vocabulary + """ + pass + + def add_tokens(self, tokens): + """ + Add the given tokens to the vocabulary + + The given tokens are added only if they don't already exist in the vocabulary. + Each token then gets a new attributed id. + + Args: + tokens (A :obj:`List` of :class:`~tokenizers.AddedToken` or :obj:`str`): + The list of tokens we want to add to the vocabulary. Each token can be either a + string or an instance of :class:`~tokenizers.AddedToken` for more customization. + + Returns: + :obj:`int`: The number of tokens that were created in the vocabulary + """ + pass + + def async_decode_batch(self, sequences, skip_special_tokens=True): + """ + Decode a batch of ids back to their corresponding string + + Args: + sequences (:obj:`List` of :obj:`List[int]`): + The batch of sequences we want to decode + + skip_special_tokens (:obj:`bool`, defaults to :obj:`True`): + Whether the special tokens should be removed from the decoded strings + + Returns: + :obj:`List[str]`: A list of decoded strings + """ + pass + + def async_encode(self, sequence, pair=None, is_pretokenized=False, add_special_tokens=True): + """ + Asynchronously encode the given input with character offsets. + + This is an async version of encode that can be awaited in async Python code. + + Example: + Here are some examples of the inputs that are accepted:: + + await async_encode("A single sequence") + + Args: + sequence (:obj:`~tokenizers.InputSequence`): + The main input sequence we want to encode. This sequence can be either raw + text or pre-tokenized, according to the ``is_pretokenized`` argument: + + - If ``is_pretokenized=False``: :class:`~tokenizers.TextInputSequence` + - If ``is_pretokenized=True``: :class:`~tokenizers.PreTokenizedInputSequence` + + pair (:obj:`~tokenizers.InputSequence`, `optional`): + An optional input sequence. The expected format is the same that for ``sequence``. + + is_pretokenized (:obj:`bool`, defaults to :obj:`False`): + Whether the input is already pre-tokenized + + add_special_tokens (:obj:`bool`, defaults to :obj:`True`): + Whether to add the special tokens + + Returns: + :class:`~tokenizers.Encoding`: The encoded result + + """ + pass + + def async_encode_batch(self, input, is_pretokenized=False, add_special_tokens=True): + """ + Asynchronously encode the given batch of inputs with character offsets. + + This is an async version of encode_batch that can be awaited in async Python code. + + Example: + Here are some examples of the inputs that are accepted:: + + await async_encode_batch([ + "A single sequence", + ("A tuple with a sequence", "And its pair"), + [ "A", "pre", "tokenized", "sequence" ], + ([ "A", "pre", "tokenized", "sequence" ], "And its pair") + ]) + + Args: + input (A :obj:`List`/:obj:`Tuple` of :obj:`~tokenizers.EncodeInput`): + A list of single sequences or pair sequences to encode. Each sequence + can be either raw text or pre-tokenized, according to the ``is_pretokenized`` + argument: + + - If ``is_pretokenized=False``: :class:`~tokenizers.TextEncodeInput` + - If ``is_pretokenized=True``: :class:`~tokenizers.PreTokenizedEncodeInput` + + is_pretokenized (:obj:`bool`, defaults to :obj:`False`): + Whether the input is already pre-tokenized + + add_special_tokens (:obj:`bool`, defaults to :obj:`True`): + Whether to add the special tokens + + Returns: + A :obj:`List` of :class:`~tokenizers.Encoding`: The encoded batch + + """ + pass + + def async_encode_batch_fast(self, input, is_pretokenized=False, add_special_tokens=True): + """ + Asynchronously encode the given batch of inputs without tracking character offsets. + + This is an async version of encode_batch_fast that can be awaited in async Python code. + + Example: + Here are some examples of the inputs that are accepted:: + + await async_encode_batch_fast([ + "A single sequence", + ("A tuple with a sequence", "And its pair"), + [ "A", "pre", "tokenized", "sequence" ], + ([ "A", "pre", "tokenized", "sequence" ], "And its pair") + ]) + + Args: + input (A :obj:`List`/:obj:`Tuple` of :obj:`~tokenizers.EncodeInput`): + A list of single sequences or pair sequences to encode. Each sequence + can be either raw text or pre-tokenized, according to the ``is_pretokenized`` + argument: + + - If ``is_pretokenized=False``: :class:`~tokenizers.TextEncodeInput` + - If ``is_pretokenized=True``: :class:`~tokenizers.PreTokenizedEncodeInput` + + is_pretokenized (:obj:`bool`, defaults to :obj:`False`): + Whether the input is already pre-tokenized + + add_special_tokens (:obj:`bool`, defaults to :obj:`True`): + Whether to add the special tokens + + Returns: + A :obj:`List` of :class:`~tokenizers.Encoding`: The encoded batch + + """ + pass + + def decode(self, ids, skip_special_tokens=True): + """ + Decode the given list of ids back to a string + + This is used to decode anything coming back from a Language Model + + Args: + ids (A :obj:`List/Tuple` of :obj:`int`): + The list of ids that we want to decode + + skip_special_tokens (:obj:`bool`, defaults to :obj:`True`): + Whether the special tokens should be removed from the decoded string + + Returns: + :obj:`str`: The decoded string + """ + pass + + def decode_batch(self, sequences, skip_special_tokens=True): + """ + Decode a batch of ids back to their corresponding string + + Args: + sequences (:obj:`List` of :obj:`List[int]`): + The batch of sequences we want to decode + + skip_special_tokens (:obj:`bool`, defaults to :obj:`True`): + Whether the special tokens should be removed from the decoded strings + + Returns: + :obj:`List[str]`: A list of decoded strings + """ + pass + + @property + def decoder(self): + """ + The `optional` :class:`~tokenizers.decoders.Decoder` in use by the Tokenizer + """ + pass + + def enable_padding( + self, direction="right", pad_id=0, pad_type_id=0, pad_token="[PAD]", length=None, pad_to_multiple_of=None + ): + """ + Enable the padding + + Args: + direction (:obj:`str`, `optional`, defaults to :obj:`right`): + The direction in which to pad. Can be either ``right`` or ``left`` + + pad_to_multiple_of (:obj:`int`, `optional`): + If specified, the padding length should always snap to the next multiple of the + given value. For example if we were going to pad witha length of 250 but + ``pad_to_multiple_of=8`` then we will pad to 256. + + pad_id (:obj:`int`, defaults to 0): + The id to be used when padding + + pad_type_id (:obj:`int`, defaults to 0): + The type id to be used when padding + + pad_token (:obj:`str`, defaults to :obj:`[PAD]`): + The pad token to be used when padding + + length (:obj:`int`, `optional`): + If specified, the length at which to pad. If not specified we pad using the size of + the longest sequence in a batch. + """ + pass + + def enable_truncation(self, max_length, stride=0, strategy="longest_first", direction="right"): + """ + Enable truncation + + Args: + max_length (:obj:`int`): + The max length at which to truncate + + stride (:obj:`int`, `optional`): + The length of the previous first sequence to be included in the overflowing + sequence + + strategy (:obj:`str`, `optional`, defaults to :obj:`longest_first`): + The strategy used to truncation. Can be one of ``longest_first``, ``only_first`` or + ``only_second``. + + direction (:obj:`str`, defaults to :obj:`right`): + Truncate direction + """ + pass + + def encode(self, sequence, pair=None, is_pretokenized=False, add_special_tokens=True): + """ + Encode the given sequence and pair. This method can process raw text sequences + as well as already pre-tokenized sequences. + + Example: + Here are some examples of the inputs that are accepted:: + + encode("A single sequence")` + encode("A sequence", "And its pair")` + encode([ "A", "pre", "tokenized", "sequence" ], is_pretokenized=True)` + encode( + [ "A", "pre", "tokenized", "sequence" ], [ "And", "its", "pair" ], + is_pretokenized=True + ) + + Args: + sequence (:obj:`~tokenizers.InputSequence`): + The main input sequence we want to encode. This sequence can be either raw + text or pre-tokenized, according to the ``is_pretokenized`` argument: + + - If ``is_pretokenized=False``: :class:`~tokenizers.TextInputSequence` + - If ``is_pretokenized=True``: :class:`~tokenizers.PreTokenizedInputSequence` + + pair (:obj:`~tokenizers.InputSequence`, `optional`): + An optional input sequence. The expected format is the same that for ``sequence``. + + is_pretokenized (:obj:`bool`, defaults to :obj:`False`): + Whether the input is already pre-tokenized + + add_special_tokens (:obj:`bool`, defaults to :obj:`True`): + Whether to add the special tokens + + Returns: + :class:`~tokenizers.Encoding`: The encoded result + + """ + pass + + def encode_batch(self, input, is_pretokenized=False, add_special_tokens=True): + """ + Encode the given batch of inputs. This method accept both raw text sequences + as well as already pre-tokenized sequences. The reason we use `PySequence` is + because it allows type checking with zero-cost (according to PyO3) as we don't + have to convert to check. + + Example: + Here are some examples of the inputs that are accepted:: + + encode_batch([ + "A single sequence", + ("A tuple with a sequence", "And its pair"), + [ "A", "pre", "tokenized", "sequence" ], + ([ "A", "pre", "tokenized", "sequence" ], "And its pair") + ]) + + Args: + input (A :obj:`List`/:obj:`Tuple` of :obj:`~tokenizers.EncodeInput`): + A list of single sequences or pair sequences to encode. Each sequence + can be either raw text or pre-tokenized, according to the ``is_pretokenized`` + argument: + + - If ``is_pretokenized=False``: :class:`~tokenizers.TextEncodeInput` + - If ``is_pretokenized=True``: :class:`~tokenizers.PreTokenizedEncodeInput` + + is_pretokenized (:obj:`bool`, defaults to :obj:`False`): + Whether the input is already pre-tokenized + + add_special_tokens (:obj:`bool`, defaults to :obj:`True`): + Whether to add the special tokens + + Returns: + A :obj:`List` of :class:`~tokenizers.Encoding`: The encoded batch + + """ + pass + + def encode_batch_fast(self, input, is_pretokenized=False, add_special_tokens=True): + """ + Encode the given batch of inputs. This method is faster than `encode_batch` + because it doesn't keep track of offsets, they will be all zeros. + + Example: + Here are some examples of the inputs that are accepted:: + + encode_batch_fast([ + "A single sequence", + ("A tuple with a sequence", "And its pair"), + [ "A", "pre", "tokenized", "sequence" ], + ([ "A", "pre", "tokenized", "sequence" ], "And its pair") + ]) + + Args: + input (A :obj:`List`/:obj:`Tuple` of :obj:`~tokenizers.EncodeInput`): + A list of single sequences or pair sequences to encode. Each sequence + can be either raw text or pre-tokenized, according to the ``is_pretokenized`` + argument: + + - If ``is_pretokenized=False``: :class:`~tokenizers.TextEncodeInput` + - If ``is_pretokenized=True``: :class:`~tokenizers.PreTokenizedEncodeInput` + + is_pretokenized (:obj:`bool`, defaults to :obj:`False`): + Whether the input is already pre-tokenized + + add_special_tokens (:obj:`bool`, defaults to :obj:`True`): + Whether to add the special tokens + + Returns: + A :obj:`List` of :class:`~tokenizers.Encoding`: The encoded batch + + """ + pass + + @property + def encode_special_tokens(self): + """ + Modifies the tokenizer in order to use or not the special tokens + during encoding. + + Args: + value (:obj:`bool`): + Whether to use the special tokens or not + + """ + pass + + @staticmethod + def from_buffer(buffer): + """ + Instantiate a new :class:`~tokenizers.Tokenizer` from the given buffer. + + Args: + buffer (:obj:`bytes`): + A buffer containing a previously serialized :class:`~tokenizers.Tokenizer` + + Returns: + :class:`~tokenizers.Tokenizer`: The new tokenizer + """ + pass + + @staticmethod + def from_file(path): + """ + Instantiate a new :class:`~tokenizers.Tokenizer` from the file at the given path. + + Args: + path (:obj:`str`): + A path to a local JSON file representing a previously serialized + :class:`~tokenizers.Tokenizer` + + Returns: + :class:`~tokenizers.Tokenizer`: The new tokenizer + """ + pass + + @staticmethod + def from_pretrained(identifier, revision="main", token=None): + """ + Instantiate a new :class:`~tokenizers.Tokenizer` from an existing file on the + Hugging Face Hub. + + Args: + identifier (:obj:`str`): + The identifier of a Model on the Hugging Face Hub, that contains + a tokenizer.json file + revision (:obj:`str`, defaults to `main`): + A branch or commit id + token (:obj:`str`, `optional`, defaults to `None`): + An optional auth token used to access private repositories on the + Hugging Face Hub + + Returns: + :class:`~tokenizers.Tokenizer`: The new tokenizer + """ + pass + + @staticmethod + def from_str(json): + """ + Instantiate a new :class:`~tokenizers.Tokenizer` from the given JSON string. + + Args: + json (:obj:`str`): + A valid JSON string representing a previously serialized + :class:`~tokenizers.Tokenizer` + + Returns: + :class:`~tokenizers.Tokenizer`: The new tokenizer + """ + pass + + def get_added_tokens_decoder(self): + """ + Get the underlying vocabulary + + Returns: + :obj:`Dict[int, AddedToken]`: The vocabulary + """ + pass + + def get_vocab(self, with_added_tokens=True): + """ + Get the underlying vocabulary + + Args: + with_added_tokens (:obj:`bool`, defaults to :obj:`True`): + Whether to include the added tokens + + Returns: + :obj:`Dict[str, int]`: The vocabulary + """ + pass + + def get_vocab_size(self, with_added_tokens=True): + """ + Get the size of the underlying vocabulary + + Args: + with_added_tokens (:obj:`bool`, defaults to :obj:`True`): + Whether to include the added tokens + + Returns: + :obj:`int`: The size of the vocabulary + """ + pass + + def id_to_token(self, id): + """ + Convert the given id to its corresponding token if it exists + + Args: + id (:obj:`int`): + The id to convert + + Returns: + :obj:`Optional[str]`: An optional token, :obj:`None` if out of vocabulary + """ + pass + + @property + def model(self): + """ + The :class:`~tokenizers.models.Model` in use by the Tokenizer + """ + pass + + def no_padding(self): + """ + Disable padding + """ + pass + + def no_truncation(self): + """ + Disable truncation + """ + pass + + @property + def normalizer(self): + """ + The `optional` :class:`~tokenizers.normalizers.Normalizer` in use by the Tokenizer + """ + pass + + def num_special_tokens_to_add(self, is_pair): + """ + Return the number of special tokens that would be added for single/pair sentences. + :param is_pair: Boolean indicating if the input would be a single sentence or a pair + :return: + """ + pass + + @property + def padding(self): + """ + Get the current padding parameters + + `Cannot be set, use` :meth:`~tokenizers.Tokenizer.enable_padding` `instead` + + Returns: + (:obj:`dict`, `optional`): + A dict with the current padding parameters if padding is enabled + """ + pass + + def post_process(self, encoding, pair=None, add_special_tokens=True): + """ + Apply all the post-processing steps to the given encodings. + + The various steps are: + + 1. Truncate according to the set truncation params (provided with + :meth:`~tokenizers.Tokenizer.enable_truncation`) + 2. Apply the :class:`~tokenizers.processors.PostProcessor` + 3. Pad according to the set padding params (provided with + :meth:`~tokenizers.Tokenizer.enable_padding`) + + Args: + encoding (:class:`~tokenizers.Encoding`): + The :class:`~tokenizers.Encoding` corresponding to the main sequence. + + pair (:class:`~tokenizers.Encoding`, `optional`): + An optional :class:`~tokenizers.Encoding` corresponding to the pair sequence. + + add_special_tokens (:obj:`bool`): + Whether to add the special tokens + + Returns: + :class:`~tokenizers.Encoding`: The final post-processed encoding + """ + pass + + @property + def post_processor(self): + """ + The `optional` :class:`~tokenizers.processors.PostProcessor` in use by the Tokenizer + """ + pass + + @property + def pre_tokenizer(self): + """ + The `optional` :class:`~tokenizers.pre_tokenizers.PreTokenizer` in use by the Tokenizer + """ + pass + + def save(self, path, pretty=True): + """ + Save the :class:`~tokenizers.Tokenizer` to the file at the given path. + + Args: + path (:obj:`str`): + A path to a file in which to save the serialized tokenizer. + + pretty (:obj:`bool`, defaults to :obj:`True`): + Whether the JSON file should be pretty formatted. + """ + pass + + def to_str(self, pretty=False): + """ + Gets a serialized string representing this :class:`~tokenizers.Tokenizer`. + + Args: + pretty (:obj:`bool`, defaults to :obj:`False`): + Whether the JSON string should be pretty formatted. + + Returns: + :obj:`str`: A string representing the serialized Tokenizer + """ + pass + + def token_to_id(self, token): + """ + Convert the given token to its corresponding id if it exists + + Args: + token (:obj:`str`): + The token to convert + + Returns: + :obj:`Optional[int]`: An optional id, :obj:`None` if out of vocabulary + """ + pass + + def train(self, files, trainer=None): + """ + Train the Tokenizer using the given files. + + Reads the files line by line, while keeping all the whitespace, even new lines. + If you want to train from data store in-memory, you can check + :meth:`~tokenizers.Tokenizer.train_from_iterator` + + Args: + files (:obj:`List[str]`): + A list of path to the files that we should use for training + + trainer (:obj:`~tokenizers.trainers.Trainer`, `optional`): + An optional trainer that should be used to train our Model + """ + pass + + def train_from_iterator(self, iterator, trainer=None, length=None): + """ + Train the Tokenizer using the provided iterator. + + You can provide anything that is a Python Iterator + + * A list of sequences :obj:`List[str]` + * A generator that yields :obj:`str` or :obj:`List[str]` + * A Numpy array of strings + * ... + + Args: + iterator (:obj:`Iterator`): + Any iterator over strings or list of strings + + trainer (:obj:`~tokenizers.trainers.Trainer`, `optional`): + An optional trainer that should be used to train our Model + + length (:obj:`int`, `optional`): + The total number of sequences in the iterator. This is used to + provide meaningful progress tracking + """ + pass + + @property + def truncation(self): + """ + Get the currently set truncation parameters + + `Cannot set, use` :meth:`~tokenizers.Tokenizer.enable_truncation` `instead` + + Returns: + (:obj:`dict`, `optional`): + A dict with the current truncation parameters if truncation is enabled + """ + pass diff --git a/phivenv/Lib/site-packages/tokenizers/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/tokenizers/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22f3d2d44fc03b1ebedb248ef238c3068b32f605 Binary files /dev/null and b/phivenv/Lib/site-packages/tokenizers/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/tokenizers/decoders/__init__.py b/phivenv/Lib/site-packages/tokenizers/decoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cf42c8857f49955d727da350aab2599a12ac2873 --- /dev/null +++ b/phivenv/Lib/site-packages/tokenizers/decoders/__init__.py @@ -0,0 +1,15 @@ +from .. import decoders + + +Decoder = decoders.Decoder +ByteLevel = decoders.ByteLevel +Replace = decoders.Replace +WordPiece = decoders.WordPiece +ByteFallback = decoders.ByteFallback +Fuse = decoders.Fuse +Strip = decoders.Strip +Metaspace = decoders.Metaspace +BPEDecoder = decoders.BPEDecoder +CTC = decoders.CTC +Sequence = decoders.Sequence +DecodeStream = decoders.DecodeStream diff --git a/phivenv/Lib/site-packages/tokenizers/decoders/__init__.pyi b/phivenv/Lib/site-packages/tokenizers/decoders/__init__.pyi new file mode 100644 index 0000000000000000000000000000000000000000..36897bf7596db4a5b34d93752e8aefa62a7bc159 --- /dev/null +++ b/phivenv/Lib/site-packages/tokenizers/decoders/__init__.pyi @@ -0,0 +1,279 @@ +# Generated content DO NOT EDIT +class DecodeStream: + """ + Class needed for streaming decode + + """ + def __init__(self, ids=None, skip_special_tokens=False): + pass + +class Decoder: + """ + Base class for all decoders + + This class is not supposed to be instantiated directly. Instead, any implementation of + a Decoder will return an instance of this class when instantiated. + """ + def decode(self, tokens): + """ + Decode the given list of tokens to a final string + + Args: + tokens (:obj:`List[str]`): + The list of tokens to decode + + Returns: + :obj:`str`: The decoded string + """ + pass + +class BPEDecoder(Decoder): + """ + BPEDecoder Decoder + + Args: + suffix (:obj:`str`, `optional`, defaults to :obj:``): + The suffix that was used to characterize an end-of-word. This suffix will + be replaced by whitespaces during the decoding + """ + def __init__(self, suffix=""): + pass + + def decode(self, tokens): + """ + Decode the given list of tokens to a final string + + Args: + tokens (:obj:`List[str]`): + The list of tokens to decode + + Returns: + :obj:`str`: The decoded string + """ + pass + +class ByteFallback(Decoder): + """ + ByteFallback Decoder + ByteFallback is a simple trick which converts tokens looking like `<0x61>` + to pure bytes, and attempts to make them into a string. If the tokens + cannot be decoded you will get � instead for each inconvertible byte token + + """ + def __init__(self): + pass + + def decode(self, tokens): + """ + Decode the given list of tokens to a final string + + Args: + tokens (:obj:`List[str]`): + The list of tokens to decode + + Returns: + :obj:`str`: The decoded string + """ + pass + +class ByteLevel(Decoder): + """ + ByteLevel Decoder + + This decoder is to be used in tandem with the :class:`~tokenizers.pre_tokenizers.ByteLevel` + :class:`~tokenizers.pre_tokenizers.PreTokenizer`. + """ + def __init__(self): + pass + + def decode(self, tokens): + """ + Decode the given list of tokens to a final string + + Args: + tokens (:obj:`List[str]`): + The list of tokens to decode + + Returns: + :obj:`str`: The decoded string + """ + pass + +class CTC(Decoder): + """ + CTC Decoder + + Args: + pad_token (:obj:`str`, `optional`, defaults to :obj:``): + The pad token used by CTC to delimit a new token. + word_delimiter_token (:obj:`str`, `optional`, defaults to :obj:`|`): + The word delimiter token. It will be replaced by a + cleanup (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to cleanup some tokenization artifacts. + Mainly spaces before punctuation, and some abbreviated english forms. + """ + def __init__(self, pad_token="", word_delimiter_token="|", cleanup=True): + pass + + def decode(self, tokens): + """ + Decode the given list of tokens to a final string + + Args: + tokens (:obj:`List[str]`): + The list of tokens to decode + + Returns: + :obj:`str`: The decoded string + """ + pass + +class Fuse(Decoder): + """ + Fuse Decoder + Fuse simply fuses every token into a single string. + This is the last step of decoding, this decoder exists only if + there is need to add other decoders *after* the fusion + """ + def __init__(self): + pass + + def decode(self, tokens): + """ + Decode the given list of tokens to a final string + + Args: + tokens (:obj:`List[str]`): + The list of tokens to decode + + Returns: + :obj:`str`: The decoded string + """ + pass + +class Metaspace(Decoder): + """ + Metaspace Decoder + + Args: + replacement (:obj:`str`, `optional`, defaults to :obj:`▁`): + The replacement character. Must be exactly one character. By default we + use the `▁` (U+2581) meta symbol (Same as in SentencePiece). + + prepend_scheme (:obj:`str`, `optional`, defaults to :obj:`"always"`): + Whether to add a space to the first word if there isn't already one. This + lets us treat `hello` exactly like `say hello`. + Choices: "always", "never", "first". First means the space is only added on the first + token (relevant when special tokens are used or other pre_tokenizer are used). + """ + def __init__(self, replacement="▁", prepend_scheme="always", split=True): + pass + + def decode(self, tokens): + """ + Decode the given list of tokens to a final string + + Args: + tokens (:obj:`List[str]`): + The list of tokens to decode + + Returns: + :obj:`str`: The decoded string + """ + pass + +class Replace(Decoder): + """ + Replace Decoder + + This decoder is to be used in tandem with the :class:`~tokenizers.pre_tokenizers.Replace` + :class:`~tokenizers.pre_tokenizers.PreTokenizer`. + """ + def __init__(self, pattern, content): + pass + + def decode(self, tokens): + """ + Decode the given list of tokens to a final string + + Args: + tokens (:obj:`List[str]`): + The list of tokens to decode + + Returns: + :obj:`str`: The decoded string + """ + pass + +class Sequence(Decoder): + """ + Sequence Decoder + + Args: + decoders (:obj:`List[Decoder]`) + The decoders that need to be chained + """ + def __init__(self, decoders): + pass + + def decode(self, tokens): + """ + Decode the given list of tokens to a final string + + Args: + tokens (:obj:`List[str]`): + The list of tokens to decode + + Returns: + :obj:`str`: The decoded string + """ + pass + +class Strip(Decoder): + """ + Strip normalizer + Strips n left characters of each token, or n right characters of each token + """ + def __init__(self, content, left=0, right=0): + pass + + def decode(self, tokens): + """ + Decode the given list of tokens to a final string + + Args: + tokens (:obj:`List[str]`): + The list of tokens to decode + + Returns: + :obj:`str`: The decoded string + """ + pass + +class WordPiece(Decoder): + """ + WordPiece Decoder + + Args: + prefix (:obj:`str`, `optional`, defaults to :obj:`##`): + The prefix to use for subwords that are not a beginning-of-word + + cleanup (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to cleanup some tokenization artifacts. Mainly spaces before punctuation, + and some abbreviated english forms. + """ + def __init__(self, prefix="##", cleanup=True): + pass + + def decode(self, tokens): + """ + Decode the given list of tokens to a final string + + Args: + tokens (:obj:`List[str]`): + The list of tokens to decode + + Returns: + :obj:`str`: The decoded string + """ + pass diff --git a/phivenv/Lib/site-packages/tokenizers/decoders/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/tokenizers/decoders/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7a41099fe655c5ab46dd3869bfe542c4ae0408f Binary files /dev/null and b/phivenv/Lib/site-packages/tokenizers/decoders/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/tokenizers/implementations/__init__.py b/phivenv/Lib/site-packages/tokenizers/implementations/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..710505ed6627e3322ba06ad8547769b0934ff537 --- /dev/null +++ b/phivenv/Lib/site-packages/tokenizers/implementations/__init__.py @@ -0,0 +1,6 @@ +from .base_tokenizer import BaseTokenizer +from .bert_wordpiece import BertWordPieceTokenizer +from .byte_level_bpe import ByteLevelBPETokenizer +from .char_level_bpe import CharBPETokenizer +from .sentencepiece_bpe import SentencePieceBPETokenizer +from .sentencepiece_unigram import SentencePieceUnigramTokenizer diff --git a/phivenv/Lib/site-packages/tokenizers/implementations/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/tokenizers/implementations/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09a4a96916b11b944b36720a3d28e31ded14e0ca Binary files /dev/null and b/phivenv/Lib/site-packages/tokenizers/implementations/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/tokenizers/implementations/__pycache__/base_tokenizer.cpython-39.pyc b/phivenv/Lib/site-packages/tokenizers/implementations/__pycache__/base_tokenizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..750077ab5e7854ab8cd9a826d5b4f8b0d5138792 Binary files /dev/null and b/phivenv/Lib/site-packages/tokenizers/implementations/__pycache__/base_tokenizer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/tokenizers/implementations/__pycache__/bert_wordpiece.cpython-39.pyc b/phivenv/Lib/site-packages/tokenizers/implementations/__pycache__/bert_wordpiece.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c17203f78968ba123a594fa80d62edc1b7492e3c Binary files /dev/null and b/phivenv/Lib/site-packages/tokenizers/implementations/__pycache__/bert_wordpiece.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/tokenizers/implementations/__pycache__/byte_level_bpe.cpython-39.pyc b/phivenv/Lib/site-packages/tokenizers/implementations/__pycache__/byte_level_bpe.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..befb84f3b26809e41e28fde7edc1aff86493a757 Binary files /dev/null and b/phivenv/Lib/site-packages/tokenizers/implementations/__pycache__/byte_level_bpe.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/tokenizers/implementations/__pycache__/char_level_bpe.cpython-39.pyc b/phivenv/Lib/site-packages/tokenizers/implementations/__pycache__/char_level_bpe.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62f0c19a50e9cb0d4993f75b2bc66a47d4894210 Binary files /dev/null and b/phivenv/Lib/site-packages/tokenizers/implementations/__pycache__/char_level_bpe.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/tokenizers/implementations/__pycache__/sentencepiece_bpe.cpython-39.pyc b/phivenv/Lib/site-packages/tokenizers/implementations/__pycache__/sentencepiece_bpe.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..139e082625bddee2312111340efc76ebb3fcd287 Binary files /dev/null and b/phivenv/Lib/site-packages/tokenizers/implementations/__pycache__/sentencepiece_bpe.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/tokenizers/implementations/__pycache__/sentencepiece_unigram.cpython-39.pyc b/phivenv/Lib/site-packages/tokenizers/implementations/__pycache__/sentencepiece_unigram.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..799798dbf974f2e24b9147842bff31408906d9a4 Binary files /dev/null and b/phivenv/Lib/site-packages/tokenizers/implementations/__pycache__/sentencepiece_unigram.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/tokenizers/implementations/base_tokenizer.py b/phivenv/Lib/site-packages/tokenizers/implementations/base_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..e249b4a735417f27626b9c892d50cb103c5e57f1 --- /dev/null +++ b/phivenv/Lib/site-packages/tokenizers/implementations/base_tokenizer.py @@ -0,0 +1,459 @@ +from typing import Dict, List, Optional, Tuple, Union + +from tokenizers import AddedToken, EncodeInput, Encoding, InputSequence, Tokenizer +from tokenizers.decoders import Decoder +from tokenizers.models import Model +from tokenizers.normalizers import Normalizer +from tokenizers.pre_tokenizers import PreTokenizer +from tokenizers.processors import PostProcessor + + +Offsets = Tuple[int, int] + + +class BaseTokenizer: + def __init__(self, tokenizer: Tokenizer, parameters=None): + self._tokenizer = tokenizer + self._parameters = parameters if parameters is not None else {} + + def __repr__(self): + return "Tokenizer(vocabulary_size={}, {})".format( + self._tokenizer.get_vocab_size(), + ", ".join(k + "=" + str(v) for k, v in self._parameters.items()), + ) + + def num_special_tokens_to_add(self, is_pair: bool) -> int: + """ + Return the number of special tokens that would be added for single/pair sentences. + :param is_pair: Boolean indicating if the input would be a single sentence or a pair + :return: + """ + return self._tokenizer.num_special_tokens_to_add(is_pair) + + def get_vocab(self, with_added_tokens: bool = True) -> Dict[str, int]: + """Returns the vocabulary + + Args: + with_added_tokens: boolean: + Whether to include the added tokens in the vocabulary + + Returns: + The vocabulary + """ + return self._tokenizer.get_vocab(with_added_tokens=with_added_tokens) + + def get_added_tokens_decoder(self) -> Dict[int, AddedToken]: + """Returns the added reverse vocabulary + + Returns: + The added vocabulary mapping ints to AddedTokens + """ + return self._tokenizer.get_added_tokens_decoder() + + def get_vocab_size(self, with_added_tokens: bool = True) -> int: + """Return the size of vocabulary, with or without added tokens. + + Args: + with_added_tokens: (`optional`) bool: + Whether to count in added special tokens or not + + Returns: + Size of vocabulary + """ + return self._tokenizer.get_vocab_size(with_added_tokens=with_added_tokens) + + def enable_padding( + self, + direction: Optional[str] = "right", + pad_to_multiple_of: Optional[int] = None, + pad_id: Optional[int] = 0, + pad_type_id: Optional[int] = 0, + pad_token: Optional[str] = "[PAD]", + length: Optional[int] = None, + ): + """Change the padding strategy + + Args: + direction: (`optional`) str: + Can be one of: `right` or `left` + + pad_to_multiple_of: (`optional`) unsigned int: + If specified, the padding length should always snap to the next multiple of + the given value. For example if we were going to pad with a length of 250 but + `pad_to_multiple_of=8` then we will pad to 256. + + pad_id: (`optional`) unsigned int: + The indice to be used when padding + + pad_type_id: (`optional`) unsigned int: + The type indice to be used when padding + + pad_token: (`optional`) str: + The pad token to be used when padding + + length: (`optional`) unsigned int: + If specified, the length at which to pad. If not specified + we pad using the size of the longest sequence in a batch + """ + return self._tokenizer.enable_padding( + direction=direction, + pad_to_multiple_of=pad_to_multiple_of, + pad_id=pad_id, + pad_type_id=pad_type_id, + pad_token=pad_token, + length=length, + ) + + def no_padding(self): + """Disable padding""" + return self._tokenizer.no_padding() + + @property + def padding(self) -> Optional[dict]: + """Get the current padding parameters + + Returns: + None if padding is disabled, a dict with the currently set parameters + if the padding is enabled. + """ + return self._tokenizer.padding + + def enable_truncation(self, max_length: int, stride: Optional[int] = 0, strategy: Optional[str] = "longest_first"): + """Change the truncation options + + Args: + max_length: unsigned int: + The maximum length at which to truncate + + stride: (`optional`) unsigned int: + The length of the previous first sequence to be included + in the overflowing sequence + + strategy: (`optional`) str: + Can be one of `longest_first`, `only_first` or `only_second` + """ + return self._tokenizer.enable_truncation(max_length, stride=stride, strategy=strategy) + + def no_truncation(self): + """Disable truncation""" + return self._tokenizer.no_truncation() + + @property + def truncation(self) -> Optional[dict]: + """Get the current truncation parameters + + Returns: + None if truncation is disabled, a dict with the current truncation parameters if + truncation is enabled + """ + return self._tokenizer.truncation + + def add_tokens(self, tokens: List[Union[str, AddedToken]]) -> int: + """Add the given tokens to the vocabulary + + Args: + tokens: List[Union[str, AddedToken]]: + A list of tokens to add to the vocabulary. Each token can either be + a string, or an instance of AddedToken + + Returns: + The number of tokens that were added to the vocabulary + """ + return self._tokenizer.add_tokens(tokens) + + def add_special_tokens(self, special_tokens: List[Union[str, AddedToken]]) -> int: + """Add the given special tokens to the vocabulary, and treat them as special tokens. + + The special tokens will never be processed by the model, and will be + removed while decoding. + + Args: + tokens: List[Union[str, AddedToken]]: + A list of special tokens to add to the vocabulary. Each token can either be + a string, or an instance of AddedToken + + Returns: + The number of tokens that were added to the vocabulary + """ + return self._tokenizer.add_special_tokens(special_tokens) + + def normalize(self, sequence: str) -> str: + """Normalize the given sequence + + Args: + sequence: str: + The sequence to normalize + + Returns: + The normalized string + """ + return self._tokenizer.normalize(sequence) + + def encode( + self, + sequence: InputSequence, + pair: Optional[InputSequence] = None, + is_pretokenized: bool = False, + add_special_tokens: bool = True, + ) -> Encoding: + """Encode the given sequence and pair. This method can process raw text sequences as well + as already pre-tokenized sequences. + + Args: + sequence: InputSequence: + The sequence we want to encode. This sequence can be either raw text or + pre-tokenized, according to the `is_pretokenized` argument: + + - If `is_pretokenized=False`: `InputSequence` is expected to be `str` + - If `is_pretokenized=True`: `InputSequence` is expected to be + `Union[List[str], Tuple[str]]` + + is_pretokenized: bool: + Whether the input is already pre-tokenized. + + add_special_tokens: bool: + Whether to add the special tokens while encoding. + + Returns: + An Encoding + """ + if sequence is None: + raise ValueError("encode: `sequence` can't be `None`") + + return self._tokenizer.encode(sequence, pair, is_pretokenized, add_special_tokens) + + def encode_batch( + self, + inputs: List[EncodeInput], + is_pretokenized: bool = False, + add_special_tokens: bool = True, + ) -> List[Encoding]: + """Encode the given inputs. This method accept both raw text sequences as well as already + pre-tokenized sequences. + + Args: + inputs: List[EncodeInput]: + A list of single sequences or pair sequences to encode. Each `EncodeInput` is + expected to be of the following form: + `Union[InputSequence, Tuple[InputSequence, InputSequence]]` + + Each `InputSequence` can either be raw text or pre-tokenized, + according to the `is_pretokenized` argument: + + - If `is_pretokenized=False`: `InputSequence` is expected to be `str` + - If `is_pretokenized=True`: `InputSequence` is expected to be + `Union[List[str], Tuple[str]]` + + is_pretokenized: bool: + Whether the input is already pre-tokenized. + + add_special_tokens: bool: + Whether to add the special tokens while encoding. + + Returns: + A list of Encoding + """ + + if inputs is None: + raise ValueError("encode_batch: `inputs` can't be `None`") + + return self._tokenizer.encode_batch(inputs, is_pretokenized, add_special_tokens) + + async def async_encode_batch( + self, + inputs: List[EncodeInput], + is_pretokenized: bool = False, + add_special_tokens: bool = True, + ) -> List[Encoding]: + """Asynchronously encode a batch (tracks character offsets). + + Args: + inputs: A list of single or pair sequences to encode. + is_pretokenized: Whether inputs are already pre-tokenized. + add_special_tokens: Whether to add special tokens. + + Returns: + A list of Encoding. + """ + if inputs is None: + raise ValueError("async_encode_batch: `inputs` can't be `None`") + # Exposed by the Rust bindings via pyo3_async_runtimes::tokio::future_into_py + return await self._tokenizer.async_encode_batch(inputs, is_pretokenized, add_special_tokens) + + async def async_encode_batch_fast( + self, + inputs: List[EncodeInput], + is_pretokenized: bool = False, + add_special_tokens: bool = True, + ) -> List[Encoding]: + """Asynchronously encode a batch (no character offsets, faster). + + Args: + inputs: A list of single or pair sequences to encode. + is_pretokenized: Whether inputs are already pre-tokenized. + add_special_tokens: Whether to add special tokens. + + Returns: + A list of Encoding. + """ + if inputs is None: + raise ValueError("async_encode_batch_fast: `inputs` can't be `None`") + return await self._tokenizer.async_encode_batch_fast(inputs, is_pretokenized, add_special_tokens) + + def decode(self, ids: List[int], skip_special_tokens: Optional[bool] = True) -> str: + """Decode the given list of ids to a string sequence + + Args: + ids: List[unsigned int]: + A list of ids to be decoded + + skip_special_tokens: (`optional`) boolean: + Whether to remove all the special tokens from the output string + + Returns: + The decoded string + """ + if ids is None: + raise ValueError("None input is not valid. Should be a list of integers.") + + return self._tokenizer.decode(ids, skip_special_tokens=skip_special_tokens) + + def decode_batch(self, sequences: List[List[int]], skip_special_tokens: Optional[bool] = True) -> str: + """Decode the list of sequences to a list of string sequences + + Args: + sequences: List[List[unsigned int]]: + A list of sequence of ids to be decoded + + skip_special_tokens: (`optional`) boolean: + Whether to remove all the special tokens from the output strings + + Returns: + A list of decoded strings + """ + if sequences is None: + raise ValueError("None input is not valid. Should be list of list of integers.") + + return self._tokenizer.decode_batch(sequences, skip_special_tokens=skip_special_tokens) + + def token_to_id(self, token: str) -> Optional[int]: + """Convert the given token to its corresponding id + + Args: + token: str: + The token to convert + + Returns: + The corresponding id if it exists, None otherwise + """ + return self._tokenizer.token_to_id(token) + + def id_to_token(self, id: int) -> Optional[str]: + """Convert the given token id to its corresponding string + + Args: + token: id: + The token id to convert + + Returns: + The corresponding string if it exists, None otherwise + """ + return self._tokenizer.id_to_token(id) + + def save_model(self, directory: str, prefix: Optional[str] = None): + """Save the current model to the given directory + + Args: + directory: str: + A path to the destination directory + + prefix: (Optional) str: + An optional prefix, used to prefix each file name + """ + return self._tokenizer.model.save(directory, prefix=prefix) + + def save(self, path: str, pretty: bool = True): + """Save the current Tokenizer at the given path + + Args: + path: str: + A path to the destination Tokenizer file + """ + return self._tokenizer.save(path, pretty) + + def to_str(self, pretty: bool = False): + """Get a serialized JSON version of the Tokenizer as a str + + Args: + pretty: bool: + Whether the JSON string should be prettified + + Returns: + str + """ + return self._tokenizer.to_str(pretty) + + def post_process( + self, encoding: Encoding, pair: Optional[Encoding] = None, add_special_tokens: bool = True + ) -> Encoding: + """Apply all the post-processing steps to the given encodings. + + The various steps are: + 1. Truncate according to global params (provided to `enable_truncation`) + 2. Apply the PostProcessor + 3. Pad according to global params. (provided to `enable_padding`) + + Args: + encoding: Encoding: + The main Encoding to post process + + pair: Optional[Encoding]: + An optional pair Encoding + + add_special_tokens: bool: + Whether to add special tokens + + Returns: + The resulting Encoding + """ + return self._tokenizer.post_process(encoding, pair, add_special_tokens) + + @property + def model(self) -> Model: + return self._tokenizer.model + + @model.setter + def model(self, model: Model): + self._tokenizer.model = model + + @property + def normalizer(self) -> Normalizer: + return self._tokenizer.normalizer + + @normalizer.setter + def normalizer(self, normalizer: Normalizer): + self._tokenizer.normalizer = normalizer + + @property + def pre_tokenizer(self) -> PreTokenizer: + return self._tokenizer.pre_tokenizer + + @pre_tokenizer.setter + def pre_tokenizer(self, pre_tokenizer: PreTokenizer): + self._tokenizer.pre_tokenizer = pre_tokenizer + + @property + def post_processor(self) -> PostProcessor: + return self._tokenizer.post_processor + + @post_processor.setter + def post_processor(self, post_processor: PostProcessor): + self._tokenizer.post_processor = post_processor + + @property + def decoder(self) -> Decoder: + return self._tokenizer.decoder + + @decoder.setter + def decoder(self, decoder: Decoder): + self._tokenizer.decoder = decoder diff --git a/phivenv/Lib/site-packages/tokenizers/implementations/bert_wordpiece.py b/phivenv/Lib/site-packages/tokenizers/implementations/bert_wordpiece.py new file mode 100644 index 0000000000000000000000000000000000000000..ed98c626537d2ad7ec881dc237d91fddaffda9ae --- /dev/null +++ b/phivenv/Lib/site-packages/tokenizers/implementations/bert_wordpiece.py @@ -0,0 +1,151 @@ +from typing import Dict, Iterator, List, Optional, Union + +from tokenizers import AddedToken, Tokenizer, decoders, trainers +from tokenizers.models import WordPiece +from tokenizers.normalizers import BertNormalizer +from tokenizers.pre_tokenizers import BertPreTokenizer +from tokenizers.processors import BertProcessing + +from .base_tokenizer import BaseTokenizer + + +class BertWordPieceTokenizer(BaseTokenizer): + """Bert WordPiece Tokenizer""" + + def __init__( + self, + vocab: Optional[Union[str, Dict[str, int]]] = None, + unk_token: Union[str, AddedToken] = "[UNK]", + sep_token: Union[str, AddedToken] = "[SEP]", + cls_token: Union[str, AddedToken] = "[CLS]", + pad_token: Union[str, AddedToken] = "[PAD]", + mask_token: Union[str, AddedToken] = "[MASK]", + clean_text: bool = True, + handle_chinese_chars: bool = True, + strip_accents: Optional[bool] = None, + lowercase: bool = True, + wordpieces_prefix: str = "##", + ): + if vocab is not None: + tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(unk_token))) + else: + tokenizer = Tokenizer(WordPiece(unk_token=str(unk_token))) + + # Let the tokenizer know about special tokens if they are part of the vocab + if tokenizer.token_to_id(str(unk_token)) is not None: + tokenizer.add_special_tokens([str(unk_token)]) + if tokenizer.token_to_id(str(sep_token)) is not None: + tokenizer.add_special_tokens([str(sep_token)]) + if tokenizer.token_to_id(str(cls_token)) is not None: + tokenizer.add_special_tokens([str(cls_token)]) + if tokenizer.token_to_id(str(pad_token)) is not None: + tokenizer.add_special_tokens([str(pad_token)]) + if tokenizer.token_to_id(str(mask_token)) is not None: + tokenizer.add_special_tokens([str(mask_token)]) + + tokenizer.normalizer = BertNormalizer( + clean_text=clean_text, + handle_chinese_chars=handle_chinese_chars, + strip_accents=strip_accents, + lowercase=lowercase, + ) + tokenizer.pre_tokenizer = BertPreTokenizer() + + if vocab is not None: + sep_token_id = tokenizer.token_to_id(str(sep_token)) + if sep_token_id is None: + raise TypeError("sep_token not found in the vocabulary") + cls_token_id = tokenizer.token_to_id(str(cls_token)) + if cls_token_id is None: + raise TypeError("cls_token not found in the vocabulary") + + tokenizer.post_processor = BertProcessing((str(sep_token), sep_token_id), (str(cls_token), cls_token_id)) + tokenizer.decoder = decoders.WordPiece(prefix=wordpieces_prefix) + + parameters = { + "model": "BertWordPiece", + "unk_token": unk_token, + "sep_token": sep_token, + "cls_token": cls_token, + "pad_token": pad_token, + "mask_token": mask_token, + "clean_text": clean_text, + "handle_chinese_chars": handle_chinese_chars, + "strip_accents": strip_accents, + "lowercase": lowercase, + "wordpieces_prefix": wordpieces_prefix, + } + + super().__init__(tokenizer, parameters) + + @staticmethod + def from_file(vocab: str, **kwargs): + vocab = WordPiece.read_file(vocab) + return BertWordPieceTokenizer(vocab, **kwargs) + + def train( + self, + files: Union[str, List[str]], + vocab_size: int = 30000, + min_frequency: int = 2, + limit_alphabet: int = 1000, + initial_alphabet: List[str] = [], + special_tokens: List[Union[str, AddedToken]] = [ + "[PAD]", + "[UNK]", + "[CLS]", + "[SEP]", + "[MASK]", + ], + show_progress: bool = True, + wordpieces_prefix: str = "##", + ): + """Train the model using the given files""" + + trainer = trainers.WordPieceTrainer( + vocab_size=vocab_size, + min_frequency=min_frequency, + limit_alphabet=limit_alphabet, + initial_alphabet=initial_alphabet, + special_tokens=special_tokens, + show_progress=show_progress, + continuing_subword_prefix=wordpieces_prefix, + ) + if isinstance(files, str): + files = [files] + self._tokenizer.train(files, trainer=trainer) + + def train_from_iterator( + self, + iterator: Union[Iterator[str], Iterator[Iterator[str]]], + vocab_size: int = 30000, + min_frequency: int = 2, + limit_alphabet: int = 1000, + initial_alphabet: List[str] = [], + special_tokens: List[Union[str, AddedToken]] = [ + "[PAD]", + "[UNK]", + "[CLS]", + "[SEP]", + "[MASK]", + ], + show_progress: bool = True, + wordpieces_prefix: str = "##", + length: Optional[int] = None, + ): + """Train the model using the given iterator""" + + trainer = trainers.WordPieceTrainer( + vocab_size=vocab_size, + min_frequency=min_frequency, + limit_alphabet=limit_alphabet, + initial_alphabet=initial_alphabet, + special_tokens=special_tokens, + show_progress=show_progress, + continuing_subword_prefix=wordpieces_prefix, + ) + self._tokenizer.train_from_iterator( + iterator, + trainer=trainer, + length=length, + ) diff --git a/phivenv/Lib/site-packages/tokenizers/implementations/byte_level_bpe.py b/phivenv/Lib/site-packages/tokenizers/implementations/byte_level_bpe.py new file mode 100644 index 0000000000000000000000000000000000000000..f573c15c5a7e4344f712deea45322733819be120 --- /dev/null +++ b/phivenv/Lib/site-packages/tokenizers/implementations/byte_level_bpe.py @@ -0,0 +1,122 @@ +from typing import Dict, Iterator, List, Optional, Tuple, Union + +from tokenizers import AddedToken, Tokenizer, decoders, pre_tokenizers, processors, trainers +from tokenizers.models import BPE +from tokenizers.normalizers import Lowercase, Sequence, unicode_normalizer_from_str + +from .base_tokenizer import BaseTokenizer + + +class ByteLevelBPETokenizer(BaseTokenizer): + """ByteLevelBPETokenizer + + Represents a Byte-level BPE as introduced by OpenAI with their GPT-2 model + """ + + def __init__( + self, + vocab: Optional[Union[str, Dict[str, int]]] = None, + merges: Optional[Union[str, List[Tuple[str, str]]]] = None, + add_prefix_space: bool = False, + lowercase: bool = False, + dropout: Optional[float] = None, + unicode_normalizer: Optional[str] = None, + continuing_subword_prefix: Optional[str] = None, + end_of_word_suffix: Optional[str] = None, + trim_offsets: bool = False, + ): + if vocab is not None and merges is not None: + tokenizer = Tokenizer( + BPE( + vocab, + merges, + dropout=dropout, + continuing_subword_prefix=continuing_subword_prefix or "", + end_of_word_suffix=end_of_word_suffix or "", + ) + ) + else: + tokenizer = Tokenizer(BPE()) + + # Check for Unicode normalization first (before everything else) + normalizers = [] + + if unicode_normalizer: + normalizers += [unicode_normalizer_from_str(unicode_normalizer)] + + if lowercase: + normalizers += [Lowercase()] + + # Create the normalizer structure + if len(normalizers) > 0: + if len(normalizers) > 1: + tokenizer.normalizer = Sequence(normalizers) + else: + tokenizer.normalizer = normalizers[0] + + tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=add_prefix_space) + tokenizer.decoder = decoders.ByteLevel() + tokenizer.post_processor = processors.ByteLevel(trim_offsets=trim_offsets) + + parameters = { + "model": "ByteLevelBPE", + "add_prefix_space": add_prefix_space, + "lowercase": lowercase, + "dropout": dropout, + "unicode_normalizer": unicode_normalizer, + "continuing_subword_prefix": continuing_subword_prefix, + "end_of_word_suffix": end_of_word_suffix, + "trim_offsets": trim_offsets, + } + + super().__init__(tokenizer, parameters) + + @staticmethod + def from_file(vocab_filename: str, merges_filename: str, **kwargs): + vocab, merges = BPE.read_file(vocab_filename, merges_filename) + return ByteLevelBPETokenizer(vocab, merges, **kwargs) + + def train( + self, + files: Union[str, List[str]], + vocab_size: int = 30000, + min_frequency: int = 2, + show_progress: bool = True, + special_tokens: List[Union[str, AddedToken]] = [], + ): + """Train the model using the given files""" + + trainer = trainers.BpeTrainer( + vocab_size=vocab_size, + min_frequency=min_frequency, + show_progress=show_progress, + special_tokens=special_tokens, + initial_alphabet=pre_tokenizers.ByteLevel.alphabet(), + ) + if isinstance(files, str): + files = [files] + self._tokenizer.train(files, trainer=trainer) + + def train_from_iterator( + self, + iterator: Union[Iterator[str], Iterator[Iterator[str]]], + vocab_size: int = 30000, + min_frequency: int = 2, + show_progress: bool = True, + special_tokens: List[Union[str, AddedToken]] = [], + length: Optional[int] = None, + ): + """Train the model using the given iterator""" + + trainer = trainers.BpeTrainer( + vocab_size=vocab_size, + min_frequency=min_frequency, + show_progress=show_progress, + special_tokens=special_tokens, + initial_alphabet=pre_tokenizers.ByteLevel.alphabet(), + ) + self._tokenizer.train_from_iterator( + iterator, + trainer=trainer, + length=length, + ) diff --git a/phivenv/Lib/site-packages/tokenizers/implementations/char_level_bpe.py b/phivenv/Lib/site-packages/tokenizers/implementations/char_level_bpe.py new file mode 100644 index 0000000000000000000000000000000000000000..1891d0435ed429613c1e09486e86165cf19479be --- /dev/null +++ b/phivenv/Lib/site-packages/tokenizers/implementations/char_level_bpe.py @@ -0,0 +1,150 @@ +from typing import Dict, Iterator, List, Optional, Tuple, Union + +from .. import AddedToken, Tokenizer, decoders, pre_tokenizers, trainers +from ..models import BPE +from ..normalizers import BertNormalizer, Lowercase, Sequence, unicode_normalizer_from_str +from .base_tokenizer import BaseTokenizer + + +class CharBPETokenizer(BaseTokenizer): + """Original BPE Tokenizer + + Represents the BPE algorithm, as introduced by Rico Sennrich + (https://arxiv.org/abs/1508.07909) + + The defaults settings corresponds to OpenAI GPT BPE tokenizers and differs from the original + Sennrich subword-nmt implementation by the following options that you can deactivate: + - adding a normalizer to clean up the text (deactivate with `bert_normalizer=False`) by: + * removing any control characters and replacing all whitespaces by the classic one. + * handle chinese chars by putting spaces around them. + * strip all accents. + - spitting on punctuation in addition to whitespaces (deactivate it with + `split_on_whitespace_only=True`) + """ + + def __init__( + self, + vocab: Optional[Union[str, Dict[str, int]]] = None, + merges: Optional[Union[str, List[Tuple[str, str]]]] = None, + unk_token: Union[str, AddedToken] = "", + suffix: str = "", + dropout: Optional[float] = None, + lowercase: bool = False, + unicode_normalizer: Optional[str] = None, + bert_normalizer: bool = True, + split_on_whitespace_only: bool = False, + ): + if vocab is not None and merges is not None: + tokenizer = Tokenizer( + BPE( + vocab, + merges, + dropout=dropout, + unk_token=str(unk_token), + end_of_word_suffix=suffix, + ) + ) + else: + tokenizer = Tokenizer(BPE(unk_token=str(unk_token), dropout=dropout, end_of_word_suffix=suffix)) + + if tokenizer.token_to_id(str(unk_token)) is not None: + tokenizer.add_special_tokens([str(unk_token)]) + + # Check for Unicode normalization first (before everything else) + normalizers = [] + + if unicode_normalizer: + normalizers += [unicode_normalizer_from_str(unicode_normalizer)] + + if bert_normalizer: + normalizers += [BertNormalizer(lowercase=False)] + + if lowercase: + normalizers += [Lowercase()] + + # Create the normalizer structure + if len(normalizers) > 0: + if len(normalizers) > 1: + tokenizer.normalizer = Sequence(normalizers) + else: + tokenizer.normalizer = normalizers[0] + + if split_on_whitespace_only: + tokenizer.pre_tokenizer = pre_tokenizers.WhitespaceSplit() + else: + tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer() + + tokenizer.decoder = decoders.BPEDecoder(suffix=suffix) + + parameters = { + "model": "BPE", + "unk_token": unk_token, + "suffix": suffix, + "dropout": dropout, + "lowercase": lowercase, + "unicode_normalizer": unicode_normalizer, + "bert_normalizer": bert_normalizer, + "split_on_whitespace_only": split_on_whitespace_only, + } + + super().__init__(tokenizer, parameters) + + @staticmethod + def from_file(vocab_filename: str, merges_filename: str, **kwargs): + vocab, merges = BPE.read_file(vocab_filename, merges_filename) + return CharBPETokenizer(vocab, merges, **kwargs) + + def train( + self, + files: Union[str, List[str]], + vocab_size: int = 30000, + min_frequency: int = 2, + special_tokens: List[Union[str, AddedToken]] = [""], + limit_alphabet: int = 1000, + initial_alphabet: List[str] = [], + suffix: Optional[str] = "", + show_progress: bool = True, + ): + """Train the model using the given files""" + + trainer = trainers.BpeTrainer( + vocab_size=vocab_size, + min_frequency=min_frequency, + special_tokens=special_tokens, + limit_alphabet=limit_alphabet, + initial_alphabet=initial_alphabet, + end_of_word_suffix=suffix, + show_progress=show_progress, + ) + if isinstance(files, str): + files = [files] + self._tokenizer.train(files, trainer=trainer) + + def train_from_iterator( + self, + iterator: Union[Iterator[str], Iterator[Iterator[str]]], + vocab_size: int = 30000, + min_frequency: int = 2, + special_tokens: List[Union[str, AddedToken]] = [""], + limit_alphabet: int = 1000, + initial_alphabet: List[str] = [], + suffix: Optional[str] = "", + show_progress: bool = True, + length: Optional[int] = None, + ): + """Train the model using the given iterator""" + + trainer = trainers.BpeTrainer( + vocab_size=vocab_size, + min_frequency=min_frequency, + special_tokens=special_tokens, + limit_alphabet=limit_alphabet, + initial_alphabet=initial_alphabet, + end_of_word_suffix=suffix, + show_progress=show_progress, + ) + self._tokenizer.train_from_iterator( + iterator, + trainer=trainer, + length=length, + ) diff --git a/phivenv/Lib/site-packages/tokenizers/implementations/sentencepiece_bpe.py b/phivenv/Lib/site-packages/tokenizers/implementations/sentencepiece_bpe.py new file mode 100644 index 0000000000000000000000000000000000000000..912fa4676930004ebc0a620b5bc7c1cfb61c3be7 --- /dev/null +++ b/phivenv/Lib/site-packages/tokenizers/implementations/sentencepiece_bpe.py @@ -0,0 +1,103 @@ +from typing import Dict, Iterator, List, Optional, Tuple, Union + +from tokenizers import AddedToken, Tokenizer, decoders, pre_tokenizers, trainers +from tokenizers.models import BPE +from tokenizers.normalizers import NFKC + +from .base_tokenizer import BaseTokenizer + + +class SentencePieceBPETokenizer(BaseTokenizer): + """SentencePiece BPE Tokenizer + + Represents the BPE algorithm, with the pretokenization used by SentencePiece + """ + + def __init__( + self, + vocab: Optional[Union[str, Dict[str, int]]] = None, + merges: Optional[Union[str, List[Tuple[str, str]]]] = None, + unk_token: Union[str, AddedToken] = "", + replacement: str = "▁", + add_prefix_space: bool = True, + dropout: Optional[float] = None, + fuse_unk: Optional[bool] = False, + ): + if vocab is not None and merges is not None: + tokenizer = Tokenizer(BPE(vocab, merges, dropout=dropout, unk_token=unk_token, fuse_unk=fuse_unk)) + else: + tokenizer = Tokenizer(BPE(dropout=dropout, unk_token=unk_token, fuse_unk=fuse_unk)) + + if tokenizer.token_to_id(str(unk_token)) is not None: + tokenizer.add_special_tokens([str(unk_token)]) + + tokenizer.normalizer = NFKC() + prepend_scheme = "always" if add_prefix_space else "never" + tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme) + tokenizer.decoder = decoders.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme) + + parameters = { + "model": "SentencePieceBPE", + "unk_token": unk_token, + "replacement": replacement, + "add_prefix_space": add_prefix_space, + "dropout": dropout, + } + + super().__init__(tokenizer, parameters) + + @staticmethod + def from_file(vocab_filename: str, merges_filename: str, **kwargs): + vocab, merges = BPE.read_file(vocab_filename, merges_filename) + return SentencePieceBPETokenizer(vocab, merges, **kwargs) + + def train( + self, + files: Union[str, List[str]], + vocab_size: int = 30000, + min_frequency: int = 2, + special_tokens: List[Union[str, AddedToken]] = [""], + limit_alphabet: int = 1000, + initial_alphabet: List[str] = [], + show_progress: bool = True, + ): + """Train the model using the given files""" + + trainer = trainers.BpeTrainer( + vocab_size=vocab_size, + min_frequency=min_frequency, + special_tokens=special_tokens, + limit_alphabet=limit_alphabet, + initial_alphabet=initial_alphabet, + show_progress=show_progress, + ) + if isinstance(files, str): + files = [files] + self._tokenizer.train(files, trainer=trainer) + + def train_from_iterator( + self, + iterator: Union[Iterator[str], Iterator[Iterator[str]]], + vocab_size: int = 30000, + min_frequency: int = 2, + special_tokens: List[Union[str, AddedToken]] = [""], + limit_alphabet: int = 1000, + initial_alphabet: List[str] = [], + show_progress: bool = True, + length: Optional[int] = None, + ): + """Train the model using the given iterator""" + + trainer = trainers.BpeTrainer( + vocab_size=vocab_size, + min_frequency=min_frequency, + special_tokens=special_tokens, + limit_alphabet=limit_alphabet, + initial_alphabet=initial_alphabet, + show_progress=show_progress, + ) + self._tokenizer.train_from_iterator( + iterator, + trainer=trainer, + length=length, + ) diff --git a/phivenv/Lib/site-packages/tokenizers/implementations/sentencepiece_unigram.py b/phivenv/Lib/site-packages/tokenizers/implementations/sentencepiece_unigram.py new file mode 100644 index 0000000000000000000000000000000000000000..da3fe32e88fe5521a6615ea86998f349471e8646 --- /dev/null +++ b/phivenv/Lib/site-packages/tokenizers/implementations/sentencepiece_unigram.py @@ -0,0 +1,196 @@ +import json +import os +from typing import Iterator, List, Optional, Union, Tuple + +from tokenizers import AddedToken, Regex, Tokenizer, decoders, normalizers, pre_tokenizers, trainers +from tokenizers.models import Unigram + +from .base_tokenizer import BaseTokenizer + + +class SentencePieceUnigramTokenizer(BaseTokenizer): + """SentencePiece Unigram Tokenizer + + Represents the Unigram algorithm, with the pretokenization used by SentencePiece + """ + + def __init__( + self, + vocab: Optional[List[Tuple[str, float]]] = None, + replacement: str = "▁", + add_prefix_space: bool = True, + ): + if vocab is not None: + # Let Unigram(..) fail if only one of them is None + tokenizer = Tokenizer(Unigram(vocab)) + else: + tokenizer = Tokenizer(Unigram()) + + tokenizer.normalizer = normalizers.Sequence( + [normalizers.Nmt(), normalizers.NFKC(), normalizers.Replace(Regex(" {2,}"), " ")] + ) + prepend_scheme = "always" if add_prefix_space else "never" + tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme) + tokenizer.decoder = decoders.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme) + + parameters = { + "model": "SentencePieceUnigram", + "replacement": replacement, + "add_prefix_space": add_prefix_space, + } + + super().__init__(tokenizer, parameters) + + def train( + self, + files: Union[str, List[str]], + vocab_size: int = 8000, + show_progress: bool = True, + special_tokens: Optional[List[Union[str, AddedToken]]] = None, + initial_alphabet: Optional[List[str]] = None, + unk_token: Optional[str] = None, + ): + """ + Train the model using the given files + + Args: + files (:obj:`List[str]`): + A list of path to the files that we should use for training + vocab_size (:obj:`int`): + The size of the final vocabulary, including all tokens and alphabet. + show_progress (:obj:`bool`): + Whether to show progress bars while training. + special_tokens (:obj:`List[Union[str, AddedToken]]`, `optional`): + A list of special tokens the model should know of. + initial_alphabet (:obj:`List[str]`, `optional`): + A list of characters to include in the initial alphabet, even + if not seen in the training dataset. + If the strings contain more than one character, only the first one + is kept. + unk_token (:obj:`str`, `optional`): + The unknown token to be used by the model. + """ + + if special_tokens is None: + special_tokens = [] + + if initial_alphabet is None: + initial_alphabet = [] + + trainer = trainers.UnigramTrainer( + vocab_size=vocab_size, + special_tokens=special_tokens, + show_progress=show_progress, + initial_alphabet=initial_alphabet, + unk_token=unk_token, + ) + + if isinstance(files, str): + files = [files] + self._tokenizer.train(files, trainer=trainer) + + def train_from_iterator( + self, + iterator: Union[Iterator[str], Iterator[Iterator[str]]], + vocab_size: int = 8000, + show_progress: bool = True, + special_tokens: Optional[List[Union[str, AddedToken]]] = None, + initial_alphabet: Optional[List[str]] = None, + unk_token: Optional[str] = None, + length: Optional[int] = None, + ): + """ + Train the model using the given iterator + + Args: + iterator (:obj:`Union[Iterator[str], Iterator[Iterator[str]]]`): + Any iterator over strings or list of strings + vocab_size (:obj:`int`): + The size of the final vocabulary, including all tokens and alphabet. + show_progress (:obj:`bool`): + Whether to show progress bars while training. + special_tokens (:obj:`List[Union[str, AddedToken]]`, `optional`): + A list of special tokens the model should know of. + initial_alphabet (:obj:`List[str]`, `optional`): + A list of characters to include in the initial alphabet, even + if not seen in the training dataset. + If the strings contain more than one character, only the first one + is kept. + unk_token (:obj:`str`, `optional`): + The unknown token to be used by the model. + length (:obj:`int`, `optional`): + The total number of sequences in the iterator. This is used to + provide meaningful progress tracking + """ + + if special_tokens is None: + special_tokens = [] + + if initial_alphabet is None: + initial_alphabet = [] + + trainer = trainers.UnigramTrainer( + vocab_size=vocab_size, + special_tokens=special_tokens, + show_progress=show_progress, + initial_alphabet=initial_alphabet, + unk_token=unk_token, + ) + + self._tokenizer.train_from_iterator( + iterator, + trainer=trainer, + length=length, + ) + + @staticmethod + def from_spm(filename: str): + try: + import sys + + sys.path.append(".") + + import sentencepiece_model_pb2 as model + except Exception: + raise Exception( + "You don't seem to have the required protobuf file, in order to use this function you need to run `pip install protobuf` and `wget https://raw.githubusercontent.com/google/sentencepiece/master/python/src/sentencepiece/sentencepiece_model_pb2.py` for us to be able to read the intrinsics of your spm_file. `pip install sentencepiece` is not required." + ) + + m = model.ModelProto() + m.ParseFromString(open(filename, "rb").read()) + + precompiled_charsmap = m.normalizer_spec.precompiled_charsmap + vocab = [(piece.piece, piece.score) for piece in m.pieces] + unk_id = m.trainer_spec.unk_id + model_type = m.trainer_spec.model_type + byte_fallback = m.trainer_spec.byte_fallback + if model_type != 1: + raise Exception( + "You're trying to run a `Unigram` model but you're file was trained with a different algorithm" + ) + + replacement = "▁" + add_prefix_space = True + + tokenizer = Tokenizer(Unigram(vocab, unk_id, byte_fallback)) + + if precompiled_charsmap: + tokenizer.normalizer = normalizers.Sequence( + [ + normalizers.Precompiled(precompiled_charsmap), + normalizers.Replace(Regex(" {2,}"), " "), + ] + ) + else: + tokenizer.normalizer = normalizers.Sequence([normalizers.Replace(Regex(" {2,}"), " ")]) + prepend_scheme = "always" if add_prefix_space else "never" + tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme) + tokenizer.decoder = decoders.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme) + + parameters = { + "model": "SentencePieceUnigram", + } + + obj = BaseTokenizer.__new__(SentencePieceUnigramTokenizer, tokenizer, parameters) + BaseTokenizer.__init__(obj, tokenizer, parameters) + return obj diff --git a/phivenv/Lib/site-packages/tokenizers/models/__init__.py b/phivenv/Lib/site-packages/tokenizers/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0094328283214245bae20d1f18b2ba3db716ed2c --- /dev/null +++ b/phivenv/Lib/site-packages/tokenizers/models/__init__.py @@ -0,0 +1,8 @@ +# Generated content DO NOT EDIT +from .. import models + +Model = models.Model +BPE = models.BPE +Unigram = models.Unigram +WordLevel = models.WordLevel +WordPiece = models.WordPiece diff --git a/phivenv/Lib/site-packages/tokenizers/models/__init__.pyi b/phivenv/Lib/site-packages/tokenizers/models/__init__.pyi new file mode 100644 index 0000000000000000000000000000000000000000..e5bff5b080ff41e145752fbebe1e4deb08e10ade --- /dev/null +++ b/phivenv/Lib/site-packages/tokenizers/models/__init__.pyi @@ -0,0 +1,591 @@ +# Generated content DO NOT EDIT +class Model: + """ + Base class for all models + + The model represents the actual tokenization algorithm. This is the part that + will contain and manage the learned vocabulary. + + This class cannot be constructed directly. Please use one of the concrete models. + """ + def get_trainer(self): + """ + Get the associated :class:`~tokenizers.trainers.Trainer` + + Retrieve the :class:`~tokenizers.trainers.Trainer` associated to this + :class:`~tokenizers.models.Model`. + + Returns: + :class:`~tokenizers.trainers.Trainer`: The Trainer used to train this model + """ + pass + + def id_to_token(self, id): + """ + Get the token associated to an ID + + Args: + id (:obj:`int`): + An ID to convert to a token + + Returns: + :obj:`str`: The token associated to the ID + """ + pass + + def save(self, folder, prefix): + """ + Save the current model + + Save the current model in the given folder, using the given prefix for the various + files that will get created. + Any file with the same name that already exists in this folder will be overwritten. + + Args: + folder (:obj:`str`): + The path to the target folder in which to save the various files + + prefix (:obj:`str`, `optional`): + An optional prefix, used to prefix each file name + + Returns: + :obj:`List[str]`: The list of saved files + """ + pass + + def token_to_id(self, tokens): + """ + Get the ID associated to a token + + Args: + token (:obj:`str`): + A token to convert to an ID + + Returns: + :obj:`int`: The ID associated to the token + """ + pass + + def tokenize(self, sequence): + """ + Tokenize a sequence + + Args: + sequence (:obj:`str`): + A sequence to tokenize + + Returns: + A :obj:`List` of :class:`~tokenizers.Token`: The generated tokens + """ + pass + +class BPE(Model): + """ + An implementation of the BPE (Byte-Pair Encoding) algorithm + + Args: + vocab (:obj:`Dict[str, int]`, `optional`): + A dictionary of string keys and their ids :obj:`{"am": 0,...}` + + merges (:obj:`List[Tuple[str, str]]`, `optional`): + A list of pairs of tokens (:obj:`Tuple[str, str]`) :obj:`[("a", "b"),...]` + + cache_capacity (:obj:`int`, `optional`): + The number of words that the BPE cache can contain. The cache allows + to speed-up the process by keeping the result of the merge operations + for a number of words. + + dropout (:obj:`float`, `optional`): + A float between 0 and 1 that represents the BPE dropout to use. + + unk_token (:obj:`str`, `optional`): + The unknown token to be used by the model. + + continuing_subword_prefix (:obj:`str`, `optional`): + The prefix to attach to subword units that don't represent a beginning of word. + + end_of_word_suffix (:obj:`str`, `optional`): + The suffix to attach to subword units that represent an end of word. + + fuse_unk (:obj:`bool`, `optional`): + Whether to fuse any subsequent unknown tokens into a single one + + byte_fallback (:obj:`bool`, `optional`): + Whether to use spm byte-fallback trick (defaults to False) + + ignore_merges (:obj:`bool`, `optional`): + Whether or not to match tokens with the vocab before using merges. + """ + def __init__( + self, + vocab=None, + merges=None, + cache_capacity=None, + dropout=None, + unk_token=None, + continuing_subword_prefix=None, + end_of_word_suffix=None, + fuse_unk=None, + byte_fallback=False, + ignore_merges=False, + ): + pass + + @staticmethod + def from_file(cls, vocab, merge, **kwargs): + """ + Instantiate a BPE model from the given files. + + This method is roughly equivalent to doing:: + + vocab, merges = BPE.read_file(vocab_filename, merges_filename) + bpe = BPE(vocab, merges) + + If you don't need to keep the :obj:`vocab, merges` values lying around, + this method is more optimized than manually calling + :meth:`~tokenizers.models.BPE.read_file` to initialize a :class:`~tokenizers.models.BPE` + + Args: + vocab (:obj:`str`): + The path to a :obj:`vocab.json` file + + merges (:obj:`str`): + The path to a :obj:`merges.txt` file + + Returns: + :class:`~tokenizers.models.BPE`: An instance of BPE loaded from these files + """ + pass + + def get_trainer(self): + """ + Get the associated :class:`~tokenizers.trainers.Trainer` + + Retrieve the :class:`~tokenizers.trainers.Trainer` associated to this + :class:`~tokenizers.models.Model`. + + Returns: + :class:`~tokenizers.trainers.Trainer`: The Trainer used to train this model + """ + pass + + def id_to_token(self, id): + """ + Get the token associated to an ID + + Args: + id (:obj:`int`): + An ID to convert to a token + + Returns: + :obj:`str`: The token associated to the ID + """ + pass + + @staticmethod + def read_file(self, vocab, merges): + """ + Read a :obj:`vocab.json` and a :obj:`merges.txt` files + + This method provides a way to read and parse the content of these files, + returning the relevant data structures. If you want to instantiate some BPE models + from memory, this method gives you the expected input from the standard files. + + Args: + vocab (:obj:`str`): + The path to a :obj:`vocab.json` file + + merges (:obj:`str`): + The path to a :obj:`merges.txt` file + + Returns: + A :obj:`Tuple` with the vocab and the merges: + The vocabulary and merges loaded into memory + """ + pass + + def save(self, folder, prefix): + """ + Save the current model + + Save the current model in the given folder, using the given prefix for the various + files that will get created. + Any file with the same name that already exists in this folder will be overwritten. + + Args: + folder (:obj:`str`): + The path to the target folder in which to save the various files + + prefix (:obj:`str`, `optional`): + An optional prefix, used to prefix each file name + + Returns: + :obj:`List[str]`: The list of saved files + """ + pass + + def token_to_id(self, tokens): + """ + Get the ID associated to a token + + Args: + token (:obj:`str`): + A token to convert to an ID + + Returns: + :obj:`int`: The ID associated to the token + """ + pass + + def tokenize(self, sequence): + """ + Tokenize a sequence + + Args: + sequence (:obj:`str`): + A sequence to tokenize + + Returns: + A :obj:`List` of :class:`~tokenizers.Token`: The generated tokens + """ + pass + +class Unigram(Model): + """ + An implementation of the Unigram algorithm + + Args: + vocab (:obj:`List[Tuple[str, float]]`, `optional`, `optional`): + A list of vocabulary items and their relative score [("am", -0.2442),...] + """ + def __init__(self, vocab, unk_id, byte_fallback): + pass + + def get_trainer(self): + """ + Get the associated :class:`~tokenizers.trainers.Trainer` + + Retrieve the :class:`~tokenizers.trainers.Trainer` associated to this + :class:`~tokenizers.models.Model`. + + Returns: + :class:`~tokenizers.trainers.Trainer`: The Trainer used to train this model + """ + pass + + def id_to_token(self, id): + """ + Get the token associated to an ID + + Args: + id (:obj:`int`): + An ID to convert to a token + + Returns: + :obj:`str`: The token associated to the ID + """ + pass + + def save(self, folder, prefix): + """ + Save the current model + + Save the current model in the given folder, using the given prefix for the various + files that will get created. + Any file with the same name that already exists in this folder will be overwritten. + + Args: + folder (:obj:`str`): + The path to the target folder in which to save the various files + + prefix (:obj:`str`, `optional`): + An optional prefix, used to prefix each file name + + Returns: + :obj:`List[str]`: The list of saved files + """ + pass + + def token_to_id(self, tokens): + """ + Get the ID associated to a token + + Args: + token (:obj:`str`): + A token to convert to an ID + + Returns: + :obj:`int`: The ID associated to the token + """ + pass + + def tokenize(self, sequence): + """ + Tokenize a sequence + + Args: + sequence (:obj:`str`): + A sequence to tokenize + + Returns: + A :obj:`List` of :class:`~tokenizers.Token`: The generated tokens + """ + pass + +class WordLevel(Model): + """ + An implementation of the WordLevel algorithm + + Most simple tokenizer model based on mapping tokens to their corresponding id. + + Args: + vocab (:obj:`str`, `optional`): + A dictionary of string keys and their ids :obj:`{"am": 0,...}` + + unk_token (:obj:`str`, `optional`): + The unknown token to be used by the model. + """ + def __init__(self, vocab, unk_token): + pass + + @staticmethod + def from_file(vocab, unk_token): + """ + Instantiate a WordLevel model from the given file + + This method is roughly equivalent to doing:: + + vocab = WordLevel.read_file(vocab_filename) + wordlevel = WordLevel(vocab) + + If you don't need to keep the :obj:`vocab` values lying around, this method is + more optimized than manually calling :meth:`~tokenizers.models.WordLevel.read_file` to + initialize a :class:`~tokenizers.models.WordLevel` + + Args: + vocab (:obj:`str`): + The path to a :obj:`vocab.json` file + + Returns: + :class:`~tokenizers.models.WordLevel`: An instance of WordLevel loaded from file + """ + pass + + def get_trainer(self): + """ + Get the associated :class:`~tokenizers.trainers.Trainer` + + Retrieve the :class:`~tokenizers.trainers.Trainer` associated to this + :class:`~tokenizers.models.Model`. + + Returns: + :class:`~tokenizers.trainers.Trainer`: The Trainer used to train this model + """ + pass + + def id_to_token(self, id): + """ + Get the token associated to an ID + + Args: + id (:obj:`int`): + An ID to convert to a token + + Returns: + :obj:`str`: The token associated to the ID + """ + pass + + @staticmethod + def read_file(vocab): + """ + Read a :obj:`vocab.json` + + This method provides a way to read and parse the content of a vocabulary file, + returning the relevant data structures. If you want to instantiate some WordLevel models + from memory, this method gives you the expected input from the standard files. + + Args: + vocab (:obj:`str`): + The path to a :obj:`vocab.json` file + + Returns: + :obj:`Dict[str, int]`: The vocabulary as a :obj:`dict` + """ + pass + + def save(self, folder, prefix): + """ + Save the current model + + Save the current model in the given folder, using the given prefix for the various + files that will get created. + Any file with the same name that already exists in this folder will be overwritten. + + Args: + folder (:obj:`str`): + The path to the target folder in which to save the various files + + prefix (:obj:`str`, `optional`): + An optional prefix, used to prefix each file name + + Returns: + :obj:`List[str]`: The list of saved files + """ + pass + + def token_to_id(self, tokens): + """ + Get the ID associated to a token + + Args: + token (:obj:`str`): + A token to convert to an ID + + Returns: + :obj:`int`: The ID associated to the token + """ + pass + + def tokenize(self, sequence): + """ + Tokenize a sequence + + Args: + sequence (:obj:`str`): + A sequence to tokenize + + Returns: + A :obj:`List` of :class:`~tokenizers.Token`: The generated tokens + """ + pass + +class WordPiece(Model): + """ + An implementation of the WordPiece algorithm + + Args: + vocab (:obj:`Dict[str, int]`, `optional`): + A dictionary of string keys and their ids :obj:`{"am": 0,...}` + + unk_token (:obj:`str`, `optional`): + The unknown token to be used by the model. + + max_input_chars_per_word (:obj:`int`, `optional`): + The maximum number of characters to authorize in a single word. + """ + def __init__(self, vocab, unk_token, max_input_chars_per_word): + pass + + @staticmethod + def from_file(vocab, **kwargs): + """ + Instantiate a WordPiece model from the given file + + This method is roughly equivalent to doing:: + + vocab = WordPiece.read_file(vocab_filename) + wordpiece = WordPiece(vocab) + + If you don't need to keep the :obj:`vocab` values lying around, this method is + more optimized than manually calling :meth:`~tokenizers.models.WordPiece.read_file` to + initialize a :class:`~tokenizers.models.WordPiece` + + Args: + vocab (:obj:`str`): + The path to a :obj:`vocab.txt` file + + Returns: + :class:`~tokenizers.models.WordPiece`: An instance of WordPiece loaded from file + """ + pass + + def get_trainer(self): + """ + Get the associated :class:`~tokenizers.trainers.Trainer` + + Retrieve the :class:`~tokenizers.trainers.Trainer` associated to this + :class:`~tokenizers.models.Model`. + + Returns: + :class:`~tokenizers.trainers.Trainer`: The Trainer used to train this model + """ + pass + + def id_to_token(self, id): + """ + Get the token associated to an ID + + Args: + id (:obj:`int`): + An ID to convert to a token + + Returns: + :obj:`str`: The token associated to the ID + """ + pass + + @staticmethod + def read_file(vocab): + """ + Read a :obj:`vocab.txt` file + + This method provides a way to read and parse the content of a standard `vocab.txt` + file as used by the WordPiece Model, returning the relevant data structures. If you + want to instantiate some WordPiece models from memory, this method gives you the + expected input from the standard files. + + Args: + vocab (:obj:`str`): + The path to a :obj:`vocab.txt` file + + Returns: + :obj:`Dict[str, int]`: The vocabulary as a :obj:`dict` + """ + pass + + def save(self, folder, prefix): + """ + Save the current model + + Save the current model in the given folder, using the given prefix for the various + files that will get created. + Any file with the same name that already exists in this folder will be overwritten. + + Args: + folder (:obj:`str`): + The path to the target folder in which to save the various files + + prefix (:obj:`str`, `optional`): + An optional prefix, used to prefix each file name + + Returns: + :obj:`List[str]`: The list of saved files + """ + pass + + def token_to_id(self, tokens): + """ + Get the ID associated to a token + + Args: + token (:obj:`str`): + A token to convert to an ID + + Returns: + :obj:`int`: The ID associated to the token + """ + pass + + def tokenize(self, sequence): + """ + Tokenize a sequence + + Args: + sequence (:obj:`str`): + A sequence to tokenize + + Returns: + A :obj:`List` of :class:`~tokenizers.Token`: The generated tokens + """ + pass diff --git a/phivenv/Lib/site-packages/tokenizers/models/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/tokenizers/models/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3bca55614e6d41e3d83155f1b3e69ef551900919 Binary files /dev/null and b/phivenv/Lib/site-packages/tokenizers/models/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/tokenizers/normalizers/__init__.py b/phivenv/Lib/site-packages/tokenizers/normalizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9e7b7476ef32ee94036ad57c7f88956552018ca7 --- /dev/null +++ b/phivenv/Lib/site-packages/tokenizers/normalizers/__init__.py @@ -0,0 +1,29 @@ +from .. import normalizers + + +Normalizer = normalizers.Normalizer +BertNormalizer = normalizers.BertNormalizer +NFD = normalizers.NFD +NFKD = normalizers.NFKD +NFC = normalizers.NFC +NFKC = normalizers.NFKC +Sequence = normalizers.Sequence +Lowercase = normalizers.Lowercase +Prepend = normalizers.Prepend +Strip = normalizers.Strip +StripAccents = normalizers.StripAccents +Nmt = normalizers.Nmt +Precompiled = normalizers.Precompiled +Replace = normalizers.Replace +ByteLevel = normalizers.ByteLevel + +NORMALIZERS = {"nfc": NFC, "nfd": NFD, "nfkc": NFKC, "nfkd": NFKD} + + +def unicode_normalizer_from_str(normalizer: str) -> Normalizer: + if normalizer not in NORMALIZERS: + raise ValueError( + "{} is not a known unicode normalizer. Available are {}".format(normalizer, NORMALIZERS.keys()) + ) + + return NORMALIZERS[normalizer]() diff --git a/phivenv/Lib/site-packages/tokenizers/normalizers/__init__.pyi b/phivenv/Lib/site-packages/tokenizers/normalizers/__init__.pyi new file mode 100644 index 0000000000000000000000000000000000000000..62b26b9fe4306b58bd2648a57bc731a1383bb78e --- /dev/null +++ b/phivenv/Lib/site-packages/tokenizers/normalizers/__init__.pyi @@ -0,0 +1,636 @@ +# Generated content DO NOT EDIT +class Normalizer: + """ + Base class for all normalizers + + This class is not supposed to be instantiated directly. Instead, any implementation of a + Normalizer will return an instance of this class when instantiated. + """ + def normalize(self, normalized): + """ + Normalize a :class:`~tokenizers.NormalizedString` in-place + + This method allows to modify a :class:`~tokenizers.NormalizedString` to + keep track of the alignment information. If you just want to see the result + of the normalization on a raw string, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize_str` + + Args: + normalized (:class:`~tokenizers.NormalizedString`): + The normalized string on which to apply this + :class:`~tokenizers.normalizers.Normalizer` + """ + pass + + def normalize_str(self, sequence): + """ + Normalize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.normalizers.Normalizer` but it does not keep track of the alignment + information. If you need to get/convert offsets, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize` + + Args: + sequence (:obj:`str`): + A string to normalize + + Returns: + :obj:`str`: A string after normalization + """ + pass + +class BertNormalizer(Normalizer): + """ + BertNormalizer + + Takes care of normalizing raw text before giving it to a Bert model. + This includes cleaning the text, handling accents, chinese chars and lowercasing + + Args: + clean_text (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to clean the text, by removing any control characters + and replacing all whitespaces by the classic one. + + handle_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to handle chinese chars by putting spaces around them. + + strip_accents (:obj:`bool`, `optional`): + Whether to strip all accents. If this option is not specified (ie == None), + then it will be determined by the value for `lowercase` (as in the original Bert). + + lowercase (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to lowercase. + """ + def __init__(self, clean_text=True, handle_chinese_chars=True, strip_accents=None, lowercase=True): + pass + + def normalize(self, normalized): + """ + Normalize a :class:`~tokenizers.NormalizedString` in-place + + This method allows to modify a :class:`~tokenizers.NormalizedString` to + keep track of the alignment information. If you just want to see the result + of the normalization on a raw string, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize_str` + + Args: + normalized (:class:`~tokenizers.NormalizedString`): + The normalized string on which to apply this + :class:`~tokenizers.normalizers.Normalizer` + """ + pass + + def normalize_str(self, sequence): + """ + Normalize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.normalizers.Normalizer` but it does not keep track of the alignment + information. If you need to get/convert offsets, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize` + + Args: + sequence (:obj:`str`): + A string to normalize + + Returns: + :obj:`str`: A string after normalization + """ + pass + +class ByteLevel(Normalizer): + """ + Bytelevel Normalizer + """ + def __init__(self): + pass + + def normalize(self, normalized): + """ + Normalize a :class:`~tokenizers.NormalizedString` in-place + + This method allows to modify a :class:`~tokenizers.NormalizedString` to + keep track of the alignment information. If you just want to see the result + of the normalization on a raw string, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize_str` + + Args: + normalized (:class:`~tokenizers.NormalizedString`): + The normalized string on which to apply this + :class:`~tokenizers.normalizers.Normalizer` + """ + pass + + def normalize_str(self, sequence): + """ + Normalize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.normalizers.Normalizer` but it does not keep track of the alignment + information. If you need to get/convert offsets, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize` + + Args: + sequence (:obj:`str`): + A string to normalize + + Returns: + :obj:`str`: A string after normalization + """ + pass + +class Lowercase(Normalizer): + """ + Lowercase Normalizer + """ + def __init__(self): + pass + + def normalize(self, normalized): + """ + Normalize a :class:`~tokenizers.NormalizedString` in-place + + This method allows to modify a :class:`~tokenizers.NormalizedString` to + keep track of the alignment information. If you just want to see the result + of the normalization on a raw string, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize_str` + + Args: + normalized (:class:`~tokenizers.NormalizedString`): + The normalized string on which to apply this + :class:`~tokenizers.normalizers.Normalizer` + """ + pass + + def normalize_str(self, sequence): + """ + Normalize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.normalizers.Normalizer` but it does not keep track of the alignment + information. If you need to get/convert offsets, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize` + + Args: + sequence (:obj:`str`): + A string to normalize + + Returns: + :obj:`str`: A string after normalization + """ + pass + +class NFC(Normalizer): + """ + NFC Unicode Normalizer + """ + def __init__(self): + pass + + def normalize(self, normalized): + """ + Normalize a :class:`~tokenizers.NormalizedString` in-place + + This method allows to modify a :class:`~tokenizers.NormalizedString` to + keep track of the alignment information. If you just want to see the result + of the normalization on a raw string, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize_str` + + Args: + normalized (:class:`~tokenizers.NormalizedString`): + The normalized string on which to apply this + :class:`~tokenizers.normalizers.Normalizer` + """ + pass + + def normalize_str(self, sequence): + """ + Normalize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.normalizers.Normalizer` but it does not keep track of the alignment + information. If you need to get/convert offsets, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize` + + Args: + sequence (:obj:`str`): + A string to normalize + + Returns: + :obj:`str`: A string after normalization + """ + pass + +class NFD(Normalizer): + """ + NFD Unicode Normalizer + """ + def __init__(self): + pass + + def normalize(self, normalized): + """ + Normalize a :class:`~tokenizers.NormalizedString` in-place + + This method allows to modify a :class:`~tokenizers.NormalizedString` to + keep track of the alignment information. If you just want to see the result + of the normalization on a raw string, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize_str` + + Args: + normalized (:class:`~tokenizers.NormalizedString`): + The normalized string on which to apply this + :class:`~tokenizers.normalizers.Normalizer` + """ + pass + + def normalize_str(self, sequence): + """ + Normalize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.normalizers.Normalizer` but it does not keep track of the alignment + information. If you need to get/convert offsets, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize` + + Args: + sequence (:obj:`str`): + A string to normalize + + Returns: + :obj:`str`: A string after normalization + """ + pass + +class NFKC(Normalizer): + """ + NFKC Unicode Normalizer + """ + def __init__(self): + pass + + def normalize(self, normalized): + """ + Normalize a :class:`~tokenizers.NormalizedString` in-place + + This method allows to modify a :class:`~tokenizers.NormalizedString` to + keep track of the alignment information. If you just want to see the result + of the normalization on a raw string, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize_str` + + Args: + normalized (:class:`~tokenizers.NormalizedString`): + The normalized string on which to apply this + :class:`~tokenizers.normalizers.Normalizer` + """ + pass + + def normalize_str(self, sequence): + """ + Normalize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.normalizers.Normalizer` but it does not keep track of the alignment + information. If you need to get/convert offsets, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize` + + Args: + sequence (:obj:`str`): + A string to normalize + + Returns: + :obj:`str`: A string after normalization + """ + pass + +class NFKD(Normalizer): + """ + NFKD Unicode Normalizer + """ + def __init__(self): + pass + + def normalize(self, normalized): + """ + Normalize a :class:`~tokenizers.NormalizedString` in-place + + This method allows to modify a :class:`~tokenizers.NormalizedString` to + keep track of the alignment information. If you just want to see the result + of the normalization on a raw string, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize_str` + + Args: + normalized (:class:`~tokenizers.NormalizedString`): + The normalized string on which to apply this + :class:`~tokenizers.normalizers.Normalizer` + """ + pass + + def normalize_str(self, sequence): + """ + Normalize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.normalizers.Normalizer` but it does not keep track of the alignment + information. If you need to get/convert offsets, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize` + + Args: + sequence (:obj:`str`): + A string to normalize + + Returns: + :obj:`str`: A string after normalization + """ + pass + +class Nmt(Normalizer): + """ + Nmt normalizer + """ + def __init__(self): + pass + + def normalize(self, normalized): + """ + Normalize a :class:`~tokenizers.NormalizedString` in-place + + This method allows to modify a :class:`~tokenizers.NormalizedString` to + keep track of the alignment information. If you just want to see the result + of the normalization on a raw string, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize_str` + + Args: + normalized (:class:`~tokenizers.NormalizedString`): + The normalized string on which to apply this + :class:`~tokenizers.normalizers.Normalizer` + """ + pass + + def normalize_str(self, sequence): + """ + Normalize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.normalizers.Normalizer` but it does not keep track of the alignment + information. If you need to get/convert offsets, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize` + + Args: + sequence (:obj:`str`): + A string to normalize + + Returns: + :obj:`str`: A string after normalization + """ + pass + +class Precompiled(Normalizer): + """ + Precompiled normalizer + Don't use manually it is used for compatibility for SentencePiece. + """ + def __init__(self, precompiled_charsmap): + pass + + def normalize(self, normalized): + """ + Normalize a :class:`~tokenizers.NormalizedString` in-place + + This method allows to modify a :class:`~tokenizers.NormalizedString` to + keep track of the alignment information. If you just want to see the result + of the normalization on a raw string, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize_str` + + Args: + normalized (:class:`~tokenizers.NormalizedString`): + The normalized string on which to apply this + :class:`~tokenizers.normalizers.Normalizer` + """ + pass + + def normalize_str(self, sequence): + """ + Normalize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.normalizers.Normalizer` but it does not keep track of the alignment + information. If you need to get/convert offsets, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize` + + Args: + sequence (:obj:`str`): + A string to normalize + + Returns: + :obj:`str`: A string after normalization + """ + pass + +class Prepend(Normalizer): + """ + Prepend normalizer + """ + def __init__(self, prepend): + pass + + def normalize(self, normalized): + """ + Normalize a :class:`~tokenizers.NormalizedString` in-place + + This method allows to modify a :class:`~tokenizers.NormalizedString` to + keep track of the alignment information. If you just want to see the result + of the normalization on a raw string, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize_str` + + Args: + normalized (:class:`~tokenizers.NormalizedString`): + The normalized string on which to apply this + :class:`~tokenizers.normalizers.Normalizer` + """ + pass + + def normalize_str(self, sequence): + """ + Normalize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.normalizers.Normalizer` but it does not keep track of the alignment + information. If you need to get/convert offsets, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize` + + Args: + sequence (:obj:`str`): + A string to normalize + + Returns: + :obj:`str`: A string after normalization + """ + pass + +class Replace(Normalizer): + """ + Replace normalizer + """ + def __init__(self, pattern, content): + pass + + def normalize(self, normalized): + """ + Normalize a :class:`~tokenizers.NormalizedString` in-place + + This method allows to modify a :class:`~tokenizers.NormalizedString` to + keep track of the alignment information. If you just want to see the result + of the normalization on a raw string, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize_str` + + Args: + normalized (:class:`~tokenizers.NormalizedString`): + The normalized string on which to apply this + :class:`~tokenizers.normalizers.Normalizer` + """ + pass + + def normalize_str(self, sequence): + """ + Normalize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.normalizers.Normalizer` but it does not keep track of the alignment + information. If you need to get/convert offsets, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize` + + Args: + sequence (:obj:`str`): + A string to normalize + + Returns: + :obj:`str`: A string after normalization + """ + pass + +class Sequence(Normalizer): + """ + Allows concatenating multiple other Normalizer as a Sequence. + All the normalizers run in sequence in the given order + + Args: + normalizers (:obj:`List[Normalizer]`): + A list of Normalizer to be run as a sequence + """ + def normalize(self, normalized): + """ + Normalize a :class:`~tokenizers.NormalizedString` in-place + + This method allows to modify a :class:`~tokenizers.NormalizedString` to + keep track of the alignment information. If you just want to see the result + of the normalization on a raw string, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize_str` + + Args: + normalized (:class:`~tokenizers.NormalizedString`): + The normalized string on which to apply this + :class:`~tokenizers.normalizers.Normalizer` + """ + pass + + def normalize_str(self, sequence): + """ + Normalize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.normalizers.Normalizer` but it does not keep track of the alignment + information. If you need to get/convert offsets, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize` + + Args: + sequence (:obj:`str`): + A string to normalize + + Returns: + :obj:`str`: A string after normalization + """ + pass + +class Strip(Normalizer): + """ + Strip normalizer + """ + def __init__(self, left=True, right=True): + pass + + def normalize(self, normalized): + """ + Normalize a :class:`~tokenizers.NormalizedString` in-place + + This method allows to modify a :class:`~tokenizers.NormalizedString` to + keep track of the alignment information. If you just want to see the result + of the normalization on a raw string, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize_str` + + Args: + normalized (:class:`~tokenizers.NormalizedString`): + The normalized string on which to apply this + :class:`~tokenizers.normalizers.Normalizer` + """ + pass + + def normalize_str(self, sequence): + """ + Normalize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.normalizers.Normalizer` but it does not keep track of the alignment + information. If you need to get/convert offsets, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize` + + Args: + sequence (:obj:`str`): + A string to normalize + + Returns: + :obj:`str`: A string after normalization + """ + pass + +class StripAccents(Normalizer): + """ + StripAccents normalizer + """ + def __init__(self): + pass + + def normalize(self, normalized): + """ + Normalize a :class:`~tokenizers.NormalizedString` in-place + + This method allows to modify a :class:`~tokenizers.NormalizedString` to + keep track of the alignment information. If you just want to see the result + of the normalization on a raw string, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize_str` + + Args: + normalized (:class:`~tokenizers.NormalizedString`): + The normalized string on which to apply this + :class:`~tokenizers.normalizers.Normalizer` + """ + pass + + def normalize_str(self, sequence): + """ + Normalize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.normalizers.Normalizer` but it does not keep track of the alignment + information. If you need to get/convert offsets, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize` + + Args: + sequence (:obj:`str`): + A string to normalize + + Returns: + :obj:`str`: A string after normalization + """ + pass diff --git a/phivenv/Lib/site-packages/tokenizers/normalizers/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/tokenizers/normalizers/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..795c259fbdb97f0c902c3fbd383a0371090dee5a Binary files /dev/null and b/phivenv/Lib/site-packages/tokenizers/normalizers/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/tokenizers/pre_tokenizers/__init__.py b/phivenv/Lib/site-packages/tokenizers/pre_tokenizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..156a980cae8685c0b2ad2466a47087c5a95d75a5 --- /dev/null +++ b/phivenv/Lib/site-packages/tokenizers/pre_tokenizers/__init__.py @@ -0,0 +1,16 @@ +# Generated content DO NOT EDIT +from .. import pre_tokenizers + +PreTokenizer = pre_tokenizers.PreTokenizer +BertPreTokenizer = pre_tokenizers.BertPreTokenizer +ByteLevel = pre_tokenizers.ByteLevel +CharDelimiterSplit = pre_tokenizers.CharDelimiterSplit +Digits = pre_tokenizers.Digits +FixedLength = pre_tokenizers.FixedLength +Metaspace = pre_tokenizers.Metaspace +Punctuation = pre_tokenizers.Punctuation +Sequence = pre_tokenizers.Sequence +Split = pre_tokenizers.Split +UnicodeScripts = pre_tokenizers.UnicodeScripts +Whitespace = pre_tokenizers.Whitespace +WhitespaceSplit = pre_tokenizers.WhitespaceSplit diff --git a/phivenv/Lib/site-packages/tokenizers/pre_tokenizers/__init__.pyi b/phivenv/Lib/site-packages/tokenizers/pre_tokenizers/__init__.pyi new file mode 100644 index 0000000000000000000000000000000000000000..0bd82bbd1398eb495fe73c4c9fadb9968958c90c --- /dev/null +++ b/phivenv/Lib/site-packages/tokenizers/pre_tokenizers/__init__.pyi @@ -0,0 +1,689 @@ +# Generated content DO NOT EDIT +class PreTokenizer: + """ + Base class for all pre-tokenizers + + This class is not supposed to be instantiated directly. Instead, any implementation of a + PreTokenizer will return an instance of this class when instantiated. + """ + def pre_tokenize(self, pretok): + """ + Pre-tokenize a :class:`~tokenizers.PyPreTokenizedString` in-place + + This method allows to modify a :class:`~tokenizers.PreTokenizedString` to + keep track of the pre-tokenization, and leverage the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you just want to see the result of + the pre-tokenization of a raw string, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize_str` + + Args: + pretok (:class:`~tokenizers.PreTokenizedString): + The pre-tokenized string on which to apply this + :class:`~tokenizers.pre_tokenizers.PreTokenizer` + """ + pass + + def pre_tokenize_str(self, sequence): + """ + Pre tokenize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.pre_tokenizers.PreTokenizer` but it does not keep track of the + alignment, nor does it provide all the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you need some of these, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize` + + Args: + sequence (:obj:`str`): + A string to pre-tokeize + + Returns: + :obj:`List[Tuple[str, Offsets]]`: + A list of tuple with the pre-tokenized parts and their offsets + """ + pass + +class BertPreTokenizer(PreTokenizer): + """ + BertPreTokenizer + + This pre-tokenizer splits tokens on spaces, and also on punctuation. + Each occurrence of a punctuation character will be treated separately. + """ + def __init__(self): + pass + + def pre_tokenize(self, pretok): + """ + Pre-tokenize a :class:`~tokenizers.PyPreTokenizedString` in-place + + This method allows to modify a :class:`~tokenizers.PreTokenizedString` to + keep track of the pre-tokenization, and leverage the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you just want to see the result of + the pre-tokenization of a raw string, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize_str` + + Args: + pretok (:class:`~tokenizers.PreTokenizedString): + The pre-tokenized string on which to apply this + :class:`~tokenizers.pre_tokenizers.PreTokenizer` + """ + pass + + def pre_tokenize_str(self, sequence): + """ + Pre tokenize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.pre_tokenizers.PreTokenizer` but it does not keep track of the + alignment, nor does it provide all the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you need some of these, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize` + + Args: + sequence (:obj:`str`): + A string to pre-tokeize + + Returns: + :obj:`List[Tuple[str, Offsets]]`: + A list of tuple with the pre-tokenized parts and their offsets + """ + pass + +class ByteLevel(PreTokenizer): + """ + ByteLevel PreTokenizer + + This pre-tokenizer takes care of replacing all bytes of the given string + with a corresponding representation, as well as splitting into words. + + Args: + add_prefix_space (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to add a space to the first word if there isn't already one. This + lets us treat `hello` exactly like `say hello`. + use_regex (:obj:`bool`, `optional`, defaults to :obj:`True`): + Set this to :obj:`False` to prevent this `pre_tokenizer` from using + the GPT2 specific regexp for spliting on whitespace. + """ + def __init__(self, add_prefix_space=True, use_regex=True): + pass + + @staticmethod + def alphabet(): + """ + Returns the alphabet used by this PreTokenizer. + + Since the ByteLevel works as its name suggests, at the byte level, it + encodes each byte value to a unique visible character. This means that there is a + total of 256 different characters composing this alphabet. + + Returns: + :obj:`List[str]`: A list of characters that compose the alphabet + """ + pass + + def pre_tokenize(self, pretok): + """ + Pre-tokenize a :class:`~tokenizers.PyPreTokenizedString` in-place + + This method allows to modify a :class:`~tokenizers.PreTokenizedString` to + keep track of the pre-tokenization, and leverage the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you just want to see the result of + the pre-tokenization of a raw string, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize_str` + + Args: + pretok (:class:`~tokenizers.PreTokenizedString): + The pre-tokenized string on which to apply this + :class:`~tokenizers.pre_tokenizers.PreTokenizer` + """ + pass + + def pre_tokenize_str(self, sequence): + """ + Pre tokenize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.pre_tokenizers.PreTokenizer` but it does not keep track of the + alignment, nor does it provide all the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you need some of these, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize` + + Args: + sequence (:obj:`str`): + A string to pre-tokeize + + Returns: + :obj:`List[Tuple[str, Offsets]]`: + A list of tuple with the pre-tokenized parts and their offsets + """ + pass + +class CharDelimiterSplit(PreTokenizer): + """ + This pre-tokenizer simply splits on the provided char. Works like `.split(delimiter)` + + Args: + delimiter: str: + The delimiter char that will be used to split input + """ + def pre_tokenize(self, pretok): + """ + Pre-tokenize a :class:`~tokenizers.PyPreTokenizedString` in-place + + This method allows to modify a :class:`~tokenizers.PreTokenizedString` to + keep track of the pre-tokenization, and leverage the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you just want to see the result of + the pre-tokenization of a raw string, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize_str` + + Args: + pretok (:class:`~tokenizers.PreTokenizedString): + The pre-tokenized string on which to apply this + :class:`~tokenizers.pre_tokenizers.PreTokenizer` + """ + pass + + def pre_tokenize_str(self, sequence): + """ + Pre tokenize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.pre_tokenizers.PreTokenizer` but it does not keep track of the + alignment, nor does it provide all the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you need some of these, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize` + + Args: + sequence (:obj:`str`): + A string to pre-tokeize + + Returns: + :obj:`List[Tuple[str, Offsets]]`: + A list of tuple with the pre-tokenized parts and their offsets + """ + pass + +class Digits(PreTokenizer): + """ + This pre-tokenizer simply splits using the digits in separate tokens + + Args: + individual_digits (:obj:`bool`, `optional`, defaults to :obj:`False`): + If set to True, digits will each be separated as follows:: + + "Call 123 please" -> "Call ", "1", "2", "3", " please" + + If set to False, digits will grouped as follows:: + + "Call 123 please" -> "Call ", "123", " please" + """ + def __init__(self, individual_digits=False): + pass + + def pre_tokenize(self, pretok): + """ + Pre-tokenize a :class:`~tokenizers.PyPreTokenizedString` in-place + + This method allows to modify a :class:`~tokenizers.PreTokenizedString` to + keep track of the pre-tokenization, and leverage the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you just want to see the result of + the pre-tokenization of a raw string, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize_str` + + Args: + pretok (:class:`~tokenizers.PreTokenizedString): + The pre-tokenized string on which to apply this + :class:`~tokenizers.pre_tokenizers.PreTokenizer` + """ + pass + + def pre_tokenize_str(self, sequence): + """ + Pre tokenize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.pre_tokenizers.PreTokenizer` but it does not keep track of the + alignment, nor does it provide all the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you need some of these, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize` + + Args: + sequence (:obj:`str`): + A string to pre-tokeize + + Returns: + :obj:`List[Tuple[str, Offsets]]`: + A list of tuple with the pre-tokenized parts and their offsets + """ + pass + +class FixedLength(PreTokenizer): + """ + This pre-tokenizer splits the text into fixed length chunks as used + [here](https://www.biorxiv.org/content/10.1101/2023.01.11.523679v1.full) + + Args: + length (:obj:`int`, `optional`, defaults to :obj:`5`): + The length of the chunks to split the text into. + + Strings are split on the character level rather than the byte level to avoid + splitting unicode characters consisting of multiple bytes. + """ + def __init__(self, length=5): + pass + + def pre_tokenize(self, pretok): + """ + Pre-tokenize a :class:`~tokenizers.PyPreTokenizedString` in-place + + This method allows to modify a :class:`~tokenizers.PreTokenizedString` to + keep track of the pre-tokenization, and leverage the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you just want to see the result of + the pre-tokenization of a raw string, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize_str` + + Args: + pretok (:class:`~tokenizers.PreTokenizedString): + The pre-tokenized string on which to apply this + :class:`~tokenizers.pre_tokenizers.PreTokenizer` + """ + pass + + def pre_tokenize_str(self, sequence): + """ + Pre tokenize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.pre_tokenizers.PreTokenizer` but it does not keep track of the + alignment, nor does it provide all the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you need some of these, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize` + + Args: + sequence (:obj:`str`): + A string to pre-tokeize + + Returns: + :obj:`List[Tuple[str, Offsets]]`: + A list of tuple with the pre-tokenized parts and their offsets + """ + pass + +class Metaspace(PreTokenizer): + """ + Metaspace pre-tokenizer + + This pre-tokenizer replaces any whitespace by the provided replacement character. + It then tries to split on these spaces. + + Args: + replacement (:obj:`str`, `optional`, defaults to :obj:`▁`): + The replacement character. Must be exactly one character. By default we + use the `▁` (U+2581) meta symbol (Same as in SentencePiece). + + prepend_scheme (:obj:`str`, `optional`, defaults to :obj:`"always"`): + Whether to add a space to the first word if there isn't already one. This + lets us treat `hello` exactly like `say hello`. + Choices: "always", "never", "first". First means the space is only added on the first + token (relevant when special tokens are used or other pre_tokenizer are used). + + """ + def __init__(self, replacement="_", prepend_scheme="always", split=True): + pass + + def pre_tokenize(self, pretok): + """ + Pre-tokenize a :class:`~tokenizers.PyPreTokenizedString` in-place + + This method allows to modify a :class:`~tokenizers.PreTokenizedString` to + keep track of the pre-tokenization, and leverage the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you just want to see the result of + the pre-tokenization of a raw string, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize_str` + + Args: + pretok (:class:`~tokenizers.PreTokenizedString): + The pre-tokenized string on which to apply this + :class:`~tokenizers.pre_tokenizers.PreTokenizer` + """ + pass + + def pre_tokenize_str(self, sequence): + """ + Pre tokenize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.pre_tokenizers.PreTokenizer` but it does not keep track of the + alignment, nor does it provide all the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you need some of these, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize` + + Args: + sequence (:obj:`str`): + A string to pre-tokeize + + Returns: + :obj:`List[Tuple[str, Offsets]]`: + A list of tuple with the pre-tokenized parts and their offsets + """ + pass + +class Punctuation(PreTokenizer): + """ + This pre-tokenizer simply splits on punctuation as individual characters. + + Args: + behavior (:class:`~tokenizers.SplitDelimiterBehavior`): + The behavior to use when splitting. + Choices: "removed", "isolated" (default), "merged_with_previous", "merged_with_next", + "contiguous" + """ + def __init__(self, behavior="isolated"): + pass + + def pre_tokenize(self, pretok): + """ + Pre-tokenize a :class:`~tokenizers.PyPreTokenizedString` in-place + + This method allows to modify a :class:`~tokenizers.PreTokenizedString` to + keep track of the pre-tokenization, and leverage the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you just want to see the result of + the pre-tokenization of a raw string, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize_str` + + Args: + pretok (:class:`~tokenizers.PreTokenizedString): + The pre-tokenized string on which to apply this + :class:`~tokenizers.pre_tokenizers.PreTokenizer` + """ + pass + + def pre_tokenize_str(self, sequence): + """ + Pre tokenize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.pre_tokenizers.PreTokenizer` but it does not keep track of the + alignment, nor does it provide all the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you need some of these, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize` + + Args: + sequence (:obj:`str`): + A string to pre-tokeize + + Returns: + :obj:`List[Tuple[str, Offsets]]`: + A list of tuple with the pre-tokenized parts and their offsets + """ + pass + +class Sequence(PreTokenizer): + """ + This pre-tokenizer composes other pre_tokenizers and applies them in sequence + """ + def __init__(self, pretokenizers): + pass + + def pre_tokenize(self, pretok): + """ + Pre-tokenize a :class:`~tokenizers.PyPreTokenizedString` in-place + + This method allows to modify a :class:`~tokenizers.PreTokenizedString` to + keep track of the pre-tokenization, and leverage the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you just want to see the result of + the pre-tokenization of a raw string, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize_str` + + Args: + pretok (:class:`~tokenizers.PreTokenizedString): + The pre-tokenized string on which to apply this + :class:`~tokenizers.pre_tokenizers.PreTokenizer` + """ + pass + + def pre_tokenize_str(self, sequence): + """ + Pre tokenize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.pre_tokenizers.PreTokenizer` but it does not keep track of the + alignment, nor does it provide all the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you need some of these, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize` + + Args: + sequence (:obj:`str`): + A string to pre-tokeize + + Returns: + :obj:`List[Tuple[str, Offsets]]`: + A list of tuple with the pre-tokenized parts and their offsets + """ + pass + +class Split(PreTokenizer): + """ + Split PreTokenizer + + This versatile pre-tokenizer splits using the provided pattern and + according to the provided behavior. The pattern can be inverted by + making use of the invert flag. + + Args: + pattern (:obj:`str` or :class:`~tokenizers.Regex`): + A pattern used to split the string. Usually a string or a regex built with `tokenizers.Regex`. + If you want to use a regex pattern, it has to be wrapped around a `tokenizers.Regex`, + otherwise we consider is as a string pattern. For example `pattern="|"` + means you want to split on `|` (imagine a csv file for example), while + `pattern=tokenizers.Regex("1|2")` means you split on either '1' or '2'. + behavior (:class:`~tokenizers.SplitDelimiterBehavior`): + The behavior to use when splitting. + Choices: "removed", "isolated", "merged_with_previous", "merged_with_next", + "contiguous" + + invert (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to invert the pattern. + """ + def __init__(self, pattern, behavior, invert=False): + pass + + def pre_tokenize(self, pretok): + """ + Pre-tokenize a :class:`~tokenizers.PyPreTokenizedString` in-place + + This method allows to modify a :class:`~tokenizers.PreTokenizedString` to + keep track of the pre-tokenization, and leverage the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you just want to see the result of + the pre-tokenization of a raw string, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize_str` + + Args: + pretok (:class:`~tokenizers.PreTokenizedString): + The pre-tokenized string on which to apply this + :class:`~tokenizers.pre_tokenizers.PreTokenizer` + """ + pass + + def pre_tokenize_str(self, sequence): + """ + Pre tokenize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.pre_tokenizers.PreTokenizer` but it does not keep track of the + alignment, nor does it provide all the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you need some of these, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize` + + Args: + sequence (:obj:`str`): + A string to pre-tokeize + + Returns: + :obj:`List[Tuple[str, Offsets]]`: + A list of tuple with the pre-tokenized parts and their offsets + """ + pass + +class UnicodeScripts(PreTokenizer): + """ + This pre-tokenizer splits on characters that belong to different language family + It roughly follows https://github.com/google/sentencepiece/blob/master/data/Scripts.txt + Actually Hiragana and Katakana are fused with Han, and 0x30FC is Han too. + This mimicks SentencePiece Unigram implementation. + """ + def __init__(self): + pass + + def pre_tokenize(self, pretok): + """ + Pre-tokenize a :class:`~tokenizers.PyPreTokenizedString` in-place + + This method allows to modify a :class:`~tokenizers.PreTokenizedString` to + keep track of the pre-tokenization, and leverage the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you just want to see the result of + the pre-tokenization of a raw string, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize_str` + + Args: + pretok (:class:`~tokenizers.PreTokenizedString): + The pre-tokenized string on which to apply this + :class:`~tokenizers.pre_tokenizers.PreTokenizer` + """ + pass + + def pre_tokenize_str(self, sequence): + """ + Pre tokenize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.pre_tokenizers.PreTokenizer` but it does not keep track of the + alignment, nor does it provide all the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you need some of these, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize` + + Args: + sequence (:obj:`str`): + A string to pre-tokeize + + Returns: + :obj:`List[Tuple[str, Offsets]]`: + A list of tuple with the pre-tokenized parts and their offsets + """ + pass + +class Whitespace(PreTokenizer): + """ + This pre-tokenizer splits on word boundaries according to the `\w+|[^\w\s]+` + regex pattern. It splits on word characters or characters that aren't words or + whitespaces (punctuation such as hyphens, apostrophes, commas, etc.). + + Example: + Use the `Whitespace` function as shown below:: + + ```python + from tokenizers.pre_tokenizers import Whitespace + + pre_tokenizer = Whitespace() + text = "Hello, world! Let's try the Whitespace pre-tokenizer." + pre_tokenizer.pre_tokenize_str(text) + [('Hello', (0, 5)), + (',', (5, 6)), + ('world', (7, 12)), + ('!', (12, 13)), + ('Let', (14, 17)), + ("'", (17, 18)), + ('s', (18, 19)), + ('try', (20, 23)), + ('the', (24, 27)), + ('Whitespace', (28, 38)), + ('pre', (39, 42)), + ('-', (42, 43)), + ('tokenizer', (43, 52)), + ('.', (52, 53))] + ``` + """ + def __init__(self): + pass + + def pre_tokenize(self, pretok): + """ + Pre-tokenize a :class:`~tokenizers.PyPreTokenizedString` in-place + + This method allows to modify a :class:`~tokenizers.PreTokenizedString` to + keep track of the pre-tokenization, and leverage the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you just want to see the result of + the pre-tokenization of a raw string, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize_str` + + Args: + pretok (:class:`~tokenizers.PreTokenizedString): + The pre-tokenized string on which to apply this + :class:`~tokenizers.pre_tokenizers.PreTokenizer` + """ + pass + + def pre_tokenize_str(self, sequence): + """ + Pre tokenize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.pre_tokenizers.PreTokenizer` but it does not keep track of the + alignment, nor does it provide all the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you need some of these, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize` + + Args: + sequence (:obj:`str`): + A string to pre-tokeize + + Returns: + :obj:`List[Tuple[str, Offsets]]`: + A list of tuple with the pre-tokenized parts and their offsets + """ + pass + +class WhitespaceSplit(PreTokenizer): + """ + This pre-tokenizer simply splits on the whitespace. Works like `.split()` + """ + def __init__(self): + pass + + def pre_tokenize(self, pretok): + """ + Pre-tokenize a :class:`~tokenizers.PyPreTokenizedString` in-place + + This method allows to modify a :class:`~tokenizers.PreTokenizedString` to + keep track of the pre-tokenization, and leverage the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you just want to see the result of + the pre-tokenization of a raw string, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize_str` + + Args: + pretok (:class:`~tokenizers.PreTokenizedString): + The pre-tokenized string on which to apply this + :class:`~tokenizers.pre_tokenizers.PreTokenizer` + """ + pass + + def pre_tokenize_str(self, sequence): + """ + Pre tokenize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.pre_tokenizers.PreTokenizer` but it does not keep track of the + alignment, nor does it provide all the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you need some of these, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize` + + Args: + sequence (:obj:`str`): + A string to pre-tokeize + + Returns: + :obj:`List[Tuple[str, Offsets]]`: + A list of tuple with the pre-tokenized parts and their offsets + """ + pass diff --git a/phivenv/Lib/site-packages/tokenizers/pre_tokenizers/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/tokenizers/pre_tokenizers/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6bb387c01c81994a793dde2e17a1afe605aabcd8 Binary files /dev/null and b/phivenv/Lib/site-packages/tokenizers/pre_tokenizers/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/tokenizers/processors/__init__.py b/phivenv/Lib/site-packages/tokenizers/processors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9c8feef829cd34819b7b872a79004990d3abf2fc --- /dev/null +++ b/phivenv/Lib/site-packages/tokenizers/processors/__init__.py @@ -0,0 +1,9 @@ +# Generated content DO NOT EDIT +from .. import processors + +PostProcessor = processors.PostProcessor +BertProcessing = processors.BertProcessing +ByteLevel = processors.ByteLevel +RobertaProcessing = processors.RobertaProcessing +Sequence = processors.Sequence +TemplateProcessing = processors.TemplateProcessing diff --git a/phivenv/Lib/site-packages/tokenizers/processors/__init__.pyi b/phivenv/Lib/site-packages/tokenizers/processors/__init__.pyi new file mode 100644 index 0000000000000000000000000000000000000000..4dffe6ede8a372388efc6726a34a8f941806d4d1 --- /dev/null +++ b/phivenv/Lib/site-packages/tokenizers/processors/__init__.pyi @@ -0,0 +1,342 @@ +# Generated content DO NOT EDIT +class PostProcessor: + """ + Base class for all post-processors + + This class is not supposed to be instantiated directly. Instead, any implementation of + a PostProcessor will return an instance of this class when instantiated. + """ + def num_special_tokens_to_add(self, is_pair): + """ + Return the number of special tokens that would be added for single/pair sentences. + + Args: + is_pair (:obj:`bool`): + Whether the input would be a pair of sequences + + Returns: + :obj:`int`: The number of tokens to add + """ + pass + + def process(self, encoding, pair=None, add_special_tokens=True): + """ + Post-process the given encodings, generating the final one + + Args: + encoding (:class:`~tokenizers.Encoding`): + The encoding for the first sequence + + pair (:class:`~tokenizers.Encoding`, `optional`): + The encoding for the pair sequence + + add_special_tokens (:obj:`bool`): + Whether to add the special tokens + + Return: + :class:`~tokenizers.Encoding`: The final encoding + """ + pass + +class BertProcessing(PostProcessor): + """ + This post-processor takes care of adding the special tokens needed by + a Bert model: + + - a SEP token + - a CLS token + + Args: + sep (:obj:`Tuple[str, int]`): + A tuple with the string representation of the SEP token, and its id + + cls (:obj:`Tuple[str, int]`): + A tuple with the string representation of the CLS token, and its id + """ + def __init__(self, sep, cls): + pass + + def num_special_tokens_to_add(self, is_pair): + """ + Return the number of special tokens that would be added for single/pair sentences. + + Args: + is_pair (:obj:`bool`): + Whether the input would be a pair of sequences + + Returns: + :obj:`int`: The number of tokens to add + """ + pass + + def process(self, encoding, pair=None, add_special_tokens=True): + """ + Post-process the given encodings, generating the final one + + Args: + encoding (:class:`~tokenizers.Encoding`): + The encoding for the first sequence + + pair (:class:`~tokenizers.Encoding`, `optional`): + The encoding for the pair sequence + + add_special_tokens (:obj:`bool`): + Whether to add the special tokens + + Return: + :class:`~tokenizers.Encoding`: The final encoding + """ + pass + +class ByteLevel(PostProcessor): + """ + This post-processor takes care of trimming the offsets. + + By default, the ByteLevel BPE might include whitespaces in the produced tokens. If you don't + want the offsets to include these whitespaces, then this PostProcessor must be used. + + Args: + trim_offsets (:obj:`bool`): + Whether to trim the whitespaces from the produced offsets. + """ + def __init__(self, trim_offsets=True): + pass + + def num_special_tokens_to_add(self, is_pair): + """ + Return the number of special tokens that would be added for single/pair sentences. + + Args: + is_pair (:obj:`bool`): + Whether the input would be a pair of sequences + + Returns: + :obj:`int`: The number of tokens to add + """ + pass + + def process(self, encoding, pair=None, add_special_tokens=True): + """ + Post-process the given encodings, generating the final one + + Args: + encoding (:class:`~tokenizers.Encoding`): + The encoding for the first sequence + + pair (:class:`~tokenizers.Encoding`, `optional`): + The encoding for the pair sequence + + add_special_tokens (:obj:`bool`): + Whether to add the special tokens + + Return: + :class:`~tokenizers.Encoding`: The final encoding + """ + pass + +class RobertaProcessing(PostProcessor): + """ + This post-processor takes care of adding the special tokens needed by + a Roberta model: + + - a SEP token + - a CLS token + + It also takes care of trimming the offsets. + By default, the ByteLevel BPE might include whitespaces in the produced tokens. If you don't + want the offsets to include these whitespaces, then this PostProcessor should be initialized + with :obj:`trim_offsets=True` + + Args: + sep (:obj:`Tuple[str, int]`): + A tuple with the string representation of the SEP token, and its id + + cls (:obj:`Tuple[str, int]`): + A tuple with the string representation of the CLS token, and its id + + trim_offsets (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to trim the whitespaces from the produced offsets. + + add_prefix_space (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether the add_prefix_space option was enabled during pre-tokenization. This + is relevant because it defines the way the offsets are trimmed out. + """ + def __init__(self, sep, cls, trim_offsets=True, add_prefix_space=True): + pass + + def num_special_tokens_to_add(self, is_pair): + """ + Return the number of special tokens that would be added for single/pair sentences. + + Args: + is_pair (:obj:`bool`): + Whether the input would be a pair of sequences + + Returns: + :obj:`int`: The number of tokens to add + """ + pass + + def process(self, encoding, pair=None, add_special_tokens=True): + """ + Post-process the given encodings, generating the final one + + Args: + encoding (:class:`~tokenizers.Encoding`): + The encoding for the first sequence + + pair (:class:`~tokenizers.Encoding`, `optional`): + The encoding for the pair sequence + + add_special_tokens (:obj:`bool`): + Whether to add the special tokens + + Return: + :class:`~tokenizers.Encoding`: The final encoding + """ + pass + +class Sequence(PostProcessor): + """ + Sequence Processor + + Args: + processors (:obj:`List[PostProcessor]`) + The processors that need to be chained + """ + def __init__(self, processors): + pass + + def num_special_tokens_to_add(self, is_pair): + """ + Return the number of special tokens that would be added for single/pair sentences. + + Args: + is_pair (:obj:`bool`): + Whether the input would be a pair of sequences + + Returns: + :obj:`int`: The number of tokens to add + """ + pass + + def process(self, encoding, pair=None, add_special_tokens=True): + """ + Post-process the given encodings, generating the final one + + Args: + encoding (:class:`~tokenizers.Encoding`): + The encoding for the first sequence + + pair (:class:`~tokenizers.Encoding`, `optional`): + The encoding for the pair sequence + + add_special_tokens (:obj:`bool`): + Whether to add the special tokens + + Return: + :class:`~tokenizers.Encoding`: The final encoding + """ + pass + +class TemplateProcessing(PostProcessor): + """ + Provides a way to specify templates in order to add the special tokens to each + input sequence as relevant. + + Let's take :obj:`BERT` tokenizer as an example. It uses two special tokens, used to + delimitate each sequence. :obj:`[CLS]` is always used at the beginning of the first + sequence, and :obj:`[SEP]` is added at the end of both the first, and the pair + sequences. The final result looks like this: + + - Single sequence: :obj:`[CLS] Hello there [SEP]` + - Pair sequences: :obj:`[CLS] My name is Anthony [SEP] What is my name? [SEP]` + + With the type ids as following:: + + [CLS] ... [SEP] ... [SEP] + 0 0 0 1 1 + + You can achieve such behavior using a TemplateProcessing:: + + TemplateProcessing( + single="[CLS] $0 [SEP]", + pair="[CLS] $A [SEP] $B:1 [SEP]:1", + special_tokens=[("[CLS]", 1), ("[SEP]", 0)], + ) + + In this example, each input sequence is identified using a ``$`` construct. This identifier + lets us specify each input sequence, and the type_id to use. When nothing is specified, + it uses the default values. Here are the different ways to specify it: + + - Specifying the sequence, with default ``type_id == 0``: ``$A`` or ``$B`` + - Specifying the `type_id` with default ``sequence == A``: ``$0``, ``$1``, ``$2``, ... + - Specifying both: ``$A:0``, ``$B:1``, ... + + The same construct is used for special tokens: ``(:)?``. + + **Warning**: You must ensure that you are giving the correct tokens/ids as these + will be added to the Encoding without any further check. If the given ids correspond + to something totally different in a `Tokenizer` using this `PostProcessor`, it + might lead to unexpected results. + + Args: + single (:obj:`Template`): + The template used for single sequences + + pair (:obj:`Template`): + The template used when both sequences are specified + + special_tokens (:obj:`Tokens`): + The list of special tokens used in each sequences + + Types: + + Template (:obj:`str` or :obj:`List`): + - If a :obj:`str` is provided, the whitespace is used as delimiter between tokens + - If a :obj:`List[str]` is provided, a list of tokens + + Tokens (:obj:`List[Union[Tuple[int, str], Tuple[str, int], dict]]`): + - A :obj:`Tuple` with both a token and its associated ID, in any order + - A :obj:`dict` with the following keys: + - "id": :obj:`str` => The special token id, as specified in the Template + - "ids": :obj:`List[int]` => The associated IDs + - "tokens": :obj:`List[str]` => The associated tokens + + The given dict expects the provided :obj:`ids` and :obj:`tokens` lists to have + the same length. + """ + def __init__(self, single, pair, special_tokens): + pass + + def num_special_tokens_to_add(self, is_pair): + """ + Return the number of special tokens that would be added for single/pair sentences. + + Args: + is_pair (:obj:`bool`): + Whether the input would be a pair of sequences + + Returns: + :obj:`int`: The number of tokens to add + """ + pass + + def process(self, encoding, pair=None, add_special_tokens=True): + """ + Post-process the given encodings, generating the final one + + Args: + encoding (:class:`~tokenizers.Encoding`): + The encoding for the first sequence + + pair (:class:`~tokenizers.Encoding`, `optional`): + The encoding for the pair sequence + + add_special_tokens (:obj:`bool`): + Whether to add the special tokens + + Return: + :class:`~tokenizers.Encoding`: The final encoding + """ + pass diff --git a/phivenv/Lib/site-packages/tokenizers/processors/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/tokenizers/processors/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d2ebf185fb6bec521681742d7ada400a108678e Binary files /dev/null and b/phivenv/Lib/site-packages/tokenizers/processors/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/tokenizers/tools/__init__.py b/phivenv/Lib/site-packages/tokenizers/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0d166cb9015bd967669e50341c3966297c447735 --- /dev/null +++ b/phivenv/Lib/site-packages/tokenizers/tools/__init__.py @@ -0,0 +1 @@ +from .visualizer import Annotation, EncodingVisualizer diff --git a/phivenv/Lib/site-packages/tokenizers/tools/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/tokenizers/tools/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c431ba0779799ed55103da28e11c0c7a1db71c4b Binary files /dev/null and b/phivenv/Lib/site-packages/tokenizers/tools/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/tokenizers/tools/__pycache__/visualizer.cpython-39.pyc b/phivenv/Lib/site-packages/tokenizers/tools/__pycache__/visualizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20c90368ca113af6802155c60d31a9d34262e6fc Binary files /dev/null and b/phivenv/Lib/site-packages/tokenizers/tools/__pycache__/visualizer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/tokenizers/tools/visualizer-styles.css b/phivenv/Lib/site-packages/tokenizers/tools/visualizer-styles.css new file mode 100644 index 0000000000000000000000000000000000000000..c108d667b391694d94c15fd583b2491e438ec058 --- /dev/null +++ b/phivenv/Lib/site-packages/tokenizers/tools/visualizer-styles.css @@ -0,0 +1,170 @@ +.tokenized-text { + width:100%; + padding:2rem; + max-height: 400px; + overflow-y: auto; + box-sizing:border-box; + line-height:4rem; /* Lots of space between lines */ + font-family: "Roboto Light", "Ubuntu Light", "Ubuntu", monospace; + box-shadow: 2px 2px 2px rgba(0,0,0,0.2); + background-color: rgba(0,0,0,0.01); + letter-spacing:2px; /* Give some extra separation between chars */ +} +.non-token{ + /* White space and other things the tokenizer ignores*/ + white-space: pre; + letter-spacing:4px; + border-top:1px solid #A0A0A0; /* A gentle border on top and bottom makes tabs more ovious*/ + border-bottom:1px solid #A0A0A0; + line-height: 1rem; + height: calc(100% - 2px); +} + +.token { + white-space: pre; + position:relative; + color:black; + letter-spacing:2px; +} + +.annotation{ + white-space:nowrap; /* Important - ensures that annotations appears even if the annotated text wraps a line */ + border-radius:4px; + position:relative; + width:fit-content; +} +.annotation:before { + /*The before holds the text and the after holds the background*/ + z-index:1000; /* Make sure this is above the background */ + content:attr(data-label); /* The annotations label is on a data attribute */ + color:white; + position:absolute; + font-size:1rem; + text-align:center; + font-weight:bold; + + top:1.75rem; + line-height:0; + left:0; + width:100%; + padding:0.5rem 0; + /* These make it so an annotation doesn't stretch beyond the annotated text if the label is longer*/ + overflow: hidden; + white-space: nowrap; + text-overflow:ellipsis; +} + +.annotation:after { + content:attr(data-label); /* The content defines the width of the annotation*/ + position:absolute; + font-size:0.75rem; + text-align:center; + font-weight:bold; + text-overflow:ellipsis; + top:1.75rem; + line-height:0; + overflow: hidden; + white-space: nowrap; + + left:0; + width:100%; /* 100% of the parent, which is the annotation whose width is the tokens inside it*/ + + padding:0.5rem 0; + /* Nast hack below: + We set the annotations color in code because we don't know the colors at css time. + But you can't pass a color as a data attribute to get it into the pseudo element (this thing) + So to get around that, annotations have the color set on them with a style attribute and then we + can get the color with currentColor. + Annotations wrap tokens and tokens set the color back to black + */ + background-color: currentColor; +} +.annotation:hover::after, .annotation:hover::before{ + /* When the user hovers over an annotation expand the label to display in full + */ + min-width: fit-content; +} + +.annotation:hover{ + /* Emphasize the annotation start end with a border on hover*/ + border-color: currentColor; + border: 2px solid; +} +.special-token:not(:empty){ + /* + A none empty special token is like UNK (as opposed to CLS which has no representation in the text ) + */ + position:relative; +} +.special-token:empty::before{ + /* Special tokens that don't have text are displayed as pseudo elements so we dont select them with the mouse*/ + content:attr(data-stok); + background:#202020; + font-size:0.75rem; + color:white; + margin: 0 0.25rem; + padding: 0.25rem; + border-radius:4px +} + +.special-token:not(:empty):before { + /* Special tokens that have text (UNK) are displayed above the actual text*/ + content:attr(data-stok); + position:absolute; + bottom:1.75rem; + min-width:100%; + width:100%; + height:1rem; + line-height:1rem; + font-size:1rem; + text-align:center; + color:white; + font-weight:bold; + background:#202020; + border-radius:10%; +} +/* +We want to alternate the color of tokens, but we can't use nth child because tokens might be broken up by annotations +instead we apply even and odd class at generation time and color them that way + */ +.even-token{ + background:#DCDCDC ; + border: 1px solid #DCDCDC; +} +.odd-token{ + background:#A0A0A0; + border: 1px solid #A0A0A0; +} +.even-token.multi-token,.odd-token.multi-token{ + background: repeating-linear-gradient( + 45deg, + transparent, + transparent 1px, + #ccc 1px, + #ccc 1px + ), + /* on "bottom" */ + linear-gradient( + to bottom, + #FFB6C1, + #999 + ); +} + +.multi-token:hover::after { + content:"This char has more than 1 token"; /* The content defines the width of the annotation*/ + color:white; + background-color: black; + position:absolute; + font-size:0.75rem; + text-align:center; + font-weight:bold; + text-overflow:ellipsis; + top:1.75rem; + line-height:0; + overflow: hidden; + white-space: nowrap; + left:0; + width:fit-content; /* 100% of the parent, which is the annotation whose width is the tokens inside it*/ + padding:0.5rem 0; +} diff --git a/phivenv/Lib/site-packages/tokenizers/tools/visualizer.py b/phivenv/Lib/site-packages/tokenizers/tools/visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..87f27bf1f867a8c19eb0096ac4eedc41ec2c7f4c --- /dev/null +++ b/phivenv/Lib/site-packages/tokenizers/tools/visualizer.py @@ -0,0 +1,403 @@ +import itertools +import os +import re +from string import Template +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple + +from tokenizers import Encoding, Tokenizer + + +dirname = os.path.dirname(__file__) +css_filename = os.path.join(dirname, "visualizer-styles.css") +with open(css_filename) as f: + css = f.read() + + +class Annotation: + start: int + end: int + label: int + + def __init__(self, start: int, end: int, label: str): + self.start = start + self.end = end + self.label = label + + +AnnotationList = List[Annotation] +PartialIntList = List[Optional[int]] + + +class CharStateKey(NamedTuple): + token_ix: Optional[int] + anno_ix: Optional[int] + + +class CharState: + char_ix: Optional[int] + + def __init__(self, char_ix): + self.char_ix = char_ix + + self.anno_ix: Optional[int] = None + self.tokens: List[int] = [] + + @property + def token_ix(self): + return self.tokens[0] if len(self.tokens) > 0 else None + + @property + def is_multitoken(self): + """ + BPE tokenizers can output more than one token for a char + """ + return len(self.tokens) > 1 + + def partition_key(self) -> CharStateKey: + return CharStateKey( + token_ix=self.token_ix, + anno_ix=self.anno_ix, + ) + + +class Aligned: + pass + + +class EncodingVisualizer: + """ + Build an EncodingVisualizer + + Args: + + tokenizer (:class:`~tokenizers.Tokenizer`): + A tokenizer instance + + default_to_notebook (:obj:`bool`): + Whether to render html output in a notebook by default + + annotation_converter (:obj:`Callable`, `optional`): + An optional (lambda) function that takes an annotation in any format and returns + an Annotation object + """ + + unk_token_regex = re.compile("(.{1}\b)?(unk|oov)(\b.{1})?", flags=re.IGNORECASE) + + def __init__( + self, + tokenizer: Tokenizer, + default_to_notebook: bool = True, + annotation_converter: Optional[Callable[[Any], Annotation]] = None, + ): + if default_to_notebook: + try: + from IPython.core.display import HTML, display + except ImportError: + raise Exception( + """We couldn't import IPython utils for html display. + Are you running in a notebook? + You can also pass `default_to_notebook=False` to get back raw HTML + """ + ) + + self.tokenizer = tokenizer + self.default_to_notebook = default_to_notebook + self.annotation_coverter = annotation_converter + pass + + def __call__( + self, + text: str, + annotations: AnnotationList = [], + default_to_notebook: Optional[bool] = None, + ) -> Optional[str]: + """ + Build a visualization of the given text + + Args: + text (:obj:`str`): + The text to tokenize + + annotations (:obj:`List[Annotation]`, `optional`): + An optional list of annotations of the text. The can either be an annotation class + or anything else if you instantiated the visualizer with a converter function + + default_to_notebook (:obj:`bool`, `optional`, defaults to `False`): + If True, will render the html in a notebook. Otherwise returns an html string. + + Returns: + The HTML string if default_to_notebook is False, otherwise (default) returns None and + renders the HTML in the notebook + + """ + final_default_to_notebook = self.default_to_notebook + if default_to_notebook is not None: + final_default_to_notebook = default_to_notebook + if final_default_to_notebook: + try: + from IPython.core.display import HTML, display + except ImportError: + raise Exception( + """We couldn't import IPython utils for html display. + Are you running in a notebook?""" + ) + if self.annotation_coverter is not None: + annotations = list(map(self.annotation_coverter, annotations)) + encoding = self.tokenizer.encode(text) + html = EncodingVisualizer.__make_html(text, encoding, annotations) + if final_default_to_notebook: + display(HTML(html)) + else: + return html + + @staticmethod + def calculate_label_colors(annotations: AnnotationList) -> Dict[str, str]: + """ + Generates a color palette for all the labels in a given set of annotations + + Args: + annotations (:obj:`Annotation`): + A list of annotations + + Returns: + :obj:`dict`: A dictionary mapping labels to colors in HSL format + """ + if len(annotations) == 0: + return {} + labels = set(map(lambda x: x.label, annotations)) + num_labels = len(labels) + h_step = int(255 / num_labels) + if h_step < 20: + h_step = 20 + s = 32 + l = 64 # noqa: E741 + h = 10 + colors = {} + + for label in sorted(labels): # sort so we always get the same colors for a given set of labels + colors[label] = f"hsl({h},{s}%,{l}%)" + h += h_step + return colors + + @staticmethod + def consecutive_chars_to_html( + consecutive_chars_list: List[CharState], + text: str, + encoding: Encoding, + ): + """ + Converts a list of "consecutive chars" into a single HTML element. + Chars are consecutive if they fall under the same word, token and annotation. + The CharState class is a named tuple with a "partition_key" method that makes it easy to + compare if two chars are consecutive. + + Args: + consecutive_chars_list (:obj:`List[CharState]`): + A list of CharStates that have been grouped together + + text (:obj:`str`): + The original text being processed + + encoding (:class:`~tokenizers.Encoding`): + The encoding returned from the tokenizer + + Returns: + :obj:`str`: The HTML span for a set of consecutive chars + """ + first = consecutive_chars_list[0] + if first.char_ix is None: + # its a special token + stoken = encoding.tokens[first.token_ix] + # special tokens are represented as empty spans. We use the data attribute and css + # magic to display it + return f'' + # We're not in a special token so this group has a start and end. + last = consecutive_chars_list[-1] + start = first.char_ix + end = last.char_ix + 1 + span_text = text[start:end] + css_classes = [] # What css classes will we apply on the resulting span + data_items = {} # What data attributes will we apply on the result span + if first.token_ix is not None: + # We can either be in a token or not (e.g. in white space) + css_classes.append("token") + if first.is_multitoken: + css_classes.append("multi-token") + if first.token_ix % 2: + # We use this to color alternating tokens. + # A token might be split by an annotation that ends in the middle of it, so this + # lets us visually indicate a consecutive token despite its possible splitting in + # the html markup + css_classes.append("odd-token") + else: + # Like above, but a different color so we can see the tokens alternate + css_classes.append("even-token") + if EncodingVisualizer.unk_token_regex.search(encoding.tokens[first.token_ix]) is not None: + # This is a special token that is in the text. probably UNK + css_classes.append("special-token") + # TODO is this the right name for the data attribute ? + data_items["stok"] = encoding.tokens[first.token_ix] + else: + # In this case we are looking at a group/single char that is not tokenized. + # e.g. white space + css_classes.append("non-token") + css = f'''class="{" ".join(css_classes)}"''' + data = "" + for key, val in data_items.items(): + data += f' data-{key}="{val}"' + return f"{span_text}" + + @staticmethod + def __make_html(text: str, encoding: Encoding, annotations: AnnotationList) -> str: + char_states = EncodingVisualizer.__make_char_states(text, encoding, annotations) + current_consecutive_chars = [char_states[0]] + prev_anno_ix = char_states[0].anno_ix + spans = [] + label_colors_dict = EncodingVisualizer.calculate_label_colors(annotations) + cur_anno_ix = char_states[0].anno_ix + if cur_anno_ix is not None: + # If we started in an annotation make a span for it + anno = annotations[cur_anno_ix] + label = anno.label + color = label_colors_dict[label] + spans.append(f'') + + for cs in char_states[1:]: + cur_anno_ix = cs.anno_ix + if cur_anno_ix != prev_anno_ix: + # If we've transitioned in or out of an annotation + spans.append( + # Create a span from the current consecutive characters + EncodingVisualizer.consecutive_chars_to_html( + current_consecutive_chars, + text=text, + encoding=encoding, + ) + ) + current_consecutive_chars = [cs] + + if prev_anno_ix is not None: + # if we transitioned out of an annotation close it's span + spans.append("") + if cur_anno_ix is not None: + # If we entered a new annotation make a span for it + anno = annotations[cur_anno_ix] + label = anno.label + color = label_colors_dict[label] + spans.append(f'') + prev_anno_ix = cur_anno_ix + + if cs.partition_key() == current_consecutive_chars[0].partition_key(): + # If the current charchter is in the same "group" as the previous one + current_consecutive_chars.append(cs) + else: + # Otherwise we make a span for the previous group + spans.append( + EncodingVisualizer.consecutive_chars_to_html( + current_consecutive_chars, + text=text, + encoding=encoding, + ) + ) + # An reset the consecutive_char_list to form a new group + current_consecutive_chars = [cs] + # All that's left is to fill out the final span + # TODO I think there is an edge case here where an annotation's span might not close + spans.append( + EncodingVisualizer.consecutive_chars_to_html( + current_consecutive_chars, + text=text, + encoding=encoding, + ) + ) + res = HTMLBody(spans) # Send the list of spans to the body of our html + return res + + @staticmethod + def __make_anno_map(text: str, annotations: AnnotationList) -> PartialIntList: + """ + Args: + text (:obj:`str`): + The raw text we want to align to + + annotations (:obj:`AnnotationList`): + A (possibly empty) list of annotations + + Returns: + A list of length len(text) whose entry at index i is None if there is no annotation on + character i or k, the index of the annotation that covers index i where k is with + respect to the list of annotations + """ + annotation_map = [None] * len(text) + for anno_ix, a in enumerate(annotations): + for i in range(a.start, a.end): + annotation_map[i] = anno_ix + return annotation_map + + @staticmethod + def __make_char_states(text: str, encoding: Encoding, annotations: AnnotationList) -> List[CharState]: + """ + For each character in the original text, we emit a tuple representing it's "state": + + * which token_ix it corresponds to + * which word_ix it corresponds to + * which annotation_ix it corresponds to + + Args: + text (:obj:`str`): + The raw text we want to align to + + annotations (:obj:`List[Annotation]`): + A (possibly empty) list of annotations + + encoding: (:class:`~tokenizers.Encoding`): + The encoding returned from the tokenizer + + Returns: + :obj:`List[CharState]`: A list of CharStates, indicating for each char in the text what + it's state is + """ + annotation_map = EncodingVisualizer.__make_anno_map(text, annotations) + # Todo make this a dataclass or named tuple + char_states: List[CharState] = [CharState(char_ix) for char_ix in range(len(text))] + for token_ix, token in enumerate(encoding.tokens): + offsets = encoding.token_to_chars(token_ix) + if offsets is not None: + start, end = offsets + for i in range(start, end): + char_states[i].tokens.append(token_ix) + for char_ix, anno_ix in enumerate(annotation_map): + char_states[char_ix].anno_ix = anno_ix + + return char_states + + +def HTMLBody(children: List[str], css_styles=css) -> str: + """ + Generates the full html with css from a list of html spans + + Args: + children (:obj:`List[str]`): + A list of strings, assumed to be html elements + + css_styles (:obj:`str`, `optional`): + Optional alternative implementation of the css + + Returns: + :obj:`str`: An HTML string with style markup + """ + children_text = "".join(children) + return f""" + + + + + +
+ {children_text} +
+ + + """ diff --git a/phivenv/Lib/site-packages/tokenizers/trainers/__init__.py b/phivenv/Lib/site-packages/tokenizers/trainers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..92d58beee3d4502a7c7df02d74a2cd3c1fa67714 --- /dev/null +++ b/phivenv/Lib/site-packages/tokenizers/trainers/__init__.py @@ -0,0 +1,8 @@ +# Generated content DO NOT EDIT +from .. import trainers + +Trainer = trainers.Trainer +BpeTrainer = trainers.BpeTrainer +UnigramTrainer = trainers.UnigramTrainer +WordLevelTrainer = trainers.WordLevelTrainer +WordPieceTrainer = trainers.WordPieceTrainer diff --git a/phivenv/Lib/site-packages/tokenizers/trainers/__init__.pyi b/phivenv/Lib/site-packages/tokenizers/trainers/__init__.pyi new file mode 100644 index 0000000000000000000000000000000000000000..2337fac99104cc14261111deccf05bad806181b3 --- /dev/null +++ b/phivenv/Lib/site-packages/tokenizers/trainers/__init__.pyi @@ -0,0 +1,156 @@ +# Generated content DO NOT EDIT +class Trainer: + """ + Base class for all trainers + + This class is not supposed to be instantiated directly. Instead, any implementation of a + Trainer will return an instance of this class when instantiated. + """ + +class BpeTrainer(Trainer): + """ + Trainer capable of training a BPE model + + Args: + vocab_size (:obj:`int`, `optional`): + The size of the final vocabulary, including all tokens and alphabet. + + min_frequency (:obj:`int`, `optional`): + The minimum frequency a pair should have in order to be merged. + + show_progress (:obj:`bool`, `optional`): + Whether to show progress bars while training. + + special_tokens (:obj:`List[Union[str, AddedToken]]`, `optional`): + A list of special tokens the model should know of. + + limit_alphabet (:obj:`int`, `optional`): + The maximum different characters to keep in the alphabet. + + initial_alphabet (:obj:`List[str]`, `optional`): + A list of characters to include in the initial alphabet, even + if not seen in the training dataset. + If the strings contain more than one character, only the first one + is kept. + + continuing_subword_prefix (:obj:`str`, `optional`): + A prefix to be used for every subword that is not a beginning-of-word. + + end_of_word_suffix (:obj:`str`, `optional`): + A suffix to be used for every subword that is a end-of-word. + + max_token_length (:obj:`int`, `optional`): + Prevents creating tokens longer than the specified size. + This can help with reducing polluting your vocabulary with + highly repetitive tokens like `======` for wikipedia + + """ + +class UnigramTrainer(Trainer): + """ + Trainer capable of training a Unigram model + + Args: + vocab_size (:obj:`int`): + The size of the final vocabulary, including all tokens and alphabet. + + show_progress (:obj:`bool`): + Whether to show progress bars while training. + + special_tokens (:obj:`List[Union[str, AddedToken]]`): + A list of special tokens the model should know of. + + initial_alphabet (:obj:`List[str]`): + A list of characters to include in the initial alphabet, even + if not seen in the training dataset. + If the strings contain more than one character, only the first one + is kept. + + shrinking_factor (:obj:`float`): + The shrinking factor used at each step of the training to prune the + vocabulary. + + unk_token (:obj:`str`): + The token used for out-of-vocabulary tokens. + + max_piece_length (:obj:`int`): + The maximum length of a given token. + + n_sub_iterations (:obj:`int`): + The number of iterations of the EM algorithm to perform before + pruning the vocabulary. + """ + def __init__( + self, + vocab_size=8000, + show_progress=True, + special_tokens=[], + shrinking_factor=0.75, + unk_token=None, + max_piece_length=16, + n_sub_iterations=2, + ): + pass + +class WordLevelTrainer(Trainer): + """ + Trainer capable of training a WorldLevel model + + Args: + vocab_size (:obj:`int`, `optional`): + The size of the final vocabulary, including all tokens and alphabet. + + min_frequency (:obj:`int`, `optional`): + The minimum frequency a pair should have in order to be merged. + + show_progress (:obj:`bool`, `optional`): + Whether to show progress bars while training. + + special_tokens (:obj:`List[Union[str, AddedToken]]`): + A list of special tokens the model should know of. + """ + +class WordPieceTrainer(Trainer): + """ + Trainer capable of training a WordPiece model + + Args: + vocab_size (:obj:`int`, `optional`): + The size of the final vocabulary, including all tokens and alphabet. + + min_frequency (:obj:`int`, `optional`): + The minimum frequency a pair should have in order to be merged. + + show_progress (:obj:`bool`, `optional`): + Whether to show progress bars while training. + + special_tokens (:obj:`List[Union[str, AddedToken]]`, `optional`): + A list of special tokens the model should know of. + + limit_alphabet (:obj:`int`, `optional`): + The maximum different characters to keep in the alphabet. + + initial_alphabet (:obj:`List[str]`, `optional`): + A list of characters to include in the initial alphabet, even + if not seen in the training dataset. + If the strings contain more than one character, only the first one + is kept. + + continuing_subword_prefix (:obj:`str`, `optional`): + A prefix to be used for every subword that is not a beginning-of-word. + + end_of_word_suffix (:obj:`str`, `optional`): + A suffix to be used for every subword that is a end-of-word. + """ + def __init__( + self, + vocab_size=30000, + min_frequency=0, + show_progress=True, + special_tokens=[], + limit_alphabet=None, + initial_alphabet=[], + continuing_subword_prefix="##", + end_of_word_suffix=None, + ): + pass diff --git a/phivenv/Lib/site-packages/tokenizers/trainers/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/tokenizers/trainers/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71d097888936f69a188e62ee4e3a9b088fbe33c5 Binary files /dev/null and b/phivenv/Lib/site-packages/tokenizers/trainers/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_C.cp39-win_amd64.pyd b/phivenv/Lib/site-packages/torch/_C.cp39-win_amd64.pyd new file mode 100644 index 0000000000000000000000000000000000000000..78ffd85b9d7651dd32666ab50b52af5f9dcf52b8 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_C.cp39-win_amd64.pyd differ diff --git a/phivenv/Lib/site-packages/torch/_VF.py b/phivenv/Lib/site-packages/torch/_VF.py new file mode 100644 index 0000000000000000000000000000000000000000..af55aff9a215f2b39f57d2cf48ae471343bda66c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_VF.py @@ -0,0 +1,31 @@ +""" +This makes the functions in torch._C._VariableFunctions available as + torch._VF. +without mypy being able to find them. + +A subset of those functions are mapped to ATen functions in +torch/jit/_builtins.py + +See https://github.com/pytorch/pytorch/issues/21478 for the reason for +introducing torch._VF + +""" + +import sys +import types + +import torch + + +class VFModule(types.ModuleType): + vf: types.ModuleType + + def __init__(self, name: str): + super().__init__(name) + self.vf = torch._C._VariableFunctions + + def __getattr__(self, name: str) -> object: + return getattr(self.vf, name) + + +sys.modules[__name__] = VFModule(__name__) diff --git a/phivenv/Lib/site-packages/torch/_VF.pyi b/phivenv/Lib/site-packages/torch/_VF.pyi new file mode 100644 index 0000000000000000000000000000000000000000..f7429c5df2fede845357a319b6a03ff038ed3a7c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_VF.pyi @@ -0,0 +1,32949 @@ +# @generated by tools/pyi/gen_pyi.py from torch/_C/_VariableFunctions.pyi.in +# mypy: disable-error-code="type-arg" +# mypy: allow-untyped-defs +# ruff: noqa: F401,PYI054 + +from collections.abc import Sequence +from types import EllipsisType +from typing import Any, Callable, Literal, overload, TypeVar + +import torch +from torch import ( + contiguous_format, + Generator, + inf, + memory_format, + strided, + SymInt, + Tensor, +) +from torch._prims_common import DeviceLikeType +from torch.types import ( + _bool, + _complex, + _device, + _dtype, + _float, + _int, + _layout, + _qscheme, + _size, + Device, + Number, +) + +__all__ = [ + "__and__", + "__lshift__", + "__or__", + "__rshift__", + "__xor__", + "_adaptive_avg_pool2d", + "_adaptive_avg_pool3d", + "_add_batch_dim", + "_add_relu", + "_add_relu_", + "_addmm_activation", + "_aminmax", + "_amp_foreach_non_finite_check_and_unscale_", + "_amp_update_scale_", + "_assert_async", + "_assert_scalar", + "_assert_tensor_metadata", + "_batch_norm_impl_index", + "_cast_Byte", + "_cast_Char", + "_cast_Double", + "_cast_Float", + "_cast_Half", + "_cast_Int", + "_cast_Long", + "_cast_Short", + "_choose_qparams_per_tensor", + "_chunk_cat", + "_coalesce", + "_compute_linear_combination", + "_conj", + "_conj_copy", + "_conj_physical", + "_convert_indices_from_coo_to_csr", + "_convert_indices_from_csr_to_coo", + "_convert_weight_to_int4pack", + "_convert_weight_to_int4pack_for_cpu", + "_convolution", + "_convolution_mode", + "_copy_from", + "_copy_from_and_resize", + "_cslt_compress", + "_cslt_sparse_mm", + "_cslt_sparse_mm_search", + "_ctc_loss", + "_cudnn_ctc_loss", + "_cudnn_init_dropout_state", + "_cudnn_rnn", + "_cudnn_rnn_flatten_weight", + "_cufft_clear_plan_cache", + "_cufft_get_plan_cache_max_size", + "_cufft_get_plan_cache_size", + "_cufft_set_plan_cache_max_size", + "_cummax_helper", + "_cummin_helper", + "_debug_has_internal_overlap", + "_dim_arange", + "_dirichlet_grad", + "_disable_functionalization", + "_dyn_quant_matmul_4bit", + "_dyn_quant_pack_4bit_weight", + "_efficientzerotensor", + "_embedding_bag", + "_embedding_bag_forward_only", + "_empty_affine_quantized", + "_empty_per_channel_affine_quantized", + "_enable_functionalization", + "_euclidean_dist", + "_fake_quantize_learnable_per_channel_affine", + "_fake_quantize_learnable_per_tensor_affine", + "_fake_quantize_per_tensor_affine_cachemask_tensor_qparams", + "_fake_quantize_per_tensor_affine_cachemask_tensor_qparams", + "_fft_c2c", + "_fft_c2r", + "_fft_r2c", + "_fill_mem_eff_dropout_mask_", + "_foobar", + "_foreach_abs", + "_foreach_abs_", + "_foreach_acos", + "_foreach_acos_", + "_foreach_add", + "_foreach_add_", + "_foreach_addcdiv", + "_foreach_addcdiv_", + "_foreach_addcmul", + "_foreach_addcmul_", + "_foreach_asin", + "_foreach_asin_", + "_foreach_atan", + "_foreach_atan_", + "_foreach_ceil", + "_foreach_ceil_", + "_foreach_clamp_max", + "_foreach_clamp_max_", + "_foreach_clamp_min", + "_foreach_clamp_min_", + "_foreach_copy_", + "_foreach_cos", + "_foreach_cos_", + "_foreach_cosh", + "_foreach_cosh_", + "_foreach_div", + "_foreach_div_", + "_foreach_erf", + "_foreach_erf_", + "_foreach_erfc", + "_foreach_erfc_", + "_foreach_exp", + "_foreach_exp_", + "_foreach_expm1", + "_foreach_expm1_", + "_foreach_floor", + "_foreach_floor_", + "_foreach_frac", + "_foreach_frac_", + "_foreach_lerp", + "_foreach_lerp_", + "_foreach_lgamma", + "_foreach_lgamma_", + "_foreach_log", + "_foreach_log10", + "_foreach_log10_", + "_foreach_log1p", + "_foreach_log1p_", + "_foreach_log2", + "_foreach_log2_", + "_foreach_log_", + "_foreach_max", + "_foreach_maximum", + "_foreach_maximum_", + "_foreach_minimum", + "_foreach_minimum_", + "_foreach_mul", + "_foreach_mul_", + "_foreach_neg", + "_foreach_neg_", + "_foreach_norm", + "_foreach_pow", + "_foreach_pow_", + "_foreach_reciprocal", + "_foreach_reciprocal_", + "_foreach_round", + "_foreach_round_", + "_foreach_rsqrt", + "_foreach_rsqrt_", + "_foreach_sigmoid", + "_foreach_sigmoid_", + "_foreach_sign", + "_foreach_sign_", + "_foreach_sin", + "_foreach_sin_", + "_foreach_sinh", + "_foreach_sinh_", + "_foreach_sqrt", + "_foreach_sqrt_", + "_foreach_sub", + "_foreach_sub_", + "_foreach_tan", + "_foreach_tan_", + "_foreach_tanh", + "_foreach_tanh_", + "_foreach_trunc", + "_foreach_trunc_", + "_foreach_zero_", + "_from_functional_tensor", + "_functional_assert_async", + "_functional_assert_scalar", + "_functional_sym_constrain_range", + "_functional_sym_constrain_range_for_size", + "_functionalize_apply_view_metas", + "_functionalize_are_all_mutations_hidden_from_autograd", + "_functionalize_are_all_mutations_under_no_grad_or_inference_mode", + "_functionalize_commit_update", + "_functionalize_has_metadata_mutation", + "_functionalize_is_symbolic", + "_functionalize_mark_mutation_hidden_from_autograd", + "_functionalize_replace", + "_functionalize_set_storage_changed", + "_functionalize_sync", + "_functionalize_unsafe_set", + "_functionalize_was_inductor_storage_resized", + "_functionalize_was_storage_changed", + "_fused_adagrad_", + "_fused_adam_", + "_fused_adamw_", + "_fused_dropout", + "_fused_moving_avg_obs_fq_helper", + "_fused_moving_avg_obs_fq_helper", + "_fused_rms_norm", + "_fused_sdp_choice", + "_fused_sgd_", + "_fw_primal_copy", + "_grid_sampler_2d_cpu_fallback", + "_grouped_mm", + "_has_compatible_shallow_copy_type", + "_histogramdd_bin_edges", + "_histogramdd_from_bin_cts", + "_histogramdd_from_bin_tensors", + "_index_put_impl_", + "_indices_copy", + "_int_mm", + "_is_all_true", + "_is_any_true", + "_is_functional_tensor", + "_is_functional_tensor_base", + "_is_zerotensor", + "_lazy_clone", + "_linalg_check_errors", + "_linalg_det", + "_linalg_det", + "_linalg_eigh", + "_linalg_eigh", + "_linalg_slogdet", + "_linalg_slogdet", + "_linalg_solve_ex", + "_linalg_solve_ex", + "_linalg_svd", + "_linalg_svd", + "_log_softmax", + "_log_softmax_backward_data", + "_logcumsumexp", + "_lstm_mps", + "_lu_with_info", + "_lu_with_info", + "_make_dep_token", + "_make_dual", + "_make_dual_copy", + "_make_per_channel_quantized_tensor", + "_make_per_tensor_quantized_tensor", + "_masked_scale", + "_masked_softmax", + "_mixed_dtypes_linear", + "_mkldnn_reshape", + "_mkldnn_transpose", + "_mkldnn_transpose_", + "_mps_convolution", + "_mps_convolution_transpose", + "_native_batch_norm_legit", + "_native_batch_norm_legit_no_training", + "_native_multi_head_attention", + "_neg_view", + "_neg_view_copy", + "_nested_compute_contiguous_strides_offsets", + "_nested_from_padded", + "_nested_from_padded_and_nested_example", + "_nested_from_padded_tensor", + "_nested_get_jagged_dummy", + "_nested_get_lengths", + "_nested_get_max_seqlen", + "_nested_get_min_seqlen", + "_nested_get_offsets", + "_nested_get_ragged_idx", + "_nested_get_values", + "_nested_get_values_copy", + "_nested_tensor_from_mask", + "_nested_tensor_from_mask_left_aligned", + "_nested_tensor_from_tensor_list", + "_nested_tensor_softmax_with_shape", + "_nested_view_from_buffer", + "_nested_view_from_buffer_copy", + "_nested_view_from_jagged", + "_nested_view_from_jagged_copy", + "_nnpack_available", + "_nnpack_spatial_convolution", + "_pack_padded_sequence", + "_pad_packed_sequence", + "_pin_memory", + "_prelu_kernel", + "_print", + "_propagate_xla_data", + "_remove_batch_dim", + "_reshape_alias_copy", + "_reshape_from_tensor", + "_resize_output_", + "_rowwise_prune", + "_safe_softmax", + "_sample_dirichlet", + "_saturate_weight_to_fp16", + "_scaled_dot_product_attention_math", + "_scaled_dot_product_attention_math_for_mps", + "_scaled_dot_product_cudnn_attention", + "_scaled_dot_product_cudnn_attention", + "_scaled_dot_product_efficient_attention", + "_scaled_dot_product_efficient_attention", + "_scaled_dot_product_flash_attention", + "_scaled_dot_product_flash_attention", + "_scaled_dot_product_flash_attention_for_cpu", + "_scaled_dot_product_flash_attention_for_cpu", + "_scaled_grouped_mm", + "_scaled_mm", + "_shape_as_tensor", + "_sobol_engine_draw", + "_sobol_engine_ff_", + "_sobol_engine_initialize_state_", + "_sobol_engine_scramble_", + "_softmax", + "_softmax_backward_data", + "_sparse_broadcast_to", + "_sparse_broadcast_to_copy", + "_sparse_csr_prod", + "_sparse_csr_sum", + "_sparse_log_softmax_backward_data", + "_sparse_semi_structured_addmm", + "_sparse_semi_structured_apply", + "_sparse_semi_structured_apply_dense", + "_sparse_semi_structured_linear", + "_sparse_semi_structured_mm", + "_sparse_semi_structured_tile", + "_sparse_softmax_backward_data", + "_sparse_sparse_matmul", + "_sparse_sum", + "_stack", + "_standard_gamma", + "_standard_gamma_grad", + "_sync", + "_test_autograd_multiple_dispatch", + "_test_autograd_multiple_dispatch_view", + "_test_autograd_multiple_dispatch_view_copy", + "_test_check_tensor", + "_test_functorch_fallback", + "_test_parallel_materialize", + "_test_serialization_subcmul", + "_to_cpu", + "_to_functional_tensor", + "_to_sparse_semi_structured", + "_transform_bias_rescale_qkv", + "_transformer_encoder_layer_fwd", + "_trilinear", + "_triton_multi_head_attention", + "_triton_scaled_dot_attention", + "_unique", + "_unique2", + "_unpack_dual", + "_unpack_dual", + "_unsafe_index", + "_unsafe_index_put", + "_unsafe_masked_index", + "_unsafe_masked_index_put_accumulate", + "_use_cudnn_ctc_loss", + "_use_cudnn_rnn_flatten_weight", + "_validate_compressed_sparse_indices", + "_validate_sparse_bsc_tensor_args", + "_validate_sparse_bsr_tensor_args", + "_validate_sparse_compressed_tensor_args", + "_validate_sparse_coo_tensor_args", + "_validate_sparse_csc_tensor_args", + "_validate_sparse_csr_tensor_args", + "_values_copy", + "_weight_int4pack_mm", + "_weight_int4pack_mm_for_cpu", + "_weight_int4pack_mm_with_scales_and_zeros", + "_weight_int8pack_mm", + "_weight_norm", + "_weight_norm_interface", + "_wrapped_linear_prepack", + "_wrapped_quantized_linear_prepacked", + "abs", + "abs_", + "absolute", + "acos", + "acos_", + "acosh", + "acosh_", + "adaptive_avg_pool1d", + "adaptive_max_pool1d", + "add", + "addbmm", + "addcdiv", + "addcmul", + "addmm", + "addmv", + "addmv_", + "addr", + "adjoint", + "affine_grid_generator", + "alias_copy", + "all", + "allclose", + "alpha_dropout", + "alpha_dropout_", + "amax", + "amin", + "aminmax", + "aminmax", + "angle", + "any", + "arange", + "arccos", + "arccos_", + "arccosh", + "arccosh_", + "arcsin", + "arcsin_", + "arcsinh", + "arcsinh_", + "arctan", + "arctan2", + "arctan_", + "arctanh", + "arctanh_", + "argmax", + "argmin", + "argsort", + "argwhere", + "as_strided", + "as_strided_", + "as_strided_copy", + "as_strided_scatter", + "as_tensor", + "asarray", + "asin", + "asin_", + "asinh", + "asinh_", + "atan", + "atan2", + "atan_", + "atanh", + "atanh_", + "avg_pool1d", + "baddbmm", + "bartlett_window", + "batch_norm", + "batch_norm_backward_elemt", + "batch_norm_backward_reduce", + "batch_norm_elemt", + "batch_norm_gather_stats", + "batch_norm_gather_stats_with_counts", + "batch_norm_stats", + "batch_norm_update_stats", + "bernoulli", + "bilinear", + "binary_cross_entropy_with_logits", + "bincount", + "binomial", + "bitwise_and", + "bitwise_left_shift", + "bitwise_not", + "bitwise_or", + "bitwise_right_shift", + "bitwise_xor", + "blackman_window", + "bmm", + "broadcast_to", + "bucketize", + "can_cast", + "cat", + "ccol_indices_copy", + "ceil", + "ceil_", + "celu", + "celu_", + "channel_shuffle", + "cholesky", + "cholesky_inverse", + "cholesky_solve", + "choose_qparams_optimized", + "chunk", + "clamp", + "clamp_", + "clamp_max", + "clamp_max_", + "clamp_min", + "clamp_min_", + "clip", + "clip_", + "clone", + "col_indices_copy", + "column_stack", + "combinations", + "complex", + "concat", + "concatenate", + "conj", + "conj_physical", + "conj_physical_", + "constant_pad_nd", + "conv1d", + "conv2d", + "conv3d", + "conv_tbc", + "conv_transpose1d", + "conv_transpose2d", + "conv_transpose3d", + "convolution", + "copysign", + "corrcoef", + "cos", + "cos_", + "cosh", + "cosh_", + "cosine_embedding_loss", + "cosine_similarity", + "count_nonzero", + "cov", + "cross", + "crow_indices_copy", + "ctc_loss", + "cudnn_affine_grid_generator", + "cudnn_batch_norm", + "cudnn_convolution", + "cudnn_convolution_add_relu", + "cudnn_convolution_relu", + "cudnn_convolution_transpose", + "cudnn_grid_sampler", + "cudnn_is_acceptable", + "cummax", + "cummax", + "cummin", + "cummin", + "cumprod", + "cumsum", + "cumulative_trapezoid", + "deg2rad", + "deg2rad_", + "dequantize", + "det", + "detach", + "detach_", + "detach_copy", + "diag", + "diag_embed", + "diagflat", + "diagonal", + "diagonal_copy", + "diagonal_scatter", + "diff", + "digamma", + "dist", + "div", + "divide", + "dot", + "dropout", + "dropout_", + "dsmm", + "dsplit", + "dstack", + "embedding", + "embedding_bag", + "embedding_renorm_", + "empty", + "empty_like", + "empty_permuted", + "empty_quantized", + "empty_strided", + "eq", + "equal", + "erf", + "erf_", + "erfc", + "erfc_", + "erfinv", + "exp", + "exp2", + "exp2_", + "exp_", + "expand_copy", + "expm1", + "expm1_", + "eye", + "fake_quantize_per_channel_affine", + "fake_quantize_per_tensor_affine", + "fbgemm_linear_fp16_weight", + "fbgemm_linear_fp16_weight_fp32_activation", + "fbgemm_linear_int8_weight", + "fbgemm_linear_int8_weight_fp32_activation", + "fbgemm_linear_quantize_weight", + "fbgemm_pack_gemm_matrix_fp16", + "fbgemm_pack_quantized_matrix", + "feature_alpha_dropout", + "feature_alpha_dropout_", + "feature_dropout", + "feature_dropout_", + "fill", + "fill_", + "fix", + "fix_", + "flatten", + "flip", + "fliplr", + "flipud", + "float_power", + "floor", + "floor_", + "floor_divide", + "fmax", + "fmin", + "fmod", + "frac", + "frac_", + "frexp", + "frexp", + "frobenius_norm", + "from_file", + "from_numpy", + "frombuffer", + "full", + "full_like", + "fused_moving_avg_obs_fake_quant", + "gather", + "gcd", + "gcd_", + "ge", + "geqrf", + "geqrf", + "ger", + "get_default_dtype", + "get_num_interop_threads", + "get_num_threads", + "gradient", + "greater", + "greater_equal", + "grid_sampler", + "grid_sampler_2d", + "grid_sampler_3d", + "group_norm", + "gru", + "gru_cell", + "gt", + "hamming_window", + "hann_window", + "hardshrink", + "heaviside", + "hinge_embedding_loss", + "histc", + "histogram", + "histogram", + "histogramdd", + "histogramdd", + "hsmm", + "hsplit", + "hspmm", + "hstack", + "hypot", + "i0", + "i0_", + "igamma", + "igammac", + "imag", + "index_add", + "index_copy", + "index_fill", + "index_put", + "index_put_", + "index_reduce", + "index_select", + "indices_copy", + "init_num_threads", + "inner", + "instance_norm", + "int_repr", + "inverse", + "is_complex", + "is_conj", + "is_distributed", + "is_floating_point", + "is_grad_enabled", + "is_inference", + "is_inference_mode_enabled", + "is_neg", + "is_nonzero", + "is_same_size", + "is_signed", + "is_vulkan_available", + "isclose", + "isfinite", + "isin", + "isinf", + "isnan", + "isneginf", + "isposinf", + "isreal", + "istft", + "kaiser_window", + "kl_div", + "kron", + "kthvalue", + "kthvalue", + "layer_norm", + "lcm", + "lcm_", + "ldexp", + "ldexp_", + "le", + "lerp", + "less", + "less_equal", + "lgamma", + "linspace", + "log", + "log10", + "log10_", + "log1p", + "log1p_", + "log2", + "log2_", + "log_", + "log_softmax", + "logaddexp", + "logaddexp2", + "logcumsumexp", + "logdet", + "logical_and", + "logical_not", + "logical_or", + "logical_xor", + "logit", + "logit_", + "logspace", + "logsumexp", + "lstm", + "lstm_cell", + "lt", + "lu_solve", + "lu_unpack", + "lu_unpack", + "margin_ranking_loss", + "masked_fill", + "masked_scatter", + "masked_select", + "matmul", + "matrix_exp", + "matrix_power", + "max", + "max", + "max_pool1d", + "max_pool1d_with_indices", + "max_pool2d", + "max_pool3d", + "maximum", + "mean", + "median", + "median", + "min", + "min", + "minimum", + "miopen_batch_norm", + "miopen_convolution", + "miopen_convolution_add_relu", + "miopen_convolution_relu", + "miopen_convolution_transpose", + "miopen_depthwise_convolution", + "miopen_rnn", + "mkldnn_adaptive_avg_pool2d", + "mkldnn_convolution", + "mkldnn_linear_backward_weights", + "mkldnn_max_pool2d", + "mkldnn_max_pool3d", + "mkldnn_rnn_layer", + "mm", + "mode", + "mode", + "moveaxis", + "movedim", + "msort", + "mul", + "multinomial", + "multiply", + "mv", + "mvlgamma", + "nan_to_num", + "nan_to_num_", + "nanmean", + "nanmedian", + "nanmedian", + "nanquantile", + "nansum", + "narrow", + "narrow_copy", + "native_batch_norm", + "native_channel_shuffle", + "native_dropout", + "native_group_norm", + "native_layer_norm", + "native_norm", + "ne", + "neg", + "neg_", + "negative", + "negative_", + "nextafter", + "nonzero", + "nonzero_static", + "norm_except_dim", + "normal", + "not_equal", + "nuclear_norm", + "numel", + "ones", + "ones_like", + "orgqr", + "ormqr", + "outer", + "pairwise_distance", + "pdist", + "permute", + "permute_copy", + "pinverse", + "pixel_shuffle", + "pixel_unshuffle", + "poisson", + "poisson_nll_loss", + "polar", + "polygamma", + "positive", + "pow", + "prelu", + "prod", + "promote_types", + "put", + "q_per_channel_axis", + "q_per_channel_scales", + "q_per_channel_zero_points", + "q_scale", + "q_zero_point", + "qr", + "qr", + "quantile", + "quantize_per_channel", + "quantize_per_tensor", + "quantize_per_tensor_dynamic", + "quantized_batch_norm", + "quantized_gru_cell", + "quantized_lstm_cell", + "quantized_max_pool1d", + "quantized_max_pool2d", + "quantized_max_pool3d", + "quantized_rnn_relu_cell", + "quantized_rnn_tanh_cell", + "rad2deg", + "rad2deg_", + "rand", + "rand_like", + "randint", + "randint_like", + "randn", + "randn_like", + "randperm", + "range", + "ravel", + "real", + "reciprocal", + "reciprocal_", + "relu", + "relu_", + "remainder", + "renorm", + "repeat_interleave", + "reshape", + "resize_as_", + "resize_as_sparse_", + "resolve_conj", + "resolve_neg", + "result_type", + "rms_norm", + "rnn_relu", + "rnn_relu_cell", + "rnn_tanh", + "rnn_tanh_cell", + "roll", + "rot90", + "round", + "round_", + "row_indices_copy", + "row_stack", + "rrelu", + "rrelu_", + "rsqrt", + "rsqrt_", + "rsub", + "saddmm", + "scalar_tensor", + "scatter", + "scatter_add", + "scatter_reduce", + "searchsorted", + "segment_reduce", + "select", + "select_copy", + "select_scatter", + "selu", + "selu_", + "set_flush_denormal", + "set_num_interop_threads", + "set_num_threads", + "sgn", + "sigmoid", + "sigmoid_", + "sign", + "signbit", + "sin", + "sin_", + "sinc", + "sinc_", + "sinh", + "sinh_", + "slice_copy", + "slice_inverse", + "slice_scatter", + "slogdet", + "slogdet", + "smm", + "softmax", + "sort", + "sort", + "sparse_bsc_tensor", + "sparse_bsr_tensor", + "sparse_compressed_tensor", + "sparse_coo_tensor", + "sparse_csc_tensor", + "sparse_csr_tensor", + "split_copy", + "split_with_sizes", + "split_with_sizes_copy", + "spmm", + "sqrt", + "sqrt_", + "square", + "square_", + "squeeze", + "squeeze_copy", + "sspaddmm", + "stack", + "std", + "std_mean", + "sub", + "subtract", + "sum", + "svd", + "svd", + "swapaxes", + "swapdims", + "sym_constrain_range", + "sym_constrain_range_for_size", + "t", + "t_copy", + "take", + "take_along_dim", + "tan", + "tan_", + "tanh", + "tanh_", + "tensor", + "tensor_split", + "threshold", + "threshold_", + "tile", + "topk", + "topk", + "trace", + "transpose", + "transpose_copy", + "trapezoid", + "trapz", + "triangular_solve", + "triangular_solve", + "tril", + "tril_indices", + "triplet_margin_loss", + "triu", + "triu_indices", + "true_divide", + "trunc", + "trunc_", + "unbind", + "unbind_copy", + "unflatten", + "unfold_copy", + "unique_dim", + "unsafe_chunk", + "unsafe_split", + "unsafe_split_with_sizes", + "unsqueeze", + "unsqueeze_copy", + "values_copy", + "vander", + "var", + "var_mean", + "vdot", + "view_as_complex", + "view_as_complex_copy", + "view_as_real", + "view_as_real_copy", + "view_copy", + "vsplit", + "vstack", + "where", + "xlogy", + "xlogy_", + "zero_", + "zeros", + "zeros_like", +] + +@overload +def __and__(input: Tensor, other: Tensor) -> Tensor: ... +@overload +def __and__(input: Tensor, other: Number | _complex) -> Tensor: ... +@overload +def __lshift__(input: Tensor, other: Tensor) -> Tensor: ... +@overload +def __lshift__(input: Tensor, other: Number | _complex) -> Tensor: ... +@overload +def __or__(input: Tensor, other: Tensor) -> Tensor: ... +@overload +def __or__(input: Tensor, other: Number | _complex) -> Tensor: ... +@overload +def __rshift__(input: Tensor, other: Tensor) -> Tensor: ... +@overload +def __rshift__(input: Tensor, other: Number | _complex) -> Tensor: ... +@overload +def __xor__(input: Tensor, other: Tensor) -> Tensor: ... +@overload +def __xor__(input: Tensor, other: Number | _complex) -> Tensor: ... +def _adaptive_avg_pool2d( + input: Tensor, + output_size: _int | SymInt | Sequence[_int | SymInt], +) -> Tensor: ... +def _adaptive_avg_pool3d( + input: Tensor, + output_size: _int | SymInt | Sequence[_int | SymInt], +) -> Tensor: ... +def _add_batch_dim(input: Tensor, batch_dim: _int, level: _int) -> Tensor: ... +@overload +def _add_relu( + input: Tensor, + other: Tensor, + *, + alpha: Number | _complex = 1, + out: Tensor | None = None, +) -> Tensor: ... +@overload +def _add_relu( + input: Tensor, + other: Number | _complex, + alpha: Number | _complex = 1, +) -> Tensor: ... +@overload +def _add_relu_( + input: Tensor, + other: Tensor, + *, + alpha: Number | _complex = 1, +) -> Tensor: ... +@overload +def _add_relu_( + input: Tensor, + other: Number | _complex, + alpha: Number | _complex = 1, +) -> Tensor: ... +def _addmm_activation( + input: Tensor, + mat1: Tensor, + mat2: Tensor, + *, + beta: Number | _complex = 1, + alpha: Number | _complex = 1, + use_gelu: _bool = False, + out: Tensor | None = None, +) -> Tensor: ... +@overload +def _aminmax(input: Tensor) -> tuple[Tensor, Tensor]: ... +@overload +def _aminmax( + input: Tensor, + dim: _int, + keepdim: _bool = False, +) -> tuple[Tensor, Tensor]: ... +def _amp_foreach_non_finite_check_and_unscale_( + self: tuple[Tensor, ...] | list[Tensor] | None, + found_inf: Tensor, + inv_scale: Tensor, +) -> None: ... +def _amp_update_scale_( + input: Tensor, + growth_tracker: Tensor, + found_inf: Tensor, + scale_growth_factor: _float, + scale_backoff_factor: _float, + growth_interval: _int, +) -> Tensor: ... +@overload +def _assert_async(input: Tensor) -> None: + r""" + _assert_async(tensor) -> void + + Asynchronously assert that the contents of tensor are nonzero. For CPU tensors, + this is equivalent to ``assert tensor`` or ``assert tensor.is_nonzero()``; for + CUDA tensors, we DO NOT synchronize and you may only find out the assertion + failed at a later CUDA kernel launch. Asynchronous assertion can be helpful for + testing invariants in CUDA tensors without giving up performance. This function + is NOT intended to be used for regular error checking, as it will trash your CUDA + context if the assert fails (forcing you to restart your PyTorch process.) + + Args: + tensor (Tensor): a one element tensor to test to see if it is nonzero. Zero + elements (including False for boolean tensors) cause an assertion failure + to be raised. + """ + +@overload +def _assert_async(input: Tensor, assert_msg: str) -> None: + r""" + _assert_async(tensor) -> void + + Asynchronously assert that the contents of tensor are nonzero. For CPU tensors, + this is equivalent to ``assert tensor`` or ``assert tensor.is_nonzero()``; for + CUDA tensors, we DO NOT synchronize and you may only find out the assertion + failed at a later CUDA kernel launch. Asynchronous assertion can be helpful for + testing invariants in CUDA tensors without giving up performance. This function + is NOT intended to be used for regular error checking, as it will trash your CUDA + context if the assert fails (forcing you to restart your PyTorch process.) + + Args: + tensor (Tensor): a one element tensor to test to see if it is nonzero. Zero + elements (including False for boolean tensors) cause an assertion failure + to be raised. + """ + +def _assert_scalar(self: Number | _complex, assert_msg: str) -> None: ... +def _assert_tensor_metadata( + a: Tensor, + size: Sequence[_int | SymInt] | None = None, + stride: Sequence[_int | SymInt] | None = None, + dtype: _dtype | None = None, + *, + device: DeviceLikeType | None = None, + layout: _layout | None = None, +) -> None: ... +def _batch_norm_impl_index( + input: Tensor, + weight: Tensor | None, + bias: Tensor | None, + running_mean: Tensor | None, + running_var: Tensor | None, + training: _bool, + momentum: _float, + eps: _float, + cudnn_enabled: _bool, +) -> tuple[Tensor, Tensor, Tensor, Tensor, _int]: ... +def _cast_Byte(input: Tensor, non_blocking: _bool = False) -> Tensor: ... +def _cast_Char(input: Tensor, non_blocking: _bool = False) -> Tensor: ... +def _cast_Double(input: Tensor, non_blocking: _bool = False) -> Tensor: ... +def _cast_Float(input: Tensor, non_blocking: _bool = False) -> Tensor: ... +def _cast_Half(input: Tensor, non_blocking: _bool = False) -> Tensor: ... +def _cast_Int(input: Tensor, non_blocking: _bool = False) -> Tensor: ... +def _cast_Long(input: Tensor, non_blocking: _bool = False) -> Tensor: ... +def _cast_Short(input: Tensor, non_blocking: _bool = False) -> Tensor: ... +def _choose_qparams_per_tensor( + input: Tensor, + reduce_range: _bool = False, +) -> tuple[_float, _int]: ... +def _chunk_cat( + tensors: tuple[Tensor, ...] | list[Tensor] | None, + dim: _int, + num_chunks: _int, + *, + out: Tensor | None = None, +) -> Tensor: ... +def _coalesce(input: Tensor) -> Tensor: ... +def _compute_linear_combination( + input: Tensor, + coefficients: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: ... +def _conj(input: Tensor) -> Tensor: ... +def _conj_copy(input: Tensor, *, out: Tensor | None = None) -> Tensor: ... +def _conj_physical(input: Tensor) -> Tensor: ... +def _convert_indices_from_coo_to_csr( + input: Tensor, + size: _int, + *, + out_int32: _bool = False, + out: Tensor | None = None, +) -> Tensor: ... +def _convert_indices_from_csr_to_coo( + crow_indices: Tensor, + col_indices: Tensor, + *, + out_int32: _bool = False, + transpose: _bool = False, + out: Tensor | None = None, +) -> Tensor: ... +def _convert_weight_to_int4pack(input: Tensor, innerKTiles: _int) -> Tensor: ... +def _convert_weight_to_int4pack_for_cpu( + input: Tensor, + innerKTiles: _int, +) -> Tensor: ... +@overload +def _convolution( + input: Tensor, + weight: Tensor, + bias: Tensor | None, + stride: Sequence[_int | SymInt], + padding: Sequence[_int | SymInt], + dilation: Sequence[_int | SymInt], + transposed: _bool, + output_padding: _size, + groups: _int | SymInt, + benchmark: _bool, + deterministic: _bool, + cudnn_enabled: _bool, +) -> Tensor: ... +@overload +def _convolution( + input: Tensor, + weight: Tensor, + bias: Tensor | None, + stride: Sequence[_int | SymInt], + padding: Sequence[_int | SymInt], + dilation: Sequence[_int | SymInt], + transposed: _bool, + output_padding: Sequence[_int | SymInt], + groups: _int | SymInt, + benchmark: _bool, + deterministic: _bool, + cudnn_enabled: _bool, + allow_tf32: _bool, +) -> Tensor: ... +def _convolution_mode( + input: Tensor, + weight: Tensor, + bias: Tensor | None, + stride: Sequence[_int | SymInt], + padding: str, + dilation: Sequence[_int | SymInt], + groups: _int | SymInt, +) -> Tensor: ... +def _copy_from( + input: Tensor, + dst: Tensor, + non_blocking: _bool = False, +) -> Tensor: ... +def _copy_from_and_resize(input: Tensor, dst: Tensor) -> Tensor: ... +def _cslt_compress(input: Tensor) -> Tensor: ... +def _cslt_sparse_mm( + compressed_A: Tensor, + dense_B: Tensor, + bias: Tensor | None = None, + alpha: Tensor | None = None, + out_dtype: _dtype | None = None, + transpose_result: _bool = False, + alg_id: _int = 0, + split_k: _int = 1, + split_k_mode: _int = -1, +) -> Tensor: ... +def _cslt_sparse_mm_search( + compressed_A: Tensor, + dense_B: Tensor, + bias: Tensor | None = None, + alpha: Tensor | None = None, + out_dtype: _dtype | None = None, + transpose_result: _bool = False, +) -> _int: ... +@overload +def _ctc_loss( + log_probs: Tensor, + targets: Tensor, + input_lengths: _size, + target_lengths: _size, + blank: _int = 0, + zero_infinity: _bool = False, +) -> tuple[Tensor, Tensor]: ... +@overload +def _ctc_loss( + log_probs: Tensor, + targets: Tensor, + input_lengths: Tensor, + target_lengths: Tensor, + blank: _int = 0, + zero_infinity: _bool = False, +) -> tuple[Tensor, Tensor]: ... +@overload +def _cudnn_ctc_loss( + log_probs: Tensor, + targets: Tensor, + input_lengths: _size, + target_lengths: _size, + blank: _int, + deterministic: _bool, + zero_infinity: _bool, +) -> tuple[Tensor, Tensor]: ... +@overload +def _cudnn_ctc_loss( + log_probs: Tensor, + targets: Tensor, + input_lengths: Tensor, + target_lengths: Tensor, + blank: _int, + deterministic: _bool, + zero_infinity: _bool, +) -> tuple[Tensor, Tensor]: ... +def _cudnn_init_dropout_state( + dropout: _float, + train: _bool, + dropout_seed: _int, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: ... +def _cudnn_rnn( + input: Tensor, + weight: tuple[Tensor, ...] | list[Tensor] | None, + weight_stride0: _int, + weight_buf: Tensor | None, + hx: Tensor, + cx: Tensor | None, + mode: _int, + hidden_size: _int | SymInt, + proj_size: _int | SymInt, + num_layers: _int, + batch_first: _bool, + dropout: _float, + train: _bool, + bidirectional: _bool, + batch_sizes: Sequence[_int | SymInt], + dropout_state: Tensor | None, +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: ... +def _cudnn_rnn_flatten_weight( + weight_arr: tuple[Tensor, ...] | list[Tensor] | None, + weight_stride0: _int, + input_size: _int | SymInt, + mode: _int, + hidden_size: _int | SymInt, + proj_size: _int | SymInt, + num_layers: _int, + batch_first: _bool, + bidirectional: _bool, +) -> Tensor: ... +def _cufft_clear_plan_cache(device_index: _int) -> None: ... +def _cufft_get_plan_cache_max_size(device_index: _int) -> _int: ... +def _cufft_get_plan_cache_size(device_index: _int) -> _int: ... +def _cufft_set_plan_cache_max_size( + device_index: _int, + max_size: _int, +) -> None: ... +def _cummax_helper( + input: Tensor, + values: Tensor, + indices: Tensor, + dim: _int, +) -> None: ... +def _cummin_helper( + input: Tensor, + values: Tensor, + indices: Tensor, + dim: _int, +) -> None: ... +def _debug_has_internal_overlap(input: Tensor) -> _int: ... +def _dim_arange(like: Tensor, dim: _int) -> Tensor: ... +def _dirichlet_grad(x: Tensor, alpha: Tensor, total: Tensor) -> Tensor: ... +def _disable_functionalization(): ... +def _dyn_quant_matmul_4bit( + inp: Tensor, + packed_weights: Tensor, + block_size: _int, + in_features: _int, + out_features: _int, +) -> Tensor: ... +def _dyn_quant_pack_4bit_weight( + weights: Tensor, + scales_zeros: Tensor, + bias: Tensor | None, + block_size: _int, + in_features: _int, + out_features: _int, +) -> Tensor: ... +@overload +def _efficientzerotensor( + size: Sequence[_int | SymInt], + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: ... +@overload +def _efficientzerotensor( + *size: _int | SymInt, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: ... +def _embedding_bag( + weight: Tensor, + indices: Tensor, + offsets: Tensor, + scale_grad_by_freq: _bool = False, + mode: _int = 0, + sparse: _bool = False, + per_sample_weights: Tensor | None = None, + include_last_offset: _bool = False, + padding_idx: _int = -1, +) -> tuple[Tensor, Tensor, Tensor, Tensor]: ... +def _embedding_bag_forward_only( + weight: Tensor, + indices: Tensor, + offsets: Tensor, + scale_grad_by_freq: _bool = False, + mode: _int = 0, + sparse: _bool = False, + per_sample_weights: Tensor | None = None, + include_last_offset: _bool = False, + padding_idx: _int = -1, +) -> tuple[Tensor, Tensor, Tensor, Tensor]: ... +@overload +def _empty_affine_quantized( + size: Sequence[_int | SymInt], + *, + scale: _float = 1, + zero_point: _int = 0, + memory_format: memory_format | None = contiguous_format, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: ... +@overload +def _empty_affine_quantized( + *size: _int | SymInt, + scale: _float = 1, + zero_point: _int = 0, + memory_format: memory_format | None = contiguous_format, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: ... +@overload +def _empty_per_channel_affine_quantized( + size: Sequence[_int | SymInt], + *, + scales: Tensor, + zero_points: Tensor, + axis: _int, + memory_format: memory_format | None = contiguous_format, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: ... +@overload +def _empty_per_channel_affine_quantized( + *size: _int | SymInt, + scales: Tensor, + zero_points: Tensor, + axis: _int, + memory_format: memory_format | None = contiguous_format, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: ... +def _enable_functionalization(*, reapply_views: _bool = False) -> None: ... +def _euclidean_dist(x1: Tensor, x2: Tensor) -> Tensor: ... +def _fake_quantize_learnable_per_channel_affine( + input: Tensor, + scale: Tensor, + zero_point: Tensor, + axis: _int, + quant_min: _int, + quant_max: _int, + grad_factor: _float = 1.0, +) -> Tensor: ... +def _fake_quantize_learnable_per_tensor_affine( + input: Tensor, + scale: Tensor, + zero_point: Tensor, + quant_min: _int, + quant_max: _int, + grad_factor: _float = 1.0, +) -> Tensor: ... +def _fake_quantize_per_tensor_affine_cachemask_tensor_qparams( + input: Tensor, + scale: Tensor, + zero_point: Tensor, + fake_quant_enabled: Tensor, + quant_min: _int, + quant_max: _int, +) -> torch.return_types._fake_quantize_per_tensor_affine_cachemask_tensor_qparams: # fmt: skip + ... +def _fft_c2c( + input: Tensor, + dim: Sequence[_int | SymInt], + normalization: _int, + forward: _bool, + *, + out: Tensor | None = None, +) -> Tensor: ... +def _fft_c2r( + input: Tensor, + dim: _size, + normalization: _int, + last_dim_size: _int | SymInt, + *, + out: Tensor | None = None, +) -> Tensor: ... +def _fft_r2c( + input: Tensor, + dim: _size, + normalization: _int, + onesided: _bool, + *, + out: Tensor | None = None, +) -> Tensor: ... +def _fill_mem_eff_dropout_mask_( + input: Tensor, + dropout_p: _float, + seed: _int, + offset: _int, +) -> Tensor: ... +def _foobar( + input: Tensor, + arg1: _bool = True, + arg2: _bool = True, + *, + arg3: _bool = True, +) -> Tensor: ... +def _foreach_abs( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_abs(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.abs` to each Tensor of the input list. + """ + +def _foreach_abs_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_abs_(self: List[Tensor]) -> None + + Apply :func:`torch.abs` to each Tensor of the input list. + """ + +def _foreach_acos( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_acos(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.acos` to each Tensor of the input list. + """ + +def _foreach_acos_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_acos_(self: List[Tensor]) -> None + + Apply :func:`torch.acos` to each Tensor of the input list. + """ + +@overload +def _foreach_add( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_add( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: tuple[Tensor, ...] | list[Tensor] | None, + *, + alpha: Number | _complex = 1, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_add( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: Tensor, + *, + alpha: Number | _complex = 1, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_add( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalar: Number | _complex, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_add_( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> None: ... +@overload +def _foreach_add_( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: tuple[Tensor, ...] | list[Tensor] | None, + *, + alpha: Number | _complex = 1, +) -> None: ... +@overload +def _foreach_add_( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: Tensor, + *, + alpha: Number | _complex = 1, +) -> None: ... +@overload +def _foreach_add_( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalar: Number | _complex, +) -> None: ... +@overload +def _foreach_addcdiv( + self: tuple[Tensor, ...] | list[Tensor] | None, + tensor1: tuple[Tensor, ...] | list[Tensor] | None, + tensor2: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_addcdiv( + self: tuple[Tensor, ...] | list[Tensor] | None, + tensor1: tuple[Tensor, ...] | list[Tensor] | None, + tensor2: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Tensor, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_addcdiv( + self: tuple[Tensor, ...] | list[Tensor] | None, + tensor1: tuple[Tensor, ...] | list[Tensor] | None, + tensor2: tuple[Tensor, ...] | list[Tensor] | None, + value: Number | _complex = 1, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_addcdiv_( + self: tuple[Tensor, ...] | list[Tensor] | None, + tensor1: tuple[Tensor, ...] | list[Tensor] | None, + tensor2: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> None: ... +@overload +def _foreach_addcdiv_( + self: tuple[Tensor, ...] | list[Tensor] | None, + tensor1: tuple[Tensor, ...] | list[Tensor] | None, + tensor2: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Tensor, +) -> None: ... +@overload +def _foreach_addcdiv_( + self: tuple[Tensor, ...] | list[Tensor] | None, + tensor1: tuple[Tensor, ...] | list[Tensor] | None, + tensor2: tuple[Tensor, ...] | list[Tensor] | None, + value: Number | _complex = 1, +) -> None: ... +@overload +def _foreach_addcmul( + self: tuple[Tensor, ...] | list[Tensor] | None, + tensor1: tuple[Tensor, ...] | list[Tensor] | None, + tensor2: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_addcmul( + self: tuple[Tensor, ...] | list[Tensor] | None, + tensor1: tuple[Tensor, ...] | list[Tensor] | None, + tensor2: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Tensor, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_addcmul( + self: tuple[Tensor, ...] | list[Tensor] | None, + tensor1: tuple[Tensor, ...] | list[Tensor] | None, + tensor2: tuple[Tensor, ...] | list[Tensor] | None, + value: Number | _complex = 1, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_addcmul_( + self: tuple[Tensor, ...] | list[Tensor] | None, + tensor1: tuple[Tensor, ...] | list[Tensor] | None, + tensor2: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> None: ... +@overload +def _foreach_addcmul_( + self: tuple[Tensor, ...] | list[Tensor] | None, + tensor1: tuple[Tensor, ...] | list[Tensor] | None, + tensor2: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Tensor, +) -> None: ... +@overload +def _foreach_addcmul_( + self: tuple[Tensor, ...] | list[Tensor] | None, + tensor1: tuple[Tensor, ...] | list[Tensor] | None, + tensor2: tuple[Tensor, ...] | list[Tensor] | None, + value: Number | _complex = 1, +) -> None: ... +def _foreach_asin( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_asin(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.asin` to each Tensor of the input list. + """ + +def _foreach_asin_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_asin_(self: List[Tensor]) -> None + + Apply :func:`torch.asin` to each Tensor of the input list. + """ + +def _foreach_atan( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_atan(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.atan` to each Tensor of the input list. + """ + +def _foreach_atan_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_atan_(self: List[Tensor]) -> None + + Apply :func:`torch.atan` to each Tensor of the input list. + """ + +def _foreach_ceil( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_ceil(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.ceil` to each Tensor of the input list. + """ + +def _foreach_ceil_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_ceil_(self: List[Tensor]) -> None + + Apply :func:`torch.ceil` to each Tensor of the input list. + """ + +@overload +def _foreach_clamp_max( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_clamp_max( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalar: Number | _complex, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_clamp_max( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_clamp_max_( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> None: ... +@overload +def _foreach_clamp_max_( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalar: Number | _complex, +) -> None: ... +@overload +def _foreach_clamp_max_( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: tuple[Tensor, ...] | list[Tensor] | None, +) -> None: ... +@overload +def _foreach_clamp_min( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_clamp_min( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalar: Number | _complex, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_clamp_min( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_clamp_min_( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> None: ... +@overload +def _foreach_clamp_min_( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalar: Number | _complex, +) -> None: ... +@overload +def _foreach_clamp_min_( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: tuple[Tensor, ...] | list[Tensor] | None, +) -> None: ... +def _foreach_copy_( + self: tuple[Tensor, ...] | list[Tensor] | None, + src: tuple[Tensor, ...] | list[Tensor] | None, + non_blocking: _bool = False, +) -> None: ... +def _foreach_cos( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_cos(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.cos` to each Tensor of the input list. + """ + +def _foreach_cos_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_cos_(self: List[Tensor]) -> None + + Apply :func:`torch.cos` to each Tensor of the input list. + """ + +def _foreach_cosh( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_cosh(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.cosh` to each Tensor of the input list. + """ + +def _foreach_cosh_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_cosh_(self: List[Tensor]) -> None + + Apply :func:`torch.cosh` to each Tensor of the input list. + """ + +@overload +def _foreach_div( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_div( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: Tensor, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_div( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalar: Number | _complex, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_div( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_div_( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> None: ... +@overload +def _foreach_div_( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: Tensor, +) -> None: ... +@overload +def _foreach_div_( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalar: Number | _complex, +) -> None: ... +@overload +def _foreach_div_( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: tuple[Tensor, ...] | list[Tensor] | None, +) -> None: ... +def _foreach_erf( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_erf(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.erf` to each Tensor of the input list. + """ + +def _foreach_erf_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_erf_(self: List[Tensor]) -> None + + Apply :func:`torch.erf` to each Tensor of the input list. + """ + +def _foreach_erfc( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_erfc(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.erfc` to each Tensor of the input list. + """ + +def _foreach_erfc_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_erfc_(self: List[Tensor]) -> None + + Apply :func:`torch.erfc` to each Tensor of the input list. + """ + +def _foreach_exp( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_exp(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.exp` to each Tensor of the input list. + """ + +def _foreach_exp_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_exp_(self: List[Tensor]) -> None + + Apply :func:`torch.exp` to each Tensor of the input list. + """ + +def _foreach_expm1( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_expm1(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.expm1` to each Tensor of the input list. + """ + +def _foreach_expm1_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_expm1_(self: List[Tensor]) -> None + + Apply :func:`torch.expm1` to each Tensor of the input list. + """ + +def _foreach_floor( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_floor(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.floor` to each Tensor of the input list. + """ + +def _foreach_floor_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_floor_(self: List[Tensor]) -> None + + Apply :func:`torch.floor` to each Tensor of the input list. + """ + +def _foreach_frac( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_frac(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.frac` to each Tensor of the input list. + """ + +def _foreach_frac_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_frac_(self: List[Tensor]) -> None + + Apply :func:`torch.frac` to each Tensor of the input list. + """ + +@overload +def _foreach_lerp( + self: tuple[Tensor, ...] | list[Tensor] | None, + tensors1: tuple[Tensor, ...] | list[Tensor] | None, + weight: Number | _complex, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_lerp( + self: tuple[Tensor, ...] | list[Tensor] | None, + tensors1: tuple[Tensor, ...] | list[Tensor] | None, + weight: Sequence[Number | _complex], +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_lerp( + self: tuple[Tensor, ...] | list[Tensor] | None, + tensors1: tuple[Tensor, ...] | list[Tensor] | None, + weights: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_lerp_( + self: tuple[Tensor, ...] | list[Tensor] | None, + tensors1: tuple[Tensor, ...] | list[Tensor] | None, + weight: Number | _complex, +) -> None: ... +@overload +def _foreach_lerp_( + self: tuple[Tensor, ...] | list[Tensor] | None, + tensors1: tuple[Tensor, ...] | list[Tensor] | None, + weight: Sequence[Number | _complex], +) -> None: ... +@overload +def _foreach_lerp_( + self: tuple[Tensor, ...] | list[Tensor] | None, + tensors1: tuple[Tensor, ...] | list[Tensor] | None, + weights: tuple[Tensor, ...] | list[Tensor] | None, +) -> None: ... +def _foreach_lgamma( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_lgamma(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.lgamma` to each Tensor of the input list. + """ + +def _foreach_lgamma_( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> None: + r""" + _foreach_lgamma_(self: List[Tensor]) -> None + + Apply :func:`torch.lgamma` to each Tensor of the input list. + """ + +def _foreach_log( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_log(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.log` to each Tensor of the input list. + """ + +def _foreach_log10( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_log10(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.log10` to each Tensor of the input list. + """ + +def _foreach_log10_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_log10_(self: List[Tensor]) -> None + + Apply :func:`torch.log10` to each Tensor of the input list. + """ + +def _foreach_log1p( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_log1p(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.log1p` to each Tensor of the input list. + """ + +def _foreach_log1p_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_log1p_(self: List[Tensor]) -> None + + Apply :func:`torch.log1p` to each Tensor of the input list. + """ + +def _foreach_log2( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_log2(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.log2` to each Tensor of the input list. + """ + +def _foreach_log2_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_log2_(self: List[Tensor]) -> None + + Apply :func:`torch.log2` to each Tensor of the input list. + """ + +def _foreach_log_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_log_(self: List[Tensor]) -> None + + Apply :func:`torch.log` to each Tensor of the input list. + """ + +def _foreach_max( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_maximum( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_maximum( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalar: Number | _complex, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_maximum( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_maximum_( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> None: ... +@overload +def _foreach_maximum_( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalar: Number | _complex, +) -> None: ... +@overload +def _foreach_maximum_( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: tuple[Tensor, ...] | list[Tensor] | None, +) -> None: ... +@overload +def _foreach_minimum( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_minimum( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalar: Number | _complex, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_minimum( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_minimum_( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> None: ... +@overload +def _foreach_minimum_( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalar: Number | _complex, +) -> None: ... +@overload +def _foreach_minimum_( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: tuple[Tensor, ...] | list[Tensor] | None, +) -> None: ... +@overload +def _foreach_mul( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_mul( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: Tensor, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_mul( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalar: Number | _complex, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_mul( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_mul_( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> None: ... +@overload +def _foreach_mul_( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: Tensor, +) -> None: ... +@overload +def _foreach_mul_( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalar: Number | _complex, +) -> None: ... +@overload +def _foreach_mul_( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: tuple[Tensor, ...] | list[Tensor] | None, +) -> None: ... +def _foreach_neg( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_neg(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.neg` to each Tensor of the input list. + """ + +def _foreach_neg_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_neg_(self: List[Tensor]) -> None + + Apply :func:`torch.neg` to each Tensor of the input list. + """ + +def _foreach_norm( + self: tuple[Tensor, ...] | list[Tensor] | None, + ord: Number | _complex = 2, + dtype: _dtype | None = None, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_pow( + self: tuple[Tensor, ...] | list[Tensor] | None, + exponent: Sequence[Number | _complex], +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_pow( + self: tuple[Tensor, ...] | list[Tensor] | None, + exponent: Number | _complex, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_pow( + self: tuple[Tensor, ...] | list[Tensor] | None, + exponent: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_pow( + self: Number | _complex, + exponent: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_pow_( + self: tuple[Tensor, ...] | list[Tensor] | None, + exponent: Sequence[Number | _complex], +) -> None: ... +@overload +def _foreach_pow_( + self: tuple[Tensor, ...] | list[Tensor] | None, + exponent: Number | _complex, +) -> None: ... +@overload +def _foreach_pow_( + self: tuple[Tensor, ...] | list[Tensor] | None, + exponent: tuple[Tensor, ...] | list[Tensor] | None, +) -> None: ... +def _foreach_reciprocal( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_reciprocal(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.reciprocal` to each Tensor of the input list. + """ + +def _foreach_reciprocal_( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> None: + r""" + _foreach_reciprocal_(self: List[Tensor]) -> None + + Apply :func:`torch.reciprocal` to each Tensor of the input list. + """ + +def _foreach_round( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_round(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.round` to each Tensor of the input list. + """ + +def _foreach_round_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_round_(self: List[Tensor]) -> None + + Apply :func:`torch.round` to each Tensor of the input list. + """ + +def _foreach_rsqrt( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: ... +def _foreach_rsqrt_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: ... +def _foreach_sigmoid( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_sigmoid(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.sigmoid` to each Tensor of the input list. + """ + +def _foreach_sigmoid_( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> None: + r""" + _foreach_sigmoid_(self: List[Tensor]) -> None + + Apply :func:`torch.sigmoid` to each Tensor of the input list. + """ + +def _foreach_sign( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: ... +def _foreach_sign_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: ... +def _foreach_sin( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_sin(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.sin` to each Tensor of the input list. + """ + +def _foreach_sin_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_sin_(self: List[Tensor]) -> None + + Apply :func:`torch.sin` to each Tensor of the input list. + """ + +def _foreach_sinh( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_sinh(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.sinh` to each Tensor of the input list. + """ + +def _foreach_sinh_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_sinh_(self: List[Tensor]) -> None + + Apply :func:`torch.sinh` to each Tensor of the input list. + """ + +def _foreach_sqrt( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_sqrt(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.sqrt` to each Tensor of the input list. + """ + +def _foreach_sqrt_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_sqrt_(self: List[Tensor]) -> None + + Apply :func:`torch.sqrt` to each Tensor of the input list. + """ + +@overload +def _foreach_sub( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_sub( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: tuple[Tensor, ...] | list[Tensor] | None, + *, + alpha: Number | _complex = 1, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_sub( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalar: Number | _complex, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_sub_( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> None: ... +@overload +def _foreach_sub_( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: tuple[Tensor, ...] | list[Tensor] | None, + *, + alpha: Number | _complex = 1, +) -> None: ... +@overload +def _foreach_sub_( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalar: Number | _complex, +) -> None: ... +def _foreach_tan( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_tan(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.tan` to each Tensor of the input list. + """ + +def _foreach_tan_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_tan_(self: List[Tensor]) -> None + + Apply :func:`torch.tan` to each Tensor of the input list. + """ + +def _foreach_tanh( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_tanh(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.tanh` to each Tensor of the input list. + """ + +def _foreach_tanh_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_tanh_(self: List[Tensor]) -> None + + Apply :func:`torch.tanh` to each Tensor of the input list. + """ + +def _foreach_trunc( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_trunc(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.trunc` to each Tensor of the input list. + """ + +def _foreach_trunc_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_trunc_(self: List[Tensor]) -> None + + Apply :func:`torch.trunc` to each Tensor of the input list. + """ + +def _foreach_zero_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_zero_(self: List[Tensor]) -> None + + Apply :func:`torch.zero` to each Tensor of the input list. + """ + +def _from_functional_tensor(t: Tensor) -> Tensor: ... +def _functional_assert_async( + input: Tensor, + assert_msg: str, + dep_token: Tensor, +) -> Tensor: ... +def _functional_assert_scalar( + self: Number | _complex, + assert_msg: str, + dep_token: Tensor, +) -> Tensor: ... +def _functional_sym_constrain_range( + size: Number | _complex, + min: _int | None, + max: _int | None, + dep_token: Tensor, +) -> Tensor: ... +def _functional_sym_constrain_range_for_size( + size: Number | _complex, + min: _int | None, + max: _int | None, + dep_token: Tensor, +) -> Tensor: ... +def _functionalize_apply_view_metas(tensor: Tensor, base: Tensor) -> Tensor: ... +def _functionalize_are_all_mutations_hidden_from_autograd( + t: Tensor, +) -> _bool: ... +def _functionalize_are_all_mutations_under_no_grad_or_inference_mode( + t: Tensor, +) -> _bool: ... +def _functionalize_commit_update(t: Tensor) -> None: ... +def _functionalize_has_metadata_mutation(tensor: Tensor) -> _bool: ... +def _functionalize_is_symbolic(tensor: Tensor) -> _bool: ... +def _functionalize_mark_mutation_hidden_from_autograd(t: Tensor) -> None: ... +def _functionalize_replace(self_: Tensor, other: Tensor) -> None: ... +def _functionalize_set_storage_changed(tensor: Tensor) -> _bool: ... +def _functionalize_sync(t: Tensor) -> None: ... +def _functionalize_unsafe_set(dst: Tensor, src: Tensor) -> None: ... +def _functionalize_was_inductor_storage_resized(t: Tensor) -> _bool: ... +def _functionalize_was_storage_changed(tensor: Tensor) -> _bool: ... +@overload +def _fused_adagrad_( + self: tuple[Tensor, ...] | list[Tensor] | None, + grads: tuple[Tensor, ...] | list[Tensor] | None, + state_sums: tuple[Tensor, ...] | list[Tensor] | None, + state_steps: tuple[Tensor, ...] | list[Tensor] | None, + *, + lr: Tensor, + lr_decay: _float, + weight_decay: _float, + eps: _float, + maximize: _bool, + grad_scale: Tensor | None = None, + found_inf: Tensor | None = None, +) -> None: ... +@overload +def _fused_adagrad_( + self: tuple[Tensor, ...] | list[Tensor] | None, + grads: tuple[Tensor, ...] | list[Tensor] | None, + state_sums: tuple[Tensor, ...] | list[Tensor] | None, + state_steps: tuple[Tensor, ...] | list[Tensor] | None, + *, + lr: _float, + lr_decay: _float, + weight_decay: _float, + eps: _float, + maximize: _bool, + grad_scale: Tensor | None = None, + found_inf: Tensor | None = None, +) -> None: ... +@overload +def _fused_adam_( + self: tuple[Tensor, ...] | list[Tensor] | None, + grads: tuple[Tensor, ...] | list[Tensor] | None, + exp_avgs: tuple[Tensor, ...] | list[Tensor] | None, + exp_avg_sqs: tuple[Tensor, ...] | list[Tensor] | None, + max_exp_avg_sqs: tuple[Tensor, ...] | list[Tensor] | None, + state_steps: tuple[Tensor, ...] | list[Tensor] | None, + *, + lr: Tensor, + beta1: _float, + beta2: _float, + weight_decay: _float, + eps: _float, + amsgrad: _bool, + maximize: _bool, + grad_scale: Tensor | None = None, + found_inf: Tensor | None = None, +) -> None: ... +@overload +def _fused_adam_( + self: tuple[Tensor, ...] | list[Tensor] | None, + grads: tuple[Tensor, ...] | list[Tensor] | None, + exp_avgs: tuple[Tensor, ...] | list[Tensor] | None, + exp_avg_sqs: tuple[Tensor, ...] | list[Tensor] | None, + max_exp_avg_sqs: tuple[Tensor, ...] | list[Tensor] | None, + state_steps: tuple[Tensor, ...] | list[Tensor] | None, + *, + lr: _float, + beta1: _float, + beta2: _float, + weight_decay: _float, + eps: _float, + amsgrad: _bool, + maximize: _bool, + grad_scale: Tensor | None = None, + found_inf: Tensor | None = None, +) -> None: ... +@overload +def _fused_adamw_( + self: tuple[Tensor, ...] | list[Tensor] | None, + grads: tuple[Tensor, ...] | list[Tensor] | None, + exp_avgs: tuple[Tensor, ...] | list[Tensor] | None, + exp_avg_sqs: tuple[Tensor, ...] | list[Tensor] | None, + max_exp_avg_sqs: tuple[Tensor, ...] | list[Tensor] | None, + state_steps: tuple[Tensor, ...] | list[Tensor] | None, + *, + lr: Tensor, + beta1: _float, + beta2: _float, + weight_decay: _float, + eps: _float, + amsgrad: _bool, + maximize: _bool, + grad_scale: Tensor | None = None, + found_inf: Tensor | None = None, +) -> None: ... +@overload +def _fused_adamw_( + self: tuple[Tensor, ...] | list[Tensor] | None, + grads: tuple[Tensor, ...] | list[Tensor] | None, + exp_avgs: tuple[Tensor, ...] | list[Tensor] | None, + exp_avg_sqs: tuple[Tensor, ...] | list[Tensor] | None, + max_exp_avg_sqs: tuple[Tensor, ...] | list[Tensor] | None, + state_steps: tuple[Tensor, ...] | list[Tensor] | None, + *, + lr: _float, + beta1: _float, + beta2: _float, + weight_decay: _float, + eps: _float, + amsgrad: _bool, + maximize: _bool, + grad_scale: Tensor | None = None, + found_inf: Tensor | None = None, +) -> None: ... +def _fused_dropout( + input: Tensor, + p: _float, + generator: Generator | None = None, +) -> tuple[Tensor, Tensor]: ... +def _fused_moving_avg_obs_fq_helper( + input: Tensor, + observer_on: Tensor, + fake_quant_on: Tensor, + running_min: Tensor, + running_max: Tensor, + scale: Tensor, + zero_point: Tensor, + averaging_const: _float, + quant_min: _int, + quant_max: _int, + ch_axis: _int, + per_row_fake_quant: _bool = False, + symmetric_quant: _bool = False, +) -> torch.return_types._fused_moving_avg_obs_fq_helper: ... +def _fused_rms_norm( + input: Tensor, + normalized_shape_ndim: _int, + weight: Tensor, + eps: _float, +) -> Tensor: ... +def _fused_sdp_choice( + query: Tensor, + key: Tensor, + value: Tensor, + attn_mask: Tensor | None = None, + dropout_p: _float = 0.0, + is_causal: _bool = False, + *, + scale: _float | None = None, + enable_gqa: _bool = False, +) -> _int: ... +@overload +def _fused_sgd_( + self: tuple[Tensor, ...] | list[Tensor] | None, + grads: tuple[Tensor, ...] | list[Tensor] | None, + momentum_buffer_list: tuple[Tensor, ...] | list[Tensor] | None, + *, + weight_decay: _float, + momentum: _float, + lr: Tensor, + dampening: _float, + nesterov: _bool, + maximize: _bool, + is_first_step: _bool, + grad_scale: Tensor | None = None, + found_inf: Tensor | None = None, +) -> None: ... +@overload +def _fused_sgd_( + self: tuple[Tensor, ...] | list[Tensor] | None, + grads: tuple[Tensor, ...] | list[Tensor] | None, + momentum_buffer_list: tuple[Tensor, ...] | list[Tensor] | None, + *, + weight_decay: _float, + momentum: _float, + lr: _float, + dampening: _float, + nesterov: _bool, + maximize: _bool, + is_first_step: _bool, + grad_scale: Tensor | None = None, + found_inf: Tensor | None = None, +) -> None: ... +def _fw_primal_copy( + input: Tensor, + level: _int, + *, + out: Tensor | None = None, +) -> Tensor: ... +def _grid_sampler_2d_cpu_fallback( + input: Tensor, + grid: Tensor, + interpolation_mode: _int, + padding_mode: _int, + align_corners: _bool, +) -> Tensor: ... +def _grouped_mm( + input: Tensor, + mat2: Tensor, + offs: Tensor | None = None, + bias: Tensor | None = None, + out_dtype: _dtype | None = None, +) -> Tensor: ... +def _has_compatible_shallow_copy_type( + input: Tensor, + from_: Tensor, +) -> _bool: ... +def _histogramdd_bin_edges( + input: Tensor, + bins: _size, + *, + range: Sequence[_float] | None = None, + weight: Tensor | None = None, + density: _bool = False, +) -> tuple[Tensor, ...]: ... +def _histogramdd_from_bin_cts( + input: Tensor, + bins: _size, + *, + range: Sequence[_float] | None = None, + weight: Tensor | None = None, + density: _bool = False, +) -> Tensor: ... +def _histogramdd_from_bin_tensors( + input: Tensor, + bins: tuple[Tensor, ...] | list[Tensor] | None, + *, + weight: Tensor | None = None, + density: _bool = False, +) -> Tensor: ... +def _index_put_impl_( + input: Tensor, + indices: tuple[Tensor, ...] | list[Tensor] | None, + values: Tensor, + accumulate: _bool = False, + unsafe: _bool = False, +) -> Tensor: ... +def _indices_copy(input: Tensor, *, out: Tensor | None = None) -> Tensor: ... +def _int_mm( + input: Tensor, + mat2: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: ... +def _is_all_true(input: Tensor) -> Tensor: ... +def _is_any_true(input: Tensor) -> Tensor: ... +def _is_functional_tensor(t: Tensor) -> _bool: ... +def _is_functional_tensor_base(t: Tensor) -> _bool: ... +def _is_zerotensor(input: Tensor) -> _bool: ... +def _lazy_clone(input: Tensor) -> Tensor: ... +def _linalg_check_errors( + info: Tensor, + api_name: str, + *, + is_matrix: _bool, +) -> None: ... +def _linalg_det( + A: Tensor, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types._linalg_det: ... +def _linalg_eigh( + A: Tensor, + UPLO: str = "L", + compute_v: _bool = True, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types._linalg_eigh: ... +def _linalg_slogdet( + A: Tensor, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types._linalg_slogdet: ... +def _linalg_solve_ex( + A: Tensor, + B: Tensor, + *, + left: _bool = True, + check_errors: _bool = False, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types._linalg_solve_ex: ... +def _linalg_svd( + A: Tensor, + full_matrices: _bool = False, + compute_uv: _bool = True, + *, + driver: str | None = None, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types._linalg_svd: ... +def _log_softmax( + input: Tensor, + dim: _int, + half_to_float: _bool, + *, + out: Tensor | None = None, +) -> Tensor: ... +def _log_softmax_backward_data( + grad_output: Tensor, + output: Tensor, + dim: _int, + input_dtype: _dtype, + *, + out: Tensor | None = None, +) -> Tensor: ... +def _logcumsumexp( + input: Tensor, + dim: _int, + *, + out: Tensor | None = None, +) -> Tensor: ... +def _lstm_mps( + input: Tensor, + hx: tuple[Tensor, ...] | list[Tensor] | None, + params: tuple[Tensor, ...] | list[Tensor] | None, + has_biases: _bool, + num_layers: _int, + dropout: _float, + train: _bool, + bidirectional: _bool, + batch_first: _bool, +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: ... +def _lu_with_info( + input: Tensor, + pivot: _bool = True, + check_errors: _bool = True, +) -> torch.return_types._lu_with_info: ... +def _make_dep_token( + *, + memory_format: memory_format | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: ... +def _make_dual(primal: Tensor, tangent: Tensor, level: _int) -> Tensor: ... +def _make_dual_copy( + primal: Tensor, + tangent: Tensor, + level: _int, + *, + out: Tensor | None = None, +) -> Tensor: ... +def _make_per_channel_quantized_tensor( + input: Tensor, + scale: Tensor, + zero_point: Tensor, + axis: _int, +) -> Tensor: ... +def _make_per_tensor_quantized_tensor( + input: Tensor, + scale: _float, + zero_point: _int, +) -> Tensor: ... +def _masked_scale(input: Tensor, mask: Tensor, scale: _float) -> Tensor: ... +def _masked_softmax( + input: Tensor, + mask: Tensor, + dim: _int | None = None, + mask_type: _int | None = None, +) -> Tensor: ... +def _mixed_dtypes_linear( + input: Tensor, + weight: Tensor, + scale: Tensor, + *, + bias: Tensor | None = None, + activation: str | None = None, +) -> Tensor: ... +def _mkldnn_reshape(input: Tensor, shape: _size) -> Tensor: ... +def _mkldnn_transpose(input: Tensor, dim0: _int, dim1: _int) -> Tensor: ... +def _mkldnn_transpose_(input: Tensor, dim0: _int, dim1: _int) -> Tensor: ... +def _mps_convolution( + input: Tensor, + weight: Tensor, + bias: Tensor | None, + padding: Sequence[_int | SymInt], + stride: Sequence[_int | SymInt], + dilation: Sequence[_int | SymInt], + groups: _int | SymInt, +) -> Tensor: ... +def _mps_convolution_transpose( + input: Tensor, + weight: Tensor, + padding: Sequence[_int | SymInt], + output_padding: Sequence[_int | SymInt], + stride: Sequence[_int | SymInt], + dilation: Sequence[_int | SymInt], + groups: _int | SymInt, +) -> Tensor: ... +@overload +def _native_batch_norm_legit( + input: Tensor, + weight: Tensor | None, + bias: Tensor | None, + running_mean: Tensor, + running_var: Tensor, + training: _bool, + momentum: _float, + eps: _float, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> tuple[Tensor, Tensor, Tensor]: ... +@overload +def _native_batch_norm_legit( + input: Tensor, + weight: Tensor | None, + bias: Tensor | None, + training: _bool, + momentum: _float, + eps: _float, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> tuple[Tensor, Tensor, Tensor]: ... +def _native_batch_norm_legit_no_training( + input: Tensor, + weight: Tensor | None, + bias: Tensor | None, + running_mean: Tensor, + running_var: Tensor, + momentum: _float, + eps: _float, +) -> tuple[Tensor, Tensor, Tensor]: ... +def _native_multi_head_attention( + query: Tensor, + key: Tensor, + value: Tensor, + embed_dim: _int, + num_head: _int, + qkv_weight: Tensor, + qkv_bias: Tensor, + proj_weight: Tensor, + proj_bias: Tensor, + mask: Tensor | None = None, + need_weights: _bool = True, + average_attn_weights: _bool = True, + mask_type: _int | None = None, +) -> tuple[Tensor, Tensor]: ... +def _neg_view(input: Tensor) -> Tensor: ... +def _neg_view_copy(input: Tensor, *, out: Tensor | None = None) -> Tensor: ... +def _nested_compute_contiguous_strides_offsets( + nested_size: Tensor, +) -> tuple[Tensor, Tensor]: ... +def _nested_from_padded( + padded: Tensor, + cpu_nested_shape_example: Tensor, + fuse_transform_0213: _bool = False, +) -> Tensor: ... +def _nested_from_padded_and_nested_example( + padded: Tensor, + nt_example: Tensor, +) -> Tensor: ... +def _nested_from_padded_tensor( + padded: Tensor, + offsets: Tensor, + dummy: Tensor, + ragged_idx: _int = 1, + min_seqlen: Tensor | None = None, + max_seqlen: Tensor | None = None, + sum_S: _int | SymInt | None = None, +) -> Tensor: ... +def _nested_get_jagged_dummy(any: Tensor) -> Tensor: ... +def _nested_get_lengths(input: Tensor) -> Tensor: ... +def _nested_get_max_seqlen(input: Tensor) -> Tensor: ... +def _nested_get_min_seqlen(input: Tensor) -> Tensor: ... +def _nested_get_offsets(input: Tensor) -> Tensor: ... +def _nested_get_ragged_idx(input: Tensor) -> _int: ... +def _nested_get_values(input: Tensor) -> Tensor: ... +def _nested_get_values_copy( + input: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: ... +def _nested_tensor_from_mask( + t: Tensor, + mask: Tensor, + mask_check: _bool = True, +) -> Tensor: ... +def _nested_tensor_from_mask_left_aligned(t: Tensor, mask: Tensor) -> _bool: ... +def _nested_tensor_from_tensor_list( + list: tuple[Tensor, ...] | list[Tensor] | None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = None, +) -> Tensor: ... +def _nested_tensor_softmax_with_shape( + input: Tensor, + query: Tensor, +) -> Tensor: ... +def _nested_view_from_buffer( + input: Tensor, + nested_size: Tensor, + nested_strides: Tensor, + offsets: Tensor, +) -> Tensor: ... +def _nested_view_from_buffer_copy( + input: Tensor, + nested_size: Tensor, + nested_strides: Tensor, + offsets: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: ... +def _nested_view_from_jagged( + input: Tensor, + offsets: Tensor, + dummy: Tensor, + lengths: Tensor | None = None, + ragged_idx: _int = 1, + min_seqlen: Tensor | None = None, + max_seqlen: Tensor | None = None, +) -> Tensor: ... +def _nested_view_from_jagged_copy( + input: Tensor, + offsets: Tensor, + dummy: Tensor, + lengths: Tensor | None = None, + ragged_idx: _int = 1, + min_seqlen: Tensor | None = None, + max_seqlen: Tensor | None = None, + *, + out: Tensor | None = None, +) -> Tensor: ... +def _nnpack_available() -> _bool: ... +def _nnpack_spatial_convolution( + input: Tensor, + weight: Tensor, + bias: Tensor | None, + padding: _int | SymInt | Sequence[_int | SymInt], + stride: _int | SymInt | Sequence[_int | SymInt] = 1, +) -> Tensor: ... +def _pack_padded_sequence( + input: Tensor, + lengths: Tensor, + batch_first: _bool, +) -> tuple[Tensor, Tensor]: ... +def _pad_packed_sequence( + data: Tensor, + batch_sizes: Tensor, + batch_first: _bool, + padding_value: Number | _complex, + total_length: _int, +) -> tuple[Tensor, Tensor]: ... +def _pin_memory( + input: Tensor, + device: DeviceLikeType | None = None, +) -> Tensor: ... +def _prelu_kernel(input: Tensor, weight: Tensor) -> Tensor: ... +def _print(s: str) -> None: ... +def _propagate_xla_data(input: Tensor, output: Tensor) -> None: ... +def _remove_batch_dim( + input: Tensor, + level: _int, + batch_size: _int | SymInt, + out_dim: _int, +) -> Tensor: ... +def _reshape_alias_copy( + input: Tensor, + size: Sequence[_int | SymInt], + stride: Sequence[_int | SymInt], + *, + out: Tensor | None = None, +) -> Tensor: ... +def _reshape_from_tensor(input: Tensor, shape: Tensor) -> Tensor: ... +def _resize_output_( + input: Tensor, + size: Sequence[_int | SymInt], + device: DeviceLikeType | None, +) -> Tensor: ... +def _rowwise_prune( + weight: Tensor, + mask: Tensor, + compressed_indices_dtype: _dtype, +) -> tuple[Tensor, Tensor]: ... +def _safe_softmax( + input: Tensor, + dim: _int, + dtype: _dtype | None = None, +) -> Tensor: ... +def _sample_dirichlet( + input: Tensor, + generator: Generator | None = None, +) -> Tensor: ... +def _saturate_weight_to_fp16(weight: Tensor) -> Tensor: ... +def _scaled_dot_product_attention_math( + query: Tensor, + key: Tensor, + value: Tensor, + attn_mask: Tensor | None = None, + dropout_p: _float = 0.0, + is_causal: _bool = False, + dropout_mask: Tensor | None = None, + *, + scale: _float | None = None, + enable_gqa: _bool = False, +) -> tuple[Tensor, Tensor]: ... +def _scaled_dot_product_attention_math_for_mps( + query: Tensor, + key: Tensor, + value: Tensor, + attn_mask: Tensor | None = None, + dropout_p: _float = 0.0, + is_causal: _bool = False, + dropout_mask: Tensor | None = None, + *, + scale: _float | None = None, +) -> tuple[Tensor, Tensor]: ... +def _scaled_dot_product_cudnn_attention( + query: Tensor, + key: Tensor, + value: Tensor, + attn_bias: Tensor | None, + compute_log_sumexp: _bool, + dropout_p: _float = 0.0, + is_causal: _bool = False, + return_debug_mask: _bool = False, + *, + scale: _float | None = None, +) -> torch.return_types._scaled_dot_product_cudnn_attention: ... +def _scaled_dot_product_efficient_attention( + query: Tensor, + key: Tensor, + value: Tensor, + attn_bias: Tensor | None, + compute_log_sumexp: _bool, + dropout_p: _float = 0.0, + is_causal: _bool = False, + *, + scale: _float | None = None, +) -> torch.return_types._scaled_dot_product_efficient_attention: ... +def _scaled_dot_product_flash_attention( + query: Tensor, + key: Tensor, + value: Tensor, + dropout_p: _float = 0.0, + is_causal: _bool = False, + return_debug_mask: _bool = False, + *, + scale: _float | None = None, +) -> torch.return_types._scaled_dot_product_flash_attention: ... +def _scaled_dot_product_flash_attention_for_cpu( + query: Tensor, + key: Tensor, + value: Tensor, + dropout_p: _float = 0.0, + is_causal: _bool = False, + *, + attn_mask: Tensor | None = None, + scale: _float | None = None, +) -> torch.return_types._scaled_dot_product_flash_attention_for_cpu: ... +def _scaled_grouped_mm( + input: Tensor, + mat2: Tensor, + scale_a: Tensor, + scale_b: Tensor, + offs: Tensor | None = None, + bias: Tensor | None = None, + scale_result: Tensor | None = None, + out_dtype: _dtype | None = None, + use_fast_accum: _bool = False, +) -> Tensor: ... +def _scaled_mm( + input: Tensor, + mat2: Tensor, + scale_a: Tensor, + scale_b: Tensor, + bias: Tensor | None = None, + scale_result: Tensor | None = None, + out_dtype: _dtype | None = None, + use_fast_accum: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: ... +def _shape_as_tensor(input: Tensor) -> Tensor: ... +def _sobol_engine_draw( + quasi: Tensor, + n: _int, + sobolstate: Tensor, + dimension: _int, + num_generated: _int, + dtype: _dtype | None, +) -> tuple[Tensor, Tensor]: ... +def _sobol_engine_ff_( + input: Tensor, + n: _int, + sobolstate: Tensor, + dimension: _int, + num_generated: _int, +) -> Tensor: ... +def _sobol_engine_initialize_state_( + input: Tensor, + dimension: _int, +) -> Tensor: ... +def _sobol_engine_scramble_( + input: Tensor, + ltm: Tensor, + dimension: _int, +) -> Tensor: ... +def _softmax( + input: Tensor, + dim: _int, + half_to_float: _bool, + *, + out: Tensor | None = None, +) -> Tensor: ... +def _softmax_backward_data( + grad_output: Tensor, + output: Tensor, + dim: _int, + input_dtype: _dtype, + *, + grad_input: Tensor | None = None, +) -> Tensor: ... +def _sparse_broadcast_to(input: Tensor, size: _size) -> Tensor: ... +def _sparse_broadcast_to_copy( + input: Tensor, + size: _size, + *, + out: Tensor | None = None, +) -> Tensor: ... +def _sparse_csr_prod( + input: Tensor, + dim: _int | _size, + keepdim: _bool = False, + *, + dtype: _dtype | None = None, +) -> Tensor: ... +def _sparse_csr_sum( + input: Tensor, + dim: _int | _size, + keepdim: _bool = False, + *, + dtype: _dtype | None = None, +) -> Tensor: ... +def _sparse_log_softmax_backward_data( + grad_output: Tensor, + output: Tensor, + dim: _int, + input: Tensor, +) -> Tensor: ... +def _sparse_semi_structured_addmm( + input: Tensor, + mat1: Tensor, + mat1_meta: Tensor, + mat2: Tensor, + *, + alpha: Number | _complex = 1, + beta: Number | _complex = 1, + out_dtype: _dtype | None = None, +) -> Tensor: ... +def _sparse_semi_structured_apply( + input: Tensor, + thread_masks: Tensor, +) -> tuple[Tensor, Tensor]: ... +def _sparse_semi_structured_apply_dense( + input: Tensor, + thread_masks: Tensor, +) -> Tensor: ... +def _sparse_semi_structured_linear( + input: Tensor, + weight: Tensor, + meta: Tensor, + *, + bias: Tensor | None = None, + activation: str | None = None, + out_dtype: _dtype | None = None, +) -> Tensor: ... +def _sparse_semi_structured_mm( + mat1: Tensor, + mat1_meta: Tensor, + mat2: Tensor, + *, + out_dtype: _dtype | None = None, +) -> Tensor: ... +def _sparse_semi_structured_tile( + input: Tensor, + algorithm: str = "", + use_cutlass: _bool = True, +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: ... +def _sparse_softmax_backward_data( + grad_output: Tensor, + output: Tensor, + dim: _int, + input: Tensor, +) -> Tensor: ... +def _sparse_sparse_matmul(input: Tensor, other: Tensor) -> Tensor: ... +@overload +def _sparse_sum(input: Tensor) -> Tensor: ... +@overload +def _sparse_sum(input: Tensor, *, dtype: _dtype) -> Tensor: ... +@overload +def _sparse_sum(input: Tensor, dim: _int | _size) -> Tensor: ... +@overload +def _sparse_sum( + input: Tensor, + dim: _int | _size, + *, + dtype: _dtype, +) -> Tensor: ... +def _stack( + tensors: tuple[Tensor, ...] | list[Tensor] | None, + dim: _int = 0, + *, + out: Tensor | None = None, +) -> Tensor: ... +def _standard_gamma( + input: Tensor, + generator: Generator | None = None, +) -> Tensor: ... +def _standard_gamma_grad(input: Tensor, output: Tensor) -> Tensor: ... +def _sync(t: Tensor) -> None: ... +@overload +def _test_autograd_multiple_dispatch(input: Tensor) -> Tensor: ... +@overload +def _test_autograd_multiple_dispatch(input: Tensor, b: _bool) -> Tensor: ... +def _test_autograd_multiple_dispatch_view(input: Tensor) -> Tensor: ... +def _test_autograd_multiple_dispatch_view_copy( + input: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: ... +def _test_check_tensor(input: Tensor) -> Tensor: ... +def _test_functorch_fallback(input: Tensor, other: Tensor) -> Tensor: ... +def _test_parallel_materialize( + input: Tensor, + num_parallel: _int, + skip_first: _bool = False, +) -> Tensor: ... +def _test_serialization_subcmul( + input: Tensor, + other: Tensor, + alpha: Number | _complex = 1, +) -> Tensor: ... +def _to_cpu( + tensors: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: ... +def _to_functional_tensor(t: Tensor) -> Tensor: ... +def _to_sparse_semi_structured(dense: Tensor) -> tuple[Tensor, Tensor]: ... +def _transform_bias_rescale_qkv( + qkv: Tensor, + qkv_bias: Tensor, + num_heads: _int, +) -> tuple[Tensor, Tensor, Tensor]: ... +def _transformer_encoder_layer_fwd( + src: Tensor, + embed_dim: _int, + num_heads: _int, + qkv_weight: Tensor, + qkv_bias: Tensor, + proj_weight: Tensor, + proj_bias: Tensor, + use_gelu: _bool, + norm_first: _bool, + eps: _float, + norm_weight_1: Tensor, + norm_bias_1: Tensor, + norm_weight_2: Tensor, + norm_bias_2: Tensor, + ffn_weight_1: Tensor, + ffn_bias_1: Tensor, + ffn_weight_2: Tensor, + ffn_bias_2: Tensor, + mask: Tensor | None = None, + mask_type: _int | None = None, +) -> Tensor: ... +def _trilinear( + i1: Tensor, + i2: Tensor, + i3: Tensor, + expand1: _size, + expand2: _size, + expand3: _size, + sumdim: _size, + unroll_dim: _int = 1, +) -> Tensor: ... +def _triton_multi_head_attention( + query: Tensor, + key: Tensor, + value: Tensor, + embed_dim: _int, + num_head: _int, + qkv_weight: Tensor, + qkv_bias: Tensor, + proj_weight: Tensor, + proj_bias: Tensor, + mask: Tensor | None = None, +) -> Tensor: ... +def _triton_scaled_dot_attention( + q: Tensor, + k: Tensor, + v: Tensor, + dropout_p: _float = 0.0, +) -> Tensor: ... +def _unique( + input: Tensor, + sorted: _bool = True, + return_inverse: _bool = False, +) -> tuple[Tensor, Tensor]: ... +def _unique2( + input: Tensor, + sorted: _bool = True, + return_inverse: _bool = False, + return_counts: _bool = False, +) -> tuple[Tensor, Tensor, Tensor]: ... +def _unpack_dual( + dual: Tensor, + level: _int, +) -> torch.return_types._unpack_dual: ... +def _unsafe_index( + input: Tensor, + indices: tuple[Tensor, ...] | list[Tensor] | None, +) -> Tensor: ... +def _unsafe_index_put( + input: Tensor, + indices: tuple[Tensor, ...] | list[Tensor] | None, + values: Tensor, + accumulate: _bool = False, +) -> Tensor: ... +def _unsafe_masked_index( + input: Tensor, + mask: Tensor, + indices: tuple[Tensor, ...] | list[Tensor] | None, + fill: Number | _complex, +) -> Tensor: ... +def _unsafe_masked_index_put_accumulate( + input: Tensor, + mask: Tensor, + indices: tuple[Tensor, ...] | list[Tensor] | None, + values: Tensor, +) -> Tensor: ... +@overload +def _use_cudnn_ctc_loss( + log_probs: Tensor, + targets: Tensor, + input_lengths: Tensor, + target_lengths: Tensor, + blank: _int, +) -> _bool: ... +@overload +def _use_cudnn_ctc_loss( + log_probs: Tensor, + targets: Tensor, + input_lengths: _size, + target_lengths: _size, + blank: _int, +) -> _bool: ... +def _use_cudnn_rnn_flatten_weight() -> _bool: ... +def _validate_compressed_sparse_indices( + is_crow: _bool, + compressed_idx: Tensor, + plain_idx: Tensor, + cdim: _int, + dim: _int, + nnz: _int, +) -> None: ... +def _validate_sparse_bsc_tensor_args( + ccol_indices: Tensor, + row_indices: Tensor, + values: Tensor, + size: _size, + check_pinning: _bool | None = None, +) -> None: ... +def _validate_sparse_bsr_tensor_args( + crow_indices: Tensor, + col_indices: Tensor, + values: Tensor, + size: _size, + check_pinning: _bool | None = None, +) -> None: ... +def _validate_sparse_compressed_tensor_args( + compressed_indices: Tensor, + plain_indices: Tensor, + values: Tensor, + size: _size, + layout: _layout, + check_pinning: _bool | None = None, +) -> None: ... +def _validate_sparse_coo_tensor_args( + indices: Tensor, + values: Tensor, + size: _size, + is_coalesced: _bool | None = None, + check_pinning: _bool | None = None, +) -> None: ... +def _validate_sparse_csc_tensor_args( + ccol_indices: Tensor, + row_indices: Tensor, + values: Tensor, + size: _size, + check_pinning: _bool | None = None, +) -> None: ... +def _validate_sparse_csr_tensor_args( + crow_indices: Tensor, + col_indices: Tensor, + values: Tensor, + size: _size, + check_pinning: _bool | None = None, +) -> None: ... +def _values_copy(input: Tensor, *, out: Tensor | None = None) -> Tensor: ... +def _weight_int4pack_mm( + input: Tensor, + mat2: Tensor, + qGroupSize: _int, + qScaleAndZeros: Tensor, +) -> Tensor: ... +def _weight_int4pack_mm_for_cpu( + input: Tensor, + mat2: Tensor, + qGroupSize: _int, + qScaleAndZeros: Tensor, +) -> Tensor: ... +def _weight_int4pack_mm_with_scales_and_zeros( + input: Tensor, + mat2: Tensor, + qGroupSize: _int, + qScale: Tensor, + qZeros: Tensor, +) -> Tensor: ... +def _weight_int8pack_mm( + input: Tensor, + mat2: Tensor, + scales: Tensor, +) -> Tensor: ... +def _weight_norm(v: Tensor, g: Tensor, dim: _int = 0) -> Tensor: ... +def _weight_norm_interface( + v: Tensor, + g: Tensor, + dim: _int = 0, +) -> tuple[Tensor, Tensor]: ... +def _wrapped_linear_prepack( + weight: Tensor, + weight_scale: Tensor, + weight_zero_point: Tensor, + bias: Tensor, +) -> Tensor: ... +def _wrapped_quantized_linear_prepacked( + input: Tensor, + input_scale: Tensor, + input_zero_point: Tensor, + packed_weight: Tensor, + output_scale: Tensor, + output_zero_point: Tensor, + out_channel: _int, +) -> Tensor: ... +def abs(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + abs(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Computes the absolute value of each element in :attr:`input`. + + .. math:: + \text{out}_{i} = |\text{input}_{i}| + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.abs(torch.tensor([-1, -2, 3])) + tensor([ 1, 2, 3]) + """ + +def abs_(input: Tensor) -> Tensor: ... +def absolute(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + absolute(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Alias for :func:`torch.abs` + """ + +def acos(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + acos(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Computes the inverse cosine of each element in :attr:`input`. + + .. math:: + \text{out}_{i} = \cos^{-1}(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.3348, -0.5889, 0.2005, -0.1584]) + >>> torch.acos(a) + tensor([ 1.2294, 2.2004, 1.3690, 1.7298]) + """ + +def acos_(input: Tensor) -> Tensor: ... +def acosh(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + acosh(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Returns a new tensor with the inverse hyperbolic cosine of the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \cosh^{-1}(\text{input}_{i}) + + Note: + The domain of the inverse hyperbolic cosine is `[1, inf)` and values outside this range + will be mapped to ``NaN``, except for `+ INF` for which the output is mapped to `+ INF`. + + Args: + input (Tensor): the input tensor. + + Keyword arguments: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4).uniform_(1, 2) + >>> a + tensor([ 1.3192, 1.9915, 1.9674, 1.7151 ]) + >>> torch.acosh(a) + tensor([ 0.7791, 1.3120, 1.2979, 1.1341 ]) + """ + +def acosh_(input: Tensor) -> Tensor: ... +def adaptive_avg_pool1d(input: Tensor, output_size: _int | _size) -> Tensor: ... +def adaptive_max_pool1d( + input: Tensor, + output_size: _int | _size, +) -> tuple[Tensor, Tensor]: ... +@overload +def add( + input: Tensor | Number | _complex, + other: Tensor | Number | _complex, + *, + alpha: Number | _complex | None = 1, + out: Tensor | None = None, +) -> Tensor: + r""" + add(input, other, *, alpha=1, out=None) -> Tensor + + Adds :attr:`other`, scaled by :attr:`alpha`, to :attr:`input`. + + .. math:: + \text{{out}}_i = \text{{input}}_i + \text{{alpha}} \times \text{{other}}_i + + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer, float, and complex inputs. + + Args: + input (Tensor): the input tensor. + other (Tensor or Number): the tensor or number to add to :attr:`input`. + + Keyword arguments: + alpha (Number): the multiplier for :attr:`other`. + out (Tensor, optional): the output tensor. + + Examples:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.0202, 1.0985, 1.3506, -0.6056]) + >>> torch.add(a, 20) + tensor([ 20.0202, 21.0985, 21.3506, 19.3944]) + + >>> b = torch.randn(4) + >>> b + tensor([-0.9732, -0.3497, 0.6245, 0.4022]) + >>> c = torch.randn(4, 1) + >>> c + tensor([[ 0.3743], + [-1.7724], + [-0.5811], + [-0.8017]]) + >>> torch.add(b, c, alpha=10) + tensor([[ 2.7695, 3.3930, 4.3672, 4.1450], + [-18.6971, -18.0736, -17.0994, -17.3216], + [ -6.7845, -6.1610, -5.1868, -5.4090], + [ -8.9902, -8.3667, -7.3925, -7.6147]]) + """ + +@overload +def add(self: Tensor, alpha: Number | _complex, other: Tensor) -> Tensor: + r""" + add(input, other, *, alpha=1, out=None) -> Tensor + + Adds :attr:`other`, scaled by :attr:`alpha`, to :attr:`input`. + + .. math:: + \text{{out}}_i = \text{{input}}_i + \text{{alpha}} \times \text{{other}}_i + + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer, float, and complex inputs. + + Args: + input (Tensor): the input tensor. + other (Tensor or Number): the tensor or number to add to :attr:`input`. + + Keyword arguments: + alpha (Number): the multiplier for :attr:`other`. + out (Tensor, optional): the output tensor. + + Examples:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.0202, 1.0985, 1.3506, -0.6056]) + >>> torch.add(a, 20) + tensor([ 20.0202, 21.0985, 21.3506, 19.3944]) + + >>> b = torch.randn(4) + >>> b + tensor([-0.9732, -0.3497, 0.6245, 0.4022]) + >>> c = torch.randn(4, 1) + >>> c + tensor([[ 0.3743], + [-1.7724], + [-0.5811], + [-0.8017]]) + >>> torch.add(b, c, alpha=10) + tensor([[ 2.7695, 3.3930, 4.3672, 4.1450], + [-18.6971, -18.0736, -17.0994, -17.3216], + [ -6.7845, -6.1610, -5.1868, -5.4090], + [ -8.9902, -8.3667, -7.3925, -7.6147]]) + """ + +@overload +def add( + self: Tensor, + alpha: Number | _complex, + other: Tensor, + *, + out: Tensor, +) -> Tensor: + r""" + add(input, other, *, alpha=1, out=None) -> Tensor + + Adds :attr:`other`, scaled by :attr:`alpha`, to :attr:`input`. + + .. math:: + \text{{out}}_i = \text{{input}}_i + \text{{alpha}} \times \text{{other}}_i + + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer, float, and complex inputs. + + Args: + input (Tensor): the input tensor. + other (Tensor or Number): the tensor or number to add to :attr:`input`. + + Keyword arguments: + alpha (Number): the multiplier for :attr:`other`. + out (Tensor, optional): the output tensor. + + Examples:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.0202, 1.0985, 1.3506, -0.6056]) + >>> torch.add(a, 20) + tensor([ 20.0202, 21.0985, 21.3506, 19.3944]) + + >>> b = torch.randn(4) + >>> b + tensor([-0.9732, -0.3497, 0.6245, 0.4022]) + >>> c = torch.randn(4, 1) + >>> c + tensor([[ 0.3743], + [-1.7724], + [-0.5811], + [-0.8017]]) + >>> torch.add(b, c, alpha=10) + tensor([[ 2.7695, 3.3930, 4.3672, 4.1450], + [-18.6971, -18.0736, -17.0994, -17.3216], + [ -6.7845, -6.1610, -5.1868, -5.4090], + [ -8.9902, -8.3667, -7.3925, -7.6147]]) + """ + +@overload +def addbmm( + beta: Number | _complex, + self: Tensor, + alpha: Number | _complex, + batch1: Tensor, + batch2: Tensor, +) -> Tensor: + r""" + addbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices stored + in :attr:`batch1` and :attr:`batch2`, + with a reduced add step (all matrix multiplications get accumulated + along the first dimension). + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the + same number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + .. math:: + out = \beta\ \text{input} + \alpha\ (\sum_{i=0}^{b-1} \text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and :attr:`alpha` + must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): matrix to be added + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for `batch1 @ batch2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.addbmm(M, batch1, batch2) + tensor([[ 6.6311, 0.0503, 6.9768, -12.0362, -2.1653], + [ -4.8185, -1.4255, -6.6760, 8.9453, 2.5743], + [ -3.8202, 4.3691, 1.0943, -1.1109, 5.4730]]) + """ + +@overload +def addbmm( + beta: Number | _complex, + self: Tensor, + alpha: Number | _complex, + batch1: Tensor, + batch2: Tensor, + *, + out: Tensor, +) -> Tensor: + r""" + addbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices stored + in :attr:`batch1` and :attr:`batch2`, + with a reduced add step (all matrix multiplications get accumulated + along the first dimension). + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the + same number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + .. math:: + out = \beta\ \text{input} + \alpha\ (\sum_{i=0}^{b-1} \text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and :attr:`alpha` + must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): matrix to be added + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for `batch1 @ batch2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.addbmm(M, batch1, batch2) + tensor([[ 6.6311, 0.0503, 6.9768, -12.0362, -2.1653], + [ -4.8185, -1.4255, -6.6760, 8.9453, 2.5743], + [ -3.8202, 4.3691, 1.0943, -1.1109, 5.4730]]) + """ + +@overload +def addbmm( + input: Tensor, + batch1: Tensor, + batch2: Tensor, + *, + beta: Number | _complex = 1, + alpha: Number | _complex = 1, + out: Tensor | None = None, +) -> Tensor: + r""" + addbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices stored + in :attr:`batch1` and :attr:`batch2`, + with a reduced add step (all matrix multiplications get accumulated + along the first dimension). + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the + same number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + .. math:: + out = \beta\ \text{input} + \alpha\ (\sum_{i=0}^{b-1} \text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and :attr:`alpha` + must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): matrix to be added + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for `batch1 @ batch2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.addbmm(M, batch1, batch2) + tensor([[ 6.6311, 0.0503, 6.9768, -12.0362, -2.1653], + [ -4.8185, -1.4255, -6.6760, 8.9453, 2.5743], + [ -3.8202, 4.3691, 1.0943, -1.1109, 5.4730]]) + """ + +@overload +def addbmm( + beta: Number | _complex, + self: Tensor, + batch1: Tensor, + batch2: Tensor, +) -> Tensor: + r""" + addbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices stored + in :attr:`batch1` and :attr:`batch2`, + with a reduced add step (all matrix multiplications get accumulated + along the first dimension). + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the + same number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + .. math:: + out = \beta\ \text{input} + \alpha\ (\sum_{i=0}^{b-1} \text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and :attr:`alpha` + must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): matrix to be added + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for `batch1 @ batch2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.addbmm(M, batch1, batch2) + tensor([[ 6.6311, 0.0503, 6.9768, -12.0362, -2.1653], + [ -4.8185, -1.4255, -6.6760, 8.9453, 2.5743], + [ -3.8202, 4.3691, 1.0943, -1.1109, 5.4730]]) + """ + +@overload +def addbmm( + beta: Number | _complex, + self: Tensor, + batch1: Tensor, + batch2: Tensor, + *, + out: Tensor, +) -> Tensor: + r""" + addbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices stored + in :attr:`batch1` and :attr:`batch2`, + with a reduced add step (all matrix multiplications get accumulated + along the first dimension). + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the + same number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + .. math:: + out = \beta\ \text{input} + \alpha\ (\sum_{i=0}^{b-1} \text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and :attr:`alpha` + must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): matrix to be added + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for `batch1 @ batch2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.addbmm(M, batch1, batch2) + tensor([[ 6.6311, 0.0503, 6.9768, -12.0362, -2.1653], + [ -4.8185, -1.4255, -6.6760, 8.9453, 2.5743], + [ -3.8202, 4.3691, 1.0943, -1.1109, 5.4730]]) + """ + +@overload +def addcdiv( + self: Tensor, + value: Number | _complex, + tensor1: Tensor, + tensor2: Tensor, +) -> Tensor: + r""" + addcdiv(input, tensor1, tensor2, *, value=1, out=None) -> Tensor + + Performs the element-wise division of :attr:`tensor1` by :attr:`tensor2`, + multiplies the result by the scalar :attr:`value` and adds it to :attr:`input`. + + .. warning:: + Integer division with addcdiv is no longer supported, and in a future + release addcdiv will perform a true division of tensor1 and tensor2. + The historic addcdiv behavior can be implemented as + (input + value * torch.trunc(tensor1 / tensor2)).to(input.dtype) + for integer inputs and as (input + value * tensor1 / tensor2) for float inputs. + The future addcdiv behavior is just the latter implementation: + (input + value * tensor1 / tensor2), for all dtypes. + + .. math:: + \text{out}_i = \text{input}_i + \text{value} \times \frac{\text{tensor1}_i}{\text{tensor2}_i} + + + The shapes of :attr:`input`, :attr:`tensor1`, and :attr:`tensor2` must be + :ref:`broadcastable `. + + For inputs of type `FloatTensor` or `DoubleTensor`, :attr:`value` must be + a real number, otherwise an integer. + + Args: + input (Tensor): the tensor to be added + tensor1 (Tensor): the numerator tensor + tensor2 (Tensor): the denominator tensor + + Keyword args: + value (Number, optional): multiplier for :math:`\text{tensor1} / \text{tensor2}` + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.randn(1, 3) + >>> t1 = torch.randn(3, 1) + >>> t2 = torch.randn(1, 3) + >>> torch.addcdiv(t, t1, t2, value=0.1) + tensor([[-0.2312, -3.6496, 0.1312], + [-1.0428, 3.4292, -0.1030], + [-0.5369, -0.9829, 0.0430]]) + """ + +@overload +def addcdiv( + self: Tensor, + value: Number | _complex, + tensor1: Tensor, + tensor2: Tensor, + *, + out: Tensor, +) -> Tensor: + r""" + addcdiv(input, tensor1, tensor2, *, value=1, out=None) -> Tensor + + Performs the element-wise division of :attr:`tensor1` by :attr:`tensor2`, + multiplies the result by the scalar :attr:`value` and adds it to :attr:`input`. + + .. warning:: + Integer division with addcdiv is no longer supported, and in a future + release addcdiv will perform a true division of tensor1 and tensor2. + The historic addcdiv behavior can be implemented as + (input + value * torch.trunc(tensor1 / tensor2)).to(input.dtype) + for integer inputs and as (input + value * tensor1 / tensor2) for float inputs. + The future addcdiv behavior is just the latter implementation: + (input + value * tensor1 / tensor2), for all dtypes. + + .. math:: + \text{out}_i = \text{input}_i + \text{value} \times \frac{\text{tensor1}_i}{\text{tensor2}_i} + + + The shapes of :attr:`input`, :attr:`tensor1`, and :attr:`tensor2` must be + :ref:`broadcastable `. + + For inputs of type `FloatTensor` or `DoubleTensor`, :attr:`value` must be + a real number, otherwise an integer. + + Args: + input (Tensor): the tensor to be added + tensor1 (Tensor): the numerator tensor + tensor2 (Tensor): the denominator tensor + + Keyword args: + value (Number, optional): multiplier for :math:`\text{tensor1} / \text{tensor2}` + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.randn(1, 3) + >>> t1 = torch.randn(3, 1) + >>> t2 = torch.randn(1, 3) + >>> torch.addcdiv(t, t1, t2, value=0.1) + tensor([[-0.2312, -3.6496, 0.1312], + [-1.0428, 3.4292, -0.1030], + [-0.5369, -0.9829, 0.0430]]) + """ + +@overload +def addcdiv( + input: Tensor, + tensor1: Tensor, + tensor2: Tensor, + *, + value: Number | _complex = 1, + out: Tensor | None = None, +) -> Tensor: + r""" + addcdiv(input, tensor1, tensor2, *, value=1, out=None) -> Tensor + + Performs the element-wise division of :attr:`tensor1` by :attr:`tensor2`, + multiplies the result by the scalar :attr:`value` and adds it to :attr:`input`. + + .. warning:: + Integer division with addcdiv is no longer supported, and in a future + release addcdiv will perform a true division of tensor1 and tensor2. + The historic addcdiv behavior can be implemented as + (input + value * torch.trunc(tensor1 / tensor2)).to(input.dtype) + for integer inputs and as (input + value * tensor1 / tensor2) for float inputs. + The future addcdiv behavior is just the latter implementation: + (input + value * tensor1 / tensor2), for all dtypes. + + .. math:: + \text{out}_i = \text{input}_i + \text{value} \times \frac{\text{tensor1}_i}{\text{tensor2}_i} + + + The shapes of :attr:`input`, :attr:`tensor1`, and :attr:`tensor2` must be + :ref:`broadcastable `. + + For inputs of type `FloatTensor` or `DoubleTensor`, :attr:`value` must be + a real number, otherwise an integer. + + Args: + input (Tensor): the tensor to be added + tensor1 (Tensor): the numerator tensor + tensor2 (Tensor): the denominator tensor + + Keyword args: + value (Number, optional): multiplier for :math:`\text{tensor1} / \text{tensor2}` + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.randn(1, 3) + >>> t1 = torch.randn(3, 1) + >>> t2 = torch.randn(1, 3) + >>> torch.addcdiv(t, t1, t2, value=0.1) + tensor([[-0.2312, -3.6496, 0.1312], + [-1.0428, 3.4292, -0.1030], + [-0.5369, -0.9829, 0.0430]]) + """ + +@overload +def addcmul( + self: Tensor, + value: Number | _complex, + tensor1: Tensor, + tensor2: Tensor, +) -> Tensor: + r""" + addcmul(input, tensor1, tensor2, *, value=1, out=None) -> Tensor + + Performs the element-wise multiplication of :attr:`tensor1` + by :attr:`tensor2`, multiplies the result by the scalar :attr:`value` + and adds it to :attr:`input`. + + .. math:: + \text{out}_i = \text{input}_i + \text{value} \times \text{tensor1}_i \times \text{tensor2}_i + + The shapes of :attr:`tensor`, :attr:`tensor1`, and :attr:`tensor2` must be + :ref:`broadcastable `. + + For inputs of type `FloatTensor` or `DoubleTensor`, :attr:`value` must be + a real number, otherwise an integer. + + Args: + input (Tensor): the tensor to be added + tensor1 (Tensor): the tensor to be multiplied + tensor2 (Tensor): the tensor to be multiplied + + Keyword args: + value (Number, optional): multiplier for :math:`tensor1 .* tensor2` + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.randn(1, 3) + >>> t1 = torch.randn(3, 1) + >>> t2 = torch.randn(1, 3) + >>> torch.addcmul(t, t1, t2, value=0.1) + tensor([[-0.8635, -0.6391, 1.6174], + [-0.7617, -0.5879, 1.7388], + [-0.8353, -0.6249, 1.6511]]) + """ + +@overload +def addcmul( + self: Tensor, + value: Number | _complex, + tensor1: Tensor, + tensor2: Tensor, + *, + out: Tensor, +) -> Tensor: + r""" + addcmul(input, tensor1, tensor2, *, value=1, out=None) -> Tensor + + Performs the element-wise multiplication of :attr:`tensor1` + by :attr:`tensor2`, multiplies the result by the scalar :attr:`value` + and adds it to :attr:`input`. + + .. math:: + \text{out}_i = \text{input}_i + \text{value} \times \text{tensor1}_i \times \text{tensor2}_i + + The shapes of :attr:`tensor`, :attr:`tensor1`, and :attr:`tensor2` must be + :ref:`broadcastable `. + + For inputs of type `FloatTensor` or `DoubleTensor`, :attr:`value` must be + a real number, otherwise an integer. + + Args: + input (Tensor): the tensor to be added + tensor1 (Tensor): the tensor to be multiplied + tensor2 (Tensor): the tensor to be multiplied + + Keyword args: + value (Number, optional): multiplier for :math:`tensor1 .* tensor2` + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.randn(1, 3) + >>> t1 = torch.randn(3, 1) + >>> t2 = torch.randn(1, 3) + >>> torch.addcmul(t, t1, t2, value=0.1) + tensor([[-0.8635, -0.6391, 1.6174], + [-0.7617, -0.5879, 1.7388], + [-0.8353, -0.6249, 1.6511]]) + """ + +@overload +def addcmul( + input: Tensor, + tensor1: Tensor, + tensor2: Tensor, + *, + value: Number | _complex = 1, + out: Tensor | None = None, +) -> Tensor: + r""" + addcmul(input, tensor1, tensor2, *, value=1, out=None) -> Tensor + + Performs the element-wise multiplication of :attr:`tensor1` + by :attr:`tensor2`, multiplies the result by the scalar :attr:`value` + and adds it to :attr:`input`. + + .. math:: + \text{out}_i = \text{input}_i + \text{value} \times \text{tensor1}_i \times \text{tensor2}_i + + The shapes of :attr:`tensor`, :attr:`tensor1`, and :attr:`tensor2` must be + :ref:`broadcastable `. + + For inputs of type `FloatTensor` or `DoubleTensor`, :attr:`value` must be + a real number, otherwise an integer. + + Args: + input (Tensor): the tensor to be added + tensor1 (Tensor): the tensor to be multiplied + tensor2 (Tensor): the tensor to be multiplied + + Keyword args: + value (Number, optional): multiplier for :math:`tensor1 .* tensor2` + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.randn(1, 3) + >>> t1 = torch.randn(3, 1) + >>> t2 = torch.randn(1, 3) + >>> torch.addcmul(t, t1, t2, value=0.1) + tensor([[-0.8635, -0.6391, 1.6174], + [-0.7617, -0.5879, 1.7388], + [-0.8353, -0.6249, 1.6511]]) + """ + +@overload +def addmm( + beta: Number | _complex, + self: Tensor, + alpha: Number | _complex, + mat1: Tensor, + mat2: Tensor, +) -> Tensor: + r""" + addmm(input, mat1, mat2, out_dtype=None, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix multiplication of the matrices :attr:`mat1` and :attr:`mat2`. + The matrix :attr:`input` is added to the final result. + + If :attr:`mat1` is a :math:`(n \times m)` tensor, :attr:`mat2` is a + :math:`(m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat1` and :attr:`mat2` and the added matrix :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat1}_i \mathbin{@} \text{mat2}_i) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operation has support for arguments with :ref:`sparse layouts`. If + :attr:`input` is sparse the result will have the same layout and if :attr:`out` + is provided it must have the same layout as :attr:`input`. + + + .. warning:: + Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported, + or may not have autograd support. If you notice missing functionality please + open a feature request. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): matrix to be added + mat1 (Tensor): the first matrix to be matrix multiplied + mat2 (Tensor): the second matrix to be matrix multiplied + out_dtype (dtype, optional): the dtype of the output tensor, + Supported only on CUDA and for torch.float32 given + torch.float16/torch.bfloat16 input dtypes + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2, 3) + >>> mat1 = torch.randn(2, 3) + >>> mat2 = torch.randn(3, 3) + >>> torch.addmm(M, mat1, mat2) + tensor([[-4.8716, 1.4671, -1.3746], + [ 0.7573, -3.9555, -2.8681]]) + """ + +@overload +def addmm( + beta: Number | _complex, + self: Tensor, + alpha: Number | _complex, + mat1: Tensor, + mat2: Tensor, + *, + out: Tensor, +) -> Tensor: + r""" + addmm(input, mat1, mat2, out_dtype=None, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix multiplication of the matrices :attr:`mat1` and :attr:`mat2`. + The matrix :attr:`input` is added to the final result. + + If :attr:`mat1` is a :math:`(n \times m)` tensor, :attr:`mat2` is a + :math:`(m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat1` and :attr:`mat2` and the added matrix :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat1}_i \mathbin{@} \text{mat2}_i) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operation has support for arguments with :ref:`sparse layouts`. If + :attr:`input` is sparse the result will have the same layout and if :attr:`out` + is provided it must have the same layout as :attr:`input`. + + + .. warning:: + Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported, + or may not have autograd support. If you notice missing functionality please + open a feature request. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): matrix to be added + mat1 (Tensor): the first matrix to be matrix multiplied + mat2 (Tensor): the second matrix to be matrix multiplied + out_dtype (dtype, optional): the dtype of the output tensor, + Supported only on CUDA and for torch.float32 given + torch.float16/torch.bfloat16 input dtypes + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2, 3) + >>> mat1 = torch.randn(2, 3) + >>> mat2 = torch.randn(3, 3) + >>> torch.addmm(M, mat1, mat2) + tensor([[-4.8716, 1.4671, -1.3746], + [ 0.7573, -3.9555, -2.8681]]) + """ + +@overload +def addmm( + input: Tensor, + mat1: Tensor, + mat2: Tensor, + *, + beta: Number | _complex = 1, + alpha: Number | _complex = 1, + out: Tensor | None = None, +) -> Tensor: + r""" + addmm(input, mat1, mat2, out_dtype=None, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix multiplication of the matrices :attr:`mat1` and :attr:`mat2`. + The matrix :attr:`input` is added to the final result. + + If :attr:`mat1` is a :math:`(n \times m)` tensor, :attr:`mat2` is a + :math:`(m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat1` and :attr:`mat2` and the added matrix :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat1}_i \mathbin{@} \text{mat2}_i) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operation has support for arguments with :ref:`sparse layouts`. If + :attr:`input` is sparse the result will have the same layout and if :attr:`out` + is provided it must have the same layout as :attr:`input`. + + + .. warning:: + Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported, + or may not have autograd support. If you notice missing functionality please + open a feature request. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): matrix to be added + mat1 (Tensor): the first matrix to be matrix multiplied + mat2 (Tensor): the second matrix to be matrix multiplied + out_dtype (dtype, optional): the dtype of the output tensor, + Supported only on CUDA and for torch.float32 given + torch.float16/torch.bfloat16 input dtypes + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2, 3) + >>> mat1 = torch.randn(2, 3) + >>> mat2 = torch.randn(3, 3) + >>> torch.addmm(M, mat1, mat2) + tensor([[-4.8716, 1.4671, -1.3746], + [ 0.7573, -3.9555, -2.8681]]) + """ + +@overload +def addmm( + input: Tensor, + mat1: Tensor, + mat2: Tensor, + out_dtype: _dtype, + *, + beta: Number | _complex = 1, + alpha: Number | _complex = 1, + out: Tensor | None = None, +) -> Tensor: + r""" + addmm(input, mat1, mat2, out_dtype=None, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix multiplication of the matrices :attr:`mat1` and :attr:`mat2`. + The matrix :attr:`input` is added to the final result. + + If :attr:`mat1` is a :math:`(n \times m)` tensor, :attr:`mat2` is a + :math:`(m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat1` and :attr:`mat2` and the added matrix :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat1}_i \mathbin{@} \text{mat2}_i) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operation has support for arguments with :ref:`sparse layouts`. If + :attr:`input` is sparse the result will have the same layout and if :attr:`out` + is provided it must have the same layout as :attr:`input`. + + + .. warning:: + Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported, + or may not have autograd support. If you notice missing functionality please + open a feature request. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): matrix to be added + mat1 (Tensor): the first matrix to be matrix multiplied + mat2 (Tensor): the second matrix to be matrix multiplied + out_dtype (dtype, optional): the dtype of the output tensor, + Supported only on CUDA and for torch.float32 given + torch.float16/torch.bfloat16 input dtypes + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2, 3) + >>> mat1 = torch.randn(2, 3) + >>> mat2 = torch.randn(3, 3) + >>> torch.addmm(M, mat1, mat2) + tensor([[-4.8716, 1.4671, -1.3746], + [ 0.7573, -3.9555, -2.8681]]) + """ + +@overload +def addmm( + beta: Number | _complex, + self: Tensor, + mat1: Tensor, + mat2: Tensor, +) -> Tensor: + r""" + addmm(input, mat1, mat2, out_dtype=None, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix multiplication of the matrices :attr:`mat1` and :attr:`mat2`. + The matrix :attr:`input` is added to the final result. + + If :attr:`mat1` is a :math:`(n \times m)` tensor, :attr:`mat2` is a + :math:`(m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat1` and :attr:`mat2` and the added matrix :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat1}_i \mathbin{@} \text{mat2}_i) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operation has support for arguments with :ref:`sparse layouts`. If + :attr:`input` is sparse the result will have the same layout and if :attr:`out` + is provided it must have the same layout as :attr:`input`. + + + .. warning:: + Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported, + or may not have autograd support. If you notice missing functionality please + open a feature request. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): matrix to be added + mat1 (Tensor): the first matrix to be matrix multiplied + mat2 (Tensor): the second matrix to be matrix multiplied + out_dtype (dtype, optional): the dtype of the output tensor, + Supported only on CUDA and for torch.float32 given + torch.float16/torch.bfloat16 input dtypes + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2, 3) + >>> mat1 = torch.randn(2, 3) + >>> mat2 = torch.randn(3, 3) + >>> torch.addmm(M, mat1, mat2) + tensor([[-4.8716, 1.4671, -1.3746], + [ 0.7573, -3.9555, -2.8681]]) + """ + +@overload +def addmm( + beta: Number | _complex, + self: Tensor, + mat1: Tensor, + mat2: Tensor, + *, + out: Tensor, +) -> Tensor: + r""" + addmm(input, mat1, mat2, out_dtype=None, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix multiplication of the matrices :attr:`mat1` and :attr:`mat2`. + The matrix :attr:`input` is added to the final result. + + If :attr:`mat1` is a :math:`(n \times m)` tensor, :attr:`mat2` is a + :math:`(m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat1` and :attr:`mat2` and the added matrix :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat1}_i \mathbin{@} \text{mat2}_i) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operation has support for arguments with :ref:`sparse layouts`. If + :attr:`input` is sparse the result will have the same layout and if :attr:`out` + is provided it must have the same layout as :attr:`input`. + + + .. warning:: + Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported, + or may not have autograd support. If you notice missing functionality please + open a feature request. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): matrix to be added + mat1 (Tensor): the first matrix to be matrix multiplied + mat2 (Tensor): the second matrix to be matrix multiplied + out_dtype (dtype, optional): the dtype of the output tensor, + Supported only on CUDA and for torch.float32 given + torch.float16/torch.bfloat16 input dtypes + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2, 3) + >>> mat1 = torch.randn(2, 3) + >>> mat2 = torch.randn(3, 3) + >>> torch.addmm(M, mat1, mat2) + tensor([[-4.8716, 1.4671, -1.3746], + [ 0.7573, -3.9555, -2.8681]]) + """ + +@overload +def addmv( + beta: Number | _complex, + self: Tensor, + alpha: Number | _complex, + mat: Tensor, + vec: Tensor, +) -> Tensor: + r""" + addmv(input, mat, vec, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix-vector product of the matrix :attr:`mat` and + the vector :attr:`vec`. + The vector :attr:`input` is added to the final result. + + If :attr:`mat` is a :math:`(n \times m)` tensor, :attr:`vec` is a 1-D tensor of + size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a 1-D tensor of size `n` and + :attr:`out` will be 1-D tensor of size `n`. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat` and :attr:`vec` and the added tensor :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat} \mathbin{@} \text{vec}) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + Args: + input (Tensor): vector to be added + mat (Tensor): matrix to be matrix multiplied + vec (Tensor): vector to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat @ vec` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2) + >>> mat = torch.randn(2, 3) + >>> vec = torch.randn(3) + >>> torch.addmv(M, mat, vec) + tensor([-0.3768, -5.5565]) + """ + +@overload +def addmv( + beta: Number | _complex, + self: Tensor, + alpha: Number | _complex, + mat: Tensor, + vec: Tensor, + *, + out: Tensor, +) -> Tensor: + r""" + addmv(input, mat, vec, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix-vector product of the matrix :attr:`mat` and + the vector :attr:`vec`. + The vector :attr:`input` is added to the final result. + + If :attr:`mat` is a :math:`(n \times m)` tensor, :attr:`vec` is a 1-D tensor of + size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a 1-D tensor of size `n` and + :attr:`out` will be 1-D tensor of size `n`. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat` and :attr:`vec` and the added tensor :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat} \mathbin{@} \text{vec}) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + Args: + input (Tensor): vector to be added + mat (Tensor): matrix to be matrix multiplied + vec (Tensor): vector to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat @ vec` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2) + >>> mat = torch.randn(2, 3) + >>> vec = torch.randn(3) + >>> torch.addmv(M, mat, vec) + tensor([-0.3768, -5.5565]) + """ + +@overload +def addmv( + input: Tensor, + mat: Tensor, + vec: Tensor, + *, + beta: Number | _complex = 1, + alpha: Number | _complex = 1, + out: Tensor | None = None, +) -> Tensor: + r""" + addmv(input, mat, vec, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix-vector product of the matrix :attr:`mat` and + the vector :attr:`vec`. + The vector :attr:`input` is added to the final result. + + If :attr:`mat` is a :math:`(n \times m)` tensor, :attr:`vec` is a 1-D tensor of + size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a 1-D tensor of size `n` and + :attr:`out` will be 1-D tensor of size `n`. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat` and :attr:`vec` and the added tensor :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat} \mathbin{@} \text{vec}) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + Args: + input (Tensor): vector to be added + mat (Tensor): matrix to be matrix multiplied + vec (Tensor): vector to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat @ vec` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2) + >>> mat = torch.randn(2, 3) + >>> vec = torch.randn(3) + >>> torch.addmv(M, mat, vec) + tensor([-0.3768, -5.5565]) + """ + +@overload +def addmv( + beta: Number | _complex, + self: Tensor, + mat: Tensor, + vec: Tensor, +) -> Tensor: + r""" + addmv(input, mat, vec, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix-vector product of the matrix :attr:`mat` and + the vector :attr:`vec`. + The vector :attr:`input` is added to the final result. + + If :attr:`mat` is a :math:`(n \times m)` tensor, :attr:`vec` is a 1-D tensor of + size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a 1-D tensor of size `n` and + :attr:`out` will be 1-D tensor of size `n`. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat` and :attr:`vec` and the added tensor :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat} \mathbin{@} \text{vec}) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + Args: + input (Tensor): vector to be added + mat (Tensor): matrix to be matrix multiplied + vec (Tensor): vector to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat @ vec` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2) + >>> mat = torch.randn(2, 3) + >>> vec = torch.randn(3) + >>> torch.addmv(M, mat, vec) + tensor([-0.3768, -5.5565]) + """ + +@overload +def addmv( + beta: Number | _complex, + self: Tensor, + mat: Tensor, + vec: Tensor, + *, + out: Tensor, +) -> Tensor: + r""" + addmv(input, mat, vec, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix-vector product of the matrix :attr:`mat` and + the vector :attr:`vec`. + The vector :attr:`input` is added to the final result. + + If :attr:`mat` is a :math:`(n \times m)` tensor, :attr:`vec` is a 1-D tensor of + size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a 1-D tensor of size `n` and + :attr:`out` will be 1-D tensor of size `n`. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat` and :attr:`vec` and the added tensor :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat} \mathbin{@} \text{vec}) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + Args: + input (Tensor): vector to be added + mat (Tensor): matrix to be matrix multiplied + vec (Tensor): vector to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat @ vec` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2) + >>> mat = torch.randn(2, 3) + >>> vec = torch.randn(3) + >>> torch.addmv(M, mat, vec) + tensor([-0.3768, -5.5565]) + """ + +@overload +def addmv_( + beta: Number | _complex, + self: Tensor, + alpha: Number | _complex, + mat: Tensor, + vec: Tensor, +) -> Tensor: ... +@overload +def addmv_( + input: Tensor, + mat: Tensor, + vec: Tensor, + *, + beta: Number | _complex = 1, + alpha: Number | _complex = 1, +) -> Tensor: ... +@overload +def addmv_( + beta: Number | _complex, + self: Tensor, + mat: Tensor, + vec: Tensor, +) -> Tensor: ... +@overload +def addr( + beta: Number | _complex, + self: Tensor, + alpha: Number | _complex, + vec1: Tensor, + vec2: Tensor, +) -> Tensor: + r""" + addr(input, vec1, vec2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs the outer-product of vectors :attr:`vec1` and :attr:`vec2` + and adds it to the matrix :attr:`input`. + + Optional values :attr:`beta` and :attr:`alpha` are scaling factors on the + outer product between :attr:`vec1` and :attr:`vec2` and the added matrix + :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{vec1} \otimes \text{vec2}) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + If :attr:`vec1` is a vector of size `n` and :attr:`vec2` is a vector + of size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a matrix of size + :math:`(n \times m)` and :attr:`out` will be a matrix of size + :math:`(n \times m)`. + + Args: + input (Tensor): matrix to be added + vec1 (Tensor): the first vector of the outer product + vec2 (Tensor): the second vector of the outer product + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{vec1} \otimes \text{vec2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> vec1 = torch.arange(1., 4.) + >>> vec2 = torch.arange(1., 3.) + >>> M = torch.zeros(3, 2) + >>> torch.addr(M, vec1, vec2) + tensor([[ 1., 2.], + [ 2., 4.], + [ 3., 6.]]) + """ + +@overload +def addr( + beta: Number | _complex, + self: Tensor, + alpha: Number | _complex, + vec1: Tensor, + vec2: Tensor, + *, + out: Tensor, +) -> Tensor: + r""" + addr(input, vec1, vec2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs the outer-product of vectors :attr:`vec1` and :attr:`vec2` + and adds it to the matrix :attr:`input`. + + Optional values :attr:`beta` and :attr:`alpha` are scaling factors on the + outer product between :attr:`vec1` and :attr:`vec2` and the added matrix + :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{vec1} \otimes \text{vec2}) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + If :attr:`vec1` is a vector of size `n` and :attr:`vec2` is a vector + of size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a matrix of size + :math:`(n \times m)` and :attr:`out` will be a matrix of size + :math:`(n \times m)`. + + Args: + input (Tensor): matrix to be added + vec1 (Tensor): the first vector of the outer product + vec2 (Tensor): the second vector of the outer product + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{vec1} \otimes \text{vec2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> vec1 = torch.arange(1., 4.) + >>> vec2 = torch.arange(1., 3.) + >>> M = torch.zeros(3, 2) + >>> torch.addr(M, vec1, vec2) + tensor([[ 1., 2.], + [ 2., 4.], + [ 3., 6.]]) + """ + +@overload +def addr( + input: Tensor, + vec1: Tensor, + vec2: Tensor, + *, + beta: Number | _complex = 1, + alpha: Number | _complex = 1, + out: Tensor | None = None, +) -> Tensor: + r""" + addr(input, vec1, vec2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs the outer-product of vectors :attr:`vec1` and :attr:`vec2` + and adds it to the matrix :attr:`input`. + + Optional values :attr:`beta` and :attr:`alpha` are scaling factors on the + outer product between :attr:`vec1` and :attr:`vec2` and the added matrix + :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{vec1} \otimes \text{vec2}) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + If :attr:`vec1` is a vector of size `n` and :attr:`vec2` is a vector + of size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a matrix of size + :math:`(n \times m)` and :attr:`out` will be a matrix of size + :math:`(n \times m)`. + + Args: + input (Tensor): matrix to be added + vec1 (Tensor): the first vector of the outer product + vec2 (Tensor): the second vector of the outer product + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{vec1} \otimes \text{vec2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> vec1 = torch.arange(1., 4.) + >>> vec2 = torch.arange(1., 3.) + >>> M = torch.zeros(3, 2) + >>> torch.addr(M, vec1, vec2) + tensor([[ 1., 2.], + [ 2., 4.], + [ 3., 6.]]) + """ + +@overload +def addr( + beta: Number | _complex, + self: Tensor, + vec1: Tensor, + vec2: Tensor, +) -> Tensor: + r""" + addr(input, vec1, vec2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs the outer-product of vectors :attr:`vec1` and :attr:`vec2` + and adds it to the matrix :attr:`input`. + + Optional values :attr:`beta` and :attr:`alpha` are scaling factors on the + outer product between :attr:`vec1` and :attr:`vec2` and the added matrix + :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{vec1} \otimes \text{vec2}) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + If :attr:`vec1` is a vector of size `n` and :attr:`vec2` is a vector + of size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a matrix of size + :math:`(n \times m)` and :attr:`out` will be a matrix of size + :math:`(n \times m)`. + + Args: + input (Tensor): matrix to be added + vec1 (Tensor): the first vector of the outer product + vec2 (Tensor): the second vector of the outer product + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{vec1} \otimes \text{vec2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> vec1 = torch.arange(1., 4.) + >>> vec2 = torch.arange(1., 3.) + >>> M = torch.zeros(3, 2) + >>> torch.addr(M, vec1, vec2) + tensor([[ 1., 2.], + [ 2., 4.], + [ 3., 6.]]) + """ + +@overload +def addr( + beta: Number | _complex, + self: Tensor, + vec1: Tensor, + vec2: Tensor, + *, + out: Tensor, +) -> Tensor: + r""" + addr(input, vec1, vec2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs the outer-product of vectors :attr:`vec1` and :attr:`vec2` + and adds it to the matrix :attr:`input`. + + Optional values :attr:`beta` and :attr:`alpha` are scaling factors on the + outer product between :attr:`vec1` and :attr:`vec2` and the added matrix + :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{vec1} \otimes \text{vec2}) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + If :attr:`vec1` is a vector of size `n` and :attr:`vec2` is a vector + of size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a matrix of size + :math:`(n \times m)` and :attr:`out` will be a matrix of size + :math:`(n \times m)`. + + Args: + input (Tensor): matrix to be added + vec1 (Tensor): the first vector of the outer product + vec2 (Tensor): the second vector of the outer product + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{vec1} \otimes \text{vec2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> vec1 = torch.arange(1., 4.) + >>> vec2 = torch.arange(1., 3.) + >>> M = torch.zeros(3, 2) + >>> torch.addr(M, vec1, vec2) + tensor([[ 1., 2.], + [ 2., 4.], + [ 3., 6.]]) + """ + +def adjoint(input: Tensor) -> Tensor: + r""" + adjoint(input: Tensor) -> Tensor + Returns a view of the tensor conjugated and with the last two dimensions transposed. + + ``x.adjoint()`` is equivalent to ``x.transpose(-2, -1).conj()`` for complex tensors and + to ``x.transpose(-2, -1)`` for real tensors. + + Args: + {input} + + Example:: + + >>> x = torch.arange(4, dtype=torch.float) + >>> A = torch.complex(x, x).reshape(2, 2) + >>> A + tensor([[0.+0.j, 1.+1.j], + [2.+2.j, 3.+3.j]]) + >>> A.adjoint() + tensor([[0.-0.j, 2.-2.j], + [1.-1.j, 3.-3.j]]) + >>> (A.adjoint() == A.mH).all() + tensor(True) + """ + +def affine_grid_generator( + theta: Tensor, + size: Sequence[_int | SymInt], + align_corners: _bool, +) -> Tensor: ... +def alias_copy(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.alias`, but all output tensors + are freshly created instead of aliasing the input. + """ + +@overload +def all(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + all(input: Tensor, *, out=None) -> Tensor + + Tests if all elements in :attr:`input` evaluate to `True`. + + .. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.all(a) + tensor(False, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.all(a) + tensor(False) + + .. function:: all(input, dim, keepdim=False, *, out=None) -> Tensor + :noindex: + + For each row of :attr:`input` in the given dimension :attr:`dim`, + returns `True` if all elements in the row evaluate to `True` and `False` otherwise. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(4, 2).bool() + >>> a + tensor([[True, True], + [True, False], + [True, True], + [True, True]], dtype=torch.bool) + >>> torch.all(a, dim=1) + tensor([ True, False, True, True], dtype=torch.bool) + >>> torch.all(a, dim=0) + tensor([ True, False], dtype=torch.bool) + """ + +@overload +def all( + input: Tensor, + dim: _size | None = None, + keepdim: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + all(input: Tensor, *, out=None) -> Tensor + + Tests if all elements in :attr:`input` evaluate to `True`. + + .. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.all(a) + tensor(False, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.all(a) + tensor(False) + + .. function:: all(input, dim, keepdim=False, *, out=None) -> Tensor + :noindex: + + For each row of :attr:`input` in the given dimension :attr:`dim`, + returns `True` if all elements in the row evaluate to `True` and `False` otherwise. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(4, 2).bool() + >>> a + tensor([[True, True], + [True, False], + [True, True], + [True, True]], dtype=torch.bool) + >>> torch.all(a, dim=1) + tensor([ True, False, True, True], dtype=torch.bool) + >>> torch.all(a, dim=0) + tensor([ True, False], dtype=torch.bool) + """ + +@overload +def all( + input: Tensor, + dim: _int, + keepdim: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + all(input: Tensor, *, out=None) -> Tensor + + Tests if all elements in :attr:`input` evaluate to `True`. + + .. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.all(a) + tensor(False, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.all(a) + tensor(False) + + .. function:: all(input, dim, keepdim=False, *, out=None) -> Tensor + :noindex: + + For each row of :attr:`input` in the given dimension :attr:`dim`, + returns `True` if all elements in the row evaluate to `True` and `False` otherwise. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(4, 2).bool() + >>> a + tensor([[True, True], + [True, False], + [True, True], + [True, True]], dtype=torch.bool) + >>> torch.all(a, dim=1) + tensor([ True, False, True, True], dtype=torch.bool) + >>> torch.all(a, dim=0) + tensor([ True, False], dtype=torch.bool) + """ + +@overload +def all( + input: Tensor, + dim: str | EllipsisType | None, + keepdim: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + all(input: Tensor, *, out=None) -> Tensor + + Tests if all elements in :attr:`input` evaluate to `True`. + + .. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.all(a) + tensor(False, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.all(a) + tensor(False) + + .. function:: all(input, dim, keepdim=False, *, out=None) -> Tensor + :noindex: + + For each row of :attr:`input` in the given dimension :attr:`dim`, + returns `True` if all elements in the row evaluate to `True` and `False` otherwise. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(4, 2).bool() + >>> a + tensor([[True, True], + [True, False], + [True, True], + [True, True]], dtype=torch.bool) + >>> torch.all(a, dim=1) + tensor([ True, False, True, True], dtype=torch.bool) + >>> torch.all(a, dim=0) + tensor([ True, False], dtype=torch.bool) + """ + +def allclose( + input: Tensor, + other: Tensor, + rtol: _float = 1e-05, + atol: _float = 1e-08, + equal_nan: _bool = False, +) -> _bool: + r""" + allclose(input: Tensor, other: Tensor, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) -> bool + + This function checks if :attr:`input` and :attr:`other` satisfy the condition: + + .. math:: + \lvert \text{input}_i - \text{other}_i \rvert \leq \texttt{atol} + \texttt{rtol} \times \lvert \text{other}_i \rvert + + elementwise, for all elements of :attr:`input` and :attr:`other`. The behaviour of this function is analogous to + `numpy.allclose `_ + + Args: + input (Tensor): first tensor to compare + other (Tensor): second tensor to compare + atol (float, optional): absolute tolerance. Default: 1e-08 + rtol (float, optional): relative tolerance. Default: 1e-05 + equal_nan (bool, optional): if ``True``, then two ``NaN`` s will be considered equal. Default: ``False`` + + Example:: + + >>> torch.allclose(torch.tensor([10000., 1e-07]), torch.tensor([10000.1, 1e-08])) + False + >>> torch.allclose(torch.tensor([10000., 1e-08]), torch.tensor([10000.1, 1e-09])) + True + >>> torch.allclose(torch.tensor([1.0, float('nan')]), torch.tensor([1.0, float('nan')])) + False + >>> torch.allclose(torch.tensor([1.0, float('nan')]), torch.tensor([1.0, float('nan')]), equal_nan=True) + True + """ + +def alpha_dropout(input: Tensor, p: _float, train: _bool) -> Tensor: ... +def alpha_dropout_(input: Tensor, p: _float, train: _bool) -> Tensor: ... +def amax( + input: Tensor, + dim: _int | _size = (), + keepdim: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + amax(input, dim, keepdim=False, *, out=None) -> Tensor + + Returns the maximum value of each slice of the :attr:`input` tensor in the given + dimension(s) :attr:`dim`. + + .. note:: + The difference between ``max``/``min`` and ``amax``/``amin`` is: + - ``amax``/``amin`` supports reducing on multiple dimensions, + - ``amax``/``amin`` does not return indices. + + Both ``max``/``min`` and ``amax``/``amin`` evenly distribute gradients between equal values + when there are multiple input elements with the same minimum or maximum value. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.8177, 1.4878, -0.2491, 0.9130], + [-0.7158, 1.1775, 2.0992, 0.4817], + [-0.0053, 0.0164, -1.3738, -0.0507], + [ 1.9700, 1.1106, -1.0318, -1.0816]]) + >>> torch.amax(a, 1) + tensor([1.4878, 2.0992, 0.0164, 1.9700]) + """ + +def amin( + input: Tensor, + dim: _int | _size = (), + keepdim: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + amin(input, dim, keepdim=False, *, out=None) -> Tensor + + Returns the minimum value of each slice of the :attr:`input` tensor in the given + dimension(s) :attr:`dim`. + + .. note:: + The difference between ``max``/``min`` and ``amax``/``amin`` is: + - ``amax``/``amin`` supports reducing on multiple dimensions, + - ``amax``/``amin`` does not return indices. + + Both ``max``/``min`` and ``amax``/``amin`` evenly distribute gradients between equal values + when there are multiple input elements with the same minimum or maximum value. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.6451, -0.4866, 0.2987, -1.3312], + [-0.5744, 1.2980, 1.8397, -0.2713], + [ 0.9128, 0.9214, -1.7268, -0.2995], + [ 0.9023, 0.4853, 0.9075, -1.6165]]) + >>> torch.amin(a, 1) + tensor([-1.3312, -0.5744, -1.7268, -1.6165]) + """ + +def aminmax( + input: Tensor, + *, + dim: _int | None = None, + keepdim: _bool = False, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.aminmax: + r""" + aminmax(input, *, dim=None, keepdim=False, out=None) -> (Tensor min, Tensor max) + + Computes the minimum and maximum values of the :attr:`input` tensor. + + Args: + input (Tensor): + The input tensor + + Keyword Args: + dim (Optional[int]): + The dimension along which to compute the values. If `None`, + computes the values over the entire :attr:`input` tensor. + Default is `None`. + keepdim (bool): + If `True`, the reduced dimensions will be kept in the output + tensor as dimensions with size 1 for broadcasting, otherwise + they will be removed, as if calling (:func:`torch.squeeze`). + Default is `False`. + out (Optional[Tuple[Tensor, Tensor]]): + Optional tensors on which to write the result. Must have the same + shape and dtype as the expected output. + Default is `None`. + + Returns: + A named tuple `(min, max)` containing the minimum and maximum values. + + Raises: + RuntimeError + If any of the dimensions to compute the values over has size 0. + + .. note:: + NaN values are propagated to the output if at least one value is NaN. + + .. seealso:: + :func:`torch.amin` computes just the minimum value + :func:`torch.amax` computes just the maximum value + + Example:: + + >>> torch.aminmax(torch.tensor([1, -3, 5])) + torch.return_types.aminmax( + min=tensor(-3), + max=tensor(5)) + + >>> # aminmax propagates NaNs + >>> torch.aminmax(torch.tensor([1, -3, 5, torch.nan])) + torch.return_types.aminmax( + min=tensor(nan), + max=tensor(nan)) + + >>> t = torch.arange(10).view(2, 5) + >>> t + tensor([[0, 1, 2, 3, 4], + [5, 6, 7, 8, 9]]) + >>> t.aminmax(dim=0, keepdim=True) + torch.return_types.aminmax( + min=tensor([[0, 1, 2, 3, 4]]), + max=tensor([[5, 6, 7, 8, 9]])) + """ + +def angle(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + angle(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Computes the element-wise angle (in radians) of the given :attr:`input` tensor. + + .. math:: + \text{out}_{i} = angle(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + .. note:: Starting in PyTorch 1.8, angle returns pi for negative real numbers, + zero for non-negative real numbers, and propagates NaNs. Previously + the function would return zero for all real numbers and not propagate + floating-point NaNs. + + Example:: + + >>> torch.angle(torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]))*180/3.14159 + tensor([ 135., 135, -45]) + """ + +@overload +def any(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + any(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Tests if any element in :attr:`input` evaluates to `True`. + + .. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.any(a) + tensor(True, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.any(a) + tensor(True) + + .. function:: any(input, dim, keepdim=False, *, out=None) -> Tensor + :noindex: + + For each row of :attr:`input` in the given dimension :attr:`dim`, + returns `True` if any element in the row evaluate to `True` and `False` otherwise. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4, 2) < 0 + >>> a + tensor([[ True, True], + [False, True], + [ True, True], + [False, False]]) + >>> torch.any(a, 1) + tensor([ True, True, True, False]) + >>> torch.any(a, 0) + tensor([True, True]) + """ + +@overload +def any( + input: Tensor, + dim: _size | None = None, + keepdim: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + any(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Tests if any element in :attr:`input` evaluates to `True`. + + .. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.any(a) + tensor(True, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.any(a) + tensor(True) + + .. function:: any(input, dim, keepdim=False, *, out=None) -> Tensor + :noindex: + + For each row of :attr:`input` in the given dimension :attr:`dim`, + returns `True` if any element in the row evaluate to `True` and `False` otherwise. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4, 2) < 0 + >>> a + tensor([[ True, True], + [False, True], + [ True, True], + [False, False]]) + >>> torch.any(a, 1) + tensor([ True, True, True, False]) + >>> torch.any(a, 0) + tensor([True, True]) + """ + +@overload +def any( + input: Tensor, + dim: _int, + keepdim: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + any(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Tests if any element in :attr:`input` evaluates to `True`. + + .. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.any(a) + tensor(True, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.any(a) + tensor(True) + + .. function:: any(input, dim, keepdim=False, *, out=None) -> Tensor + :noindex: + + For each row of :attr:`input` in the given dimension :attr:`dim`, + returns `True` if any element in the row evaluate to `True` and `False` otherwise. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4, 2) < 0 + >>> a + tensor([[ True, True], + [False, True], + [ True, True], + [False, False]]) + >>> torch.any(a, 1) + tensor([ True, True, True, False]) + >>> torch.any(a, 0) + tensor([True, True]) + """ + +@overload +def any( + input: Tensor, + dim: str | EllipsisType | None, + keepdim: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + any(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Tests if any element in :attr:`input` evaluates to `True`. + + .. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.any(a) + tensor(True, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.any(a) + tensor(True) + + .. function:: any(input, dim, keepdim=False, *, out=None) -> Tensor + :noindex: + + For each row of :attr:`input` in the given dimension :attr:`dim`, + returns `True` if any element in the row evaluate to `True` and `False` otherwise. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4, 2) < 0 + >>> a + tensor([[ True, True], + [False, True], + [ True, True], + [False, False]]) + >>> torch.any(a, 1) + tensor([ True, True, True, False]) + >>> torch.any(a, 0) + tensor([True, True]) + """ + +@overload +def arange( + start: Number, + end: Number, + step: Number, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + device: DeviceLikeType | None = None, + requires_grad: _bool = False, + pin_memory: _bool = False, +) -> Tensor: + r""" + arange(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a 1-D tensor of size :math:`\left\lceil \frac{\text{end} - \text{start}}{\text{step}} \right\rceil` + with values from the interval ``[start, end)`` taken with common difference + :attr:`step` beginning from `start`. + + Note: When using floating-point dtypes (especially reduced precision types like ``bfloat16``), + the results may be affected by floating-point rounding behavior. Some values in the sequence + might not be exactly representable in certain floating-point formats, which can lead to + repeated values or unexpected rounding. For precise sequences, it is recommended to use + integer dtypes instead of floating-point dtypes. + + Note that non-integer :attr:`step` is subject to floating point rounding errors when + comparing against :attr:`end`; to avoid inconsistency, we advise subtracting a small epsilon from :attr:`end` + in such cases. + + .. math:: + \text{out}_{{i+1}} = \text{out}_{i} + \text{step} + + Args: + start (Number, optional): the starting value for the set of points. Default: ``0``. + end (Number): the ending value for the set of points + step (Number, optional): the gap between each pair of adjacent points. Default: ``1``. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). If `dtype` is not given, infer the data type from the other input + arguments. If any of `start`, `end`, or `stop` are floating-point, the + `dtype` is inferred to be the default dtype, see + :meth:`~torch.get_default_dtype`. Otherwise, the `dtype` is inferred to + be `torch.int64`. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.arange(5) + tensor([ 0, 1, 2, 3, 4]) + >>> torch.arange(1, 4) + tensor([ 1, 2, 3]) + >>> torch.arange(1, 2.5, 0.5) + tensor([ 1.0000, 1.5000, 2.0000]) + """ + +@overload +def arange( + start: Number, + end: Number, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + device: DeviceLikeType | None = None, + requires_grad: _bool = False, + pin_memory: _bool = False, +) -> Tensor: + r""" + arange(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a 1-D tensor of size :math:`\left\lceil \frac{\text{end} - \text{start}}{\text{step}} \right\rceil` + with values from the interval ``[start, end)`` taken with common difference + :attr:`step` beginning from `start`. + + Note: When using floating-point dtypes (especially reduced precision types like ``bfloat16``), + the results may be affected by floating-point rounding behavior. Some values in the sequence + might not be exactly representable in certain floating-point formats, which can lead to + repeated values or unexpected rounding. For precise sequences, it is recommended to use + integer dtypes instead of floating-point dtypes. + + Note that non-integer :attr:`step` is subject to floating point rounding errors when + comparing against :attr:`end`; to avoid inconsistency, we advise subtracting a small epsilon from :attr:`end` + in such cases. + + .. math:: + \text{out}_{{i+1}} = \text{out}_{i} + \text{step} + + Args: + start (Number, optional): the starting value for the set of points. Default: ``0``. + end (Number): the ending value for the set of points + step (Number, optional): the gap between each pair of adjacent points. Default: ``1``. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). If `dtype` is not given, infer the data type from the other input + arguments. If any of `start`, `end`, or `stop` are floating-point, the + `dtype` is inferred to be the default dtype, see + :meth:`~torch.get_default_dtype`. Otherwise, the `dtype` is inferred to + be `torch.int64`. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.arange(5) + tensor([ 0, 1, 2, 3, 4]) + >>> torch.arange(1, 4) + tensor([ 1, 2, 3]) + >>> torch.arange(1, 2.5, 0.5) + tensor([ 1.0000, 1.5000, 2.0000]) + """ + +@overload +def arange( + end: Number, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + device: DeviceLikeType | None = None, + requires_grad: _bool = False, + pin_memory: _bool = False, +) -> Tensor: + r""" + arange(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a 1-D tensor of size :math:`\left\lceil \frac{\text{end} - \text{start}}{\text{step}} \right\rceil` + with values from the interval ``[start, end)`` taken with common difference + :attr:`step` beginning from `start`. + + Note: When using floating-point dtypes (especially reduced precision types like ``bfloat16``), + the results may be affected by floating-point rounding behavior. Some values in the sequence + might not be exactly representable in certain floating-point formats, which can lead to + repeated values or unexpected rounding. For precise sequences, it is recommended to use + integer dtypes instead of floating-point dtypes. + + Note that non-integer :attr:`step` is subject to floating point rounding errors when + comparing against :attr:`end`; to avoid inconsistency, we advise subtracting a small epsilon from :attr:`end` + in such cases. + + .. math:: + \text{out}_{{i+1}} = \text{out}_{i} + \text{step} + + Args: + start (Number, optional): the starting value for the set of points. Default: ``0``. + end (Number): the ending value for the set of points + step (Number, optional): the gap between each pair of adjacent points. Default: ``1``. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). If `dtype` is not given, infer the data type from the other input + arguments. If any of `start`, `end`, or `stop` are floating-point, the + `dtype` is inferred to be the default dtype, see + :meth:`~torch.get_default_dtype`. Otherwise, the `dtype` is inferred to + be `torch.int64`. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.arange(5) + tensor([ 0, 1, 2, 3, 4]) + >>> torch.arange(1, 4) + tensor([ 1, 2, 3]) + >>> torch.arange(1, 2.5, 0.5) + tensor([ 1.0000, 1.5000, 2.0000]) + """ + +@overload +def arange( + end: Number | _complex, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + arange(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a 1-D tensor of size :math:`\left\lceil \frac{\text{end} - \text{start}}{\text{step}} \right\rceil` + with values from the interval ``[start, end)`` taken with common difference + :attr:`step` beginning from `start`. + + Note: When using floating-point dtypes (especially reduced precision types like ``bfloat16``), + the results may be affected by floating-point rounding behavior. Some values in the sequence + might not be exactly representable in certain floating-point formats, which can lead to + repeated values or unexpected rounding. For precise sequences, it is recommended to use + integer dtypes instead of floating-point dtypes. + + Note that non-integer :attr:`step` is subject to floating point rounding errors when + comparing against :attr:`end`; to avoid inconsistency, we advise subtracting a small epsilon from :attr:`end` + in such cases. + + .. math:: + \text{out}_{{i+1}} = \text{out}_{i} + \text{step} + + Args: + start (Number, optional): the starting value for the set of points. Default: ``0``. + end (Number): the ending value for the set of points + step (Number, optional): the gap between each pair of adjacent points. Default: ``1``. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). If `dtype` is not given, infer the data type from the other input + arguments. If any of `start`, `end`, or `stop` are floating-point, the + `dtype` is inferred to be the default dtype, see + :meth:`~torch.get_default_dtype`. Otherwise, the `dtype` is inferred to + be `torch.int64`. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.arange(5) + tensor([ 0, 1, 2, 3, 4]) + >>> torch.arange(1, 4) + tensor([ 1, 2, 3]) + >>> torch.arange(1, 2.5, 0.5) + tensor([ 1.0000, 1.5000, 2.0000]) + """ + +@overload +def arange( + start: Number | _complex, + end: Number | _complex, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + arange(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a 1-D tensor of size :math:`\left\lceil \frac{\text{end} - \text{start}}{\text{step}} \right\rceil` + with values from the interval ``[start, end)`` taken with common difference + :attr:`step` beginning from `start`. + + Note: When using floating-point dtypes (especially reduced precision types like ``bfloat16``), + the results may be affected by floating-point rounding behavior. Some values in the sequence + might not be exactly representable in certain floating-point formats, which can lead to + repeated values or unexpected rounding. For precise sequences, it is recommended to use + integer dtypes instead of floating-point dtypes. + + Note that non-integer :attr:`step` is subject to floating point rounding errors when + comparing against :attr:`end`; to avoid inconsistency, we advise subtracting a small epsilon from :attr:`end` + in such cases. + + .. math:: + \text{out}_{{i+1}} = \text{out}_{i} + \text{step} + + Args: + start (Number, optional): the starting value for the set of points. Default: ``0``. + end (Number): the ending value for the set of points + step (Number, optional): the gap between each pair of adjacent points. Default: ``1``. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). If `dtype` is not given, infer the data type from the other input + arguments. If any of `start`, `end`, or `stop` are floating-point, the + `dtype` is inferred to be the default dtype, see + :meth:`~torch.get_default_dtype`. Otherwise, the `dtype` is inferred to + be `torch.int64`. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.arange(5) + tensor([ 0, 1, 2, 3, 4]) + >>> torch.arange(1, 4) + tensor([ 1, 2, 3]) + >>> torch.arange(1, 2.5, 0.5) + tensor([ 1.0000, 1.5000, 2.0000]) + """ + +@overload +def arange( + start: Number | _complex, + end: Number | _complex, + step: Number | _complex = 1, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + arange(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a 1-D tensor of size :math:`\left\lceil \frac{\text{end} - \text{start}}{\text{step}} \right\rceil` + with values from the interval ``[start, end)`` taken with common difference + :attr:`step` beginning from `start`. + + Note: When using floating-point dtypes (especially reduced precision types like ``bfloat16``), + the results may be affected by floating-point rounding behavior. Some values in the sequence + might not be exactly representable in certain floating-point formats, which can lead to + repeated values or unexpected rounding. For precise sequences, it is recommended to use + integer dtypes instead of floating-point dtypes. + + Note that non-integer :attr:`step` is subject to floating point rounding errors when + comparing against :attr:`end`; to avoid inconsistency, we advise subtracting a small epsilon from :attr:`end` + in such cases. + + .. math:: + \text{out}_{{i+1}} = \text{out}_{i} + \text{step} + + Args: + start (Number, optional): the starting value for the set of points. Default: ``0``. + end (Number): the ending value for the set of points + step (Number, optional): the gap between each pair of adjacent points. Default: ``1``. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). If `dtype` is not given, infer the data type from the other input + arguments. If any of `start`, `end`, or `stop` are floating-point, the + `dtype` is inferred to be the default dtype, see + :meth:`~torch.get_default_dtype`. Otherwise, the `dtype` is inferred to + be `torch.int64`. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.arange(5) + tensor([ 0, 1, 2, 3, 4]) + >>> torch.arange(1, 4) + tensor([ 1, 2, 3]) + >>> torch.arange(1, 2.5, 0.5) + tensor([ 1.0000, 1.5000, 2.0000]) + """ + +def arccos(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + arccos(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Alias for :func:`torch.acos`. + """ + +def arccos_(input: Tensor) -> Tensor: ... +def arccosh(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + arccosh(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Alias for :func:`torch.acosh`. + """ + +def arccosh_(input: Tensor) -> Tensor: ... +def arcsin(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + arcsin(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Alias for :func:`torch.asin`. + """ + +def arcsin_(input: Tensor) -> Tensor: ... +def arcsinh(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + arcsinh(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Alias for :func:`torch.asinh`. + """ + +def arcsinh_(input: Tensor) -> Tensor: ... +def arctan(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + arctan(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Alias for :func:`torch.atan`. + """ + +def arctan2( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + arctan2(input: Tensor, other: Tensor, *, out: Optional[Tensor]) -> Tensor + Alias for :func:`torch.atan2`. + """ + +def arctan_(input: Tensor) -> Tensor: ... +def arctanh(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + arctanh(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Alias for :func:`torch.atanh`. + """ + +def arctanh_(input: Tensor) -> Tensor: ... +def argmax( + input: Tensor, + dim: _int | None = None, + keepdim: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + argmax(input) -> LongTensor + + Returns the indices of the maximum value of all elements in the :attr:`input` tensor. + + This is the second value returned by :meth:`torch.max`. See its + documentation for the exact semantics of this method. + + .. note:: If there are multiple maximal values then the indices of the first maximal value are returned. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 1.3398, 0.2663, -0.2686, 0.2450], + [-0.7401, -0.8805, -0.3402, -1.1936], + [ 0.4907, -1.3948, -1.0691, -0.3132], + [-1.6092, 0.5419, -0.2993, 0.3195]]) + >>> torch.argmax(a) + tensor(0) + + .. function:: argmax(input, dim, keepdim=False) -> LongTensor + :noindex: + + Returns the indices of the maximum values of a tensor across a dimension. + + This is the second value returned by :meth:`torch.max`. See its + documentation for the exact semantics of this method. + + Args: + input (Tensor): the input tensor. + + dim (int, optional): the dimension to reduce. + If ``None``, the argmax of the flattened input is returned. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 1.3398, 0.2663, -0.2686, 0.2450], + [-0.7401, -0.8805, -0.3402, -1.1936], + [ 0.4907, -1.3948, -1.0691, -0.3132], + [-1.6092, 0.5419, -0.2993, 0.3195]]) + >>> torch.argmax(a, dim=1) + tensor([ 0, 2, 0, 1]) + """ + +def argmin( + input: Tensor, + dim: _int | None = None, + keepdim: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + argmin(input, dim=None, keepdim=False) -> LongTensor + + Returns the indices of the minimum value(s) of the flattened tensor or along a dimension + + This is the second value returned by :meth:`torch.min`. See its + documentation for the exact semantics of this method. + + .. note:: If there are multiple minimal values then the indices of the first minimal value are returned. + + Args: + input (Tensor): the input tensor. + + dim (int, optional): the dimension to reduce. + If ``None``, the argmin of the flattened input is returned. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.1139, 0.2254, -0.1381, 0.3687], + [ 1.0100, -1.1975, -0.0102, -0.4732], + [-0.9240, 0.1207, -0.7506, -1.0213], + [ 1.7809, -1.2960, 0.9384, 0.1438]]) + >>> torch.argmin(a) + tensor(13) + >>> torch.argmin(a, dim=1) + tensor([ 2, 1, 3, 1]) + >>> torch.argmin(a, dim=1, keepdim=True) + tensor([[2], + [1], + [3], + [1]]) + """ + +@overload +def argsort( + input: Tensor, + *, + stable: _bool, + dim: _int = -1, + descending: _bool = False, + out: Tensor | None = None, +) -> Tensor: + r""" + argsort(input, dim=-1, descending=False, stable=False) -> Tensor + + Returns the indices that sort a tensor along a given dimension in ascending + order by value. + + This is the second value returned by :meth:`torch.sort`. See its documentation + for the exact semantics of this method. + + If :attr:`stable` is ``True`` then the sorting routine becomes stable, preserving + the order of equivalent elements. If ``False``, the relative order of values + which compare equal is not guaranteed. ``True`` is slower. + + Args: + input (Tensor): the input tensor. + dim (int, optional): the dimension to sort along + descending (bool, optional): controls the sorting order (ascending or descending) + stable (bool, optional): controls the relative order of equivalent elements + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.0785, 1.5267, -0.8521, 0.4065], + [ 0.1598, 0.0788, -0.0745, -1.2700], + [ 1.2208, 1.0722, -0.7064, 1.2564], + [ 0.0669, -0.2318, -0.8229, -0.9280]]) + + + >>> torch.argsort(a, dim=1) + tensor([[2, 0, 3, 1], + [3, 2, 1, 0], + [2, 1, 0, 3], + [3, 2, 1, 0]]) + """ + +@overload +def argsort( + input: Tensor, + dim: _int = -1, + descending: _bool = False, +) -> Tensor: + r""" + argsort(input, dim=-1, descending=False, stable=False) -> Tensor + + Returns the indices that sort a tensor along a given dimension in ascending + order by value. + + This is the second value returned by :meth:`torch.sort`. See its documentation + for the exact semantics of this method. + + If :attr:`stable` is ``True`` then the sorting routine becomes stable, preserving + the order of equivalent elements. If ``False``, the relative order of values + which compare equal is not guaranteed. ``True`` is slower. + + Args: + input (Tensor): the input tensor. + dim (int, optional): the dimension to sort along + descending (bool, optional): controls the sorting order (ascending or descending) + stable (bool, optional): controls the relative order of equivalent elements + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.0785, 1.5267, -0.8521, 0.4065], + [ 0.1598, 0.0788, -0.0745, -1.2700], + [ 1.2208, 1.0722, -0.7064, 1.2564], + [ 0.0669, -0.2318, -0.8229, -0.9280]]) + + + >>> torch.argsort(a, dim=1) + tensor([[2, 0, 3, 1], + [3, 2, 1, 0], + [2, 1, 0, 3], + [3, 2, 1, 0]]) + """ + +@overload +def argsort( + input: Tensor, + dim: str | EllipsisType | None, + descending: _bool = False, +) -> Tensor: + r""" + argsort(input, dim=-1, descending=False, stable=False) -> Tensor + + Returns the indices that sort a tensor along a given dimension in ascending + order by value. + + This is the second value returned by :meth:`torch.sort`. See its documentation + for the exact semantics of this method. + + If :attr:`stable` is ``True`` then the sorting routine becomes stable, preserving + the order of equivalent elements. If ``False``, the relative order of values + which compare equal is not guaranteed. ``True`` is slower. + + Args: + input (Tensor): the input tensor. + dim (int, optional): the dimension to sort along + descending (bool, optional): controls the sorting order (ascending or descending) + stable (bool, optional): controls the relative order of equivalent elements + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.0785, 1.5267, -0.8521, 0.4065], + [ 0.1598, 0.0788, -0.0745, -1.2700], + [ 1.2208, 1.0722, -0.7064, 1.2564], + [ 0.0669, -0.2318, -0.8229, -0.9280]]) + + + >>> torch.argsort(a, dim=1) + tensor([[2, 0, 3, 1], + [3, 2, 1, 0], + [2, 1, 0, 3], + [3, 2, 1, 0]]) + """ + +def argwhere(input: Tensor) -> Tensor: + r""" + argwhere(input) -> Tensor + + Returns a tensor containing the indices of all non-zero elements of + :attr:`input`. Each row in the result contains the indices of a non-zero + element in :attr:`input`. The result is sorted lexicographically, with + the last index changing the fastest (C-style). + + If :attr:`input` has :math:`n` dimensions, then the resulting indices tensor + :attr:`out` is of size :math:`(z \times n)`, where :math:`z` is the total number of + non-zero elements in the :attr:`input` tensor. + + .. note:: + This function is similar to NumPy's `argwhere`. + + When :attr:`input` is on CUDA, this function causes host-device synchronization. + + Args: + {input} + + Example:: + + >>> t = torch.tensor([1, 0, 1]) + >>> torch.argwhere(t) + tensor([[0], + [2]]) + >>> t = torch.tensor([[1, 0, 1], [0, 1, 1]]) + >>> torch.argwhere(t) + tensor([[0, 0], + [0, 2], + [1, 1], + [1, 2]]) + """ + +def as_strided( + input: Tensor, + size: Sequence[_int | SymInt], + stride: Sequence[_int | SymInt], + storage_offset: _int | SymInt | None = None, +) -> Tensor: + r""" + as_strided(input, size, stride, storage_offset=None) -> Tensor + + Create a view of an existing `torch.Tensor` :attr:`input` with specified + :attr:`size`, :attr:`stride` and :attr:`storage_offset`. + + .. warning:: + Prefer using other view functions, like :meth:`torch.Tensor.view` or + :meth:`torch.Tensor.expand`, to setting a view's strides manually with + `as_strided`, as this function will throw an error on non-standard Pytorch + backends (that do not have a concept of stride) and the result will depend + on the current layout in memory. The constructed view must only refer to + elements within the Tensor's storage or a runtime error will be thrown. + If the generated view is "overlapped" (with multiple indices referring to + the same element in memory), the behavior of inplace operations on this view + is undefined (and might not throw runtime errors). + + Args: + input (Tensor): the input tensor. + size (tuple or ints): the shape of the output tensor + stride (tuple or ints): the stride of the output tensor + storage_offset (int, optional): the offset in the underlying storage of the output tensor. + If ``None``, the storage_offset of the output tensor will match the input tensor. + + Example:: + + >>> x = torch.randn(3, 3) + >>> x + tensor([[ 0.9039, 0.6291, 1.0795], + [ 0.1586, 2.1939, -0.4900], + [-0.1909, -0.7503, 1.9355]]) + >>> t = torch.as_strided(x, (2, 2), (1, 2)) + >>> t + tensor([[0.9039, 1.0795], + [0.6291, 0.1586]]) + >>> t = torch.as_strided(x, (2, 2), (1, 2), 1) + tensor([[0.6291, 0.1586], + [1.0795, 2.1939]]) + """ + +def as_strided_( + input: Tensor, + size: Sequence[_int | SymInt], + stride: Sequence[_int | SymInt], + storage_offset: _int | SymInt | None = None, +) -> Tensor: ... +def as_strided_copy( + input: Tensor, + size: Sequence[_int | SymInt], + stride: Sequence[_int | SymInt], + storage_offset: _int | SymInt | None = None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + Performs the same operation as :func:`torch.as_strided`, but all output tensors + are freshly created instead of aliasing the input. + """ + +def as_strided_scatter( + input: Tensor, + src: Tensor, + size: Sequence[_int | SymInt], + stride: Sequence[_int | SymInt], + storage_offset: _int | SymInt | None = None, +) -> Tensor: + r""" + as_strided_scatter(input, src, size, stride, storage_offset=None) -> Tensor + + Embeds the values of the :attr:`src` tensor into :attr:`input` along + the elements corresponding to the result of calling + input.as_strided(size, stride, storage_offset). + + This function returns a tensor with fresh storage; it does not + return a view. + + Args: + input (Tensor): the input tensor. + size (tuple or ints): the shape of the output tensor + stride (tuple or ints): the stride of the output tensor + storage_offset (int, optional): the offset in the underlying storage of the output tensor + + .. note:: + + :attr:`src` must be of the proper size in order to be embedded + into :attr:`input`. Specifically, it should have the same shape as + `torch.as_strided(input, size, stride, storage_offset)` + + Example:: + + >>> a = torch.arange(4).reshape(2, 2) + 1 + >>> a + tensor([[1, 2], + [3, 4]]) + >>> b = torch.zeros(3, 3) + >>> b + tensor([[0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.]]) + >>> torch.as_strided_scatter(b, a, (2, 2), (1, 2)) + tensor([[1., 3., 2.], + [4., 0., 0.], + [0., 0., 0.]]) + """ + +def as_tensor( + data: Any, + dtype: _dtype | None = None, + device: DeviceLikeType | None = None, +) -> Tensor: + r""" + as_tensor(data: Any, dtype: Optional[dtype] = None, device: Optional[DeviceLikeType]) -> Tensor + + Converts :attr:`data` into a tensor, sharing data and preserving autograd + history if possible. + + If :attr:`data` is already a tensor with the requested dtype and device + then :attr:`data` itself is returned, but if :attr:`data` is a + tensor with a different dtype or device then it's copied as if using + `data.to(dtype=dtype, device=device)`. + + If :attr:`data` is a NumPy array (an ndarray) with the same dtype and device then a + tensor is constructed using :func:`torch.from_numpy`. + + If :attr:`data` is a CuPy array, the returned tensor will be located on the same device as the CuPy array unless + specifically overwritten by :attr:`device` or a default device. + + .. seealso:: + + :func:`torch.tensor` never shares its data and creates a new "leaf tensor" (see :doc:`/notes/autograd`). + + + Args: + data (array_like): Initial data for the tensor. Can be a list, tuple, + NumPy ``ndarray``, scalar, and other types. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, infers data type from :attr:`data`. + device (:class:`torch.device`, optional): the device of the constructed tensor. If None and data is a tensor + then the device of data is used. If None and data is not a tensor then + the result tensor is constructed on the current device. + + + Example:: + + >>> a = numpy.array([1, 2, 3]) + >>> t = torch.as_tensor(a) + >>> t + tensor([ 1, 2, 3]) + >>> t[0] = -1 + >>> a + array([-1, 2, 3]) + + >>> a = numpy.array([1, 2, 3]) + >>> t = torch.as_tensor(a, device=torch.device('cuda')) + >>> t + tensor([ 1, 2, 3]) + >>> t[0] = -1 + >>> a + array([1, 2, 3]) + """ + +def asarray( + obj: Any, + *, + dtype: _dtype | None = None, + device: DeviceLikeType | None = None, + copy: _bool | None = None, + requires_grad: _bool = False, +) -> Tensor: + r""" + asarray(obj: Any, *, dtype: Optional[dtype], device: Optional[DeviceLikeType], copy: Optional[bool] = None, requires_grad: bool = False) -> Tensor # noqa: B950 + + Converts :attr:`obj` to a tensor. + + :attr:`obj` can be one of: + + 1. a tensor + 2. a NumPy array or a NumPy scalar + 3. a DLPack capsule + 4. an object that implements Python's buffer protocol + 5. a scalar + 6. a sequence of scalars + + When :attr:`obj` is a tensor, NumPy array, or DLPack capsule the returned tensor will, + by default, not require a gradient, have the same datatype as :attr:`obj`, be on the + same device, and share memory with it. These properties can be controlled with the + :attr:`dtype`, :attr:`device`, :attr:`copy`, and :attr:`requires_grad` keyword arguments. + If the returned tensor is of a different datatype, on a different device, or a copy is + requested then it will not share its memory with :attr:`obj`. If :attr:`requires_grad` + is ``True`` then the returned tensor will require a gradient, and if :attr:`obj` is + also a tensor with an autograd history then the returned tensor will have the same history. + + When :attr:`obj` is not a tensor, NumPy array, or DLPack capsule but implements Python's + buffer protocol then the buffer is interpreted as an array of bytes grouped according to + the size of the datatype passed to the :attr:`dtype` keyword argument. (If no datatype is + passed then the default floating point datatype is used, instead.) The returned tensor + will have the specified datatype (or default floating point datatype if none is specified) + and, by default, be on the CPU device and share memory with the buffer. + + When :attr:`obj` is a NumPy scalar, the returned tensor will be a 0-dimensional tensor on + the CPU and that doesn't share its memory (i.e. ``copy=True``). By default datatype will + be the PyTorch datatype corresponding to the NumPy's scalar's datatype. + + When :attr:`obj` is none of the above but a scalar, or a sequence of scalars then the + returned tensor will, by default, infer its datatype from the scalar values, be on the + current default device, and not share its memory. + + .. seealso:: + + :func:`torch.tensor` creates a tensor that always copies the data from the input object. + :func:`torch.from_numpy` creates a tensor that always shares memory from NumPy arrays. + :func:`torch.frombuffer` creates a tensor that always shares memory from objects that + implement the buffer protocol. + :func:`torch.from_dlpack` creates a tensor that always shares memory from + DLPack capsules. + + Args: + obj (object): a tensor, NumPy array, DLPack Capsule, object that implements Python's + buffer protocol, scalar, or sequence of scalars. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the datatype of the returned tensor. + Default: ``None``, which causes the datatype of the returned tensor to be + inferred from :attr:`obj`. + copy (bool, optional): controls whether the returned tensor shares memory with :attr:`obj`. + Default: ``None``, which causes the returned tensor to share memory with :attr:`obj` + whenever possible. If ``True`` then the returned tensor does not share its memory. + If ``False`` then the returned tensor shares its memory with :attr:`obj` and an + error is thrown if it cannot. + device (:class:`torch.device`, optional): the device of the returned tensor. + Default: ``None``, which causes the device of :attr:`obj` to be used. Or, if + :attr:`obj` is a Python sequence, the current default device will be used. + requires_grad (bool, optional): whether the returned tensor requires grad. + Default: ``False``, which causes the returned tensor not to require a gradient. + If ``True``, then the returned tensor will require a gradient, and if :attr:`obj` + is also a tensor with an autograd history then the returned tensor will have + the same history. + + Example:: + + >>> a = torch.tensor([1, 2, 3]) + >>> # Shares memory with tensor 'a' + >>> b = torch.asarray(a) + >>> a.data_ptr() == b.data_ptr() + True + >>> # Forces memory copy + >>> c = torch.asarray(a, copy=True) + >>> a.data_ptr() == c.data_ptr() + False + + >>> a = torch.tensor([1., 2., 3.], requires_grad=True) + >>> b = a + 2 + >>> b + tensor([3., 4., 5.], grad_fn=) + >>> # Shares memory with tensor 'b', with no grad + >>> c = torch.asarray(b) + >>> c + tensor([3., 4., 5.]) + >>> # Shares memory with tensor 'b', retaining autograd history + >>> d = torch.asarray(b, requires_grad=True) + >>> d + tensor([3., 4., 5.], grad_fn=) + + >>> array = numpy.array([1, 2, 3]) + >>> # Shares memory with array 'array' + >>> t1 = torch.asarray(array) + >>> array.__array_interface__['data'][0] == t1.data_ptr() + True + >>> # Copies memory due to dtype mismatch + >>> t2 = torch.asarray(array, dtype=torch.float32) + >>> array.__array_interface__['data'][0] == t2.data_ptr() + False + + >>> scalar = numpy.float64(0.5) + >>> torch.asarray(scalar) + tensor(0.5000, dtype=torch.float64) + """ + +def asin(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + asin(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Returns a new tensor with the arcsine of the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \sin^{-1}(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-0.5962, 1.4985, -0.4396, 1.4525]) + >>> torch.asin(a) + tensor([-0.6387, nan, -0.4552, nan]) + """ + +def asin_(input: Tensor) -> Tensor: ... +def asinh(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + asinh(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Returns a new tensor with the inverse hyperbolic sine of the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \sinh^{-1}(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword arguments: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.1606, -1.4267, -1.0899, -1.0250 ]) + >>> torch.asinh(a) + tensor([ 0.1599, -1.1534, -0.9435, -0.8990 ]) + """ + +def asinh_(input: Tensor) -> Tensor: ... +def atan(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + atan(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Returns a new tensor with the arctangent of the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \tan^{-1}(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.2341, 0.2539, -0.6256, -0.6448]) + >>> torch.atan(a) + tensor([ 0.2299, 0.2487, -0.5591, -0.5727]) + """ + +def atan2( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + atan2(input: Tensor, other: Tensor, *, out: Optional[Tensor]) -> Tensor + + Element-wise arctangent of :math:`\text{input}_{i} / \text{other}_{i}` + with consideration of the quadrant. Returns a new tensor with the signed angles + in radians between vector :math:`(\text{other}_{i}, \text{input}_{i})` + and vector :math:`(1, 0)`. (Note that :math:`\text{other}_{i}`, the second + parameter, is the x-coordinate, while :math:`\text{input}_{i}`, the first + parameter, is the y-coordinate.) + + The shapes of ``input`` and ``other`` must be + :ref:`broadcastable `. + + Args: + input (Tensor): the first input tensor + other (Tensor): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.9041, 0.0196, -0.3108, -2.4423]) + >>> torch.atan2(a, torch.randn(4)) + tensor([ 0.9833, 0.0811, -1.9743, -1.4151]) + """ + +def atan_(input: Tensor) -> Tensor: ... +def atanh(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + atanh(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Returns a new tensor with the inverse hyperbolic tangent of the elements of :attr:`input`. + + Note: + The domain of the inverse hyperbolic tangent is `(-1, 1)` and values outside this range + will be mapped to ``NaN``, except for the values `1` and `-1` for which the output is + mapped to `+/-INF` respectively. + + .. math:: + \text{out}_{i} = \tanh^{-1}(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword arguments: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4).uniform_(-1, 1) + >>> a + tensor([ -0.9385, 0.2968, -0.8591, -0.1871 ]) + >>> torch.atanh(a) + tensor([ -1.7253, 0.3060, -1.2899, -0.1893 ]) + """ + +def atanh_(input: Tensor) -> Tensor: ... +def avg_pool1d( + input: Tensor, + kernel_size: _int | _size, + stride: _int | _size = (), + padding: _int | _size = 0, + ceil_mode: _bool = False, + count_include_pad: _bool = True, +) -> Tensor: ... +@overload +def baddbmm( + beta: Number | _complex, + self: Tensor, + alpha: Number | _complex, + batch1: Tensor, + batch2: Tensor, +) -> Tensor: + r""" + baddbmm(input, batch1, batch2, out_dtype=None, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices in :attr:`batch1` + and :attr:`batch2`. + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the same + number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a + :math:`(b \times n \times p)` tensor and :attr:`out` will be a + :math:`(b \times n \times p)` tensor. Both :attr:`alpha` and :attr:`beta` mean the + same as the scaling factors used in :meth:`torch.addbmm`. + + .. math:: + \text{out}_i = \beta\ \text{input}_i + \alpha\ (\text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): the tensor to be added + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + out_dtype (dtype, optional): the dtype of the output tensor, + Supported only on CUDA and for torch.float32 given + torch.float16/torch.bfloat16 input dtypes + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{batch1} \mathbin{@} \text{batch2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(10, 3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.baddbmm(M, batch1, batch2).size() + torch.Size([10, 3, 5]) + """ + +@overload +def baddbmm( + beta: Number | _complex, + self: Tensor, + alpha: Number | _complex, + batch1: Tensor, + batch2: Tensor, + *, + out: Tensor, +) -> Tensor: + r""" + baddbmm(input, batch1, batch2, out_dtype=None, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices in :attr:`batch1` + and :attr:`batch2`. + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the same + number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a + :math:`(b \times n \times p)` tensor and :attr:`out` will be a + :math:`(b \times n \times p)` tensor. Both :attr:`alpha` and :attr:`beta` mean the + same as the scaling factors used in :meth:`torch.addbmm`. + + .. math:: + \text{out}_i = \beta\ \text{input}_i + \alpha\ (\text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): the tensor to be added + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + out_dtype (dtype, optional): the dtype of the output tensor, + Supported only on CUDA and for torch.float32 given + torch.float16/torch.bfloat16 input dtypes + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{batch1} \mathbin{@} \text{batch2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(10, 3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.baddbmm(M, batch1, batch2).size() + torch.Size([10, 3, 5]) + """ + +@overload +def baddbmm( + input: Tensor, + batch1: Tensor, + batch2: Tensor, + *, + beta: Number | _complex = 1, + alpha: Number | _complex = 1, + out: Tensor | None = None, +) -> Tensor: + r""" + baddbmm(input, batch1, batch2, out_dtype=None, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices in :attr:`batch1` + and :attr:`batch2`. + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the same + number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a + :math:`(b \times n \times p)` tensor and :attr:`out` will be a + :math:`(b \times n \times p)` tensor. Both :attr:`alpha` and :attr:`beta` mean the + same as the scaling factors used in :meth:`torch.addbmm`. + + .. math:: + \text{out}_i = \beta\ \text{input}_i + \alpha\ (\text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): the tensor to be added + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + out_dtype (dtype, optional): the dtype of the output tensor, + Supported only on CUDA and for torch.float32 given + torch.float16/torch.bfloat16 input dtypes + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{batch1} \mathbin{@} \text{batch2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(10, 3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.baddbmm(M, batch1, batch2).size() + torch.Size([10, 3, 5]) + """ + +@overload +def baddbmm( + input: Tensor, + batch1: Tensor, + batch2: Tensor, + out_dtype: _dtype, + *, + beta: Number | _complex = 1, + alpha: Number | _complex = 1, + out: Tensor | None = None, +) -> Tensor: + r""" + baddbmm(input, batch1, batch2, out_dtype=None, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices in :attr:`batch1` + and :attr:`batch2`. + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the same + number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a + :math:`(b \times n \times p)` tensor and :attr:`out` will be a + :math:`(b \times n \times p)` tensor. Both :attr:`alpha` and :attr:`beta` mean the + same as the scaling factors used in :meth:`torch.addbmm`. + + .. math:: + \text{out}_i = \beta\ \text{input}_i + \alpha\ (\text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): the tensor to be added + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + out_dtype (dtype, optional): the dtype of the output tensor, + Supported only on CUDA and for torch.float32 given + torch.float16/torch.bfloat16 input dtypes + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{batch1} \mathbin{@} \text{batch2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(10, 3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.baddbmm(M, batch1, batch2).size() + torch.Size([10, 3, 5]) + """ + +@overload +def baddbmm( + beta: Number | _complex, + self: Tensor, + batch1: Tensor, + batch2: Tensor, +) -> Tensor: + r""" + baddbmm(input, batch1, batch2, out_dtype=None, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices in :attr:`batch1` + and :attr:`batch2`. + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the same + number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a + :math:`(b \times n \times p)` tensor and :attr:`out` will be a + :math:`(b \times n \times p)` tensor. Both :attr:`alpha` and :attr:`beta` mean the + same as the scaling factors used in :meth:`torch.addbmm`. + + .. math:: + \text{out}_i = \beta\ \text{input}_i + \alpha\ (\text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): the tensor to be added + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + out_dtype (dtype, optional): the dtype of the output tensor, + Supported only on CUDA and for torch.float32 given + torch.float16/torch.bfloat16 input dtypes + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{batch1} \mathbin{@} \text{batch2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(10, 3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.baddbmm(M, batch1, batch2).size() + torch.Size([10, 3, 5]) + """ + +@overload +def baddbmm( + beta: Number | _complex, + self: Tensor, + batch1: Tensor, + batch2: Tensor, + *, + out: Tensor, +) -> Tensor: + r""" + baddbmm(input, batch1, batch2, out_dtype=None, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices in :attr:`batch1` + and :attr:`batch2`. + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the same + number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a + :math:`(b \times n \times p)` tensor and :attr:`out` will be a + :math:`(b \times n \times p)` tensor. Both :attr:`alpha` and :attr:`beta` mean the + same as the scaling factors used in :meth:`torch.addbmm`. + + .. math:: + \text{out}_i = \beta\ \text{input}_i + \alpha\ (\text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): the tensor to be added + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + out_dtype (dtype, optional): the dtype of the output tensor, + Supported only on CUDA and for torch.float32 given + torch.float16/torch.bfloat16 input dtypes + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{batch1} \mathbin{@} \text{batch2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(10, 3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.baddbmm(M, batch1, batch2).size() + torch.Size([10, 3, 5]) + """ + +@overload +def bartlett_window( + window_length: _int, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + bartlett_window(window_length, periodic=True, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Bartlett window function. + + .. math:: + w[n] = 1 - \left| \frac{2n}{N-1} - 1 \right| = \begin{cases} + \frac{2n}{N - 1} & \text{if } 0 \leq n \leq \frac{N - 1}{2} \\ + 2 - \frac{2n}{N - 1} & \text{if } \frac{N - 1}{2} < n < N \\ + \end{cases}, + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.bartlett_window(L, periodic=True)`` equal to + ``torch.bartlett_window(L + 1, periodic=False)[:-1])``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window + """ + +@overload +def bartlett_window( + window_length: _int, + periodic: _bool, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + bartlett_window(window_length, periodic=True, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Bartlett window function. + + .. math:: + w[n] = 1 - \left| \frac{2n}{N-1} - 1 \right| = \begin{cases} + \frac{2n}{N - 1} & \text{if } 0 \leq n \leq \frac{N - 1}{2} \\ + 2 - \frac{2n}{N - 1} & \text{if } \frac{N - 1}{2} < n < N \\ + \end{cases}, + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.bartlett_window(L, periodic=True)`` equal to + ``torch.bartlett_window(L + 1, periodic=False)[:-1])``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window + """ + +def batch_norm( + input: Tensor, + weight: Tensor | None, + bias: Tensor | None, + running_mean: Tensor | None, + running_var: Tensor | None, + training: _bool, + momentum: _float, + eps: _float, + cudnn_enabled: _bool, +) -> Tensor: ... +def batch_norm_backward_elemt( + grad_out: Tensor, + input: Tensor, + mean: Tensor, + invstd: Tensor, + weight: Tensor | None, + sum_dy: Tensor, + sum_dy_xmu: Tensor, + count: Tensor, +) -> Tensor: ... +def batch_norm_backward_reduce( + grad_out: Tensor, + input: Tensor, + mean: Tensor, + invstd: Tensor, + weight: Tensor | None, + input_g: _bool, + weight_g: _bool, + bias_g: _bool, +) -> tuple[Tensor, Tensor, Tensor, Tensor]: ... +def batch_norm_elemt( + input: Tensor, + weight: Tensor | None, + bias: Tensor | None, + mean: Tensor, + invstd: Tensor, + eps: _float, + *, + out: Tensor | None = None, +) -> Tensor: ... +def batch_norm_gather_stats( + input: Tensor, + mean: Tensor, + invstd: Tensor, + running_mean: Tensor | None, + running_var: Tensor | None, + momentum: _float, + eps: _float, + count: _int, +) -> tuple[Tensor, Tensor]: ... +def batch_norm_gather_stats_with_counts( + input: Tensor, + mean: Tensor, + invstd: Tensor, + running_mean: Tensor | None, + running_var: Tensor | None, + momentum: _float, + eps: _float, + counts: Tensor, +) -> tuple[Tensor, Tensor]: ... +def batch_norm_stats(input: Tensor, eps: _float) -> tuple[Tensor, Tensor]: ... +def batch_norm_update_stats( + input: Tensor, + running_mean: Tensor | None, + running_var: Tensor | None, + momentum: _float, +) -> tuple[Tensor, Tensor]: ... +@overload +def bernoulli( + input: Tensor, + *, + generator: Generator | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + bernoulli(input: Tensor, *, generator: Optional[Generator], out: Optional[Tensor]) -> Tensor + + Draws binary random numbers (0 or 1) from a Bernoulli distribution. + + The :attr:`input` tensor should be a tensor containing probabilities + to be used for drawing the binary random number. + Hence, all values in :attr:`input` have to be in the range: + :math:`0 \leq \text{input}_i \leq 1`. + + The :math:`\text{i}^{th}` element of the output tensor will draw a + value :math:`1` according to the :math:`\text{i}^{th}` probability value given + in :attr:`input`. + + .. math:: + \text{out}_{i} \sim \mathrm{Bernoulli}(p = \text{input}_{i}) + + The returned :attr:`out` tensor only has values 0 or 1 and is of the same + shape as :attr:`input`. + + :attr:`out` can have integral ``dtype``, but :attr:`input` must have floating + point ``dtype``. + + Args: + input (Tensor): the input tensor of probability values for the Bernoulli distribution + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.empty(3, 3).uniform_(0, 1) # generate a uniform random matrix with range [0, 1] + >>> a + tensor([[ 0.1737, 0.0950, 0.3609], + [ 0.7148, 0.0289, 0.2676], + [ 0.9456, 0.8937, 0.7202]]) + >>> torch.bernoulli(a) + tensor([[ 1., 0., 0.], + [ 0., 0., 0.], + [ 1., 1., 1.]]) + + >>> a = torch.ones(3, 3) # probability of drawing "1" is 1 + >>> torch.bernoulli(a) + tensor([[ 1., 1., 1.], + [ 1., 1., 1.], + [ 1., 1., 1.]]) + >>> a = torch.zeros(3, 3) # probability of drawing "1" is 0 + >>> torch.bernoulli(a) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.], + [ 0., 0., 0.]]) + """ + +@overload +def bernoulli( + input: Tensor, + p: _float, + *, + generator: Generator | None = None, +) -> Tensor: + r""" + bernoulli(input: Tensor, *, generator: Optional[Generator], out: Optional[Tensor]) -> Tensor + + Draws binary random numbers (0 or 1) from a Bernoulli distribution. + + The :attr:`input` tensor should be a tensor containing probabilities + to be used for drawing the binary random number. + Hence, all values in :attr:`input` have to be in the range: + :math:`0 \leq \text{input}_i \leq 1`. + + The :math:`\text{i}^{th}` element of the output tensor will draw a + value :math:`1` according to the :math:`\text{i}^{th}` probability value given + in :attr:`input`. + + .. math:: + \text{out}_{i} \sim \mathrm{Bernoulli}(p = \text{input}_{i}) + + The returned :attr:`out` tensor only has values 0 or 1 and is of the same + shape as :attr:`input`. + + :attr:`out` can have integral ``dtype``, but :attr:`input` must have floating + point ``dtype``. + + Args: + input (Tensor): the input tensor of probability values for the Bernoulli distribution + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.empty(3, 3).uniform_(0, 1) # generate a uniform random matrix with range [0, 1] + >>> a + tensor([[ 0.1737, 0.0950, 0.3609], + [ 0.7148, 0.0289, 0.2676], + [ 0.9456, 0.8937, 0.7202]]) + >>> torch.bernoulli(a) + tensor([[ 1., 0., 0.], + [ 0., 0., 0.], + [ 1., 1., 1.]]) + + >>> a = torch.ones(3, 3) # probability of drawing "1" is 1 + >>> torch.bernoulli(a) + tensor([[ 1., 1., 1.], + [ 1., 1., 1.], + [ 1., 1., 1.]]) + >>> a = torch.zeros(3, 3) # probability of drawing "1" is 0 + >>> torch.bernoulli(a) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.], + [ 0., 0., 0.]]) + """ + +def bilinear( + input1: Tensor, + input2: Tensor, + weight: Tensor, + bias: Tensor | None = None, +) -> Tensor: ... +def binary_cross_entropy_with_logits( + input: Tensor, + target: Tensor, + weight: Tensor | None = None, + pos_weight: Tensor | None = None, + reduction: _int = 1, +) -> Tensor: ... +def bincount( + input: Tensor, + weights: Tensor | None = None, + minlength: _int | SymInt = 0, +) -> Tensor: + r""" + bincount(input, weights=None, minlength=0) -> Tensor + + Count the frequency of each value in an array of non-negative ints. + + The number of bins (size 1) is one larger than the largest value in + :attr:`input` unless :attr:`input` is empty, in which case the result is a + tensor of size 0. If :attr:`minlength` is specified, the number of bins is at least + :attr:`minlength` and if :attr:`input` is empty, then the result is tensor of size + :attr:`minlength` filled with zeros. If ``n`` is the value at position ``i``, + ``out[n] += weights[i]`` if :attr:`weights` is specified else + ``out[n] += 1``. + + Note: + This operation may produce nondeterministic gradients when given tensors on a CUDA device. See :doc:`/notes/randomness` for more information. + + Arguments: + input (Tensor): 1-d int tensor + weights (Tensor): optional, weight for each value in the input tensor. + Should be of same size as input tensor. + minlength (int): optional, minimum number of bins. Should be non-negative. + + Returns: + output (Tensor): a tensor of shape ``Size([max(input) + 1])`` if + :attr:`input` is non-empty, else ``Size(0)`` + + Example:: + + >>> input = torch.randint(0, 8, (5,), dtype=torch.int64) + >>> weights = torch.linspace(0, 1, steps=5) + >>> input, weights + (tensor([4, 3, 6, 3, 4]), + tensor([ 0.0000, 0.2500, 0.5000, 0.7500, 1.0000]) + + >>> torch.bincount(input) + tensor([0, 0, 0, 2, 2, 0, 1]) + + >>> input.bincount(weights) + tensor([0.0000, 0.0000, 0.0000, 1.0000, 1.0000, 0.0000, 0.5000]) + """ + +def binomial( + count: Tensor, + prob: Tensor, + generator: Generator | None = None, +) -> Tensor: ... +@overload +def bitwise_and( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + bitwise_and(input, other, *, out=None) -> Tensor + + Computes the bitwise AND of :attr:`input` and :attr:`other`. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical AND. + + Args: + input: the first input tensor + other: the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_and(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([1, 0, 3], dtype=torch.int8) + >>> torch.bitwise_and(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ False, True, False]) + """ + +@overload +def bitwise_and(self: Number | _complex, other: Tensor) -> Tensor: + r""" + bitwise_and(input, other, *, out=None) -> Tensor + + Computes the bitwise AND of :attr:`input` and :attr:`other`. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical AND. + + Args: + input: the first input tensor + other: the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_and(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([1, 0, 3], dtype=torch.int8) + >>> torch.bitwise_and(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ False, True, False]) + """ + +@overload +def bitwise_and( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + bitwise_and(input, other, *, out=None) -> Tensor + + Computes the bitwise AND of :attr:`input` and :attr:`other`. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical AND. + + Args: + input: the first input tensor + other: the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_and(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([1, 0, 3], dtype=torch.int8) + >>> torch.bitwise_and(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ False, True, False]) + """ + +@overload +def bitwise_left_shift( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + bitwise_left_shift(input, other, *, out=None) -> Tensor + + Computes the left arithmetic shift of :attr:`input` by :attr:`other` bits. + The input tensor must be of integral type. This operator supports + :ref:`broadcasting to a common shape ` and + :ref:`type promotion `. + + The operation applied is: + + .. math:: + \text{out}_i = \text{input}_i << \text{other}_i + + Args: + input (Tensor or Scalar): the first input tensor + other (Tensor or Scalar): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_left_shift(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-2, -2, 24], dtype=torch.int8) + """ + +@overload +def bitwise_left_shift(self: Number | _complex, other: Tensor) -> Tensor: + r""" + bitwise_left_shift(input, other, *, out=None) -> Tensor + + Computes the left arithmetic shift of :attr:`input` by :attr:`other` bits. + The input tensor must be of integral type. This operator supports + :ref:`broadcasting to a common shape ` and + :ref:`type promotion `. + + The operation applied is: + + .. math:: + \text{out}_i = \text{input}_i << \text{other}_i + + Args: + input (Tensor or Scalar): the first input tensor + other (Tensor or Scalar): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_left_shift(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-2, -2, 24], dtype=torch.int8) + """ + +@overload +def bitwise_left_shift( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + bitwise_left_shift(input, other, *, out=None) -> Tensor + + Computes the left arithmetic shift of :attr:`input` by :attr:`other` bits. + The input tensor must be of integral type. This operator supports + :ref:`broadcasting to a common shape ` and + :ref:`type promotion `. + + The operation applied is: + + .. math:: + \text{out}_i = \text{input}_i << \text{other}_i + + Args: + input (Tensor or Scalar): the first input tensor + other (Tensor or Scalar): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_left_shift(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-2, -2, 24], dtype=torch.int8) + """ + +def bitwise_not(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + bitwise_not(input, *, out=None) -> Tensor + + Computes the bitwise NOT of the given input tensor. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical NOT. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_not(torch.tensor([-1, -2, 3], dtype=torch.int8)) + tensor([ 0, 1, -4], dtype=torch.int8) + """ + +@overload +def bitwise_or( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + bitwise_or(input: Tensor, other: Tensor, *, out: Optional[Tensor]) -> Tensor + + Computes the bitwise OR of :attr:`input` and :attr:`other`. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical OR. + + Args: + input: the first input tensor + other: the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_or(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-1, -2, 3], dtype=torch.int8) + >>> torch.bitwise_or(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ True, True, False]) + """ + +@overload +def bitwise_or(self: Number | _complex, other: Tensor) -> Tensor: + r""" + bitwise_or(input: Tensor, other: Tensor, *, out: Optional[Tensor]) -> Tensor + + Computes the bitwise OR of :attr:`input` and :attr:`other`. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical OR. + + Args: + input: the first input tensor + other: the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_or(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-1, -2, 3], dtype=torch.int8) + >>> torch.bitwise_or(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ True, True, False]) + """ + +@overload +def bitwise_or( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + bitwise_or(input: Tensor, other: Tensor, *, out: Optional[Tensor]) -> Tensor + + Computes the bitwise OR of :attr:`input` and :attr:`other`. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical OR. + + Args: + input: the first input tensor + other: the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_or(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-1, -2, 3], dtype=torch.int8) + >>> torch.bitwise_or(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ True, True, False]) + """ + +@overload +def bitwise_right_shift( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + bitwise_right_shift(input, other, *, out=None) -> Tensor + + Computes the right arithmetic shift of :attr:`input` by :attr:`other` bits. + The input tensor must be of integral type. This operator supports + :ref:`broadcasting to a common shape ` and + :ref:`type promotion `. + In any case, if the value of the right operand is negative or is greater + or equal to the number of bits in the promoted left operand, the behavior is undefined. + + The operation applied is: + + .. math:: + \text{out}_i = \text{input}_i >> \text{other}_i + + Args: + input (Tensor or Scalar): the first input tensor + other (Tensor or Scalar): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_right_shift(torch.tensor([-2, -7, 31], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-1, -7, 3], dtype=torch.int8) + """ + +@overload +def bitwise_right_shift(self: Number | _complex, other: Tensor) -> Tensor: + r""" + bitwise_right_shift(input, other, *, out=None) -> Tensor + + Computes the right arithmetic shift of :attr:`input` by :attr:`other` bits. + The input tensor must be of integral type. This operator supports + :ref:`broadcasting to a common shape ` and + :ref:`type promotion `. + In any case, if the value of the right operand is negative or is greater + or equal to the number of bits in the promoted left operand, the behavior is undefined. + + The operation applied is: + + .. math:: + \text{out}_i = \text{input}_i >> \text{other}_i + + Args: + input (Tensor or Scalar): the first input tensor + other (Tensor or Scalar): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_right_shift(torch.tensor([-2, -7, 31], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-1, -7, 3], dtype=torch.int8) + """ + +@overload +def bitwise_right_shift( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + bitwise_right_shift(input, other, *, out=None) -> Tensor + + Computes the right arithmetic shift of :attr:`input` by :attr:`other` bits. + The input tensor must be of integral type. This operator supports + :ref:`broadcasting to a common shape ` and + :ref:`type promotion `. + In any case, if the value of the right operand is negative or is greater + or equal to the number of bits in the promoted left operand, the behavior is undefined. + + The operation applied is: + + .. math:: + \text{out}_i = \text{input}_i >> \text{other}_i + + Args: + input (Tensor or Scalar): the first input tensor + other (Tensor or Scalar): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_right_shift(torch.tensor([-2, -7, 31], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-1, -7, 3], dtype=torch.int8) + """ + +@overload +def bitwise_xor( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + bitwise_xor(input, other, *, out=None) -> Tensor + + Computes the bitwise XOR of :attr:`input` and :attr:`other`. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical XOR. + + Args: + input: the first input tensor + other: the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_xor(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-2, -2, 0], dtype=torch.int8) + >>> torch.bitwise_xor(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ True, False, False]) + """ + +@overload +def bitwise_xor(self: Number | _complex, other: Tensor) -> Tensor: + r""" + bitwise_xor(input, other, *, out=None) -> Tensor + + Computes the bitwise XOR of :attr:`input` and :attr:`other`. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical XOR. + + Args: + input: the first input tensor + other: the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_xor(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-2, -2, 0], dtype=torch.int8) + >>> torch.bitwise_xor(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ True, False, False]) + """ + +@overload +def bitwise_xor( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + bitwise_xor(input, other, *, out=None) -> Tensor + + Computes the bitwise XOR of :attr:`input` and :attr:`other`. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical XOR. + + Args: + input: the first input tensor + other: the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_xor(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-2, -2, 0], dtype=torch.int8) + >>> torch.bitwise_xor(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ True, False, False]) + """ + +@overload +def blackman_window( + window_length: _int, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + blackman_window(window_length, periodic=True, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Blackman window function. + + .. math:: + w[n] = 0.42 - 0.5 \cos \left( \frac{2 \pi n}{N - 1} \right) + 0.08 \cos \left( \frac{4 \pi n}{N - 1} \right) + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.blackman_window(L, periodic=True)`` equal to + ``torch.blackman_window(L + 1, periodic=False)[:-1]``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window + """ + +@overload +def blackman_window( + window_length: _int, + periodic: _bool, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + blackman_window(window_length, periodic=True, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Blackman window function. + + .. math:: + w[n] = 0.42 - 0.5 \cos \left( \frac{2 \pi n}{N - 1} \right) + 0.08 \cos \left( \frac{4 \pi n}{N - 1} \right) + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.blackman_window(L, periodic=True)`` equal to + ``torch.blackman_window(L + 1, periodic=False)[:-1]``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window + """ + +@overload +def bmm( + input: Tensor, + mat2: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + bmm(input, mat2, out_dtype=None, *, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices stored in :attr:`input` + and :attr:`mat2`. + + :attr:`input` and :attr:`mat2` must be 3-D tensors each containing + the same number of matrices. + + If :attr:`input` is a :math:`(b \times n \times m)` tensor, :attr:`mat2` is a + :math:`(b \times m \times p)` tensor, :attr:`out` will be a + :math:`(b \times n \times p)` tensor. + + .. math:: + \text{out}_i = \text{input}_i \mathbin{@} \text{mat2}_i + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + .. note:: This function does not :ref:`broadcast `. + For broadcasting matrix products, see :func:`torch.matmul`. + + Args: + input (Tensor): the first batch of matrices to be multiplied + mat2 (Tensor): the second batch of matrices to be multiplied + out_dtype (dtype, optional): the dtype of the output tensor, + Supported only on CUDA and for torch.float32 given + torch.float16/torch.bfloat16 input dtypes + + Keyword Args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> input = torch.randn(10, 3, 4) + >>> mat2 = torch.randn(10, 4, 5) + >>> res = torch.bmm(input, mat2) + >>> res.size() + torch.Size([10, 3, 5]) + """ + +@overload +def bmm( + input: Tensor, + mat2: Tensor, + out_dtype: _dtype, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + bmm(input, mat2, out_dtype=None, *, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices stored in :attr:`input` + and :attr:`mat2`. + + :attr:`input` and :attr:`mat2` must be 3-D tensors each containing + the same number of matrices. + + If :attr:`input` is a :math:`(b \times n \times m)` tensor, :attr:`mat2` is a + :math:`(b \times m \times p)` tensor, :attr:`out` will be a + :math:`(b \times n \times p)` tensor. + + .. math:: + \text{out}_i = \text{input}_i \mathbin{@} \text{mat2}_i + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + .. note:: This function does not :ref:`broadcast `. + For broadcasting matrix products, see :func:`torch.matmul`. + + Args: + input (Tensor): the first batch of matrices to be multiplied + mat2 (Tensor): the second batch of matrices to be multiplied + out_dtype (dtype, optional): the dtype of the output tensor, + Supported only on CUDA and for torch.float32 given + torch.float16/torch.bfloat16 input dtypes + + Keyword Args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> input = torch.randn(10, 3, 4) + >>> mat2 = torch.randn(10, 4, 5) + >>> res = torch.bmm(input, mat2) + >>> res.size() + torch.Size([10, 3, 5]) + """ + +def broadcast_to(input: Tensor, size: Sequence[_int | SymInt]) -> Tensor: + r""" + broadcast_to(input, shape) -> Tensor + + Broadcasts :attr:`input` to the shape :attr:`\shape`. + Equivalent to calling ``input.expand(shape)``. See :meth:`~Tensor.expand` for details. + + Args: + input (Tensor): the input tensor. + shape (list, tuple, or :class:`torch.Size`): the new shape. + + Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> torch.broadcast_to(x, (3, 3)) + tensor([[1, 2, 3], + [1, 2, 3], + [1, 2, 3]]) + """ + +@overload +def bucketize( + input: Tensor, + boundaries: Tensor, + *, + out_int32: _bool = False, + right: _bool = False, + out: Tensor | None = None, +) -> Tensor: + r""" + bucketize(input, boundaries, *, out_int32=False, right=False, out=None) -> Tensor + + Returns the indices of the buckets to which each value in the :attr:`input` belongs, where the + boundaries of the buckets are set by :attr:`boundaries`. Return a new tensor with the same size + as :attr:`input`. If :attr:`right` is False (default), then the left boundary is open. Note that + this behavior is opposite the behavior of + `numpy.digitize `_. + More formally, the returned index satisfies the following rules: + + .. list-table:: + :widths: 15 85 + :header-rows: 1 + + * - :attr:`right` + - *returned index satisfies* + * - False + - ``boundaries[i-1] < input[m][n]...[l][x] <= boundaries[i]`` + * - True + - ``boundaries[i-1] <= input[m][n]...[l][x] < boundaries[i]`` + + Args: + input (Tensor or Scalar): N-D tensor or a Scalar containing the search value(s). + boundaries (Tensor): 1-D tensor, must contain a strictly increasing sequence, or the return value is undefined. + + Keyword args: + out_int32 (bool, optional): indicate the output data type. torch.int32 if True, torch.int64 otherwise. + Default value is False, i.e. default output data type is torch.int64. + right (bool, optional): determines the behavior for values in :attr:`boundaries`. See the table above. + out (Tensor, optional): the output tensor, must be the same size as :attr:`input` if provided. + + + Example:: + + >>> boundaries = torch.tensor([1, 3, 5, 7, 9]) + >>> boundaries + tensor([1, 3, 5, 7, 9]) + >>> v = torch.tensor([[3, 6, 9], [3, 6, 9]]) + >>> v + tensor([[3, 6, 9], + [3, 6, 9]]) + >>> torch.bucketize(v, boundaries) + tensor([[1, 3, 4], + [1, 3, 4]]) + >>> torch.bucketize(v, boundaries, right=True) + tensor([[2, 3, 5], + [2, 3, 5]]) + """ + +@overload +def bucketize( + self: Number | _complex, + boundaries: Tensor, + *, + out_int32: _bool = False, + right: _bool = False, +) -> Tensor: + r""" + bucketize(input, boundaries, *, out_int32=False, right=False, out=None) -> Tensor + + Returns the indices of the buckets to which each value in the :attr:`input` belongs, where the + boundaries of the buckets are set by :attr:`boundaries`. Return a new tensor with the same size + as :attr:`input`. If :attr:`right` is False (default), then the left boundary is open. Note that + this behavior is opposite the behavior of + `numpy.digitize `_. + More formally, the returned index satisfies the following rules: + + .. list-table:: + :widths: 15 85 + :header-rows: 1 + + * - :attr:`right` + - *returned index satisfies* + * - False + - ``boundaries[i-1] < input[m][n]...[l][x] <= boundaries[i]`` + * - True + - ``boundaries[i-1] <= input[m][n]...[l][x] < boundaries[i]`` + + Args: + input (Tensor or Scalar): N-D tensor or a Scalar containing the search value(s). + boundaries (Tensor): 1-D tensor, must contain a strictly increasing sequence, or the return value is undefined. + + Keyword args: + out_int32 (bool, optional): indicate the output data type. torch.int32 if True, torch.int64 otherwise. + Default value is False, i.e. default output data type is torch.int64. + right (bool, optional): determines the behavior for values in :attr:`boundaries`. See the table above. + out (Tensor, optional): the output tensor, must be the same size as :attr:`input` if provided. + + + Example:: + + >>> boundaries = torch.tensor([1, 3, 5, 7, 9]) + >>> boundaries + tensor([1, 3, 5, 7, 9]) + >>> v = torch.tensor([[3, 6, 9], [3, 6, 9]]) + >>> v + tensor([[3, 6, 9], + [3, 6, 9]]) + >>> torch.bucketize(v, boundaries) + tensor([[1, 3, 4], + [1, 3, 4]]) + >>> torch.bucketize(v, boundaries, right=True) + tensor([[2, 3, 5], + [2, 3, 5]]) + """ + +def can_cast(from_: _dtype, to: _dtype) -> _bool: + r""" + can_cast(from_, to) -> bool + + Determines if a type conversion is allowed under PyTorch casting rules + described in the type promotion :ref:`documentation `. + + Args: + from\_ (dtype): The original :class:`torch.dtype`. + to (dtype): The target :class:`torch.dtype`. + + Example:: + + >>> torch.can_cast(torch.double, torch.float) + True + >>> torch.can_cast(torch.float, torch.int) + False + """ + +@overload +def cat( + tensors: tuple[Tensor, ...] | list[Tensor] | None, + dim: _int = 0, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + cat(tensors, dim=0, *, out=None) -> Tensor + + Concatenates the given sequence of tensors in :attr:`tensors` in the given dimension. + All tensors must either have the same shape (except in the concatenating + dimension) or be a 1-D empty tensor with size ``(0,)``. + + :func:`torch.cat` can be seen as an inverse operation for :func:`torch.split` + and :func:`torch.chunk`. + + :func:`torch.cat` can be best understood via examples. + + .. seealso:: + + :func:`torch.stack` concatenates the given sequence along a new dimension. + + Args: + tensors (sequence of Tensors): Non-empty tensors provided must have the same shape, + except in the cat dimension. + + dim (int, optional): the dimension over which the tensors are concatenated + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> x = torch.randn(2, 3) + >>> x + tensor([[ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497]]) + >>> torch.cat((x, x, x), 0) + tensor([[ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497], + [ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497], + [ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497]]) + >>> torch.cat((x, x, x), 1) + tensor([[ 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614, 0.6580, + -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497, -0.1034, + -0.5790, 0.1497]]) + """ + +@overload +def cat( + tensors: tuple[Tensor, ...] | list[Tensor] | None, + dim: str | EllipsisType | None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + cat(tensors, dim=0, *, out=None) -> Tensor + + Concatenates the given sequence of tensors in :attr:`tensors` in the given dimension. + All tensors must either have the same shape (except in the concatenating + dimension) or be a 1-D empty tensor with size ``(0,)``. + + :func:`torch.cat` can be seen as an inverse operation for :func:`torch.split` + and :func:`torch.chunk`. + + :func:`torch.cat` can be best understood via examples. + + .. seealso:: + + :func:`torch.stack` concatenates the given sequence along a new dimension. + + Args: + tensors (sequence of Tensors): Non-empty tensors provided must have the same shape, + except in the cat dimension. + + dim (int, optional): the dimension over which the tensors are concatenated + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> x = torch.randn(2, 3) + >>> x + tensor([[ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497]]) + >>> torch.cat((x, x, x), 0) + tensor([[ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497], + [ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497], + [ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497]]) + >>> torch.cat((x, x, x), 1) + tensor([[ 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614, 0.6580, + -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497, -0.1034, + -0.5790, 0.1497]]) + """ + +def ccol_indices_copy( + input: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: ... +def ceil(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + ceil(input, *, out=None) -> Tensor + + Returns a new tensor with the ceil of the elements of :attr:`input`, + the smallest integer greater than or equal to each element. + + For integer inputs, follows the array-api convention of returning a + copy of the input tensor. + + .. math:: + \text{out}_{i} = \left\lceil \text{input}_{i} \right\rceil + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-0.6341, -1.4208, -1.0900, 0.5826]) + >>> torch.ceil(a) + tensor([-0., -1., -1., 1.]) + """ + +def ceil_(input: Tensor) -> Tensor: ... +def celu(input: Tensor, alpha: Number | _complex = 1.0) -> Tensor: ... +def celu_(input: Tensor, alpha: Number | _complex = 1.0) -> Tensor: ... +def channel_shuffle(input: Tensor, groups: _int | SymInt) -> Tensor: ... +def cholesky( + input: Tensor, + upper: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + cholesky(input, upper=False, *, out=None) -> Tensor + + Computes the Cholesky decomposition of a symmetric positive-definite + matrix :math:`A` or for batches of symmetric positive-definite matrices. + + If :attr:`upper` is ``True``, the returned matrix ``U`` is upper-triangular, and + the decomposition has the form: + + .. math:: + + A = U^TU + + If :attr:`upper` is ``False``, the returned matrix ``L`` is lower-triangular, and + the decomposition has the form: + + .. math:: + + A = LL^T + + If :attr:`upper` is ``True``, and :math:`A` is a batch of symmetric positive-definite + matrices, then the returned tensor will be composed of upper-triangular Cholesky factors + of each of the individual matrices. Similarly, when :attr:`upper` is ``False``, the returned + tensor will be composed of lower-triangular Cholesky factors of each of the individual + matrices. + + .. warning:: + + :func:`torch.cholesky` is deprecated in favor of :func:`torch.linalg.cholesky` + and will be removed in a future PyTorch release. + + ``L = torch.cholesky(A)`` should be replaced with + + .. code:: python + + L = torch.linalg.cholesky(A) + + ``U = torch.cholesky(A, upper=True)`` should be replaced with + + .. code:: python + + U = torch.linalg.cholesky(A).mH + + This transform will produce equivalent results for all valid (symmetric positive definite) inputs. + + Args: + input (Tensor): the input tensor :math:`A` of size :math:`(*, n, n)` where `*` is zero or more + batch dimensions consisting of symmetric positive-definite matrices. + upper (bool, optional): flag that indicates whether to return a + upper or lower triangular matrix. Default: ``False`` + + Keyword args: + out (Tensor, optional): the output matrix + + Example:: + + >>> a = torch.randn(3, 3) + >>> a = a @ a.mT + 1e-3 # make symmetric positive-definite + >>> l = torch.cholesky(a) + >>> a + tensor([[ 2.4112, -0.7486, 1.4551], + [-0.7486, 1.3544, 0.1294], + [ 1.4551, 0.1294, 1.6724]]) + >>> l + tensor([[ 1.5528, 0.0000, 0.0000], + [-0.4821, 1.0592, 0.0000], + [ 0.9371, 0.5487, 0.7023]]) + >>> l @ l.mT + tensor([[ 2.4112, -0.7486, 1.4551], + [-0.7486, 1.3544, 0.1294], + [ 1.4551, 0.1294, 1.6724]]) + >>> a = torch.randn(3, 2, 2) # Example for batched input + >>> a = a @ a.mT + 1e-03 # make symmetric positive-definite + >>> l = torch.cholesky(a) + >>> z = l @ l.mT + >>> torch.dist(z, a) + tensor(2.3842e-07) + """ + +def cholesky_inverse( + input: Tensor, + upper: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + cholesky_inverse(L, upper=False, *, out=None) -> Tensor + + Computes the inverse of a complex Hermitian or real symmetric + positive-definite matrix given its Cholesky decomposition. + + Let :math:`A` be a complex Hermitian or real symmetric positive-definite matrix, + and :math:`L` its Cholesky decomposition such that: + + .. math:: + + A = LL^{\text{H}} + + where :math:`L^{\text{H}}` is the conjugate transpose when :math:`L` is complex, + and the transpose when :math:`L` is real-valued. + + Computes the inverse matrix :math:`A^{-1}`. + + Supports input of float, double, cfloat and cdouble dtypes. + Also supports batches of matrices, and if :math:`A` is a batch of matrices + then the output has the same batch dimensions. + + Args: + L (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions + consisting of lower or upper triangular Cholesky decompositions of + symmetric or Hermitian positive-definite matrices. + upper (bool, optional): flag that indicates whether :math:`L` is lower triangular + or upper triangular. Default: ``False`` + + Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + + Example:: + + >>> A = torch.randn(3, 3) + >>> A = A @ A.T + torch.eye(3) * 1e-3 # Creates a symmetric positive-definite matrix + >>> L = torch.linalg.cholesky(A) # Extract Cholesky decomposition + >>> torch.cholesky_inverse(L) + tensor([[ 1.9314, 1.2251, -0.0889], + [ 1.2251, 2.4439, 0.2122], + [-0.0889, 0.2122, 0.1412]]) + >>> A.inverse() + tensor([[ 1.9314, 1.2251, -0.0889], + [ 1.2251, 2.4439, 0.2122], + [-0.0889, 0.2122, 0.1412]]) + + >>> A = torch.randn(3, 2, 2, dtype=torch.complex64) + >>> A = A @ A.mH + torch.eye(2) * 1e-3 # Batch of Hermitian positive-definite matrices + >>> L = torch.linalg.cholesky(A) + >>> torch.dist(torch.inverse(A), torch.cholesky_inverse(L)) + tensor(5.6358e-7) + """ + +def cholesky_solve( + input: Tensor, + input2: Tensor, + upper: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + cholesky_solve(B, L, upper=False, *, out=None) -> Tensor + + Computes the solution of a system of linear equations with complex Hermitian + or real symmetric positive-definite lhs given its Cholesky decomposition. + + Let :math:`A` be a complex Hermitian or real symmetric positive-definite matrix, + and :math:`L` its Cholesky decomposition such that: + + .. math:: + + A = LL^{\text{H}} + + where :math:`L^{\text{H}}` is the conjugate transpose when :math:`L` is complex, + and the transpose when :math:`L` is real-valued. + + Returns the solution :math:`X` of the following linear system: + + .. math:: + + AX = B + + Supports inputs of float, double, cfloat and cdouble dtypes. + Also supports batches of matrices, and if :math:`A` or :math:`B` is a batch of matrices + then the output has the same batch dimensions. + + Args: + B (Tensor): right-hand side tensor of shape `(*, n, k)` + where :math:`*` is zero or more batch dimensions + L (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions + consisting of lower or upper triangular Cholesky decompositions of + symmetric or Hermitian positive-definite matrices. + upper (bool, optional): flag that indicates whether :math:`L` is lower triangular + or upper triangular. Default: ``False``. + + Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + + Example:: + + >>> A = torch.randn(3, 3) + >>> A = A @ A.T + torch.eye(3) * 1e-3 # Creates a symmetric positive-definite matrix + >>> L = torch.linalg.cholesky(A) # Extract Cholesky decomposition + >>> B = torch.randn(3, 2) + >>> torch.cholesky_solve(B, L) + tensor([[ -8.1625, 19.6097], + [ -5.8398, 14.2387], + [ -4.3771, 10.4173]]) + >>> A.inverse() @ B + tensor([[ -8.1626, 19.6097], + [ -5.8398, 14.2387], + [ -4.3771, 10.4173]]) + + >>> A = torch.randn(3, 2, 2, dtype=torch.complex64) + >>> A = A @ A.mH + torch.eye(2) * 1e-3 # Batch of Hermitian positive-definite matrices + >>> L = torch.linalg.cholesky(A) + >>> B = torch.randn(2, 1, dtype=torch.complex64) + >>> X = torch.cholesky_solve(B, L) + >>> torch.dist(X, A.inverse() @ B) + tensor(1.6881e-5) + """ + +def choose_qparams_optimized( + input: Tensor, + numel: _int, + n_bins: _int, + ratio: _float, + bit_width: _int, +) -> tuple[Tensor, Tensor]: ... +def chunk(input: Tensor, chunks: _int, dim: _int = 0) -> tuple[Tensor, ...]: + r""" + chunk(input: Tensor, chunks: int, dim: int = 0) -> Tuple[Tensor, ...] + + Attempts to split a tensor into the specified number of chunks. Each chunk is a view of + the input tensor. + + + .. note:: + + This function may return fewer than the specified number of chunks! + + .. seealso:: + + :func:`torch.tensor_split` a function that always returns exactly the specified number of chunks + + If the tensor size along the given dimension :attr:`dim` is divisible by :attr:`chunks`, + all returned chunks will be the same size. + If the tensor size along the given dimension :attr:`dim` is not divisible by :attr:`chunks`, + all returned chunks will be the same size, except the last one. + If such division is not possible, this function may return fewer + than the specified number of chunks. + + Arguments: + input (Tensor): the tensor to split + chunks (int): number of chunks to return + dim (int): dimension along which to split the tensor + + Example: + >>> torch.arange(11).chunk(6) + (tensor([0, 1]), + tensor([2, 3]), + tensor([4, 5]), + tensor([6, 7]), + tensor([8, 9]), + tensor([10])) + >>> torch.arange(12).chunk(6) + (tensor([0, 1]), + tensor([2, 3]), + tensor([4, 5]), + tensor([6, 7]), + tensor([8, 9]), + tensor([10, 11])) + >>> torch.arange(13).chunk(6) + (tensor([0, 1, 2]), + tensor([3, 4, 5]), + tensor([6, 7, 8]), + tensor([ 9, 10, 11]), + tensor([12])) + """ + +@overload +def clamp( + input: Tensor, + min: Tensor | None = None, + max: Tensor | None = None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + clamp(input, min=None, max=None, *, out=None) -> Tensor + + Clamps all elements in :attr:`input` into the range `[` :attr:`min`, :attr:`max` `]`. + Letting min_value and max_value be :attr:`min` and :attr:`max`, respectively, this returns: + + .. math:: + y_i = \min(\max(x_i, \text{min\_value}_i), \text{max\_value}_i) + + If :attr:`min` is ``None``, there is no lower bound. + Or, if :attr:`max` is ``None`` there is no upper bound. + + + .. note:: + If :attr:`min` is greater than :attr:`max` :func:`torch.clamp(..., min, max) ` + sets all elements in :attr:`input` to the value of :attr:`max`. + + Args: + input (Tensor): the input tensor. + min (Number or Tensor, optional): lower-bound of the range to be clamped to + max (Number or Tensor, optional): upper-bound of the range to be clamped to + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-1.7120, 0.1734, -0.0478, -0.0922]) + >>> torch.clamp(a, min=-0.5, max=0.5) + tensor([-0.5000, 0.1734, -0.0478, -0.0922]) + + >>> min = torch.linspace(-1, 1, steps=4) + >>> torch.clamp(a, min=min) + tensor([-1.0000, 0.1734, 0.3333, 1.0000]) + """ + +@overload +def clamp( + input: Tensor, + min: Number | _complex | None = None, + max: Number | _complex | None = None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + clamp(input, min=None, max=None, *, out=None) -> Tensor + + Clamps all elements in :attr:`input` into the range `[` :attr:`min`, :attr:`max` `]`. + Letting min_value and max_value be :attr:`min` and :attr:`max`, respectively, this returns: + + .. math:: + y_i = \min(\max(x_i, \text{min\_value}_i), \text{max\_value}_i) + + If :attr:`min` is ``None``, there is no lower bound. + Or, if :attr:`max` is ``None`` there is no upper bound. + + + .. note:: + If :attr:`min` is greater than :attr:`max` :func:`torch.clamp(..., min, max) ` + sets all elements in :attr:`input` to the value of :attr:`max`. + + Args: + input (Tensor): the input tensor. + min (Number or Tensor, optional): lower-bound of the range to be clamped to + max (Number or Tensor, optional): upper-bound of the range to be clamped to + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-1.7120, 0.1734, -0.0478, -0.0922]) + >>> torch.clamp(a, min=-0.5, max=0.5) + tensor([-0.5000, 0.1734, -0.0478, -0.0922]) + + >>> min = torch.linspace(-1, 1, steps=4) + >>> torch.clamp(a, min=min) + tensor([-1.0000, 0.1734, 0.3333, 1.0000]) + """ + +@overload +def clamp_( + input: Tensor, + min: Tensor | None = None, + max: Tensor | None = None, +) -> Tensor: ... +@overload +def clamp_( + input: Tensor, + min: Number | _complex | None = None, + max: Number | _complex | None = None, +) -> Tensor: ... +@overload +def clamp_max( + input: Tensor, + max: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: ... +@overload +def clamp_max( + input: Tensor, + max: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: ... +@overload +def clamp_max_(input: Tensor, max: Tensor) -> Tensor: ... +@overload +def clamp_max_(input: Tensor, max: Number | _complex) -> Tensor: ... +@overload +def clamp_min( + input: Tensor, + min: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: ... +@overload +def clamp_min( + input: Tensor, + min: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: ... +@overload +def clamp_min_(input: Tensor, min: Tensor) -> Tensor: ... +@overload +def clamp_min_(input: Tensor, min: Number | _complex) -> Tensor: ... +@overload +def clip( + input: Tensor, + min: Tensor | None = None, + max: Tensor | None = None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + clip(input, min=None, max=None, *, out=None) -> Tensor + + Alias for :func:`torch.clamp`. + """ + +@overload +def clip( + input: Tensor, + min: Number | _complex | None = None, + max: Number | _complex | None = None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + clip(input, min=None, max=None, *, out=None) -> Tensor + + Alias for :func:`torch.clamp`. + """ + +@overload +def clip_( + input: Tensor, + min: Tensor | None = None, + max: Tensor | None = None, +) -> Tensor: ... +@overload +def clip_( + input: Tensor, + min: Number | _complex | None = None, + max: Number | _complex | None = None, +) -> Tensor: ... +def clone( + input: Tensor, + *, + memory_format: memory_format | None = None, +) -> Tensor: + r""" + clone(input, *, memory_format=torch.preserve_format) -> Tensor + + Returns a copy of :attr:`input`. + + .. note:: + + This function is differentiable, so gradients will flow back from the + result of this operation to :attr:`input`. To create a tensor without an + autograd relationship to :attr:`input` see :meth:`~Tensor.detach`. + + Args: + input (Tensor): the input tensor. + + Keyword args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned tensor. Default: ``torch.preserve_format``. + """ + +def col_indices_copy(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.col_indices`, but all output tensors + are freshly created instead of aliasing the input. + """ + +def column_stack( + tensors: tuple[Tensor, ...] | list[Tensor] | None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + column_stack(tensors, *, out=None) -> Tensor + + Creates a new tensor by horizontally stacking the tensors in :attr:`tensors`. + + Equivalent to ``torch.hstack(tensors)``, except each zero or one dimensional tensor ``t`` + in :attr:`tensors` is first reshaped into a ``(t.numel(), 1)`` column before being stacked horizontally. + + Args: + tensors (sequence of Tensors): sequence of tensors to concatenate + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([1, 2, 3]) + >>> b = torch.tensor([4, 5, 6]) + >>> torch.column_stack((a, b)) + tensor([[1, 4], + [2, 5], + [3, 6]]) + >>> a = torch.arange(5) + >>> b = torch.arange(10).reshape(5, 2) + >>> torch.column_stack((a, b, b)) + tensor([[0, 0, 1, 0, 1], + [1, 2, 3, 2, 3], + [2, 4, 5, 4, 5], + [3, 6, 7, 6, 7], + [4, 8, 9, 8, 9]]) + """ + +def combinations( + input: Tensor, + r: _int = 2, + with_replacement: _bool = False, +) -> Tensor: + r""" + combinations(input: Tensor, r: int = 2, with_replacement: bool = False) -> seq + + Compute combinations of length :math:`r` of the given tensor. The behavior is similar to + python's `itertools.combinations` when `with_replacement` is set to `False`, and + `itertools.combinations_with_replacement` when `with_replacement` is set to `True`. + + Arguments: + input (Tensor): 1D vector. + r (int, optional): number of elements to combine + with_replacement (bool, optional): whether to allow duplication in combination + + Returns: + Tensor: A tensor equivalent to converting all the input tensors into lists, do + `itertools.combinations` or `itertools.combinations_with_replacement` on these + lists, and finally convert the resulting list into tensor. + + Example:: + + >>> a = [1, 2, 3] + >>> list(itertools.combinations(a, r=2)) + [(1, 2), (1, 3), (2, 3)] + >>> list(itertools.combinations(a, r=3)) + [(1, 2, 3)] + >>> list(itertools.combinations_with_replacement(a, r=2)) + [(1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (3, 3)] + >>> tensor_a = torch.tensor(a) + >>> torch.combinations(tensor_a) + tensor([[1, 2], + [1, 3], + [2, 3]]) + >>> torch.combinations(tensor_a, r=3) + tensor([[1, 2, 3]]) + >>> torch.combinations(tensor_a, with_replacement=True) + tensor([[1, 1], + [1, 2], + [1, 3], + [2, 2], + [2, 3], + [3, 3]]) + """ + +def complex( + real: Tensor, + imag: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + complex(real, imag, *, out=None) -> Tensor + + Constructs a complex tensor with its real part equal to :attr:`real` and its + imaginary part equal to :attr:`imag`. + + Args: + real (Tensor): The real part of the complex tensor. Must be half, float or double. + imag (Tensor): The imaginary part of the complex tensor. Must be same dtype + as :attr:`real`. + + Keyword args: + out (Tensor): If the inputs are ``torch.float32``, must be + ``torch.complex64``. If the inputs are ``torch.float64``, must be + ``torch.complex128``. + + Example:: + + >>> real = torch.tensor([1, 2], dtype=torch.float32) + >>> imag = torch.tensor([3, 4], dtype=torch.float32) + >>> z = torch.complex(real, imag) + >>> z + tensor([(1.+3.j), (2.+4.j)]) + >>> z.dtype + torch.complex64 + """ + +@overload +def concat( + tensors: tuple[Tensor, ...] | list[Tensor] | None, + dim: _int = 0, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + concat(tensors, dim=0, *, out=None) -> Tensor + + Alias of :func:`torch.cat`. + """ + +@overload +def concat( + tensors: tuple[Tensor, ...] | list[Tensor] | None, + dim: str | EllipsisType | None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + concat(tensors, dim=0, *, out=None) -> Tensor + + Alias of :func:`torch.cat`. + """ + +@overload +def concatenate( + tensors: tuple[Tensor, ...] | list[Tensor] | None, + dim: _int = 0, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + concatenate(tensors, axis=0, out=None) -> Tensor + + Alias of :func:`torch.cat`. + """ + +@overload +def concatenate( + tensors: tuple[Tensor, ...] | list[Tensor] | None, + dim: str | EllipsisType | None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + concatenate(tensors, axis=0, out=None) -> Tensor + + Alias of :func:`torch.cat`. + """ + +def conj(input: Tensor) -> Tensor: + r""" + conj(input) -> Tensor + + Returns a view of :attr:`input` with a flipped conjugate bit. If :attr:`input` has a non-complex dtype, + this function just returns :attr:`input`. + + .. note:: + :func:`torch.conj` performs a lazy conjugation, but the actual conjugated tensor can be materialized + at any time using :func:`torch.resolve_conj`. + + .. warning:: In the future, :func:`torch.conj` may return a non-writeable view for an :attr:`input` of + non-complex dtype. It's recommended that programs not modify the tensor returned by :func:`torch.conj_physical` + when :attr:`input` is of non-complex dtype to be compatible with this change. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]) + >>> x.is_conj() + False + >>> y = torch.conj(x) + >>> y.is_conj() + True + """ + +def conj_physical(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + conj_physical(input, *, out=None) -> Tensor + + Computes the element-wise conjugate of the given :attr:`input` tensor. + If :attr:`input` has a non-complex dtype, this function just returns :attr:`input`. + + .. note:: + This performs the conjugate operation regardless of the fact conjugate bit is set or not. + + .. warning:: In the future, :func:`torch.conj_physical` may return a non-writeable view for an :attr:`input` of + non-complex dtype. It's recommended that programs not modify the tensor returned by :func:`torch.conj_physical` + when :attr:`input` is of non-complex dtype to be compatible with this change. + + .. math:: + \text{out}_{i} = conj(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.conj_physical(torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j])) + tensor([-1 - 1j, -2 - 2j, 3 + 3j]) + """ + +def conj_physical_(input: Tensor) -> Tensor: ... +def constant_pad_nd( + input: Tensor, + pad: Sequence[_int | SymInt], + value: Number | _complex = 0, +) -> Tensor: ... +@overload +def conv1d( + input: Tensor, + weight: Tensor, + bias: Tensor | None = None, + stride: _int | SymInt | Sequence[_int | SymInt] = 1, + padding: _int | SymInt | Sequence[_int | SymInt] = 0, + dilation: _int | SymInt | Sequence[_int | SymInt] = 1, + groups: _int | SymInt = 1, +) -> Tensor: ... +@overload +def conv1d( + input: Tensor, + weight: Tensor, + bias: Tensor | None = None, + stride: _int | SymInt | Sequence[_int | SymInt] = 1, + padding: str = "valid", + dilation: _int | SymInt | Sequence[_int | SymInt] = 1, + groups: _int | SymInt = 1, +) -> Tensor: ... +@overload +def conv2d( + input: Tensor, + weight: Tensor, + bias: Tensor | None = None, + stride: _int | SymInt | Sequence[_int | SymInt] = 1, + padding: _int | SymInt | Sequence[_int | SymInt] = 0, + dilation: _int | SymInt | Sequence[_int | SymInt] = 1, + groups: _int | SymInt = 1, +) -> Tensor: ... +@overload +def conv2d( + input: Tensor, + weight: Tensor, + bias: Tensor | None = None, + stride: _int | SymInt | Sequence[_int | SymInt] = 1, + padding: str = "valid", + dilation: _int | SymInt | Sequence[_int | SymInt] = 1, + groups: _int | SymInt = 1, +) -> Tensor: ... +@overload +def conv3d( + input: Tensor, + weight: Tensor, + bias: Tensor | None = None, + stride: _int | SymInt | Sequence[_int | SymInt] = 1, + padding: _int | SymInt | Sequence[_int | SymInt] = 0, + dilation: _int | SymInt | Sequence[_int | SymInt] = 1, + groups: _int | SymInt = 1, +) -> Tensor: ... +@overload +def conv3d( + input: Tensor, + weight: Tensor, + bias: Tensor | None = None, + stride: _int | SymInt | Sequence[_int | SymInt] = 1, + padding: str = "valid", + dilation: _int | SymInt | Sequence[_int | SymInt] = 1, + groups: _int | SymInt = 1, +) -> Tensor: ... +def conv_tbc( + input: Tensor, + weight: Tensor, + bias: Tensor, + pad: _int = 0, +) -> Tensor: ... +def conv_transpose1d( + input: Tensor, + weight: Tensor, + bias: Tensor | None = None, + stride: _int | SymInt | Sequence[_int | SymInt] = 1, + padding: _int | SymInt | Sequence[_int | SymInt] = 0, + output_padding: _int | SymInt | Sequence[_int | SymInt] = 0, + groups: _int | SymInt = 1, + dilation: _int | SymInt | Sequence[_int | SymInt] = 1, +) -> Tensor: ... +def conv_transpose2d( + input: Tensor, + weight: Tensor, + bias: Tensor | None = None, + stride: _int | SymInt | Sequence[_int | SymInt] = 1, + padding: _int | SymInt | Sequence[_int | SymInt] = 0, + output_padding: _int | SymInt | Sequence[_int | SymInt] = 0, + groups: _int | SymInt = 1, + dilation: _int | SymInt | Sequence[_int | SymInt] = 1, +) -> Tensor: ... +def conv_transpose3d( + input: Tensor, + weight: Tensor, + bias: Tensor | None = None, + stride: _int | SymInt | Sequence[_int | SymInt] = 1, + padding: _int | SymInt | Sequence[_int | SymInt] = 0, + output_padding: _int | SymInt | Sequence[_int | SymInt] = 0, + groups: _int | SymInt = 1, + dilation: _int | SymInt | Sequence[_int | SymInt] = 1, +) -> Tensor: ... +def convolution( + input: Tensor, + weight: Tensor, + bias: Tensor | None, + stride: Sequence[_int | SymInt], + padding: Sequence[_int | SymInt], + dilation: Sequence[_int | SymInt], + transposed: _bool, + output_padding: Sequence[_int | SymInt], + groups: _int | SymInt, +) -> Tensor: ... +@overload +def copysign( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + copysign(input, other, *, out=None) -> Tensor + + Create a new floating-point tensor with the magnitude of :attr:`input` and the sign of :attr:`other`, elementwise. + + .. math:: + \text{out}_{i} = \begin{cases} + -|\text{input}_{i}| & \text{if } \text{other}_{i} \leq -0.0 \\ + |\text{input}_{i}| & \text{if } \text{other}_{i} \geq 0.0 \\ + \end{cases} + + + Supports :ref:`broadcasting to a common shape `, + and integer and float inputs. + + Args: + input (Tensor): magnitudes. + other (Tensor or Number): contains value(s) whose signbit(s) are + applied to the magnitudes in :attr:`input`. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(5) + >>> a + tensor([-1.2557, -0.0026, -0.5387, 0.4740, -0.9244]) + >>> torch.copysign(a, 1) + tensor([1.2557, 0.0026, 0.5387, 0.4740, 0.9244]) + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.7079, 0.2778, -1.0249, 0.5719], + [-0.0059, -0.2600, -0.4475, -1.3948], + [ 0.3667, -0.9567, -2.5757, -0.1751], + [ 0.2046, -0.0742, 0.2998, -0.1054]]) + >>> b = torch.randn(4) + tensor([ 0.2373, 0.3120, 0.3190, -1.1128]) + >>> torch.copysign(a, b) + tensor([[ 0.7079, 0.2778, 1.0249, -0.5719], + [ 0.0059, 0.2600, 0.4475, -1.3948], + [ 0.3667, 0.9567, 2.5757, -0.1751], + [ 0.2046, 0.0742, 0.2998, -0.1054]]) + >>> a = torch.tensor([1.]) + >>> b = torch.tensor([-0.]) + >>> torch.copysign(a, b) + tensor([-1.]) + + .. note:: + copysign handles signed zeros. If the other argument has a negative zero (-0), + the corresponding output value will be negative. + """ + +@overload +def copysign( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + copysign(input, other, *, out=None) -> Tensor + + Create a new floating-point tensor with the magnitude of :attr:`input` and the sign of :attr:`other`, elementwise. + + .. math:: + \text{out}_{i} = \begin{cases} + -|\text{input}_{i}| & \text{if } \text{other}_{i} \leq -0.0 \\ + |\text{input}_{i}| & \text{if } \text{other}_{i} \geq 0.0 \\ + \end{cases} + + + Supports :ref:`broadcasting to a common shape `, + and integer and float inputs. + + Args: + input (Tensor): magnitudes. + other (Tensor or Number): contains value(s) whose signbit(s) are + applied to the magnitudes in :attr:`input`. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(5) + >>> a + tensor([-1.2557, -0.0026, -0.5387, 0.4740, -0.9244]) + >>> torch.copysign(a, 1) + tensor([1.2557, 0.0026, 0.5387, 0.4740, 0.9244]) + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.7079, 0.2778, -1.0249, 0.5719], + [-0.0059, -0.2600, -0.4475, -1.3948], + [ 0.3667, -0.9567, -2.5757, -0.1751], + [ 0.2046, -0.0742, 0.2998, -0.1054]]) + >>> b = torch.randn(4) + tensor([ 0.2373, 0.3120, 0.3190, -1.1128]) + >>> torch.copysign(a, b) + tensor([[ 0.7079, 0.2778, 1.0249, -0.5719], + [ 0.0059, 0.2600, 0.4475, -1.3948], + [ 0.3667, 0.9567, 2.5757, -0.1751], + [ 0.2046, 0.0742, 0.2998, -0.1054]]) + >>> a = torch.tensor([1.]) + >>> b = torch.tensor([-0.]) + >>> torch.copysign(a, b) + tensor([-1.]) + + .. note:: + copysign handles signed zeros. If the other argument has a negative zero (-0), + the corresponding output value will be negative. + """ + +def corrcoef(input: Tensor) -> Tensor: + r""" + corrcoef(input) -> Tensor + + Estimates the Pearson product-moment correlation coefficient matrix of the variables given by the :attr:`input` matrix, + where rows are the variables and columns are the observations. + + .. note:: + + The correlation coefficient matrix R is computed using the covariance matrix C as given by + :math:`R_{ij} = \frac{ C_{ij} } { \sqrt{ C_{ii} * C_{jj} } }` + + .. note:: + + Due to floating point rounding, the resulting array may not be Hermitian and its diagonal elements may not be 1. + The real and imaginary values are clipped to the interval [-1, 1] in an attempt to improve this situation. + + Args: + input (Tensor): A 2D matrix containing multiple variables and observations, or a + Scalar or 1D vector representing a single variable. + + Returns: + (Tensor) The correlation coefficient matrix of the variables. + + .. seealso:: + + :func:`torch.cov` covariance matrix. + + Example:: + + >>> x = torch.tensor([[0, 1, 2], [2, 1, 0]]) + >>> torch.corrcoef(x) + tensor([[ 1., -1.], + [-1., 1.]]) + >>> x = torch.randn(2, 4) + >>> x + tensor([[-0.2678, -0.0908, -0.3766, 0.2780], + [-0.5812, 0.1535, 0.2387, 0.2350]]) + >>> torch.corrcoef(x) + tensor([[1.0000, 0.3582], + [0.3582, 1.0000]]) + >>> torch.corrcoef(x[0]) + tensor(1.) + """ + +def cos(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + cos(input, *, out=None) -> Tensor + + Returns a new tensor with the cosine of the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \cos(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 1.4309, 1.2706, -0.8562, 0.9796]) + >>> torch.cos(a) + tensor([ 0.1395, 0.2957, 0.6553, 0.5574]) + """ + +def cos_(input: Tensor) -> Tensor: ... +def cosh(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + cosh(input, *, out=None) -> Tensor + + Returns a new tensor with the hyperbolic cosine of the elements of + :attr:`input`. + + .. math:: + \text{out}_{i} = \cosh(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.1632, 1.1835, -0.6979, -0.7325]) + >>> torch.cosh(a) + tensor([ 1.0133, 1.7860, 1.2536, 1.2805]) + + .. note:: + When :attr:`input` is on the CPU, the implementation of torch.cosh may use + the Sleef library, which rounds very large results to infinity or negative + infinity. See `here `_ for details. + """ + +def cosh_(input: Tensor) -> Tensor: ... +def cosine_embedding_loss( + input1: Tensor, + input2: Tensor, + target: Tensor, + margin: _float = 0.0, + reduction: _int = 1, +) -> Tensor: ... +def cosine_similarity( + x1: Tensor, + x2: Tensor, + dim: _int = 1, + eps: _float = 1e-08, +) -> Tensor: ... +@overload +def count_nonzero(input: Tensor, dim: _int | None = None) -> Tensor: + r""" + count_nonzero(input, dim=None) -> Tensor + + Counts the number of non-zero values in the tensor :attr:`input` along the given :attr:`dim`. + If no dim is specified then all non-zeros in the tensor are counted. + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints, optional): Dim or tuple of dims along which to count non-zeros. + + Example:: + + >>> x = torch.zeros(3,3) + >>> x[torch.randn(3,3) > 0.5] = 1 + >>> x + tensor([[0., 1., 1.], + [0., 0., 0.], + [0., 0., 1.]]) + >>> torch.count_nonzero(x) + tensor(3) + >>> torch.count_nonzero(x, dim=0) + tensor([0, 1, 2]) + """ + +@overload +def count_nonzero(input: Tensor, dim: _size) -> Tensor: + r""" + count_nonzero(input, dim=None) -> Tensor + + Counts the number of non-zero values in the tensor :attr:`input` along the given :attr:`dim`. + If no dim is specified then all non-zeros in the tensor are counted. + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints, optional): Dim or tuple of dims along which to count non-zeros. + + Example:: + + >>> x = torch.zeros(3,3) + >>> x[torch.randn(3,3) > 0.5] = 1 + >>> x + tensor([[0., 1., 1.], + [0., 0., 0.], + [0., 0., 1.]]) + >>> torch.count_nonzero(x) + tensor(3) + >>> torch.count_nonzero(x, dim=0) + tensor([0, 1, 2]) + """ + +def cov( + input: Tensor, + *, + correction: _int = 1, + fweights: Tensor | None = None, + aweights: Tensor | None = None, +) -> Tensor: + r""" + cov(input, *, correction=1, fweights=None, aweights=None) -> Tensor + + Estimates the covariance matrix of the variables given by the :attr:`input` matrix, where rows are + the variables and columns are the observations. + + A covariance matrix is a square matrix giving the covariance of each pair of variables. The diagonal contains + the variance of each variable (covariance of a variable with itself). By definition, if :attr:`input` represents + a single variable (Scalar or 1D) then its variance is returned. + + The sample covariance of the variables :math:`x` and :math:`y` is given by: + + .. math:: + \text{cov}(x,y) = \frac{\sum^{N}_{i = 1}(x_{i} - \bar{x})(y_{i} - \bar{y})}{\max(0,~N~-~\delta N)} + + where :math:`\bar{x}` and :math:`\bar{y}` are the simple means of the :math:`x` and :math:`y` respectively, and + :math:`\delta N` is the :attr:`correction`. + + If :attr:`fweights` and/or :attr:`aweights` are provided, the weighted covariance + is calculated, which is given by: + + .. math:: + \text{cov}_w(x,y) = \frac{\sum^{N}_{i = 1}w_i(x_{i} - \mu_x^*)(y_{i} - \mu_y^*)} + {\max(0,~\sum^{N}_{i = 1}w_i~-~\frac{\sum^{N}_{i = 1}w_ia_i}{\sum^{N}_{i = 1}w_i}~\delta N)} + + where :math:`w` denotes :attr:`fweights` or :attr:`aweights` (``f`` and ``a`` for brevity) based on whichever is + provided, or :math:`w = f \times a` if both are provided, and + :math:`\mu_x^* = \frac{\sum^{N}_{i = 1}w_ix_{i} }{\sum^{N}_{i = 1}w_i}` is the weighted mean of the variable. If not + provided, ``f`` and/or ``a`` can be seen as a :math:`\mathbb{1}` vector of appropriate size. + + Args: + input (Tensor): A 2D matrix containing multiple variables and observations, or a + Scalar or 1D vector representing a single variable. + + Keyword Args: + correction (int, optional): difference between the sample size and sample degrees of freedom. + Defaults to Bessel's correction, ``correction = 1`` which returns the unbiased estimate, + even if both :attr:`fweights` and :attr:`aweights` are specified. ``correction = 0`` + will return the simple average. Defaults to ``1``. + fweights (tensor, optional): A Scalar or 1D tensor of observation vector frequencies representing the number of + times each observation should be repeated. Its numel must equal the number of columns of :attr:`input`. + Must have integral dtype. Ignored if ``None``. Defaults to ``None``. + aweights (tensor, optional): A Scalar or 1D array of observation vector weights. + These relative weights are typically large for observations considered "important" and smaller for + observations considered less "important". Its numel must equal the number of columns of :attr:`input`. + Must have floating point dtype. Ignored if ``None``. Defaults to ``None``. + + Returns: + (Tensor) The covariance matrix of the variables. + + .. seealso:: + + :func:`torch.corrcoef` normalized covariance matrix. + + Example:: + + >>> x = torch.tensor([[0, 2], [1, 1], [2, 0]]).T + >>> x + tensor([[0, 1, 2], + [2, 1, 0]]) + >>> torch.cov(x) + tensor([[ 1., -1.], + [-1., 1.]]) + >>> torch.cov(x, correction=0) + tensor([[ 0.6667, -0.6667], + [-0.6667, 0.6667]]) + >>> fw = torch.randint(1, 10, (3,)) + >>> fw + tensor([1, 6, 9]) + >>> aw = torch.rand(3) + >>> aw + tensor([0.4282, 0.0255, 0.4144]) + >>> torch.cov(x, fweights=fw, aweights=aw) + tensor([[ 0.4169, -0.4169], + [-0.4169, 0.4169]]) + """ + +def cross( + input: Tensor, + other: Tensor, + dim: _int | None = None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + cross(input, other, dim=None, *, out=None) -> Tensor + + + Returns the cross product of vectors in dimension :attr:`dim` of :attr:`input` + and :attr:`other`. + + Supports input of float, double, cfloat and cdouble dtypes. Also supports batches + of vectors, for which it computes the product along the dimension :attr:`dim`. + In this case, the output has the same batch dimensions as the inputs. + + .. warning:: + If :attr:`dim` is not given, it defaults to the first dimension found + with the size 3. Note that this might be unexpected. + + This behavior is deprecated and will be changed to match that of :func:`torch.linalg.cross` + in a future release. + + .. seealso:: + :func:`torch.linalg.cross` which has dim=-1 as default. + + + Args: + input (Tensor): the input tensor. + other (Tensor): the second input tensor + dim (int, optional): the dimension to take the cross-product in. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4, 3) + >>> a + tensor([[-0.3956, 1.1455, 1.6895], + [-0.5849, 1.3672, 0.3599], + [-1.1626, 0.7180, -0.0521], + [-0.1339, 0.9902, -2.0225]]) + >>> b = torch.randn(4, 3) + >>> b + tensor([[-0.0257, -1.4725, -1.2251], + [-1.1479, -0.7005, -1.9757], + [-1.3904, 0.3726, -1.1836], + [-0.9688, -0.7153, 0.2159]]) + >>> torch.cross(a, b, dim=1) + tensor([[ 1.0844, -0.5281, 0.6120], + [-2.4490, -1.5687, 1.9792], + [-0.8304, -1.3037, 0.5650], + [-1.2329, 1.9883, 1.0551]]) + >>> torch.cross(a, b) + tensor([[ 1.0844, -0.5281, 0.6120], + [-2.4490, -1.5687, 1.9792], + [-0.8304, -1.3037, 0.5650], + [-1.2329, 1.9883, 1.0551]]) + """ + +def crow_indices_copy( + input: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + Performs the same operation as :func:`torch.crow_indices`, but all output tensors + are freshly created instead of aliasing the input. + """ + +@overload +def ctc_loss( + log_probs: Tensor, + targets: Tensor, + input_lengths: _size, + target_lengths: _size, + blank: _int = 0, + reduction: _int = 1, + zero_infinity: _bool = False, +) -> Tensor: ... +@overload +def ctc_loss( + log_probs: Tensor, + targets: Tensor, + input_lengths: Tensor, + target_lengths: Tensor, + blank: _int = 0, + reduction: _int = 1, + zero_infinity: _bool = False, +) -> Tensor: ... +def cudnn_affine_grid_generator( + theta: Tensor, + N: _int, + C: _int, + H: _int, + W: _int, +) -> Tensor: ... +def cudnn_batch_norm( + input: Tensor, + weight: Tensor, + bias: Tensor | None, + running_mean: Tensor | None, + running_var: Tensor | None, + training: _bool, + exponential_average_factor: _float, + epsilon: _float, +) -> tuple[Tensor, Tensor, Tensor, Tensor]: ... +def cudnn_convolution( + input: Tensor, + weight: Tensor, + padding: Sequence[_int | SymInt], + stride: Sequence[_int | SymInt], + dilation: Sequence[_int | SymInt], + groups: _int | SymInt, + benchmark: _bool, + deterministic: _bool, + allow_tf32: _bool, + *, + out: Tensor | None = None, +) -> Tensor: ... +def cudnn_convolution_add_relu( + input: Tensor, + weight: Tensor, + z: Tensor, + alpha: Number | _complex | None, + bias: Tensor | None, + stride: Sequence[_int | SymInt], + padding: Sequence[_int | SymInt], + dilation: Sequence[_int | SymInt], + groups: _int | SymInt, +) -> Tensor: ... +def cudnn_convolution_relu( + input: Tensor, + weight: Tensor, + bias: Tensor | None, + stride: Sequence[_int | SymInt], + padding: Sequence[_int | SymInt], + dilation: Sequence[_int | SymInt], + groups: _int | SymInt, +) -> Tensor: ... +def cudnn_convolution_transpose( + input: Tensor, + weight: Tensor, + padding: Sequence[_int | SymInt], + output_padding: Sequence[_int | SymInt], + stride: Sequence[_int | SymInt], + dilation: Sequence[_int | SymInt], + groups: _int | SymInt, + benchmark: _bool, + deterministic: _bool, + allow_tf32: _bool, +) -> Tensor: ... +def cudnn_grid_sampler(input: Tensor, grid: Tensor) -> Tensor: ... +def cudnn_is_acceptable(input: Tensor) -> _bool: ... +@overload +def cummax( + input: Tensor, + dim: _int, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.cummax: + r""" + cummax(input, dim, *, out=None) -> (Tensor, LongTensor) + Returns a namedtuple ``(values, indices)`` where ``values`` is the cumulative maximum of + elements of :attr:`input` in the dimension :attr:`dim`. And ``indices`` is the index + location of each maximum value found in the dimension :attr:`dim`. + + .. math:: + y_i = max(x_1, x_2, x_3, \dots, x_i) + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (values, indices) + + Example:: + + >>> a = torch.randn(10) + >>> a + tensor([-0.3449, -1.5447, 0.0685, -1.5104, -1.1706, 0.2259, 1.4696, -1.3284, + 1.9946, -0.8209]) + >>> torch.cummax(a, dim=0) + torch.return_types.cummax( + values=tensor([-0.3449, -0.3449, 0.0685, 0.0685, 0.0685, 0.2259, 1.4696, 1.4696, + 1.9946, 1.9946]), + indices=tensor([0, 0, 2, 2, 2, 5, 6, 6, 8, 8])) + """ + +@overload +def cummax( + input: Tensor, + dim: str | EllipsisType | None, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.cummax: + r""" + cummax(input, dim, *, out=None) -> (Tensor, LongTensor) + Returns a namedtuple ``(values, indices)`` where ``values`` is the cumulative maximum of + elements of :attr:`input` in the dimension :attr:`dim`. And ``indices`` is the index + location of each maximum value found in the dimension :attr:`dim`. + + .. math:: + y_i = max(x_1, x_2, x_3, \dots, x_i) + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (values, indices) + + Example:: + + >>> a = torch.randn(10) + >>> a + tensor([-0.3449, -1.5447, 0.0685, -1.5104, -1.1706, 0.2259, 1.4696, -1.3284, + 1.9946, -0.8209]) + >>> torch.cummax(a, dim=0) + torch.return_types.cummax( + values=tensor([-0.3449, -0.3449, 0.0685, 0.0685, 0.0685, 0.2259, 1.4696, 1.4696, + 1.9946, 1.9946]), + indices=tensor([0, 0, 2, 2, 2, 5, 6, 6, 8, 8])) + """ + +@overload +def cummin( + input: Tensor, + dim: _int, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.cummin: + r""" + cummin(input, dim, *, out=None) -> (Tensor, LongTensor) + Returns a namedtuple ``(values, indices)`` where ``values`` is the cumulative minimum of + elements of :attr:`input` in the dimension :attr:`dim`. And ``indices`` is the index + location of each maximum value found in the dimension :attr:`dim`. + + .. math:: + y_i = min(x_1, x_2, x_3, \dots, x_i) + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (values, indices) + + Example:: + + >>> a = torch.randn(10) + >>> a + tensor([-0.2284, -0.6628, 0.0975, 0.2680, -1.3298, -0.4220, -0.3885, 1.1762, + 0.9165, 1.6684]) + >>> torch.cummin(a, dim=0) + torch.return_types.cummin( + values=tensor([-0.2284, -0.6628, -0.6628, -0.6628, -1.3298, -1.3298, -1.3298, -1.3298, + -1.3298, -1.3298]), + indices=tensor([0, 1, 1, 1, 4, 4, 4, 4, 4, 4])) + """ + +@overload +def cummin( + input: Tensor, + dim: str | EllipsisType | None, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.cummin: + r""" + cummin(input, dim, *, out=None) -> (Tensor, LongTensor) + Returns a namedtuple ``(values, indices)`` where ``values`` is the cumulative minimum of + elements of :attr:`input` in the dimension :attr:`dim`. And ``indices`` is the index + location of each maximum value found in the dimension :attr:`dim`. + + .. math:: + y_i = min(x_1, x_2, x_3, \dots, x_i) + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (values, indices) + + Example:: + + >>> a = torch.randn(10) + >>> a + tensor([-0.2284, -0.6628, 0.0975, 0.2680, -1.3298, -0.4220, -0.3885, 1.1762, + 0.9165, 1.6684]) + >>> torch.cummin(a, dim=0) + torch.return_types.cummin( + values=tensor([-0.2284, -0.6628, -0.6628, -0.6628, -1.3298, -1.3298, -1.3298, -1.3298, + -1.3298, -1.3298]), + indices=tensor([0, 1, 1, 1, 4, 4, 4, 4, 4, 4])) + """ + +@overload +def cumprod( + input: Tensor, + dim: _int, + *, + dtype: _dtype | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + cumprod(input, dim, *, dtype=None, out=None) -> Tensor + + Returns the cumulative product of elements of :attr:`input` in the dimension + :attr:`dim`. + + For example, if :attr:`input` is a vector of size N, the result will also be + a vector of size N, with elements. + + .. math:: + y_i = x_1 \times x_2\times x_3\times \dots \times x_i + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(10) + >>> a + tensor([ 0.6001, 0.2069, -0.1919, 0.9792, 0.6727, 1.0062, 0.4126, + -0.2129, -0.4206, 0.1968]) + >>> torch.cumprod(a, dim=0) + tensor([ 0.6001, 0.1241, -0.0238, -0.0233, -0.0157, -0.0158, -0.0065, + 0.0014, -0.0006, -0.0001]) + + >>> a[5] = 0.0 + >>> torch.cumprod(a, dim=0) + tensor([ 0.6001, 0.1241, -0.0238, -0.0233, -0.0157, -0.0000, -0.0000, + 0.0000, -0.0000, -0.0000]) + """ + +@overload +def cumprod( + input: Tensor, + dim: str | EllipsisType | None, + *, + dtype: _dtype | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + cumprod(input, dim, *, dtype=None, out=None) -> Tensor + + Returns the cumulative product of elements of :attr:`input` in the dimension + :attr:`dim`. + + For example, if :attr:`input` is a vector of size N, the result will also be + a vector of size N, with elements. + + .. math:: + y_i = x_1 \times x_2\times x_3\times \dots \times x_i + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(10) + >>> a + tensor([ 0.6001, 0.2069, -0.1919, 0.9792, 0.6727, 1.0062, 0.4126, + -0.2129, -0.4206, 0.1968]) + >>> torch.cumprod(a, dim=0) + tensor([ 0.6001, 0.1241, -0.0238, -0.0233, -0.0157, -0.0158, -0.0065, + 0.0014, -0.0006, -0.0001]) + + >>> a[5] = 0.0 + >>> torch.cumprod(a, dim=0) + tensor([ 0.6001, 0.1241, -0.0238, -0.0233, -0.0157, -0.0000, -0.0000, + 0.0000, -0.0000, -0.0000]) + """ + +@overload +def cumsum( + input: Tensor, + dim: _int, + *, + dtype: _dtype | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + cumsum(input, dim, *, dtype=None, out=None) -> Tensor + + Returns the cumulative sum of elements of :attr:`input` in the dimension + :attr:`dim`. + + For example, if :attr:`input` is a vector of size N, the result will also be + a vector of size N, with elements. + + .. math:: + y_i = x_1 + x_2 + x_3 + \dots + x_i + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randint(1, 20, (10,)) + >>> a + tensor([13, 7, 3, 10, 13, 3, 15, 10, 9, 10]) + >>> torch.cumsum(a, dim=0) + tensor([13, 20, 23, 33, 46, 49, 64, 74, 83, 93]) + """ + +@overload +def cumsum( + input: Tensor, + dim: str | EllipsisType | None, + *, + dtype: _dtype | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + cumsum(input, dim, *, dtype=None, out=None) -> Tensor + + Returns the cumulative sum of elements of :attr:`input` in the dimension + :attr:`dim`. + + For example, if :attr:`input` is a vector of size N, the result will also be + a vector of size N, with elements. + + .. math:: + y_i = x_1 + x_2 + x_3 + \dots + x_i + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randint(1, 20, (10,)) + >>> a + tensor([13, 7, 3, 10, 13, 3, 15, 10, 9, 10]) + >>> torch.cumsum(a, dim=0) + tensor([13, 20, 23, 33, 46, 49, 64, 74, 83, 93]) + """ + +@overload +def cumulative_trapezoid(y: Tensor, x: Tensor, *, dim: _int = -1) -> Tensor: + r""" + cumulative_trapezoid(y, x=None, *, dx=None, dim=-1) -> Tensor + + Cumulatively computes the `trapezoidal rule `_ + along :attr:`dim`. By default the spacing between elements is assumed to be 1, but + :attr:`dx` can be used to specify a different constant spacing, and :attr:`x` can be + used to specify arbitrary spacing along :attr:`dim`. + + For more details, please read :func:`torch.trapezoid`. The difference between :func:`torch.trapezoid` + and this function is that, :func:`torch.trapezoid` returns a value for each integration, + where as this function returns a cumulative value for every spacing within the integration. This + is analogous to how `.sum` returns a value and `.cumsum` returns a cumulative sum. + + Arguments: + y (Tensor): Values to use when computing the trapezoidal rule. + x (Tensor): If specified, defines spacing between values as specified above. + + Keyword arguments: + dx (float): constant spacing between values. If neither :attr:`x` or :attr:`dx` + are specified then this defaults to 1. Effectively multiplies the result by its value. + dim (int): The dimension along which to compute the trapezoidal rule. + The last (inner-most) dimension by default. + + Examples:: + + >>> # Cumulatively computes the trapezoidal rule in 1D, spacing is implicitly 1. + >>> y = torch.tensor([1, 5, 10]) + >>> torch.cumulative_trapezoid(y) + tensor([3., 10.5]) + + >>> # Computes the same trapezoidal rule directly up to each element to verify + >>> (1 + 5) / 2 + 3.0 + >>> (1 + 10 + 10) / 2 + 10.5 + + >>> # Cumulatively computes the trapezoidal rule in 1D with constant spacing of 2 + >>> # NOTE: the result is the same as before, but multiplied by 2 + >>> torch.cumulative_trapezoid(y, dx=2) + tensor([6., 21.]) + + >>> # Cumulatively computes the trapezoidal rule in 1D with arbitrary spacing + >>> x = torch.tensor([1, 3, 6]) + >>> torch.cumulative_trapezoid(y, x) + tensor([6., 28.5]) + + >>> # Computes the same trapezoidal rule directly up to each element to verify + >>> ((3 - 1) * (1 + 5)) / 2 + 6.0 + >>> ((3 - 1) * (1 + 5) + (6 - 3) * (5 + 10)) / 2 + 28.5 + + >>> # Cumulatively computes the trapezoidal rule for each row of a 3x3 matrix + >>> y = torch.arange(9).reshape(3, 3) + tensor([[0, 1, 2], + [3, 4, 5], + [6, 7, 8]]) + >>> torch.cumulative_trapezoid(y) + tensor([[ 0.5, 2.], + [ 3.5, 8.], + [ 6.5, 14.]]) + + >>> # Cumulatively computes the trapezoidal rule for each column of the matrix + >>> torch.cumulative_trapezoid(y, dim=0) + tensor([[ 1.5, 2.5, 3.5], + [ 6.0, 8.0, 10.0]]) + + >>> # Cumulatively computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with the same arbitrary spacing + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([1, 3, 6]) + >>> torch.cumulative_trapezoid(y, x) + tensor([[2., 5.], + [2., 5.], + [2., 5.]]) + + >>> # Cumulatively computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with different arbitrary spacing per row + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([[1, 2, 3], [1, 3, 5], [1, 4, 7]]) + >>> torch.cumulative_trapezoid(y, x) + tensor([[1., 2.], + [2., 4.], + [3., 6.]]) + """ + +@overload +def cumulative_trapezoid( + y: Tensor, + *, + dx: Number | _complex = 1, + dim: _int = -1, +) -> Tensor: + r""" + cumulative_trapezoid(y, x=None, *, dx=None, dim=-1) -> Tensor + + Cumulatively computes the `trapezoidal rule `_ + along :attr:`dim`. By default the spacing between elements is assumed to be 1, but + :attr:`dx` can be used to specify a different constant spacing, and :attr:`x` can be + used to specify arbitrary spacing along :attr:`dim`. + + For more details, please read :func:`torch.trapezoid`. The difference between :func:`torch.trapezoid` + and this function is that, :func:`torch.trapezoid` returns a value for each integration, + where as this function returns a cumulative value for every spacing within the integration. This + is analogous to how `.sum` returns a value and `.cumsum` returns a cumulative sum. + + Arguments: + y (Tensor): Values to use when computing the trapezoidal rule. + x (Tensor): If specified, defines spacing between values as specified above. + + Keyword arguments: + dx (float): constant spacing between values. If neither :attr:`x` or :attr:`dx` + are specified then this defaults to 1. Effectively multiplies the result by its value. + dim (int): The dimension along which to compute the trapezoidal rule. + The last (inner-most) dimension by default. + + Examples:: + + >>> # Cumulatively computes the trapezoidal rule in 1D, spacing is implicitly 1. + >>> y = torch.tensor([1, 5, 10]) + >>> torch.cumulative_trapezoid(y) + tensor([3., 10.5]) + + >>> # Computes the same trapezoidal rule directly up to each element to verify + >>> (1 + 5) / 2 + 3.0 + >>> (1 + 10 + 10) / 2 + 10.5 + + >>> # Cumulatively computes the trapezoidal rule in 1D with constant spacing of 2 + >>> # NOTE: the result is the same as before, but multiplied by 2 + >>> torch.cumulative_trapezoid(y, dx=2) + tensor([6., 21.]) + + >>> # Cumulatively computes the trapezoidal rule in 1D with arbitrary spacing + >>> x = torch.tensor([1, 3, 6]) + >>> torch.cumulative_trapezoid(y, x) + tensor([6., 28.5]) + + >>> # Computes the same trapezoidal rule directly up to each element to verify + >>> ((3 - 1) * (1 + 5)) / 2 + 6.0 + >>> ((3 - 1) * (1 + 5) + (6 - 3) * (5 + 10)) / 2 + 28.5 + + >>> # Cumulatively computes the trapezoidal rule for each row of a 3x3 matrix + >>> y = torch.arange(9).reshape(3, 3) + tensor([[0, 1, 2], + [3, 4, 5], + [6, 7, 8]]) + >>> torch.cumulative_trapezoid(y) + tensor([[ 0.5, 2.], + [ 3.5, 8.], + [ 6.5, 14.]]) + + >>> # Cumulatively computes the trapezoidal rule for each column of the matrix + >>> torch.cumulative_trapezoid(y, dim=0) + tensor([[ 1.5, 2.5, 3.5], + [ 6.0, 8.0, 10.0]]) + + >>> # Cumulatively computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with the same arbitrary spacing + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([1, 3, 6]) + >>> torch.cumulative_trapezoid(y, x) + tensor([[2., 5.], + [2., 5.], + [2., 5.]]) + + >>> # Cumulatively computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with different arbitrary spacing per row + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([[1, 2, 3], [1, 3, 5], [1, 4, 7]]) + >>> torch.cumulative_trapezoid(y, x) + tensor([[1., 2.], + [2., 4.], + [3., 6.]]) + """ + +def deg2rad(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + deg2rad(input, *, out=None) -> Tensor + + Returns a new tensor with each of the elements of :attr:`input` + converted from angles in degrees to radians. + + Args: + input (Tensor): the input tensor. + + Keyword arguments: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([[180.0, -180.0], [360.0, -360.0], [90.0, -90.0]]) + >>> torch.deg2rad(a) + tensor([[ 3.1416, -3.1416], + [ 6.2832, -6.2832], + [ 1.5708, -1.5708]]) + """ + +def deg2rad_(input: Tensor) -> Tensor: ... +@overload +def dequantize(input: Tensor) -> Tensor: + r""" + dequantize(tensor) -> Tensor + + Returns an fp32 Tensor by dequantizing a quantized Tensor + + Args: + tensor (Tensor): A quantized Tensor + + .. function:: dequantize(tensors) -> sequence of Tensors + :noindex: + + Given a list of quantized Tensors, dequantize them and return a list of fp32 Tensors + + Args: + tensors (sequence of Tensors): A list of quantized Tensors + """ + +@overload +def dequantize( + tensors: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + dequantize(tensor) -> Tensor + + Returns an fp32 Tensor by dequantizing a quantized Tensor + + Args: + tensor (Tensor): A quantized Tensor + + .. function:: dequantize(tensors) -> sequence of Tensors + :noindex: + + Given a list of quantized Tensors, dequantize them and return a list of fp32 Tensors + + Args: + tensors (sequence of Tensors): A list of quantized Tensors + """ + +def det(input: Tensor) -> Tensor: + r""" + det(input) -> Tensor + + Alias for :func:`torch.linalg.det` + """ + +def detach(input: Tensor) -> Tensor: ... +def detach_(input: Tensor) -> Tensor: ... +def detach_copy(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.detach`, but all output tensors + are freshly created instead of aliasing the input. + """ + +def diag( + input: Tensor, + diagonal: _int = 0, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + diag(input, diagonal=0, *, out=None) -> Tensor + + - If :attr:`input` is a vector (1-D tensor), then returns a 2-D square tensor + with the elements of :attr:`input` as the diagonal. + - If :attr:`input` is a matrix (2-D tensor), then returns a 1-D tensor with + the diagonal elements of :attr:`input`. + + The argument :attr:`diagonal` controls which diagonal to consider: + + - If :attr:`diagonal` = 0, it is the main diagonal. + - If :attr:`diagonal` > 0, it is above the main diagonal. + - If :attr:`diagonal` < 0, it is below the main diagonal. + + Args: + input (Tensor): the input tensor. + diagonal (int, optional): the diagonal to consider + + Keyword args: + out (Tensor, optional): the output tensor. + + .. seealso:: + + :func:`torch.diagonal` always returns the diagonal of its input. + + :func:`torch.diagflat` always constructs a tensor with diagonal elements + specified by the input. + + Examples: + + Get the square matrix where the input vector is the diagonal:: + + >>> a = torch.randn(3) + >>> a + tensor([ 0.5950,-0.0872, 2.3298]) + >>> torch.diag(a) + tensor([[ 0.5950, 0.0000, 0.0000], + [ 0.0000,-0.0872, 0.0000], + [ 0.0000, 0.0000, 2.3298]]) + >>> torch.diag(a, 1) + tensor([[ 0.0000, 0.5950, 0.0000, 0.0000], + [ 0.0000, 0.0000,-0.0872, 0.0000], + [ 0.0000, 0.0000, 0.0000, 2.3298], + [ 0.0000, 0.0000, 0.0000, 0.0000]]) + + Get the k-th diagonal of a given matrix:: + + >>> a = torch.randn(3, 3) + >>> a + tensor([[-0.4264, 0.0255,-0.1064], + [ 0.8795,-0.2429, 0.1374], + [ 0.1029,-0.6482,-1.6300]]) + >>> torch.diag(a, 0) + tensor([-0.4264,-0.2429,-1.6300]) + >>> torch.diag(a, 1) + tensor([ 0.0255, 0.1374]) + """ + +def diag_embed( + input: Tensor, + offset: _int = 0, + dim1: _int = -2, + dim2: _int = -1, +) -> Tensor: + r""" + diag_embed(input, offset=0, dim1=-2, dim2=-1) -> Tensor + + Creates a tensor whose diagonals of certain 2D planes (specified by + :attr:`dim1` and :attr:`dim2`) are filled by :attr:`input`. + To facilitate creating batched diagonal matrices, the 2D planes formed by + the last two dimensions of the returned tensor are chosen by default. + + The argument :attr:`offset` controls which diagonal to consider: + + - If :attr:`offset` = 0, it is the main diagonal. + - If :attr:`offset` > 0, it is above the main diagonal. + - If :attr:`offset` < 0, it is below the main diagonal. + + The size of the new matrix will be calculated to make the specified diagonal + of the size of the last input dimension. + Note that for :attr:`offset` other than :math:`0`, the order of :attr:`dim1` + and :attr:`dim2` matters. Exchanging them is equivalent to changing the + sign of :attr:`offset`. + + Applying :meth:`torch.diagonal` to the output of this function with + the same arguments yields a matrix identical to input. However, + :meth:`torch.diagonal` has different default dimensions, so those + need to be explicitly specified. + + Args: + input (Tensor): the input tensor. Must be at least 1-dimensional. + offset (int, optional): which diagonal to consider. Default: 0 + (main diagonal). + dim1 (int, optional): first dimension with respect to which to + take diagonal. Default: -2. + dim2 (int, optional): second dimension with respect to which to + take diagonal. Default: -1. + + Example:: + + >>> a = torch.randn(2, 3) + >>> torch.diag_embed(a) + tensor([[[ 1.5410, 0.0000, 0.0000], + [ 0.0000, -0.2934, 0.0000], + [ 0.0000, 0.0000, -2.1788]], + + [[ 0.5684, 0.0000, 0.0000], + [ 0.0000, -1.0845, 0.0000], + [ 0.0000, 0.0000, -1.3986]]]) + + >>> torch.diag_embed(a, offset=1, dim1=0, dim2=2) + tensor([[[ 0.0000, 1.5410, 0.0000, 0.0000], + [ 0.0000, 0.5684, 0.0000, 0.0000]], + + [[ 0.0000, 0.0000, -0.2934, 0.0000], + [ 0.0000, 0.0000, -1.0845, 0.0000]], + + [[ 0.0000, 0.0000, 0.0000, -2.1788], + [ 0.0000, 0.0000, 0.0000, -1.3986]], + + [[ 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000]]]) + """ + +def diagflat(input: Tensor, offset: _int = 0) -> Tensor: + r""" + diagflat(input, offset=0) -> Tensor + + - If :attr:`input` is a vector (1-D tensor), then returns a 2-D square tensor + with the elements of :attr:`input` as the diagonal. + - If :attr:`input` is a tensor with more than one dimension, then returns a + 2-D tensor with diagonal elements equal to a flattened :attr:`input`. + + The argument :attr:`offset` controls which diagonal to consider: + + - If :attr:`offset` = 0, it is the main diagonal. + - If :attr:`offset` > 0, it is above the main diagonal. + - If :attr:`offset` < 0, it is below the main diagonal. + + Args: + input (Tensor): the input tensor. + offset (int, optional): the diagonal to consider. Default: 0 (main + diagonal). + + Examples:: + + >>> a = torch.randn(3) + >>> a + tensor([-0.2956, -0.9068, 0.1695]) + >>> torch.diagflat(a) + tensor([[-0.2956, 0.0000, 0.0000], + [ 0.0000, -0.9068, 0.0000], + [ 0.0000, 0.0000, 0.1695]]) + >>> torch.diagflat(a, 1) + tensor([[ 0.0000, -0.2956, 0.0000, 0.0000], + [ 0.0000, 0.0000, -0.9068, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.1695], + [ 0.0000, 0.0000, 0.0000, 0.0000]]) + + >>> a = torch.randn(2, 2) + >>> a + tensor([[ 0.2094, -0.3018], + [-0.1516, 1.9342]]) + >>> torch.diagflat(a) + tensor([[ 0.2094, 0.0000, 0.0000, 0.0000], + [ 0.0000, -0.3018, 0.0000, 0.0000], + [ 0.0000, 0.0000, -0.1516, 0.0000], + [ 0.0000, 0.0000, 0.0000, 1.9342]]) + """ + +@overload +def diagonal( + input: Tensor, + offset: _int = 0, + dim1: _int = 0, + dim2: _int = 1, +) -> Tensor: + r""" + diagonal(input, offset=0, dim1=0, dim2=1) -> Tensor + + Returns a partial view of :attr:`input` with the its diagonal elements + with respect to :attr:`dim1` and :attr:`dim2` appended as a dimension + at the end of the shape. + + The argument :attr:`offset` controls which diagonal to consider: + + - If :attr:`offset` = 0, it is the main diagonal. + - If :attr:`offset` > 0, it is above the main diagonal. + - If :attr:`offset` < 0, it is below the main diagonal. + + Applying :meth:`torch.diag_embed` to the output of this function with + the same arguments yields a diagonal matrix with the diagonal entries + of the input. However, :meth:`torch.diag_embed` has different default + dimensions, so those need to be explicitly specified. + + Args: + input (Tensor): the input tensor. Must be at least 2-dimensional. + offset (int, optional): which diagonal to consider. Default: 0 + (main diagonal). + dim1 (int, optional): first dimension with respect to which to + take diagonal. Default: 0. + dim2 (int, optional): second dimension with respect to which to + take diagonal. Default: 1. + + .. note:: To take a batch diagonal, pass in dim1=-2, dim2=-1. + + Examples:: + + >>> a = torch.randn(3, 3) + >>> a + tensor([[-1.0854, 1.1431, -0.1752], + [ 0.8536, -0.0905, 0.0360], + [ 0.6927, -0.3735, -0.4945]]) + + + >>> torch.diagonal(a) + tensor([-1.0854, -0.0905, -0.4945]) + + + >>> torch.diagonal(a, 1) + tensor([ 1.1431, 0.0360]) + + >>> b = torch.randn(2, 5) + >>> b + tensor([[-1.7948, -1.2731, -0.3181, 2.0200, -1.6745], + [ 1.8262, -1.5049, 0.4114, 1.0704, -1.2607]]) + + >>> torch.diagonal(b, 1, 1, 0) + tensor([1.8262]) + + >>> x = torch.randn(2, 5, 4, 2) + >>> torch.diagonal(x, offset=-1, dim1=1, dim2=2) + tensor([[[-1.2631, 0.3755, -1.5977, -1.8172], + [-1.1065, 1.0401, -0.2235, -0.7938]], + + [[-1.7325, -0.3081, 0.6166, 0.2335], + [ 1.0500, 0.7336, -0.3836, -1.1015]]]) + """ + +@overload +def diagonal( + input: Tensor, + *, + outdim: str | EllipsisType | None, + dim1: str | EllipsisType | None, + dim2: str | EllipsisType | None, + offset: _int = 0, +) -> Tensor: + r""" + diagonal(input, offset=0, dim1=0, dim2=1) -> Tensor + + Returns a partial view of :attr:`input` with the its diagonal elements + with respect to :attr:`dim1` and :attr:`dim2` appended as a dimension + at the end of the shape. + + The argument :attr:`offset` controls which diagonal to consider: + + - If :attr:`offset` = 0, it is the main diagonal. + - If :attr:`offset` > 0, it is above the main diagonal. + - If :attr:`offset` < 0, it is below the main diagonal. + + Applying :meth:`torch.diag_embed` to the output of this function with + the same arguments yields a diagonal matrix with the diagonal entries + of the input. However, :meth:`torch.diag_embed` has different default + dimensions, so those need to be explicitly specified. + + Args: + input (Tensor): the input tensor. Must be at least 2-dimensional. + offset (int, optional): which diagonal to consider. Default: 0 + (main diagonal). + dim1 (int, optional): first dimension with respect to which to + take diagonal. Default: 0. + dim2 (int, optional): second dimension with respect to which to + take diagonal. Default: 1. + + .. note:: To take a batch diagonal, pass in dim1=-2, dim2=-1. + + Examples:: + + >>> a = torch.randn(3, 3) + >>> a + tensor([[-1.0854, 1.1431, -0.1752], + [ 0.8536, -0.0905, 0.0360], + [ 0.6927, -0.3735, -0.4945]]) + + + >>> torch.diagonal(a) + tensor([-1.0854, -0.0905, -0.4945]) + + + >>> torch.diagonal(a, 1) + tensor([ 1.1431, 0.0360]) + + >>> b = torch.randn(2, 5) + >>> b + tensor([[-1.7948, -1.2731, -0.3181, 2.0200, -1.6745], + [ 1.8262, -1.5049, 0.4114, 1.0704, -1.2607]]) + + >>> torch.diagonal(b, 1, 1, 0) + tensor([1.8262]) + + >>> x = torch.randn(2, 5, 4, 2) + >>> torch.diagonal(x, offset=-1, dim1=1, dim2=2) + tensor([[[-1.2631, 0.3755, -1.5977, -1.8172], + [-1.1065, 1.0401, -0.2235, -0.7938]], + + [[-1.7325, -0.3081, 0.6166, 0.2335], + [ 1.0500, 0.7336, -0.3836, -1.1015]]]) + """ + +def diagonal_copy( + input: Tensor, + offset: _int = 0, + dim1: _int = 0, + dim2: _int = 1, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + Performs the same operation as :func:`torch.diagonal`, but all output tensors + are freshly created instead of aliasing the input. + """ + +def diagonal_scatter( + input: Tensor, + src: Tensor, + offset: _int = 0, + dim1: _int = 0, + dim2: _int = 1, +) -> Tensor: + r""" + diagonal_scatter(input, src, offset=0, dim1=0, dim2=1) -> Tensor + + Embeds the values of the :attr:`src` tensor into :attr:`input` along + the diagonal elements of :attr:`input`, with respect to :attr:`dim1` + and :attr:`dim2`. + + This function returns a tensor with fresh storage; it does not + return a view. + + The argument :attr:`offset` controls which diagonal to consider: + + - If :attr:`offset` = 0, it is the main diagonal. + - If :attr:`offset` > 0, it is above the main diagonal. + - If :attr:`offset` < 0, it is below the main diagonal. + + Args: + input (Tensor): the input tensor. Must be at least 2-dimensional. + src (Tensor): the tensor to embed into :attr:`input`. + offset (int, optional): which diagonal to consider. Default: 0 + (main diagonal). + dim1 (int, optional): first dimension with respect to which to + take diagonal. Default: 0. + dim2 (int, optional): second dimension with respect to which to + take diagonal. Default: 1. + + .. note:: + + :attr:`src` must be of the proper size in order to be embedded + into :attr:`input`. Specifically, it should have the same shape as + ``torch.diagonal(input, offset, dim1, dim2)`` + + Examples:: + + >>> a = torch.zeros(3, 3) + >>> a + tensor([[0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.]]) + + >>> torch.diagonal_scatter(a, torch.ones(3), 0) + tensor([[1., 0., 0.], + [0., 1., 0.], + [0., 0., 1.]]) + + >>> torch.diagonal_scatter(a, torch.ones(2), 1) + tensor([[0., 1., 0.], + [0., 0., 1.], + [0., 0., 0.]]) + """ + +def diff( + input: Tensor, + n: _int = 1, + dim: _int = -1, + prepend: Tensor | None = None, + append: Tensor | None = None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + diff(input, n=1, dim=-1, prepend=None, append=None) -> Tensor + + Computes the n-th forward difference along the given dimension. + + The first-order differences are given by `out[i] = input[i + 1] - input[i]`. Higher-order + differences are calculated by using :func:`torch.diff` recursively. + + Args: + input (Tensor): the tensor to compute the differences on + n (int, optional): the number of times to recursively compute the difference + dim (int, optional): the dimension to compute the difference along. + Default is the last dimension. + prepend, append (Tensor, optional): values to prepend or append to + :attr:`input` along :attr:`dim` before computing the difference. + Their dimensions must be equivalent to that of input, and their shapes + must match input's shape except on :attr:`dim`. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([1, 3, 2]) + >>> torch.diff(a) + tensor([ 2, -1]) + >>> b = torch.tensor([4, 5]) + >>> torch.diff(a, append=b) + tensor([ 2, -1, 2, 1]) + >>> c = torch.tensor([[1, 2, 3], [3, 4, 5]]) + >>> torch.diff(c, dim=0) + tensor([[2, 2, 2]]) + >>> torch.diff(c, dim=1) + tensor([[1, 1], + [1, 1]]) + """ + +def digamma(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + digamma(input, *, out=None) -> Tensor + + Alias for :func:`torch.special.digamma`. + """ + +def dist(input: Tensor, other: Tensor, p: Number | _complex = 2) -> Tensor: + r""" + dist(input, other, p=2) -> Tensor + + Returns the p-norm of (:attr:`input` - :attr:`other`) + + The shapes of :attr:`input` and :attr:`other` must be + :ref:`broadcastable `. + + Args: + input (Tensor): the input tensor. + other (Tensor): the Right-hand-side input tensor + p (float, optional): the norm to be computed + + Example:: + + >>> x = torch.randn(4) + >>> x + tensor([-1.5393, -0.8675, 0.5916, 1.6321]) + >>> y = torch.randn(4) + >>> y + tensor([ 0.0967, -1.0511, 0.6295, 0.8360]) + >>> torch.dist(x, y, 3.5) + tensor(1.6727) + >>> torch.dist(x, y, 3) + tensor(1.6973) + >>> torch.dist(x, y, 0) + tensor(4.) + >>> torch.dist(x, y, 1) + tensor(2.6537) + """ + +def div( + input: Tensor | Number, + other: Tensor | Number, + *, + rounding_mode: str | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + div(input, other, *, rounding_mode=None, out=None) -> Tensor + + Divides each element of the input ``input`` by the corresponding element of + :attr:`other`. + + .. math:: + \text{out}_i = \frac{\text{input}_i}{\text{other}_i} + + .. note:: + By default, this performs a "true" division like Python 3. + See the :attr:`rounding_mode` argument for floor division. + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer, float, and complex inputs. + Always promotes integer types to the default scalar type. + + Args: + input (Tensor): the dividend + other (Tensor or Number): the divisor + + Keyword args: + rounding_mode (str, optional): Type of rounding applied to the result: + + * None - default behavior. Performs no rounding and, if both :attr:`input` and + :attr:`other` are integer types, promotes the inputs to the default scalar type. + Equivalent to true division in Python (the ``/`` operator) and NumPy's ``np.true_divide``. + * ``"trunc"`` - rounds the results of the division towards zero. + Equivalent to C-style integer division. + * ``"floor"`` - rounds the results of the division down. + Equivalent to floor division in Python (the ``//`` operator) and NumPy's ``np.floor_divide``. + + out (Tensor, optional): the output tensor. + + Examples:: + + >>> x = torch.tensor([ 0.3810, 1.2774, -0.2972, -0.3719, 0.4637]) + >>> torch.div(x, 0.5) + tensor([ 0.7620, 2.5548, -0.5944, -0.7438, 0.9274]) + + >>> a = torch.tensor([[-0.3711, -1.9353, -0.4605, -0.2917], + ... [ 0.1815, -1.0111, 0.9805, -1.5923], + ... [ 0.1062, 1.4581, 0.7759, -1.2344], + ... [-0.1830, -0.0313, 1.1908, -1.4757]]) + >>> b = torch.tensor([ 0.8032, 0.2930, -0.8113, -0.2308]) + >>> torch.div(a, b) + tensor([[-0.4620, -6.6051, 0.5676, 1.2639], + [ 0.2260, -3.4509, -1.2086, 6.8990], + [ 0.1322, 4.9764, -0.9564, 5.3484], + [-0.2278, -0.1068, -1.4678, 6.3938]]) + + >>> torch.div(a, b, rounding_mode='trunc') + tensor([[-0., -6., 0., 1.], + [ 0., -3., -1., 6.], + [ 0., 4., -0., 5.], + [-0., -0., -1., 6.]]) + + >>> torch.div(a, b, rounding_mode='floor') + tensor([[-1., -7., 0., 1.], + [ 0., -4., -2., 6.], + [ 0., 4., -1., 5.], + [-1., -1., -2., 6.]]) + """ + +@overload +def divide( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + divide(input, other, *, rounding_mode=None, out=None) -> Tensor + + Alias for :func:`torch.div`. + """ + +@overload +def divide( + input: Tensor, + other: Tensor, + *, + rounding_mode: str | None, + out: Tensor | None = None, +) -> Tensor: + r""" + divide(input, other, *, rounding_mode=None, out=None) -> Tensor + + Alias for :func:`torch.div`. + """ + +@overload +def divide( + input: Tensor, + other: Number | _complex, + *, + rounding_mode: str | None, +) -> Tensor: + r""" + divide(input, other, *, rounding_mode=None, out=None) -> Tensor + + Alias for :func:`torch.div`. + """ + +@overload +def divide(input: Tensor, other: Number | _complex) -> Tensor: + r""" + divide(input, other, *, rounding_mode=None, out=None) -> Tensor + + Alias for :func:`torch.div`. + """ + +def dot( + input: Tensor, + tensor: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + dot(input, tensor, *, out=None) -> Tensor + + Computes the dot product of two 1D tensors. + + .. note:: + + Unlike NumPy's dot, torch.dot intentionally only supports computing the dot product + of two 1D tensors with the same number of elements. + + Args: + input (Tensor): first tensor in the dot product, must be 1D. + tensor (Tensor): second tensor in the dot product, must be 1D. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.dot(torch.tensor([2, 3]), torch.tensor([2, 1])) + tensor(7) + + >>> t1, t2 = torch.tensor([0, 1]), torch.tensor([2, 3]) + >>> torch.dot(t1, t2) + tensor(3) + """ + +def dropout(input: Tensor, p: _float, train: _bool) -> Tensor: ... +def dropout_(input: Tensor, p: _float, train: _bool) -> Tensor: ... +def dsmm(input: Tensor, mat2: Tensor) -> Tensor: ... +@overload +def dsplit(input: Tensor, sections: _int) -> tuple[Tensor, ...]: + r""" + dsplit(input, indices_or_sections) -> List of Tensors + + Splits :attr:`input`, a tensor with three or more dimensions, into multiple tensors + depthwise according to :attr:`indices_or_sections`. Each split is a view of + :attr:`input`. + + This is equivalent to calling torch.tensor_split(input, indices_or_sections, dim=2) + (the split dimension is 2), except that if :attr:`indices_or_sections` is an integer + it must evenly divide the split dimension or a runtime error will be thrown. + + This function is based on NumPy's :func:`numpy.dsplit`. + + Args: + input (Tensor): tensor to split. + indices_or_sections (int or list or tuple of ints): See argument in :func:`torch.tensor_split`. + + Example:: + + >>> t = torch.arange(16.0).reshape(2, 2, 4) + >>> t + tensor([[[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.]], + [[ 8., 9., 10., 11.], + [12., 13., 14., 15.]]]) + >>> torch.dsplit(t, 2) + (tensor([[[ 0., 1.], + [ 4., 5.]], + [[ 8., 9.], + [12., 13.]]]), + tensor([[[ 2., 3.], + [ 6., 7.]], + [[10., 11.], + [14., 15.]]])) + + >>> torch.dsplit(t, [3, 6]) + (tensor([[[ 0., 1., 2.], + [ 4., 5., 6.]], + [[ 8., 9., 10.], + [12., 13., 14.]]]), + tensor([[[ 3.], + [ 7.]], + [[11.], + [15.]]]), + tensor([], size=(2, 2, 0))) + """ + +@overload +def dsplit(input: Tensor, indices: _size) -> tuple[Tensor, ...]: + r""" + dsplit(input, indices_or_sections) -> List of Tensors + + Splits :attr:`input`, a tensor with three or more dimensions, into multiple tensors + depthwise according to :attr:`indices_or_sections`. Each split is a view of + :attr:`input`. + + This is equivalent to calling torch.tensor_split(input, indices_or_sections, dim=2) + (the split dimension is 2), except that if :attr:`indices_or_sections` is an integer + it must evenly divide the split dimension or a runtime error will be thrown. + + This function is based on NumPy's :func:`numpy.dsplit`. + + Args: + input (Tensor): tensor to split. + indices_or_sections (int or list or tuple of ints): See argument in :func:`torch.tensor_split`. + + Example:: + + >>> t = torch.arange(16.0).reshape(2, 2, 4) + >>> t + tensor([[[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.]], + [[ 8., 9., 10., 11.], + [12., 13., 14., 15.]]]) + >>> torch.dsplit(t, 2) + (tensor([[[ 0., 1.], + [ 4., 5.]], + [[ 8., 9.], + [12., 13.]]]), + tensor([[[ 2., 3.], + [ 6., 7.]], + [[10., 11.], + [14., 15.]]])) + + >>> torch.dsplit(t, [3, 6]) + (tensor([[[ 0., 1., 2.], + [ 4., 5., 6.]], + [[ 8., 9., 10.], + [12., 13., 14.]]]), + tensor([[[ 3.], + [ 7.]], + [[11.], + [15.]]]), + tensor([], size=(2, 2, 0))) + """ + +def dstack( + tensors: tuple[Tensor, ...] | list[Tensor] | None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + dstack(tensors, *, out=None) -> Tensor + + Stack tensors in sequence depthwise (along third axis). + + This is equivalent to concatenation along the third axis after 1-D and 2-D tensors have been reshaped by :func:`torch.atleast_3d`. + + Args: + tensors (sequence of Tensors): sequence of tensors to concatenate + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([1, 2, 3]) + >>> b = torch.tensor([4, 5, 6]) + >>> torch.dstack((a,b)) + tensor([[[1, 4], + [2, 5], + [3, 6]]]) + >>> a = torch.tensor([[1],[2],[3]]) + >>> b = torch.tensor([[4],[5],[6]]) + >>> torch.dstack((a,b)) + tensor([[[1, 4]], + [[2, 5]], + [[3, 6]]]) + """ + +def embedding( + weight: Tensor, + indices: Tensor, + padding_idx: _int | SymInt = -1, + scale_grad_by_freq: _bool = False, + sparse: _bool = False, +) -> Tensor: ... +@overload +def embedding_bag( + weight: Tensor, + indices: Tensor, + offsets: Tensor, + scale_grad_by_freq: _bool, + mode: _int, + sparse: _bool, + per_sample_weights: Tensor | None, + include_last_offset: _bool, + padding_idx: _int | None, +) -> tuple[Tensor, Tensor, Tensor, Tensor]: ... +@overload +def embedding_bag( + weight: Tensor, + indices: Tensor, + offsets: Tensor, + scale_grad_by_freq: _bool = False, + mode: _int = 0, + sparse: _bool = False, + per_sample_weights: Tensor | None = None, + include_last_offset: _bool = False, +) -> tuple[Tensor, Tensor, Tensor, Tensor]: ... +def embedding_renorm_( + input: Tensor, + indices: Tensor, + max_norm: _float, + norm_type: _float, +) -> Tensor: ... +@overload +def empty( + size: Sequence[_int | SymInt], + *, + memory_format: memory_format | None = None, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + empty(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False, memory_format=torch.contiguous_format) -> Tensor + + Returns a tensor filled with uninitialized data. The shape of the tensor is + defined by the variable argument :attr:`size`. + + .. note:: + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, the output tensor is initialized to prevent any possible + nondeterministic behavior from using the data as an input to an operation. + Floating point and complex tensors are filled with NaN, and integer tensors + are filled with the maximum value. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.contiguous_format``. + + Example:: + + >>> torch.empty((2,3), dtype=torch.int64) + tensor([[ 9.4064e+13, 2.8000e+01, 9.3493e+13], + [ 7.5751e+18, 7.1428e+18, 7.5955e+18]]) + """ + +@overload +def empty( + *size: _int | SymInt, + memory_format: memory_format | None = None, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + empty(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False, memory_format=torch.contiguous_format) -> Tensor + + Returns a tensor filled with uninitialized data. The shape of the tensor is + defined by the variable argument :attr:`size`. + + .. note:: + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, the output tensor is initialized to prevent any possible + nondeterministic behavior from using the data as an input to an operation. + Floating point and complex tensors are filled with NaN, and integer tensors + are filled with the maximum value. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.contiguous_format``. + + Example:: + + >>> torch.empty((2,3), dtype=torch.int64) + tensor([[ 9.4064e+13, 2.8000e+01, 9.3493e+13], + [ 7.5751e+18, 7.1428e+18, 7.5955e+18]]) + """ + +@overload +def empty( + size: _size, + *, + names: Sequence[str | EllipsisType | None] | None, + memory_format: memory_format | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + empty(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False, memory_format=torch.contiguous_format) -> Tensor + + Returns a tensor filled with uninitialized data. The shape of the tensor is + defined by the variable argument :attr:`size`. + + .. note:: + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, the output tensor is initialized to prevent any possible + nondeterministic behavior from using the data as an input to an operation. + Floating point and complex tensors are filled with NaN, and integer tensors + are filled with the maximum value. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.contiguous_format``. + + Example:: + + >>> torch.empty((2,3), dtype=torch.int64) + tensor([[ 9.4064e+13, 2.8000e+01, 9.3493e+13], + [ 7.5751e+18, 7.1428e+18, 7.5955e+18]]) + """ + +@overload +def empty( + *size: _int, + names: Sequence[str | EllipsisType | None] | None, + memory_format: memory_format | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + empty(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False, memory_format=torch.contiguous_format) -> Tensor + + Returns a tensor filled with uninitialized data. The shape of the tensor is + defined by the variable argument :attr:`size`. + + .. note:: + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, the output tensor is initialized to prevent any possible + nondeterministic behavior from using the data as an input to an operation. + Floating point and complex tensors are filled with NaN, and integer tensors + are filled with the maximum value. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.contiguous_format``. + + Example:: + + >>> torch.empty((2,3), dtype=torch.int64) + tensor([[ 9.4064e+13, 2.8000e+01, 9.3493e+13], + [ 7.5751e+18, 7.1428e+18, 7.5955e+18]]) + """ + +def empty_like( + input: Tensor, + *, + memory_format: memory_format | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + empty_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + + Returns an uninitialized tensor with the same size as :attr:`input`. + ``torch.empty_like(input)`` is equivalent to + ``torch.empty(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``. + + .. note:: + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, the output tensor is initialized to prevent any possible + nondeterministic behavior from using the data as an input to an operation. + Floating point and complex tensors are filled with NaN, and integer tensors + are filled with the maximum value. + + Args: + input (Tensor): the size of :attr:`input` will determine size of the output tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: if ``None``, defaults to the dtype of :attr:`input`. + layout (:class:`torch.layout`, optional): the desired layout of returned tensor. + Default: if ``None``, defaults to the layout of :attr:`input`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, defaults to the device of :attr:`input`. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + + Example:: + + >>> a=torch.empty((2,3), dtype=torch.int32, device = 'cuda') + >>> torch.empty_like(a) + tensor([[0, 0, 0], + [0, 0, 0]], device='cuda:0', dtype=torch.int32) + """ + +def empty_permuted( + size: Sequence[_int | SymInt], + physical_layout: _size, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + empty_permuted(size, physical_layout, *, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Creates an uninitialized, non-overlapping and dense tensor with the + specified :attr:`size`, with :attr:`physical_layout` specifying how the + dimensions are physically laid out in memory (each logical dimension is listed + from outermost to innermost). :attr:`physical_layout` is a generalization + of NCHW/NHWC notation: if each dimension is assigned a number according to + what order they occur in size (N=0, C=1, H=2, W=3), then NCHW is ``(0, 1, 2, 3)`` + while NHWC is ``(0, 2, 3, 1)``. Equivalently, the strides of the output + tensor ``t`` are such that ``t.stride(physical_layout[i]) == contiguous_strides[i]`` + (notably, this function is *not* equivalent to ``torch.empty(size).permute(physical_layout)``). + + Unlike :func:`torch.empty_strided`, this is guaranteed to produce a dense + tensor with no overlaps. If possible, prefer using this function over + :func:`torch.empty_strided` or manual use of :func:`torch.as_strided`. + + .. note:: + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, the output tensor is initialized to prevent any possible + nondeterministic behavior from using the data as an input to an operation. + Floating point and complex tensors are filled with NaN, and integer tensors + are filled with the maximum value. + + Args: + size (tuple of int): the shape of the output tensor + physical_layout (tuple of int): the ordering of dimensions physically in memory + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Examples: + + >>> torch.empty((2, 3, 5, 7)).stride() + (105, 35, 7, 1) + >>> torch.empty_permuted((2, 3, 5, 7), (0, 1, 2, 3)).stride() + (105, 35, 7, 1) + >>> torch.empty((2, 3, 5, 7), memory_format=torch.channels_last).stride() + (105, 1, 21, 3) + >>> torch.empty_permuted((2, 3, 5, 7), (0, 2, 3, 1)).stride() + (105, 1, 21, 3) + >>> torch.empty_permuted((2, 3, 5, 7), (0, 2, 3, 1)).dim_order() + (0, 2, 3, 1) + """ + +def empty_quantized( + size: _size, + qtensor: Tensor, + *, + memory_format: memory_format | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: ... +def empty_strided( + size: Sequence[_int | SymInt], + stride: Sequence[_int | SymInt], + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + empty_strided(size, stride, *, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Creates a tensor with the specified :attr:`size` and :attr:`stride` and filled with undefined data. + + .. warning:: + If the constructed tensor is "overlapped" (with multiple indices referring to the same element + in memory) its behavior is undefined. + + .. note:: + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, the output tensor is initialized to prevent any possible + nondeterministic behavior from using the data as an input to an operation. + Floating point and complex tensors are filled with NaN, and integer tensors + are filled with the maximum value. + + Args: + size (tuple of int): the shape of the output tensor + stride (tuple of int): the strides of the output tensor + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> a = torch.empty_strided((2, 3), (1, 2)) + >>> a + tensor([[8.9683e-44, 4.4842e-44, 5.1239e+07], + [0.0000e+00, 0.0000e+00, 3.0705e-41]]) + >>> a.stride() + (1, 2) + >>> a.size() + torch.Size([2, 3]) + """ + +@overload +def eq( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + eq(input, other, *, out=None) -> Tensor + + Computes element-wise equality + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is equal to :attr:`other` and False elsewhere + + Example:: + + >>> torch.eq(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[ True, False], + [False, True]]) + """ + +@overload +def eq( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + eq(input, other, *, out=None) -> Tensor + + Computes element-wise equality + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is equal to :attr:`other` and False elsewhere + + Example:: + + >>> torch.eq(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[ True, False], + [False, True]]) + """ + +def equal(input: Tensor, other: Tensor) -> _bool: + r""" + equal(input, other) -> bool + + ``True`` if two tensors have the same size and elements, ``False`` otherwise. + + .. note:: + + Tensors containing NaNs are never equal to each other. Additionally, this function does not + differentiate between the data types of the tensors during comparison. For more thorough tensor checks, + use :meth:`torch.testing.assert_close`. + + Example:: + + >>> torch.equal(torch.tensor([1, 2]), torch.tensor([1, 2])) + True + >>> torch.equal(torch.tensor([3, torch.nan]), torch.tensor([3, torch.nan])) + False + >>> torch.equal(torch.tensor([1, 2, 3], dtype=torch.int32), torch.tensor([1, 2, 3], dtype=torch.float32)) + True + """ + +def erf(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + erf(input, *, out=None) -> Tensor + + Alias for :func:`torch.special.erf`. + """ + +def erf_(input: Tensor) -> Tensor: ... +def erfc(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + erfc(input, *, out=None) -> Tensor + + Alias for :func:`torch.special.erfc`. + """ + +def erfc_(input: Tensor) -> Tensor: ... +def erfinv(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + erfinv(input, *, out=None) -> Tensor + + Alias for :func:`torch.special.erfinv`. + """ + +def exp(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + exp(input, *, out=None) -> Tensor + + Returns a new tensor with the exponential of the elements + of the input tensor :attr:`input`. + + .. math:: + y_{i} = e^{x_{i}} + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.exp(torch.tensor([0, math.log(2.)])) + tensor([ 1., 2.]) + """ + +def exp2(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + exp2(input, *, out=None) -> Tensor + + Alias for :func:`torch.special.exp2`. + """ + +def exp2_(input: Tensor) -> Tensor: ... +def exp_(input: Tensor) -> Tensor: ... +def expand_copy( + input: Tensor, + size: Sequence[_int | SymInt], + *, + implicit: _bool = False, + out: Tensor | None = None, +) -> Tensor: + r""" + Performs the same operation as :func:`torch.Tensor.expand`, but all output tensors + are freshly created instead of aliasing the input. + """ + +def expm1(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + expm1(input, *, out=None) -> Tensor + + Alias for :func:`torch.special.expm1`. + """ + +def expm1_(input: Tensor) -> Tensor: ... +@overload +def eye( + n: _int | SymInt, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + eye(n, m=None, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a 2-D tensor with ones on the diagonal and zeros elsewhere. + + Args: + n (int): the number of rows + m (int, optional): the number of columns with default being :attr:`n` + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 2-D tensor with ones on the diagonal and zeros elsewhere + + Example:: + + >>> torch.eye(3) + tensor([[ 1., 0., 0.], + [ 0., 1., 0.], + [ 0., 0., 1.]]) + """ + +@overload +def eye( + n: _int | SymInt, + m: _int | SymInt, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + eye(n, m=None, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a 2-D tensor with ones on the diagonal and zeros elsewhere. + + Args: + n (int): the number of rows + m (int, optional): the number of columns with default being :attr:`n` + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 2-D tensor with ones on the diagonal and zeros elsewhere + + Example:: + + >>> torch.eye(3) + tensor([[ 1., 0., 0.], + [ 0., 1., 0.], + [ 0., 0., 1.]]) + """ + +def fake_quantize_per_channel_affine( + input: Tensor, + scale: Tensor, + zero_point: Tensor, + axis: _int, + quant_min: _int, + quant_max: _int, +) -> Tensor: + r""" + fake_quantize_per_channel_affine(input, scale, zero_point, axis, quant_min, quant_max) -> Tensor + + Returns a new tensor with the data in :attr:`input` fake quantized per channel using :attr:`scale`, + :attr:`zero_point`, :attr:`quant_min` and :attr:`quant_max`, across the channel specified by :attr:`axis`. + + .. math:: + \text{output} = ( + min( + \text{quant\_max}, + max( + \text{quant\_min}, + \text{std::nearby\_int}(\text{input} / \text{scale}) + \text{zero\_point} + ) + ) - \text{zero\_point} + ) \times \text{scale} + + Args: + input (Tensor): the input value(s), in ``torch.float32`` + scale (Tensor): quantization scale, per channel in ``torch.float32`` + zero_point (Tensor): quantization zero_point, per channel in ``torch.int32`` or ``torch.half`` or ``torch.float32`` + axis (int32): channel axis + quant_min (int64): lower bound of the quantized domain + quant_max (int64): upper bound of the quantized domain + + Returns: + Tensor: A newly fake_quantized per channel ``torch.float32`` tensor + + Example:: + + >>> x = torch.randn(2, 2, 2) + >>> x + tensor([[[-0.2525, -0.0466], + [ 0.3491, -0.2168]], + + [[-0.5906, 1.6258], + [ 0.6444, -0.0542]]]) + >>> scales = (torch.randn(2) + 1) * 0.05 + >>> scales + tensor([0.0475, 0.0486]) + >>> zero_points = torch.zeros(2).to(torch.int32) + >>> zero_points + tensor([0, 0]) + >>> torch.fake_quantize_per_channel_affine(x, scales, zero_points, 1, 0, 255) + tensor([[[0.0000, 0.0000], + [0.3405, 0.0000]], + + [[0.0000, 1.6134], + [0.6323, 0.0000]]]) + """ + +@overload +def fake_quantize_per_tensor_affine( + input: Tensor, + scale: _float, + zero_point: _int, + quant_min: _int, + quant_max: _int, +) -> Tensor: + r""" + fake_quantize_per_tensor_affine(input, scale, zero_point, quant_min, quant_max) -> Tensor + + Returns a new tensor with the data in :attr:`input` fake quantized using :attr:`scale`, + :attr:`zero_point`, :attr:`quant_min` and :attr:`quant_max`. + + .. math:: + \text{output} = ( + min( + \text{quant\_max}, + max( + \text{quant\_min}, + \text{std::nearby\_int}(\text{input} / \text{scale}) + \text{zero\_point} + ) + ) - \text{zero\_point} + ) \times \text{scale} + + Args: + input (Tensor): the input value(s), ``torch.float32`` tensor + scale (double scalar or ``float32`` Tensor): quantization scale + zero_point (int64 scalar or ``int32`` Tensor): quantization zero_point + quant_min (int64): lower bound of the quantized domain + quant_max (int64): upper bound of the quantized domain + + Returns: + Tensor: A newly fake_quantized ``torch.float32`` tensor + + Example:: + + >>> x = torch.randn(4) + >>> x + tensor([ 0.0552, 0.9730, 0.3973, -1.0780]) + >>> torch.fake_quantize_per_tensor_affine(x, 0.1, 0, 0, 255) + tensor([0.1000, 1.0000, 0.4000, 0.0000]) + >>> torch.fake_quantize_per_tensor_affine(x, torch.tensor(0.1), torch.tensor(0), 0, 255) + tensor([0.1000, 1.0000, 0.4000, 0.0000]) + """ + +@overload +def fake_quantize_per_tensor_affine( + input: Tensor, + scale: Tensor, + zero_point: Tensor, + quant_min: _int, + quant_max: _int, +) -> Tensor: + r""" + fake_quantize_per_tensor_affine(input, scale, zero_point, quant_min, quant_max) -> Tensor + + Returns a new tensor with the data in :attr:`input` fake quantized using :attr:`scale`, + :attr:`zero_point`, :attr:`quant_min` and :attr:`quant_max`. + + .. math:: + \text{output} = ( + min( + \text{quant\_max}, + max( + \text{quant\_min}, + \text{std::nearby\_int}(\text{input} / \text{scale}) + \text{zero\_point} + ) + ) - \text{zero\_point} + ) \times \text{scale} + + Args: + input (Tensor): the input value(s), ``torch.float32`` tensor + scale (double scalar or ``float32`` Tensor): quantization scale + zero_point (int64 scalar or ``int32`` Tensor): quantization zero_point + quant_min (int64): lower bound of the quantized domain + quant_max (int64): upper bound of the quantized domain + + Returns: + Tensor: A newly fake_quantized ``torch.float32`` tensor + + Example:: + + >>> x = torch.randn(4) + >>> x + tensor([ 0.0552, 0.9730, 0.3973, -1.0780]) + >>> torch.fake_quantize_per_tensor_affine(x, 0.1, 0, 0, 255) + tensor([0.1000, 1.0000, 0.4000, 0.0000]) + >>> torch.fake_quantize_per_tensor_affine(x, torch.tensor(0.1), torch.tensor(0), 0, 255) + tensor([0.1000, 1.0000, 0.4000, 0.0000]) + """ + +def fbgemm_linear_fp16_weight( + input: Tensor, + packed_weight: Tensor, + bias: Tensor, +) -> Tensor: ... +def fbgemm_linear_fp16_weight_fp32_activation( + input: Tensor, + packed_weight: Tensor, + bias: Tensor, +) -> Tensor: ... +def fbgemm_linear_int8_weight( + input: Tensor, + weight: Tensor, + packed: Tensor, + col_offsets: Tensor, + weight_scale: Number | _complex, + weight_zero_point: Number | _complex, + bias: Tensor, +) -> Tensor: ... +def fbgemm_linear_int8_weight_fp32_activation( + input: Tensor, + weight: Tensor, + packed: Tensor, + col_offsets: Tensor, + weight_scale: Number | _complex, + weight_zero_point: Number | _complex, + bias: Tensor, +) -> Tensor: ... +def fbgemm_linear_quantize_weight( + input: Tensor, +) -> tuple[Tensor, Tensor, _float, _int]: ... +def fbgemm_pack_gemm_matrix_fp16(input: Tensor) -> Tensor: ... +@overload +def fbgemm_pack_quantized_matrix(input: Tensor) -> Tensor: ... +@overload +def fbgemm_pack_quantized_matrix(input: Tensor, K: _int, N: _int) -> Tensor: ... +def feature_alpha_dropout(input: Tensor, p: _float, train: _bool) -> Tensor: ... +def feature_alpha_dropout_( + input: Tensor, + p: _float, + train: _bool, +) -> Tensor: ... +def feature_dropout(input: Tensor, p: _float, train: _bool) -> Tensor: ... +def feature_dropout_(input: Tensor, p: _float, train: _bool) -> Tensor: ... +@overload +def fill(input: Tensor, value: Tensor) -> Tensor: ... +@overload +def fill(input: Tensor, value: Number | _complex) -> Tensor: ... +@overload +def fill_(input: Tensor, value: Tensor) -> Tensor: ... +@overload +def fill_(input: Tensor, value: Number | _complex) -> Tensor: ... +def fix(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + fix(input, *, out=None) -> Tensor + + Alias for :func:`torch.trunc` + """ + +def fix_(input: Tensor) -> Tensor: ... +@overload +def flatten( + input: Tensor, + start_dim: _int = 0, + end_dim: _int = -1, +) -> Tensor: + r""" + flatten(input, start_dim=0, end_dim=-1) -> Tensor + + Flattens :attr:`input` by reshaping it into a one-dimensional tensor. If :attr:`start_dim` or :attr:`end_dim` + are passed, only dimensions starting with :attr:`start_dim` and ending with :attr:`end_dim` are flattened. + The order of elements in :attr:`input` is unchanged. + + Unlike NumPy's flatten, which always copies input's data, this function may return the original object, a view, + or copy. If no dimensions are flattened, then the original object :attr:`input` is returned. Otherwise, if input can + be viewed as the flattened shape, then that view is returned. Finally, only if the input cannot be viewed as the + flattened shape is input's data copied. See :meth:`torch.Tensor.view` for details on when a view will be returned. + + .. note:: + Flattening a zero-dimensional tensor will return a one-dimensional view. + + Args: + input (Tensor): the input tensor. + start_dim (int): the first dim to flatten + end_dim (int): the last dim to flatten + + Example:: + + >>> t = torch.tensor([[[1, 2], + ... [3, 4]], + ... [[5, 6], + ... [7, 8]]]) + >>> torch.flatten(t) + tensor([1, 2, 3, 4, 5, 6, 7, 8]) + >>> torch.flatten(t, start_dim=1) + tensor([[1, 2, 3, 4], + [5, 6, 7, 8]]) + """ + +@overload +def flatten( + input: Tensor, + start_dim: _int, + end_dim: _int, + out_dim: str | EllipsisType | None, +) -> Tensor: + r""" + flatten(input, start_dim=0, end_dim=-1) -> Tensor + + Flattens :attr:`input` by reshaping it into a one-dimensional tensor. If :attr:`start_dim` or :attr:`end_dim` + are passed, only dimensions starting with :attr:`start_dim` and ending with :attr:`end_dim` are flattened. + The order of elements in :attr:`input` is unchanged. + + Unlike NumPy's flatten, which always copies input's data, this function may return the original object, a view, + or copy. If no dimensions are flattened, then the original object :attr:`input` is returned. Otherwise, if input can + be viewed as the flattened shape, then that view is returned. Finally, only if the input cannot be viewed as the + flattened shape is input's data copied. See :meth:`torch.Tensor.view` for details on when a view will be returned. + + .. note:: + Flattening a zero-dimensional tensor will return a one-dimensional view. + + Args: + input (Tensor): the input tensor. + start_dim (int): the first dim to flatten + end_dim (int): the last dim to flatten + + Example:: + + >>> t = torch.tensor([[[1, 2], + ... [3, 4]], + ... [[5, 6], + ... [7, 8]]]) + >>> torch.flatten(t) + tensor([1, 2, 3, 4, 5, 6, 7, 8]) + >>> torch.flatten(t, start_dim=1) + tensor([[1, 2, 3, 4], + [5, 6, 7, 8]]) + """ + +@overload +def flatten( + input: Tensor, + start_dim: str | EllipsisType | None, + end_dim: str | EllipsisType | None, + out_dim: str | EllipsisType | None, +) -> Tensor: + r""" + flatten(input, start_dim=0, end_dim=-1) -> Tensor + + Flattens :attr:`input` by reshaping it into a one-dimensional tensor. If :attr:`start_dim` or :attr:`end_dim` + are passed, only dimensions starting with :attr:`start_dim` and ending with :attr:`end_dim` are flattened. + The order of elements in :attr:`input` is unchanged. + + Unlike NumPy's flatten, which always copies input's data, this function may return the original object, a view, + or copy. If no dimensions are flattened, then the original object :attr:`input` is returned. Otherwise, if input can + be viewed as the flattened shape, then that view is returned. Finally, only if the input cannot be viewed as the + flattened shape is input's data copied. See :meth:`torch.Tensor.view` for details on when a view will be returned. + + .. note:: + Flattening a zero-dimensional tensor will return a one-dimensional view. + + Args: + input (Tensor): the input tensor. + start_dim (int): the first dim to flatten + end_dim (int): the last dim to flatten + + Example:: + + >>> t = torch.tensor([[[1, 2], + ... [3, 4]], + ... [[5, 6], + ... [7, 8]]]) + >>> torch.flatten(t) + tensor([1, 2, 3, 4, 5, 6, 7, 8]) + >>> torch.flatten(t, start_dim=1) + tensor([[1, 2, 3, 4], + [5, 6, 7, 8]]) + """ + +@overload +def flatten( + input: Tensor, + dims: Sequence[str | EllipsisType | None], + out_dim: str | EllipsisType | None, +) -> Tensor: + r""" + flatten(input, start_dim=0, end_dim=-1) -> Tensor + + Flattens :attr:`input` by reshaping it into a one-dimensional tensor. If :attr:`start_dim` or :attr:`end_dim` + are passed, only dimensions starting with :attr:`start_dim` and ending with :attr:`end_dim` are flattened. + The order of elements in :attr:`input` is unchanged. + + Unlike NumPy's flatten, which always copies input's data, this function may return the original object, a view, + or copy. If no dimensions are flattened, then the original object :attr:`input` is returned. Otherwise, if input can + be viewed as the flattened shape, then that view is returned. Finally, only if the input cannot be viewed as the + flattened shape is input's data copied. See :meth:`torch.Tensor.view` for details on when a view will be returned. + + .. note:: + Flattening a zero-dimensional tensor will return a one-dimensional view. + + Args: + input (Tensor): the input tensor. + start_dim (int): the first dim to flatten + end_dim (int): the last dim to flatten + + Example:: + + >>> t = torch.tensor([[[1, 2], + ... [3, 4]], + ... [[5, 6], + ... [7, 8]]]) + >>> torch.flatten(t) + tensor([1, 2, 3, 4, 5, 6, 7, 8]) + >>> torch.flatten(t, start_dim=1) + tensor([[1, 2, 3, 4], + [5, 6, 7, 8]]) + """ + +def flip(input: Tensor, dims: _size) -> Tensor: + r""" + flip(input, dims) -> Tensor + + Reverse the order of an n-D tensor along given axis in dims. + + .. note:: + `torch.flip` makes a copy of :attr:`input`'s data. This is different from NumPy's `np.flip`, + which returns a view in constant time. Since copying a tensor's data is more work than viewing that data, + `torch.flip` is expected to be slower than `np.flip`. + + Args: + input (Tensor): the input tensor. + dims (a list or tuple): axis to flip on + + Example:: + + >>> x = torch.arange(8).view(2, 2, 2) + >>> x + tensor([[[ 0, 1], + [ 2, 3]], + + [[ 4, 5], + [ 6, 7]]]) + >>> torch.flip(x, [0, 1]) + tensor([[[ 6, 7], + [ 4, 5]], + + [[ 2, 3], + [ 0, 1]]]) + """ + +def fliplr(input: Tensor) -> Tensor: + r""" + fliplr(input) -> Tensor + + Flip tensor in the left/right direction, returning a new tensor. + + Flip the entries in each row in the left/right direction. + Columns are preserved, but appear in a different order than before. + + Note: + Requires the tensor to be at least 2-D. + + .. note:: + `torch.fliplr` makes a copy of :attr:`input`'s data. This is different from NumPy's `np.fliplr`, + which returns a view in constant time. Since copying a tensor's data is more work than viewing that data, + `torch.fliplr` is expected to be slower than `np.fliplr`. + + Args: + input (Tensor): Must be at least 2-dimensional. + + Example:: + + >>> x = torch.arange(4).view(2, 2) + >>> x + tensor([[0, 1], + [2, 3]]) + >>> torch.fliplr(x) + tensor([[1, 0], + [3, 2]]) + """ + +def flipud(input: Tensor) -> Tensor: + r""" + flipud(input) -> Tensor + + Flip tensor in the up/down direction, returning a new tensor. + + Flip the entries in each column in the up/down direction. + Rows are preserved, but appear in a different order than before. + + Note: + Requires the tensor to be at least 1-D. + + .. note:: + `torch.flipud` makes a copy of :attr:`input`'s data. This is different from NumPy's `np.flipud`, + which returns a view in constant time. Since copying a tensor's data is more work than viewing that data, + `torch.flipud` is expected to be slower than `np.flipud`. + + Args: + input (Tensor): Must be at least 1-dimensional. + + Example:: + + >>> x = torch.arange(4).view(2, 2) + >>> x + tensor([[0, 1], + [2, 3]]) + >>> torch.flipud(x) + tensor([[2, 3], + [0, 1]]) + """ + +@overload +def float_power( + input: Tensor, + exponent: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + float_power(input, exponent, *, out=None) -> Tensor + + Raises :attr:`input` to the power of :attr:`exponent`, elementwise, in double precision. + If neither input is complex returns a ``torch.float64`` tensor, + and if one or more inputs is complex returns a ``torch.complex128`` tensor. + + .. note:: + This function always computes in double precision, unlike :func:`torch.pow`, + which implements more typical :ref:`type promotion `. + This is useful when the computation needs to be performed in a wider or more precise dtype, + or the results of the computation may contain fractional values not representable in the input dtypes, + like when an integer base is raised to a negative integer exponent. + + Args: + input (Tensor or Number): the base value(s) + exponent (Tensor or Number): the exponent value(s) + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randint(10, (4,)) + >>> a + tensor([6, 4, 7, 1]) + >>> torch.float_power(a, 2) + tensor([36., 16., 49., 1.], dtype=torch.float64) + + >>> a = torch.arange(1, 5) + >>> a + tensor([ 1, 2, 3, 4]) + >>> exp = torch.tensor([2, -3, 4, -5]) + >>> exp + tensor([ 2, -3, 4, -5]) + >>> torch.float_power(a, exp) + tensor([1.0000e+00, 1.2500e-01, 8.1000e+01, 9.7656e-04], dtype=torch.float64) + """ + +@overload +def float_power( + self: Number | _complex, + exponent: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + float_power(input, exponent, *, out=None) -> Tensor + + Raises :attr:`input` to the power of :attr:`exponent`, elementwise, in double precision. + If neither input is complex returns a ``torch.float64`` tensor, + and if one or more inputs is complex returns a ``torch.complex128`` tensor. + + .. note:: + This function always computes in double precision, unlike :func:`torch.pow`, + which implements more typical :ref:`type promotion `. + This is useful when the computation needs to be performed in a wider or more precise dtype, + or the results of the computation may contain fractional values not representable in the input dtypes, + like when an integer base is raised to a negative integer exponent. + + Args: + input (Tensor or Number): the base value(s) + exponent (Tensor or Number): the exponent value(s) + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randint(10, (4,)) + >>> a + tensor([6, 4, 7, 1]) + >>> torch.float_power(a, 2) + tensor([36., 16., 49., 1.], dtype=torch.float64) + + >>> a = torch.arange(1, 5) + >>> a + tensor([ 1, 2, 3, 4]) + >>> exp = torch.tensor([2, -3, 4, -5]) + >>> exp + tensor([ 2, -3, 4, -5]) + >>> torch.float_power(a, exp) + tensor([1.0000e+00, 1.2500e-01, 8.1000e+01, 9.7656e-04], dtype=torch.float64) + """ + +@overload +def float_power( + input: Tensor, + exponent: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + float_power(input, exponent, *, out=None) -> Tensor + + Raises :attr:`input` to the power of :attr:`exponent`, elementwise, in double precision. + If neither input is complex returns a ``torch.float64`` tensor, + and if one or more inputs is complex returns a ``torch.complex128`` tensor. + + .. note:: + This function always computes in double precision, unlike :func:`torch.pow`, + which implements more typical :ref:`type promotion `. + This is useful when the computation needs to be performed in a wider or more precise dtype, + or the results of the computation may contain fractional values not representable in the input dtypes, + like when an integer base is raised to a negative integer exponent. + + Args: + input (Tensor or Number): the base value(s) + exponent (Tensor or Number): the exponent value(s) + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randint(10, (4,)) + >>> a + tensor([6, 4, 7, 1]) + >>> torch.float_power(a, 2) + tensor([36., 16., 49., 1.], dtype=torch.float64) + + >>> a = torch.arange(1, 5) + >>> a + tensor([ 1, 2, 3, 4]) + >>> exp = torch.tensor([2, -3, 4, -5]) + >>> exp + tensor([ 2, -3, 4, -5]) + >>> torch.float_power(a, exp) + tensor([1.0000e+00, 1.2500e-01, 8.1000e+01, 9.7656e-04], dtype=torch.float64) + """ + +def floor(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + floor(input, *, out=None) -> Tensor + + Returns a new tensor with the floor of the elements of :attr:`input`, + the largest integer less than or equal to each element. + + For integer inputs, follows the array-api convention of returning a + copy of the input tensor. + + .. math:: + \text{out}_{i} = \left\lfloor \text{input}_{i} \right\rfloor + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-0.8166, 1.5308, -0.2530, -0.2091]) + >>> torch.floor(a) + tensor([-1., 1., -1., -1.]) + """ + +def floor_(input: Tensor) -> Tensor: ... +def floor_divide( + input: Tensor | Number, + other: Tensor | Number, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + floor_divide(input, other, *, out=None) -> Tensor + + .. note:: + + Before PyTorch 1.13 :func:`torch.floor_divide` incorrectly performed + truncation division. To restore the previous behavior use + :func:`torch.div` with ``rounding_mode='trunc'``. + + Computes :attr:`input` divided by :attr:`other`, elementwise, and floors + the result. + + .. math:: + \text{{out}}_i = \text{floor} \left( \frac{{\text{{input}}_i}}{{\text{{other}}_i}} \right) + + + + Supports broadcasting to a common shape, type promotion, and integer and float inputs. + + Args: + input (Tensor or Number): the dividend + other (Tensor or Number): the divisor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([4.0, 3.0]) + >>> b = torch.tensor([2.0, 2.0]) + >>> torch.floor_divide(a, b) + tensor([2.0, 1.0]) + >>> torch.floor_divide(a, 1.4) + tensor([2.0, 2.0]) + """ + +def fmax( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + fmax(input, other, *, out=None) -> Tensor + + Computes the element-wise maximum of :attr:`input` and :attr:`other`. + + This is like :func:`torch.maximum` except it handles NaNs differently: + if exactly one of the two elements being compared is a NaN then the non-NaN element is taken as the maximum. + Only if both elements are NaN is NaN propagated. + + This function is a wrapper around C++'s ``std::fmax`` and is similar to NumPy's ``fmax`` function. + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer and floating-point inputs. + + Args: + input (Tensor): the input tensor. + other (Tensor): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([9.7, float('nan'), 3.1, float('nan')]) + >>> b = torch.tensor([-2.2, 0.5, float('nan'), float('nan')]) + >>> torch.fmax(a, b) + tensor([9.7000, 0.5000, 3.1000, nan]) + """ + +def fmin( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + fmin(input, other, *, out=None) -> Tensor + + Computes the element-wise minimum of :attr:`input` and :attr:`other`. + + This is like :func:`torch.minimum` except it handles NaNs differently: + if exactly one of the two elements being compared is a NaN then the non-NaN element is taken as the minimum. + Only if both elements are NaN is NaN propagated. + + This function is a wrapper around C++'s ``std::fmin`` and is similar to NumPy's ``fmin`` function. + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer and floating-point inputs. + + Args: + input (Tensor): the input tensor. + other (Tensor): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([2.2, float('nan'), 2.1, float('nan')]) + >>> b = torch.tensor([-9.3, 0.1, float('nan'), float('nan')]) + >>> torch.fmin(a, b) + tensor([-9.3000, 0.1000, 2.1000, nan]) + """ + +@overload +def fmod( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + fmod(input, other, *, out=None) -> Tensor + + Applies C++'s `std::fmod `_ entrywise. + The result has the same sign as the dividend :attr:`input` and its absolute value + is less than that of :attr:`other`. + + This function may be defined in terms of :func:`torch.div` as + + .. code:: python + + torch.fmod(a, b) == a - a.div(b, rounding_mode="trunc") * b + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer and float inputs. + + .. note:: + + When the divisor is zero, returns ``NaN`` for floating point dtypes + on both CPU and GPU; raises ``RuntimeError`` for integer division by + zero on CPU; Integer division by zero on GPU may return any value. + + .. note:: + + Complex inputs are not supported. In some cases, it is not mathematically + possible to satisfy the definition of a modulo operation with complex numbers. + + .. seealso:: + + :func:`torch.remainder` which implements Python's modulus operator. + This one is defined using division rounding down the result. + + Args: + input (Tensor): the dividend + other (Tensor or Scalar): the divisor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.fmod(torch.tensor([-3., -2, -1, 1, 2, 3]), 2) + tensor([-1., -0., -1., 1., 0., 1.]) + >>> torch.fmod(torch.tensor([1, 2, 3, 4, 5]), -1.5) + tensor([1.0000, 0.5000, 0.0000, 1.0000, 0.5000]) + """ + +@overload +def fmod( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + fmod(input, other, *, out=None) -> Tensor + + Applies C++'s `std::fmod `_ entrywise. + The result has the same sign as the dividend :attr:`input` and its absolute value + is less than that of :attr:`other`. + + This function may be defined in terms of :func:`torch.div` as + + .. code:: python + + torch.fmod(a, b) == a - a.div(b, rounding_mode="trunc") * b + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer and float inputs. + + .. note:: + + When the divisor is zero, returns ``NaN`` for floating point dtypes + on both CPU and GPU; raises ``RuntimeError`` for integer division by + zero on CPU; Integer division by zero on GPU may return any value. + + .. note:: + + Complex inputs are not supported. In some cases, it is not mathematically + possible to satisfy the definition of a modulo operation with complex numbers. + + .. seealso:: + + :func:`torch.remainder` which implements Python's modulus operator. + This one is defined using division rounding down the result. + + Args: + input (Tensor): the dividend + other (Tensor or Scalar): the divisor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.fmod(torch.tensor([-3., -2, -1, 1, 2, 3]), 2) + tensor([-1., -0., -1., 1., 0., 1.]) + >>> torch.fmod(torch.tensor([1, 2, 3, 4, 5]), -1.5) + tensor([1.0000, 0.5000, 0.0000, 1.0000, 0.5000]) + """ + +def frac(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + frac(input, *, out=None) -> Tensor + + Computes the fractional portion of each element in :attr:`input`. + + .. math:: + \text{out}_{i} = \text{input}_{i} - \left\lfloor |\text{input}_{i}| \right\rfloor * \operatorname{sgn}(\text{input}_{i}) + + Example:: + + >>> torch.frac(torch.tensor([1, 2.5, -3.2])) + tensor([ 0.0000, 0.5000, -0.2000]) + """ + +def frac_(input: Tensor) -> Tensor: ... +def frexp( + input: Tensor, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.frexp: + r""" + frexp(input, *, out=None) -> (Tensor mantissa, Tensor exponent) + + Decomposes :attr:`input` into mantissa and exponent tensors + such that :math:`\text{input} = \text{mantissa} \times 2^{\text{exponent}}`. + + The range of mantissa is the open interval (-1, 1). + + Supports float inputs. + + Args: + input (Tensor): the input tensor + + + Keyword args: + out (tuple, optional): the output tensors + + Example:: + + >>> x = torch.arange(9.) + >>> mantissa, exponent = torch.frexp(x) + >>> mantissa + tensor([0.0000, 0.5000, 0.5000, 0.7500, 0.5000, 0.6250, 0.7500, 0.8750, 0.5000]) + >>> exponent + tensor([0, 1, 2, 2, 3, 3, 3, 3, 4], dtype=torch.int32) + >>> torch.ldexp(mantissa, exponent) + tensor([0., 1., 2., 3., 4., 5., 6., 7., 8.]) + """ + +def frobenius_norm( + input: Tensor, + dim: _int | _size, + keepdim: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: ... +def from_file( + filename: str, + shared: _bool | None = None, + size: _int | None = 0, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + from_file(filename, shared=None, size=0, *, dtype=None, layout=None, device=None, pin_memory=False) + + Creates a CPU tensor with a storage backed by a memory-mapped file. + + If ``shared`` is True, then memory is shared between processes. All changes are written to the file. + If ``shared`` is False, then changes to the tensor do not affect the file. + + ``size`` is the number of elements in the Tensor. If ``shared`` is ``False``, then the file must contain + at least ``size * sizeof(dtype)`` bytes. If ``shared`` is ``True`` the file will be created if needed. + + .. note:: + Only CPU tensors can be mapped to files. + + .. note:: + For now, tensors with storages backed by a memory-mapped file cannot be created in pinned memory. + + + Args: + filename (str): file name to map + shared (bool): whether to share memory (whether ``MAP_SHARED`` or ``MAP_PRIVATE`` is passed to the + underlying `mmap(2) call `_) + size (int): number of elements in the tensor + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> t = torch.randn(2, 5, dtype=torch.float64) + >>> t.numpy().tofile('storage.pt') + >>> t_mapped = torch.from_file('storage.pt', shared=False, size=10, dtype=torch.float64) + """ + +def from_numpy(ndarray) -> Tensor: + r""" + from_numpy(ndarray) -> Tensor + + Creates a :class:`Tensor` from a :class:`numpy.ndarray`. + + The returned tensor and :attr:`ndarray` share the same memory. Modifications to + the tensor will be reflected in the :attr:`ndarray` and vice versa. The returned + tensor is not resizable. + + It currently accepts :attr:`ndarray` with dtypes of ``numpy.float64``, + ``numpy.float32``, ``numpy.float16``, ``numpy.complex64``, ``numpy.complex128``, + ``numpy.int64``, ``numpy.int32``, ``numpy.int16``, ``numpy.int8``, ``numpy.uint8``, + and ``bool``. + + .. warning:: + Writing to a tensor created from a read-only NumPy array is not supported and will result in undefined behavior. + + Example:: + + >>> a = numpy.array([1, 2, 3]) + >>> t = torch.from_numpy(a) + >>> t + tensor([ 1, 2, 3]) + >>> t[0] = -1 + >>> a + array([-1, 2, 3]) + """ + +def frombuffer( + buffer: Any, + *, + dtype: _dtype, + count: int = -1, + offset: int = 0, + requires_grad: _bool = False, +) -> Tensor: + r""" + frombuffer(buffer, *, dtype, count=-1, offset=0, requires_grad=False) -> Tensor + + Creates a 1-dimensional :class:`Tensor` from an object that implements + the Python buffer protocol. + + Skips the first :attr:`offset` bytes in the buffer, and interprets the rest of + the raw bytes as a 1-dimensional tensor of type :attr:`dtype` with :attr:`count` + elements. + + Note that either of the following must be true: + + 1. :attr:`count` is a positive non-zero number, and the total number of bytes + in the buffer is more than :attr:`offset` plus :attr:`count` times the size + (in bytes) of :attr:`dtype`. + + 2. :attr:`count` is negative, and the length (number of bytes) of the buffer + subtracted by the :attr:`offset` is a multiple of the size (in bytes) of + :attr:`dtype`. + + The returned tensor and buffer share the same memory. Modifications to + the tensor will be reflected in the buffer and vice versa. The returned + tensor is not resizable. + + .. note:: + This function increments the reference count for the object that + owns the shared memory. Therefore, such memory will not be deallocated + before the returned tensor goes out of scope. + + .. warning:: + This function's behavior is undefined when passed an object implementing + the buffer protocol whose data is not on the CPU. Doing so is likely to + cause a segmentation fault. + + .. warning:: + This function does not try to infer the :attr:`dtype` (hence, it is not + optional). Passing a different :attr:`dtype` than its source may result + in unexpected behavior. + + Args: + buffer (object): a Python object that exposes the buffer interface. + + Keyword args: + dtype (:class:`torch.dtype`): the desired data type of returned tensor. + count (int, optional): the number of desired elements to be read. + If negative, all the elements (until the end of the buffer) will be + read. Default: -1. + offset (int, optional): the number of bytes to skip at the start of + the buffer. Default: 0. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> import array + >>> a = array.array('i', [1, 2, 3]) + >>> t = torch.frombuffer(a, dtype=torch.int32) + >>> t + tensor([ 1, 2, 3]) + >>> t[0] = -1 + >>> a + array([-1, 2, 3]) + + >>> # Interprets the signed char bytes as 32-bit integers. + >>> # Each 4 signed char elements will be interpreted as + >>> # 1 signed 32-bit integer. + >>> import array + >>> a = array.array('b', [-1, 0, 0, 0]) + >>> torch.frombuffer(a, dtype=torch.int32) + tensor([255], dtype=torch.int32) + """ + +@overload +def full( + size: _size, + fill_value: Number | _complex, + *, + out: Tensor | None = None, + layout: _layout = strided, + dtype: _dtype | None = None, + device: DeviceLikeType | None = None, + requires_grad: _bool = False, + pin_memory: _bool = False, +) -> Tensor: + r""" + full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Creates a tensor of size :attr:`size` filled with :attr:`fill_value`. The + tensor's dtype is inferred from :attr:`fill_value`. + + Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + fill_value (Scalar): the value to fill the output tensor with. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.full((2, 3), 3.141592) + tensor([[ 3.1416, 3.1416, 3.1416], + [ 3.1416, 3.1416, 3.1416]]) + """ + +@overload +def full( + size: _size, + fill_value: Number | _complex, + *, + names: list[str | None], + layout: _layout = strided, + dtype: _dtype | None = None, + device: DeviceLikeType | None = None, + requires_grad: _bool = False, + pin_memory: _bool = False, +) -> Tensor: + r""" + full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Creates a tensor of size :attr:`size` filled with :attr:`fill_value`. The + tensor's dtype is inferred from :attr:`fill_value`. + + Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + fill_value (Scalar): the value to fill the output tensor with. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.full((2, 3), 3.141592) + tensor([[ 3.1416, 3.1416, 3.1416], + [ 3.1416, 3.1416, 3.1416]]) + """ + +@overload +def full( + size: Sequence[_int | SymInt], + fill_value: Number | _complex, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Creates a tensor of size :attr:`size` filled with :attr:`fill_value`. The + tensor's dtype is inferred from :attr:`fill_value`. + + Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + fill_value (Scalar): the value to fill the output tensor with. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.full((2, 3), 3.141592) + tensor([[ 3.1416, 3.1416, 3.1416], + [ 3.1416, 3.1416, 3.1416]]) + """ + +@overload +def full( + size: _size, + fill_value: Number | _complex, + *, + names: Sequence[str | EllipsisType | None] | None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Creates a tensor of size :attr:`size` filled with :attr:`fill_value`. The + tensor's dtype is inferred from :attr:`fill_value`. + + Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + fill_value (Scalar): the value to fill the output tensor with. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.full((2, 3), 3.141592) + tensor([[ 3.1416, 3.1416, 3.1416], + [ 3.1416, 3.1416, 3.1416]]) + """ + +def full_like( + input: Tensor, + fill_value: Number | _complex, + *, + memory_format: memory_format | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + full_like(input, fill_value, \*, dtype=None, layout=torch.strided, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + + Returns a tensor with the same size as :attr:`input` filled with :attr:`fill_value`. + ``torch.full_like(input, fill_value)`` is equivalent to + ``torch.full(input.size(), fill_value, dtype=input.dtype, layout=input.layout, device=input.device)``. + + Args: + input (Tensor): the size of :attr:`input` will determine size of the output tensor. + fill_value: the number to fill the output tensor with. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: if ``None``, defaults to the dtype of :attr:`input`. + layout (:class:`torch.layout`, optional): the desired layout of returned tensor. + Default: if ``None``, defaults to the layout of :attr:`input`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, defaults to the device of :attr:`input`. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + +def fused_moving_avg_obs_fake_quant( + input: Tensor, + observer_on: Tensor, + fake_quant_on: Tensor, + running_min: Tensor, + running_max: Tensor, + scale: Tensor, + zero_point: Tensor, + averaging_const: _float, + quant_min: _int, + quant_max: _int, + ch_axis: _int, + per_row_fake_quant: _bool = False, + symmetric_quant: _bool = False, +) -> Tensor: ... +@overload +def gather( + input: Tensor, + dim: _int, + index: Tensor, + *, + sparse_grad: _bool = False, + out: Tensor | None = None, +) -> Tensor: + r""" + gather(input, dim, index, *, sparse_grad=False, out=None) -> Tensor + + Gathers values along an axis specified by `dim`. + + For a 3-D tensor the output is specified by:: + + out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 + out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 + out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2 + + :attr:`input` and :attr:`index` must have the same number of dimensions. + It is also required that ``index.size(d) <= input.size(d)`` for all + dimensions ``d != dim``. :attr:`out` will have the same shape as :attr:`index`. + Note that ``input`` and ``index`` do not broadcast against each other. + + Args: + input (Tensor): the source tensor + dim (int): the axis along which to index + index (LongTensor): the indices of elements to gather + + Keyword arguments: + sparse_grad (bool, optional): If ``True``, gradient w.r.t. :attr:`input` will be a sparse tensor. + out (Tensor, optional): the destination tensor + + Example:: + + >>> t = torch.tensor([[1, 2], [3, 4]]) + >>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]])) + tensor([[ 1, 1], + [ 4, 3]]) + """ + +@overload +def gather( + input: Tensor, + dim: str | EllipsisType | None, + index: Tensor, + *, + sparse_grad: _bool = False, + out: Tensor | None = None, +) -> Tensor: + r""" + gather(input, dim, index, *, sparse_grad=False, out=None) -> Tensor + + Gathers values along an axis specified by `dim`. + + For a 3-D tensor the output is specified by:: + + out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 + out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 + out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2 + + :attr:`input` and :attr:`index` must have the same number of dimensions. + It is also required that ``index.size(d) <= input.size(d)`` for all + dimensions ``d != dim``. :attr:`out` will have the same shape as :attr:`index`. + Note that ``input`` and ``index`` do not broadcast against each other. + + Args: + input (Tensor): the source tensor + dim (int): the axis along which to index + index (LongTensor): the indices of elements to gather + + Keyword arguments: + sparse_grad (bool, optional): If ``True``, gradient w.r.t. :attr:`input` will be a sparse tensor. + out (Tensor, optional): the destination tensor + + Example:: + + >>> t = torch.tensor([[1, 2], [3, 4]]) + >>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]])) + tensor([[ 1, 1], + [ 4, 3]]) + """ + +def gcd( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + gcd(input, other, *, out=None) -> Tensor + + Computes the element-wise greatest common divisor (GCD) of :attr:`input` and :attr:`other`. + + Both :attr:`input` and :attr:`other` must have integer types. + + .. note:: + This defines :math:`gcd(0, 0) = 0`. + + Args: + input (Tensor): the input tensor. + other (Tensor): the second input tensor + + Keyword arguments: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([5, 10, 15]) + >>> b = torch.tensor([3, 4, 5]) + >>> torch.gcd(a, b) + tensor([1, 2, 5]) + >>> c = torch.tensor([3]) + >>> torch.gcd(a, c) + tensor([1, 1, 3]) + """ + +def gcd_(input: Tensor, other: Tensor) -> Tensor: ... +@overload +def ge( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + ge(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} \geq \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is greater than or equal to :attr:`other` and False elsewhere + + Example:: + + >>> torch.ge(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[True, True], [False, True]]) + """ + +@overload +def ge( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + ge(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} \geq \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is greater than or equal to :attr:`other` and False elsewhere + + Example:: + + >>> torch.ge(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[True, True], [False, True]]) + """ + +def geqrf( + input: Tensor, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.geqrf: + r""" + geqrf(input, *, out=None) -> (Tensor, Tensor) + + This is a low-level function for calling LAPACK's geqrf directly. This function + returns a namedtuple (a, tau) as defined in `LAPACK documentation for geqrf`_ . + + Computes a QR decomposition of :attr:`input`. + Both `Q` and `R` matrices are stored in the same output tensor `a`. + The elements of `R` are stored on and above the diagonal. + Elementary reflectors (or Householder vectors) implicitly defining matrix `Q` + are stored below the diagonal. + The results of this function can be used together with :func:`torch.linalg.householder_product` + to obtain the `Q` matrix or + with :func:`torch.ormqr`, which uses an implicit representation of the `Q` matrix, + for an efficient matrix-matrix multiplication. + + See `LAPACK documentation for geqrf`_ for further details. + + .. note:: + See also :func:`torch.linalg.qr`, which computes Q and R matrices, and :func:`torch.linalg.lstsq` + with the ``driver="gels"`` option for a function that can solve matrix equations using a QR decomposition. + + Args: + input (Tensor): the input matrix + + Keyword args: + out (tuple, optional): the output tuple of (Tensor, Tensor). Ignored if `None`. Default: `None`. + + .. _LAPACK documentation for geqrf: + http://www.netlib.org/lapack/explore-html/df/dc5/group__variants_g_ecomputational_ga3766ea903391b5cf9008132f7440ec7b.html + """ + +def ger( + input: Tensor, + vec2: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + ger(input, vec2, *, out=None) -> Tensor + + Alias of :func:`torch.outer`. + + .. warning:: + This function is deprecated and will be removed in a future PyTorch release. + Use :func:`torch.outer` instead. + """ + +def get_default_dtype() -> _dtype: + r""" + get_default_dtype() -> torch.dtype + + Get the current default floating point :class:`torch.dtype`. + + Example:: + + >>> torch.get_default_dtype() # initial default for floating point is torch.float32 + torch.float32 + >>> torch.set_default_dtype(torch.float64) + >>> torch.get_default_dtype() # default is now changed to torch.float64 + torch.float64 + """ + +def get_num_interop_threads() -> _int: + r""" + get_num_interop_threads() -> int + + Returns the number of threads used for inter-op parallelism on CPU + (e.g. in JIT interpreter) + """ + +def get_num_threads() -> _int: + r""" + get_num_threads() -> int + + Returns the number of threads used for parallelizing CPU operations + """ + +@overload +def gradient( + input: Tensor, + *, + spacing: Number | _complex | None = None, + dim: _int | None = None, + edge_order: _int = 1, +) -> tuple[Tensor, ...]: + r""" + gradient(input, *, spacing=1, dim=None, edge_order=1) -> List of Tensors + + Estimates the gradient of a function :math:`g : \mathbb{R}^n \rightarrow \mathbb{R}` in + one or more dimensions using the `second-order accurate central differences method + `_ and + either first or second order estimates at the boundaries. + + The gradient of :math:`g` is estimated using samples. By default, when :attr:`spacing` is not + specified, the samples are entirely described by :attr:`input`, and the mapping of input coordinates + to an output is the same as the tensor's mapping of indices to values. For example, for a three-dimensional + :attr:`input` the function described is :math:`g : \mathbb{R}^3 \rightarrow \mathbb{R}`, and + :math:`g(1, 2, 3)\ == input[1, 2, 3]`. + + When :attr:`spacing` is specified, it modifies the relationship between :attr:`input` and input coordinates. + This is detailed in the "Keyword Arguments" section below. + + The gradient is estimated by estimating each partial derivative of :math:`g` independently. This estimation is + accurate if :math:`g` is in :math:`C^3` (it has at least 3 continuous derivatives), and the estimation can be + improved by providing closer samples. Mathematically, the value at each interior point of a partial derivative + is estimated using `Taylor's theorem with remainder `_. + Letting :math:`x` be an interior point with :math:`x-h_l` and :math:`x+h_r` be points neighboring + it to the left and right respectively, :math:`f(x+h_r)` and :math:`f(x-h_l)` can be estimated using: + + .. math:: + \begin{aligned} + f(x+h_r) = f(x) + h_r f'(x) + {h_r}^2 \frac{f''(x)}{2} + {h_r}^3 \frac{f'''(\xi_1)}{6}, \xi_1 \in (x, x+h_r) \\ + f(x-h_l) = f(x) - h_l f'(x) + {h_l}^2 \frac{f''(x)}{2} - {h_l}^3 \frac{f'''(\xi_2)}{6}, \xi_2 \in (x, x-h_l) \\ + \end{aligned} + + Using the fact that :math:`f \in C^3` and solving the linear system, we derive: + + .. math:: + f'(x) \approx \frac{ {h_l}^2 f(x+h_r) - {h_r}^2 f(x-h_l) + + ({h_r}^2-{h_l}^2 ) f(x) }{ {h_r} {h_l}^2 + {h_r}^2 {h_l} } + + .. note:: + We estimate the gradient of functions in complex domain + :math:`g : \mathbb{C}^n \rightarrow \mathbb{C}` in the same way. + + The value of each partial derivative at the boundary points is computed differently. See edge_order below. + + Args: + input (``Tensor``): the tensor that represents the values of the function + + Keyword args: + spacing (``scalar``, ``list of scalar``, ``list of Tensor``, optional): :attr:`spacing` can be used to modify + how the :attr:`input` tensor's indices relate to sample coordinates. If :attr:`spacing` is a scalar then + the indices are multiplied by the scalar to produce the coordinates. For example, if :attr:`spacing=2` the + indices (1, 2, 3) become coordinates (2, 4, 6). If :attr:`spacing` is a list of scalars then the corresponding + indices are multiplied. For example, if :attr:`spacing=(2, -1, 3)` the indices (1, 2, 3) become coordinates (2, -2, 9). + Finally, if :attr:`spacing` is a list of one-dimensional tensors then each tensor specifies the coordinates for + the corresponding dimension. For example, if the indices are (1, 2, 3) and the tensors are (t0, t1, t2), then + the coordinates are (t0[1], t1[2], t2[3]) + + dim (``int``, ``list of int``, optional): the dimension or dimensions to approximate the gradient over. By default + the partial gradient in every dimension is computed. Note that when :attr:`dim` is specified the elements of + the :attr:`spacing` argument must correspond with the specified dims." + + edge_order (``int``, optional): 1 or 2, for `first-order + `_ or + `second-order `_ + estimation of the boundary ("edge") values, respectively. + + Examples:: + + >>> # Estimates the gradient of f(x)=x^2 at points [-2, -1, 2, 4] + >>> coordinates = (torch.tensor([-2., -1., 1., 4.]),) + >>> values = torch.tensor([4., 1., 1., 16.], ) + >>> torch.gradient(values, spacing = coordinates) + (tensor([-3., -2., 2., 5.]),) + + >>> # Estimates the gradient of the R^2 -> R function whose samples are + >>> # described by the tensor t. Implicit coordinates are [0, 1] for the outermost + >>> # dimension and [0, 1, 2, 3] for the innermost dimension, and function estimates + >>> # partial derivative for both dimensions. + >>> t = torch.tensor([[1, 2, 4, 8], [10, 20, 40, 80]]) + >>> torch.gradient(t) + (tensor([[ 9., 18., 36., 72.], + [ 9., 18., 36., 72.]]), + tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]])) + + >>> # A scalar value for spacing modifies the relationship between tensor indices + >>> # and input coordinates by multiplying the indices to find the + >>> # coordinates. For example, below the indices of the innermost + >>> # 0, 1, 2, 3 translate to coordinates of [0, 2, 4, 6], and the indices of + >>> # the outermost dimension 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = 2.0) # dim = None (implicitly [0, 1]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.5000, 0.7500, 1.5000, 2.0000], + [ 5.0000, 7.5000, 15.0000, 20.0000]])) + >>> # doubling the spacing between samples halves the estimated partial gradients. + + >>> + >>> # Estimates only the partial derivative for dimension 1 + >>> torch.gradient(t, dim = 1) # spacing = None (implicitly 1.) + (tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]]),) + + >>> # When spacing is a list of scalars, the relationship between the tensor + >>> # indices and input coordinates changes based on dimension. + >>> # For example, below, the indices of the innermost dimension 0, 1, 2, 3 translate + >>> # to coordinates of [0, 3, 6, 9], and the indices of the outermost dimension + >>> # 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = [3., 2.]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + + >>> # The following example is a replication of the previous one with explicit + >>> # coordinates. + >>> coords = (torch.tensor([0, 2]), torch.tensor([0, 3, 6, 9])) + >>> torch.gradient(t, spacing = coords) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + """ + +@overload +def gradient( + input: Tensor, + *, + spacing: Sequence[Number | _complex], + dim: _int | None = None, + edge_order: _int = 1, +) -> tuple[Tensor, ...]: + r""" + gradient(input, *, spacing=1, dim=None, edge_order=1) -> List of Tensors + + Estimates the gradient of a function :math:`g : \mathbb{R}^n \rightarrow \mathbb{R}` in + one or more dimensions using the `second-order accurate central differences method + `_ and + either first or second order estimates at the boundaries. + + The gradient of :math:`g` is estimated using samples. By default, when :attr:`spacing` is not + specified, the samples are entirely described by :attr:`input`, and the mapping of input coordinates + to an output is the same as the tensor's mapping of indices to values. For example, for a three-dimensional + :attr:`input` the function described is :math:`g : \mathbb{R}^3 \rightarrow \mathbb{R}`, and + :math:`g(1, 2, 3)\ == input[1, 2, 3]`. + + When :attr:`spacing` is specified, it modifies the relationship between :attr:`input` and input coordinates. + This is detailed in the "Keyword Arguments" section below. + + The gradient is estimated by estimating each partial derivative of :math:`g` independently. This estimation is + accurate if :math:`g` is in :math:`C^3` (it has at least 3 continuous derivatives), and the estimation can be + improved by providing closer samples. Mathematically, the value at each interior point of a partial derivative + is estimated using `Taylor's theorem with remainder `_. + Letting :math:`x` be an interior point with :math:`x-h_l` and :math:`x+h_r` be points neighboring + it to the left and right respectively, :math:`f(x+h_r)` and :math:`f(x-h_l)` can be estimated using: + + .. math:: + \begin{aligned} + f(x+h_r) = f(x) + h_r f'(x) + {h_r}^2 \frac{f''(x)}{2} + {h_r}^3 \frac{f'''(\xi_1)}{6}, \xi_1 \in (x, x+h_r) \\ + f(x-h_l) = f(x) - h_l f'(x) + {h_l}^2 \frac{f''(x)}{2} - {h_l}^3 \frac{f'''(\xi_2)}{6}, \xi_2 \in (x, x-h_l) \\ + \end{aligned} + + Using the fact that :math:`f \in C^3` and solving the linear system, we derive: + + .. math:: + f'(x) \approx \frac{ {h_l}^2 f(x+h_r) - {h_r}^2 f(x-h_l) + + ({h_r}^2-{h_l}^2 ) f(x) }{ {h_r} {h_l}^2 + {h_r}^2 {h_l} } + + .. note:: + We estimate the gradient of functions in complex domain + :math:`g : \mathbb{C}^n \rightarrow \mathbb{C}` in the same way. + + The value of each partial derivative at the boundary points is computed differently. See edge_order below. + + Args: + input (``Tensor``): the tensor that represents the values of the function + + Keyword args: + spacing (``scalar``, ``list of scalar``, ``list of Tensor``, optional): :attr:`spacing` can be used to modify + how the :attr:`input` tensor's indices relate to sample coordinates. If :attr:`spacing` is a scalar then + the indices are multiplied by the scalar to produce the coordinates. For example, if :attr:`spacing=2` the + indices (1, 2, 3) become coordinates (2, 4, 6). If :attr:`spacing` is a list of scalars then the corresponding + indices are multiplied. For example, if :attr:`spacing=(2, -1, 3)` the indices (1, 2, 3) become coordinates (2, -2, 9). + Finally, if :attr:`spacing` is a list of one-dimensional tensors then each tensor specifies the coordinates for + the corresponding dimension. For example, if the indices are (1, 2, 3) and the tensors are (t0, t1, t2), then + the coordinates are (t0[1], t1[2], t2[3]) + + dim (``int``, ``list of int``, optional): the dimension or dimensions to approximate the gradient over. By default + the partial gradient in every dimension is computed. Note that when :attr:`dim` is specified the elements of + the :attr:`spacing` argument must correspond with the specified dims." + + edge_order (``int``, optional): 1 or 2, for `first-order + `_ or + `second-order `_ + estimation of the boundary ("edge") values, respectively. + + Examples:: + + >>> # Estimates the gradient of f(x)=x^2 at points [-2, -1, 2, 4] + >>> coordinates = (torch.tensor([-2., -1., 1., 4.]),) + >>> values = torch.tensor([4., 1., 1., 16.], ) + >>> torch.gradient(values, spacing = coordinates) + (tensor([-3., -2., 2., 5.]),) + + >>> # Estimates the gradient of the R^2 -> R function whose samples are + >>> # described by the tensor t. Implicit coordinates are [0, 1] for the outermost + >>> # dimension and [0, 1, 2, 3] for the innermost dimension, and function estimates + >>> # partial derivative for both dimensions. + >>> t = torch.tensor([[1, 2, 4, 8], [10, 20, 40, 80]]) + >>> torch.gradient(t) + (tensor([[ 9., 18., 36., 72.], + [ 9., 18., 36., 72.]]), + tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]])) + + >>> # A scalar value for spacing modifies the relationship between tensor indices + >>> # and input coordinates by multiplying the indices to find the + >>> # coordinates. For example, below the indices of the innermost + >>> # 0, 1, 2, 3 translate to coordinates of [0, 2, 4, 6], and the indices of + >>> # the outermost dimension 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = 2.0) # dim = None (implicitly [0, 1]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.5000, 0.7500, 1.5000, 2.0000], + [ 5.0000, 7.5000, 15.0000, 20.0000]])) + >>> # doubling the spacing between samples halves the estimated partial gradients. + + >>> + >>> # Estimates only the partial derivative for dimension 1 + >>> torch.gradient(t, dim = 1) # spacing = None (implicitly 1.) + (tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]]),) + + >>> # When spacing is a list of scalars, the relationship between the tensor + >>> # indices and input coordinates changes based on dimension. + >>> # For example, below, the indices of the innermost dimension 0, 1, 2, 3 translate + >>> # to coordinates of [0, 3, 6, 9], and the indices of the outermost dimension + >>> # 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = [3., 2.]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + + >>> # The following example is a replication of the previous one with explicit + >>> # coordinates. + >>> coords = (torch.tensor([0, 2]), torch.tensor([0, 3, 6, 9])) + >>> torch.gradient(t, spacing = coords) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + """ + +@overload +def gradient( + input: Tensor, + *, + spacing: Sequence[Number | _complex], + dim: _size, + edge_order: _int = 1, +) -> tuple[Tensor, ...]: + r""" + gradient(input, *, spacing=1, dim=None, edge_order=1) -> List of Tensors + + Estimates the gradient of a function :math:`g : \mathbb{R}^n \rightarrow \mathbb{R}` in + one or more dimensions using the `second-order accurate central differences method + `_ and + either first or second order estimates at the boundaries. + + The gradient of :math:`g` is estimated using samples. By default, when :attr:`spacing` is not + specified, the samples are entirely described by :attr:`input`, and the mapping of input coordinates + to an output is the same as the tensor's mapping of indices to values. For example, for a three-dimensional + :attr:`input` the function described is :math:`g : \mathbb{R}^3 \rightarrow \mathbb{R}`, and + :math:`g(1, 2, 3)\ == input[1, 2, 3]`. + + When :attr:`spacing` is specified, it modifies the relationship between :attr:`input` and input coordinates. + This is detailed in the "Keyword Arguments" section below. + + The gradient is estimated by estimating each partial derivative of :math:`g` independently. This estimation is + accurate if :math:`g` is in :math:`C^3` (it has at least 3 continuous derivatives), and the estimation can be + improved by providing closer samples. Mathematically, the value at each interior point of a partial derivative + is estimated using `Taylor's theorem with remainder `_. + Letting :math:`x` be an interior point with :math:`x-h_l` and :math:`x+h_r` be points neighboring + it to the left and right respectively, :math:`f(x+h_r)` and :math:`f(x-h_l)` can be estimated using: + + .. math:: + \begin{aligned} + f(x+h_r) = f(x) + h_r f'(x) + {h_r}^2 \frac{f''(x)}{2} + {h_r}^3 \frac{f'''(\xi_1)}{6}, \xi_1 \in (x, x+h_r) \\ + f(x-h_l) = f(x) - h_l f'(x) + {h_l}^2 \frac{f''(x)}{2} - {h_l}^3 \frac{f'''(\xi_2)}{6}, \xi_2 \in (x, x-h_l) \\ + \end{aligned} + + Using the fact that :math:`f \in C^3` and solving the linear system, we derive: + + .. math:: + f'(x) \approx \frac{ {h_l}^2 f(x+h_r) - {h_r}^2 f(x-h_l) + + ({h_r}^2-{h_l}^2 ) f(x) }{ {h_r} {h_l}^2 + {h_r}^2 {h_l} } + + .. note:: + We estimate the gradient of functions in complex domain + :math:`g : \mathbb{C}^n \rightarrow \mathbb{C}` in the same way. + + The value of each partial derivative at the boundary points is computed differently. See edge_order below. + + Args: + input (``Tensor``): the tensor that represents the values of the function + + Keyword args: + spacing (``scalar``, ``list of scalar``, ``list of Tensor``, optional): :attr:`spacing` can be used to modify + how the :attr:`input` tensor's indices relate to sample coordinates. If :attr:`spacing` is a scalar then + the indices are multiplied by the scalar to produce the coordinates. For example, if :attr:`spacing=2` the + indices (1, 2, 3) become coordinates (2, 4, 6). If :attr:`spacing` is a list of scalars then the corresponding + indices are multiplied. For example, if :attr:`spacing=(2, -1, 3)` the indices (1, 2, 3) become coordinates (2, -2, 9). + Finally, if :attr:`spacing` is a list of one-dimensional tensors then each tensor specifies the coordinates for + the corresponding dimension. For example, if the indices are (1, 2, 3) and the tensors are (t0, t1, t2), then + the coordinates are (t0[1], t1[2], t2[3]) + + dim (``int``, ``list of int``, optional): the dimension or dimensions to approximate the gradient over. By default + the partial gradient in every dimension is computed. Note that when :attr:`dim` is specified the elements of + the :attr:`spacing` argument must correspond with the specified dims." + + edge_order (``int``, optional): 1 or 2, for `first-order + `_ or + `second-order `_ + estimation of the boundary ("edge") values, respectively. + + Examples:: + + >>> # Estimates the gradient of f(x)=x^2 at points [-2, -1, 2, 4] + >>> coordinates = (torch.tensor([-2., -1., 1., 4.]),) + >>> values = torch.tensor([4., 1., 1., 16.], ) + >>> torch.gradient(values, spacing = coordinates) + (tensor([-3., -2., 2., 5.]),) + + >>> # Estimates the gradient of the R^2 -> R function whose samples are + >>> # described by the tensor t. Implicit coordinates are [0, 1] for the outermost + >>> # dimension and [0, 1, 2, 3] for the innermost dimension, and function estimates + >>> # partial derivative for both dimensions. + >>> t = torch.tensor([[1, 2, 4, 8], [10, 20, 40, 80]]) + >>> torch.gradient(t) + (tensor([[ 9., 18., 36., 72.], + [ 9., 18., 36., 72.]]), + tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]])) + + >>> # A scalar value for spacing modifies the relationship between tensor indices + >>> # and input coordinates by multiplying the indices to find the + >>> # coordinates. For example, below the indices of the innermost + >>> # 0, 1, 2, 3 translate to coordinates of [0, 2, 4, 6], and the indices of + >>> # the outermost dimension 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = 2.0) # dim = None (implicitly [0, 1]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.5000, 0.7500, 1.5000, 2.0000], + [ 5.0000, 7.5000, 15.0000, 20.0000]])) + >>> # doubling the spacing between samples halves the estimated partial gradients. + + >>> + >>> # Estimates only the partial derivative for dimension 1 + >>> torch.gradient(t, dim = 1) # spacing = None (implicitly 1.) + (tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]]),) + + >>> # When spacing is a list of scalars, the relationship between the tensor + >>> # indices and input coordinates changes based on dimension. + >>> # For example, below, the indices of the innermost dimension 0, 1, 2, 3 translate + >>> # to coordinates of [0, 3, 6, 9], and the indices of the outermost dimension + >>> # 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = [3., 2.]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + + >>> # The following example is a replication of the previous one with explicit + >>> # coordinates. + >>> coords = (torch.tensor([0, 2]), torch.tensor([0, 3, 6, 9])) + >>> torch.gradient(t, spacing = coords) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + """ + +@overload +def gradient( + input: Tensor, + *, + spacing: tuple[Tensor, ...] | list[Tensor] | None, + dim: _int | None = None, + edge_order: _int = 1, +) -> tuple[Tensor, ...]: + r""" + gradient(input, *, spacing=1, dim=None, edge_order=1) -> List of Tensors + + Estimates the gradient of a function :math:`g : \mathbb{R}^n \rightarrow \mathbb{R}` in + one or more dimensions using the `second-order accurate central differences method + `_ and + either first or second order estimates at the boundaries. + + The gradient of :math:`g` is estimated using samples. By default, when :attr:`spacing` is not + specified, the samples are entirely described by :attr:`input`, and the mapping of input coordinates + to an output is the same as the tensor's mapping of indices to values. For example, for a three-dimensional + :attr:`input` the function described is :math:`g : \mathbb{R}^3 \rightarrow \mathbb{R}`, and + :math:`g(1, 2, 3)\ == input[1, 2, 3]`. + + When :attr:`spacing` is specified, it modifies the relationship between :attr:`input` and input coordinates. + This is detailed in the "Keyword Arguments" section below. + + The gradient is estimated by estimating each partial derivative of :math:`g` independently. This estimation is + accurate if :math:`g` is in :math:`C^3` (it has at least 3 continuous derivatives), and the estimation can be + improved by providing closer samples. Mathematically, the value at each interior point of a partial derivative + is estimated using `Taylor's theorem with remainder `_. + Letting :math:`x` be an interior point with :math:`x-h_l` and :math:`x+h_r` be points neighboring + it to the left and right respectively, :math:`f(x+h_r)` and :math:`f(x-h_l)` can be estimated using: + + .. math:: + \begin{aligned} + f(x+h_r) = f(x) + h_r f'(x) + {h_r}^2 \frac{f''(x)}{2} + {h_r}^3 \frac{f'''(\xi_1)}{6}, \xi_1 \in (x, x+h_r) \\ + f(x-h_l) = f(x) - h_l f'(x) + {h_l}^2 \frac{f''(x)}{2} - {h_l}^3 \frac{f'''(\xi_2)}{6}, \xi_2 \in (x, x-h_l) \\ + \end{aligned} + + Using the fact that :math:`f \in C^3` and solving the linear system, we derive: + + .. math:: + f'(x) \approx \frac{ {h_l}^2 f(x+h_r) - {h_r}^2 f(x-h_l) + + ({h_r}^2-{h_l}^2 ) f(x) }{ {h_r} {h_l}^2 + {h_r}^2 {h_l} } + + .. note:: + We estimate the gradient of functions in complex domain + :math:`g : \mathbb{C}^n \rightarrow \mathbb{C}` in the same way. + + The value of each partial derivative at the boundary points is computed differently. See edge_order below. + + Args: + input (``Tensor``): the tensor that represents the values of the function + + Keyword args: + spacing (``scalar``, ``list of scalar``, ``list of Tensor``, optional): :attr:`spacing` can be used to modify + how the :attr:`input` tensor's indices relate to sample coordinates. If :attr:`spacing` is a scalar then + the indices are multiplied by the scalar to produce the coordinates. For example, if :attr:`spacing=2` the + indices (1, 2, 3) become coordinates (2, 4, 6). If :attr:`spacing` is a list of scalars then the corresponding + indices are multiplied. For example, if :attr:`spacing=(2, -1, 3)` the indices (1, 2, 3) become coordinates (2, -2, 9). + Finally, if :attr:`spacing` is a list of one-dimensional tensors then each tensor specifies the coordinates for + the corresponding dimension. For example, if the indices are (1, 2, 3) and the tensors are (t0, t1, t2), then + the coordinates are (t0[1], t1[2], t2[3]) + + dim (``int``, ``list of int``, optional): the dimension or dimensions to approximate the gradient over. By default + the partial gradient in every dimension is computed. Note that when :attr:`dim` is specified the elements of + the :attr:`spacing` argument must correspond with the specified dims." + + edge_order (``int``, optional): 1 or 2, for `first-order + `_ or + `second-order `_ + estimation of the boundary ("edge") values, respectively. + + Examples:: + + >>> # Estimates the gradient of f(x)=x^2 at points [-2, -1, 2, 4] + >>> coordinates = (torch.tensor([-2., -1., 1., 4.]),) + >>> values = torch.tensor([4., 1., 1., 16.], ) + >>> torch.gradient(values, spacing = coordinates) + (tensor([-3., -2., 2., 5.]),) + + >>> # Estimates the gradient of the R^2 -> R function whose samples are + >>> # described by the tensor t. Implicit coordinates are [0, 1] for the outermost + >>> # dimension and [0, 1, 2, 3] for the innermost dimension, and function estimates + >>> # partial derivative for both dimensions. + >>> t = torch.tensor([[1, 2, 4, 8], [10, 20, 40, 80]]) + >>> torch.gradient(t) + (tensor([[ 9., 18., 36., 72.], + [ 9., 18., 36., 72.]]), + tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]])) + + >>> # A scalar value for spacing modifies the relationship between tensor indices + >>> # and input coordinates by multiplying the indices to find the + >>> # coordinates. For example, below the indices of the innermost + >>> # 0, 1, 2, 3 translate to coordinates of [0, 2, 4, 6], and the indices of + >>> # the outermost dimension 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = 2.0) # dim = None (implicitly [0, 1]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.5000, 0.7500, 1.5000, 2.0000], + [ 5.0000, 7.5000, 15.0000, 20.0000]])) + >>> # doubling the spacing between samples halves the estimated partial gradients. + + >>> + >>> # Estimates only the partial derivative for dimension 1 + >>> torch.gradient(t, dim = 1) # spacing = None (implicitly 1.) + (tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]]),) + + >>> # When spacing is a list of scalars, the relationship between the tensor + >>> # indices and input coordinates changes based on dimension. + >>> # For example, below, the indices of the innermost dimension 0, 1, 2, 3 translate + >>> # to coordinates of [0, 3, 6, 9], and the indices of the outermost dimension + >>> # 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = [3., 2.]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + + >>> # The following example is a replication of the previous one with explicit + >>> # coordinates. + >>> coords = (torch.tensor([0, 2]), torch.tensor([0, 3, 6, 9])) + >>> torch.gradient(t, spacing = coords) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + """ + +@overload +def gradient( + input: Tensor, + *, + spacing: Number | _complex, + dim: _size, + edge_order: _int = 1, +) -> tuple[Tensor, ...]: + r""" + gradient(input, *, spacing=1, dim=None, edge_order=1) -> List of Tensors + + Estimates the gradient of a function :math:`g : \mathbb{R}^n \rightarrow \mathbb{R}` in + one or more dimensions using the `second-order accurate central differences method + `_ and + either first or second order estimates at the boundaries. + + The gradient of :math:`g` is estimated using samples. By default, when :attr:`spacing` is not + specified, the samples are entirely described by :attr:`input`, and the mapping of input coordinates + to an output is the same as the tensor's mapping of indices to values. For example, for a three-dimensional + :attr:`input` the function described is :math:`g : \mathbb{R}^3 \rightarrow \mathbb{R}`, and + :math:`g(1, 2, 3)\ == input[1, 2, 3]`. + + When :attr:`spacing` is specified, it modifies the relationship between :attr:`input` and input coordinates. + This is detailed in the "Keyword Arguments" section below. + + The gradient is estimated by estimating each partial derivative of :math:`g` independently. This estimation is + accurate if :math:`g` is in :math:`C^3` (it has at least 3 continuous derivatives), and the estimation can be + improved by providing closer samples. Mathematically, the value at each interior point of a partial derivative + is estimated using `Taylor's theorem with remainder `_. + Letting :math:`x` be an interior point with :math:`x-h_l` and :math:`x+h_r` be points neighboring + it to the left and right respectively, :math:`f(x+h_r)` and :math:`f(x-h_l)` can be estimated using: + + .. math:: + \begin{aligned} + f(x+h_r) = f(x) + h_r f'(x) + {h_r}^2 \frac{f''(x)}{2} + {h_r}^3 \frac{f'''(\xi_1)}{6}, \xi_1 \in (x, x+h_r) \\ + f(x-h_l) = f(x) - h_l f'(x) + {h_l}^2 \frac{f''(x)}{2} - {h_l}^3 \frac{f'''(\xi_2)}{6}, \xi_2 \in (x, x-h_l) \\ + \end{aligned} + + Using the fact that :math:`f \in C^3` and solving the linear system, we derive: + + .. math:: + f'(x) \approx \frac{ {h_l}^2 f(x+h_r) - {h_r}^2 f(x-h_l) + + ({h_r}^2-{h_l}^2 ) f(x) }{ {h_r} {h_l}^2 + {h_r}^2 {h_l} } + + .. note:: + We estimate the gradient of functions in complex domain + :math:`g : \mathbb{C}^n \rightarrow \mathbb{C}` in the same way. + + The value of each partial derivative at the boundary points is computed differently. See edge_order below. + + Args: + input (``Tensor``): the tensor that represents the values of the function + + Keyword args: + spacing (``scalar``, ``list of scalar``, ``list of Tensor``, optional): :attr:`spacing` can be used to modify + how the :attr:`input` tensor's indices relate to sample coordinates. If :attr:`spacing` is a scalar then + the indices are multiplied by the scalar to produce the coordinates. For example, if :attr:`spacing=2` the + indices (1, 2, 3) become coordinates (2, 4, 6). If :attr:`spacing` is a list of scalars then the corresponding + indices are multiplied. For example, if :attr:`spacing=(2, -1, 3)` the indices (1, 2, 3) become coordinates (2, -2, 9). + Finally, if :attr:`spacing` is a list of one-dimensional tensors then each tensor specifies the coordinates for + the corresponding dimension. For example, if the indices are (1, 2, 3) and the tensors are (t0, t1, t2), then + the coordinates are (t0[1], t1[2], t2[3]) + + dim (``int``, ``list of int``, optional): the dimension or dimensions to approximate the gradient over. By default + the partial gradient in every dimension is computed. Note that when :attr:`dim` is specified the elements of + the :attr:`spacing` argument must correspond with the specified dims." + + edge_order (``int``, optional): 1 or 2, for `first-order + `_ or + `second-order `_ + estimation of the boundary ("edge") values, respectively. + + Examples:: + + >>> # Estimates the gradient of f(x)=x^2 at points [-2, -1, 2, 4] + >>> coordinates = (torch.tensor([-2., -1., 1., 4.]),) + >>> values = torch.tensor([4., 1., 1., 16.], ) + >>> torch.gradient(values, spacing = coordinates) + (tensor([-3., -2., 2., 5.]),) + + >>> # Estimates the gradient of the R^2 -> R function whose samples are + >>> # described by the tensor t. Implicit coordinates are [0, 1] for the outermost + >>> # dimension and [0, 1, 2, 3] for the innermost dimension, and function estimates + >>> # partial derivative for both dimensions. + >>> t = torch.tensor([[1, 2, 4, 8], [10, 20, 40, 80]]) + >>> torch.gradient(t) + (tensor([[ 9., 18., 36., 72.], + [ 9., 18., 36., 72.]]), + tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]])) + + >>> # A scalar value for spacing modifies the relationship between tensor indices + >>> # and input coordinates by multiplying the indices to find the + >>> # coordinates. For example, below the indices of the innermost + >>> # 0, 1, 2, 3 translate to coordinates of [0, 2, 4, 6], and the indices of + >>> # the outermost dimension 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = 2.0) # dim = None (implicitly [0, 1]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.5000, 0.7500, 1.5000, 2.0000], + [ 5.0000, 7.5000, 15.0000, 20.0000]])) + >>> # doubling the spacing between samples halves the estimated partial gradients. + + >>> + >>> # Estimates only the partial derivative for dimension 1 + >>> torch.gradient(t, dim = 1) # spacing = None (implicitly 1.) + (tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]]),) + + >>> # When spacing is a list of scalars, the relationship between the tensor + >>> # indices and input coordinates changes based on dimension. + >>> # For example, below, the indices of the innermost dimension 0, 1, 2, 3 translate + >>> # to coordinates of [0, 3, 6, 9], and the indices of the outermost dimension + >>> # 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = [3., 2.]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + + >>> # The following example is a replication of the previous one with explicit + >>> # coordinates. + >>> coords = (torch.tensor([0, 2]), torch.tensor([0, 3, 6, 9])) + >>> torch.gradient(t, spacing = coords) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + """ + +@overload +def gradient( + input: Tensor, + *, + spacing: tuple[Tensor, ...] | list[Tensor] | None, + dim: _size, + edge_order: _int = 1, +) -> tuple[Tensor, ...]: + r""" + gradient(input, *, spacing=1, dim=None, edge_order=1) -> List of Tensors + + Estimates the gradient of a function :math:`g : \mathbb{R}^n \rightarrow \mathbb{R}` in + one or more dimensions using the `second-order accurate central differences method + `_ and + either first or second order estimates at the boundaries. + + The gradient of :math:`g` is estimated using samples. By default, when :attr:`spacing` is not + specified, the samples are entirely described by :attr:`input`, and the mapping of input coordinates + to an output is the same as the tensor's mapping of indices to values. For example, for a three-dimensional + :attr:`input` the function described is :math:`g : \mathbb{R}^3 \rightarrow \mathbb{R}`, and + :math:`g(1, 2, 3)\ == input[1, 2, 3]`. + + When :attr:`spacing` is specified, it modifies the relationship between :attr:`input` and input coordinates. + This is detailed in the "Keyword Arguments" section below. + + The gradient is estimated by estimating each partial derivative of :math:`g` independently. This estimation is + accurate if :math:`g` is in :math:`C^3` (it has at least 3 continuous derivatives), and the estimation can be + improved by providing closer samples. Mathematically, the value at each interior point of a partial derivative + is estimated using `Taylor's theorem with remainder `_. + Letting :math:`x` be an interior point with :math:`x-h_l` and :math:`x+h_r` be points neighboring + it to the left and right respectively, :math:`f(x+h_r)` and :math:`f(x-h_l)` can be estimated using: + + .. math:: + \begin{aligned} + f(x+h_r) = f(x) + h_r f'(x) + {h_r}^2 \frac{f''(x)}{2} + {h_r}^3 \frac{f'''(\xi_1)}{6}, \xi_1 \in (x, x+h_r) \\ + f(x-h_l) = f(x) - h_l f'(x) + {h_l}^2 \frac{f''(x)}{2} - {h_l}^3 \frac{f'''(\xi_2)}{6}, \xi_2 \in (x, x-h_l) \\ + \end{aligned} + + Using the fact that :math:`f \in C^3` and solving the linear system, we derive: + + .. math:: + f'(x) \approx \frac{ {h_l}^2 f(x+h_r) - {h_r}^2 f(x-h_l) + + ({h_r}^2-{h_l}^2 ) f(x) }{ {h_r} {h_l}^2 + {h_r}^2 {h_l} } + + .. note:: + We estimate the gradient of functions in complex domain + :math:`g : \mathbb{C}^n \rightarrow \mathbb{C}` in the same way. + + The value of each partial derivative at the boundary points is computed differently. See edge_order below. + + Args: + input (``Tensor``): the tensor that represents the values of the function + + Keyword args: + spacing (``scalar``, ``list of scalar``, ``list of Tensor``, optional): :attr:`spacing` can be used to modify + how the :attr:`input` tensor's indices relate to sample coordinates. If :attr:`spacing` is a scalar then + the indices are multiplied by the scalar to produce the coordinates. For example, if :attr:`spacing=2` the + indices (1, 2, 3) become coordinates (2, 4, 6). If :attr:`spacing` is a list of scalars then the corresponding + indices are multiplied. For example, if :attr:`spacing=(2, -1, 3)` the indices (1, 2, 3) become coordinates (2, -2, 9). + Finally, if :attr:`spacing` is a list of one-dimensional tensors then each tensor specifies the coordinates for + the corresponding dimension. For example, if the indices are (1, 2, 3) and the tensors are (t0, t1, t2), then + the coordinates are (t0[1], t1[2], t2[3]) + + dim (``int``, ``list of int``, optional): the dimension or dimensions to approximate the gradient over. By default + the partial gradient in every dimension is computed. Note that when :attr:`dim` is specified the elements of + the :attr:`spacing` argument must correspond with the specified dims." + + edge_order (``int``, optional): 1 or 2, for `first-order + `_ or + `second-order `_ + estimation of the boundary ("edge") values, respectively. + + Examples:: + + >>> # Estimates the gradient of f(x)=x^2 at points [-2, -1, 2, 4] + >>> coordinates = (torch.tensor([-2., -1., 1., 4.]),) + >>> values = torch.tensor([4., 1., 1., 16.], ) + >>> torch.gradient(values, spacing = coordinates) + (tensor([-3., -2., 2., 5.]),) + + >>> # Estimates the gradient of the R^2 -> R function whose samples are + >>> # described by the tensor t. Implicit coordinates are [0, 1] for the outermost + >>> # dimension and [0, 1, 2, 3] for the innermost dimension, and function estimates + >>> # partial derivative for both dimensions. + >>> t = torch.tensor([[1, 2, 4, 8], [10, 20, 40, 80]]) + >>> torch.gradient(t) + (tensor([[ 9., 18., 36., 72.], + [ 9., 18., 36., 72.]]), + tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]])) + + >>> # A scalar value for spacing modifies the relationship between tensor indices + >>> # and input coordinates by multiplying the indices to find the + >>> # coordinates. For example, below the indices of the innermost + >>> # 0, 1, 2, 3 translate to coordinates of [0, 2, 4, 6], and the indices of + >>> # the outermost dimension 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = 2.0) # dim = None (implicitly [0, 1]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.5000, 0.7500, 1.5000, 2.0000], + [ 5.0000, 7.5000, 15.0000, 20.0000]])) + >>> # doubling the spacing between samples halves the estimated partial gradients. + + >>> + >>> # Estimates only the partial derivative for dimension 1 + >>> torch.gradient(t, dim = 1) # spacing = None (implicitly 1.) + (tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]]),) + + >>> # When spacing is a list of scalars, the relationship between the tensor + >>> # indices and input coordinates changes based on dimension. + >>> # For example, below, the indices of the innermost dimension 0, 1, 2, 3 translate + >>> # to coordinates of [0, 3, 6, 9], and the indices of the outermost dimension + >>> # 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = [3., 2.]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + + >>> # The following example is a replication of the previous one with explicit + >>> # coordinates. + >>> coords = (torch.tensor([0, 2]), torch.tensor([0, 3, 6, 9])) + >>> torch.gradient(t, spacing = coords) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + """ + +@overload +def gradient( + input: Tensor, + *, + dim: _size, + edge_order: _int = 1, +) -> tuple[Tensor, ...]: + r""" + gradient(input, *, spacing=1, dim=None, edge_order=1) -> List of Tensors + + Estimates the gradient of a function :math:`g : \mathbb{R}^n \rightarrow \mathbb{R}` in + one or more dimensions using the `second-order accurate central differences method + `_ and + either first or second order estimates at the boundaries. + + The gradient of :math:`g` is estimated using samples. By default, when :attr:`spacing` is not + specified, the samples are entirely described by :attr:`input`, and the mapping of input coordinates + to an output is the same as the tensor's mapping of indices to values. For example, for a three-dimensional + :attr:`input` the function described is :math:`g : \mathbb{R}^3 \rightarrow \mathbb{R}`, and + :math:`g(1, 2, 3)\ == input[1, 2, 3]`. + + When :attr:`spacing` is specified, it modifies the relationship between :attr:`input` and input coordinates. + This is detailed in the "Keyword Arguments" section below. + + The gradient is estimated by estimating each partial derivative of :math:`g` independently. This estimation is + accurate if :math:`g` is in :math:`C^3` (it has at least 3 continuous derivatives), and the estimation can be + improved by providing closer samples. Mathematically, the value at each interior point of a partial derivative + is estimated using `Taylor's theorem with remainder `_. + Letting :math:`x` be an interior point with :math:`x-h_l` and :math:`x+h_r` be points neighboring + it to the left and right respectively, :math:`f(x+h_r)` and :math:`f(x-h_l)` can be estimated using: + + .. math:: + \begin{aligned} + f(x+h_r) = f(x) + h_r f'(x) + {h_r}^2 \frac{f''(x)}{2} + {h_r}^3 \frac{f'''(\xi_1)}{6}, \xi_1 \in (x, x+h_r) \\ + f(x-h_l) = f(x) - h_l f'(x) + {h_l}^2 \frac{f''(x)}{2} - {h_l}^3 \frac{f'''(\xi_2)}{6}, \xi_2 \in (x, x-h_l) \\ + \end{aligned} + + Using the fact that :math:`f \in C^3` and solving the linear system, we derive: + + .. math:: + f'(x) \approx \frac{ {h_l}^2 f(x+h_r) - {h_r}^2 f(x-h_l) + + ({h_r}^2-{h_l}^2 ) f(x) }{ {h_r} {h_l}^2 + {h_r}^2 {h_l} } + + .. note:: + We estimate the gradient of functions in complex domain + :math:`g : \mathbb{C}^n \rightarrow \mathbb{C}` in the same way. + + The value of each partial derivative at the boundary points is computed differently. See edge_order below. + + Args: + input (``Tensor``): the tensor that represents the values of the function + + Keyword args: + spacing (``scalar``, ``list of scalar``, ``list of Tensor``, optional): :attr:`spacing` can be used to modify + how the :attr:`input` tensor's indices relate to sample coordinates. If :attr:`spacing` is a scalar then + the indices are multiplied by the scalar to produce the coordinates. For example, if :attr:`spacing=2` the + indices (1, 2, 3) become coordinates (2, 4, 6). If :attr:`spacing` is a list of scalars then the corresponding + indices are multiplied. For example, if :attr:`spacing=(2, -1, 3)` the indices (1, 2, 3) become coordinates (2, -2, 9). + Finally, if :attr:`spacing` is a list of one-dimensional tensors then each tensor specifies the coordinates for + the corresponding dimension. For example, if the indices are (1, 2, 3) and the tensors are (t0, t1, t2), then + the coordinates are (t0[1], t1[2], t2[3]) + + dim (``int``, ``list of int``, optional): the dimension or dimensions to approximate the gradient over. By default + the partial gradient in every dimension is computed. Note that when :attr:`dim` is specified the elements of + the :attr:`spacing` argument must correspond with the specified dims." + + edge_order (``int``, optional): 1 or 2, for `first-order + `_ or + `second-order `_ + estimation of the boundary ("edge") values, respectively. + + Examples:: + + >>> # Estimates the gradient of f(x)=x^2 at points [-2, -1, 2, 4] + >>> coordinates = (torch.tensor([-2., -1., 1., 4.]),) + >>> values = torch.tensor([4., 1., 1., 16.], ) + >>> torch.gradient(values, spacing = coordinates) + (tensor([-3., -2., 2., 5.]),) + + >>> # Estimates the gradient of the R^2 -> R function whose samples are + >>> # described by the tensor t. Implicit coordinates are [0, 1] for the outermost + >>> # dimension and [0, 1, 2, 3] for the innermost dimension, and function estimates + >>> # partial derivative for both dimensions. + >>> t = torch.tensor([[1, 2, 4, 8], [10, 20, 40, 80]]) + >>> torch.gradient(t) + (tensor([[ 9., 18., 36., 72.], + [ 9., 18., 36., 72.]]), + tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]])) + + >>> # A scalar value for spacing modifies the relationship between tensor indices + >>> # and input coordinates by multiplying the indices to find the + >>> # coordinates. For example, below the indices of the innermost + >>> # 0, 1, 2, 3 translate to coordinates of [0, 2, 4, 6], and the indices of + >>> # the outermost dimension 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = 2.0) # dim = None (implicitly [0, 1]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.5000, 0.7500, 1.5000, 2.0000], + [ 5.0000, 7.5000, 15.0000, 20.0000]])) + >>> # doubling the spacing between samples halves the estimated partial gradients. + + >>> + >>> # Estimates only the partial derivative for dimension 1 + >>> torch.gradient(t, dim = 1) # spacing = None (implicitly 1.) + (tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]]),) + + >>> # When spacing is a list of scalars, the relationship between the tensor + >>> # indices and input coordinates changes based on dimension. + >>> # For example, below, the indices of the innermost dimension 0, 1, 2, 3 translate + >>> # to coordinates of [0, 3, 6, 9], and the indices of the outermost dimension + >>> # 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = [3., 2.]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + + >>> # The following example is a replication of the previous one with explicit + >>> # coordinates. + >>> coords = (torch.tensor([0, 2]), torch.tensor([0, 3, 6, 9])) + >>> torch.gradient(t, spacing = coords) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + """ + +@overload +def greater( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + greater(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.gt`. + """ + +@overload +def greater( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + greater(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.gt`. + """ + +@overload +def greater_equal( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + greater_equal(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.ge`. + """ + +@overload +def greater_equal( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + greater_equal(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.ge`. + """ + +def grid_sampler( + input: Tensor, + grid: Tensor, + interpolation_mode: _int, + padding_mode: _int, + align_corners: _bool, +) -> Tensor: ... +def grid_sampler_2d( + input: Tensor, + grid: Tensor, + interpolation_mode: _int, + padding_mode: _int, + align_corners: _bool, +) -> Tensor: ... +def grid_sampler_3d( + input: Tensor, + grid: Tensor, + interpolation_mode: _int, + padding_mode: _int, + align_corners: _bool, +) -> Tensor: ... +def group_norm( + input: Tensor, + num_groups: _int, + weight: Tensor | None = None, + bias: Tensor | None = None, + eps: _float = 1e-05, + cudnn_enabled: _bool = True, +) -> Tensor: ... +@overload +def gru( + data: Tensor, + batch_sizes: Tensor, + hx: Tensor, + params: tuple[Tensor, ...] | list[Tensor] | None, + has_biases: _bool, + num_layers: _int, + dropout: _float, + train: _bool, + bidirectional: _bool, +) -> tuple[Tensor, Tensor]: ... +@overload +def gru( + input: Tensor, + hx: Tensor, + params: tuple[Tensor, ...] | list[Tensor] | None, + has_biases: _bool, + num_layers: _int, + dropout: _float, + train: _bool, + bidirectional: _bool, + batch_first: _bool, +) -> tuple[Tensor, Tensor]: ... +def gru_cell( + input: Tensor, + hx: Tensor, + w_ih: Tensor, + w_hh: Tensor, + b_ih: Tensor | None = None, + b_hh: Tensor | None = None, +) -> Tensor: ... +@overload +def gt( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + gt(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} > \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is greater than :attr:`other` and False elsewhere + + Example:: + + >>> torch.gt(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[False, True], [False, False]]) + """ + +@overload +def gt( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + gt(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} > \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is greater than :attr:`other` and False elsewhere + + Example:: + + >>> torch.gt(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[False, True], [False, False]]) + """ + +@overload +def hamming_window( + window_length: _int, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + hamming_window(window_length, periodic=True, alpha=0.54, beta=0.46, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Hamming window function. + + .. math:: + w[n] = \alpha - \beta\ \cos \left( \frac{2 \pi n}{N - 1} \right), + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.hamming_window(L, periodic=True)`` equal to + ``torch.hamming_window(L + 1, periodic=False)[:-1])``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + .. note:: + This is a generalized version of :meth:`torch.hann_window`. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + alpha (float, optional): The coefficient :math:`\alpha` in the equation above + beta (float, optional): The coefficient :math:`\beta` in the equation above + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window. + """ + +@overload +def hamming_window( + window_length: _int, + periodic: _bool, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + hamming_window(window_length, periodic=True, alpha=0.54, beta=0.46, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Hamming window function. + + .. math:: + w[n] = \alpha - \beta\ \cos \left( \frac{2 \pi n}{N - 1} \right), + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.hamming_window(L, periodic=True)`` equal to + ``torch.hamming_window(L + 1, periodic=False)[:-1])``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + .. note:: + This is a generalized version of :meth:`torch.hann_window`. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + alpha (float, optional): The coefficient :math:`\alpha` in the equation above + beta (float, optional): The coefficient :math:`\beta` in the equation above + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window. + """ + +@overload +def hamming_window( + window_length: _int, + periodic: _bool, + alpha: _float, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + hamming_window(window_length, periodic=True, alpha=0.54, beta=0.46, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Hamming window function. + + .. math:: + w[n] = \alpha - \beta\ \cos \left( \frac{2 \pi n}{N - 1} \right), + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.hamming_window(L, periodic=True)`` equal to + ``torch.hamming_window(L + 1, periodic=False)[:-1])``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + .. note:: + This is a generalized version of :meth:`torch.hann_window`. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + alpha (float, optional): The coefficient :math:`\alpha` in the equation above + beta (float, optional): The coefficient :math:`\beta` in the equation above + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window. + """ + +@overload +def hamming_window( + window_length: _int, + periodic: _bool, + alpha: _float, + beta: _float, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + hamming_window(window_length, periodic=True, alpha=0.54, beta=0.46, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Hamming window function. + + .. math:: + w[n] = \alpha - \beta\ \cos \left( \frac{2 \pi n}{N - 1} \right), + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.hamming_window(L, periodic=True)`` equal to + ``torch.hamming_window(L + 1, periodic=False)[:-1])``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + .. note:: + This is a generalized version of :meth:`torch.hann_window`. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + alpha (float, optional): The coefficient :math:`\alpha` in the equation above + beta (float, optional): The coefficient :math:`\beta` in the equation above + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window. + """ + +@overload +def hann_window( + window_length: _int, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + hann_window(window_length, periodic=True, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Hann window function. + + .. math:: + w[n] = \frac{1}{2}\ \left[1 - \cos \left( \frac{2 \pi n}{N - 1} \right)\right] = + \sin^2 \left( \frac{\pi n}{N - 1} \right), + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.hann_window(L, periodic=True)`` equal to + ``torch.hann_window(L + 1, periodic=False)[:-1])``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window + """ + +@overload +def hann_window( + window_length: _int, + periodic: _bool, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + hann_window(window_length, periodic=True, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Hann window function. + + .. math:: + w[n] = \frac{1}{2}\ \left[1 - \cos \left( \frac{2 \pi n}{N - 1} \right)\right] = + \sin^2 \left( \frac{\pi n}{N - 1} \right), + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.hann_window(L, periodic=True)`` equal to + ``torch.hann_window(L + 1, periodic=False)[:-1])``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window + """ + +def hardshrink( + input: Tensor, + lambd: Number | _complex = 0.5, + *, + out: Tensor | None = None, +) -> Tensor: ... +def heaviside( + input: Tensor, + values: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + heaviside(input, values, *, out=None) -> Tensor + + Computes the Heaviside step function for each element in :attr:`input`. + The Heaviside step function is defined as: + + .. math:: + \text{{heaviside}}(input, values) = \begin{cases} + 0, & \text{if input < 0}\\ + values, & \text{if input == 0}\\ + 1, & \text{if input > 0} + \end{cases} + + + Args: + input (Tensor): the input tensor. + values (Tensor): The values to use where :attr:`input` is zero. + + Keyword arguments: + out (Tensor, optional): the output tensor. + + Example:: + + >>> input = torch.tensor([-1.5, 0, 2.0]) + >>> values = torch.tensor([0.5]) + >>> torch.heaviside(input, values) + tensor([0.0000, 0.5000, 1.0000]) + >>> values = torch.tensor([1.2, -2.0, 3.5]) + >>> torch.heaviside(input, values) + tensor([0., -2., 1.]) + """ + +def hinge_embedding_loss( + input: Tensor, + target: Tensor, + margin: _float = 1.0, + reduction: _int = 1, +) -> Tensor: ... +def histc( + input: Tensor, + bins: _int = 100, + min: Number | _complex = 0, + max: Number | _complex = 0, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + histc(input, bins=100, min=0, max=0, *, out=None) -> Tensor + + Computes the histogram of a tensor. + + The elements are sorted into equal width bins between :attr:`min` and + :attr:`max`. If :attr:`min` and :attr:`max` are both zero, the minimum and + maximum values of the data are used. + + Elements lower than min and higher than max and ``NaN`` elements are ignored. + + Args: + input (Tensor): the input tensor. + bins (int): number of histogram bins + min (Scalar): lower end of the range (inclusive) + max (Scalar): upper end of the range (inclusive) + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + Tensor: Histogram represented as a tensor + + Example:: + + >>> torch.histc(torch.tensor([1., 2, 1]), bins=4, min=0, max=3) + tensor([ 0., 2., 1., 0.]) + """ + +@overload +def histogram( + input: Tensor, + bins: Tensor, + *, + weight: Tensor | None = None, + density: _bool = False, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.histogram: + r""" + histogram(input, bins, *, range=None, weight=None, density=False, out=None) -> (Tensor, Tensor) + + Computes a histogram of the values in a tensor. + + :attr:`bins` can be an integer or a 1D tensor. + + If :attr:`bins` is an int, it specifies the number of equal-width bins. + By default, the lower and upper range of the bins is determined by the + minimum and maximum elements of the input tensor. The :attr:`range` + argument can be provided to specify a range for the bins. + + If :attr:`bins` is a 1D tensor, it specifies the sequence of bin edges + including the rightmost edge. It should contain at least 2 elements + and its elements should be increasing. + + Args: + input (Tensor): the input tensor. + bins: int or 1D Tensor. If int, defines the number of equal-width bins. If tensor, + defines the sequence of bin edges including the rightmost edge. + + Keyword args: + range (tuple of float): Defines the range of the bins. + weight (Tensor): If provided, weight should have the same shape as input. Each value in + input contributes its associated weight towards its bin's result. + density (bool): If False, the result will contain the count (or total weight) in each bin. + If True, the result is the value of the probability density function over the bins, + normalized such that the integral over the range of the bins is 1. + out (Tensor, optional): the output tensor. (tuple, optional): The result tuple of two output tensors (hist, bin_edges). + + Returns: + hist (Tensor): 1D Tensor containing the values of the histogram. + bin_edges(Tensor): 1D Tensor containing the edges of the histogram bins. + + Example:: + + >>> torch.histogram(torch.tensor([1., 2, 1]), bins=4, range=(0., 3.), weight=torch.tensor([1., 2., 4.])) + (tensor([ 0., 5., 2., 0.]), tensor([0., 0.75, 1.5, 2.25, 3.])) + >>> torch.histogram(torch.tensor([1., 2, 1]), bins=4, range=(0., 3.), weight=torch.tensor([1., 2., 4.]), density=True) + (tensor([ 0., 0.9524, 0.3810, 0.]), tensor([0., 0.75, 1.5, 2.25, 3.])) + """ + +@overload +def histogram( + input: Tensor, + bins: _int = 100, + *, + range: Sequence[_float] | None = None, + weight: Tensor | None = None, + density: _bool = False, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.histogram: + r""" + histogram(input, bins, *, range=None, weight=None, density=False, out=None) -> (Tensor, Tensor) + + Computes a histogram of the values in a tensor. + + :attr:`bins` can be an integer or a 1D tensor. + + If :attr:`bins` is an int, it specifies the number of equal-width bins. + By default, the lower and upper range of the bins is determined by the + minimum and maximum elements of the input tensor. The :attr:`range` + argument can be provided to specify a range for the bins. + + If :attr:`bins` is a 1D tensor, it specifies the sequence of bin edges + including the rightmost edge. It should contain at least 2 elements + and its elements should be increasing. + + Args: + input (Tensor): the input tensor. + bins: int or 1D Tensor. If int, defines the number of equal-width bins. If tensor, + defines the sequence of bin edges including the rightmost edge. + + Keyword args: + range (tuple of float): Defines the range of the bins. + weight (Tensor): If provided, weight should have the same shape as input. Each value in + input contributes its associated weight towards its bin's result. + density (bool): If False, the result will contain the count (or total weight) in each bin. + If True, the result is the value of the probability density function over the bins, + normalized such that the integral over the range of the bins is 1. + out (Tensor, optional): the output tensor. (tuple, optional): The result tuple of two output tensors (hist, bin_edges). + + Returns: + hist (Tensor): 1D Tensor containing the values of the histogram. + bin_edges(Tensor): 1D Tensor containing the edges of the histogram bins. + + Example:: + + >>> torch.histogram(torch.tensor([1., 2, 1]), bins=4, range=(0., 3.), weight=torch.tensor([1., 2., 4.])) + (tensor([ 0., 5., 2., 0.]), tensor([0., 0.75, 1.5, 2.25, 3.])) + >>> torch.histogram(torch.tensor([1., 2, 1]), bins=4, range=(0., 3.), weight=torch.tensor([1., 2., 4.]), density=True) + (tensor([ 0., 0.9524, 0.3810, 0.]), tensor([0., 0.75, 1.5, 2.25, 3.])) + """ + +@overload +def histogramdd( + input: Tensor, + bins: _int, + range: Sequence[_float] | None = None, + weight: Tensor | None = None, + density: _bool = False, +) -> torch.return_types.histogramdd: + r""" + histogramdd(input, bins, *, range=None, weight=None, density=False, out=None) -> (Tensor, Tensor[]) + + Computes a multi-dimensional histogram of the values in a tensor. + + Interprets the elements of an input tensor whose innermost dimension has size N + as a collection of N-dimensional points. Maps each of the points into a set of + N-dimensional bins and returns the number of points (or total weight) in each bin. + + :attr:`input` must be a tensor with at least 2 dimensions. + If input has shape (M, N), each of its M rows defines a point in N-dimensional space. + If input has three or more dimensions, all but the last dimension are flattened. + + Each dimension is independently associated with its own strictly increasing sequence + of bin edges. Bin edges may be specified explicitly by passing a sequence of 1D + tensors. Alternatively, bin edges may be constructed automatically by passing a + sequence of integers specifying the number of equal-width bins in each dimension. + + For each N-dimensional point in input: + - Each of its coordinates is binned independently among the bin edges + corresponding to its dimension + - Binning results are combined to identify the N-dimensional bin (if any) + into which the point falls + - If the point falls into a bin, the bin's count (or total weight) is incremented + - Points which do not fall into any bin do not contribute to the output + + :attr:`bins` can be a sequence of N 1D tensors, a sequence of N ints, or a single int. + + If :attr:`bins` is a sequence of N 1D tensors, it explicitly specifies the N sequences + of bin edges. Each 1D tensor should contain a strictly increasing sequence with at + least one element. A sequence of K bin edges defines K-1 bins, explicitly specifying + the left and right edges of all bins. Every bin is exclusive of its left edge. Only + the rightmost bin is inclusive of its right edge. + + If :attr:`bins` is a sequence of N ints, it specifies the number of equal-width bins + in each dimension. By default, the leftmost and rightmost bin edges in each dimension + are determined by the minimum and maximum elements of the input tensor in the + corresponding dimension. The :attr:`range` argument can be provided to manually + specify the leftmost and rightmost bin edges in each dimension. + + If :attr:`bins` is an int, it specifies the number of equal-width bins for all dimensions. + + .. note:: + See also :func:`torch.histogram`, which specifically computes 1D histograms. + While :func:`torch.histogramdd` infers the dimensionality of its bins and + binned values from the shape of :attr:`input`, :func:`torch.histogram` + accepts and flattens :attr:`input` of any shape. + + Args: + input (Tensor): the input tensor. + bins: Tensor[], int[], or int. + If Tensor[], defines the sequences of bin edges. + If int[], defines the number of equal-width bins in each dimension. + If int, defines the number of equal-width bins for all dimensions. + Keyword args: + range (sequence of float): Defines the leftmost and rightmost bin edges + in each dimension. + weight (Tensor): By default, each value in the input has weight 1. If a weight + tensor is passed, each N-dimensional coordinate in input + contributes its associated weight towards its bin's result. + The weight tensor should have the same shape as the :attr:`input` + tensor excluding its innermost dimension N. + density (bool): If False (default), the result will contain the count (or total weight) + in each bin. If True, each count (weight) is divided by the total count + (total weight), then divided by the volume of its associated bin. + Returns: + hist (Tensor): N-dimensional Tensor containing the values of the histogram. + bin_edges(Tensor[]): sequence of N 1D Tensors containing the bin edges. + + Example:: + + >>> torch.histogramdd(torch.tensor([[0., 1.], [1., 0.], [2., 0.], [2., 2.]]), bins=[3, 3], + ... weight=torch.tensor([1., 2., 4., 8.])) + torch.return_types.histogramdd( + hist=tensor([[0., 1., 0.], + [2., 0., 0.], + [4., 0., 8.]]), + bin_edges=(tensor([0.0000, 0.6667, 1.3333, 2.0000]), + tensor([0.0000, 0.6667, 1.3333, 2.0000]))) + + >>> torch.histogramdd(torch.tensor([[0., 0.], [1., 1.], [2., 2.]]), bins=[2, 2], + ... range=[0., 1., 0., 1.], density=True) + torch.return_types.histogramdd( + hist=tensor([[2., 0.], + [0., 2.]]), + bin_edges=(tensor([0.0000, 0.5000, 1.0000]), + tensor([0.0000, 0.5000, 1.0000]))) + """ + +@overload +def histogramdd( + input: Tensor, + bins: _size, + range: Sequence[_float] | None = None, + weight: Tensor | None = None, + density: _bool = False, +) -> torch.return_types.histogramdd: + r""" + histogramdd(input, bins, *, range=None, weight=None, density=False, out=None) -> (Tensor, Tensor[]) + + Computes a multi-dimensional histogram of the values in a tensor. + + Interprets the elements of an input tensor whose innermost dimension has size N + as a collection of N-dimensional points. Maps each of the points into a set of + N-dimensional bins and returns the number of points (or total weight) in each bin. + + :attr:`input` must be a tensor with at least 2 dimensions. + If input has shape (M, N), each of its M rows defines a point in N-dimensional space. + If input has three or more dimensions, all but the last dimension are flattened. + + Each dimension is independently associated with its own strictly increasing sequence + of bin edges. Bin edges may be specified explicitly by passing a sequence of 1D + tensors. Alternatively, bin edges may be constructed automatically by passing a + sequence of integers specifying the number of equal-width bins in each dimension. + + For each N-dimensional point in input: + - Each of its coordinates is binned independently among the bin edges + corresponding to its dimension + - Binning results are combined to identify the N-dimensional bin (if any) + into which the point falls + - If the point falls into a bin, the bin's count (or total weight) is incremented + - Points which do not fall into any bin do not contribute to the output + + :attr:`bins` can be a sequence of N 1D tensors, a sequence of N ints, or a single int. + + If :attr:`bins` is a sequence of N 1D tensors, it explicitly specifies the N sequences + of bin edges. Each 1D tensor should contain a strictly increasing sequence with at + least one element. A sequence of K bin edges defines K-1 bins, explicitly specifying + the left and right edges of all bins. Every bin is exclusive of its left edge. Only + the rightmost bin is inclusive of its right edge. + + If :attr:`bins` is a sequence of N ints, it specifies the number of equal-width bins + in each dimension. By default, the leftmost and rightmost bin edges in each dimension + are determined by the minimum and maximum elements of the input tensor in the + corresponding dimension. The :attr:`range` argument can be provided to manually + specify the leftmost and rightmost bin edges in each dimension. + + If :attr:`bins` is an int, it specifies the number of equal-width bins for all dimensions. + + .. note:: + See also :func:`torch.histogram`, which specifically computes 1D histograms. + While :func:`torch.histogramdd` infers the dimensionality of its bins and + binned values from the shape of :attr:`input`, :func:`torch.histogram` + accepts and flattens :attr:`input` of any shape. + + Args: + input (Tensor): the input tensor. + bins: Tensor[], int[], or int. + If Tensor[], defines the sequences of bin edges. + If int[], defines the number of equal-width bins in each dimension. + If int, defines the number of equal-width bins for all dimensions. + Keyword args: + range (sequence of float): Defines the leftmost and rightmost bin edges + in each dimension. + weight (Tensor): By default, each value in the input has weight 1. If a weight + tensor is passed, each N-dimensional coordinate in input + contributes its associated weight towards its bin's result. + The weight tensor should have the same shape as the :attr:`input` + tensor excluding its innermost dimension N. + density (bool): If False (default), the result will contain the count (or total weight) + in each bin. If True, each count (weight) is divided by the total count + (total weight), then divided by the volume of its associated bin. + Returns: + hist (Tensor): N-dimensional Tensor containing the values of the histogram. + bin_edges(Tensor[]): sequence of N 1D Tensors containing the bin edges. + + Example:: + + >>> torch.histogramdd(torch.tensor([[0., 1.], [1., 0.], [2., 0.], [2., 2.]]), bins=[3, 3], + ... weight=torch.tensor([1., 2., 4., 8.])) + torch.return_types.histogramdd( + hist=tensor([[0., 1., 0.], + [2., 0., 0.], + [4., 0., 8.]]), + bin_edges=(tensor([0.0000, 0.6667, 1.3333, 2.0000]), + tensor([0.0000, 0.6667, 1.3333, 2.0000]))) + + >>> torch.histogramdd(torch.tensor([[0., 0.], [1., 1.], [2., 2.]]), bins=[2, 2], + ... range=[0., 1., 0., 1.], density=True) + torch.return_types.histogramdd( + hist=tensor([[2., 0.], + [0., 2.]]), + bin_edges=(tensor([0.0000, 0.5000, 1.0000]), + tensor([0.0000, 0.5000, 1.0000]))) + """ + +@overload +def histogramdd( + input: Tensor, + bins: tuple[Tensor, ...] | list[Tensor] | None, + range: Sequence[_float] | None = None, + weight: Tensor | None = None, + density: _bool = False, +) -> torch.return_types.histogramdd: + r""" + histogramdd(input, bins, *, range=None, weight=None, density=False, out=None) -> (Tensor, Tensor[]) + + Computes a multi-dimensional histogram of the values in a tensor. + + Interprets the elements of an input tensor whose innermost dimension has size N + as a collection of N-dimensional points. Maps each of the points into a set of + N-dimensional bins and returns the number of points (or total weight) in each bin. + + :attr:`input` must be a tensor with at least 2 dimensions. + If input has shape (M, N), each of its M rows defines a point in N-dimensional space. + If input has three or more dimensions, all but the last dimension are flattened. + + Each dimension is independently associated with its own strictly increasing sequence + of bin edges. Bin edges may be specified explicitly by passing a sequence of 1D + tensors. Alternatively, bin edges may be constructed automatically by passing a + sequence of integers specifying the number of equal-width bins in each dimension. + + For each N-dimensional point in input: + - Each of its coordinates is binned independently among the bin edges + corresponding to its dimension + - Binning results are combined to identify the N-dimensional bin (if any) + into which the point falls + - If the point falls into a bin, the bin's count (or total weight) is incremented + - Points which do not fall into any bin do not contribute to the output + + :attr:`bins` can be a sequence of N 1D tensors, a sequence of N ints, or a single int. + + If :attr:`bins` is a sequence of N 1D tensors, it explicitly specifies the N sequences + of bin edges. Each 1D tensor should contain a strictly increasing sequence with at + least one element. A sequence of K bin edges defines K-1 bins, explicitly specifying + the left and right edges of all bins. Every bin is exclusive of its left edge. Only + the rightmost bin is inclusive of its right edge. + + If :attr:`bins` is a sequence of N ints, it specifies the number of equal-width bins + in each dimension. By default, the leftmost and rightmost bin edges in each dimension + are determined by the minimum and maximum elements of the input tensor in the + corresponding dimension. The :attr:`range` argument can be provided to manually + specify the leftmost and rightmost bin edges in each dimension. + + If :attr:`bins` is an int, it specifies the number of equal-width bins for all dimensions. + + .. note:: + See also :func:`torch.histogram`, which specifically computes 1D histograms. + While :func:`torch.histogramdd` infers the dimensionality of its bins and + binned values from the shape of :attr:`input`, :func:`torch.histogram` + accepts and flattens :attr:`input` of any shape. + + Args: + input (Tensor): the input tensor. + bins: Tensor[], int[], or int. + If Tensor[], defines the sequences of bin edges. + If int[], defines the number of equal-width bins in each dimension. + If int, defines the number of equal-width bins for all dimensions. + Keyword args: + range (sequence of float): Defines the leftmost and rightmost bin edges + in each dimension. + weight (Tensor): By default, each value in the input has weight 1. If a weight + tensor is passed, each N-dimensional coordinate in input + contributes its associated weight towards its bin's result. + The weight tensor should have the same shape as the :attr:`input` + tensor excluding its innermost dimension N. + density (bool): If False (default), the result will contain the count (or total weight) + in each bin. If True, each count (weight) is divided by the total count + (total weight), then divided by the volume of its associated bin. + Returns: + hist (Tensor): N-dimensional Tensor containing the values of the histogram. + bin_edges(Tensor[]): sequence of N 1D Tensors containing the bin edges. + + Example:: + + >>> torch.histogramdd(torch.tensor([[0., 1.], [1., 0.], [2., 0.], [2., 2.]]), bins=[3, 3], + ... weight=torch.tensor([1., 2., 4., 8.])) + torch.return_types.histogramdd( + hist=tensor([[0., 1., 0.], + [2., 0., 0.], + [4., 0., 8.]]), + bin_edges=(tensor([0.0000, 0.6667, 1.3333, 2.0000]), + tensor([0.0000, 0.6667, 1.3333, 2.0000]))) + + >>> torch.histogramdd(torch.tensor([[0., 0.], [1., 1.], [2., 2.]]), bins=[2, 2], + ... range=[0., 1., 0., 1.], density=True) + torch.return_types.histogramdd( + hist=tensor([[2., 0.], + [0., 2.]]), + bin_edges=(tensor([0.0000, 0.5000, 1.0000]), + tensor([0.0000, 0.5000, 1.0000]))) + """ + +def hsmm(input: Tensor, mat2: Tensor) -> Tensor: ... +@overload +def hsplit(input: Tensor, sections: _int) -> tuple[Tensor, ...]: + r""" + hsplit(input, indices_or_sections) -> List of Tensors + + Splits :attr:`input`, a tensor with one or more dimensions, into multiple tensors + horizontally according to :attr:`indices_or_sections`. Each split is a view of + :attr:`input`. + + If :attr:`input` is one dimensional this is equivalent to calling + torch.tensor_split(input, indices_or_sections, dim=0) (the split dimension is + zero), and if :attr:`input` has two or more dimensions it's equivalent to calling + torch.tensor_split(input, indices_or_sections, dim=1) (the split dimension is 1), + except that if :attr:`indices_or_sections` is an integer it must evenly divide + the split dimension or a runtime error will be thrown. + + This function is based on NumPy's :func:`numpy.hsplit`. + + Args: + input (Tensor): tensor to split. + indices_or_sections (int or list or tuple of ints): See argument in :func:`torch.tensor_split`. + + Example:: + + >>> t = torch.arange(16.0).reshape(4,4) + >>> t + tensor([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.], + [12., 13., 14., 15.]]) + >>> torch.hsplit(t, 2) + (tensor([[ 0., 1.], + [ 4., 5.], + [ 8., 9.], + [12., 13.]]), + tensor([[ 2., 3.], + [ 6., 7.], + [10., 11.], + [14., 15.]])) + >>> torch.hsplit(t, [3, 6]) + (tensor([[ 0., 1., 2.], + [ 4., 5., 6.], + [ 8., 9., 10.], + [12., 13., 14.]]), + tensor([[ 3.], + [ 7.], + [11.], + [15.]]), + tensor([], size=(4, 0))) + """ + +@overload +def hsplit(input: Tensor, indices: _size) -> tuple[Tensor, ...]: + r""" + hsplit(input, indices_or_sections) -> List of Tensors + + Splits :attr:`input`, a tensor with one or more dimensions, into multiple tensors + horizontally according to :attr:`indices_or_sections`. Each split is a view of + :attr:`input`. + + If :attr:`input` is one dimensional this is equivalent to calling + torch.tensor_split(input, indices_or_sections, dim=0) (the split dimension is + zero), and if :attr:`input` has two or more dimensions it's equivalent to calling + torch.tensor_split(input, indices_or_sections, dim=1) (the split dimension is 1), + except that if :attr:`indices_or_sections` is an integer it must evenly divide + the split dimension or a runtime error will be thrown. + + This function is based on NumPy's :func:`numpy.hsplit`. + + Args: + input (Tensor): tensor to split. + indices_or_sections (int or list or tuple of ints): See argument in :func:`torch.tensor_split`. + + Example:: + + >>> t = torch.arange(16.0).reshape(4,4) + >>> t + tensor([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.], + [12., 13., 14., 15.]]) + >>> torch.hsplit(t, 2) + (tensor([[ 0., 1.], + [ 4., 5.], + [ 8., 9.], + [12., 13.]]), + tensor([[ 2., 3.], + [ 6., 7.], + [10., 11.], + [14., 15.]])) + >>> torch.hsplit(t, [3, 6]) + (tensor([[ 0., 1., 2.], + [ 4., 5., 6.], + [ 8., 9., 10.], + [12., 13., 14.]]), + tensor([[ 3.], + [ 7.], + [11.], + [15.]]), + tensor([], size=(4, 0))) + """ + +def hspmm( + mat1: Tensor, + mat2: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + hspmm(mat1, mat2, *, out=None) -> Tensor + + Performs a matrix multiplication of a :ref:`sparse COO matrix + ` :attr:`mat1` and a strided matrix :attr:`mat2`. The + result is a (1 + 1)-dimensional :ref:`hybrid COO matrix + `. + + Args: + mat1 (Tensor): the first sparse matrix to be matrix multiplied + mat2 (Tensor): the second strided matrix to be matrix multiplied + + Keyword args: + out (Tensor, optional): the output tensor. + """ + +def hstack( + tensors: tuple[Tensor, ...] | list[Tensor] | None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + hstack(tensors, *, out=None) -> Tensor + + Stack tensors in sequence horizontally (column wise). + + This is equivalent to concatenation along the first axis for 1-D tensors, and along the second axis for all other tensors. + + Args: + tensors (sequence of Tensors): sequence of tensors to concatenate + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([1, 2, 3]) + >>> b = torch.tensor([4, 5, 6]) + >>> torch.hstack((a,b)) + tensor([1, 2, 3, 4, 5, 6]) + >>> a = torch.tensor([[1],[2],[3]]) + >>> b = torch.tensor([[4],[5],[6]]) + >>> torch.hstack((a,b)) + tensor([[1, 4], + [2, 5], + [3, 6]]) + """ + +def hypot( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + hypot(input, other, *, out=None) -> Tensor + + Given the legs of a right triangle, return its hypotenuse. + + .. math:: + \text{out}_{i} = \sqrt{\text{input}_{i}^{2} + \text{other}_{i}^{2}} + + The shapes of ``input`` and ``other`` must be + :ref:`broadcastable `. + + Args: + input (Tensor): the first input tensor + other (Tensor): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.hypot(torch.tensor([4.0]), torch.tensor([3.0, 4.0, 5.0])) + tensor([5.0000, 5.6569, 6.4031]) + """ + +def i0(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + i0(input, *, out=None) -> Tensor + + Alias for :func:`torch.special.i0`. + """ + +def i0_(input: Tensor) -> Tensor: ... +def igamma( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + igamma(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.special.gammainc`. + """ + +def igammac( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + igammac(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.special.gammaincc`. + """ + +def imag(input: Tensor) -> Tensor: + r""" + imag(input) -> Tensor + + Returns a new tensor containing imaginary values of the :attr:`self` tensor. + The returned tensor and :attr:`self` share the same underlying storage. + + .. warning:: + :func:`imag` is only supported for tensors with complex dtypes. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> x=torch.randn(4, dtype=torch.cfloat) + >>> x + tensor([(0.3100+0.3553j), (-0.5445-0.7896j), (-1.6492-0.0633j), (-0.0638-0.8119j)]) + >>> x.imag + tensor([ 0.3553, -0.7896, -0.0633, -0.8119]) + """ + +@overload +def index_add( + input: Tensor, + dim: _int, + index: Tensor, + source: Tensor, + *, + alpha: Number | _complex = 1, + out: Tensor | None = None, +) -> Tensor: + r""" + index_add(input: Tensor, dim: int, index: Tensor, source: Tensor, *, alpha: Union[Number, _complex] = 1, out: Optional[Tensor]) -> Tensor # noqa: B950 + + See :meth:`~Tensor.index_add_` for function description. + """ + +@overload +def index_add( + input: Tensor, + dim: str | EllipsisType | None, + index: Tensor, + source: Tensor, + *, + alpha: Number | _complex = 1, +) -> Tensor: + r""" + index_add(input: Tensor, dim: int, index: Tensor, source: Tensor, *, alpha: Union[Number, _complex] = 1, out: Optional[Tensor]) -> Tensor # noqa: B950 + + See :meth:`~Tensor.index_add_` for function description. + """ + +@overload +def index_copy( + input: Tensor, + dim: _int, + index: Tensor, + source: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + index_copy(input: Tensor, dim: int, index: Tensor, source: Tensor, *, out: Optional[Tensor]) -> Tensor + + See :meth:`~Tensor.index_add_` for function description. + """ + +@overload +def index_copy( + input: Tensor, + dim: str | EllipsisType | None, + index: Tensor, + source: Tensor, +) -> Tensor: + r""" + index_copy(input: Tensor, dim: int, index: Tensor, source: Tensor, *, out: Optional[Tensor]) -> Tensor + + See :meth:`~Tensor.index_add_` for function description. + """ + +@overload +def index_fill( + input: Tensor, + dim: _int, + index: Tensor, + value: Tensor, +) -> Tensor: ... +@overload +def index_fill( + input: Tensor, + dim: str | EllipsisType | None, + index: Tensor, + value: Tensor, +) -> Tensor: ... +@overload +def index_fill( + input: Tensor, + dim: _int, + index: Tensor, + value: Number | _complex, +) -> Tensor: ... +@overload +def index_fill( + input: Tensor, + dim: str | EllipsisType | None, + index: Tensor, + value: Number | _complex, +) -> Tensor: ... +def index_put( + input: Tensor, + indices: tuple[Tensor, ...] | list[Tensor] | None, + values: Tensor, + accumulate: _bool = False, +) -> Tensor: ... +def index_put_( + input: Tensor, + indices: tuple[Tensor, ...] | list[Tensor] | None, + values: Tensor, + accumulate: _bool = False, +) -> Tensor: ... +def index_reduce( + input: Tensor, + dim: _int, + index: Tensor, + source: Tensor, + reduce: str, + *, + include_self: _bool = True, + out: Tensor | None = None, +) -> Tensor: + r""" + index_reduce(input: Tensor, dim: int, index: Tensor, source: Tensor, reduce: str, *, include_self: bool = True, out: Optional[Tensor]) -> Tensor # noqa: B950 + + See :meth:`~Tensor.index_reduce_` for function description. + """ + +@overload +def index_select( + input: Tensor, + dim: _int, + index: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + index_select(input, dim, index, *, out=None) -> Tensor + + Returns a new tensor which indexes the :attr:`input` tensor along dimension + :attr:`dim` using the entries in :attr:`index` which is a `LongTensor`. + + The returned tensor has the same number of dimensions as the original tensor + (:attr:`input`). The :attr:`dim`\ th dimension has the same size as the length + of :attr:`index`; other dimensions have the same size as in the original tensor. + + .. note:: The returned tensor does **not** use the same storage as the original + tensor. If :attr:`out` has a different shape than expected, we + silently change it to the correct shape, reallocating the underlying + storage if necessary. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension in which we index + index (IntTensor or LongTensor): the 1-D tensor containing the indices to index + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> x = torch.randn(3, 4) + >>> x + tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], + [-0.4664, 0.2647, -0.1228, -1.1068], + [-1.1734, -0.6571, 0.7230, -0.6004]]) + >>> indices = torch.tensor([0, 2]) + >>> torch.index_select(x, 0, indices) + tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], + [-1.1734, -0.6571, 0.7230, -0.6004]]) + >>> torch.index_select(x, 1, indices) + tensor([[ 0.1427, -0.5414], + [-0.4664, -0.1228], + [-1.1734, 0.7230]]) + """ + +@overload +def index_select( + input: Tensor, + dim: str | EllipsisType | None, + index: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + index_select(input, dim, index, *, out=None) -> Tensor + + Returns a new tensor which indexes the :attr:`input` tensor along dimension + :attr:`dim` using the entries in :attr:`index` which is a `LongTensor`. + + The returned tensor has the same number of dimensions as the original tensor + (:attr:`input`). The :attr:`dim`\ th dimension has the same size as the length + of :attr:`index`; other dimensions have the same size as in the original tensor. + + .. note:: The returned tensor does **not** use the same storage as the original + tensor. If :attr:`out` has a different shape than expected, we + silently change it to the correct shape, reallocating the underlying + storage if necessary. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension in which we index + index (IntTensor or LongTensor): the 1-D tensor containing the indices to index + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> x = torch.randn(3, 4) + >>> x + tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], + [-0.4664, 0.2647, -0.1228, -1.1068], + [-1.1734, -0.6571, 0.7230, -0.6004]]) + >>> indices = torch.tensor([0, 2]) + >>> torch.index_select(x, 0, indices) + tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], + [-1.1734, -0.6571, 0.7230, -0.6004]]) + >>> torch.index_select(x, 1, indices) + tensor([[ 0.1427, -0.5414], + [-0.4664, -0.1228], + [-1.1734, 0.7230]]) + """ + +def indices_copy(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.indices`, but all output tensors + are freshly created instead of aliasing the input. + """ + +def init_num_threads() -> None: ... +def inner( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + inner(input, other, *, out=None) -> Tensor + + Computes the dot product for 1D tensors. For higher dimensions, sums the product + of elements from :attr:`input` and :attr:`other` along their last dimension. + + .. note:: + + If either :attr:`input` or :attr:`other` is a scalar, the result is equivalent + to `torch.mul(input, other)`. + + If both :attr:`input` and :attr:`other` are non-scalars, the size of their last + dimension must match and the result is equivalent to `torch.tensordot(input, + other, dims=([-1], [-1]))` + + Args: + input (Tensor): First input tensor + other (Tensor): Second input tensor + + Keyword args: + out (Tensor, optional): Optional output tensor to write result into. The output + shape is `input.shape[:-1] + other.shape[:-1]`. + + Example:: + + # Dot product + >>> torch.inner(torch.tensor([1, 2, 3]), torch.tensor([0, 2, 1])) + tensor(7) + + # Multidimensional input tensors + >>> a = torch.randn(2, 3) + >>> a + tensor([[0.8173, 1.0874, 1.1784], + [0.3279, 0.1234, 2.7894]]) + >>> b = torch.randn(2, 4, 3) + >>> b + tensor([[[-0.4682, -0.7159, 0.1506], + [ 0.4034, -0.3657, 1.0387], + [ 0.9892, -0.6684, 0.1774], + [ 0.9482, 1.3261, 0.3917]], + + [[ 0.4537, 0.7493, 1.1724], + [ 0.2291, 0.5749, -0.2267], + [-0.7920, 0.3607, -0.3701], + [ 1.3666, -0.5850, -1.7242]]]) + >>> torch.inner(a, b) + tensor([[[-0.9837, 1.1560, 0.2907, 2.6785], + [ 2.5671, 0.5452, -0.6912, -1.5509]], + + [[ 0.1782, 2.9843, 0.7366, 1.5672], + [ 3.5115, -0.4864, -1.2476, -4.4337]]]) + + # Scalar input + >>> torch.inner(a, torch.tensor(2)) + tensor([[1.6347, 2.1748, 2.3567], + [0.6558, 0.2469, 5.5787]]) + """ + +def instance_norm( + input: Tensor, + weight: Tensor | None, + bias: Tensor | None, + running_mean: Tensor | None, + running_var: Tensor | None, + use_input_stats: _bool, + momentum: _float, + eps: _float, + cudnn_enabled: _bool, +) -> Tensor: ... +def int_repr(input: Tensor) -> Tensor: ... +def inverse(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + inverse(input, *, out=None) -> Tensor + + Alias for :func:`torch.linalg.inv` + """ + +def is_complex(input: Tensor) -> _bool: + r""" + is_complex(input) -> (bool) + + Returns True if the data type of :attr:`input` is a complex data type i.e., + one of ``torch.complex64``, and ``torch.complex128``. + + Args: + input (Tensor): the input tensor. + """ + +def is_conj(input: Tensor) -> _bool: + r""" + is_conj(input) -> (bool) + + Returns True if the :attr:`input` is a conjugated tensor, i.e. its conjugate bit is set to `True`. + + Args: + input (Tensor): the input tensor. + """ + +def is_distributed(input: Tensor) -> _bool: ... +def is_floating_point(input: Tensor) -> _bool: + r""" + is_floating_point(input) -> (bool) + + Returns True if the data type of :attr:`input` is a floating point data type i.e., + one of ``torch.float64``, ``torch.float32``, ``torch.float16``, and ``torch.bfloat16``. + + Args: + input (Tensor): the input tensor. + """ + +def is_grad_enabled() -> _bool: + r""" + is_grad_enabled() -> (bool) + + Returns True if grad mode is currently enabled. + """ + +def is_inference(input: Tensor) -> _bool: + r""" + is_inference(input) -> (bool) + + Returns True if :attr:`input` is an inference tensor. + + A non-view tensor is an inference tensor if and only if it was + allocated during inference mode. A view tensor is an inference + tensor if and only if the tensor it is a view of is an inference tensor. + + For details on inference mode please see + `Inference Mode `_. + + Args: + input (Tensor): the input tensor. + """ + +def is_inference_mode_enabled() -> _bool: + r""" + is_inference_mode_enabled() -> (bool) + + Returns True if inference mode is currently enabled. + """ + +def is_neg(input: Tensor) -> _bool: ... +def is_nonzero(input: Tensor) -> _bool: + r""" + is_nonzero(input) -> (bool) + + Returns True if the :attr:`input` is a single element tensor which is not equal to zero + after type conversions. + i.e. not equal to ``torch.tensor([0.])`` or ``torch.tensor([0])`` or + ``torch.tensor([False])``. + Throws a ``RuntimeError`` if ``torch.numel() != 1`` (even in case + of sparse tensors). + + Args: + input (Tensor): the input tensor. + + Examples:: + + >>> torch.is_nonzero(torch.tensor([0.])) + False + >>> torch.is_nonzero(torch.tensor([1.5])) + True + >>> torch.is_nonzero(torch.tensor([False])) + False + >>> torch.is_nonzero(torch.tensor([3])) + True + >>> torch.is_nonzero(torch.tensor([1, 3, 5])) + Traceback (most recent call last): + ... + RuntimeError: bool value of Tensor with more than one value is ambiguous + >>> torch.is_nonzero(torch.tensor([])) + Traceback (most recent call last): + ... + RuntimeError: bool value of Tensor with no values is ambiguous + """ + +def is_same_size(input: Tensor, other: Tensor) -> _bool: ... +def is_signed(input: Tensor) -> _bool: ... +def is_vulkan_available() -> _bool: ... +def isclose( + input: Tensor, + other: Tensor, + rtol: _float = 1e-05, + atol: _float = 1e-08, + equal_nan: _bool = False, +) -> Tensor: + r""" + isclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False) -> Tensor + + Returns a new tensor with boolean elements representing if each element of + :attr:`input` is "close" to the corresponding element of :attr:`other`. + Closeness is defined as: + + .. math:: + \lvert \text{input}_i - \text{other}_i \rvert \leq \texttt{rtol} \times \lvert \text{other}_i \rvert + \texttt{atol} + + + where :attr:`input` and :attr:`other` are finite. Where :attr:`input` + and/or :attr:`other` are nonfinite they are close if and only if + they are equal, with NaNs being considered equal to each other when + :attr:`equal_nan` is True. + + Args: + input (Tensor): first tensor to compare + other (Tensor): second tensor to compare + rtol (float, optional): relative tolerance. Default: 1e-05 + atol (float, optional): absolute tolerance. Default: 1e-08 + equal_nan (bool, optional): if ``True``, then two ``NaN`` s will be considered equal. Default: ``False`` + + Examples:: + + >>> torch.isclose(torch.tensor((1., 2, 3)), torch.tensor((1 + 1e-10, 3, 4))) + tensor([ True, False, False]) + >>> torch.isclose(torch.tensor((float('inf'), 4)), torch.tensor((float('inf'), 6)), rtol=.5) + tensor([True, True]) + """ + +def isfinite(input: Tensor) -> Tensor: + r""" + isfinite(input) -> Tensor + + Returns a new tensor with boolean elements representing if each element is `finite` or not. + + Real values are finite when they are not NaN, negative infinity, or infinity. + Complex values are finite when both their real and imaginary parts are finite. + + Args: + input (Tensor): the input tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is finite and False elsewhere + + Example:: + + >>> torch.isfinite(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')])) + tensor([True, False, True, False, False]) + """ + +@overload +def isin( + elements: Tensor, + test_elements: Tensor, + *, + assume_unique: _bool = False, + invert: _bool = False, + out: Tensor | None = None, +) -> Tensor: + r""" + isin(elements, test_elements, *, assume_unique=False, invert=False) -> Tensor + + Tests if each element of :attr:`elements` is in :attr:`test_elements`. Returns + a boolean tensor of the same shape as :attr:`elements` that is True for elements + in :attr:`test_elements` and False otherwise. + + .. note:: + One of :attr:`elements` or :attr:`test_elements` can be a scalar, but not both. + + Args: + elements (Tensor or Scalar): Input elements + test_elements (Tensor or Scalar): Values against which to test for each input element + assume_unique (bool, optional): If True, assumes both :attr:`elements` and + :attr:`test_elements` contain unique elements, which can speed up the + calculation. Default: False + invert (bool, optional): If True, inverts the boolean return tensor, resulting in True + values for elements *not* in :attr:`test_elements`. Default: False + + Returns: + A boolean tensor of the same shape as :attr:`elements` that is True for elements in + :attr:`test_elements` and False otherwise + + Example: + >>> torch.isin(torch.tensor([[1, 2], [3, 4]]), torch.tensor([2, 3])) + tensor([[False, True], + [ True, False]]) + """ + +@overload +def isin( + element: Number | _complex, + test_elements: Tensor, + *, + assume_unique: _bool = False, + invert: _bool = False, + out: Tensor | None = None, +) -> Tensor: + r""" + isin(elements, test_elements, *, assume_unique=False, invert=False) -> Tensor + + Tests if each element of :attr:`elements` is in :attr:`test_elements`. Returns + a boolean tensor of the same shape as :attr:`elements` that is True for elements + in :attr:`test_elements` and False otherwise. + + .. note:: + One of :attr:`elements` or :attr:`test_elements` can be a scalar, but not both. + + Args: + elements (Tensor or Scalar): Input elements + test_elements (Tensor or Scalar): Values against which to test for each input element + assume_unique (bool, optional): If True, assumes both :attr:`elements` and + :attr:`test_elements` contain unique elements, which can speed up the + calculation. Default: False + invert (bool, optional): If True, inverts the boolean return tensor, resulting in True + values for elements *not* in :attr:`test_elements`. Default: False + + Returns: + A boolean tensor of the same shape as :attr:`elements` that is True for elements in + :attr:`test_elements` and False otherwise + + Example: + >>> torch.isin(torch.tensor([[1, 2], [3, 4]]), torch.tensor([2, 3])) + tensor([[False, True], + [ True, False]]) + """ + +@overload +def isin( + elements: Tensor, + test_element: Number | _complex, + *, + assume_unique: _bool = False, + invert: _bool = False, + out: Tensor | None = None, +) -> Tensor: + r""" + isin(elements, test_elements, *, assume_unique=False, invert=False) -> Tensor + + Tests if each element of :attr:`elements` is in :attr:`test_elements`. Returns + a boolean tensor of the same shape as :attr:`elements` that is True for elements + in :attr:`test_elements` and False otherwise. + + .. note:: + One of :attr:`elements` or :attr:`test_elements` can be a scalar, but not both. + + Args: + elements (Tensor or Scalar): Input elements + test_elements (Tensor or Scalar): Values against which to test for each input element + assume_unique (bool, optional): If True, assumes both :attr:`elements` and + :attr:`test_elements` contain unique elements, which can speed up the + calculation. Default: False + invert (bool, optional): If True, inverts the boolean return tensor, resulting in True + values for elements *not* in :attr:`test_elements`. Default: False + + Returns: + A boolean tensor of the same shape as :attr:`elements` that is True for elements in + :attr:`test_elements` and False otherwise + + Example: + >>> torch.isin(torch.tensor([[1, 2], [3, 4]]), torch.tensor([2, 3])) + tensor([[False, True], + [ True, False]]) + """ + +def isinf(input: Tensor) -> Tensor: + r""" + isinf(input) -> Tensor + + Tests if each element of :attr:`input` is infinite + (positive or negative infinity) or not. + + .. note:: + Complex values are infinite when their real or imaginary part is + infinite. + + Args: + input (Tensor): the input tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is infinite and False elsewhere + + Example:: + + >>> torch.isinf(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')])) + tensor([False, True, False, True, False]) + """ + +def isnan(input: Tensor) -> Tensor: + r""" + isnan(input) -> Tensor + + Returns a new tensor with boolean elements representing if each element of :attr:`input` + is NaN or not. Complex values are considered NaN when either their real + and/or imaginary part is NaN. + + Arguments: + input (Tensor): the input tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is NaN and False elsewhere + + Example:: + + >>> torch.isnan(torch.tensor([1, float('nan'), 2])) + tensor([False, True, False]) + """ + +def isneginf(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + isneginf(input, *, out=None) -> Tensor + Tests if each element of :attr:`input` is negative infinity or not. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([-float('inf'), float('inf'), 1.2]) + >>> torch.isneginf(a) + tensor([ True, False, False]) + """ + +def isposinf(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + isposinf(input, *, out=None) -> Tensor + Tests if each element of :attr:`input` is positive infinity or not. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([-float('inf'), float('inf'), 1.2]) + >>> torch.isposinf(a) + tensor([False, True, False]) + """ + +def isreal(input: Tensor) -> Tensor: + r""" + isreal(input) -> Tensor + + Returns a new tensor with boolean elements representing if each element of :attr:`input` is real-valued or not. + All real-valued types are considered real. Complex values are considered real when their imaginary part is 0. + + Arguments: + input (Tensor): the input tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is real and False elsewhere + + Example:: + + >>> torch.isreal(torch.tensor([1, 1+1j, 2+0j])) + tensor([True, False, True]) + """ + +def istft( + input: Tensor, + n_fft: _int, + hop_length: _int | None = None, + win_length: _int | None = None, + window: Tensor | None = None, + center: _bool = True, + normalized: _bool = False, + onesided: _bool | None = None, + length: _int | None = None, + return_complex: _bool = False, +) -> Tensor: ... +@overload +def kaiser_window( + window_length: _int, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + kaiser_window(window_length, periodic=True, beta=12.0, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Computes the Kaiser window with window length :attr:`window_length` and shape parameter :attr:`beta`. + + Let I_0 be the zeroth order modified Bessel function of the first kind (see :func:`torch.i0`) and + ``N = L - 1`` if :attr:`periodic` is False and ``L`` if :attr:`periodic` is True, + where ``L`` is the :attr:`window_length`. This function computes: + + .. math:: + out_i = I_0 \left( \beta \sqrt{1 - \left( {\frac{i - N/2}{N/2}} \right) ^2 } \right) / I_0( \beta ) + + Calling ``torch.kaiser_window(L, B, periodic=True)`` is equivalent to calling + ``torch.kaiser_window(L + 1, B, periodic=False)[:-1])``. + The :attr:`periodic` argument is intended as a helpful shorthand + to produce a periodic window as input to functions like :func:`torch.stft`. + + .. note:: + If :attr:`window_length` is one, then the returned window is a single element tensor containing a one. + + + Args: + window_length (int): length of the window. + periodic (bool, optional): If True, returns a periodic window suitable for use in spectral analysis. + If False, returns a symmetric window suitable for use in filter design. + beta (float, optional): shape parameter for the window. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + """ + +@overload +def kaiser_window( + window_length: _int, + periodic: _bool, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + kaiser_window(window_length, periodic=True, beta=12.0, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Computes the Kaiser window with window length :attr:`window_length` and shape parameter :attr:`beta`. + + Let I_0 be the zeroth order modified Bessel function of the first kind (see :func:`torch.i0`) and + ``N = L - 1`` if :attr:`periodic` is False and ``L`` if :attr:`periodic` is True, + where ``L`` is the :attr:`window_length`. This function computes: + + .. math:: + out_i = I_0 \left( \beta \sqrt{1 - \left( {\frac{i - N/2}{N/2}} \right) ^2 } \right) / I_0( \beta ) + + Calling ``torch.kaiser_window(L, B, periodic=True)`` is equivalent to calling + ``torch.kaiser_window(L + 1, B, periodic=False)[:-1])``. + The :attr:`periodic` argument is intended as a helpful shorthand + to produce a periodic window as input to functions like :func:`torch.stft`. + + .. note:: + If :attr:`window_length` is one, then the returned window is a single element tensor containing a one. + + + Args: + window_length (int): length of the window. + periodic (bool, optional): If True, returns a periodic window suitable for use in spectral analysis. + If False, returns a symmetric window suitable for use in filter design. + beta (float, optional): shape parameter for the window. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + """ + +@overload +def kaiser_window( + window_length: _int, + periodic: _bool, + beta: _float, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + kaiser_window(window_length, periodic=True, beta=12.0, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Computes the Kaiser window with window length :attr:`window_length` and shape parameter :attr:`beta`. + + Let I_0 be the zeroth order modified Bessel function of the first kind (see :func:`torch.i0`) and + ``N = L - 1`` if :attr:`periodic` is False and ``L`` if :attr:`periodic` is True, + where ``L`` is the :attr:`window_length`. This function computes: + + .. math:: + out_i = I_0 \left( \beta \sqrt{1 - \left( {\frac{i - N/2}{N/2}} \right) ^2 } \right) / I_0( \beta ) + + Calling ``torch.kaiser_window(L, B, periodic=True)`` is equivalent to calling + ``torch.kaiser_window(L + 1, B, periodic=False)[:-1])``. + The :attr:`periodic` argument is intended as a helpful shorthand + to produce a periodic window as input to functions like :func:`torch.stft`. + + .. note:: + If :attr:`window_length` is one, then the returned window is a single element tensor containing a one. + + + Args: + window_length (int): length of the window. + periodic (bool, optional): If True, returns a periodic window suitable for use in spectral analysis. + If False, returns a symmetric window suitable for use in filter design. + beta (float, optional): shape parameter for the window. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + """ + +def kl_div( + input: Tensor, + target: Tensor, + reduction: _int = 1, + *, + log_target: _bool = False, +) -> Tensor: ... +def kron( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + kron(input, other, *, out=None) -> Tensor + + Computes the Kronecker product, denoted by :math:`\otimes`, of :attr:`input` and :attr:`other`. + + If :attr:`input` is a :math:`(a_0 \times a_1 \times \dots \times a_n)` tensor and :attr:`other` is a + :math:`(b_0 \times b_1 \times \dots \times b_n)` tensor, the result will be a + :math:`(a_0*b_0 \times a_1*b_1 \times \dots \times a_n*b_n)` tensor with the following entries: + + .. math:: + (\text{input} \otimes \text{other})_{k_0, k_1, \dots, k_n} = + \text{input}_{i_0, i_1, \dots, i_n} * \text{other}_{j_0, j_1, \dots, j_n}, + + where :math:`k_t = i_t * b_t + j_t` for :math:`0 \leq t \leq n`. + If one tensor has fewer dimensions than the other it is unsqueezed until it has the same number of dimensions. + + Supports real-valued and complex-valued inputs. + + .. note:: + This function generalizes the typical definition of the Kronecker product for two matrices to two tensors, + as described above. When :attr:`input` is a :math:`(m \times n)` matrix and :attr:`other` is a + :math:`(p \times q)` matrix, the result will be a :math:`(p*m \times q*n)` block matrix: + + .. math:: + \mathbf{A} \otimes \mathbf{B}=\begin{bmatrix} + a_{11} \mathbf{B} & \cdots & a_{1 n} \mathbf{B} \\ + \vdots & \ddots & \vdots \\ + a_{m 1} \mathbf{B} & \cdots & a_{m n} \mathbf{B} \end{bmatrix} + + where :attr:`input` is :math:`\mathbf{A}` and :attr:`other` is :math:`\mathbf{B}`. + + Arguments: + input (Tensor) + other (Tensor) + + Keyword args: + out (Tensor, optional): The output tensor. Ignored if ``None``. Default: ``None`` + + Examples:: + + >>> mat1 = torch.eye(2) + >>> mat2 = torch.ones(2, 2) + >>> torch.kron(mat1, mat2) + tensor([[1., 1., 0., 0.], + [1., 1., 0., 0.], + [0., 0., 1., 1.], + [0., 0., 1., 1.]]) + + >>> mat1 = torch.eye(2) + >>> mat2 = torch.arange(1, 5).reshape(2, 2) + >>> torch.kron(mat1, mat2) + tensor([[1., 2., 0., 0.], + [3., 4., 0., 0.], + [0., 0., 1., 2.], + [0., 0., 3., 4.]]) + """ + +@overload +def kthvalue( + input: Tensor, + k: _int | SymInt, + dim: _int = -1, + keepdim: _bool = False, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.kthvalue: + r""" + kthvalue(input, k, dim=None, keepdim=False, *, out=None) -> (Tensor, LongTensor) + + Returns a namedtuple ``(values, indices)`` where ``values`` is the :attr:`k` th + smallest element of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each element found. + + If :attr:`dim` is not given, the last dimension of the `input` is chosen. + + If :attr:`keepdim` is ``True``, both the :attr:`values` and :attr:`indices` tensors + are the same size as :attr:`input`, except in the dimension :attr:`dim` where + they are of size 1. Otherwise, :attr:`dim` is squeezed + (see :func:`torch.squeeze`), resulting in both the :attr:`values` and + :attr:`indices` tensors having 1 fewer dimension than the :attr:`input` tensor. + + .. note:: + When :attr:`input` is a CUDA tensor and there are multiple valid + :attr:`k` th values, this function may nondeterministically return + :attr:`indices` for any of them. + + Args: + input (Tensor): the input tensor. + k (int): k for the k-th smallest element + dim (int, optional): the dimension to find the kth value along + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (tuple, optional): the output tuple of (Tensor, LongTensor) + can be optionally given to be used as output buffers + + Example:: + + >>> x = torch.arange(1., 6.) + >>> x + tensor([ 1., 2., 3., 4., 5.]) + >>> torch.kthvalue(x, 4) + torch.return_types.kthvalue(values=tensor(4.), indices=tensor(3)) + + >>> x=torch.arange(1.,7.).resize_(2,3) + >>> x + tensor([[ 1., 2., 3.], + [ 4., 5., 6.]]) + >>> torch.kthvalue(x, 2, 0, True) + torch.return_types.kthvalue(values=tensor([[4., 5., 6.]]), indices=tensor([[1, 1, 1]])) + """ + +@overload +def kthvalue( + input: Tensor, + k: _int | SymInt, + dim: str | EllipsisType | None, + keepdim: _bool = False, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.kthvalue: + r""" + kthvalue(input, k, dim=None, keepdim=False, *, out=None) -> (Tensor, LongTensor) + + Returns a namedtuple ``(values, indices)`` where ``values`` is the :attr:`k` th + smallest element of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each element found. + + If :attr:`dim` is not given, the last dimension of the `input` is chosen. + + If :attr:`keepdim` is ``True``, both the :attr:`values` and :attr:`indices` tensors + are the same size as :attr:`input`, except in the dimension :attr:`dim` where + they are of size 1. Otherwise, :attr:`dim` is squeezed + (see :func:`torch.squeeze`), resulting in both the :attr:`values` and + :attr:`indices` tensors having 1 fewer dimension than the :attr:`input` tensor. + + .. note:: + When :attr:`input` is a CUDA tensor and there are multiple valid + :attr:`k` th values, this function may nondeterministically return + :attr:`indices` for any of them. + + Args: + input (Tensor): the input tensor. + k (int): k for the k-th smallest element + dim (int, optional): the dimension to find the kth value along + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (tuple, optional): the output tuple of (Tensor, LongTensor) + can be optionally given to be used as output buffers + + Example:: + + >>> x = torch.arange(1., 6.) + >>> x + tensor([ 1., 2., 3., 4., 5.]) + >>> torch.kthvalue(x, 4) + torch.return_types.kthvalue(values=tensor(4.), indices=tensor(3)) + + >>> x=torch.arange(1.,7.).resize_(2,3) + >>> x + tensor([[ 1., 2., 3.], + [ 4., 5., 6.]]) + >>> torch.kthvalue(x, 2, 0, True) + torch.return_types.kthvalue(values=tensor([[4., 5., 6.]]), indices=tensor([[1, 1, 1]])) + """ + +def layer_norm( + input: Tensor, + normalized_shape: Sequence[_int | SymInt], + weight: Tensor | None = None, + bias: Tensor | None = None, + eps: _float = 1e-05, + cudnn_enable: _bool = True, +) -> Tensor: ... +def lcm( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + lcm(input, other, *, out=None) -> Tensor + + Computes the element-wise least common multiple (LCM) of :attr:`input` and :attr:`other`. + + Both :attr:`input` and :attr:`other` must have integer types. + + .. note:: + This defines :math:`lcm(0, 0) = 0` and :math:`lcm(0, a) = 0`. + + Args: + input (Tensor): the input tensor. + other (Tensor): the second input tensor + + Keyword arguments: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([5, 10, 15]) + >>> b = torch.tensor([3, 4, 5]) + >>> torch.lcm(a, b) + tensor([15, 20, 15]) + >>> c = torch.tensor([3]) + >>> torch.lcm(a, c) + tensor([15, 30, 15]) + """ + +def lcm_(input: Tensor, other: Tensor) -> Tensor: ... +def ldexp( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + ldexp(input, other, *, out=None) -> Tensor + + Multiplies :attr:`input` by 2 ** :attr:`other`. + + .. math:: + \text{{out}}_i = \text{{input}}_i * 2^\text{{other}}_i + + + Typically this function is used to construct floating point numbers by multiplying + mantissas in :attr:`input` with integral powers of two created from the exponents + in :attr:`other`. + + Args: + input (Tensor): the input tensor. + other (Tensor): a tensor of exponents, typically integers. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.ldexp(torch.tensor([1.]), torch.tensor([1])) + tensor([2.]) + >>> torch.ldexp(torch.tensor([1.0]), torch.tensor([1, 2, 3, 4])) + tensor([ 2., 4., 8., 16.]) + """ + +def ldexp_(input: Tensor, other: Tensor) -> Tensor: ... +@overload +def le( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + le(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} \leq \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or Scalar): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is less than or equal to + :attr:`other` and False elsewhere + + Example:: + + >>> torch.le(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[True, False], [True, True]]) + """ + +@overload +def le( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + le(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} \leq \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or Scalar): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is less than or equal to + :attr:`other` and False elsewhere + + Example:: + + >>> torch.le(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[True, False], [True, True]]) + """ + +@overload +def lerp( + input: Tensor, + end: Tensor, + weight: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + lerp(input, end, weight, *, out=None) + + Does a linear interpolation of two tensors :attr:`start` (given by :attr:`input`) and :attr:`end` based + on a scalar or tensor :attr:`weight` and returns the resulting :attr:`out` tensor. + + .. math:: + \text{out}_i = \text{start}_i + \text{weight}_i \times (\text{end}_i - \text{start}_i) + + The shapes of :attr:`start` and :attr:`end` must be + :ref:`broadcastable `. If :attr:`weight` is a tensor, then + the shapes of :attr:`weight`, :attr:`start`, and :attr:`end` must be :ref:`broadcastable `. + + Args: + input (Tensor): the tensor with the starting points + end (Tensor): the tensor with the ending points + weight (float or tensor): the weight for the interpolation formula + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> start = torch.arange(1., 5.) + >>> end = torch.empty(4).fill_(10) + >>> start + tensor([ 1., 2., 3., 4.]) + >>> end + tensor([ 10., 10., 10., 10.]) + >>> torch.lerp(start, end, 0.5) + tensor([ 5.5000, 6.0000, 6.5000, 7.0000]) + >>> torch.lerp(start, end, torch.full_like(start, 0.5)) + tensor([ 5.5000, 6.0000, 6.5000, 7.0000]) + """ + +@overload +def lerp( + input: Tensor, + end: Tensor, + weight: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + lerp(input, end, weight, *, out=None) + + Does a linear interpolation of two tensors :attr:`start` (given by :attr:`input`) and :attr:`end` based + on a scalar or tensor :attr:`weight` and returns the resulting :attr:`out` tensor. + + .. math:: + \text{out}_i = \text{start}_i + \text{weight}_i \times (\text{end}_i - \text{start}_i) + + The shapes of :attr:`start` and :attr:`end` must be + :ref:`broadcastable `. If :attr:`weight` is a tensor, then + the shapes of :attr:`weight`, :attr:`start`, and :attr:`end` must be :ref:`broadcastable `. + + Args: + input (Tensor): the tensor with the starting points + end (Tensor): the tensor with the ending points + weight (float or tensor): the weight for the interpolation formula + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> start = torch.arange(1., 5.) + >>> end = torch.empty(4).fill_(10) + >>> start + tensor([ 1., 2., 3., 4.]) + >>> end + tensor([ 10., 10., 10., 10.]) + >>> torch.lerp(start, end, 0.5) + tensor([ 5.5000, 6.0000, 6.5000, 7.0000]) + >>> torch.lerp(start, end, torch.full_like(start, 0.5)) + tensor([ 5.5000, 6.0000, 6.5000, 7.0000]) + """ + +@overload +def less( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + less(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.lt`. + """ + +@overload +def less( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + less(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.lt`. + """ + +@overload +def less_equal( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + less_equal(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.le`. + """ + +@overload +def less_equal( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + less_equal(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.le`. + """ + +def lgamma(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + lgamma(input, *, out=None) -> Tensor + + Computes the natural logarithm of the absolute value of the gamma function on :attr:`input`. + + .. math:: + \text{out}_{i} = \ln |\Gamma(\text{input}_{i})| + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.arange(0.5, 2, 0.5) + >>> torch.lgamma(a) + tensor([ 0.5724, 0.0000, -0.1208]) + """ + +@overload +def linspace( + start: Number, + end: Number, + steps: _int | None = None, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + device: DeviceLikeType | None = None, + requires_grad: _bool = False, + pin_memory: _bool = False, +) -> Tensor: + r""" + linspace(start, end, steps, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :attr:`start` to :attr:`end`, inclusive. That is, the value are: + + .. math:: + (\text{start}, + \text{start} + \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \ldots, + \text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \text{end}) + + + From PyTorch 1.11 linspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + + Example:: + + >>> torch.linspace(3, 10, steps=5) + tensor([ 3.0000, 4.7500, 6.5000, 8.2500, 10.0000]) + >>> torch.linspace(-10, 10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=1) + tensor([-10.]) + """ + +@overload +def linspace( + start: Tensor, + end: Tensor, + steps: _int, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + linspace(start, end, steps, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :attr:`start` to :attr:`end`, inclusive. That is, the value are: + + .. math:: + (\text{start}, + \text{start} + \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \ldots, + \text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \text{end}) + + + From PyTorch 1.11 linspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + + Example:: + + >>> torch.linspace(3, 10, steps=5) + tensor([ 3.0000, 4.7500, 6.5000, 8.2500, 10.0000]) + >>> torch.linspace(-10, 10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=1) + tensor([-10.]) + """ + +@overload +def linspace( + start: Number | _complex, + end: Tensor, + steps: _int, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + linspace(start, end, steps, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :attr:`start` to :attr:`end`, inclusive. That is, the value are: + + .. math:: + (\text{start}, + \text{start} + \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \ldots, + \text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \text{end}) + + + From PyTorch 1.11 linspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + + Example:: + + >>> torch.linspace(3, 10, steps=5) + tensor([ 3.0000, 4.7500, 6.5000, 8.2500, 10.0000]) + >>> torch.linspace(-10, 10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=1) + tensor([-10.]) + """ + +@overload +def linspace( + start: Tensor, + end: Number | _complex, + steps: _int, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + linspace(start, end, steps, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :attr:`start` to :attr:`end`, inclusive. That is, the value are: + + .. math:: + (\text{start}, + \text{start} + \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \ldots, + \text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \text{end}) + + + From PyTorch 1.11 linspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + + Example:: + + >>> torch.linspace(3, 10, steps=5) + tensor([ 3.0000, 4.7500, 6.5000, 8.2500, 10.0000]) + >>> torch.linspace(-10, 10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=1) + tensor([-10.]) + """ + +@overload +def linspace( + start: Number | _complex, + end: Number | _complex, + steps: _int, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + linspace(start, end, steps, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :attr:`start` to :attr:`end`, inclusive. That is, the value are: + + .. math:: + (\text{start}, + \text{start} + \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \ldots, + \text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \text{end}) + + + From PyTorch 1.11 linspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + + Example:: + + >>> torch.linspace(3, 10, steps=5) + tensor([ 3.0000, 4.7500, 6.5000, 8.2500, 10.0000]) + >>> torch.linspace(-10, 10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=1) + tensor([-10.]) + """ + +def log(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + log(input, *, out=None) -> Tensor + + Returns a new tensor with the natural logarithm of the elements + of :attr:`input`. + + .. math:: + y_{i} = \log_{e} (x_{i}) + + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(5) * 5 + >>> a + tensor([4.7767, 4.3234, 1.2156, 0.2411, 4.5739]) + >>> torch.log(a) + tensor([ 1.5637, 1.4640, 0.1952, -1.4226, 1.5204]) + """ + +def log10(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + log10(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Returns a new tensor with the logarithm to the base 10 of the elements + of :attr:`input`. + + .. math:: + y_{i} = \log_{10} (x_{i}) + + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(5) + >>> a + tensor([ 0.5224, 0.9354, 0.7257, 0.1301, 0.2251]) + + + >>> torch.log10(a) + tensor([-0.2820, -0.0290, -0.1392, -0.8857, -0.6476]) + """ + +def log10_(input: Tensor) -> Tensor: ... +def log1p(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + log1p(input, *, out=None) -> Tensor + + Returns a new tensor with the natural logarithm of (1 + :attr:`input`). + + .. math:: + y_i = \log_{e} (x_i + 1) + + .. note:: This function is more accurate than :func:`torch.log` for small + values of :attr:`input` + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(5) + >>> a + tensor([-1.0090, -0.9923, 1.0249, -0.5372, 0.2492]) + >>> torch.log1p(a) + tensor([ nan, -4.8653, 0.7055, -0.7705, 0.2225]) + """ + +def log1p_(input: Tensor) -> Tensor: ... +def log2(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + log2(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Returns a new tensor with the logarithm to the base 2 of the elements + of :attr:`input`. + + .. math:: + y_{i} = \log_{2} (x_{i}) + + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(5) + >>> a + tensor([ 0.8419, 0.8003, 0.9971, 0.5287, 0.0490]) + + + >>> torch.log2(a) + tensor([-0.2483, -0.3213, -0.0042, -0.9196, -4.3504]) + """ + +def log2_(input: Tensor) -> Tensor: ... +def log_(input: Tensor) -> Tensor: ... +@overload +def log_softmax( + input: Tensor, + dim: _int, + dtype: _dtype | None = None, + *, + out: Tensor | None = None, +) -> Tensor: ... +@overload +def log_softmax( + input: Tensor, + dim: str | EllipsisType | None, + *, + dtype: _dtype | None = None, +) -> Tensor: ... +def logaddexp( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + logaddexp(input, other, *, out=None) -> Tensor + + Logarithm of the sum of exponentiations of the inputs. + + Calculates pointwise :math:`\log\left(e^x + e^y\right)`. This function is useful + in statistics where the calculated probabilities of events may be so small as to + exceed the range of normal floating point numbers. In such cases the logarithm + of the calculated probability is stored. This function allows adding + probabilities stored in such a fashion. + + This op should be disambiguated with :func:`torch.logsumexp` which performs a + reduction on a single tensor. + + Args: + input (Tensor): the input tensor. + other (Tensor): the second input tensor + + Keyword arguments: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.logaddexp(torch.tensor([-1.0]), torch.tensor([-1.0, -2, -3])) + tensor([-0.3069, -0.6867, -0.8731]) + >>> torch.logaddexp(torch.tensor([-100.0, -200, -300]), torch.tensor([-1.0, -2, -3])) + tensor([-1., -2., -3.]) + >>> torch.logaddexp(torch.tensor([1.0, 2000, 30000]), torch.tensor([-1.0, -2, -3])) + tensor([1.1269e+00, 2.0000e+03, 3.0000e+04]) + """ + +def logaddexp2( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + logaddexp2(input, other, *, out=None) -> Tensor + + Logarithm of the sum of exponentiations of the inputs in base-2. + + Calculates pointwise :math:`\log_2\left(2^x + 2^y\right)`. See + :func:`torch.logaddexp` for more details. + + Args: + input (Tensor): the input tensor. + other (Tensor): the second input tensor + + Keyword arguments: + out (Tensor, optional): the output tensor. + """ + +@overload +def logcumsumexp( + input: Tensor, + dim: _int, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + logcumsumexp(input, dim, *, out=None) -> Tensor + Returns the logarithm of the cumulative summation of the exponentiation of + elements of :attr:`input` in the dimension :attr:`dim`. + + For summation index :math:`j` given by `dim` and other indices :math:`i`, the result is + + .. math:: + \text{logcumsumexp}(x)_{ij} = \log \sum\limits_{k=0}^{j} \exp(x_{ik}) + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(10) + >>> torch.logcumsumexp(a, dim=0) + tensor([-0.42296738, -0.04462666, 0.86278635, 0.94622083, 1.05277811, + 1.39202815, 1.83525007, 1.84492621, 2.06084887, 2.06844475])) + """ + +@overload +def logcumsumexp( + input: Tensor, + dim: str | EllipsisType | None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + logcumsumexp(input, dim, *, out=None) -> Tensor + Returns the logarithm of the cumulative summation of the exponentiation of + elements of :attr:`input` in the dimension :attr:`dim`. + + For summation index :math:`j` given by `dim` and other indices :math:`i`, the result is + + .. math:: + \text{logcumsumexp}(x)_{ij} = \log \sum\limits_{k=0}^{j} \exp(x_{ik}) + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(10) + >>> torch.logcumsumexp(a, dim=0) + tensor([-0.42296738, -0.04462666, 0.86278635, 0.94622083, 1.05277811, + 1.39202815, 1.83525007, 1.84492621, 2.06084887, 2.06844475])) + """ + +def logdet(input: Tensor) -> Tensor: + r""" + logdet(input) -> Tensor + + Calculates log determinant of a square matrix or batches of square matrices. + + It returns ``-inf`` if the input has a determinant of zero, and ``NaN`` if it has + a negative determinant. + + .. note:: + Backward through :meth:`logdet` internally uses SVD results when :attr:`input` + is not invertible. In this case, double backward through :meth:`logdet` will + be unstable in when :attr:`input` doesn't have distinct singular values. See + :func:`torch.linalg.svd` for details. + + .. seealso:: + + :func:`torch.linalg.slogdet` computes the sign (resp. angle) and natural logarithm of the + absolute value of the determinant of real-valued (resp. complex) square matrices. + + Arguments: + input (Tensor): the input tensor of size ``(*, n, n)`` where ``*`` is zero or more + batch dimensions. + + Example:: + + >>> A = torch.randn(3, 3) + >>> torch.det(A) + tensor(0.2611) + >>> torch.logdet(A) + tensor(-1.3430) + >>> A + tensor([[[ 0.9254, -0.6213], + [-0.5787, 1.6843]], + + [[ 0.3242, -0.9665], + [ 0.4539, -0.0887]], + + [[ 1.1336, -0.4025], + [-0.7089, 0.9032]]]) + >>> A.det() + tensor([1.1990, 0.4099, 0.7386]) + >>> A.det().log() + tensor([ 0.1815, -0.8917, -0.3031]) + """ + +def logical_and( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + logical_and(input, other, *, out=None) -> Tensor + + Computes the element-wise logical AND of the given input tensors. Zeros are treated as ``False`` and nonzeros are + treated as ``True``. + + Args: + input (Tensor): the input tensor. + other (Tensor): the tensor to compute AND with + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.logical_and(torch.tensor([True, False, True]), torch.tensor([True, False, False])) + tensor([ True, False, False]) + >>> a = torch.tensor([0, 1, 10, 0], dtype=torch.int8) + >>> b = torch.tensor([4, 0, 1, 0], dtype=torch.int8) + >>> torch.logical_and(a, b) + tensor([False, False, True, False]) + >>> torch.logical_and(a.double(), b.double()) + tensor([False, False, True, False]) + >>> torch.logical_and(a.double(), b) + tensor([False, False, True, False]) + >>> torch.logical_and(a, b, out=torch.empty(4, dtype=torch.bool)) + tensor([False, False, True, False]) + """ + +def logical_not(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + logical_not(input, *, out=None) -> Tensor + + Computes the element-wise logical NOT of the given input tensor. If not specified, the output tensor will have the bool + dtype. If the input tensor is not a bool tensor, zeros are treated as ``False`` and non-zeros are treated as ``True``. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.logical_not(torch.tensor([True, False])) + tensor([False, True]) + >>> torch.logical_not(torch.tensor([0, 1, -10], dtype=torch.int8)) + tensor([ True, False, False]) + >>> torch.logical_not(torch.tensor([0., 1.5, -10.], dtype=torch.double)) + tensor([ True, False, False]) + >>> torch.logical_not(torch.tensor([0., 1., -10.], dtype=torch.double), out=torch.empty(3, dtype=torch.int16)) + tensor([1, 0, 0], dtype=torch.int16) + """ + +def logical_or( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + logical_or(input, other, *, out=None) -> Tensor + + Computes the element-wise logical OR of the given input tensors. Zeros are treated as ``False`` and nonzeros are + treated as ``True``. + + Args: + input (Tensor): the input tensor. + other (Tensor): the tensor to compute OR with + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.logical_or(torch.tensor([True, False, True]), torch.tensor([True, False, False])) + tensor([ True, False, True]) + >>> a = torch.tensor([0, 1, 10, 0], dtype=torch.int8) + >>> b = torch.tensor([4, 0, 1, 0], dtype=torch.int8) + >>> torch.logical_or(a, b) + tensor([ True, True, True, False]) + >>> torch.logical_or(a.double(), b.double()) + tensor([ True, True, True, False]) + >>> torch.logical_or(a.double(), b) + tensor([ True, True, True, False]) + >>> torch.logical_or(a, b, out=torch.empty(4, dtype=torch.bool)) + tensor([ True, True, True, False]) + """ + +def logical_xor( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + logical_xor(input: Tensor, other: Tensor, *, out: Optional[Tensor]) -> Tensor + + Computes the element-wise logical XOR of the given input tensors. Zeros are treated as ``False`` and nonzeros are + treated as ``True``. + + Args: + input (Tensor): the input tensor. + other (Tensor): the tensor to compute XOR with + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.logical_xor(torch.tensor([True, False, True]), torch.tensor([True, False, False])) + tensor([False, False, True]) + >>> a = torch.tensor([0, 1, 10, 0], dtype=torch.int8) + >>> b = torch.tensor([4, 0, 1, 0], dtype=torch.int8) + >>> torch.logical_xor(a, b) + tensor([ True, True, False, False]) + >>> torch.logical_xor(a.double(), b.double()) + tensor([ True, True, False, False]) + >>> torch.logical_xor(a.double(), b) + tensor([ True, True, False, False]) + >>> torch.logical_xor(a, b, out=torch.empty(4, dtype=torch.bool)) + tensor([ True, True, False, False]) + """ + +def logit( + input: Tensor, + eps: _float | None = None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + logit(input, eps=None, *, out=None) -> Tensor + + Alias for :func:`torch.special.logit`. + """ + +def logit_(input: Tensor, eps: _float | None = None) -> Tensor: ... +@overload +def logspace( + start: Number, + end: Number, + steps: _int | None = None, + base: _float = 10.0, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + device: DeviceLikeType | None = None, + requires_grad: _bool = False, + pin_memory: _bool = False, +) -> Tensor: + r""" + logspace(start, end, steps, base=10.0, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :math:`{{\text{{base}}}}^{{\text{{start}}}}` to + :math:`{{\text{{base}}}}^{{\text{{end}}}}`, inclusive, on a logarithmic scale + with base :attr:`base`. That is, the values are: + + .. math:: + (\text{base}^{\text{start}}, + \text{base}^{(\text{start} + \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \ldots, + \text{base}^{(\text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \text{base}^{\text{end}}) + + + + From PyTorch 1.11 logspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + base (float, optional): base of the logarithm function. Default: ``10.0``. + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.logspace(start=-10, end=10, steps=5) + tensor([ 1.0000e-10, 1.0000e-05, 1.0000e+00, 1.0000e+05, 1.0000e+10]) + >>> torch.logspace(start=0.1, end=1.0, steps=5) + tensor([ 1.2589, 2.1135, 3.5481, 5.9566, 10.0000]) + >>> torch.logspace(start=0.1, end=1.0, steps=1) + tensor([1.2589]) + >>> torch.logspace(start=2, end=2, steps=1, base=2) + tensor([4.0]) + """ + +@overload +def logspace( + start: Tensor, + end: Tensor, + steps: _int, + base: _float = 10.0, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + logspace(start, end, steps, base=10.0, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :math:`{{\text{{base}}}}^{{\text{{start}}}}` to + :math:`{{\text{{base}}}}^{{\text{{end}}}}`, inclusive, on a logarithmic scale + with base :attr:`base`. That is, the values are: + + .. math:: + (\text{base}^{\text{start}}, + \text{base}^{(\text{start} + \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \ldots, + \text{base}^{(\text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \text{base}^{\text{end}}) + + + + From PyTorch 1.11 logspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + base (float, optional): base of the logarithm function. Default: ``10.0``. + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.logspace(start=-10, end=10, steps=5) + tensor([ 1.0000e-10, 1.0000e-05, 1.0000e+00, 1.0000e+05, 1.0000e+10]) + >>> torch.logspace(start=0.1, end=1.0, steps=5) + tensor([ 1.2589, 2.1135, 3.5481, 5.9566, 10.0000]) + >>> torch.logspace(start=0.1, end=1.0, steps=1) + tensor([1.2589]) + >>> torch.logspace(start=2, end=2, steps=1, base=2) + tensor([4.0]) + """ + +@overload +def logspace( + start: Number | _complex, + end: Tensor, + steps: _int, + base: _float = 10.0, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + logspace(start, end, steps, base=10.0, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :math:`{{\text{{base}}}}^{{\text{{start}}}}` to + :math:`{{\text{{base}}}}^{{\text{{end}}}}`, inclusive, on a logarithmic scale + with base :attr:`base`. That is, the values are: + + .. math:: + (\text{base}^{\text{start}}, + \text{base}^{(\text{start} + \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \ldots, + \text{base}^{(\text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \text{base}^{\text{end}}) + + + + From PyTorch 1.11 logspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + base (float, optional): base of the logarithm function. Default: ``10.0``. + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.logspace(start=-10, end=10, steps=5) + tensor([ 1.0000e-10, 1.0000e-05, 1.0000e+00, 1.0000e+05, 1.0000e+10]) + >>> torch.logspace(start=0.1, end=1.0, steps=5) + tensor([ 1.2589, 2.1135, 3.5481, 5.9566, 10.0000]) + >>> torch.logspace(start=0.1, end=1.0, steps=1) + tensor([1.2589]) + >>> torch.logspace(start=2, end=2, steps=1, base=2) + tensor([4.0]) + """ + +@overload +def logspace( + start: Tensor, + end: Number | _complex, + steps: _int, + base: _float = 10.0, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + logspace(start, end, steps, base=10.0, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :math:`{{\text{{base}}}}^{{\text{{start}}}}` to + :math:`{{\text{{base}}}}^{{\text{{end}}}}`, inclusive, on a logarithmic scale + with base :attr:`base`. That is, the values are: + + .. math:: + (\text{base}^{\text{start}}, + \text{base}^{(\text{start} + \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \ldots, + \text{base}^{(\text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \text{base}^{\text{end}}) + + + + From PyTorch 1.11 logspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + base (float, optional): base of the logarithm function. Default: ``10.0``. + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.logspace(start=-10, end=10, steps=5) + tensor([ 1.0000e-10, 1.0000e-05, 1.0000e+00, 1.0000e+05, 1.0000e+10]) + >>> torch.logspace(start=0.1, end=1.0, steps=5) + tensor([ 1.2589, 2.1135, 3.5481, 5.9566, 10.0000]) + >>> torch.logspace(start=0.1, end=1.0, steps=1) + tensor([1.2589]) + >>> torch.logspace(start=2, end=2, steps=1, base=2) + tensor([4.0]) + """ + +@overload +def logspace( + start: Number | _complex, + end: Number | _complex, + steps: _int, + base: _float = 10.0, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + logspace(start, end, steps, base=10.0, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :math:`{{\text{{base}}}}^{{\text{{start}}}}` to + :math:`{{\text{{base}}}}^{{\text{{end}}}}`, inclusive, on a logarithmic scale + with base :attr:`base`. That is, the values are: + + .. math:: + (\text{base}^{\text{start}}, + \text{base}^{(\text{start} + \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \ldots, + \text{base}^{(\text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \text{base}^{\text{end}}) + + + + From PyTorch 1.11 logspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + base (float, optional): base of the logarithm function. Default: ``10.0``. + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.logspace(start=-10, end=10, steps=5) + tensor([ 1.0000e-10, 1.0000e-05, 1.0000e+00, 1.0000e+05, 1.0000e+10]) + >>> torch.logspace(start=0.1, end=1.0, steps=5) + tensor([ 1.2589, 2.1135, 3.5481, 5.9566, 10.0000]) + >>> torch.logspace(start=0.1, end=1.0, steps=1) + tensor([1.2589]) + >>> torch.logspace(start=2, end=2, steps=1, base=2) + tensor([4.0]) + """ + +@overload +def logsumexp( + input: Tensor, + dim: _int | _size, + keepdim: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + logsumexp(input, dim, keepdim=False, *, out=None) + + Returns the log of summed exponentials of each row of the :attr:`input` + tensor in the given dimension :attr:`dim`. The computation is numerically + stabilized. + + For summation index :math:`j` given by `dim` and other indices :math:`i`, the result is + + .. math:: + \text{logsumexp}(x)_{i} = \log \sum_j \exp(x_{ij}) + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(3, 3) + >>> torch.logsumexp(a, 1) + tensor([1.4907, 1.0593, 1.5696]) + >>> torch.dist(torch.logsumexp(a, 1), torch.log(torch.sum(torch.exp(a), 1))) + tensor(1.6859e-07) + """ + +@overload +def logsumexp( + input: Tensor, + dim: Sequence[str | EllipsisType | None], + keepdim: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + logsumexp(input, dim, keepdim=False, *, out=None) + + Returns the log of summed exponentials of each row of the :attr:`input` + tensor in the given dimension :attr:`dim`. The computation is numerically + stabilized. + + For summation index :math:`j` given by `dim` and other indices :math:`i`, the result is + + .. math:: + \text{logsumexp}(x)_{i} = \log \sum_j \exp(x_{ij}) + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(3, 3) + >>> torch.logsumexp(a, 1) + tensor([1.4907, 1.0593, 1.5696]) + >>> torch.dist(torch.logsumexp(a, 1), torch.log(torch.sum(torch.exp(a), 1))) + tensor(1.6859e-07) + """ + +@overload +def lstm( + data: Tensor, + batch_sizes: Tensor, + hx: tuple[Tensor, ...] | list[Tensor] | None, + params: tuple[Tensor, ...] | list[Tensor] | None, + has_biases: _bool, + num_layers: _int, + dropout: _float, + train: _bool, + bidirectional: _bool, +) -> tuple[Tensor, Tensor, Tensor]: ... +@overload +def lstm( + input: Tensor, + hx: tuple[Tensor, ...] | list[Tensor] | None, + params: tuple[Tensor, ...] | list[Tensor] | None, + has_biases: _bool, + num_layers: _int, + dropout: _float, + train: _bool, + bidirectional: _bool, + batch_first: _bool, +) -> tuple[Tensor, Tensor, Tensor]: ... +def lstm_cell( + input: Tensor, + hx: tuple[Tensor, ...] | list[Tensor] | None, + w_ih: Tensor, + w_hh: Tensor, + b_ih: Tensor | None = None, + b_hh: Tensor | None = None, +) -> tuple[Tensor, Tensor]: ... +@overload +def lt( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + lt(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} < \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is less than :attr:`other` and False elsewhere + + Example:: + + >>> torch.lt(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[False, False], [True, False]]) + """ + +@overload +def lt( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + lt(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} < \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is less than :attr:`other` and False elsewhere + + Example:: + + >>> torch.lt(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[False, False], [True, False]]) + """ + +def lu_solve( + input: Tensor, + LU_data: Tensor, + LU_pivots: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + lu_solve(b, LU_data, LU_pivots, *, out=None) -> Tensor + + Returns the LU solve of the linear system :math:`Ax = b` using the partially pivoted + LU factorization of A from :func:`~linalg.lu_factor`. + + This function supports ``float``, ``double``, ``cfloat`` and ``cdouble`` dtypes for :attr:`input`. + + .. warning:: + + :func:`torch.lu_solve` is deprecated in favor of :func:`torch.linalg.lu_solve`. + :func:`torch.lu_solve` will be removed in a future PyTorch release. + ``X = torch.lu_solve(B, LU, pivots)`` should be replaced with + + .. code:: python + + X = linalg.lu_solve(LU, pivots, B) + + Arguments: + b (Tensor): the RHS tensor of size :math:`(*, m, k)`, where :math:`*` + is zero or more batch dimensions. + LU_data (Tensor): the pivoted LU factorization of A from :meth:`~linalg.lu_factor` of size :math:`(*, m, m)`, + where :math:`*` is zero or more batch dimensions. + LU_pivots (IntTensor): the pivots of the LU factorization from :meth:`~linalg.lu_factor` of size :math:`(*, m)`, + where :math:`*` is zero or more batch dimensions. + The batch dimensions of :attr:`LU_pivots` must be equal to the batch dimensions of + :attr:`LU_data`. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> A = torch.randn(2, 3, 3) + >>> b = torch.randn(2, 3, 1) + >>> LU, pivots = torch.linalg.lu_factor(A) + >>> x = torch.lu_solve(b, LU, pivots) + >>> torch.dist(A @ x, b) + tensor(1.00000e-07 * + 2.8312) + """ + +def lu_unpack( + LU_data: Tensor, + LU_pivots: Tensor, + unpack_data: _bool = True, + unpack_pivots: _bool = True, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.lu_unpack: + r""" + lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True, *, out=None) -> (Tensor, Tensor, Tensor) + + Unpacks the LU decomposition returned by :func:`~linalg.lu_factor` into the `P, L, U` matrices. + + .. seealso:: + + :func:`~linalg.lu` returns the matrices from the LU decomposition. Its gradient formula is more efficient + than that of doing :func:`~linalg.lu_factor` followed by :func:`~linalg.lu_unpack`. + + Args: + LU_data (Tensor): the packed LU factorization data + LU_pivots (Tensor): the packed LU factorization pivots + unpack_data (bool): flag indicating if the data should be unpacked. + If ``False``, then the returned ``L`` and ``U`` are empty tensors. + Default: ``True`` + unpack_pivots (bool): flag indicating if the pivots should be unpacked into a permutation matrix ``P``. + If ``False``, then the returned ``P`` is an empty tensor. + Default: ``True`` + + Keyword args: + out (tuple, optional): output tuple of three tensors. Ignored if `None`. + + Returns: + A namedtuple ``(P, L, U)`` + + Examples:: + + >>> A = torch.randn(2, 3, 3) + >>> LU, pivots = torch.linalg.lu_factor(A) + >>> P, L, U = torch.lu_unpack(LU, pivots) + >>> # We can recover A from the factorization + >>> A_ = P @ L @ U + >>> torch.allclose(A, A_) + True + + >>> # LU factorization of a rectangular matrix: + >>> A = torch.randn(2, 3, 2) + >>> LU, pivots = torch.linalg.lu_factor(A) + >>> P, L, U = torch.lu_unpack(LU, pivots) + >>> # P, L, U are the same as returned by linalg.lu + >>> P_, L_, U_ = torch.linalg.lu(A) + >>> torch.allclose(P, P_) and torch.allclose(L, L_) and torch.allclose(U, U_) + True + """ + +def margin_ranking_loss( + input1: Tensor, + input2: Tensor, + target: Tensor, + margin: _float = 0.0, + reduction: _int = 1, +) -> Tensor: ... +@overload +def masked_fill(input: Tensor, mask: Tensor, value: Tensor) -> Tensor: ... +@overload +def masked_fill( + input: Tensor, + mask: Tensor, + value: Number | _complex, +) -> Tensor: ... +def masked_scatter(input: Tensor, mask: Tensor, source: Tensor) -> Tensor: ... +def masked_select( + input: Tensor, + mask: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + masked_select(input, mask, *, out=None) -> Tensor + + Returns a new 1-D tensor which indexes the :attr:`input` tensor according to + the boolean mask :attr:`mask` which is a `BoolTensor`. + + The shapes of the :attr:`mask` tensor and the :attr:`input` tensor don't need + to match, but they must be :ref:`broadcastable `. + + .. note:: The returned tensor does **not** use the same storage + as the original tensor + + Args: + input (Tensor): the input tensor. + mask (BoolTensor): the tensor containing the binary mask to index with + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> x = torch.randn(3, 4) + >>> x + tensor([[ 0.3552, -2.3825, -0.8297, 0.3477], + [-1.2035, 1.2252, 0.5002, 0.6248], + [ 0.1307, -2.0608, 0.1244, 2.0139]]) + >>> mask = x.ge(0.5) + >>> mask + tensor([[False, False, False, False], + [False, True, True, True], + [False, False, False, True]]) + >>> torch.masked_select(x, mask) + tensor([ 1.2252, 0.5002, 0.6248, 2.0139]) + """ + +def matmul( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + matmul(input, other, *, out=None) -> Tensor + + Matrix product of two tensors. + + The behavior depends on the dimensionality of the tensors as follows: + + - If both tensors are 1-dimensional, the dot product (scalar) is returned. + - If both arguments are 2-dimensional, the matrix-matrix product is returned. + - If the first argument is 1-dimensional and the second argument is 2-dimensional, + a 1 is prepended to its dimension for the purpose of the matrix multiply. + After the matrix multiply, the prepended dimension is removed. + - If the first argument is 2-dimensional and the second argument is 1-dimensional, + the matrix-vector product is returned. + - If both arguments are at least 1-dimensional and at least one argument is + N-dimensional (where N > 2), then a batched matrix multiply is returned. If the first + argument is 1-dimensional, a 1 is prepended to its dimension for the purpose of the + batched matrix multiply and removed after. If the second argument is 1-dimensional, a + 1 is appended to its dimension for the purpose of the batched matrix multiple and removed after. + The non-matrix (i.e. batch) dimensions are :ref:`broadcasted ` (and thus + must be broadcastable). For example, if :attr:`input` is a + :math:`(j \times 1 \times n \times n)` tensor and :attr:`other` is a :math:`(k \times n \times n)` + tensor, :attr:`out` will be a :math:`(j \times k \times n \times n)` tensor. + + Note that the broadcasting logic only looks at the batch dimensions when determining if the inputs + are broadcastable, and not the matrix dimensions. For example, if :attr:`input` is a + :math:`(j \times 1 \times n \times m)` tensor and :attr:`other` is a :math:`(k \times m \times p)` + tensor, these inputs are valid for broadcasting even though the final two dimensions (i.e. the + matrix dimensions) are different. :attr:`out` will be a :math:`(j \times k \times n \times p)` tensor. + + This operation has support for arguments with :ref:`sparse layouts`. In particular the + matrix-matrix (both arguments 2-dimensional) supports sparse arguments with the same restrictions + as :func:`torch.mm` + + + .. warning:: + Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported, + or may not have autograd support. If you notice missing functionality please + open a feature request. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + .. note:: + + The 1-dimensional dot product version of this function does not support an :attr:`out` parameter. + + Arguments: + input (Tensor): the first tensor to be multiplied + other (Tensor): the second tensor to be multiplied + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> # vector x vector + >>> tensor1 = torch.randn(3) + >>> tensor2 = torch.randn(3) + >>> torch.matmul(tensor1, tensor2).size() + torch.Size([]) + >>> # matrix x vector + >>> tensor1 = torch.randn(3, 4) + >>> tensor2 = torch.randn(4) + >>> torch.matmul(tensor1, tensor2).size() + torch.Size([3]) + >>> # batched matrix x broadcasted vector + >>> tensor1 = torch.randn(10, 3, 4) + >>> tensor2 = torch.randn(4) + >>> torch.matmul(tensor1, tensor2).size() + torch.Size([10, 3]) + >>> # batched matrix x batched matrix + >>> tensor1 = torch.randn(10, 3, 4) + >>> tensor2 = torch.randn(10, 4, 5) + >>> torch.matmul(tensor1, tensor2).size() + torch.Size([10, 3, 5]) + >>> # batched matrix x broadcasted matrix + >>> tensor1 = torch.randn(10, 3, 4) + >>> tensor2 = torch.randn(4, 5) + >>> torch.matmul(tensor1, tensor2).size() + torch.Size([10, 3, 5]) + """ + +def matrix_exp(input: Tensor) -> Tensor: + r""" + matrix_exp(A) -> Tensor + + Alias for :func:`torch.linalg.matrix_exp`. + """ + +def matrix_power( + input: Tensor, + n: _int, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + matrix_power(input, n, *, out=None) -> Tensor + + Alias for :func:`torch.linalg.matrix_power` + """ + +@overload +def max(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + max(input, *, out=None) -> Tensor + + Returns the maximum value of all elements in the ``input`` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.6763, 0.7445, -2.2369]]) + >>> torch.max(a) + tensor(0.7445) + + .. function:: max(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` is the maximum + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each maximum value found + (argmax). + + If ``keepdim`` is ``True``, the output tensors are of the same size + as ``input`` except in the dimension ``dim`` where they are of size 1. + Otherwise, ``dim`` is squeezed (see :func:`torch.squeeze`), resulting + in the output tensors having 1 fewer dimension than ``input``. + + .. note:: If there are multiple maximal values in a reduced row then + the indices of the first maximal value are returned. + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (max, max_indices) + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-1.2360, -0.2942, -0.1222, 0.8475], + [ 1.1949, -1.1127, -2.2379, -0.6702], + [ 1.5717, -0.9207, 0.1297, -1.8768], + [-0.6172, 1.0036, -0.6060, -0.2432]]) + >>> torch.max(a, 1) + torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1])) + >>> a = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + >>> a.max(dim=1, keepdim=True) + torch.return_types.max( + values=tensor([[2.], [4.]]), + indices=tensor([[1], [1]])) + >>> a.max(dim=1, keepdim=False) + torch.return_types.max( + values=tensor([2., 4.]), + indices=tensor([1, 1])) + + .. function:: max(input, other, *, out=None) -> Tensor + :noindex: + + See :func:`torch.maximum`. + """ + +@overload +def max( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + max(input, *, out=None) -> Tensor + + Returns the maximum value of all elements in the ``input`` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.6763, 0.7445, -2.2369]]) + >>> torch.max(a) + tensor(0.7445) + + .. function:: max(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` is the maximum + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each maximum value found + (argmax). + + If ``keepdim`` is ``True``, the output tensors are of the same size + as ``input`` except in the dimension ``dim`` where they are of size 1. + Otherwise, ``dim`` is squeezed (see :func:`torch.squeeze`), resulting + in the output tensors having 1 fewer dimension than ``input``. + + .. note:: If there are multiple maximal values in a reduced row then + the indices of the first maximal value are returned. + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (max, max_indices) + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-1.2360, -0.2942, -0.1222, 0.8475], + [ 1.1949, -1.1127, -2.2379, -0.6702], + [ 1.5717, -0.9207, 0.1297, -1.8768], + [-0.6172, 1.0036, -0.6060, -0.2432]]) + >>> torch.max(a, 1) + torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1])) + >>> a = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + >>> a.max(dim=1, keepdim=True) + torch.return_types.max( + values=tensor([[2.], [4.]]), + indices=tensor([[1], [1]])) + >>> a.max(dim=1, keepdim=False) + torch.return_types.max( + values=tensor([2., 4.]), + indices=tensor([1, 1])) + + .. function:: max(input, other, *, out=None) -> Tensor + :noindex: + + See :func:`torch.maximum`. + """ + +@overload +def max( + input: Tensor, + dim: _int, + keepdim: _bool = False, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.max: + r""" + max(input, *, out=None) -> Tensor + + Returns the maximum value of all elements in the ``input`` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.6763, 0.7445, -2.2369]]) + >>> torch.max(a) + tensor(0.7445) + + .. function:: max(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` is the maximum + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each maximum value found + (argmax). + + If ``keepdim`` is ``True``, the output tensors are of the same size + as ``input`` except in the dimension ``dim`` where they are of size 1. + Otherwise, ``dim`` is squeezed (see :func:`torch.squeeze`), resulting + in the output tensors having 1 fewer dimension than ``input``. + + .. note:: If there are multiple maximal values in a reduced row then + the indices of the first maximal value are returned. + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (max, max_indices) + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-1.2360, -0.2942, -0.1222, 0.8475], + [ 1.1949, -1.1127, -2.2379, -0.6702], + [ 1.5717, -0.9207, 0.1297, -1.8768], + [-0.6172, 1.0036, -0.6060, -0.2432]]) + >>> torch.max(a, 1) + torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1])) + >>> a = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + >>> a.max(dim=1, keepdim=True) + torch.return_types.max( + values=tensor([[2.], [4.]]), + indices=tensor([[1], [1]])) + >>> a.max(dim=1, keepdim=False) + torch.return_types.max( + values=tensor([2., 4.]), + indices=tensor([1, 1])) + + .. function:: max(input, other, *, out=None) -> Tensor + :noindex: + + See :func:`torch.maximum`. + """ + +@overload +def max( + input: Tensor, + dim: str | EllipsisType | None, + keepdim: _bool = False, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.max: + r""" + max(input, *, out=None) -> Tensor + + Returns the maximum value of all elements in the ``input`` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.6763, 0.7445, -2.2369]]) + >>> torch.max(a) + tensor(0.7445) + + .. function:: max(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` is the maximum + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each maximum value found + (argmax). + + If ``keepdim`` is ``True``, the output tensors are of the same size + as ``input`` except in the dimension ``dim`` where they are of size 1. + Otherwise, ``dim`` is squeezed (see :func:`torch.squeeze`), resulting + in the output tensors having 1 fewer dimension than ``input``. + + .. note:: If there are multiple maximal values in a reduced row then + the indices of the first maximal value are returned. + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (max, max_indices) + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-1.2360, -0.2942, -0.1222, 0.8475], + [ 1.1949, -1.1127, -2.2379, -0.6702], + [ 1.5717, -0.9207, 0.1297, -1.8768], + [-0.6172, 1.0036, -0.6060, -0.2432]]) + >>> torch.max(a, 1) + torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1])) + >>> a = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + >>> a.max(dim=1, keepdim=True) + torch.return_types.max( + values=tensor([[2.], [4.]]), + indices=tensor([[1], [1]])) + >>> a.max(dim=1, keepdim=False) + torch.return_types.max( + values=tensor([2., 4.]), + indices=tensor([1, 1])) + + .. function:: max(input, other, *, out=None) -> Tensor + :noindex: + + See :func:`torch.maximum`. + """ + +def max_pool1d( + input: Tensor, + kernel_size: _int | _size, + stride: _int | _size = (), + padding: _int | _size = 0, + dilation: _int | _size = 1, + ceil_mode: _bool = False, +) -> Tensor: ... +def max_pool1d_with_indices( + input: Tensor, + kernel_size: _int | _size, + stride: _int | _size = (), + padding: _int | _size = 0, + dilation: _int | _size = 1, + ceil_mode: _bool = False, +) -> tuple[Tensor, Tensor]: ... +def max_pool2d( + input: Tensor, + kernel_size: _int | _size, + stride: _int | _size = (), + padding: _int | _size = 0, + dilation: _int | _size = 1, + ceil_mode: _bool = False, +) -> Tensor: ... +def max_pool3d( + input: Tensor, + kernel_size: _int | _size, + stride: _int | _size = (), + padding: _int | _size = 0, + dilation: _int | _size = 1, + ceil_mode: _bool = False, +) -> Tensor: ... +def maximum( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + maximum(input, other, *, out=None) -> Tensor + + Computes the element-wise maximum of :attr:`input` and :attr:`other`. + + .. note:: + If one of the elements being compared is a NaN, then that element is returned. + :func:`maximum` is not supported for tensors with complex dtypes. + + Args: + input (Tensor): the input tensor. + other (Tensor): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor((1, 2, -1)) + >>> b = torch.tensor((3, 0, 4)) + >>> torch.maximum(a, b) + tensor([3, 2, 4]) + """ + +@overload +def mean( + input: Tensor, + *, + dtype: _dtype | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + mean(input, *, dtype=None) -> Tensor + + .. note:: + If the `input` tensor is empty, ``torch.mean()`` returns ``nan``. + This behavior is consistent with NumPy and follows the definition + that the mean over an empty set is undefined. + + + Returns the mean value of all elements in the :attr:`input` tensor. Input must be floating point or complex. + + Args: + input (Tensor): + the input tensor, either of floating point or complex dtype + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.2294, -0.5481, 1.3288]]) + >>> torch.mean(a) + tensor(0.3367) + + .. function:: mean(input, dim, keepdim=False, *, dtype=None, out=None) -> Tensor + :noindex: + + Returns the mean value of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`. If :attr:`dim` is a list of dimensions, + reduce over all of them. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + out (Tensor, optional): the output tensor. + + .. seealso:: + + :func:`torch.nanmean` computes the mean value of `non-NaN` elements. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-0.3841, 0.6320, 0.4254, -0.7384], + [-0.9644, 1.0131, -0.6549, -1.4279], + [-0.2951, -1.3350, -0.7694, 0.5600], + [ 1.0842, -0.9580, 0.3623, 0.2343]]) + >>> torch.mean(a, 1) + tensor([-0.0163, -0.5085, -0.4599, 0.1807]) + >>> torch.mean(a, 1, True) + tensor([[-0.0163], + [-0.5085], + [-0.4599], + [ 0.1807]]) + """ + +@overload +def mean( + input: Tensor, + dim: _int | _size | None, + keepdim: _bool = False, + *, + dtype: _dtype | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + mean(input, *, dtype=None) -> Tensor + + .. note:: + If the `input` tensor is empty, ``torch.mean()`` returns ``nan``. + This behavior is consistent with NumPy and follows the definition + that the mean over an empty set is undefined. + + + Returns the mean value of all elements in the :attr:`input` tensor. Input must be floating point or complex. + + Args: + input (Tensor): + the input tensor, either of floating point or complex dtype + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.2294, -0.5481, 1.3288]]) + >>> torch.mean(a) + tensor(0.3367) + + .. function:: mean(input, dim, keepdim=False, *, dtype=None, out=None) -> Tensor + :noindex: + + Returns the mean value of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`. If :attr:`dim` is a list of dimensions, + reduce over all of them. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + out (Tensor, optional): the output tensor. + + .. seealso:: + + :func:`torch.nanmean` computes the mean value of `non-NaN` elements. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-0.3841, 0.6320, 0.4254, -0.7384], + [-0.9644, 1.0131, -0.6549, -1.4279], + [-0.2951, -1.3350, -0.7694, 0.5600], + [ 1.0842, -0.9580, 0.3623, 0.2343]]) + >>> torch.mean(a, 1) + tensor([-0.0163, -0.5085, -0.4599, 0.1807]) + >>> torch.mean(a, 1, True) + tensor([[-0.0163], + [-0.5085], + [-0.4599], + [ 0.1807]]) + """ + +@overload +def mean( + input: Tensor, + dim: Sequence[str | EllipsisType | None], + keepdim: _bool = False, + *, + dtype: _dtype | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + mean(input, *, dtype=None) -> Tensor + + .. note:: + If the `input` tensor is empty, ``torch.mean()`` returns ``nan``. + This behavior is consistent with NumPy and follows the definition + that the mean over an empty set is undefined. + + + Returns the mean value of all elements in the :attr:`input` tensor. Input must be floating point or complex. + + Args: + input (Tensor): + the input tensor, either of floating point or complex dtype + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.2294, -0.5481, 1.3288]]) + >>> torch.mean(a) + tensor(0.3367) + + .. function:: mean(input, dim, keepdim=False, *, dtype=None, out=None) -> Tensor + :noindex: + + Returns the mean value of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`. If :attr:`dim` is a list of dimensions, + reduce over all of them. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + out (Tensor, optional): the output tensor. + + .. seealso:: + + :func:`torch.nanmean` computes the mean value of `non-NaN` elements. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-0.3841, 0.6320, 0.4254, -0.7384], + [-0.9644, 1.0131, -0.6549, -1.4279], + [-0.2951, -1.3350, -0.7694, 0.5600], + [ 1.0842, -0.9580, 0.3623, 0.2343]]) + >>> torch.mean(a, 1) + tensor([-0.0163, -0.5085, -0.4599, 0.1807]) + >>> torch.mean(a, 1, True) + tensor([[-0.0163], + [-0.5085], + [-0.4599], + [ 0.1807]]) + """ + +@overload +def median(input: Tensor) -> Tensor: + r""" + median(input) -> Tensor + + Returns the median of the values in :attr:`input`. + + .. note:: + The median is not unique for :attr:`input` tensors with an even number + of elements. In this case the lower of the two medians is returned. To + compute the mean of both medians, use :func:`torch.quantile` with ``q=0.5`` instead. + + .. warning:: + This function produces deterministic (sub)gradients unlike ``median(dim=0)`` + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 1.5219, -1.5212, 0.2202]]) + >>> torch.median(a) + tensor(0.2202) + + .. function:: median(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` contains the median of each row of :attr:`input` + in the dimension :attr:`dim`, and ``indices`` contains the index of the median values found in the dimension :attr:`dim`. + + By default, :attr:`dim` is the last dimension of the :attr:`input` tensor. + + If :attr:`keepdim` is ``True``, the output tensors are of the same size + as :attr:`input` except in the dimension :attr:`dim` where they are of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the outputs tensor having 1 fewer dimension than :attr:`input`. + + .. note:: + The median is not unique for :attr:`input` tensors with an even number + of elements in the dimension :attr:`dim`. In this case the lower of the + two medians is returned. To compute the mean of both medians in + :attr:`input`, use :func:`torch.quantile` with ``q=0.5`` instead. + + .. warning:: + ``indices`` does not necessarily contain the first occurrence of each + median value found, unless it is unique. + The exact implementation details are device-specific. + Do not expect the same result when run on CPU and GPU in general. + For the same reason do not expect the gradients to be deterministic. + + Args: + input (Tensor): the input tensor. + + dim (int, optional): the dimension to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out ((Tensor, Tensor), optional): The first tensor will be populated with the median values and the second + tensor, which must have dtype long, with their indices in the dimension + :attr:`dim` of :attr:`input`. + + Example:: + + >>> a = torch.randn(4, 5) + >>> a + tensor([[ 0.2505, -0.3982, -0.9948, 0.3518, -1.3131], + [ 0.3180, -0.6993, 1.0436, 0.0438, 0.2270], + [-0.2751, 0.7303, 0.2192, 0.3321, 0.2488], + [ 1.0778, -1.9510, 0.7048, 0.4742, -0.7125]]) + >>> torch.median(a, 1) + torch.return_types.median(values=tensor([-0.3982, 0.2270, 0.2488, 0.4742]), indices=tensor([1, 4, 4, 3])) + """ + +@overload +def median( + input: Tensor, + dim: _int, + keepdim: _bool = False, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.median: + r""" + median(input) -> Tensor + + Returns the median of the values in :attr:`input`. + + .. note:: + The median is not unique for :attr:`input` tensors with an even number + of elements. In this case the lower of the two medians is returned. To + compute the mean of both medians, use :func:`torch.quantile` with ``q=0.5`` instead. + + .. warning:: + This function produces deterministic (sub)gradients unlike ``median(dim=0)`` + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 1.5219, -1.5212, 0.2202]]) + >>> torch.median(a) + tensor(0.2202) + + .. function:: median(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` contains the median of each row of :attr:`input` + in the dimension :attr:`dim`, and ``indices`` contains the index of the median values found in the dimension :attr:`dim`. + + By default, :attr:`dim` is the last dimension of the :attr:`input` tensor. + + If :attr:`keepdim` is ``True``, the output tensors are of the same size + as :attr:`input` except in the dimension :attr:`dim` where they are of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the outputs tensor having 1 fewer dimension than :attr:`input`. + + .. note:: + The median is not unique for :attr:`input` tensors with an even number + of elements in the dimension :attr:`dim`. In this case the lower of the + two medians is returned. To compute the mean of both medians in + :attr:`input`, use :func:`torch.quantile` with ``q=0.5`` instead. + + .. warning:: + ``indices`` does not necessarily contain the first occurrence of each + median value found, unless it is unique. + The exact implementation details are device-specific. + Do not expect the same result when run on CPU and GPU in general. + For the same reason do not expect the gradients to be deterministic. + + Args: + input (Tensor): the input tensor. + + dim (int, optional): the dimension to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out ((Tensor, Tensor), optional): The first tensor will be populated with the median values and the second + tensor, which must have dtype long, with their indices in the dimension + :attr:`dim` of :attr:`input`. + + Example:: + + >>> a = torch.randn(4, 5) + >>> a + tensor([[ 0.2505, -0.3982, -0.9948, 0.3518, -1.3131], + [ 0.3180, -0.6993, 1.0436, 0.0438, 0.2270], + [-0.2751, 0.7303, 0.2192, 0.3321, 0.2488], + [ 1.0778, -1.9510, 0.7048, 0.4742, -0.7125]]) + >>> torch.median(a, 1) + torch.return_types.median(values=tensor([-0.3982, 0.2270, 0.2488, 0.4742]), indices=tensor([1, 4, 4, 3])) + """ + +@overload +def median( + input: Tensor, + dim: str | EllipsisType | None, + keepdim: _bool = False, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.median: + r""" + median(input) -> Tensor + + Returns the median of the values in :attr:`input`. + + .. note:: + The median is not unique for :attr:`input` tensors with an even number + of elements. In this case the lower of the two medians is returned. To + compute the mean of both medians, use :func:`torch.quantile` with ``q=0.5`` instead. + + .. warning:: + This function produces deterministic (sub)gradients unlike ``median(dim=0)`` + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 1.5219, -1.5212, 0.2202]]) + >>> torch.median(a) + tensor(0.2202) + + .. function:: median(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` contains the median of each row of :attr:`input` + in the dimension :attr:`dim`, and ``indices`` contains the index of the median values found in the dimension :attr:`dim`. + + By default, :attr:`dim` is the last dimension of the :attr:`input` tensor. + + If :attr:`keepdim` is ``True``, the output tensors are of the same size + as :attr:`input` except in the dimension :attr:`dim` where they are of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the outputs tensor having 1 fewer dimension than :attr:`input`. + + .. note:: + The median is not unique for :attr:`input` tensors with an even number + of elements in the dimension :attr:`dim`. In this case the lower of the + two medians is returned. To compute the mean of both medians in + :attr:`input`, use :func:`torch.quantile` with ``q=0.5`` instead. + + .. warning:: + ``indices`` does not necessarily contain the first occurrence of each + median value found, unless it is unique. + The exact implementation details are device-specific. + Do not expect the same result when run on CPU and GPU in general. + For the same reason do not expect the gradients to be deterministic. + + Args: + input (Tensor): the input tensor. + + dim (int, optional): the dimension to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out ((Tensor, Tensor), optional): The first tensor will be populated with the median values and the second + tensor, which must have dtype long, with their indices in the dimension + :attr:`dim` of :attr:`input`. + + Example:: + + >>> a = torch.randn(4, 5) + >>> a + tensor([[ 0.2505, -0.3982, -0.9948, 0.3518, -1.3131], + [ 0.3180, -0.6993, 1.0436, 0.0438, 0.2270], + [-0.2751, 0.7303, 0.2192, 0.3321, 0.2488], + [ 1.0778, -1.9510, 0.7048, 0.4742, -0.7125]]) + >>> torch.median(a, 1) + torch.return_types.median(values=tensor([-0.3982, 0.2270, 0.2488, 0.4742]), indices=tensor([1, 4, 4, 3])) + """ + +@overload +def min(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + min(input, *, out=None) -> Tensor + + Returns the minimum value of all elements in the :attr:`input` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.6750, 1.0857, 1.7197]]) + >>> torch.min(a) + tensor(0.6750) + + .. function:: min(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` is the minimum + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each minimum value found + (argmin). + + If :attr:`keepdim` is ``True``, the output tensors are of the same size as + :attr:`input` except in the dimension :attr:`dim` where they are of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the output tensors having 1 fewer dimension than :attr:`input`. + + .. note:: If there are multiple minimal values in a reduced row then + the indices of the first minimal value are returned. + + Args: + input (Tensor): the input tensor. + + dim (int, optional): the dimension to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (tuple, optional): the tuple of two output tensors (min, min_indices) + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-0.6248, 1.1334, -1.1899, -0.2803], + [-1.4644, -0.2635, -0.3651, 0.6134], + [ 0.2457, 0.0384, 1.0128, 0.7015], + [-0.1153, 2.9849, 2.1458, 0.5788]]) + >>> torch.min(a, 1) + torch.return_types.min(values=tensor([-1.1899, -1.4644, 0.0384, -0.1153]), indices=tensor([2, 0, 1, 0])) + + .. function:: min(input, other, *, out=None) -> Tensor + :noindex: + + See :func:`torch.minimum`. + """ + +@overload +def min( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + min(input, *, out=None) -> Tensor + + Returns the minimum value of all elements in the :attr:`input` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.6750, 1.0857, 1.7197]]) + >>> torch.min(a) + tensor(0.6750) + + .. function:: min(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` is the minimum + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each minimum value found + (argmin). + + If :attr:`keepdim` is ``True``, the output tensors are of the same size as + :attr:`input` except in the dimension :attr:`dim` where they are of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the output tensors having 1 fewer dimension than :attr:`input`. + + .. note:: If there are multiple minimal values in a reduced row then + the indices of the first minimal value are returned. + + Args: + input (Tensor): the input tensor. + + dim (int, optional): the dimension to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (tuple, optional): the tuple of two output tensors (min, min_indices) + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-0.6248, 1.1334, -1.1899, -0.2803], + [-1.4644, -0.2635, -0.3651, 0.6134], + [ 0.2457, 0.0384, 1.0128, 0.7015], + [-0.1153, 2.9849, 2.1458, 0.5788]]) + >>> torch.min(a, 1) + torch.return_types.min(values=tensor([-1.1899, -1.4644, 0.0384, -0.1153]), indices=tensor([2, 0, 1, 0])) + + .. function:: min(input, other, *, out=None) -> Tensor + :noindex: + + See :func:`torch.minimum`. + """ + +@overload +def min( + input: Tensor, + dim: _int, + keepdim: _bool = False, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.min: + r""" + min(input, *, out=None) -> Tensor + + Returns the minimum value of all elements in the :attr:`input` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.6750, 1.0857, 1.7197]]) + >>> torch.min(a) + tensor(0.6750) + + .. function:: min(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` is the minimum + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each minimum value found + (argmin). + + If :attr:`keepdim` is ``True``, the output tensors are of the same size as + :attr:`input` except in the dimension :attr:`dim` where they are of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the output tensors having 1 fewer dimension than :attr:`input`. + + .. note:: If there are multiple minimal values in a reduced row then + the indices of the first minimal value are returned. + + Args: + input (Tensor): the input tensor. + + dim (int, optional): the dimension to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (tuple, optional): the tuple of two output tensors (min, min_indices) + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-0.6248, 1.1334, -1.1899, -0.2803], + [-1.4644, -0.2635, -0.3651, 0.6134], + [ 0.2457, 0.0384, 1.0128, 0.7015], + [-0.1153, 2.9849, 2.1458, 0.5788]]) + >>> torch.min(a, 1) + torch.return_types.min(values=tensor([-1.1899, -1.4644, 0.0384, -0.1153]), indices=tensor([2, 0, 1, 0])) + + .. function:: min(input, other, *, out=None) -> Tensor + :noindex: + + See :func:`torch.minimum`. + """ + +@overload +def min( + input: Tensor, + dim: str | EllipsisType | None, + keepdim: _bool = False, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.min: + r""" + min(input, *, out=None) -> Tensor + + Returns the minimum value of all elements in the :attr:`input` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.6750, 1.0857, 1.7197]]) + >>> torch.min(a) + tensor(0.6750) + + .. function:: min(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` is the minimum + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each minimum value found + (argmin). + + If :attr:`keepdim` is ``True``, the output tensors are of the same size as + :attr:`input` except in the dimension :attr:`dim` where they are of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the output tensors having 1 fewer dimension than :attr:`input`. + + .. note:: If there are multiple minimal values in a reduced row then + the indices of the first minimal value are returned. + + Args: + input (Tensor): the input tensor. + + dim (int, optional): the dimension to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (tuple, optional): the tuple of two output tensors (min, min_indices) + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-0.6248, 1.1334, -1.1899, -0.2803], + [-1.4644, -0.2635, -0.3651, 0.6134], + [ 0.2457, 0.0384, 1.0128, 0.7015], + [-0.1153, 2.9849, 2.1458, 0.5788]]) + >>> torch.min(a, 1) + torch.return_types.min(values=tensor([-1.1899, -1.4644, 0.0384, -0.1153]), indices=tensor([2, 0, 1, 0])) + + .. function:: min(input, other, *, out=None) -> Tensor + :noindex: + + See :func:`torch.minimum`. + """ + +def minimum( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + minimum(input, other, *, out=None) -> Tensor + + Computes the element-wise minimum of :attr:`input` and :attr:`other`. + + .. note:: + If one of the elements being compared is a NaN, then that element is returned. + :func:`minimum` is not supported for tensors with complex dtypes. + + Args: + input (Tensor): the input tensor. + other (Tensor): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor((1, 2, -1)) + >>> b = torch.tensor((3, 0, 4)) + >>> torch.minimum(a, b) + tensor([1, 0, -1]) + """ + +def miopen_batch_norm( + input: Tensor, + weight: Tensor, + bias: Tensor | None, + running_mean: Tensor | None, + running_var: Tensor | None, + training: _bool, + exponential_average_factor: _float, + epsilon: _float, +) -> tuple[Tensor, Tensor, Tensor]: ... +def miopen_convolution( + input: Tensor, + weight: Tensor, + bias: Tensor | None, + padding: Sequence[_int | SymInt], + stride: Sequence[_int | SymInt], + dilation: Sequence[_int | SymInt], + groups: _int | SymInt, + benchmark: _bool, + deterministic: _bool, +) -> Tensor: ... +def miopen_convolution_add_relu( + input: Tensor, + weight: Tensor, + z: Tensor, + alpha: Number | _complex | None, + bias: Tensor | None, + stride: Sequence[_int | SymInt], + padding: Sequence[_int | SymInt], + dilation: Sequence[_int | SymInt], + groups: _int | SymInt, +) -> Tensor: ... +def miopen_convolution_relu( + input: Tensor, + weight: Tensor, + bias: Tensor | None, + stride: Sequence[_int | SymInt], + padding: Sequence[_int | SymInt], + dilation: Sequence[_int | SymInt], + groups: _int | SymInt, +) -> Tensor: ... +def miopen_convolution_transpose( + input: Tensor, + weight: Tensor, + bias: Tensor | None, + padding: Sequence[_int | SymInt], + output_padding: Sequence[_int | SymInt], + stride: Sequence[_int | SymInt], + dilation: Sequence[_int | SymInt], + groups: _int | SymInt, + benchmark: _bool, + deterministic: _bool, +) -> Tensor: ... +def miopen_depthwise_convolution( + input: Tensor, + weight: Tensor, + bias: Tensor | None, + padding: Sequence[_int | SymInt], + stride: Sequence[_int | SymInt], + dilation: Sequence[_int | SymInt], + groups: _int | SymInt, + benchmark: _bool, + deterministic: _bool, +) -> Tensor: ... +def miopen_rnn( + input: Tensor, + weight: tuple[Tensor, ...] | list[Tensor] | None, + weight_stride0: _int, + hx: Tensor, + cx: Tensor | None, + mode: _int, + hidden_size: _int, + num_layers: _int, + batch_first: _bool, + dropout: _float, + train: _bool, + bidirectional: _bool, + batch_sizes: _size, + dropout_state: Tensor | None, +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: ... +def mkldnn_adaptive_avg_pool2d( + input: Tensor, + output_size: _int | _size, + *, + out: Tensor | None = None, +) -> Tensor: ... +def mkldnn_convolution( + input: Tensor, + weight: Tensor, + bias: Tensor | None, + padding: Sequence[_int | SymInt], + stride: Sequence[_int | SymInt], + dilation: Sequence[_int | SymInt], + groups: _int | SymInt, +) -> Tensor: ... +def mkldnn_linear_backward_weights( + grad_output: Tensor, + input: Tensor, + weight: Tensor, + bias_defined: _bool, +) -> tuple[Tensor, Tensor]: ... +def mkldnn_max_pool2d( + input: Tensor, + kernel_size: _int | _size, + stride: _int | _size = (), + padding: _int | _size = 0, + dilation: _int | _size = 1, + ceil_mode: _bool = False, +) -> Tensor: ... +def mkldnn_max_pool3d( + input: Tensor, + kernel_size: _int | _size, + stride: _int | _size = (), + padding: _int | _size = 0, + dilation: _int | _size = 1, + ceil_mode: _bool = False, +) -> Tensor: ... +def mkldnn_rnn_layer( + input: Tensor, + weight0: Tensor, + weight1: Tensor, + weight2: Tensor, + weight3: Tensor, + hx_: Tensor, + cx_: Tensor, + reverse: _bool, + batch_sizes: _size, + mode: _int, + hidden_size: _int, + num_layers: _int, + has_biases: _bool, + bidirectional: _bool, + batch_first: _bool, + train: _bool, +) -> tuple[Tensor, Tensor, Tensor, Tensor]: ... +@overload +def mm(input: Tensor, mat2: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + mm(input, mat2, out_dtype=None, *, out=None) -> Tensor + + Performs a matrix multiplication of the matrices :attr:`input` and :attr:`mat2`. + + If :attr:`input` is a :math:`(n \times m)` tensor, :attr:`mat2` is a + :math:`(m \times p)` tensor, :attr:`out` will be a :math:`(n \times p)` tensor. + + .. note:: This function does not :ref:`broadcast `. + For broadcasting matrix products, see :func:`torch.matmul`. + + Supports strided and sparse 2-D tensors as inputs, autograd with + respect to strided inputs. + + This operation has support for arguments with :ref:`sparse layouts`. + If :attr:`out` is provided its layout will be used. Otherwise, the result + layout will be deduced from that of :attr:`input`. + + + .. warning:: + Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported, + or may not have autograd support. If you notice missing functionality please + open a feature request. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): the first matrix to be matrix multiplied + mat2 (Tensor): the second matrix to be matrix multiplied + out_dtype (dtype, optional): the dtype of the output tensor, + Supported only on CUDA and for torch.float32 given + torch.float16/torch.bfloat16 input dtypes + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> mat1 = torch.randn(2, 3) + >>> mat2 = torch.randn(3, 3) + >>> torch.mm(mat1, mat2) + tensor([[ 0.4851, 0.5037, -0.3633], + [-0.0760, -3.6705, 2.4784]]) + """ + +@overload +def mm( + input: Tensor, + mat2: Tensor, + out_dtype: _dtype, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + mm(input, mat2, out_dtype=None, *, out=None) -> Tensor + + Performs a matrix multiplication of the matrices :attr:`input` and :attr:`mat2`. + + If :attr:`input` is a :math:`(n \times m)` tensor, :attr:`mat2` is a + :math:`(m \times p)` tensor, :attr:`out` will be a :math:`(n \times p)` tensor. + + .. note:: This function does not :ref:`broadcast `. + For broadcasting matrix products, see :func:`torch.matmul`. + + Supports strided and sparse 2-D tensors as inputs, autograd with + respect to strided inputs. + + This operation has support for arguments with :ref:`sparse layouts`. + If :attr:`out` is provided its layout will be used. Otherwise, the result + layout will be deduced from that of :attr:`input`. + + + .. warning:: + Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported, + or may not have autograd support. If you notice missing functionality please + open a feature request. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): the first matrix to be matrix multiplied + mat2 (Tensor): the second matrix to be matrix multiplied + out_dtype (dtype, optional): the dtype of the output tensor, + Supported only on CUDA and for torch.float32 given + torch.float16/torch.bfloat16 input dtypes + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> mat1 = torch.randn(2, 3) + >>> mat2 = torch.randn(3, 3) + >>> torch.mm(mat1, mat2) + tensor([[ 0.4851, 0.5037, -0.3633], + [-0.0760, -3.6705, 2.4784]]) + """ + +@overload +def mode( + input: Tensor, + dim: _int = -1, + keepdim: _bool = False, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.mode: + r""" + mode(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + + Returns a namedtuple ``(values, indices)`` where ``values`` is the mode + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`, i.e. a value which appears most often + in that row, and ``indices`` is the index location of each mode value found. + + By default, :attr:`dim` is the last dimension of the :attr:`input` tensor. + + If :attr:`keepdim` is ``True``, the output tensors are of the same size as + :attr:`input` except in the dimension :attr:`dim` where they are of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting + in the output tensors having 1 fewer dimension than :attr:`input`. + + .. note:: This function is not defined for ``torch.cuda.Tensor`` yet. + + Args: + input (Tensor): the input tensor. + + dim (int, optional): the dimension to reduce. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (values, indices) + + Example:: + + >>> b = torch.tensor([[0, 0, 0, 2, 0, 0, 2], + ... [0, 3, 0, 0, 2, 0, 1], + ... [2, 2, 2, 0, 0, 0, 3], + ... [2, 2, 3, 0, 1, 1, 0], + ... [1, 1, 0, 0, 2, 0, 2]]) + >>> torch.mode(b, 0) + torch.return_types.mode( + values=tensor([0, 2, 0, 0, 0, 0, 2]), + indices=tensor([1, 3, 4, 4, 2, 4, 4])) + """ + +@overload +def mode( + input: Tensor, + dim: str | EllipsisType | None, + keepdim: _bool = False, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.mode: + r""" + mode(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + + Returns a namedtuple ``(values, indices)`` where ``values`` is the mode + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`, i.e. a value which appears most often + in that row, and ``indices`` is the index location of each mode value found. + + By default, :attr:`dim` is the last dimension of the :attr:`input` tensor. + + If :attr:`keepdim` is ``True``, the output tensors are of the same size as + :attr:`input` except in the dimension :attr:`dim` where they are of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting + in the output tensors having 1 fewer dimension than :attr:`input`. + + .. note:: This function is not defined for ``torch.cuda.Tensor`` yet. + + Args: + input (Tensor): the input tensor. + + dim (int, optional): the dimension to reduce. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (values, indices) + + Example:: + + >>> b = torch.tensor([[0, 0, 0, 2, 0, 0, 2], + ... [0, 3, 0, 0, 2, 0, 1], + ... [2, 2, 2, 0, 0, 0, 3], + ... [2, 2, 3, 0, 1, 1, 0], + ... [1, 1, 0, 0, 2, 0, 2]]) + >>> torch.mode(b, 0) + torch.return_types.mode( + values=tensor([0, 2, 0, 0, 0, 0, 2]), + indices=tensor([1, 3, 4, 4, 2, 4, 4])) + """ + +@overload +def moveaxis(input: Tensor, source: _int, destination: _int) -> Tensor: + r""" + moveaxis(input, source, destination) -> Tensor + + Alias for :func:`torch.movedim`. + + This function is equivalent to NumPy's moveaxis function. + + Examples:: + + >>> t = torch.randn(3,2,1) + >>> t + tensor([[[-0.3362], + [-0.8437]], + + [[-0.9627], + [ 0.1727]], + + [[ 0.5173], + [-0.1398]]]) + >>> torch.moveaxis(t, 1, 0).shape + torch.Size([2, 3, 1]) + >>> torch.moveaxis(t, 1, 0) + tensor([[[-0.3362], + [-0.9627], + [ 0.5173]], + + [[-0.8437], + [ 0.1727], + [-0.1398]]]) + >>> torch.moveaxis(t, (1, 2), (0, 1)).shape + torch.Size([2, 1, 3]) + >>> torch.moveaxis(t, (1, 2), (0, 1)) + tensor([[[-0.3362, -0.9627, 0.5173]], + + [[-0.8437, 0.1727, -0.1398]]]) + """ + +@overload +def moveaxis(input: Tensor, source: _size, destination: _size) -> Tensor: + r""" + moveaxis(input, source, destination) -> Tensor + + Alias for :func:`torch.movedim`. + + This function is equivalent to NumPy's moveaxis function. + + Examples:: + + >>> t = torch.randn(3,2,1) + >>> t + tensor([[[-0.3362], + [-0.8437]], + + [[-0.9627], + [ 0.1727]], + + [[ 0.5173], + [-0.1398]]]) + >>> torch.moveaxis(t, 1, 0).shape + torch.Size([2, 3, 1]) + >>> torch.moveaxis(t, 1, 0) + tensor([[[-0.3362], + [-0.9627], + [ 0.5173]], + + [[-0.8437], + [ 0.1727], + [-0.1398]]]) + >>> torch.moveaxis(t, (1, 2), (0, 1)).shape + torch.Size([2, 1, 3]) + >>> torch.moveaxis(t, (1, 2), (0, 1)) + tensor([[[-0.3362, -0.9627, 0.5173]], + + [[-0.8437, 0.1727, -0.1398]]]) + """ + +@overload +def movedim(input: Tensor, source: _int, destination: _int) -> Tensor: + r""" + movedim(input, source, destination) -> Tensor + + Moves the dimension(s) of :attr:`input` at the position(s) in :attr:`source` + to the position(s) in :attr:`destination`. + + Other dimensions of :attr:`input` that are not explicitly moved remain in + their original order and appear at the positions not specified in :attr:`destination`. + + Args: + input (Tensor): the input tensor. + source (int or tuple of ints): Original positions of the dims to move. These must be unique. + destination (int or tuple of ints): Destination positions for each of the original dims. These must also be unique. + + Examples:: + + >>> t = torch.randn(3,2,1) + >>> t + tensor([[[-0.3362], + [-0.8437]], + + [[-0.9627], + [ 0.1727]], + + [[ 0.5173], + [-0.1398]]]) + >>> torch.movedim(t, 1, 0).shape + torch.Size([2, 3, 1]) + >>> torch.movedim(t, 1, 0) + tensor([[[-0.3362], + [-0.9627], + [ 0.5173]], + + [[-0.8437], + [ 0.1727], + [-0.1398]]]) + >>> torch.movedim(t, (1, 2), (0, 1)).shape + torch.Size([2, 1, 3]) + >>> torch.movedim(t, (1, 2), (0, 1)) + tensor([[[-0.3362, -0.9627, 0.5173]], + + [[-0.8437, 0.1727, -0.1398]]]) + """ + +@overload +def movedim(input: Tensor, source: _size, destination: _size) -> Tensor: + r""" + movedim(input, source, destination) -> Tensor + + Moves the dimension(s) of :attr:`input` at the position(s) in :attr:`source` + to the position(s) in :attr:`destination`. + + Other dimensions of :attr:`input` that are not explicitly moved remain in + their original order and appear at the positions not specified in :attr:`destination`. + + Args: + input (Tensor): the input tensor. + source (int or tuple of ints): Original positions of the dims to move. These must be unique. + destination (int or tuple of ints): Destination positions for each of the original dims. These must also be unique. + + Examples:: + + >>> t = torch.randn(3,2,1) + >>> t + tensor([[[-0.3362], + [-0.8437]], + + [[-0.9627], + [ 0.1727]], + + [[ 0.5173], + [-0.1398]]]) + >>> torch.movedim(t, 1, 0).shape + torch.Size([2, 3, 1]) + >>> torch.movedim(t, 1, 0) + tensor([[[-0.3362], + [-0.9627], + [ 0.5173]], + + [[-0.8437], + [ 0.1727], + [-0.1398]]]) + >>> torch.movedim(t, (1, 2), (0, 1)).shape + torch.Size([2, 1, 3]) + >>> torch.movedim(t, (1, 2), (0, 1)) + tensor([[[-0.3362, -0.9627, 0.5173]], + + [[-0.8437, 0.1727, -0.1398]]]) + """ + +def msort(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + msort(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Sorts the elements of the :attr:`input` tensor along its first dimension + in ascending order by value. + + .. note:: `torch.msort(t)` is equivalent to `torch.sort(t, dim=0)[0]`. + See also :func:`torch.sort`. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.randn(3, 4) + >>> t + tensor([[-0.1321, 0.4370, -1.2631, -1.1289], + [-2.0527, -1.1250, 0.2275, 0.3077], + [-0.0881, -0.1259, -0.5495, 1.0284]]) + >>> torch.msort(t) + tensor([[-2.0527, -1.1250, -1.2631, -1.1289], + [-0.1321, -0.1259, -0.5495, 0.3077], + [-0.0881, 0.4370, 0.2275, 1.0284]]) + """ + +def mul( + input: Tensor | Number | _complex, + other: Tensor | Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + mul(input, other, *, out=None) -> Tensor + + Multiplies :attr:`input` by :attr:`other`. + + + .. math:: + \text{out}_i = \text{input}_i \times \text{other}_i + + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer, float, and complex inputs. + + Args: + input (Tensor): the input tensor. + other (Tensor or Number) - the tensor or number to multiply input by. + + Keyword args: + out (Tensor, optional): the output tensor. + + Examples:: + + >>> a = torch.randn(3) + >>> a + tensor([ 0.2015, -0.4255, 2.6087]) + >>> torch.mul(a, 100) + tensor([ 20.1494, -42.5491, 260.8663]) + + >>> b = torch.randn(4, 1) + >>> b + tensor([[ 1.1207], + [-0.3137], + [ 0.0700], + [ 0.8378]]) + >>> c = torch.randn(1, 4) + >>> c + tensor([[ 0.5146, 0.1216, -0.5244, 2.2382]]) + >>> torch.mul(b, c) + tensor([[ 0.5767, 0.1363, -0.5877, 2.5083], + [-0.1614, -0.0382, 0.1645, -0.7021], + [ 0.0360, 0.0085, -0.0367, 0.1567], + [ 0.4312, 0.1019, -0.4394, 1.8753]]) + """ + +def multinomial( + input: Tensor, + num_samples: _int | SymInt, + replacement: _bool = False, + *, + generator: Generator | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + multinomial(input, num_samples, replacement=False, *, generator=None, out=None) -> LongTensor + + Returns a tensor where each row contains :attr:`num_samples` indices sampled + from the multinomial (a stricter definition would be multivariate, + refer to :class:`torch.distributions.multinomial.Multinomial` for more details) + probability distribution located in the corresponding row + of tensor :attr:`input`. + + .. note:: + The rows of :attr:`input` do not need to sum to one (in which case we use + the values as weights), but must be non-negative, finite and have + a non-zero sum. + + Indices are ordered from left to right according to when each was sampled + (first samples are placed in first column). + + If :attr:`input` is a vector, :attr:`out` is a vector of size :attr:`num_samples`. + + If :attr:`input` is a matrix with `m` rows, :attr:`out` is an matrix of shape + :math:`(m \times \text{num\_samples})`. + + If replacement is ``True``, samples are drawn with replacement. + + If not, they are drawn without replacement, which means that when a + sample index is drawn for a row, it cannot be drawn again for that row. + + .. note:: + When drawn without replacement, :attr:`num_samples` must be lower than + number of non-zero elements in :attr:`input` (or the min number of non-zero + elements in each row of :attr:`input` if it is a matrix). + + Args: + input (Tensor): the input tensor containing probabilities + num_samples (int): number of samples to draw + replacement (bool, optional): whether to draw with replacement or not + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + + Example:: + + >>> weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # create a tensor of weights + >>> torch.multinomial(weights, 2) + tensor([1, 2]) + >>> torch.multinomial(weights, 5) # ERROR! + RuntimeError: cannot sample n_sample > prob_dist.size(-1) samples without replacement + >>> torch.multinomial(weights, 4, replacement=True) + tensor([ 2, 1, 1, 1]) + """ + +@overload +def multiply( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + multiply(input, other, *, out=None) + + Alias for :func:`torch.mul`. + """ + +@overload +def multiply(input: Tensor, other: Number | _complex) -> Tensor: + r""" + multiply(input, other, *, out=None) + + Alias for :func:`torch.mul`. + """ + +def mv(input: Tensor, vec: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + mv(input, vec, *, out=None) -> Tensor + + Performs a matrix-vector product of the matrix :attr:`input` and the vector + :attr:`vec`. + + If :attr:`input` is a :math:`(n \times m)` tensor, :attr:`vec` is a 1-D tensor of + size :math:`m`, :attr:`out` will be 1-D of size :math:`n`. + + .. note:: This function does not :ref:`broadcast `. + + Args: + input (Tensor): matrix to be multiplied + vec (Tensor): vector to be multiplied + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> mat = torch.randn(2, 3) + >>> vec = torch.randn(3) + >>> torch.mv(mat, vec) + tensor([ 1.0404, -0.6361]) + """ + +def mvlgamma( + input: Tensor, + p: _int, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + mvlgamma(input, p, *, out=None) -> Tensor + + Alias for :func:`torch.special.multigammaln`. + """ + +def nan_to_num( + input: Tensor, + nan: _float | None = None, + posinf: _float | None = None, + neginf: _float | None = None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None) -> Tensor + + Replaces :literal:`NaN`, positive infinity, and negative infinity values in :attr:`input` + with the values specified by :attr:`nan`, :attr:`posinf`, and :attr:`neginf`, respectively. + By default, :literal:`NaN`\ s are replaced with zero, positive infinity is replaced with the + greatest finite value representable by :attr:`input`'s dtype, and negative infinity + is replaced with the least finite value representable by :attr:`input`'s dtype. + + Args: + input (Tensor): the input tensor. + nan (Number, optional): the value to replace :literal:`NaN`\s with. Default is zero. + posinf (Number, optional): if a Number, the value to replace positive infinity values with. + If None, positive infinity values are replaced with the greatest finite value representable by :attr:`input`'s dtype. + Default is None. + neginf (Number, optional): if a Number, the value to replace negative infinity values with. + If None, negative infinity values are replaced with the lowest finite value representable by :attr:`input`'s dtype. + Default is None. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> x = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.14]) + >>> torch.nan_to_num(x) + tensor([ 0.0000e+00, 3.4028e+38, -3.4028e+38, 3.1400e+00]) + >>> torch.nan_to_num(x, nan=2.0) + tensor([ 2.0000e+00, 3.4028e+38, -3.4028e+38, 3.1400e+00]) + >>> torch.nan_to_num(x, nan=2.0, posinf=1.0) + tensor([ 2.0000e+00, 1.0000e+00, -3.4028e+38, 3.1400e+00]) + """ + +def nan_to_num_( + input: Tensor, + nan: _float | None = None, + posinf: _float | None = None, + neginf: _float | None = None, +) -> Tensor: ... +def nanmean( + input: Tensor, + dim: _int | _size | None = None, + keepdim: _bool = False, + *, + dtype: _dtype | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + nanmean(input, dim=None, keepdim=False, *, dtype=None, out=None) -> Tensor + + Computes the mean of all `non-NaN` elements along the specified dimensions. + Input must be floating point or complex. + + This function is identical to :func:`torch.mean` when there are no `NaN` values + in the :attr:`input` tensor. In the presence of `NaN`, :func:`torch.mean` will + propagate the `NaN` to the output whereas :func:`torch.nanmean` will ignore the + `NaN` values (`torch.nanmean(a)` is equivalent to `torch.mean(a[~a.isnan()])`). + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor, either of floating point or complex dtype + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + out (Tensor, optional): the output tensor. + + .. seealso:: + + :func:`torch.mean` computes the mean value, propagating `NaN`. + + Example:: + + >>> x = torch.tensor([[torch.nan, 1, 2], [1, 2, 3]]) + >>> x.mean() + tensor(nan) + >>> x.nanmean() + tensor(1.8000) + >>> x.mean(dim=0) + tensor([ nan, 1.5000, 2.5000]) + >>> x.nanmean(dim=0) + tensor([1.0000, 1.5000, 2.5000]) + + # If all elements in the reduced dimensions are NaN then the result is NaN + >>> torch.tensor([torch.nan]).nanmean() + tensor(nan) + """ + +@overload +def nanmedian(input: Tensor) -> Tensor: + r""" + nanmedian(input) -> Tensor + + Returns the median of the values in :attr:`input`, ignoring ``NaN`` values. + + This function is identical to :func:`torch.median` when there are no ``NaN`` values in :attr:`input`. + When :attr:`input` has one or more ``NaN`` values, :func:`torch.median` will always return ``NaN``, + while this function will return the median of the non-``NaN`` elements in :attr:`input`. + If all the elements in :attr:`input` are ``NaN`` it will also return ``NaN``. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.tensor([1, float('nan'), 3, 2]) + >>> a.median() + tensor(nan) + >>> a.nanmedian() + tensor(2.) + + .. function:: nanmedian(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` contains the median of each row of :attr:`input` + in the dimension :attr:`dim`, ignoring ``NaN`` values, and ``indices`` contains the index of the median values + found in the dimension :attr:`dim`. + + This function is identical to :func:`torch.median` when there are no ``NaN`` values in a reduced row. When a reduced row has + one or more ``NaN`` values, :func:`torch.median` will always reduce it to ``NaN``, while this function will reduce it to the + median of the non-``NaN`` elements. If all the elements in a reduced row are ``NaN`` then it will be reduced to ``NaN``, too. + + Args: + input (Tensor): the input tensor. + + dim (int, optional): the dimension to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out ((Tensor, Tensor), optional): The first tensor will be populated with the median values and the second + tensor, which must have dtype long, with their indices in the dimension + :attr:`dim` of :attr:`input`. + + Example:: + + >>> a = torch.tensor([[2, 3, 1], [float('nan'), 1, float('nan')]]) + >>> a + tensor([[2., 3., 1.], + [nan, 1., nan]]) + >>> a.median(0) + torch.return_types.median(values=tensor([nan, 1., nan]), indices=tensor([1, 1, 1])) + >>> a.nanmedian(0) + torch.return_types.nanmedian(values=tensor([2., 1., 1.]), indices=tensor([0, 1, 0])) + """ + +@overload +def nanmedian( + input: Tensor, + dim: _int, + keepdim: _bool = False, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.nanmedian: + r""" + nanmedian(input) -> Tensor + + Returns the median of the values in :attr:`input`, ignoring ``NaN`` values. + + This function is identical to :func:`torch.median` when there are no ``NaN`` values in :attr:`input`. + When :attr:`input` has one or more ``NaN`` values, :func:`torch.median` will always return ``NaN``, + while this function will return the median of the non-``NaN`` elements in :attr:`input`. + If all the elements in :attr:`input` are ``NaN`` it will also return ``NaN``. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.tensor([1, float('nan'), 3, 2]) + >>> a.median() + tensor(nan) + >>> a.nanmedian() + tensor(2.) + + .. function:: nanmedian(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` contains the median of each row of :attr:`input` + in the dimension :attr:`dim`, ignoring ``NaN`` values, and ``indices`` contains the index of the median values + found in the dimension :attr:`dim`. + + This function is identical to :func:`torch.median` when there are no ``NaN`` values in a reduced row. When a reduced row has + one or more ``NaN`` values, :func:`torch.median` will always reduce it to ``NaN``, while this function will reduce it to the + median of the non-``NaN`` elements. If all the elements in a reduced row are ``NaN`` then it will be reduced to ``NaN``, too. + + Args: + input (Tensor): the input tensor. + + dim (int, optional): the dimension to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out ((Tensor, Tensor), optional): The first tensor will be populated with the median values and the second + tensor, which must have dtype long, with their indices in the dimension + :attr:`dim` of :attr:`input`. + + Example:: + + >>> a = torch.tensor([[2, 3, 1], [float('nan'), 1, float('nan')]]) + >>> a + tensor([[2., 3., 1.], + [nan, 1., nan]]) + >>> a.median(0) + torch.return_types.median(values=tensor([nan, 1., nan]), indices=tensor([1, 1, 1])) + >>> a.nanmedian(0) + torch.return_types.nanmedian(values=tensor([2., 1., 1.]), indices=tensor([0, 1, 0])) + """ + +@overload +def nanmedian( + input: Tensor, + dim: str | EllipsisType | None, + keepdim: _bool = False, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.nanmedian: + r""" + nanmedian(input) -> Tensor + + Returns the median of the values in :attr:`input`, ignoring ``NaN`` values. + + This function is identical to :func:`torch.median` when there are no ``NaN`` values in :attr:`input`. + When :attr:`input` has one or more ``NaN`` values, :func:`torch.median` will always return ``NaN``, + while this function will return the median of the non-``NaN`` elements in :attr:`input`. + If all the elements in :attr:`input` are ``NaN`` it will also return ``NaN``. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.tensor([1, float('nan'), 3, 2]) + >>> a.median() + tensor(nan) + >>> a.nanmedian() + tensor(2.) + + .. function:: nanmedian(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` contains the median of each row of :attr:`input` + in the dimension :attr:`dim`, ignoring ``NaN`` values, and ``indices`` contains the index of the median values + found in the dimension :attr:`dim`. + + This function is identical to :func:`torch.median` when there are no ``NaN`` values in a reduced row. When a reduced row has + one or more ``NaN`` values, :func:`torch.median` will always reduce it to ``NaN``, while this function will reduce it to the + median of the non-``NaN`` elements. If all the elements in a reduced row are ``NaN`` then it will be reduced to ``NaN``, too. + + Args: + input (Tensor): the input tensor. + + dim (int, optional): the dimension to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out ((Tensor, Tensor), optional): The first tensor will be populated with the median values and the second + tensor, which must have dtype long, with their indices in the dimension + :attr:`dim` of :attr:`input`. + + Example:: + + >>> a = torch.tensor([[2, 3, 1], [float('nan'), 1, float('nan')]]) + >>> a + tensor([[2., 3., 1.], + [nan, 1., nan]]) + >>> a.median(0) + torch.return_types.median(values=tensor([nan, 1., nan]), indices=tensor([1, 1, 1])) + >>> a.nanmedian(0) + torch.return_types.nanmedian(values=tensor([2., 1., 1.]), indices=tensor([0, 1, 0])) + """ + +@overload +def nanquantile( + input: Tensor, + q: Tensor, + dim: _int | None = None, + keepdim: _bool = False, + *, + interpolation: str = "linear", + out: Tensor | None = None, +) -> Tensor: + r""" + nanquantile(input, q, dim=None, keepdim=False, *, interpolation='linear', out=None) -> Tensor + + This is a variant of :func:`torch.quantile` that "ignores" ``NaN`` values, + computing the quantiles :attr:`q` as if ``NaN`` values in :attr:`input` did + not exist. If all values in a reduced row are ``NaN`` then the quantiles for + that reduction will be ``NaN``. See the documentation for :func:`torch.quantile`. + + Args: + input (Tensor): the input tensor. + q (float or Tensor): a scalar or 1D tensor of quantile values in the range [0, 1] + + dim (int, optional): the dimension to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword arguments: + interpolation (str): interpolation method to use when the desired quantile lies between two data points. + Can be ``linear``, ``lower``, ``higher``, ``midpoint`` and ``nearest``. + Default is ``linear``. + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.tensor([float('nan'), 1, 2]) + >>> t.quantile(0.5) + tensor(nan) + >>> t.nanquantile(0.5) + tensor(1.5000) + >>> t = torch.tensor([[float('nan'), float('nan')], [1, 2]]) + >>> t + tensor([[nan, nan], + [1., 2.]]) + >>> t.nanquantile(0.5, dim=0) + tensor([1., 2.]) + >>> t.nanquantile(0.5, dim=1) + tensor([ nan, 1.5000]) + """ + +@overload +def nanquantile( + input: Tensor, + q: _float, + dim: _int | None = None, + keepdim: _bool = False, + *, + interpolation: str = "linear", + out: Tensor | None = None, +) -> Tensor: + r""" + nanquantile(input, q, dim=None, keepdim=False, *, interpolation='linear', out=None) -> Tensor + + This is a variant of :func:`torch.quantile` that "ignores" ``NaN`` values, + computing the quantiles :attr:`q` as if ``NaN`` values in :attr:`input` did + not exist. If all values in a reduced row are ``NaN`` then the quantiles for + that reduction will be ``NaN``. See the documentation for :func:`torch.quantile`. + + Args: + input (Tensor): the input tensor. + q (float or Tensor): a scalar or 1D tensor of quantile values in the range [0, 1] + + dim (int, optional): the dimension to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword arguments: + interpolation (str): interpolation method to use when the desired quantile lies between two data points. + Can be ``linear``, ``lower``, ``higher``, ``midpoint`` and ``nearest``. + Default is ``linear``. + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.tensor([float('nan'), 1, 2]) + >>> t.quantile(0.5) + tensor(nan) + >>> t.nanquantile(0.5) + tensor(1.5000) + >>> t = torch.tensor([[float('nan'), float('nan')], [1, 2]]) + >>> t + tensor([[nan, nan], + [1., 2.]]) + >>> t.nanquantile(0.5, dim=0) + tensor([1., 2.]) + >>> t.nanquantile(0.5, dim=1) + tensor([ nan, 1.5000]) + """ + +def nansum( + input: Tensor, + dim: _int | _size | None = None, + keepdim: _bool = False, + *, + dtype: _dtype | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + nansum(input, *, dtype=None) -> Tensor + + Returns the sum of all elements, treating Not a Numbers (NaNs) as zero. + + Args: + input (Tensor): the input tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.tensor([1., 2., float('nan'), 4.]) + >>> torch.nansum(a) + tensor(7.) + + .. function:: nansum(input, dim, keepdim=False, *, dtype=None) -> Tensor + :noindex: + + Returns the sum of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`, treating Not a Numbers (NaNs) as zero. + If :attr:`dim` is a list of dimensions, reduce over all of them. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> torch.nansum(torch.tensor([1., float("nan")])) + tensor(1.) + >>> a = torch.tensor([[1, 2], [3., float("nan")]]) + >>> torch.nansum(a) + tensor(6.) + >>> torch.nansum(a, dim=0) + tensor([4., 2.]) + >>> torch.nansum(a, dim=1) + tensor([3., 3.]) + """ + +@overload +def narrow( + input: Tensor, + dim: _int, + start: Tensor, + length: _int | SymInt, +) -> Tensor: + r""" + narrow(input, dim, start, length) -> Tensor + + Returns a new tensor that is a narrowed version of :attr:`input` tensor. The + dimension :attr:`dim` is input from :attr:`start` to ``start + length``. The + returned tensor and :attr:`input` tensor share the same underlying storage. + + Args: + input (Tensor): the tensor to narrow + dim (int): the dimension along which to narrow + start (int or Tensor): index of the element to start the narrowed dimension + from. Can be negative, which means indexing from the end of `dim`. If + `Tensor`, it must be an 0-dim integral `Tensor` (bools not allowed) + length (int): length of the narrowed dimension, must be weakly positive + + Example:: + + >>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> torch.narrow(x, 0, 0, 2) + tensor([[ 1, 2, 3], + [ 4, 5, 6]]) + >>> torch.narrow(x, 1, 1, 2) + tensor([[ 2, 3], + [ 5, 6], + [ 8, 9]]) + >>> torch.narrow(x, -1, torch.tensor(-1), 1) + tensor([[3], + [6], + [9]]) + """ + +@overload +def narrow( + input: Tensor, + dim: _int, + start: _int | SymInt, + length: _int | SymInt, +) -> Tensor: + r""" + narrow(input, dim, start, length) -> Tensor + + Returns a new tensor that is a narrowed version of :attr:`input` tensor. The + dimension :attr:`dim` is input from :attr:`start` to ``start + length``. The + returned tensor and :attr:`input` tensor share the same underlying storage. + + Args: + input (Tensor): the tensor to narrow + dim (int): the dimension along which to narrow + start (int or Tensor): index of the element to start the narrowed dimension + from. Can be negative, which means indexing from the end of `dim`. If + `Tensor`, it must be an 0-dim integral `Tensor` (bools not allowed) + length (int): length of the narrowed dimension, must be weakly positive + + Example:: + + >>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> torch.narrow(x, 0, 0, 2) + tensor([[ 1, 2, 3], + [ 4, 5, 6]]) + >>> torch.narrow(x, 1, 1, 2) + tensor([[ 2, 3], + [ 5, 6], + [ 8, 9]]) + >>> torch.narrow(x, -1, torch.tensor(-1), 1) + tensor([[3], + [6], + [9]]) + """ + +def narrow_copy( + input: Tensor, + dim: _int, + start: _int | SymInt, + length: _int | SymInt, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + narrow_copy(input, dim, start, length, *, out=None) -> Tensor + + Same as :meth:`Tensor.narrow` except this returns a copy rather + than shared storage. This is primarily for sparse tensors, which + do not have a shared-storage narrow method. + + Args: + input (Tensor): the tensor to narrow + dim (int): the dimension along which to narrow + start (int): index of the element to start the narrowed dimension from. Can + be negative, which means indexing from the end of `dim` + length (int): length of the narrowed dimension, must be weakly positive + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> torch.narrow_copy(x, 0, 0, 2) + tensor([[ 1, 2, 3], + [ 4, 5, 6]]) + >>> torch.narrow_copy(x, 1, 1, 2) + tensor([[ 2, 3], + [ 5, 6], + [ 8, 9]]) + >>> s = torch.arange(16).reshape(2, 2, 2, 2).to_sparse(2) + >>> torch.narrow_copy(s, 0, 0, 1) + tensor(indices=tensor([[0, 0], + [0, 1]]), + values=tensor([[[0, 1], + [2, 3]], + + [[4, 5], + [6, 7]]]), + size=(1, 2, 2, 2), nnz=2, layout=torch.sparse_coo) + + .. seealso:: + + :func:`torch.narrow` for a non copy variant + """ + +def native_batch_norm( + input: Tensor, + weight: Tensor | None, + bias: Tensor | None, + running_mean: Tensor | None, + running_var: Tensor | None, + training: _bool, + momentum: _float, + eps: _float, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> tuple[Tensor, Tensor, Tensor]: ... +def native_channel_shuffle(input: Tensor, groups: _int | SymInt) -> Tensor: ... +def native_dropout( + input: Tensor, + p: _float, + train: _bool | None, +) -> tuple[Tensor, Tensor]: ... +def native_group_norm( + input: Tensor, + weight: Tensor | None, + bias: Tensor | None, + N: _int | SymInt, + C: _int | SymInt, + HxW: _int | SymInt, + group: _int, + eps: _float, +) -> tuple[Tensor, Tensor, Tensor]: ... +def native_layer_norm( + input: Tensor, + normalized_shape: Sequence[_int | SymInt], + weight: Tensor | None, + bias: Tensor | None, + eps: _float, +) -> tuple[Tensor, Tensor, Tensor]: ... +@overload +def native_norm( + input: Tensor, + p: Number | _complex | None, + dim: _int | _size, + keepdim: _bool, + dtype: _dtype | None, +) -> Tensor: ... +@overload +def native_norm(input: Tensor, p: Number | _complex = 2) -> Tensor: ... +@overload +def ne( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + ne(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} \neq \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is not equal to :attr:`other` and False elsewhere + + Example:: + + >>> torch.ne(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[False, True], [True, False]]) + """ + +@overload +def ne( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + ne(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} \neq \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is not equal to :attr:`other` and False elsewhere + + Example:: + + >>> torch.ne(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[False, True], [True, False]]) + """ + +def neg(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + neg(input, *, out=None) -> Tensor + + Returns a new tensor with the negative of the elements of :attr:`input`. + + .. math:: + \text{out} = -1 \times \text{input} + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(5) + >>> a + tensor([ 0.0090, -0.2262, -0.0682, -0.2866, 0.3940]) + >>> torch.neg(a) + tensor([-0.0090, 0.2262, 0.0682, 0.2866, -0.3940]) + """ + +def neg_(input: Tensor) -> Tensor: ... +def negative(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + negative(input, *, out=None) -> Tensor + + Alias for :func:`torch.neg` + """ + +def negative_(input: Tensor) -> Tensor: ... +def nextafter( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + nextafter(input, other, *, out=None) -> Tensor + + Return the next floating-point value after :attr:`input` towards :attr:`other`, elementwise. + + The shapes of ``input`` and ``other`` must be + :ref:`broadcastable `. + + Args: + input (Tensor): the first input tensor + other (Tensor): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> eps = torch.finfo(torch.float32).eps + >>> torch.nextafter(torch.tensor([1.0, 2.0]), torch.tensor([2.0, 1.0])) == torch.tensor([eps + 1, 2 - eps]) + tensor([True, True]) + """ + +@overload +def nonzero( + input: Tensor, + *, + as_tuple: Literal[False] = False, + out: Tensor | None = None, +) -> Tensor: + r""" + nonzero(input, *, out=None, as_tuple=False) -> LongTensor or tuple of LongTensors + + .. note:: + :func:`torch.nonzero(..., as_tuple=False) ` (default) returns a + 2-D tensor where each row is the index for a nonzero value. + + :func:`torch.nonzero(..., as_tuple=True) ` returns a tuple of 1-D + index tensors, allowing for advanced indexing, so ``x[x.nonzero(as_tuple=True)]`` + gives all nonzero values of tensor ``x``. Of the returned tuple, each index tensor + contains nonzero indices for a certain dimension. + + See below for more details on the two behaviors. + + When :attr:`input` is on CUDA, :func:`torch.nonzero() ` causes + host-device synchronization. + + **When** :attr:`as_tuple` **is** ``False`` **(default)**: + + Returns a tensor containing the indices of all non-zero elements of + :attr:`input`. Each row in the result contains the indices of a non-zero + element in :attr:`input`. The result is sorted lexicographically, with + the last index changing the fastest (C-style). + + If :attr:`input` has :math:`n` dimensions, then the resulting indices tensor + :attr:`out` is of size :math:`(z \times n)`, where :math:`z` is the total number of + non-zero elements in the :attr:`input` tensor. + + **When** :attr:`as_tuple` **is** ``True``: + + Returns a tuple of 1-D tensors, one for each dimension in :attr:`input`, + each containing the indices (in that dimension) of all non-zero elements of + :attr:`input` . + + If :attr:`input` has :math:`n` dimensions, then the resulting tuple contains :math:`n` + tensors of size :math:`z`, where :math:`z` is the total number of + non-zero elements in the :attr:`input` tensor. + + As a special case, when :attr:`input` has zero dimensions and a nonzero scalar + value, it is treated as a one-dimensional tensor with one element. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (LongTensor, optional): the output tensor containing indices + + Returns: + LongTensor or tuple of LongTensor: If :attr:`as_tuple` is ``False``, the output + tensor containing indices. If :attr:`as_tuple` is ``True``, one 1-D tensor for + each dimension, containing the indices of each nonzero element along that + dimension. + + Example:: + + >>> torch.nonzero(torch.tensor([1, 1, 1, 0, 1])) + tensor([[ 0], + [ 1], + [ 2], + [ 4]]) + >>> torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0], + ... [0.0, 0.4, 0.0, 0.0], + ... [0.0, 0.0, 1.2, 0.0], + ... [0.0, 0.0, 0.0,-0.4]])) + tensor([[ 0, 0], + [ 1, 1], + [ 2, 2], + [ 3, 3]]) + >>> torch.nonzero(torch.tensor([1, 1, 1, 0, 1]), as_tuple=True) + (tensor([0, 1, 2, 4]),) + >>> torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0], + ... [0.0, 0.4, 0.0, 0.0], + ... [0.0, 0.0, 1.2, 0.0], + ... [0.0, 0.0, 0.0,-0.4]]), as_tuple=True) + (tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])) + >>> torch.nonzero(torch.tensor(5), as_tuple=True) + (tensor([0]),) + """ + +@overload +def nonzero( + input: Tensor, + *, + as_tuple: Literal[True], +) -> tuple[Tensor, ...]: + r""" + nonzero(input, *, out=None, as_tuple=False) -> LongTensor or tuple of LongTensors + + .. note:: + :func:`torch.nonzero(..., as_tuple=False) ` (default) returns a + 2-D tensor where each row is the index for a nonzero value. + + :func:`torch.nonzero(..., as_tuple=True) ` returns a tuple of 1-D + index tensors, allowing for advanced indexing, so ``x[x.nonzero(as_tuple=True)]`` + gives all nonzero values of tensor ``x``. Of the returned tuple, each index tensor + contains nonzero indices for a certain dimension. + + See below for more details on the two behaviors. + + When :attr:`input` is on CUDA, :func:`torch.nonzero() ` causes + host-device synchronization. + + **When** :attr:`as_tuple` **is** ``False`` **(default)**: + + Returns a tensor containing the indices of all non-zero elements of + :attr:`input`. Each row in the result contains the indices of a non-zero + element in :attr:`input`. The result is sorted lexicographically, with + the last index changing the fastest (C-style). + + If :attr:`input` has :math:`n` dimensions, then the resulting indices tensor + :attr:`out` is of size :math:`(z \times n)`, where :math:`z` is the total number of + non-zero elements in the :attr:`input` tensor. + + **When** :attr:`as_tuple` **is** ``True``: + + Returns a tuple of 1-D tensors, one for each dimension in :attr:`input`, + each containing the indices (in that dimension) of all non-zero elements of + :attr:`input` . + + If :attr:`input` has :math:`n` dimensions, then the resulting tuple contains :math:`n` + tensors of size :math:`z`, where :math:`z` is the total number of + non-zero elements in the :attr:`input` tensor. + + As a special case, when :attr:`input` has zero dimensions and a nonzero scalar + value, it is treated as a one-dimensional tensor with one element. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (LongTensor, optional): the output tensor containing indices + + Returns: + LongTensor or tuple of LongTensor: If :attr:`as_tuple` is ``False``, the output + tensor containing indices. If :attr:`as_tuple` is ``True``, one 1-D tensor for + each dimension, containing the indices of each nonzero element along that + dimension. + + Example:: + + >>> torch.nonzero(torch.tensor([1, 1, 1, 0, 1])) + tensor([[ 0], + [ 1], + [ 2], + [ 4]]) + >>> torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0], + ... [0.0, 0.4, 0.0, 0.0], + ... [0.0, 0.0, 1.2, 0.0], + ... [0.0, 0.0, 0.0,-0.4]])) + tensor([[ 0, 0], + [ 1, 1], + [ 2, 2], + [ 3, 3]]) + >>> torch.nonzero(torch.tensor([1, 1, 1, 0, 1]), as_tuple=True) + (tensor([0, 1, 2, 4]),) + >>> torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0], + ... [0.0, 0.4, 0.0, 0.0], + ... [0.0, 0.0, 1.2, 0.0], + ... [0.0, 0.0, 0.0,-0.4]]), as_tuple=True) + (tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])) + >>> torch.nonzero(torch.tensor(5), as_tuple=True) + (tensor([0]),) + """ + +def nonzero_static( + input: Tensor, + *, + size: _int | SymInt, + fill_value: _int = -1, + out: Tensor | None = None, +) -> Tensor: ... +def norm_except_dim(v: Tensor, pow: _int = 2, dim: _int = 0) -> Tensor: ... +@overload +def normal( + mean: Tensor, + std: Tensor, + *, + generator: Generator | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + normal(mean, std, *, generator=None, out=None) -> Tensor + + Returns a tensor of random numbers drawn from separate normal distributions + whose mean and standard deviation are given. + + The :attr:`mean` is a tensor with the mean of + each output element's normal distribution + + The :attr:`std` is a tensor with the standard deviation of + each output element's normal distribution + + The shapes of :attr:`mean` and :attr:`std` don't need to match, but the + total number of elements in each tensor need to be the same. + + .. note:: When the shapes do not match, the shape of :attr:`mean` + is used as the shape for the returned output tensor + + .. note:: When :attr:`std` is a CUDA tensor, this function synchronizes + its device with the CPU. + + Args: + mean (Tensor): the tensor of per-element means + std (Tensor): the tensor of per-element standard deviations + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(mean=torch.arange(1., 11.), std=torch.arange(1, 0, -0.1)) + tensor([ 1.0425, 3.5672, 2.7969, 4.2925, 4.7229, 6.2134, + 8.0505, 8.1408, 9.0563, 10.0566]) + + .. function:: normal(mean=0.0, std, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the means are shared among all drawn + elements. + + Args: + mean (float, optional): the mean for all distributions + std (Tensor): the tensor of per-element standard deviations + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(mean=0.5, std=torch.arange(1., 6.)) + tensor([-1.2793, -1.0732, -2.0687, 5.1177, -1.2303]) + + .. function:: normal(mean, std=1.0, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the standard deviations are shared among + all drawn elements. + + Args: + mean (Tensor): the tensor of per-element means + std (float, optional): the standard deviation for all distributions + + Keyword args: + out (Tensor, optional): the output tensor + + Example:: + + >>> torch.normal(mean=torch.arange(1., 6.)) + tensor([ 1.1552, 2.6148, 2.6535, 5.8318, 4.2361]) + + .. function:: normal(mean, std, size, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the means and standard deviations are shared + among all drawn elements. The resulting tensor has size given by :attr:`size`. + + Args: + mean (float): the mean for all distributions + std (float): the standard deviation for all distributions + size (int...): a sequence of integers defining the shape of the output tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(2, 3, size=(1, 4)) + tensor([[-1.3987, -1.9544, 3.6048, 0.7909]]) + """ + +@overload +def normal( + mean: Tensor, + std: _float = 1, + *, + generator: Generator | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + normal(mean, std, *, generator=None, out=None) -> Tensor + + Returns a tensor of random numbers drawn from separate normal distributions + whose mean and standard deviation are given. + + The :attr:`mean` is a tensor with the mean of + each output element's normal distribution + + The :attr:`std` is a tensor with the standard deviation of + each output element's normal distribution + + The shapes of :attr:`mean` and :attr:`std` don't need to match, but the + total number of elements in each tensor need to be the same. + + .. note:: When the shapes do not match, the shape of :attr:`mean` + is used as the shape for the returned output tensor + + .. note:: When :attr:`std` is a CUDA tensor, this function synchronizes + its device with the CPU. + + Args: + mean (Tensor): the tensor of per-element means + std (Tensor): the tensor of per-element standard deviations + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(mean=torch.arange(1., 11.), std=torch.arange(1, 0, -0.1)) + tensor([ 1.0425, 3.5672, 2.7969, 4.2925, 4.7229, 6.2134, + 8.0505, 8.1408, 9.0563, 10.0566]) + + .. function:: normal(mean=0.0, std, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the means are shared among all drawn + elements. + + Args: + mean (float, optional): the mean for all distributions + std (Tensor): the tensor of per-element standard deviations + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(mean=0.5, std=torch.arange(1., 6.)) + tensor([-1.2793, -1.0732, -2.0687, 5.1177, -1.2303]) + + .. function:: normal(mean, std=1.0, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the standard deviations are shared among + all drawn elements. + + Args: + mean (Tensor): the tensor of per-element means + std (float, optional): the standard deviation for all distributions + + Keyword args: + out (Tensor, optional): the output tensor + + Example:: + + >>> torch.normal(mean=torch.arange(1., 6.)) + tensor([ 1.1552, 2.6148, 2.6535, 5.8318, 4.2361]) + + .. function:: normal(mean, std, size, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the means and standard deviations are shared + among all drawn elements. The resulting tensor has size given by :attr:`size`. + + Args: + mean (float): the mean for all distributions + std (float): the standard deviation for all distributions + size (int...): a sequence of integers defining the shape of the output tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(2, 3, size=(1, 4)) + tensor([[-1.3987, -1.9544, 3.6048, 0.7909]]) + """ + +@overload +def normal( + mean: _float, + std: Tensor, + *, + generator: Generator | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + normal(mean, std, *, generator=None, out=None) -> Tensor + + Returns a tensor of random numbers drawn from separate normal distributions + whose mean and standard deviation are given. + + The :attr:`mean` is a tensor with the mean of + each output element's normal distribution + + The :attr:`std` is a tensor with the standard deviation of + each output element's normal distribution + + The shapes of :attr:`mean` and :attr:`std` don't need to match, but the + total number of elements in each tensor need to be the same. + + .. note:: When the shapes do not match, the shape of :attr:`mean` + is used as the shape for the returned output tensor + + .. note:: When :attr:`std` is a CUDA tensor, this function synchronizes + its device with the CPU. + + Args: + mean (Tensor): the tensor of per-element means + std (Tensor): the tensor of per-element standard deviations + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(mean=torch.arange(1., 11.), std=torch.arange(1, 0, -0.1)) + tensor([ 1.0425, 3.5672, 2.7969, 4.2925, 4.7229, 6.2134, + 8.0505, 8.1408, 9.0563, 10.0566]) + + .. function:: normal(mean=0.0, std, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the means are shared among all drawn + elements. + + Args: + mean (float, optional): the mean for all distributions + std (Tensor): the tensor of per-element standard deviations + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(mean=0.5, std=torch.arange(1., 6.)) + tensor([-1.2793, -1.0732, -2.0687, 5.1177, -1.2303]) + + .. function:: normal(mean, std=1.0, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the standard deviations are shared among + all drawn elements. + + Args: + mean (Tensor): the tensor of per-element means + std (float, optional): the standard deviation for all distributions + + Keyword args: + out (Tensor, optional): the output tensor + + Example:: + + >>> torch.normal(mean=torch.arange(1., 6.)) + tensor([ 1.1552, 2.6148, 2.6535, 5.8318, 4.2361]) + + .. function:: normal(mean, std, size, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the means and standard deviations are shared + among all drawn elements. The resulting tensor has size given by :attr:`size`. + + Args: + mean (float): the mean for all distributions + std (float): the standard deviation for all distributions + size (int...): a sequence of integers defining the shape of the output tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(2, 3, size=(1, 4)) + tensor([[-1.3987, -1.9544, 3.6048, 0.7909]]) + """ + +@overload +def normal( + mean: _float, + std: _float, + size: Sequence[_int | SymInt], + *, + generator: Generator | None = None, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + normal(mean, std, *, generator=None, out=None) -> Tensor + + Returns a tensor of random numbers drawn from separate normal distributions + whose mean and standard deviation are given. + + The :attr:`mean` is a tensor with the mean of + each output element's normal distribution + + The :attr:`std` is a tensor with the standard deviation of + each output element's normal distribution + + The shapes of :attr:`mean` and :attr:`std` don't need to match, but the + total number of elements in each tensor need to be the same. + + .. note:: When the shapes do not match, the shape of :attr:`mean` + is used as the shape for the returned output tensor + + .. note:: When :attr:`std` is a CUDA tensor, this function synchronizes + its device with the CPU. + + Args: + mean (Tensor): the tensor of per-element means + std (Tensor): the tensor of per-element standard deviations + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(mean=torch.arange(1., 11.), std=torch.arange(1, 0, -0.1)) + tensor([ 1.0425, 3.5672, 2.7969, 4.2925, 4.7229, 6.2134, + 8.0505, 8.1408, 9.0563, 10.0566]) + + .. function:: normal(mean=0.0, std, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the means are shared among all drawn + elements. + + Args: + mean (float, optional): the mean for all distributions + std (Tensor): the tensor of per-element standard deviations + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(mean=0.5, std=torch.arange(1., 6.)) + tensor([-1.2793, -1.0732, -2.0687, 5.1177, -1.2303]) + + .. function:: normal(mean, std=1.0, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the standard deviations are shared among + all drawn elements. + + Args: + mean (Tensor): the tensor of per-element means + std (float, optional): the standard deviation for all distributions + + Keyword args: + out (Tensor, optional): the output tensor + + Example:: + + >>> torch.normal(mean=torch.arange(1., 6.)) + tensor([ 1.1552, 2.6148, 2.6535, 5.8318, 4.2361]) + + .. function:: normal(mean, std, size, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the means and standard deviations are shared + among all drawn elements. The resulting tensor has size given by :attr:`size`. + + Args: + mean (float): the mean for all distributions + std (float): the standard deviation for all distributions + size (int...): a sequence of integers defining the shape of the output tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(2, 3, size=(1, 4)) + tensor([[-1.3987, -1.9544, 3.6048, 0.7909]]) + """ + +@overload +def not_equal( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + not_equal(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.ne`. + """ + +@overload +def not_equal( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + not_equal(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.ne`. + """ + +@overload +def nuclear_norm( + input: Tensor, + dim: _int | _size, + keepdim: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: ... +@overload +def nuclear_norm( + input: Tensor, + keepdim: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: ... +def numel(self: Tensor) -> _int: + r""" + numel(input: Tensor) -> int + + Returns the total number of elements in the :attr:`input` tensor. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.randn(1, 2, 3, 4, 5) + >>> torch.numel(a) + 120 + >>> a = torch.zeros(4,4) + >>> torch.numel(a) + 16 + """ + +@overload +def ones( + size: Sequence[_int | SymInt], + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + ones(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with the scalar value `1`, with the shape defined + by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.ones(2, 3) + tensor([[ 1., 1., 1.], + [ 1., 1., 1.]]) + + >>> torch.ones(5) + tensor([ 1., 1., 1., 1., 1.]) + """ + +@overload +def ones( + *size: _int | SymInt, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + ones(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with the scalar value `1`, with the shape defined + by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.ones(2, 3) + tensor([[ 1., 1., 1.], + [ 1., 1., 1.]]) + + >>> torch.ones(5) + tensor([ 1., 1., 1., 1., 1.]) + """ + +@overload +def ones( + size: _size, + *, + names: Sequence[str | EllipsisType | None] | None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + ones(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with the scalar value `1`, with the shape defined + by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.ones(2, 3) + tensor([[ 1., 1., 1.], + [ 1., 1., 1.]]) + + >>> torch.ones(5) + tensor([ 1., 1., 1., 1., 1.]) + """ + +@overload +def ones( + *size: _int, + names: Sequence[str | EllipsisType | None] | None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + ones(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with the scalar value `1`, with the shape defined + by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.ones(2, 3) + tensor([[ 1., 1., 1.], + [ 1., 1., 1.]]) + + >>> torch.ones(5) + tensor([ 1., 1., 1., 1., 1.]) + """ + +def ones_like( + input: Tensor, + *, + memory_format: memory_format | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + ones_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + + Returns a tensor filled with the scalar value `1`, with the same size as + :attr:`input`. ``torch.ones_like(input)`` is equivalent to + ``torch.ones(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``. + + .. warning:: + As of 0.4, this function does not support an :attr:`out` keyword. As an alternative, + the old ``torch.ones_like(input, out=output)`` is equivalent to + ``torch.ones(input.size(), out=output)``. + + Args: + input (Tensor): the size of :attr:`input` will determine size of the output tensor. + + Keyword arguments: + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: if ``None``, defaults to the dtype of :attr:`input`. + layout (:class:`torch.layout`, optional): the desired layout of returned tensor. + Default: if ``None``, defaults to the layout of :attr:`input`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, defaults to the device of :attr:`input`. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + + Example:: + + >>> input = torch.empty(2, 3) + >>> torch.ones_like(input) + tensor([[ 1., 1., 1.], + [ 1., 1., 1.]]) + """ + +def orgqr( + input: Tensor, + input2: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + orgqr(input, tau) -> Tensor + + Alias for :func:`torch.linalg.householder_product`. + """ + +def ormqr( + input: Tensor, + input2: Tensor, + input3: Tensor, + left: _bool = True, + transpose: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + ormqr(input, tau, other, left=True, transpose=False, *, out=None) -> Tensor + + Computes the matrix-matrix multiplication of a product of Householder matrices with a general matrix. + + Multiplies a :math:`m \times n` matrix `C` (given by :attr:`other`) with a matrix `Q`, + where `Q` is represented using Householder reflectors `(input, tau)`. + See `Representation of Orthogonal or Unitary Matrices`_ for further details. + + If :attr:`left` is `True` then `op(Q)` times `C` is computed, otherwise the result is `C` times `op(Q)`. + When :attr:`left` is `True`, the implicit matrix `Q` has size :math:`m \times m`. + It has size :math:`n \times n` otherwise. + If :attr:`transpose` is `True` then `op` is the conjugate transpose operation, otherwise it's a no-op. + + Supports inputs of float, double, cfloat and cdouble dtypes. + Also supports batched inputs, and, if the input is batched, the output is batched with the same dimensions. + + .. seealso:: + :func:`torch.geqrf` can be used to form the Householder representation `(input, tau)` of matrix `Q` + from the QR decomposition. + + .. note:: + This function supports backward but it is only fast when ``(input, tau)`` do not require gradients + and/or ``tau.size(-1)`` is very small. + `` + + Args: + input (Tensor): tensor of shape `(*, mn, k)` where `*` is zero or more batch dimensions + and `mn` equals to `m` or `n` depending on the :attr:`left`. + tau (Tensor): tensor of shape `(*, min(mn, k))` where `*` is zero or more batch dimensions. + other (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + left (bool): controls the order of multiplication. + transpose (bool): controls whether the matrix `Q` is conjugate transposed or not. + + Keyword args: + out (Tensor, optional): the output Tensor. Ignored if `None`. Default: `None`. + + .. _Representation of Orthogonal or Unitary Matrices: + https://www.netlib.org/lapack/lug/node128.html + """ + +def outer( + input: Tensor, + vec2: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + outer(input, vec2, *, out=None) -> Tensor + + Outer product of :attr:`input` and :attr:`vec2`. + If :attr:`input` is a vector of size :math:`n` and :attr:`vec2` is a vector of + size :math:`m`, then :attr:`out` must be a matrix of size :math:`(n \times m)`. + + .. note:: This function does not :ref:`broadcast `. + + Args: + input (Tensor): 1-D input vector + vec2 (Tensor): 1-D input vector + + Keyword args: + out (Tensor, optional): optional output matrix + + Example:: + + >>> v1 = torch.arange(1., 5.) + >>> v2 = torch.arange(1., 4.) + >>> torch.outer(v1, v2) + tensor([[ 1., 2., 3.], + [ 2., 4., 6.], + [ 3., 6., 9.], + [ 4., 8., 12.]]) + """ + +def pairwise_distance( + x1: Tensor, + x2: Tensor, + p: _float = 2, + eps: _float = 1e-06, + keepdim: _bool = False, +) -> Tensor: ... +def pdist(input: Tensor, p: _float = 2) -> Tensor: ... +def permute(input: Tensor, dims: _size) -> Tensor: + r""" + permute(input, dims) -> Tensor + + Returns a view of the original tensor :attr:`input` with its dimensions permuted. + + Args: + input (Tensor): the input tensor. + dims (tuple of int): The desired ordering of dimensions + + Example: + >>> x = torch.randn(2, 3, 5) + >>> x.size() + torch.Size([2, 3, 5]) + >>> torch.permute(x, (2, 0, 1)).size() + torch.Size([5, 2, 3]) + """ + +def permute_copy( + input: Tensor, + dims: _size, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + Performs the same operation as :func:`torch.permute`, but all output tensors + are freshly created instead of aliasing the input. + """ + +def pinverse(input: Tensor, rcond: _float = 1e-15) -> Tensor: + r""" + pinverse(input, rcond=1e-15) -> Tensor + + Alias for :func:`torch.linalg.pinv` + """ + +def pixel_shuffle(input: Tensor, upscale_factor: _int) -> Tensor: ... +def pixel_unshuffle(input: Tensor, downscale_factor: _int) -> Tensor: ... +def poisson(input: Tensor, generator: Generator | None = None) -> Tensor: + r""" + poisson(input, generator=None) -> Tensor + + Returns a tensor of the same size as :attr:`input` with each element + sampled from a Poisson distribution with rate parameter given by the corresponding + element in :attr:`input` i.e., + + .. math:: + \text{out}_i \sim \text{Poisson}(\text{input}_i) + + :attr:`input` must be non-negative. + + Args: + input (Tensor): the input tensor containing the rates of the Poisson distribution + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + + Example:: + + >>> rates = torch.rand(4, 4) * 5 # rate parameter between 0 and 5 + >>> torch.poisson(rates) + tensor([[9., 1., 3., 5.], + [8., 6., 6., 0.], + [0., 4., 5., 3.], + [2., 1., 4., 2.]]) + """ + +def poisson_nll_loss( + input: Tensor, + target: Tensor, + log_input: _bool, + full: _bool, + eps: _float, + reduction: _int, +) -> Tensor: ... +def polar( + abs: Tensor, + angle: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + polar(abs, angle, *, out=None) -> Tensor + + Constructs a complex tensor whose elements are Cartesian coordinates + corresponding to the polar coordinates with absolute value :attr:`abs` and angle + :attr:`angle`. + + .. math:: + \text{out} = \text{abs} \cdot \cos(\text{angle}) + \text{abs} \cdot \sin(\text{angle}) \cdot j + + .. note:: + `torch.polar` is similar to + `std::polar `_ + and does not compute the polar decomposition + of a complex tensor like Python's `cmath.polar` and SciPy's `linalg.polar` do. + The behavior of this function is undefined if `abs` is negative or NaN, or if `angle` is + infinite. + + + Args: + abs (Tensor): The absolute value the complex tensor. Must be float or double. + angle (Tensor): The angle of the complex tensor. Must be same dtype as + :attr:`abs`. + + Keyword args: + out (Tensor): If the inputs are ``torch.float32``, must be + ``torch.complex64``. If the inputs are ``torch.float64``, must be + ``torch.complex128``. + + Example:: + + >>> import numpy as np + >>> abs = torch.tensor([1, 2], dtype=torch.float64) + >>> angle = torch.tensor([np.pi / 2, 5 * np.pi / 4], dtype=torch.float64) + >>> z = torch.polar(abs, angle) + >>> z + tensor([(0.0000+1.0000j), (-1.4142-1.4142j)], dtype=torch.complex128) + """ + +def polygamma( + n: _int, + input: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + polygamma(n, input, *, out=None) -> Tensor + + Alias for :func:`torch.special.polygamma`. + """ + +def positive(input: Tensor) -> Tensor: + r""" + positive(input) -> Tensor + + Returns :attr:`input`. + Throws a runtime error if :attr:`input` is a bool tensor. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> t = torch.randn(5) + >>> t + tensor([ 0.0090, -0.2262, -0.0682, -0.2866, 0.3940]) + >>> torch.positive(t) + tensor([ 0.0090, -0.2262, -0.0682, -0.2866, 0.3940]) + """ + +@overload +def pow( + input: Tensor, + exponent: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + pow(input, exponent, *, out=None) -> Tensor + + Takes the power of each element in :attr:`input` with :attr:`exponent` and + returns a tensor with the result. + + :attr:`exponent` can be either a single ``float`` number or a `Tensor` + with the same number of elements as :attr:`input`. + + When :attr:`exponent` is a scalar value, the operation applied is: + + .. math:: + \text{out}_i = x_i ^ \text{exponent} + + When :attr:`exponent` is a tensor, the operation applied is: + + .. math:: + \text{out}_i = x_i ^ {\text{exponent}_i} + + When :attr:`exponent` is a tensor, the shapes of :attr:`input` + and :attr:`exponent` must be :ref:`broadcastable `. + + Args: + input (Tensor): the input tensor. + exponent (float or tensor): the exponent value + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.4331, 1.2475, 0.6834, -0.2791]) + >>> torch.pow(a, 2) + tensor([ 0.1875, 1.5561, 0.4670, 0.0779]) + >>> exp = torch.arange(1., 5.) + + >>> a = torch.arange(1., 5.) + >>> a + tensor([ 1., 2., 3., 4.]) + >>> exp + tensor([ 1., 2., 3., 4.]) + >>> torch.pow(a, exp) + tensor([ 1., 4., 27., 256.]) + + .. function:: pow(self, exponent, *, out=None) -> Tensor + :noindex: + + :attr:`self` is a scalar ``float`` value, and :attr:`exponent` is a tensor. + The returned tensor :attr:`out` is of the same shape as :attr:`exponent` + + The operation applied is: + + .. math:: + \text{out}_i = \text{self} ^ {\text{exponent}_i} + + Args: + self (float): the scalar base value for the power operation + exponent (Tensor): the exponent tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> exp = torch.arange(1., 5.) + >>> base = 2 + >>> torch.pow(base, exp) + tensor([ 2., 4., 8., 16.]) + """ + +@overload +def pow( + self: Number | _complex, + exponent: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + pow(input, exponent, *, out=None) -> Tensor + + Takes the power of each element in :attr:`input` with :attr:`exponent` and + returns a tensor with the result. + + :attr:`exponent` can be either a single ``float`` number or a `Tensor` + with the same number of elements as :attr:`input`. + + When :attr:`exponent` is a scalar value, the operation applied is: + + .. math:: + \text{out}_i = x_i ^ \text{exponent} + + When :attr:`exponent` is a tensor, the operation applied is: + + .. math:: + \text{out}_i = x_i ^ {\text{exponent}_i} + + When :attr:`exponent` is a tensor, the shapes of :attr:`input` + and :attr:`exponent` must be :ref:`broadcastable `. + + Args: + input (Tensor): the input tensor. + exponent (float or tensor): the exponent value + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.4331, 1.2475, 0.6834, -0.2791]) + >>> torch.pow(a, 2) + tensor([ 0.1875, 1.5561, 0.4670, 0.0779]) + >>> exp = torch.arange(1., 5.) + + >>> a = torch.arange(1., 5.) + >>> a + tensor([ 1., 2., 3., 4.]) + >>> exp + tensor([ 1., 2., 3., 4.]) + >>> torch.pow(a, exp) + tensor([ 1., 4., 27., 256.]) + + .. function:: pow(self, exponent, *, out=None) -> Tensor + :noindex: + + :attr:`self` is a scalar ``float`` value, and :attr:`exponent` is a tensor. + The returned tensor :attr:`out` is of the same shape as :attr:`exponent` + + The operation applied is: + + .. math:: + \text{out}_i = \text{self} ^ {\text{exponent}_i} + + Args: + self (float): the scalar base value for the power operation + exponent (Tensor): the exponent tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> exp = torch.arange(1., 5.) + >>> base = 2 + >>> torch.pow(base, exp) + tensor([ 2., 4., 8., 16.]) + """ + +@overload +def pow( + input: Tensor, + exponent: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + pow(input, exponent, *, out=None) -> Tensor + + Takes the power of each element in :attr:`input` with :attr:`exponent` and + returns a tensor with the result. + + :attr:`exponent` can be either a single ``float`` number or a `Tensor` + with the same number of elements as :attr:`input`. + + When :attr:`exponent` is a scalar value, the operation applied is: + + .. math:: + \text{out}_i = x_i ^ \text{exponent} + + When :attr:`exponent` is a tensor, the operation applied is: + + .. math:: + \text{out}_i = x_i ^ {\text{exponent}_i} + + When :attr:`exponent` is a tensor, the shapes of :attr:`input` + and :attr:`exponent` must be :ref:`broadcastable `. + + Args: + input (Tensor): the input tensor. + exponent (float or tensor): the exponent value + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.4331, 1.2475, 0.6834, -0.2791]) + >>> torch.pow(a, 2) + tensor([ 0.1875, 1.5561, 0.4670, 0.0779]) + >>> exp = torch.arange(1., 5.) + + >>> a = torch.arange(1., 5.) + >>> a + tensor([ 1., 2., 3., 4.]) + >>> exp + tensor([ 1., 2., 3., 4.]) + >>> torch.pow(a, exp) + tensor([ 1., 4., 27., 256.]) + + .. function:: pow(self, exponent, *, out=None) -> Tensor + :noindex: + + :attr:`self` is a scalar ``float`` value, and :attr:`exponent` is a tensor. + The returned tensor :attr:`out` is of the same shape as :attr:`exponent` + + The operation applied is: + + .. math:: + \text{out}_i = \text{self} ^ {\text{exponent}_i} + + Args: + self (float): the scalar base value for the power operation + exponent (Tensor): the exponent tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> exp = torch.arange(1., 5.) + >>> base = 2 + >>> torch.pow(base, exp) + tensor([ 2., 4., 8., 16.]) + """ + +def prelu(input: Tensor, weight: Tensor) -> Tensor: ... +@overload +def prod(input: Tensor, *, dtype: _dtype | None = None) -> Tensor: + r""" + prod(input: Tensor, *, dtype: Optional[_dtype]) -> Tensor + + Returns the product of all elements in the :attr:`input` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[-0.8020, 0.5428, -1.5854]]) + >>> torch.prod(a) + tensor(0.6902) + + .. function:: prod(input, dim, keepdim=False, *, dtype=None) -> Tensor + :noindex: + + Returns the product of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`. + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the output tensor having 1 fewer dimension than :attr:`input`. + + Args: + input (Tensor): the input tensor. + + dim (int, optional): the dimension to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(4, 2) + >>> a + tensor([[ 0.5261, -0.3837], + [ 1.1857, -0.2498], + [-1.1646, 0.0705], + [ 1.1131, -1.0629]]) + >>> torch.prod(a, 1) + tensor([-0.2018, -0.2962, -0.0821, -1.1831]) + """ + +@overload +def prod( + input: Tensor, + dim: _int, + keepdim: _bool = False, + *, + dtype: _dtype | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + prod(input: Tensor, *, dtype: Optional[_dtype]) -> Tensor + + Returns the product of all elements in the :attr:`input` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[-0.8020, 0.5428, -1.5854]]) + >>> torch.prod(a) + tensor(0.6902) + + .. function:: prod(input, dim, keepdim=False, *, dtype=None) -> Tensor + :noindex: + + Returns the product of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`. + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the output tensor having 1 fewer dimension than :attr:`input`. + + Args: + input (Tensor): the input tensor. + + dim (int, optional): the dimension to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(4, 2) + >>> a + tensor([[ 0.5261, -0.3837], + [ 1.1857, -0.2498], + [-1.1646, 0.0705], + [ 1.1131, -1.0629]]) + >>> torch.prod(a, 1) + tensor([-0.2018, -0.2962, -0.0821, -1.1831]) + """ + +@overload +def prod( + input: Tensor, + dim: str | EllipsisType | None, + keepdim: _bool = False, + *, + dtype: _dtype | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + prod(input: Tensor, *, dtype: Optional[_dtype]) -> Tensor + + Returns the product of all elements in the :attr:`input` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[-0.8020, 0.5428, -1.5854]]) + >>> torch.prod(a) + tensor(0.6902) + + .. function:: prod(input, dim, keepdim=False, *, dtype=None) -> Tensor + :noindex: + + Returns the product of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`. + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the output tensor having 1 fewer dimension than :attr:`input`. + + Args: + input (Tensor): the input tensor. + + dim (int, optional): the dimension to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(4, 2) + >>> a + tensor([[ 0.5261, -0.3837], + [ 1.1857, -0.2498], + [-1.1646, 0.0705], + [ 1.1131, -1.0629]]) + >>> torch.prod(a, 1) + tensor([-0.2018, -0.2962, -0.0821, -1.1831]) + """ + +def promote_types(type1: _dtype, type2: _dtype) -> _dtype: + r""" + promote_types(type1, type2) -> dtype + + Returns the :class:`torch.dtype` with the smallest size and scalar kind that is + not smaller nor of lower kind than either `type1` or `type2`. See type promotion + :ref:`documentation ` for more information on the type + promotion logic. + + Args: + type1 (:class:`torch.dtype`) + type2 (:class:`torch.dtype`) + + Example:: + + >>> torch.promote_types(torch.int32, torch.float32) + torch.float32 + >>> torch.promote_types(torch.uint8, torch.long) + torch.long + """ + +def put( + input: Tensor, + index: Tensor, + source: Tensor, + accumulate: _bool = False, +) -> Tensor: ... +def q_per_channel_axis(input: Tensor) -> _int: ... +def q_per_channel_scales(input: Tensor) -> Tensor: ... +def q_per_channel_zero_points(input: Tensor) -> Tensor: ... +def q_scale(input: Tensor) -> _float: ... +def q_zero_point(input: Tensor) -> _int: ... +def qr( + input: Tensor, + some: _bool = True, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.qr: + r""" + qr(input: Tensor, some: bool = True, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None]) -> (Tensor, Tensor) + + Computes the QR decomposition of a matrix or a batch of matrices :attr:`input`, + and returns a namedtuple (Q, R) of tensors such that :math:`\text{input} = Q R` + with :math:`Q` being an orthogonal matrix or batch of orthogonal matrices and + :math:`R` being an upper triangular matrix or batch of upper triangular matrices. + + If :attr:`some` is ``True``, then this function returns the thin (reduced) QR factorization. + Otherwise, if :attr:`some` is ``False``, this function returns the complete QR factorization. + + .. warning:: + + :func:`torch.qr` is deprecated in favor of :func:`torch.linalg.qr` + and will be removed in a future PyTorch release. The boolean parameter :attr:`some` has been + replaced with a string parameter :attr:`mode`. + + ``Q, R = torch.qr(A)`` should be replaced with + + .. code:: python + + Q, R = torch.linalg.qr(A) + + ``Q, R = torch.qr(A, some=False)`` should be replaced with + + .. code:: python + + Q, R = torch.linalg.qr(A, mode="complete") + + .. warning:: + If you plan to backpropagate through QR, note that the current backward implementation + is only well-defined when the first :math:`\min(input.size(-1), input.size(-2))` + columns of :attr:`input` are linearly independent. + This behavior will probably change once QR supports pivoting. + + .. note:: This function uses LAPACK for CPU inputs and MAGMA for CUDA inputs, + and may produce different (valid) decompositions on different device types + or different platforms. + + Args: + input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more + batch dimensions consisting of matrices of dimension :math:`m \times n`. + some (bool, optional): Set to ``True`` for reduced QR decomposition and ``False`` for + complete QR decomposition. If `k = min(m, n)` then: + + * ``some=True`` : returns `(Q, R)` with dimensions (m, k), (k, n) (default) + + * ``'some=False'``: returns `(Q, R)` with dimensions (m, m), (m, n) + + Keyword args: + out (tuple, optional): tuple of `Q` and `R` tensors. + The dimensions of `Q` and `R` are detailed in the description of :attr:`some` above. + + Example:: + + >>> a = torch.tensor([[12., -51, 4], [6, 167, -68], [-4, 24, -41]]) + >>> q, r = torch.qr(a) + >>> q + tensor([[-0.8571, 0.3943, 0.3314], + [-0.4286, -0.9029, -0.0343], + [ 0.2857, -0.1714, 0.9429]]) + >>> r + tensor([[ -14.0000, -21.0000, 14.0000], + [ 0.0000, -175.0000, 70.0000], + [ 0.0000, 0.0000, -35.0000]]) + >>> torch.mm(q, r).round() + tensor([[ 12., -51., 4.], + [ 6., 167., -68.], + [ -4., 24., -41.]]) + >>> torch.mm(q.t(), q).round() + tensor([[ 1., 0., 0.], + [ 0., 1., -0.], + [ 0., -0., 1.]]) + >>> a = torch.randn(3, 4, 5) + >>> q, r = torch.qr(a, some=False) + >>> torch.allclose(torch.matmul(q, r), a) + True + >>> torch.allclose(torch.matmul(q.mT, q), torch.eye(5)) + True + """ + +@overload +def quantile( + input: Tensor, + q: Tensor, + dim: _int | None = None, + keepdim: _bool = False, + *, + interpolation: str = "linear", + out: Tensor | None = None, +) -> Tensor: + r""" + quantile(input, q, dim=None, keepdim=False, *, interpolation='linear', out=None) -> Tensor + + Computes the q-th quantiles of each row of the :attr:`input` tensor along the dimension :attr:`dim`. + + To compute the quantile, we map q in [0, 1] to the range of indices [0, n] to find the location + of the quantile in the sorted input. If the quantile lies between two data points ``a < b`` with + indices ``i`` and ``j`` in the sorted order, result is computed according to the given + :attr:`interpolation` method as follows: + + - ``linear``: ``a + (b - a) * fraction``, where ``fraction`` is the fractional part of the computed quantile index. + - ``lower``: ``a``. + - ``higher``: ``b``. + - ``nearest``: ``a`` or ``b``, whichever's index is closer to the computed quantile index (rounding down for .5 fractions). + - ``midpoint``: ``(a + b) / 2``. + + If :attr:`q` is a 1D tensor, the first dimension of the output represents the quantiles and has size + equal to the size of :attr:`q`, the remaining dimensions are what remains from the reduction. + + .. note:: + By default :attr:`dim` is ``None`` resulting in the :attr:`input` tensor being flattened before computation. + + Args: + input (Tensor): the input tensor. + q (float or Tensor): a scalar or 1D tensor of values in the range [0, 1]. + + dim (int, optional): the dimension to reduce. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword arguments: + interpolation (str): interpolation method to use when the desired quantile lies between two data points. + Can be ``linear``, ``lower``, ``higher``, ``midpoint`` and ``nearest``. + Default is ``linear``. + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(2, 3) + >>> a + tensor([[ 0.0795, -1.2117, 0.9765], + [ 1.1707, 0.6706, 0.4884]]) + >>> q = torch.tensor([0.25, 0.5, 0.75]) + >>> torch.quantile(a, q, dim=1, keepdim=True) + tensor([[[-0.5661], + [ 0.5795]], + + [[ 0.0795], + [ 0.6706]], + + [[ 0.5280], + [ 0.9206]]]) + >>> torch.quantile(a, q, dim=1, keepdim=True).shape + torch.Size([3, 2, 1]) + >>> a = torch.arange(4.) + >>> a + tensor([0., 1., 2., 3.]) + >>> torch.quantile(a, 0.6, interpolation='linear') + tensor(1.8000) + >>> torch.quantile(a, 0.6, interpolation='lower') + tensor(1.) + >>> torch.quantile(a, 0.6, interpolation='higher') + tensor(2.) + >>> torch.quantile(a, 0.6, interpolation='midpoint') + tensor(1.5000) + >>> torch.quantile(a, 0.6, interpolation='nearest') + tensor(2.) + >>> torch.quantile(a, 0.4, interpolation='nearest') + tensor(1.) + """ + +@overload +def quantile( + input: Tensor, + q: _float, + dim: _int | None = None, + keepdim: _bool = False, + *, + interpolation: str = "linear", + out: Tensor | None = None, +) -> Tensor: + r""" + quantile(input, q, dim=None, keepdim=False, *, interpolation='linear', out=None) -> Tensor + + Computes the q-th quantiles of each row of the :attr:`input` tensor along the dimension :attr:`dim`. + + To compute the quantile, we map q in [0, 1] to the range of indices [0, n] to find the location + of the quantile in the sorted input. If the quantile lies between two data points ``a < b`` with + indices ``i`` and ``j`` in the sorted order, result is computed according to the given + :attr:`interpolation` method as follows: + + - ``linear``: ``a + (b - a) * fraction``, where ``fraction`` is the fractional part of the computed quantile index. + - ``lower``: ``a``. + - ``higher``: ``b``. + - ``nearest``: ``a`` or ``b``, whichever's index is closer to the computed quantile index (rounding down for .5 fractions). + - ``midpoint``: ``(a + b) / 2``. + + If :attr:`q` is a 1D tensor, the first dimension of the output represents the quantiles and has size + equal to the size of :attr:`q`, the remaining dimensions are what remains from the reduction. + + .. note:: + By default :attr:`dim` is ``None`` resulting in the :attr:`input` tensor being flattened before computation. + + Args: + input (Tensor): the input tensor. + q (float or Tensor): a scalar or 1D tensor of values in the range [0, 1]. + + dim (int, optional): the dimension to reduce. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword arguments: + interpolation (str): interpolation method to use when the desired quantile lies between two data points. + Can be ``linear``, ``lower``, ``higher``, ``midpoint`` and ``nearest``. + Default is ``linear``. + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(2, 3) + >>> a + tensor([[ 0.0795, -1.2117, 0.9765], + [ 1.1707, 0.6706, 0.4884]]) + >>> q = torch.tensor([0.25, 0.5, 0.75]) + >>> torch.quantile(a, q, dim=1, keepdim=True) + tensor([[[-0.5661], + [ 0.5795]], + + [[ 0.0795], + [ 0.6706]], + + [[ 0.5280], + [ 0.9206]]]) + >>> torch.quantile(a, q, dim=1, keepdim=True).shape + torch.Size([3, 2, 1]) + >>> a = torch.arange(4.) + >>> a + tensor([0., 1., 2., 3.]) + >>> torch.quantile(a, 0.6, interpolation='linear') + tensor(1.8000) + >>> torch.quantile(a, 0.6, interpolation='lower') + tensor(1.) + >>> torch.quantile(a, 0.6, interpolation='higher') + tensor(2.) + >>> torch.quantile(a, 0.6, interpolation='midpoint') + tensor(1.5000) + >>> torch.quantile(a, 0.6, interpolation='nearest') + tensor(2.) + >>> torch.quantile(a, 0.4, interpolation='nearest') + tensor(1.) + """ + +def quantize_per_channel( + input: Tensor, + scales: Tensor, + zero_points: Tensor, + axis: _int, + dtype: _dtype, +) -> Tensor: + r""" + quantize_per_channel(input, scales, zero_points, axis, dtype) -> Tensor + + Converts a float tensor to a per-channel quantized tensor with given scales and zero points. + + Arguments: + input (Tensor): float tensor to quantize + scales (Tensor): float 1D tensor of scales to use, size should match ``input.size(axis)`` + zero_points (int): integer 1D tensor of offset to use, size should match ``input.size(axis)`` + axis (int): dimension on which apply per-channel quantization + dtype (:class:`torch.dtype`): the desired data type of returned tensor. + Has to be one of the quantized dtypes: ``torch.quint8``, ``torch.qint8``, ``torch.qint32`` + + Returns: + Tensor: A newly quantized tensor + + Example:: + + >>> x = torch.tensor([[-1.0, 0.0], [1.0, 2.0]]) + >>> torch.quantize_per_channel(x, torch.tensor([0.1, 0.01]), torch.tensor([10, 0]), 0, torch.quint8) + tensor([[-1., 0.], + [ 1., 2.]], size=(2, 2), dtype=torch.quint8, + quantization_scheme=torch.per_channel_affine, + scale=tensor([0.1000, 0.0100], dtype=torch.float64), + zero_point=tensor([10, 0]), axis=0) + >>> torch.quantize_per_channel(x, torch.tensor([0.1, 0.01]), torch.tensor([10, 0]), 0, torch.quint8).int_repr() + tensor([[ 0, 10], + [100, 200]], dtype=torch.uint8) + """ + +@overload +def quantize_per_tensor( + input: Tensor, + scale: Tensor, + zero_point: Tensor, + dtype: _dtype, +) -> Tensor: + r""" + quantize_per_tensor(input, scale, zero_point, dtype) -> Tensor + + Converts a float tensor to a quantized tensor with given scale and zero point. + + Arguments: + input (Tensor): float tensor or list of tensors to quantize + scale (float or Tensor): scale to apply in quantization formula + zero_point (int or Tensor): offset in integer value that maps to float zero + dtype (:class:`torch.dtype`): the desired data type of returned tensor. + Has to be one of the quantized dtypes: ``torch.quint8``, ``torch.qint8``, ``torch.qint32`` + + Returns: + Tensor: A newly quantized tensor or list of quantized tensors. + + Example:: + + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8) + tensor([-1., 0., 1., 2.], size=(4,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.1, zero_point=10) + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8).int_repr() + tensor([ 0, 10, 20, 30], dtype=torch.uint8) + >>> torch.quantize_per_tensor([torch.tensor([-1.0, 0.0]), torch.tensor([-2.0, 2.0])], + >>> torch.tensor([0.1, 0.2]), torch.tensor([10, 20]), torch.quint8) + (tensor([-1., 0.], size=(2,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.1, zero_point=10), + tensor([-2., 2.], size=(2,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.2, zero_point=20)) + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), torch.tensor(0.1), torch.tensor(10), torch.quint8) + tensor([-1., 0., 1., 2.], size=(4,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.10, zero_point=10) + """ + +@overload +def quantize_per_tensor( + input: Tensor, + scale: _float, + zero_point: _int, + dtype: _dtype, +) -> Tensor: + r""" + quantize_per_tensor(input, scale, zero_point, dtype) -> Tensor + + Converts a float tensor to a quantized tensor with given scale and zero point. + + Arguments: + input (Tensor): float tensor or list of tensors to quantize + scale (float or Tensor): scale to apply in quantization formula + zero_point (int or Tensor): offset in integer value that maps to float zero + dtype (:class:`torch.dtype`): the desired data type of returned tensor. + Has to be one of the quantized dtypes: ``torch.quint8``, ``torch.qint8``, ``torch.qint32`` + + Returns: + Tensor: A newly quantized tensor or list of quantized tensors. + + Example:: + + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8) + tensor([-1., 0., 1., 2.], size=(4,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.1, zero_point=10) + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8).int_repr() + tensor([ 0, 10, 20, 30], dtype=torch.uint8) + >>> torch.quantize_per_tensor([torch.tensor([-1.0, 0.0]), torch.tensor([-2.0, 2.0])], + >>> torch.tensor([0.1, 0.2]), torch.tensor([10, 20]), torch.quint8) + (tensor([-1., 0.], size=(2,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.1, zero_point=10), + tensor([-2., 2.], size=(2,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.2, zero_point=20)) + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), torch.tensor(0.1), torch.tensor(10), torch.quint8) + tensor([-1., 0., 1., 2.], size=(4,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.10, zero_point=10) + """ + +@overload +def quantize_per_tensor( + tensors: tuple[Tensor, ...] | list[Tensor] | None, + scales: Tensor, + zero_points: Tensor, + dtype: _dtype, +) -> tuple[Tensor, ...]: + r""" + quantize_per_tensor(input, scale, zero_point, dtype) -> Tensor + + Converts a float tensor to a quantized tensor with given scale and zero point. + + Arguments: + input (Tensor): float tensor or list of tensors to quantize + scale (float or Tensor): scale to apply in quantization formula + zero_point (int or Tensor): offset in integer value that maps to float zero + dtype (:class:`torch.dtype`): the desired data type of returned tensor. + Has to be one of the quantized dtypes: ``torch.quint8``, ``torch.qint8``, ``torch.qint32`` + + Returns: + Tensor: A newly quantized tensor or list of quantized tensors. + + Example:: + + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8) + tensor([-1., 0., 1., 2.], size=(4,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.1, zero_point=10) + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8).int_repr() + tensor([ 0, 10, 20, 30], dtype=torch.uint8) + >>> torch.quantize_per_tensor([torch.tensor([-1.0, 0.0]), torch.tensor([-2.0, 2.0])], + >>> torch.tensor([0.1, 0.2]), torch.tensor([10, 20]), torch.quint8) + (tensor([-1., 0.], size=(2,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.1, zero_point=10), + tensor([-2., 2.], size=(2,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.2, zero_point=20)) + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), torch.tensor(0.1), torch.tensor(10), torch.quint8) + tensor([-1., 0., 1., 2.], size=(4,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.10, zero_point=10) + """ + +def quantize_per_tensor_dynamic( + input: Tensor, + dtype: _dtype, + reduce_range: _bool, +) -> Tensor: + r""" + quantize_per_tensor_dynamic(input, dtype, reduce_range) -> Tensor + + Converts a float tensor to a quantized tensor with scale and zero_point calculated + dynamically based on the input. + + Arguments: + input (Tensor): float tensor or list of tensors to quantize + dtype (:class:`torch.dtype`): the desired data type of returned tensor. + Has to be one of the quantized dtypes: ``torch.quint8``, ``torch.qint8`` + reduce_range (bool): a flag to indicate whether to reduce the range of quantized + data by 1 bit, it's required to avoid instruction overflow for some hardwares + + Returns: + Tensor: A newly (dynamically) quantized tensor + + Example:: + + >>> t = torch.quantize_per_tensor_dynamic(torch.tensor([-1.0, 0.0, 1.0, 2.0]), torch.quint8, False) + >>> print(t) + tensor([-1., 0., 1., 2.], size=(4,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.011764705882352941, + zero_point=85) + >>> t.int_repr() + tensor([ 0, 85, 170, 255], dtype=torch.uint8) + """ + +def quantized_batch_norm( + input: Tensor, + weight: Tensor | None, + bias: Tensor | None, + mean: Tensor, + var: Tensor, + eps: _float, + output_scale: _float, + output_zero_point: _int, +) -> Tensor: + r""" + quantized_batch_norm(input, weight=None, bias=None, mean, var, eps, output_scale, output_zero_point) -> Tensor + + Applies batch normalization on a 4D (NCHW) quantized tensor. + + .. math:: + + y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + Arguments: + input (Tensor): quantized tensor + weight (Tensor): float tensor that corresponds to the gamma, size C + bias (Tensor): float tensor that corresponds to the beta, size C + mean (Tensor): float mean value in batch normalization, size C + var (Tensor): float tensor for variance, size C + eps (float): a value added to the denominator for numerical stability. + output_scale (float): output quantized tensor scale + output_zero_point (int): output quantized tensor zero_point + + Returns: + Tensor: A quantized tensor with batch normalization applied. + + Example:: + + >>> qx = torch.quantize_per_tensor(torch.rand(2, 2, 2, 2), 1.5, 3, torch.quint8) + >>> torch.quantized_batch_norm(qx, torch.ones(2), torch.zeros(2), torch.rand(2), torch.rand(2), 0.00001, 0.2, 2) + tensor([[[[-0.2000, -0.2000], + [ 1.6000, -0.2000]], + + [[-0.4000, -0.4000], + [-0.4000, 0.6000]]], + + + [[[-0.2000, -0.2000], + [-0.2000, -0.2000]], + + [[ 0.6000, -0.4000], + [ 0.6000, -0.4000]]]], size=(2, 2, 2, 2), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.2, zero_point=2) + """ + +def quantized_gru_cell( + input: Tensor, + hx: Tensor, + w_ih: Tensor, + w_hh: Tensor, + b_ih: Tensor, + b_hh: Tensor, + packed_ih: Tensor, + packed_hh: Tensor, + col_offsets_ih: Tensor, + col_offsets_hh: Tensor, + scale_ih: Number | _complex, + scale_hh: Number | _complex, + zero_point_ih: Number | _complex, + zero_point_hh: Number | _complex, +) -> Tensor: ... +def quantized_lstm_cell( + input: Tensor, + hx: tuple[Tensor, ...] | list[Tensor] | None, + w_ih: Tensor, + w_hh: Tensor, + b_ih: Tensor, + b_hh: Tensor, + packed_ih: Tensor, + packed_hh: Tensor, + col_offsets_ih: Tensor, + col_offsets_hh: Tensor, + scale_ih: Number | _complex, + scale_hh: Number | _complex, + zero_point_ih: Number | _complex, + zero_point_hh: Number | _complex, +) -> tuple[Tensor, Tensor]: ... +def quantized_max_pool1d( + input: Tensor, + kernel_size: _int | _size, + stride: _int | _size = (), + padding: _int | _size = 0, + dilation: _int | _size = 1, + ceil_mode: _bool = False, +) -> Tensor: + r""" + quantized_max_pool1d(input, kernel_size, stride=[], padding=0, dilation=1, ceil_mode=False) -> Tensor + + Applies a 1D max pooling over an input quantized tensor composed of several input planes. + + Arguments: + input (Tensor): quantized tensor + kernel_size (list of int): the size of the sliding window + stride (``list of int``, optional): the stride of the sliding window + padding (``list of int``, optional): padding to be added on both sides, must be >= 0 and <= kernel_size / 2 + dilation (``list of int``, optional): The stride between elements within a sliding window, must be > 0. Default 1 + ceil_mode (bool, optional): If True, will use ceil instead of floor to compute the output shape. + Defaults to False. + + + Returns: + Tensor: A quantized tensor with max_pool1d applied. + + Example:: + + >>> qx = torch.quantize_per_tensor(torch.rand(2, 2), 1.5, 3, torch.quint8) + >>> torch.quantized_max_pool1d(qx, [2]) + tensor([[0.0000], + [1.5000]], size=(2, 1), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=1.5, zero_point=3) + """ + +def quantized_max_pool2d( + input: Tensor, + kernel_size: _int | _size, + stride: _int | _size = (), + padding: _int | _size = 0, + dilation: _int | _size = 1, + ceil_mode: _bool = False, +) -> Tensor: + r""" + quantized_max_pool2d(input, kernel_size, stride=[], padding=0, dilation=1, ceil_mode=False) -> Tensor + + Applies a 2D max pooling over an input quantized tensor composed of several input planes. + + Arguments: + input (Tensor): quantized tensor + kernel_size (``list of int``): the size of the sliding window + stride (``list of int``, optional): the stride of the sliding window + padding (``list of int``, optional): padding to be added on both sides, must be >= 0 and <= kernel_size / 2 + dilation (``list of int``, optional): The stride between elements within a sliding window, must be > 0. Default 1 + ceil_mode (bool, optional): If True, will use ceil instead of floor to compute the output shape. + Defaults to False. + + + Returns: + Tensor: A quantized tensor with max_pool2d applied. + + Example:: + + >>> qx = torch.quantize_per_tensor(torch.rand(2, 2, 2, 2), 1.5, 3, torch.quint8) + >>> torch.quantized_max_pool2d(qx, [2,2]) + tensor([[[[1.5000]], + + [[1.5000]]], + + + [[[0.0000]], + + [[0.0000]]]], size=(2, 2, 1, 1), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=1.5, zero_point=3) + """ + +def quantized_max_pool3d( + input: Tensor, + kernel_size: _int | _size, + stride: _int | _size = (), + padding: _int | _size = 0, + dilation: _int | _size = 1, + ceil_mode: _bool = False, +) -> Tensor: ... +def quantized_rnn_relu_cell( + input: Tensor, + hx: Tensor, + w_ih: Tensor, + w_hh: Tensor, + b_ih: Tensor, + b_hh: Tensor, + packed_ih: Tensor, + packed_hh: Tensor, + col_offsets_ih: Tensor, + col_offsets_hh: Tensor, + scale_ih: Number | _complex, + scale_hh: Number | _complex, + zero_point_ih: Number | _complex, + zero_point_hh: Number | _complex, +) -> Tensor: ... +def quantized_rnn_tanh_cell( + input: Tensor, + hx: Tensor, + w_ih: Tensor, + w_hh: Tensor, + b_ih: Tensor, + b_hh: Tensor, + packed_ih: Tensor, + packed_hh: Tensor, + col_offsets_ih: Tensor, + col_offsets_hh: Tensor, + scale_ih: Number | _complex, + scale_hh: Number | _complex, + zero_point_ih: Number | _complex, + zero_point_hh: Number | _complex, +) -> Tensor: ... +def rad2deg(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + rad2deg(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Returns a new tensor with each of the elements of :attr:`input` + converted from angles in radians to degrees. + + Args: + input (Tensor): the input tensor. + + Keyword arguments: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([[3.142, -3.142], [6.283, -6.283], [1.570, -1.570]]) + >>> torch.rad2deg(a) + tensor([[ 180.0233, -180.0233], + [ 359.9894, -359.9894], + [ 89.9544, -89.9544]]) + """ + +def rad2deg_(input: Tensor) -> Tensor: ... +@overload +def rand( + size: Sequence[_int | SymInt], + *, + generator: Generator | None, + names: Sequence[str | EllipsisType | None] | None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + rand(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a tensor filled with random numbers from a uniform distribution + on the interval :math:`[0, 1)` + + The shape of the tensor is defined by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.rand(4) + tensor([ 0.5204, 0.2503, 0.3525, 0.5673]) + >>> torch.rand(2, 3) + tensor([[ 0.8237, 0.5781, 0.6879], + [ 0.3816, 0.7249, 0.0998]]) + """ + +@overload +def rand( + *size: _int | SymInt, + generator: Generator | None, + names: Sequence[str | EllipsisType | None] | None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + rand(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a tensor filled with random numbers from a uniform distribution + on the interval :math:`[0, 1)` + + The shape of the tensor is defined by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.rand(4) + tensor([ 0.5204, 0.2503, 0.3525, 0.5673]) + >>> torch.rand(2, 3) + tensor([[ 0.8237, 0.5781, 0.6879], + [ 0.3816, 0.7249, 0.0998]]) + """ + +@overload +def rand( + size: Sequence[_int | SymInt], + *, + generator: Generator | None, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + rand(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a tensor filled with random numbers from a uniform distribution + on the interval :math:`[0, 1)` + + The shape of the tensor is defined by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.rand(4) + tensor([ 0.5204, 0.2503, 0.3525, 0.5673]) + >>> torch.rand(2, 3) + tensor([[ 0.8237, 0.5781, 0.6879], + [ 0.3816, 0.7249, 0.0998]]) + """ + +@overload +def rand( + *size: _int | SymInt, + generator: Generator | None, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + rand(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a tensor filled with random numbers from a uniform distribution + on the interval :math:`[0, 1)` + + The shape of the tensor is defined by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.rand(4) + tensor([ 0.5204, 0.2503, 0.3525, 0.5673]) + >>> torch.rand(2, 3) + tensor([[ 0.8237, 0.5781, 0.6879], + [ 0.3816, 0.7249, 0.0998]]) + """ + +@overload +def rand( + size: Sequence[_int | SymInt], + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + rand(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a tensor filled with random numbers from a uniform distribution + on the interval :math:`[0, 1)` + + The shape of the tensor is defined by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.rand(4) + tensor([ 0.5204, 0.2503, 0.3525, 0.5673]) + >>> torch.rand(2, 3) + tensor([[ 0.8237, 0.5781, 0.6879], + [ 0.3816, 0.7249, 0.0998]]) + """ + +@overload +def rand( + *size: _int | SymInt, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + rand(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a tensor filled with random numbers from a uniform distribution + on the interval :math:`[0, 1)` + + The shape of the tensor is defined by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.rand(4) + tensor([ 0.5204, 0.2503, 0.3525, 0.5673]) + >>> torch.rand(2, 3) + tensor([[ 0.8237, 0.5781, 0.6879], + [ 0.3816, 0.7249, 0.0998]]) + """ + +@overload +def rand( + size: Sequence[_int | SymInt], + *, + names: Sequence[str | EllipsisType | None] | None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + rand(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a tensor filled with random numbers from a uniform distribution + on the interval :math:`[0, 1)` + + The shape of the tensor is defined by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.rand(4) + tensor([ 0.5204, 0.2503, 0.3525, 0.5673]) + >>> torch.rand(2, 3) + tensor([[ 0.8237, 0.5781, 0.6879], + [ 0.3816, 0.7249, 0.0998]]) + """ + +@overload +def rand( + *size: _int | SymInt, + names: Sequence[str | EllipsisType | None] | None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + rand(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a tensor filled with random numbers from a uniform distribution + on the interval :math:`[0, 1)` + + The shape of the tensor is defined by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.rand(4) + tensor([ 0.5204, 0.2503, 0.3525, 0.5673]) + >>> torch.rand(2, 3) + tensor([[ 0.8237, 0.5781, 0.6879], + [ 0.3816, 0.7249, 0.0998]]) + """ + +def rand_like( + input: Tensor, + *, + memory_format: memory_format | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + rand_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + + Returns a tensor with the same size as :attr:`input` that is filled with + random numbers from a uniform distribution on the interval :math:`[0, 1)`. + ``torch.rand_like(input)`` is equivalent to + ``torch.rand(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``. + + Args: + input (Tensor): the size of :attr:`input` will determine size of the output tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: if ``None``, defaults to the dtype of :attr:`input`. + layout (:class:`torch.layout`, optional): the desired layout of returned tensor. + Default: if ``None``, defaults to the layout of :attr:`input`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, defaults to the device of :attr:`input`. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + +@overload +def randint( + low: _int, + high: _int, + size: _size, + *, + generator: Generator | None = None, + dtype: _dtype | None = None, + device: DeviceLikeType | None = None, + requires_grad: _bool = False, + pin_memory: _bool = False, +) -> Tensor: + r""" + randint(low=0, high, size, \*, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with random integers generated uniformly + between :attr:`low` (inclusive) and :attr:`high` (exclusive). + + The shape of the tensor is defined by the variable argument :attr:`size`. + + .. note:: + With the global dtype default (``torch.float32``), this function returns + a tensor with dtype ``torch.int64``. + + Args: + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + size (tuple): a tuple defining the shape of the output tensor. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (`torch.dtype`, optional) - the desired data type of returned tensor. Default: if ``None``, + this function returns a tensor with dtype ``torch.int64``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.randint(3, 5, (3,)) + tensor([4, 3, 4]) + + + >>> torch.randint(10, (2, 2)) + tensor([[0, 2], + [5, 5]]) + + + >>> torch.randint(3, 10, (2, 2)) + tensor([[4, 5], + [6, 7]]) + """ + +@overload +def randint( + high: _int, + size: _size, + *, + generator: Generator | None = None, + dtype: _dtype | None = None, + device: DeviceLikeType | None = None, + requires_grad: _bool = False, + pin_memory: _bool = False, +) -> Tensor: + r""" + randint(low=0, high, size, \*, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with random integers generated uniformly + between :attr:`low` (inclusive) and :attr:`high` (exclusive). + + The shape of the tensor is defined by the variable argument :attr:`size`. + + .. note:: + With the global dtype default (``torch.float32``), this function returns + a tensor with dtype ``torch.int64``. + + Args: + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + size (tuple): a tuple defining the shape of the output tensor. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (`torch.dtype`, optional) - the desired data type of returned tensor. Default: if ``None``, + this function returns a tensor with dtype ``torch.int64``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.randint(3, 5, (3,)) + tensor([4, 3, 4]) + + + >>> torch.randint(10, (2, 2)) + tensor([[0, 2], + [5, 5]]) + + + >>> torch.randint(3, 10, (2, 2)) + tensor([[4, 5], + [6, 7]]) + """ + +@overload +def randint( + high: _int | SymInt, + size: Sequence[_int | SymInt], + *, + generator: Generator | None, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + randint(low=0, high, size, \*, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with random integers generated uniformly + between :attr:`low` (inclusive) and :attr:`high` (exclusive). + + The shape of the tensor is defined by the variable argument :attr:`size`. + + .. note:: + With the global dtype default (``torch.float32``), this function returns + a tensor with dtype ``torch.int64``. + + Args: + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + size (tuple): a tuple defining the shape of the output tensor. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (`torch.dtype`, optional) - the desired data type of returned tensor. Default: if ``None``, + this function returns a tensor with dtype ``torch.int64``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.randint(3, 5, (3,)) + tensor([4, 3, 4]) + + + >>> torch.randint(10, (2, 2)) + tensor([[0, 2], + [5, 5]]) + + + >>> torch.randint(3, 10, (2, 2)) + tensor([[4, 5], + [6, 7]]) + """ + +@overload +def randint( + high: _int | SymInt, + size: Sequence[_int | SymInt], + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + randint(low=0, high, size, \*, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with random integers generated uniformly + between :attr:`low` (inclusive) and :attr:`high` (exclusive). + + The shape of the tensor is defined by the variable argument :attr:`size`. + + .. note:: + With the global dtype default (``torch.float32``), this function returns + a tensor with dtype ``torch.int64``. + + Args: + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + size (tuple): a tuple defining the shape of the output tensor. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (`torch.dtype`, optional) - the desired data type of returned tensor. Default: if ``None``, + this function returns a tensor with dtype ``torch.int64``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.randint(3, 5, (3,)) + tensor([4, 3, 4]) + + + >>> torch.randint(10, (2, 2)) + tensor([[0, 2], + [5, 5]]) + + + >>> torch.randint(3, 10, (2, 2)) + tensor([[4, 5], + [6, 7]]) + """ + +@overload +def randint( + low: _int | SymInt, + high: _int | SymInt, + size: Sequence[_int | SymInt], + *, + generator: Generator | None, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + randint(low=0, high, size, \*, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with random integers generated uniformly + between :attr:`low` (inclusive) and :attr:`high` (exclusive). + + The shape of the tensor is defined by the variable argument :attr:`size`. + + .. note:: + With the global dtype default (``torch.float32``), this function returns + a tensor with dtype ``torch.int64``. + + Args: + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + size (tuple): a tuple defining the shape of the output tensor. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (`torch.dtype`, optional) - the desired data type of returned tensor. Default: if ``None``, + this function returns a tensor with dtype ``torch.int64``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.randint(3, 5, (3,)) + tensor([4, 3, 4]) + + + >>> torch.randint(10, (2, 2)) + tensor([[0, 2], + [5, 5]]) + + + >>> torch.randint(3, 10, (2, 2)) + tensor([[4, 5], + [6, 7]]) + """ + +@overload +def randint( + low: _int | SymInt, + high: _int | SymInt, + size: Sequence[_int | SymInt], + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + randint(low=0, high, size, \*, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with random integers generated uniformly + between :attr:`low` (inclusive) and :attr:`high` (exclusive). + + The shape of the tensor is defined by the variable argument :attr:`size`. + + .. note:: + With the global dtype default (``torch.float32``), this function returns + a tensor with dtype ``torch.int64``. + + Args: + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + size (tuple): a tuple defining the shape of the output tensor. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (`torch.dtype`, optional) - the desired data type of returned tensor. Default: if ``None``, + this function returns a tensor with dtype ``torch.int64``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.randint(3, 5, (3,)) + tensor([4, 3, 4]) + + + >>> torch.randint(10, (2, 2)) + tensor([[0, 2], + [5, 5]]) + + + >>> torch.randint(3, 10, (2, 2)) + tensor([[4, 5], + [6, 7]]) + """ + +@overload +def randint_like( + input: Tensor, + low: _int | SymInt, + high: _int | SymInt, + *, + memory_format: memory_format | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + randint_like(input, low=0, high, \*, dtype=None, layout=torch.strided, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + + Returns a tensor with the same shape as Tensor :attr:`input` filled with + random integers generated uniformly between :attr:`low` (inclusive) and + :attr:`high` (exclusive). + + .. note: + With the global dtype default (``torch.float32``), this function returns + a tensor with dtype ``torch.int64``. + + Args: + input (Tensor): the size of :attr:`input` will determine size of the output tensor. + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: if ``None``, defaults to the dtype of :attr:`input`. + layout (:class:`torch.layout`, optional): the desired layout of returned tensor. + Default: if ``None``, defaults to the layout of :attr:`input`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, defaults to the device of :attr:`input`. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + +@overload +def randint_like( + input: Tensor, + high: Tensor, + *, + memory_format: memory_format | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + randint_like(input, low=0, high, \*, dtype=None, layout=torch.strided, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + + Returns a tensor with the same shape as Tensor :attr:`input` filled with + random integers generated uniformly between :attr:`low` (inclusive) and + :attr:`high` (exclusive). + + .. note: + With the global dtype default (``torch.float32``), this function returns + a tensor with dtype ``torch.int64``. + + Args: + input (Tensor): the size of :attr:`input` will determine size of the output tensor. + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: if ``None``, defaults to the dtype of :attr:`input`. + layout (:class:`torch.layout`, optional): the desired layout of returned tensor. + Default: if ``None``, defaults to the layout of :attr:`input`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, defaults to the device of :attr:`input`. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + +@overload +def randint_like( + input: Tensor, + high: _int | SymInt, + *, + memory_format: memory_format | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + randint_like(input, low=0, high, \*, dtype=None, layout=torch.strided, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + + Returns a tensor with the same shape as Tensor :attr:`input` filled with + random integers generated uniformly between :attr:`low` (inclusive) and + :attr:`high` (exclusive). + + .. note: + With the global dtype default (``torch.float32``), this function returns + a tensor with dtype ``torch.int64``. + + Args: + input (Tensor): the size of :attr:`input` will determine size of the output tensor. + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: if ``None``, defaults to the dtype of :attr:`input`. + layout (:class:`torch.layout`, optional): the desired layout of returned tensor. + Default: if ``None``, defaults to the layout of :attr:`input`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, defaults to the device of :attr:`input`. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + +@overload +def randn( + size: Sequence[_int | SymInt], + *, + generator: Generator | None, + names: Sequence[str | EllipsisType | None] | None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + randn(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + + Returns a tensor filled with random numbers from a normal distribution + with mean `0` and variance `1` (also called the standard normal + distribution). + + .. math:: + \text{out}_{i} \sim \mathcal{N}(0, 1) + + For complex dtypes, the tensor is i.i.d. sampled from a `complex normal distribution`_ with zero mean and + unit variance as + + .. math:: + \text{out}_{i} \sim \mathcal{CN}(0, 1) + + This is equivalent to separately sampling the real :math:`(\operatorname{Re})` and imaginary + :math:`(\operatorname{Im})` part of :math:`\text{out}_i` as + + .. math:: + \operatorname{Re}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}),\quad + \operatorname{Im}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}) + + The shape of the tensor is defined by the variable argument :attr:`size`. + + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randn(4) + tensor([-2.1436, 0.9966, 2.3426, -0.6366]) + >>> torch.randn(2, 3) + tensor([[ 1.5954, 2.8929, -1.0923], + [ 1.1719, -0.4709, -0.1996]]) + + .. _complex normal distribution: https://en.wikipedia.org/wiki/Complex_normal_distribution + """ + +@overload +def randn( + *size: _int | SymInt, + generator: Generator | None, + names: Sequence[str | EllipsisType | None] | None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + randn(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + + Returns a tensor filled with random numbers from a normal distribution + with mean `0` and variance `1` (also called the standard normal + distribution). + + .. math:: + \text{out}_{i} \sim \mathcal{N}(0, 1) + + For complex dtypes, the tensor is i.i.d. sampled from a `complex normal distribution`_ with zero mean and + unit variance as + + .. math:: + \text{out}_{i} \sim \mathcal{CN}(0, 1) + + This is equivalent to separately sampling the real :math:`(\operatorname{Re})` and imaginary + :math:`(\operatorname{Im})` part of :math:`\text{out}_i` as + + .. math:: + \operatorname{Re}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}),\quad + \operatorname{Im}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}) + + The shape of the tensor is defined by the variable argument :attr:`size`. + + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randn(4) + tensor([-2.1436, 0.9966, 2.3426, -0.6366]) + >>> torch.randn(2, 3) + tensor([[ 1.5954, 2.8929, -1.0923], + [ 1.1719, -0.4709, -0.1996]]) + + .. _complex normal distribution: https://en.wikipedia.org/wiki/Complex_normal_distribution + """ + +@overload +def randn( + size: Sequence[_int | SymInt], + *, + generator: Generator | None, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + randn(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + + Returns a tensor filled with random numbers from a normal distribution + with mean `0` and variance `1` (also called the standard normal + distribution). + + .. math:: + \text{out}_{i} \sim \mathcal{N}(0, 1) + + For complex dtypes, the tensor is i.i.d. sampled from a `complex normal distribution`_ with zero mean and + unit variance as + + .. math:: + \text{out}_{i} \sim \mathcal{CN}(0, 1) + + This is equivalent to separately sampling the real :math:`(\operatorname{Re})` and imaginary + :math:`(\operatorname{Im})` part of :math:`\text{out}_i` as + + .. math:: + \operatorname{Re}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}),\quad + \operatorname{Im}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}) + + The shape of the tensor is defined by the variable argument :attr:`size`. + + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randn(4) + tensor([-2.1436, 0.9966, 2.3426, -0.6366]) + >>> torch.randn(2, 3) + tensor([[ 1.5954, 2.8929, -1.0923], + [ 1.1719, -0.4709, -0.1996]]) + + .. _complex normal distribution: https://en.wikipedia.org/wiki/Complex_normal_distribution + """ + +@overload +def randn( + *size: _int | SymInt, + generator: Generator | None, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + randn(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + + Returns a tensor filled with random numbers from a normal distribution + with mean `0` and variance `1` (also called the standard normal + distribution). + + .. math:: + \text{out}_{i} \sim \mathcal{N}(0, 1) + + For complex dtypes, the tensor is i.i.d. sampled from a `complex normal distribution`_ with zero mean and + unit variance as + + .. math:: + \text{out}_{i} \sim \mathcal{CN}(0, 1) + + This is equivalent to separately sampling the real :math:`(\operatorname{Re})` and imaginary + :math:`(\operatorname{Im})` part of :math:`\text{out}_i` as + + .. math:: + \operatorname{Re}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}),\quad + \operatorname{Im}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}) + + The shape of the tensor is defined by the variable argument :attr:`size`. + + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randn(4) + tensor([-2.1436, 0.9966, 2.3426, -0.6366]) + >>> torch.randn(2, 3) + tensor([[ 1.5954, 2.8929, -1.0923], + [ 1.1719, -0.4709, -0.1996]]) + + .. _complex normal distribution: https://en.wikipedia.org/wiki/Complex_normal_distribution + """ + +@overload +def randn( + size: Sequence[_int | SymInt], + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + randn(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + + Returns a tensor filled with random numbers from a normal distribution + with mean `0` and variance `1` (also called the standard normal + distribution). + + .. math:: + \text{out}_{i} \sim \mathcal{N}(0, 1) + + For complex dtypes, the tensor is i.i.d. sampled from a `complex normal distribution`_ with zero mean and + unit variance as + + .. math:: + \text{out}_{i} \sim \mathcal{CN}(0, 1) + + This is equivalent to separately sampling the real :math:`(\operatorname{Re})` and imaginary + :math:`(\operatorname{Im})` part of :math:`\text{out}_i` as + + .. math:: + \operatorname{Re}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}),\quad + \operatorname{Im}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}) + + The shape of the tensor is defined by the variable argument :attr:`size`. + + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randn(4) + tensor([-2.1436, 0.9966, 2.3426, -0.6366]) + >>> torch.randn(2, 3) + tensor([[ 1.5954, 2.8929, -1.0923], + [ 1.1719, -0.4709, -0.1996]]) + + .. _complex normal distribution: https://en.wikipedia.org/wiki/Complex_normal_distribution + """ + +@overload +def randn( + *size: _int | SymInt, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + randn(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + + Returns a tensor filled with random numbers from a normal distribution + with mean `0` and variance `1` (also called the standard normal + distribution). + + .. math:: + \text{out}_{i} \sim \mathcal{N}(0, 1) + + For complex dtypes, the tensor is i.i.d. sampled from a `complex normal distribution`_ with zero mean and + unit variance as + + .. math:: + \text{out}_{i} \sim \mathcal{CN}(0, 1) + + This is equivalent to separately sampling the real :math:`(\operatorname{Re})` and imaginary + :math:`(\operatorname{Im})` part of :math:`\text{out}_i` as + + .. math:: + \operatorname{Re}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}),\quad + \operatorname{Im}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}) + + The shape of the tensor is defined by the variable argument :attr:`size`. + + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randn(4) + tensor([-2.1436, 0.9966, 2.3426, -0.6366]) + >>> torch.randn(2, 3) + tensor([[ 1.5954, 2.8929, -1.0923], + [ 1.1719, -0.4709, -0.1996]]) + + .. _complex normal distribution: https://en.wikipedia.org/wiki/Complex_normal_distribution + """ + +@overload +def randn( + size: Sequence[_int | SymInt], + *, + names: Sequence[str | EllipsisType | None] | None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + randn(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + + Returns a tensor filled with random numbers from a normal distribution + with mean `0` and variance `1` (also called the standard normal + distribution). + + .. math:: + \text{out}_{i} \sim \mathcal{N}(0, 1) + + For complex dtypes, the tensor is i.i.d. sampled from a `complex normal distribution`_ with zero mean and + unit variance as + + .. math:: + \text{out}_{i} \sim \mathcal{CN}(0, 1) + + This is equivalent to separately sampling the real :math:`(\operatorname{Re})` and imaginary + :math:`(\operatorname{Im})` part of :math:`\text{out}_i` as + + .. math:: + \operatorname{Re}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}),\quad + \operatorname{Im}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}) + + The shape of the tensor is defined by the variable argument :attr:`size`. + + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randn(4) + tensor([-2.1436, 0.9966, 2.3426, -0.6366]) + >>> torch.randn(2, 3) + tensor([[ 1.5954, 2.8929, -1.0923], + [ 1.1719, -0.4709, -0.1996]]) + + .. _complex normal distribution: https://en.wikipedia.org/wiki/Complex_normal_distribution + """ + +@overload +def randn( + *size: _int | SymInt, + names: Sequence[str | EllipsisType | None] | None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + randn(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + + Returns a tensor filled with random numbers from a normal distribution + with mean `0` and variance `1` (also called the standard normal + distribution). + + .. math:: + \text{out}_{i} \sim \mathcal{N}(0, 1) + + For complex dtypes, the tensor is i.i.d. sampled from a `complex normal distribution`_ with zero mean and + unit variance as + + .. math:: + \text{out}_{i} \sim \mathcal{CN}(0, 1) + + This is equivalent to separately sampling the real :math:`(\operatorname{Re})` and imaginary + :math:`(\operatorname{Im})` part of :math:`\text{out}_i` as + + .. math:: + \operatorname{Re}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}),\quad + \operatorname{Im}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}) + + The shape of the tensor is defined by the variable argument :attr:`size`. + + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randn(4) + tensor([-2.1436, 0.9966, 2.3426, -0.6366]) + >>> torch.randn(2, 3) + tensor([[ 1.5954, 2.8929, -1.0923], + [ 1.1719, -0.4709, -0.1996]]) + + .. _complex normal distribution: https://en.wikipedia.org/wiki/Complex_normal_distribution + """ + +def randn_like( + input: Tensor, + *, + memory_format: memory_format | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + randn_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + + Returns a tensor with the same size as :attr:`input` that is filled with + random numbers from a normal distribution with mean 0 and variance 1. Please refer to :func:`torch.randn` for the + sampling process of complex dtypes. ``torch.randn_like(input)`` is equivalent to + ``torch.randn(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``. + + Args: + input (Tensor): the size of :attr:`input` will determine size of the output tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: if ``None``, defaults to the dtype of :attr:`input`. + layout (:class:`torch.layout`, optional): the desired layout of returned tensor. + Default: if ``None``, defaults to the layout of :attr:`input`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, defaults to the device of :attr:`input`. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + +@overload +def randperm( + n: _int | SymInt, + *, + generator: Generator | None, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + randperm(n, *, generator=None, out=None, dtype=torch.int64,layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a random permutation of integers from ``0`` to ``n - 1``. + + Args: + n (int): the upper bound (exclusive) + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: ``torch.int64``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randperm(4) + tensor([2, 1, 0, 3]) + """ + +@overload +def randperm( + n: _int | SymInt, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + randperm(n, *, generator=None, out=None, dtype=torch.int64,layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a random permutation of integers from ``0`` to ``n - 1``. + + Args: + n (int): the upper bound (exclusive) + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: ``torch.int64``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randperm(4) + tensor([2, 1, 0, 3]) + """ + +def range( + start: Number, + end: Number, + step: Number = 1, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + device: DeviceLikeType | None = None, + requires_grad: _bool = False, + pin_memory: _bool = False, +) -> Tensor: + r""" + range(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a 1-D tensor of size :math:`\left\lfloor \frac{\text{end} - \text{start}}{\text{step}} \right\rfloor + 1` + with values from :attr:`start` to :attr:`end` with step :attr:`step`. Step is + the gap between two values in the tensor. + + .. math:: + \text{out}_{i+1} = \text{out}_i + \text{step}. + + .. warning:: + This function is deprecated and will be removed in a future release because its behavior is inconsistent with + Python's range builtin. Instead, use :func:`torch.arange`, which produces values in [start, end). + + Args: + start (float, optional): the starting value for the set of points. Default: ``0``. + end (float): the ending value for the set of points + step (float, optional): the gap between each pair of adjacent points. Default: ``1``. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). If `dtype` is not given, infer the data type from the other input + arguments. If any of `start`, `end`, or `step` are floating-point, the + `dtype` is inferred to be the default dtype, see + :meth:`~torch.get_default_dtype`. Otherwise, the `dtype` is inferred to + be `torch.int64`. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.range(1, 4) + tensor([ 1., 2., 3., 4.]) + >>> torch.range(1, 4, 0.5) + tensor([ 1.0000, 1.5000, 2.0000, 2.5000, 3.0000, 3.5000, 4.0000]) + """ + +def ravel(input: Tensor) -> Tensor: + r""" + ravel(input) -> Tensor + + Return a contiguous flattened tensor. A copy is made only if needed. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> t = torch.tensor([[[1, 2], + ... [3, 4]], + ... [[5, 6], + ... [7, 8]]]) + >>> torch.ravel(t) + tensor([1, 2, 3, 4, 5, 6, 7, 8]) + """ + +def real(input: Tensor) -> Tensor: + r""" + real(input) -> Tensor + + Returns a new tensor containing real values of the :attr:`self` tensor. + The returned tensor and :attr:`self` share the same underlying storage. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> x=torch.randn(4, dtype=torch.cfloat) + >>> x + tensor([(0.3100+0.3553j), (-0.5445-0.7896j), (-1.6492-0.0633j), (-0.0638-0.8119j)]) + >>> x.real + tensor([ 0.3100, -0.5445, -1.6492, -0.0638]) + """ + +def reciprocal(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + reciprocal(input, *, out=None) -> Tensor + + Returns a new tensor with the reciprocal of the elements of :attr:`input` + + .. math:: + \text{out}_{i} = \frac{1}{\text{input}_{i}} + + .. note:: + Unlike NumPy's reciprocal, torch.reciprocal supports integral inputs. Integral + inputs to reciprocal are automatically :ref:`promoted ` to + the default scalar type. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-0.4595, -2.1219, -1.4314, 0.7298]) + >>> torch.reciprocal(a) + tensor([-2.1763, -0.4713, -0.6986, 1.3702]) + """ + +def reciprocal_(input: Tensor) -> Tensor: ... +def relu(input: Tensor) -> Tensor: ... +def relu_(input: Tensor) -> Tensor: ... +@overload +def remainder( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + remainder(input, other, *, out=None) -> Tensor + + Computes + `Python's modulus operation `_ + entrywise. The result has the same sign as the divisor :attr:`other` and its absolute value + is less than that of :attr:`other`. + + It may also be defined in terms of :func:`torch.div` as + + .. code:: python + + torch.remainder(a, b) == a - a.div(b, rounding_mode="floor") * b + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer and float inputs. + + .. note:: + Complex inputs are not supported. In some cases, it is not mathematically + possible to satisfy the definition of a modulo operation with complex numbers. + See :func:`torch.fmod` for how division by zero is handled. + + .. seealso:: + + :func:`torch.fmod` which implements C++'s `std::fmod `_. + This one is defined in terms of division rounding towards zero. + + Args: + input (Tensor or Scalar): the dividend + other (Tensor or Scalar): the divisor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.remainder(torch.tensor([-3., -2, -1, 1, 2, 3]), 2) + tensor([ 1., 0., 1., 1., 0., 1.]) + >>> torch.remainder(torch.tensor([1, 2, 3, 4, 5]), -1.5) + tensor([ -0.5000, -1.0000, 0.0000, -0.5000, -1.0000 ]) + """ + +@overload +def remainder(self: Number | _complex, other: Tensor) -> Tensor: + r""" + remainder(input, other, *, out=None) -> Tensor + + Computes + `Python's modulus operation `_ + entrywise. The result has the same sign as the divisor :attr:`other` and its absolute value + is less than that of :attr:`other`. + + It may also be defined in terms of :func:`torch.div` as + + .. code:: python + + torch.remainder(a, b) == a - a.div(b, rounding_mode="floor") * b + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer and float inputs. + + .. note:: + Complex inputs are not supported. In some cases, it is not mathematically + possible to satisfy the definition of a modulo operation with complex numbers. + See :func:`torch.fmod` for how division by zero is handled. + + .. seealso:: + + :func:`torch.fmod` which implements C++'s `std::fmod `_. + This one is defined in terms of division rounding towards zero. + + Args: + input (Tensor or Scalar): the dividend + other (Tensor or Scalar): the divisor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.remainder(torch.tensor([-3., -2, -1, 1, 2, 3]), 2) + tensor([ 1., 0., 1., 1., 0., 1.]) + >>> torch.remainder(torch.tensor([1, 2, 3, 4, 5]), -1.5) + tensor([ -0.5000, -1.0000, 0.0000, -0.5000, -1.0000 ]) + """ + +@overload +def remainder( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + remainder(input, other, *, out=None) -> Tensor + + Computes + `Python's modulus operation `_ + entrywise. The result has the same sign as the divisor :attr:`other` and its absolute value + is less than that of :attr:`other`. + + It may also be defined in terms of :func:`torch.div` as + + .. code:: python + + torch.remainder(a, b) == a - a.div(b, rounding_mode="floor") * b + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer and float inputs. + + .. note:: + Complex inputs are not supported. In some cases, it is not mathematically + possible to satisfy the definition of a modulo operation with complex numbers. + See :func:`torch.fmod` for how division by zero is handled. + + .. seealso:: + + :func:`torch.fmod` which implements C++'s `std::fmod `_. + This one is defined in terms of division rounding towards zero. + + Args: + input (Tensor or Scalar): the dividend + other (Tensor or Scalar): the divisor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.remainder(torch.tensor([-3., -2, -1, 1, 2, 3]), 2) + tensor([ 1., 0., 1., 1., 0., 1.]) + >>> torch.remainder(torch.tensor([1, 2, 3, 4, 5]), -1.5) + tensor([ -0.5000, -1.0000, 0.0000, -0.5000, -1.0000 ]) + """ + +def renorm( + input: Tensor, + p: Number | _complex, + dim: _int, + maxnorm: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + renorm(input, p, dim, maxnorm, *, out=None) -> Tensor + + Returns a tensor where each sub-tensor of :attr:`input` along dimension + :attr:`dim` is normalized such that the `p`-norm of the sub-tensor is lower + than the value :attr:`maxnorm` + + .. note:: If the norm of a row is lower than `maxnorm`, the row is unchanged + + Args: + input (Tensor): the input tensor. + p (float): the power for the norm computation + dim (int): the dimension to slice over to get the sub-tensors + maxnorm (float): the maximum norm to keep each sub-tensor under + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> x = torch.ones(3, 3) + >>> x[1].fill_(2) + tensor([ 2., 2., 2.]) + >>> x[2].fill_(3) + tensor([ 3., 3., 3.]) + >>> x + tensor([[ 1., 1., 1.], + [ 2., 2., 2.], + [ 3., 3., 3.]]) + >>> torch.renorm(x, 1, 0, 5) + tensor([[ 1.0000, 1.0000, 1.0000], + [ 1.6667, 1.6667, 1.6667], + [ 1.6667, 1.6667, 1.6667]]) + """ + +@overload +def repeat_interleave( + input: Tensor, + repeats: Tensor, + dim: _int | None = None, + *, + output_size: _int | SymInt | None = None, +) -> Tensor: + r""" + repeat_interleave(input, repeats, dim=None, *, output_size=None) -> Tensor + + Repeat elements of a tensor. + + .. warning:: + + This is different from :meth:`torch.Tensor.repeat` but similar to ``numpy.repeat``. + + Args: + input (Tensor): the input tensor. + repeats (Tensor or int): The number of repetitions for each element. + repeats is broadcasted to fit the shape of the given axis. + dim (int, optional): The dimension along which to repeat values. + By default, use the flattened input array, and return a flat output + array. + + Keyword args: + output_size (int, optional): Total output size for the given axis + ( e.g. sum of repeats). If given, it will avoid stream synchronization + needed to calculate output shape of the tensor. + + Returns: + Tensor: Repeated tensor which has the same shape as input, except along the given axis. + + Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> x.repeat_interleave(2) + tensor([1, 1, 2, 2, 3, 3]) + >>> y = torch.tensor([[1, 2], [3, 4]]) + >>> torch.repeat_interleave(y, 2) + tensor([1, 1, 2, 2, 3, 3, 4, 4]) + >>> torch.repeat_interleave(y, 3, dim=1) + tensor([[1, 1, 1, 2, 2, 2], + [3, 3, 3, 4, 4, 4]]) + >>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0) + tensor([[1, 2], + [3, 4], + [3, 4]]) + >>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0, output_size=3) + tensor([[1, 2], + [3, 4], + [3, 4]]) + + If the `repeats` is `tensor([n1, n2, n3, ...])`, then the output will be + `tensor([0, 0, ..., 1, 1, ..., 2, 2, ..., ...])` where `0` appears `n1` times, + `1` appears `n2` times, `2` appears `n3` times, etc. + + .. function:: repeat_interleave(repeats, *) -> Tensor + :noindex: + + Repeats 0 repeats[0] times, 1 repeats[1] times, 2 repeats[2] times, etc. + + Args: + repeats (Tensor): The number of repetitions for each element. + + Returns: + Tensor: Repeated tensor of size `sum(repeats)`. + + Example:: + + >>> torch.repeat_interleave(torch.tensor([1, 2, 3])) + tensor([0, 1, 1, 2, 2, 2]) + """ + +@overload +def repeat_interleave( + repeats: Tensor, + *, + output_size: _int | SymInt | None = None, +) -> Tensor: + r""" + repeat_interleave(input, repeats, dim=None, *, output_size=None) -> Tensor + + Repeat elements of a tensor. + + .. warning:: + + This is different from :meth:`torch.Tensor.repeat` but similar to ``numpy.repeat``. + + Args: + input (Tensor): the input tensor. + repeats (Tensor or int): The number of repetitions for each element. + repeats is broadcasted to fit the shape of the given axis. + dim (int, optional): The dimension along which to repeat values. + By default, use the flattened input array, and return a flat output + array. + + Keyword args: + output_size (int, optional): Total output size for the given axis + ( e.g. sum of repeats). If given, it will avoid stream synchronization + needed to calculate output shape of the tensor. + + Returns: + Tensor: Repeated tensor which has the same shape as input, except along the given axis. + + Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> x.repeat_interleave(2) + tensor([1, 1, 2, 2, 3, 3]) + >>> y = torch.tensor([[1, 2], [3, 4]]) + >>> torch.repeat_interleave(y, 2) + tensor([1, 1, 2, 2, 3, 3, 4, 4]) + >>> torch.repeat_interleave(y, 3, dim=1) + tensor([[1, 1, 1, 2, 2, 2], + [3, 3, 3, 4, 4, 4]]) + >>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0) + tensor([[1, 2], + [3, 4], + [3, 4]]) + >>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0, output_size=3) + tensor([[1, 2], + [3, 4], + [3, 4]]) + + If the `repeats` is `tensor([n1, n2, n3, ...])`, then the output will be + `tensor([0, 0, ..., 1, 1, ..., 2, 2, ..., ...])` where `0` appears `n1` times, + `1` appears `n2` times, `2` appears `n3` times, etc. + + .. function:: repeat_interleave(repeats, *) -> Tensor + :noindex: + + Repeats 0 repeats[0] times, 1 repeats[1] times, 2 repeats[2] times, etc. + + Args: + repeats (Tensor): The number of repetitions for each element. + + Returns: + Tensor: Repeated tensor of size `sum(repeats)`. + + Example:: + + >>> torch.repeat_interleave(torch.tensor([1, 2, 3])) + tensor([0, 1, 1, 2, 2, 2]) + """ + +@overload +def repeat_interleave( + input: Tensor, + repeats: _int | SymInt, + dim: _int | None = None, + *, + output_size: _int | SymInt | None = None, +) -> Tensor: + r""" + repeat_interleave(input, repeats, dim=None, *, output_size=None) -> Tensor + + Repeat elements of a tensor. + + .. warning:: + + This is different from :meth:`torch.Tensor.repeat` but similar to ``numpy.repeat``. + + Args: + input (Tensor): the input tensor. + repeats (Tensor or int): The number of repetitions for each element. + repeats is broadcasted to fit the shape of the given axis. + dim (int, optional): The dimension along which to repeat values. + By default, use the flattened input array, and return a flat output + array. + + Keyword args: + output_size (int, optional): Total output size for the given axis + ( e.g. sum of repeats). If given, it will avoid stream synchronization + needed to calculate output shape of the tensor. + + Returns: + Tensor: Repeated tensor which has the same shape as input, except along the given axis. + + Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> x.repeat_interleave(2) + tensor([1, 1, 2, 2, 3, 3]) + >>> y = torch.tensor([[1, 2], [3, 4]]) + >>> torch.repeat_interleave(y, 2) + tensor([1, 1, 2, 2, 3, 3, 4, 4]) + >>> torch.repeat_interleave(y, 3, dim=1) + tensor([[1, 1, 1, 2, 2, 2], + [3, 3, 3, 4, 4, 4]]) + >>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0) + tensor([[1, 2], + [3, 4], + [3, 4]]) + >>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0, output_size=3) + tensor([[1, 2], + [3, 4], + [3, 4]]) + + If the `repeats` is `tensor([n1, n2, n3, ...])`, then the output will be + `tensor([0, 0, ..., 1, 1, ..., 2, 2, ..., ...])` where `0` appears `n1` times, + `1` appears `n2` times, `2` appears `n3` times, etc. + + .. function:: repeat_interleave(repeats, *) -> Tensor + :noindex: + + Repeats 0 repeats[0] times, 1 repeats[1] times, 2 repeats[2] times, etc. + + Args: + repeats (Tensor): The number of repetitions for each element. + + Returns: + Tensor: Repeated tensor of size `sum(repeats)`. + + Example:: + + >>> torch.repeat_interleave(torch.tensor([1, 2, 3])) + tensor([0, 1, 1, 2, 2, 2]) + """ + +def reshape(input: Tensor, shape: Sequence[_int | SymInt]) -> Tensor: + r""" + reshape(input, shape) -> Tensor + + Returns a tensor with the same data and number of elements as :attr:`input`, + but with the specified shape. When possible, the returned tensor will be a view + of :attr:`input`. Otherwise, it will be a copy. Contiguous inputs and inputs + with compatible strides can be reshaped without copying, but you should not + depend on the copying vs. viewing behavior. + + See :meth:`torch.Tensor.view` on when it is possible to return a view. + + A single dimension may be -1, in which case it's inferred from the remaining + dimensions and the number of elements in :attr:`input`. + + Args: + input (Tensor): the tensor to be reshaped + shape (tuple of int): the new shape + + Example:: + + >>> a = torch.arange(4.) + >>> torch.reshape(a, (2, 2)) + tensor([[ 0., 1.], + [ 2., 3.]]) + >>> b = torch.tensor([[0, 1], [2, 3]]) + >>> torch.reshape(b, (-1,)) + tensor([ 0, 1, 2, 3]) + """ + +def resize_as_( + input: Tensor, + the_template: Tensor, + *, + memory_format: memory_format | None = None, +) -> Tensor: ... +def resize_as_sparse_(input: Tensor, the_template: Tensor) -> Tensor: ... +def resolve_conj(input: Tensor) -> Tensor: + r""" + resolve_conj(input) -> Tensor + + Returns a new tensor with materialized conjugation if :attr:`input`'s conjugate bit is set to `True`, + else returns :attr:`input`. The output tensor will always have its conjugate bit set to `False`. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]) + >>> y = x.conj() + >>> y.is_conj() + True + >>> z = y.resolve_conj() + >>> z + tensor([-1 - 1j, -2 - 2j, 3 + 3j]) + >>> z.is_conj() + False + """ + +def resolve_neg(input: Tensor) -> Tensor: + r""" + resolve_neg(input) -> Tensor + + Returns a new tensor with materialized negation if :attr:`input`'s negative bit is set to `True`, + else returns :attr:`input`. The output tensor will always have its negative bit set to `False`. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]) + >>> y = x.conj() + >>> z = y.imag + >>> z.is_neg() + True + >>> out = z.resolve_neg() + >>> out + tensor([-1., -2., 3.]) + >>> out.is_neg() + False + """ + +@overload +def result_type(tensor: Tensor, other: Tensor) -> _dtype: + r""" + result_type(tensor1, tensor2) -> dtype + + Returns the :class:`torch.dtype` that would result from performing an arithmetic + operation on the provided input tensors. See type promotion :ref:`documentation ` + for more information on the type promotion logic. + + Args: + tensor1 (Tensor or Number): an input tensor or number + tensor2 (Tensor or Number): an input tensor or number + + Example:: + + >>> torch.result_type(torch.tensor([1, 2], dtype=torch.int), 1.0) + torch.float32 + >>> torch.result_type(torch.tensor([1, 2], dtype=torch.uint8), torch.tensor(1)) + torch.uint8 + """ + +@overload +def result_type(scalar: Number | _complex, tensor: Tensor) -> _dtype: + r""" + result_type(tensor1, tensor2) -> dtype + + Returns the :class:`torch.dtype` that would result from performing an arithmetic + operation on the provided input tensors. See type promotion :ref:`documentation ` + for more information on the type promotion logic. + + Args: + tensor1 (Tensor or Number): an input tensor or number + tensor2 (Tensor or Number): an input tensor or number + + Example:: + + >>> torch.result_type(torch.tensor([1, 2], dtype=torch.int), 1.0) + torch.float32 + >>> torch.result_type(torch.tensor([1, 2], dtype=torch.uint8), torch.tensor(1)) + torch.uint8 + """ + +@overload +def result_type(tensor: Tensor, other: Number | _complex) -> _dtype: + r""" + result_type(tensor1, tensor2) -> dtype + + Returns the :class:`torch.dtype` that would result from performing an arithmetic + operation on the provided input tensors. See type promotion :ref:`documentation ` + for more information on the type promotion logic. + + Args: + tensor1 (Tensor or Number): an input tensor or number + tensor2 (Tensor or Number): an input tensor or number + + Example:: + + >>> torch.result_type(torch.tensor([1, 2], dtype=torch.int), 1.0) + torch.float32 + >>> torch.result_type(torch.tensor([1, 2], dtype=torch.uint8), torch.tensor(1)) + torch.uint8 + """ + +@overload +def result_type( + scalar1: Number | _complex, + scalar2: Number | _complex, +) -> _dtype: + r""" + result_type(tensor1, tensor2) -> dtype + + Returns the :class:`torch.dtype` that would result from performing an arithmetic + operation on the provided input tensors. See type promotion :ref:`documentation ` + for more information on the type promotion logic. + + Args: + tensor1 (Tensor or Number): an input tensor or number + tensor2 (Tensor or Number): an input tensor or number + + Example:: + + >>> torch.result_type(torch.tensor([1, 2], dtype=torch.int), 1.0) + torch.float32 + >>> torch.result_type(torch.tensor([1, 2], dtype=torch.uint8), torch.tensor(1)) + torch.uint8 + """ + +def rms_norm( + input: Tensor, + normalized_shape: Sequence[_int | SymInt], + weight: Tensor | None = None, + eps: _float | None = None, +) -> Tensor: ... +@overload +def rnn_relu( + data: Tensor, + batch_sizes: Tensor, + hx: Tensor, + params: tuple[Tensor, ...] | list[Tensor] | None, + has_biases: _bool, + num_layers: _int, + dropout: _float, + train: _bool, + bidirectional: _bool, +) -> tuple[Tensor, Tensor]: ... +@overload +def rnn_relu( + input: Tensor, + hx: Tensor, + params: tuple[Tensor, ...] | list[Tensor] | None, + has_biases: _bool, + num_layers: _int, + dropout: _float, + train: _bool, + bidirectional: _bool, + batch_first: _bool, +) -> tuple[Tensor, Tensor]: ... +def rnn_relu_cell( + input: Tensor, + hx: Tensor, + w_ih: Tensor, + w_hh: Tensor, + b_ih: Tensor | None = None, + b_hh: Tensor | None = None, +) -> Tensor: ... +@overload +def rnn_tanh( + data: Tensor, + batch_sizes: Tensor, + hx: Tensor, + params: tuple[Tensor, ...] | list[Tensor] | None, + has_biases: _bool, + num_layers: _int, + dropout: _float, + train: _bool, + bidirectional: _bool, +) -> tuple[Tensor, Tensor]: ... +@overload +def rnn_tanh( + input: Tensor, + hx: Tensor, + params: tuple[Tensor, ...] | list[Tensor] | None, + has_biases: _bool, + num_layers: _int, + dropout: _float, + train: _bool, + bidirectional: _bool, + batch_first: _bool, +) -> tuple[Tensor, Tensor]: ... +def rnn_tanh_cell( + input: Tensor, + hx: Tensor, + w_ih: Tensor, + w_hh: Tensor, + b_ih: Tensor | None = None, + b_hh: Tensor | None = None, +) -> Tensor: ... +def roll( + input: Tensor, + shifts: _int | SymInt | Sequence[_int | SymInt], + dims: _int | _size = (), +) -> Tensor: + r""" + roll(input, shifts, dims=None) -> Tensor + + Roll the tensor :attr:`input` along the given dimension(s). Elements that are + shifted beyond the last position are re-introduced at the first position. If + :attr:`dims` is `None`, the tensor will be flattened before rolling and then + restored to the original shape. + + Args: + input (Tensor): the input tensor. + shifts (int or tuple of ints): The number of places by which the elements + of the tensor are shifted. If shifts is a tuple, dims must be a tuple of + the same size, and each dimension will be rolled by the corresponding + value + dims (int or tuple of ints): Axis along which to roll + + Example:: + + >>> x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]).view(4, 2) + >>> x + tensor([[1, 2], + [3, 4], + [5, 6], + [7, 8]]) + >>> torch.roll(x, 1) + tensor([[8, 1], + [2, 3], + [4, 5], + [6, 7]]) + >>> torch.roll(x, 1, 0) + tensor([[7, 8], + [1, 2], + [3, 4], + [5, 6]]) + >>> torch.roll(x, -1, 0) + tensor([[3, 4], + [5, 6], + [7, 8], + [1, 2]]) + >>> torch.roll(x, shifts=(2, 1), dims=(0, 1)) + tensor([[6, 5], + [8, 7], + [2, 1], + [4, 3]]) + """ + +def rot90(input: Tensor, k: _int = 1, dims: _size = (0, 1)) -> Tensor: + r""" + rot90(input, k=1, dims=(0, 1)) -> Tensor + + Rotate an n-D tensor by 90 degrees in the plane specified by dims axis. + Rotation direction is from the first towards the second axis if k > 0, and from the second towards the first for k < 0. + + Args: + input (Tensor): the input tensor. + k (int): number of times to rotate. Default value is 1 + dims (a list or tuple): axis to rotate. Default value is [0, 1] + + Example:: + + >>> x = torch.arange(4).view(2, 2) + >>> x + tensor([[0, 1], + [2, 3]]) + >>> torch.rot90(x, 1, [0, 1]) + tensor([[1, 3], + [0, 2]]) + + >>> x = torch.arange(8).view(2, 2, 2) + >>> x + tensor([[[0, 1], + [2, 3]], + + [[4, 5], + [6, 7]]]) + >>> torch.rot90(x, 1, [1, 2]) + tensor([[[1, 3], + [0, 2]], + + [[5, 7], + [4, 6]]]) + """ + +@overload +def round(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + round(input, *, decimals=0, out=None) -> Tensor + + Rounds elements of :attr:`input` to the nearest integer. + + For integer inputs, follows the array-api convention of returning a + copy of the input tensor. + The return type of output is same as that of input's dtype. + + .. note:: + This function implements the "round half to even" to + break ties when a number is equidistant from two + integers (e.g. `round(2.5)` is 2). + + When the :attr:\`decimals\` argument is specified the + algorithm used is similar to NumPy's `around`. This + algorithm is fast but inexact and it can easily + overflow for low precision dtypes. + Eg. `round(tensor([10000], dtype=torch.float16), decimals=3)` is `inf`. + + .. seealso:: + :func:`torch.ceil`, which rounds up. + :func:`torch.floor`, which rounds down. + :func:`torch.trunc`, which rounds towards zero. + + Args: + input (Tensor): the input tensor. + decimals (int): Number of decimal places to round to (default: 0). + If decimals is negative, it specifies the number of positions + to the left of the decimal point. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.round(torch.tensor((4.7, -2.3, 9.1, -7.7))) + tensor([ 5., -2., 9., -8.]) + + >>> # Values equidistant from two integers are rounded towards the + >>> # the nearest even value (zero is treated as even) + >>> torch.round(torch.tensor([-0.5, 0.5, 1.5, 2.5])) + tensor([-0., 0., 2., 2.]) + + >>> # A positive decimals argument rounds to the to that decimal place + >>> torch.round(torch.tensor([0.1234567]), decimals=3) + tensor([0.1230]) + + >>> # A negative decimals argument rounds to the left of the decimal + >>> torch.round(torch.tensor([1200.1234567]), decimals=-3) + tensor([1000.]) + """ + +@overload +def round( + input: Tensor, + *, + decimals: _int, + out: Tensor | None = None, +) -> Tensor: + r""" + round(input, *, decimals=0, out=None) -> Tensor + + Rounds elements of :attr:`input` to the nearest integer. + + For integer inputs, follows the array-api convention of returning a + copy of the input tensor. + The return type of output is same as that of input's dtype. + + .. note:: + This function implements the "round half to even" to + break ties when a number is equidistant from two + integers (e.g. `round(2.5)` is 2). + + When the :attr:\`decimals\` argument is specified the + algorithm used is similar to NumPy's `around`. This + algorithm is fast but inexact and it can easily + overflow for low precision dtypes. + Eg. `round(tensor([10000], dtype=torch.float16), decimals=3)` is `inf`. + + .. seealso:: + :func:`torch.ceil`, which rounds up. + :func:`torch.floor`, which rounds down. + :func:`torch.trunc`, which rounds towards zero. + + Args: + input (Tensor): the input tensor. + decimals (int): Number of decimal places to round to (default: 0). + If decimals is negative, it specifies the number of positions + to the left of the decimal point. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.round(torch.tensor((4.7, -2.3, 9.1, -7.7))) + tensor([ 5., -2., 9., -8.]) + + >>> # Values equidistant from two integers are rounded towards the + >>> # the nearest even value (zero is treated as even) + >>> torch.round(torch.tensor([-0.5, 0.5, 1.5, 2.5])) + tensor([-0., 0., 2., 2.]) + + >>> # A positive decimals argument rounds to the to that decimal place + >>> torch.round(torch.tensor([0.1234567]), decimals=3) + tensor([0.1230]) + + >>> # A negative decimals argument rounds to the left of the decimal + >>> torch.round(torch.tensor([1200.1234567]), decimals=-3) + tensor([1000.]) + """ + +@overload +def round_(input: Tensor) -> Tensor: ... +@overload +def round_(input: Tensor, *, decimals: _int) -> Tensor: ... +def row_indices_copy(input: Tensor, *, out: Tensor | None = None) -> Tensor: ... +def row_stack( + tensors: tuple[Tensor, ...] | list[Tensor] | None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + row_stack(tensors, *, out=None) -> Tensor + + Alias of :func:`torch.vstack`. + """ + +def rrelu( + input: Tensor, + lower: Number | _complex = 0.125, + upper: Number | _complex = 0.3333333333333333, + training: _bool = False, + generator: Generator | None = None, +) -> Tensor: ... +def rrelu_( + input: Tensor, + lower: Number | _complex = 0.125, + upper: Number | _complex = 0.3333333333333333, + training: _bool = False, + generator: Generator | None = None, +) -> Tensor: ... +def rsqrt(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + rsqrt(input, *, out=None) -> Tensor + + Returns a new tensor with the reciprocal of the square-root of each of + the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \frac{1}{\sqrt{\text{input}_{i}}} + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-0.0370, 0.2970, 1.5420, -0.9105]) + >>> torch.rsqrt(a) + tensor([ nan, 1.8351, 0.8053, nan]) + """ + +def rsqrt_(input: Tensor) -> Tensor: ... +@overload +def rsub( + input: Tensor, + other: Tensor, + *, + alpha: Number | _complex = 1, +) -> Tensor: ... +@overload +def rsub( + input: Tensor, + other: Number | _complex, + alpha: Number | _complex = 1, +) -> Tensor: ... +def saddmm( + input: Tensor, + mat1: Tensor, + mat2: Tensor, + *, + beta: Number = 1, + alpha: Number = 1, + out: Tensor | None = None, +) -> Tensor: ... +def scalar_tensor( + s: Number | _complex, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: ... +@overload +def scatter( + input: Tensor, + dim: _int, + index: Tensor, + src: Tensor, + *, + reduce: str, + out: Tensor | None = None, +) -> Tensor: + r""" + scatter(input, dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_` + """ + +@overload +def scatter( + input: Tensor, + dim: _int, + index: Tensor, + src: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + scatter(input, dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_` + """ + +@overload +def scatter( + input: Tensor, + dim: _int, + index: Tensor, + value: Number | _complex, + *, + reduce: str, + out: Tensor | None = None, +) -> Tensor: + r""" + scatter(input, dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_` + """ + +@overload +def scatter( + input: Tensor, + dim: str | EllipsisType | None, + index: Tensor, + src: Tensor, +) -> Tensor: + r""" + scatter(input, dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_` + """ + +@overload +def scatter( + input: Tensor, + dim: _int, + index: Tensor, + value: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + scatter(input, dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_` + """ + +@overload +def scatter( + input: Tensor, + dim: str | EllipsisType | None, + index: Tensor, + value: Number | _complex, +) -> Tensor: + r""" + scatter(input, dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_` + """ + +@overload +def scatter_add( + input: Tensor, + dim: _int, + index: Tensor, + src: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + scatter_add(input, dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_add_` + """ + +@overload +def scatter_add( + input: Tensor, + dim: str | EllipsisType | None, + index: Tensor, + src: Tensor, +) -> Tensor: + r""" + scatter_add(input, dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_add_` + """ + +def scatter_reduce( + input: Tensor, + dim: _int, + index: Tensor, + src: Tensor, + reduce: str, + *, + include_self: _bool = True, + out: Tensor | None = None, +) -> Tensor: + r""" + scatter_reduce(input, dim, index, src, reduce, *, include_self=True) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_reduce_` + """ + +@overload +def searchsorted( + sorted_sequence: Tensor, + input: Tensor, + *, + out_int32: _bool = False, + right: _bool = False, + side: str | None = None, + sorter: Tensor | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + searchsorted(sorted_sequence, values, *, out_int32=False, right=False, side=None, out=None, sorter=None) -> Tensor + + Find the indices from the *innermost* dimension of :attr:`sorted_sequence` such that, if the + corresponding values in :attr:`values` were inserted before the indices, when sorted, the order + of the corresponding *innermost* dimension within :attr:`sorted_sequence` would be preserved. + Return a new tensor with the same size as :attr:`values`. More formally, + the returned index satisfies the following rules: + + .. list-table:: + :widths: 12 10 78 + :header-rows: 1 + + * - :attr:`sorted_sequence` + - :attr:`right` + - *returned index satisfies* + * - 1-D + - False + - ``sorted_sequence[i-1] < values[m][n]...[l][x] <= sorted_sequence[i]`` + * - 1-D + - True + - ``sorted_sequence[i-1] <= values[m][n]...[l][x] < sorted_sequence[i]`` + * - N-D + - False + - ``sorted_sequence[m][n]...[l][i-1] < values[m][n]...[l][x] <= sorted_sequence[m][n]...[l][i]`` + * - N-D + - True + - ``sorted_sequence[m][n]...[l][i-1] <= values[m][n]...[l][x] < sorted_sequence[m][n]...[l][i]`` + + Args: + sorted_sequence (Tensor): N-D or 1-D tensor, containing monotonically increasing sequence on the *innermost* + dimension unless :attr:`sorter` is provided, in which case the sequence does not + need to be sorted + values (Tensor or Scalar): N-D tensor or a Scalar containing the search value(s). + + Keyword args: + out_int32 (bool, optional): indicate the output data type. torch.int32 if True, torch.int64 otherwise. + Default value is False, i.e. default output data type is torch.int64. + right (bool, optional): if False, return the first suitable location that is found. If True, return the + last such index. If no suitable index found, return 0 for non-numerical value + (eg. nan, inf) or the size of *innermost* dimension within :attr:`sorted_sequence` + (one pass the last index of the *innermost* dimension). In other words, if False, + gets the lower bound index for each value in :attr:`values` on the corresponding + *innermost* dimension of the :attr:`sorted_sequence`. If True, gets the upper + bound index instead. Default value is False. :attr:`side` does the same and is + preferred. It will error if :attr:`side` is set to "left" while this is True. + side (str, optional): the same as :attr:`right` but preferred. "left" corresponds to False for :attr:`right` + and "right" corresponds to True for :attr:`right`. It will error if this is set to + "left" while :attr:`right` is True. Default value is None. + out (Tensor, optional): the output tensor, must be the same size as :attr:`values` if provided. + sorter (LongTensor, optional): if provided, a tensor matching the shape of the unsorted + :attr:`sorted_sequence` containing a sequence of indices that sort it in the + ascending order on the innermost dimension + + + Example:: + + >>> sorted_sequence = torch.tensor([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]]) + >>> sorted_sequence + tensor([[ 1, 3, 5, 7, 9], + [ 2, 4, 6, 8, 10]]) + >>> values = torch.tensor([[3, 6, 9], [3, 6, 9]]) + >>> values + tensor([[3, 6, 9], + [3, 6, 9]]) + >>> torch.searchsorted(sorted_sequence, values) + tensor([[1, 3, 4], + [1, 2, 4]]) + >>> torch.searchsorted(sorted_sequence, values, side='right') + tensor([[2, 3, 5], + [1, 3, 4]]) + + >>> sorted_sequence_1d = torch.tensor([1, 3, 5, 7, 9]) + >>> sorted_sequence_1d + tensor([1, 3, 5, 7, 9]) + >>> torch.searchsorted(sorted_sequence_1d, values) + tensor([[1, 3, 4], + [1, 3, 4]]) + """ + +@overload +def searchsorted( + sorted_sequence: Tensor, + self: Number | _complex, + *, + out_int32: _bool = False, + right: _bool = False, + side: str | None = None, + sorter: Tensor | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + searchsorted(sorted_sequence, values, *, out_int32=False, right=False, side=None, out=None, sorter=None) -> Tensor + + Find the indices from the *innermost* dimension of :attr:`sorted_sequence` such that, if the + corresponding values in :attr:`values` were inserted before the indices, when sorted, the order + of the corresponding *innermost* dimension within :attr:`sorted_sequence` would be preserved. + Return a new tensor with the same size as :attr:`values`. More formally, + the returned index satisfies the following rules: + + .. list-table:: + :widths: 12 10 78 + :header-rows: 1 + + * - :attr:`sorted_sequence` + - :attr:`right` + - *returned index satisfies* + * - 1-D + - False + - ``sorted_sequence[i-1] < values[m][n]...[l][x] <= sorted_sequence[i]`` + * - 1-D + - True + - ``sorted_sequence[i-1] <= values[m][n]...[l][x] < sorted_sequence[i]`` + * - N-D + - False + - ``sorted_sequence[m][n]...[l][i-1] < values[m][n]...[l][x] <= sorted_sequence[m][n]...[l][i]`` + * - N-D + - True + - ``sorted_sequence[m][n]...[l][i-1] <= values[m][n]...[l][x] < sorted_sequence[m][n]...[l][i]`` + + Args: + sorted_sequence (Tensor): N-D or 1-D tensor, containing monotonically increasing sequence on the *innermost* + dimension unless :attr:`sorter` is provided, in which case the sequence does not + need to be sorted + values (Tensor or Scalar): N-D tensor or a Scalar containing the search value(s). + + Keyword args: + out_int32 (bool, optional): indicate the output data type. torch.int32 if True, torch.int64 otherwise. + Default value is False, i.e. default output data type is torch.int64. + right (bool, optional): if False, return the first suitable location that is found. If True, return the + last such index. If no suitable index found, return 0 for non-numerical value + (eg. nan, inf) or the size of *innermost* dimension within :attr:`sorted_sequence` + (one pass the last index of the *innermost* dimension). In other words, if False, + gets the lower bound index for each value in :attr:`values` on the corresponding + *innermost* dimension of the :attr:`sorted_sequence`. If True, gets the upper + bound index instead. Default value is False. :attr:`side` does the same and is + preferred. It will error if :attr:`side` is set to "left" while this is True. + side (str, optional): the same as :attr:`right` but preferred. "left" corresponds to False for :attr:`right` + and "right" corresponds to True for :attr:`right`. It will error if this is set to + "left" while :attr:`right` is True. Default value is None. + out (Tensor, optional): the output tensor, must be the same size as :attr:`values` if provided. + sorter (LongTensor, optional): if provided, a tensor matching the shape of the unsorted + :attr:`sorted_sequence` containing a sequence of indices that sort it in the + ascending order on the innermost dimension + + + Example:: + + >>> sorted_sequence = torch.tensor([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]]) + >>> sorted_sequence + tensor([[ 1, 3, 5, 7, 9], + [ 2, 4, 6, 8, 10]]) + >>> values = torch.tensor([[3, 6, 9], [3, 6, 9]]) + >>> values + tensor([[3, 6, 9], + [3, 6, 9]]) + >>> torch.searchsorted(sorted_sequence, values) + tensor([[1, 3, 4], + [1, 2, 4]]) + >>> torch.searchsorted(sorted_sequence, values, side='right') + tensor([[2, 3, 5], + [1, 3, 4]]) + + >>> sorted_sequence_1d = torch.tensor([1, 3, 5, 7, 9]) + >>> sorted_sequence_1d + tensor([1, 3, 5, 7, 9]) + >>> torch.searchsorted(sorted_sequence_1d, values) + tensor([[1, 3, 4], + [1, 3, 4]]) + """ + +def segment_reduce( + data: Tensor, + reduce: str, + *, + lengths: Tensor | None = None, + indices: Tensor | None = None, + offsets: Tensor | None = None, + axis: _int = 0, + unsafe: _bool = False, + initial: Number | _complex | None = None, +) -> Tensor: ... +@overload +def select(input: Tensor, dim: _int, index: _int | SymInt) -> Tensor: + r""" + select(input, dim, index) -> Tensor + + Slices the :attr:`input` tensor along the selected dimension at the given index. + This function returns a view of the original tensor with the given dimension removed. + + .. note:: If :attr:`input` is a sparse tensor and returning a view of + the tensor is not possible, a RuntimeError exception is + raised. In this is the case, consider using + :func:`torch.select_copy` function. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to slice + index (int): the index to select with + + .. note:: + + :meth:`select` is equivalent to slicing. For example, + ``tensor.select(0, index)`` is equivalent to ``tensor[index]`` and + ``tensor.select(2, index)`` is equivalent to ``tensor[:,:,index]``. + """ + +@overload +def select( + input: Tensor, + dim: str | EllipsisType | None, + index: _int, +) -> Tensor: + r""" + select(input, dim, index) -> Tensor + + Slices the :attr:`input` tensor along the selected dimension at the given index. + This function returns a view of the original tensor with the given dimension removed. + + .. note:: If :attr:`input` is a sparse tensor and returning a view of + the tensor is not possible, a RuntimeError exception is + raised. In this is the case, consider using + :func:`torch.select_copy` function. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to slice + index (int): the index to select with + + .. note:: + + :meth:`select` is equivalent to slicing. For example, + ``tensor.select(0, index)`` is equivalent to ``tensor[index]`` and + ``tensor.select(2, index)`` is equivalent to ``tensor[:,:,index]``. + """ + +def select_copy( + input: Tensor, + dim: _int, + index: _int | SymInt, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + Performs the same operation as :func:`torch.select`, but all output tensors + are freshly created instead of aliasing the input. + """ + +def select_scatter( + input: Tensor, + src: Tensor, + dim: _int, + index: _int | SymInt, +) -> Tensor: + r""" + select_scatter(input, src, dim, index) -> Tensor + + Embeds the values of the :attr:`src` tensor into :attr:`input` at the given index. + This function returns a tensor with fresh storage; it does not create a view. + + + Args: + input (Tensor): the input tensor. + src (Tensor): The tensor to embed into :attr:`input` + dim (int): the dimension to insert the slice into. + index (int): the index to select with + + .. note:: + + :attr:`src` must be of the proper size in order to be embedded + into :attr:`input`. Specifically, it should have the same shape as + ``torch.select(input, dim, index)`` + + Example:: + + >>> a = torch.zeros(2, 2) + >>> b = torch.ones(2) + >>> a.select_scatter(b, 0, 0) + tensor([[1., 1.], + [0., 0.]]) + """ + +def selu(input: Tensor) -> Tensor: ... +def selu_(input: Tensor) -> Tensor: ... +def set_flush_denormal(mode: _bool) -> _bool: + r""" + set_flush_denormal(mode) -> bool + + Disables denormal floating numbers on CPU. + + Returns ``True`` if your system supports flushing denormal numbers and it + successfully configures flush denormal mode. :meth:`~torch.set_flush_denormal` + is supported on x86 architectures supporting SSE3 and AArch64 architecture. + + Args: + mode (bool): Controls whether to enable flush denormal mode or not + + Example:: + + >>> torch.set_flush_denormal(True) + True + >>> torch.tensor([1e-323], dtype=torch.float64) + tensor([ 0.], dtype=torch.float64) + >>> torch.set_flush_denormal(False) + True + >>> torch.tensor([1e-323], dtype=torch.float64) + tensor(9.88131e-324 * + [ 1.0000], dtype=torch.float64) + """ + +def set_num_interop_threads(num: _int) -> None: + r""" + set_num_interop_threads(int) + + Sets the number of threads used for interop parallelism + (e.g. in JIT interpreter) on CPU. + + .. warning:: + Can only be called once and before any inter-op parallel work + is started (e.g. JIT execution). + """ + +def set_num_threads(num: _int) -> None: + r""" + set_num_threads(int) + + Sets the number of threads used for intraop parallelism on CPU. + + .. warning:: + To ensure that the correct number of threads is used, set_num_threads + must be called before running eager, JIT or autograd code. + """ + +def sgn(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + sgn(input, *, out=None) -> Tensor + + This function is an extension of torch.sign() to complex tensors. + It computes a new tensor whose elements have + the same angles as the corresponding elements of :attr:`input` and + absolute values (i.e. magnitudes) of one for complex tensors and + is equivalent to torch.sign() for non-complex tensors. + + .. math:: + \text{out}_{i} = \begin{cases} + 0 & |\text{{input}}_i| == 0 \\ + \frac{{\text{{input}}_i}}{|{\text{{input}}_i}|} & \text{otherwise} + \end{cases} + + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.tensor([3+4j, 7-24j, 0, 1+2j]) + >>> t.sgn() + tensor([0.6000+0.8000j, 0.2800-0.9600j, 0.0000+0.0000j, 0.4472+0.8944j]) + """ + +def sigmoid(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + sigmoid(input, *, out=None) -> Tensor + + Alias for :func:`torch.special.expit`. + """ + +def sigmoid_(input: Tensor) -> Tensor: ... +def sign(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + sign(input, *, out=None) -> Tensor + + Returns a new tensor with the signs of the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \operatorname{sgn}(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([0.7, -1.2, 0., 2.3]) + >>> a + tensor([ 0.7000, -1.2000, 0.0000, 2.3000]) + >>> torch.sign(a) + tensor([ 1., -1., 0., 1.]) + """ + +def signbit(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + signbit(input, *, out=None) -> Tensor + + Tests if each element of :attr:`input` has its sign bit set or not. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([0.7, -1.2, 0., 2.3]) + >>> torch.signbit(a) + tensor([ False, True, False, False]) + >>> a = torch.tensor([-0.0, 0.0]) + >>> torch.signbit(a) + tensor([ True, False]) + + .. note:: + signbit handles signed zeros, so negative zero (-0) returns True. + """ + +def sin(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + sin(input, *, out=None) -> Tensor + + Returns a new tensor with the sine of the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \sin(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-0.5461, 0.1347, -2.7266, -0.2746]) + >>> torch.sin(a) + tensor([-0.5194, 0.1343, -0.4032, -0.2711]) + """ + +def sin_(input: Tensor) -> Tensor: ... +def sinc(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + sinc(input, *, out=None) -> Tensor + + Alias for :func:`torch.special.sinc`. + """ + +def sinc_(input: Tensor) -> Tensor: ... +def sinh(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + sinh(input, *, out=None) -> Tensor + + Returns a new tensor with the hyperbolic sine of the elements of + :attr:`input`. + + .. math:: + \text{out}_{i} = \sinh(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.5380, -0.8632, -0.1265, 0.9399]) + >>> torch.sinh(a) + tensor([ 0.5644, -0.9744, -0.1268, 1.0845]) + + .. note:: + When :attr:`input` is on the CPU, the implementation of torch.sinh may use + the Sleef library, which rounds very large results to infinity or negative + infinity. See `here `_ for details. + """ + +def sinh_(input: Tensor) -> Tensor: ... +def slice_copy( + input: Tensor, + dim: _int = 0, + start: _int | SymInt | None = None, + end: _int | SymInt | None = None, + step: _int | SymInt = 1, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + Performs the same operation as :func:`torch.slice`, but all output tensors + are freshly created instead of aliasing the input. + """ + +def slice_inverse( + input: Tensor, + src: Tensor, + dim: _int = 0, + start: _int | SymInt | None = None, + end: _int | SymInt | None = None, + step: _int | SymInt = 1, +) -> Tensor: ... +def slice_scatter( + input: Tensor, + src: Tensor, + dim: _int = 0, + start: _int | SymInt | None = None, + end: _int | SymInt | None = None, + step: _int | SymInt = 1, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + slice_scatter(input, src, dim=0, start=None, end=None, step=1) -> Tensor + + Embeds the values of the :attr:`src` tensor into :attr:`input` at the given + dimension. + This function returns a tensor with fresh storage; it does not create a view. + + + Args: + input (Tensor): the input tensor. + src (Tensor): The tensor to embed into :attr:`input` + dim (int): the dimension to insert the slice into + start (Optional[int]): the start index of where to insert the slice + end (Optional[int]): the end index of where to insert the slice + step (int): the how many elements to skip in + + Example:: + + >>> a = torch.zeros(8, 8) + >>> b = torch.ones(2, 8) + >>> a.slice_scatter(b, start=6) + tensor([[0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0.], + [1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1.]]) + + >>> b = torch.ones(8, 2) + >>> a.slice_scatter(b, dim=1, start=2, end=6, step=2) + tensor([[0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.]]) + """ + +def slogdet( + input: Tensor, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.slogdet: + r""" + slogdet(input) -> (Tensor, Tensor) + + Alias for :func:`torch.linalg.slogdet` + """ + +def smm(input: Tensor, mat2: Tensor) -> Tensor: + r""" + smm(input, mat) -> Tensor + + Performs a matrix multiplication of the sparse matrix :attr:`input` + with the dense matrix :attr:`mat`. + + Args: + input (Tensor): a sparse matrix to be matrix multiplied + mat (Tensor): a dense matrix to be matrix multiplied + """ + +@overload +def softmax( + input: Tensor, + dim: _int, + dtype: _dtype | None = None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + softmax(input, dim, *, dtype=None) -> Tensor + + Alias for :func:`torch.nn.functional.softmax`. + """ + +@overload +def softmax( + input: Tensor, + dim: str | EllipsisType | None, + *, + dtype: _dtype | None = None, +) -> Tensor: + r""" + softmax(input, dim, *, dtype=None) -> Tensor + + Alias for :func:`torch.nn.functional.softmax`. + """ + +@overload +def sort( + input: Tensor, + *, + stable: _bool | None, + dim: _int = -1, + descending: _bool = False, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.sort: + r""" + sort(input, dim=-1, descending=False, stable=False, *, out=None) -> (Tensor, LongTensor) + + Sorts the elements of the :attr:`input` tensor along a given dimension + in ascending order by value. + + If :attr:`dim` is not given, the last dimension of the `input` is chosen. + + If :attr:`descending` is ``True`` then the elements are sorted in descending + order by value. + + If :attr:`stable` is ``True`` then the sorting routine becomes stable, preserving + the order of equivalent elements. + + A namedtuple of (values, indices) is returned, where the `values` are the + sorted values and `indices` are the indices of the elements in the original + `input` tensor. + + Args: + input (Tensor): the input tensor. + dim (int, optional): the dimension to sort along + descending (bool, optional): controls the sorting order (ascending or descending) + stable (bool, optional): makes the sorting routine stable, which guarantees that the order + of equivalent elements is preserved. + + Keyword args: + out (tuple, optional): the output tuple of (`Tensor`, `LongTensor`) that can + be optionally given to be used as output buffers + + Example:: + + >>> x = torch.randn(3, 4) + >>> sorted, indices = torch.sort(x) + >>> sorted + tensor([[-0.2162, 0.0608, 0.6719, 2.3332], + [-0.5793, 0.0061, 0.6058, 0.9497], + [-0.5071, 0.3343, 0.9553, 1.0960]]) + >>> indices + tensor([[ 1, 0, 2, 3], + [ 3, 1, 0, 2], + [ 0, 3, 1, 2]]) + + >>> sorted, indices = torch.sort(x, 0) + >>> sorted + tensor([[-0.5071, -0.2162, 0.6719, -0.5793], + [ 0.0608, 0.0061, 0.9497, 0.3343], + [ 0.6058, 0.9553, 1.0960, 2.3332]]) + >>> indices + tensor([[ 2, 0, 0, 1], + [ 0, 1, 1, 2], + [ 1, 2, 2, 0]]) + >>> x = torch.tensor([0, 1] * 9) + >>> x.sort() + torch.return_types.sort( + values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + indices=tensor([ 2, 16, 4, 6, 14, 8, 0, 10, 12, 9, 17, 15, 13, 11, 7, 5, 3, 1])) + >>> x.sort(stable=True) + torch.return_types.sort( + values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + indices=tensor([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 1, 3, 5, 7, 9, 11, 13, 15, 17])) + """ + +@overload +def sort( + input: Tensor, + dim: _int = -1, + descending: _bool = False, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.sort: + r""" + sort(input, dim=-1, descending=False, stable=False, *, out=None) -> (Tensor, LongTensor) + + Sorts the elements of the :attr:`input` tensor along a given dimension + in ascending order by value. + + If :attr:`dim` is not given, the last dimension of the `input` is chosen. + + If :attr:`descending` is ``True`` then the elements are sorted in descending + order by value. + + If :attr:`stable` is ``True`` then the sorting routine becomes stable, preserving + the order of equivalent elements. + + A namedtuple of (values, indices) is returned, where the `values` are the + sorted values and `indices` are the indices of the elements in the original + `input` tensor. + + Args: + input (Tensor): the input tensor. + dim (int, optional): the dimension to sort along + descending (bool, optional): controls the sorting order (ascending or descending) + stable (bool, optional): makes the sorting routine stable, which guarantees that the order + of equivalent elements is preserved. + + Keyword args: + out (tuple, optional): the output tuple of (`Tensor`, `LongTensor`) that can + be optionally given to be used as output buffers + + Example:: + + >>> x = torch.randn(3, 4) + >>> sorted, indices = torch.sort(x) + >>> sorted + tensor([[-0.2162, 0.0608, 0.6719, 2.3332], + [-0.5793, 0.0061, 0.6058, 0.9497], + [-0.5071, 0.3343, 0.9553, 1.0960]]) + >>> indices + tensor([[ 1, 0, 2, 3], + [ 3, 1, 0, 2], + [ 0, 3, 1, 2]]) + + >>> sorted, indices = torch.sort(x, 0) + >>> sorted + tensor([[-0.5071, -0.2162, 0.6719, -0.5793], + [ 0.0608, 0.0061, 0.9497, 0.3343], + [ 0.6058, 0.9553, 1.0960, 2.3332]]) + >>> indices + tensor([[ 2, 0, 0, 1], + [ 0, 1, 1, 2], + [ 1, 2, 2, 0]]) + >>> x = torch.tensor([0, 1] * 9) + >>> x.sort() + torch.return_types.sort( + values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + indices=tensor([ 2, 16, 4, 6, 14, 8, 0, 10, 12, 9, 17, 15, 13, 11, 7, 5, 3, 1])) + >>> x.sort(stable=True) + torch.return_types.sort( + values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + indices=tensor([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 1, 3, 5, 7, 9, 11, 13, 15, 17])) + """ + +@overload +def sort( + input: Tensor, + *, + stable: _bool | None, + dim: str | EllipsisType | None, + descending: _bool = False, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.sort: + r""" + sort(input, dim=-1, descending=False, stable=False, *, out=None) -> (Tensor, LongTensor) + + Sorts the elements of the :attr:`input` tensor along a given dimension + in ascending order by value. + + If :attr:`dim` is not given, the last dimension of the `input` is chosen. + + If :attr:`descending` is ``True`` then the elements are sorted in descending + order by value. + + If :attr:`stable` is ``True`` then the sorting routine becomes stable, preserving + the order of equivalent elements. + + A namedtuple of (values, indices) is returned, where the `values` are the + sorted values and `indices` are the indices of the elements in the original + `input` tensor. + + Args: + input (Tensor): the input tensor. + dim (int, optional): the dimension to sort along + descending (bool, optional): controls the sorting order (ascending or descending) + stable (bool, optional): makes the sorting routine stable, which guarantees that the order + of equivalent elements is preserved. + + Keyword args: + out (tuple, optional): the output tuple of (`Tensor`, `LongTensor`) that can + be optionally given to be used as output buffers + + Example:: + + >>> x = torch.randn(3, 4) + >>> sorted, indices = torch.sort(x) + >>> sorted + tensor([[-0.2162, 0.0608, 0.6719, 2.3332], + [-0.5793, 0.0061, 0.6058, 0.9497], + [-0.5071, 0.3343, 0.9553, 1.0960]]) + >>> indices + tensor([[ 1, 0, 2, 3], + [ 3, 1, 0, 2], + [ 0, 3, 1, 2]]) + + >>> sorted, indices = torch.sort(x, 0) + >>> sorted + tensor([[-0.5071, -0.2162, 0.6719, -0.5793], + [ 0.0608, 0.0061, 0.9497, 0.3343], + [ 0.6058, 0.9553, 1.0960, 2.3332]]) + >>> indices + tensor([[ 2, 0, 0, 1], + [ 0, 1, 1, 2], + [ 1, 2, 2, 0]]) + >>> x = torch.tensor([0, 1] * 9) + >>> x.sort() + torch.return_types.sort( + values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + indices=tensor([ 2, 16, 4, 6, 14, 8, 0, 10, 12, 9, 17, 15, 13, 11, 7, 5, 3, 1])) + >>> x.sort(stable=True) + torch.return_types.sort( + values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + indices=tensor([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 1, 3, 5, 7, 9, 11, 13, 15, 17])) + """ + +@overload +def sort( + input: Tensor, + dim: str | EllipsisType | None, + descending: _bool = False, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.sort: + r""" + sort(input, dim=-1, descending=False, stable=False, *, out=None) -> (Tensor, LongTensor) + + Sorts the elements of the :attr:`input` tensor along a given dimension + in ascending order by value. + + If :attr:`dim` is not given, the last dimension of the `input` is chosen. + + If :attr:`descending` is ``True`` then the elements are sorted in descending + order by value. + + If :attr:`stable` is ``True`` then the sorting routine becomes stable, preserving + the order of equivalent elements. + + A namedtuple of (values, indices) is returned, where the `values` are the + sorted values and `indices` are the indices of the elements in the original + `input` tensor. + + Args: + input (Tensor): the input tensor. + dim (int, optional): the dimension to sort along + descending (bool, optional): controls the sorting order (ascending or descending) + stable (bool, optional): makes the sorting routine stable, which guarantees that the order + of equivalent elements is preserved. + + Keyword args: + out (tuple, optional): the output tuple of (`Tensor`, `LongTensor`) that can + be optionally given to be used as output buffers + + Example:: + + >>> x = torch.randn(3, 4) + >>> sorted, indices = torch.sort(x) + >>> sorted + tensor([[-0.2162, 0.0608, 0.6719, 2.3332], + [-0.5793, 0.0061, 0.6058, 0.9497], + [-0.5071, 0.3343, 0.9553, 1.0960]]) + >>> indices + tensor([[ 1, 0, 2, 3], + [ 3, 1, 0, 2], + [ 0, 3, 1, 2]]) + + >>> sorted, indices = torch.sort(x, 0) + >>> sorted + tensor([[-0.5071, -0.2162, 0.6719, -0.5793], + [ 0.0608, 0.0061, 0.9497, 0.3343], + [ 0.6058, 0.9553, 1.0960, 2.3332]]) + >>> indices + tensor([[ 2, 0, 0, 1], + [ 0, 1, 1, 2], + [ 1, 2, 2, 0]]) + >>> x = torch.tensor([0, 1] * 9) + >>> x.sort() + torch.return_types.sort( + values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + indices=tensor([ 2, 16, 4, 6, 14, 8, 0, 10, 12, 9, 17, 15, 13, 11, 7, 5, 3, 1])) + >>> x.sort(stable=True) + torch.return_types.sort( + values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + indices=tensor([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 1, 3, 5, 7, 9, 11, 13, 15, 17])) + """ + +def sparse_bsc_tensor( + ccol_indices: Tensor | list, + row_indices: Tensor | list, + values: Tensor | list, + size: _size | None = None, + *, + dtype: _dtype | None = None, + device: DeviceLikeType | None = None, + requires_grad: _bool = False, + check_invariants: _bool | None = None, +) -> Tensor: + r""" + sparse_bsc_tensor(ccol_indices, row_indices, values, size=None, *, dtype=None, device=None, pin_memory=False, requires_grad=False, check_invariants=None) -> Tensor + + Constructs a :ref:`sparse tensor in BSC (Block Compressed Sparse + Column)) ` with specified 2-dimensional blocks at the + given :attr:`ccol_indices` and :attr:`row_indices`. Sparse matrix + multiplication operations in BSC format are typically faster than that + for sparse tensors in COO format. Make you have a look at :ref:`the + note on the data type of the indices `. + + .. note:: + + If the ``device`` argument is not specified the device of the given + :attr:`values` and indices tensor(s) must match. If, however, the + argument is specified the input Tensors will be converted to the + given device and in turn determine the device of the constructed + sparse tensor. + + Args: + ccol_indices (array_like): (B+1)-dimensional array of size + ``(*batchsize, ncolblocks + 1)``. The last element of each + batch is the number of non-zeros. This tensor encodes the + index in values and row_indices depending on where the given + column starts. Each successive number in the tensor subtracted + by the number before it denotes the number of elements in a + given column. + row_indices (array_like): Row block co-ordinates of each block in + values. (B+1)-dimensional tensor with the same length + as values. + values (array_list): Initial blocks for the tensor. Can be a list, + tuple, NumPy ``ndarray``, and other types that + represents a (1 + 2 + K)-dimensional tensor where ``K`` is the + number of dense dimensions. + size (list, tuple, :class:`torch.Size`, optional): Size of the + sparse tensor: ``(*batchsize, nrows * blocksize[0], ncols * + blocksize[1], *densesize)`` If not provided, the size will be + inferred as the minimum size big enough to hold all non-zero + blocks. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of + returned tensor. Default: if None, infers data type from + :attr:`values`. + device (:class:`torch.device`, optional): the desired device of + returned tensor. Default: if None, uses the current device + for the default tensor type (see + :func:`torch.set_default_device`). :attr:`device` will be + the CPU for CPU tensor types and the current CUDA device for + CUDA tensor types. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + check_invariants (bool, optional): If sparse tensor invariants are checked. + Default: as returned by :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled`, + initially False. + + Example:: + + >>> ccol_indices = [0, 1, 2] + >>> row_indices = [0, 1] + >>> values = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] + >>> torch.sparse_bsc_tensor(torch.tensor(ccol_indices, dtype=torch.int64), + ... torch.tensor(row_indices, dtype=torch.int64), + ... torch.tensor(values), dtype=torch.double) + tensor(ccol_indices=tensor([0, 1, 2]), + row_indices=tensor([0, 1]), + values=tensor([[[1., 2.], + [3., 4.]], + [[5., 6.], + [7., 8.]]]), size=(2, 2), nnz=2, dtype=torch.float64, + layout=torch.sparse_bsc) + """ + +def sparse_bsr_tensor( + crow_indices: Tensor | list, + col_indices: Tensor | list, + values: Tensor | list, + size: _size | None = None, + *, + dtype: _dtype | None = None, + device: DeviceLikeType | None = None, + requires_grad: _bool = False, + check_invariants: _bool | None = None, +) -> Tensor: + r""" + sparse_bsr_tensor(crow_indices, col_indices, values, size=None, *, dtype=None, device=None, pin_memory=False, requires_grad=False, check_invariants=None) -> Tensor + + Constructs a :ref:`sparse tensor in BSR (Block Compressed Sparse Row)) + ` with specified 2-dimensional blocks at the given + :attr:`crow_indices` and :attr:`col_indices`. Sparse matrix + multiplication operations in BSR format are typically faster than that + for sparse tensors in COO format. Make you have a look at :ref:`the + note on the data type of the indices `. + + .. note:: + + If the ``device`` argument is not specified the device of the given + :attr:`values` and indices tensor(s) must match. If, however, the + argument is specified the input Tensors will be converted to the + given device and in turn determine the device of the constructed + sparse tensor. + + Args: + crow_indices (array_like): (B+1)-dimensional array of size + ``(*batchsize, nrowblocks + 1)``. The last element of each + batch is the number of non-zeros. This tensor encodes the + block index in values and col_indices depending on where the + given row block starts. Each successive number in the tensor + subtracted by the number before it denotes the number of + blocks in a given row. + col_indices (array_like): Column block co-ordinates of each block + in values. (B+1)-dimensional tensor with the same length as + values. + values (array_list): Initial values for the tensor. Can be a list, + tuple, NumPy ``ndarray``, scalar, and other types that + represents a (1 + 2 + K)-dimensional tensor where ``K`` is the + number of dense dimensions. + size (list, tuple, :class:`torch.Size`, optional): Size of the + sparse tensor: ``(*batchsize, nrows * blocksize[0], ncols * + blocksize[1], *densesize)`` where ``blocksize == + values.shape[1:3]``. If not provided, the size will be + inferred as the minimum size big enough to hold all non-zero + blocks. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of + returned tensor. Default: if None, infers data type from + :attr:`values`. + device (:class:`torch.device`, optional): the desired device of + returned tensor. Default: if None, uses the current device + for the default tensor type (see + :func:`torch.set_default_device`). :attr:`device` will be + the CPU for CPU tensor types and the current CUDA device for + CUDA tensor types. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + check_invariants (bool, optional): If sparse tensor invariants are checked. + Default: as returned by :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled`, + initially False. + + Example:: + + >>> crow_indices = [0, 1, 2] + >>> col_indices = [0, 1] + >>> values = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] + >>> torch.sparse_bsr_tensor(torch.tensor(crow_indices, dtype=torch.int64), + ... torch.tensor(col_indices, dtype=torch.int64), + ... torch.tensor(values), dtype=torch.double) + tensor(crow_indices=tensor([0, 1, 2]), + col_indices=tensor([0, 1]), + values=tensor([[[1., 2.], + [3., 4.]], + [[5., 6.], + [7., 8.]]]), size=(2, 2), nnz=2, dtype=torch.float64, + layout=torch.sparse_bsr) + """ + +def sparse_compressed_tensor( + compressed_indices: Tensor | list, + plain_indices: Tensor | list, + values: Tensor | list, + size: _size | None = None, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + requires_grad: _bool = False, + check_invariants: _bool | None = None, +) -> Tensor: + r""" + sparse_compressed_tensor(compressed_indices, plain_indices, values, size=None, *, dtype=None, layout=None, device=None, pin_memory=False, requires_grad=False, check_invariants=None) -> Tensor + + Constructs a :ref:`sparse tensor in Compressed Sparse format - CSR, + CSC, BSR, or BSC - ` with specified values at + the given :attr:`compressed_indices` and :attr:`plain_indices`. Sparse + matrix multiplication operations in Compressed Sparse format are + typically faster than that for sparse tensors in COO format. Make you + have a look at :ref:`the note on the data type of the indices + `. + + .. note:: + + If the ``device`` argument is not specified the device of the given + :attr:`values` and indices tensor(s) must match. If, however, the + argument is specified the input Tensors will be converted to the + given device and in turn determine the device of the constructed + sparse tensor. + + Args: + compressed_indices (array_like): (B+1)-dimensional array of size + ``(*batchsize, compressed_dim_size + 1)``. The last element of + each batch is the number of non-zero elements or blocks. This + tensor encodes the index in ``values`` and ``plain_indices`` + depending on where the given compressed dimension (row or + column) starts. Each successive number in the tensor + subtracted by the number before it denotes the number of + elements or blocks in a given compressed dimension. + plain_indices (array_like): Plain dimension (column or row) + co-ordinates of each element or block in values. (B+1)-dimensional + tensor with the same length as values. + + values (array_list): Initial values for the tensor. Can be a list, + tuple, NumPy ``ndarray``, scalar, and other types. that + represents a (1+K)-dimensional (for CSR and CSC layouts) or + (1+2+K)-dimensional tensor (for BSR and BSC layouts) where + ``K`` is the number of dense dimensions. + size (list, tuple, :class:`torch.Size`, optional): Size of the + sparse tensor: ``(*batchsize, nrows * blocksize[0], ncols * + blocksize[1], *densesize)`` where ``blocksize[0] == + blocksize[1] == 1`` for CSR and CSC formats. If not provided, + the size will be inferred as the minimum size big enough to + hold all non-zero elements or blocks. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of + returned tensor. Default: if None, infers data type from + :attr:`values`. + layout (:class:`torch.layout`, required): the desired layout of + returned tensor: :attr:`torch.sparse_csr`, + :attr:`torch.sparse_csc`, :attr:`torch.sparse_bsr`, or + :attr:`torch.sparse_bsc`. + device (:class:`torch.device`, optional): the desired device of + returned tensor. Default: if None, uses the current device + for the default tensor type (see + :func:`torch.set_default_device`). :attr:`device` will be + the CPU for CPU tensor types and the current CUDA device for + CUDA tensor types. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + check_invariants (bool, optional): If sparse tensor invariants are checked. + Default: as returned by :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled`, + initially False. + + Example:: + + >>> compressed_indices = [0, 2, 4] + >>> plain_indices = [0, 1, 0, 1] + >>> values = [1, 2, 3, 4] + >>> torch.sparse_compressed_tensor(torch.tensor(compressed_indices, dtype=torch.int64), + ... torch.tensor(plain_indices, dtype=torch.int64), + ... torch.tensor(values), dtype=torch.double, layout=torch.sparse_csr) + tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 1]), + values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4, + dtype=torch.float64, layout=torch.sparse_csr) + """ + +def sparse_coo_tensor( + indices: Tensor, + values: Tensor | list, + size: _size | None = None, + *, + dtype: _dtype | None = None, + device: DeviceLikeType | None = None, + requires_grad: _bool = False, + check_invariants: _bool | None = None, + is_coalesced: _bool | None = None, +) -> Tensor: + r""" + sparse_coo_tensor(indices, values, size=None, *, dtype=None, device=None, pin_memory=False, requires_grad=False, check_invariants=None, is_coalesced=None) -> Tensor + + Constructs a :ref:`sparse tensor in COO(rdinate) format + ` with specified values at the given + :attr:`indices`. + + .. note:: + + This function returns an :ref:`uncoalesced tensor + ` when :attr:`is_coalesced` is + unspecified or ``None``. + + .. note:: + + If the ``device`` argument is not specified the device of the given + :attr:`values` and indices tensor(s) must match. If, however, the + argument is specified the input Tensors will be converted to the + given device and in turn determine the device of the constructed + sparse tensor. + + Args: + indices (array_like): Initial data for the tensor. Can be a list, tuple, + NumPy ``ndarray``, scalar, and other types. Will be cast to a :class:`torch.LongTensor` + internally. The indices are the coordinates of the non-zero values in the matrix, and thus + should be two-dimensional where the first dimension is the number of tensor dimensions and + the second dimension is the number of non-zero values. + values (array_like): Initial values for the tensor. Can be a list, tuple, + NumPy ``ndarray``, scalar, and other types. + size (list, tuple, or :class:`torch.Size`, optional): Size of the sparse tensor. If not + provided the size will be inferred as the minimum size big enough to hold all non-zero + elements. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if None, infers data type from :attr:`values`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if None, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + check_invariants (bool, optional): If sparse tensor invariants are checked. + Default: as returned by :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled`, + initially False. + is_coalesced (bool, optional): When``True``, the caller is + responsible for providing tensor indices that correspond to a + coalesced tensor. If the :attr:`check_invariants` flag is + False, no error will be raised if the prerequisites are not + met and this will lead to silently incorrect results. To force + coalescion please use :meth:`coalesce` on the resulting + Tensor. + Default: None: except for trivial cases (e.g. nnz < 2) the + resulting Tensor has is_coalesced set to ``False```. + + Example:: + + >>> i = torch.tensor([[0, 1, 1], + ... [2, 0, 2]]) + >>> v = torch.tensor([3, 4, 5], dtype=torch.float32) + >>> torch.sparse_coo_tensor(i, v, [2, 4]) + tensor(indices=tensor([[0, 1, 1], + [2, 0, 2]]), + values=tensor([3., 4., 5.]), + size=(2, 4), nnz=3, layout=torch.sparse_coo) + + >>> torch.sparse_coo_tensor(i, v) # Shape inference + tensor(indices=tensor([[0, 1, 1], + [2, 0, 2]]), + values=tensor([3., 4., 5.]), + size=(2, 3), nnz=3, layout=torch.sparse_coo) + + >>> torch.sparse_coo_tensor(i, v, [2, 4], + ... dtype=torch.float64, + ... device=torch.device('cuda:0')) + tensor(indices=tensor([[0, 1, 1], + [2, 0, 2]]), + values=tensor([3., 4., 5.]), + device='cuda:0', size=(2, 4), nnz=3, dtype=torch.float64, + layout=torch.sparse_coo) + + # Create an empty sparse tensor with the following invariants: + # 1. sparse_dim + dense_dim = len(SparseTensor.shape) + # 2. SparseTensor._indices().shape = (sparse_dim, nnz) + # 3. SparseTensor._values().shape = (nnz, SparseTensor.shape[sparse_dim:]) + # + # For instance, to create an empty sparse tensor with nnz = 0, dense_dim = 0 and + # sparse_dim = 1 (hence indices is a 2D tensor of shape = (1, 0)) + >>> S = torch.sparse_coo_tensor(torch.empty([1, 0]), [], [1]) + tensor(indices=tensor([], size=(1, 0)), + values=tensor([], size=(0,)), + size=(1,), nnz=0, layout=torch.sparse_coo) + + # and to create an empty sparse tensor with nnz = 0, dense_dim = 1 and + # sparse_dim = 1 + >>> S = torch.sparse_coo_tensor(torch.empty([1, 0]), torch.empty([0, 2]), [1, 2]) + tensor(indices=tensor([], size=(1, 0)), + values=tensor([], size=(0, 2)), + size=(1, 2), nnz=0, layout=torch.sparse_coo) + + .. _torch.sparse: https://pytorch.org/docs/stable/sparse.html + """ + +def sparse_csc_tensor( + ccol_indices: Tensor | list, + row_indices: Tensor | list, + values: Tensor | list, + size: _size | None = None, + *, + dtype: _dtype | None = None, + device: DeviceLikeType | None = None, + requires_grad: _bool = False, + check_invariants: _bool | None = None, +) -> Tensor: + r""" + sparse_csc_tensor(ccol_indices, row_indices, values, size=None, *, dtype=None, device=None, pin_memory=False, requires_grad=False, check_invariants=None) -> Tensor + + Constructs a :ref:`sparse tensor in CSC (Compressed Sparse Column) + ` with specified values at the given + :attr:`ccol_indices` and :attr:`row_indices`. Sparse matrix + multiplication operations in CSC format are typically faster than that + for sparse tensors in COO format. Make you have a look at :ref:`the + note on the data type of the indices `. + + .. note:: + + If the ``device`` argument is not specified the device of the given + :attr:`values` and indices tensor(s) must match. If, however, the + argument is specified the input Tensors will be converted to the + given device and in turn determine the device of the constructed + sparse tensor. + + Args: + ccol_indices (array_like): (B+1)-dimensional array of size + ``(*batchsize, ncols + 1)``. The last element of each batch + is the number of non-zeros. This tensor encodes the index in + values and row_indices depending on where the given column + starts. Each successive number in the tensor subtracted by the + number before it denotes the number of elements in a given + column. + row_indices (array_like): Row co-ordinates of each element in + values. (B+1)-dimensional tensor with the same length as + values. + values (array_list): Initial values for the tensor. Can be a list, + tuple, NumPy ``ndarray``, scalar, and other types that + represents a (1+K)-dimensional tensor where ``K`` is the number + of dense dimensions. + size (list, tuple, :class:`torch.Size`, optional): Size of the + sparse tensor: ``(*batchsize, nrows, ncols, *densesize)``. If + not provided, the size will be inferred as the minimum size + big enough to hold all non-zero elements. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of + returned tensor. Default: if None, infers data type from + :attr:`values`. + device (:class:`torch.device`, optional): the desired device of + returned tensor. Default: if None, uses the current device + for the default tensor type (see + :func:`torch.set_default_device`). :attr:`device` will be + the CPU for CPU tensor types and the current CUDA device for + CUDA tensor types. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + check_invariants (bool, optional): If sparse tensor invariants are checked. + Default: as returned by :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled`, + initially False. + + Example:: + + >>> ccol_indices = [0, 2, 4] + >>> row_indices = [0, 1, 0, 1] + >>> values = [1, 2, 3, 4] + >>> torch.sparse_csc_tensor(torch.tensor(ccol_indices, dtype=torch.int64), + ... torch.tensor(row_indices, dtype=torch.int64), + ... torch.tensor(values), dtype=torch.double) + tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 1]), + values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4, + dtype=torch.float64, layout=torch.sparse_csc) + """ + +def sparse_csr_tensor( + crow_indices: Tensor | list, + col_indices: Tensor | list, + values: Tensor | list, + size: _size | None = None, + *, + dtype: _dtype | None = None, + device: DeviceLikeType | None = None, + requires_grad: _bool = False, + check_invariants: _bool | None = None, +) -> Tensor: + r""" + sparse_csr_tensor(crow_indices, col_indices, values, size=None, *, dtype=None, device=None, pin_memory=False, requires_grad=False, check_invariants=None) -> Tensor + + Constructs a :ref:`sparse tensor in CSR (Compressed Sparse Row) ` with specified + values at the given :attr:`crow_indices` and :attr:`col_indices`. Sparse matrix multiplication operations + in CSR format are typically faster than that for sparse tensors in COO format. Make you have a look + at :ref:`the note on the data type of the indices `. + + .. note:: + + If the ``device`` argument is not specified the device of the given + :attr:`values` and indices tensor(s) must match. If, however, the + argument is specified the input Tensors will be converted to the + given device and in turn determine the device of the constructed + sparse tensor. + + Args: + crow_indices (array_like): (B+1)-dimensional array of size + ``(*batchsize, nrows + 1)``. The last element of each batch + is the number of non-zeros. This tensor encodes the index in + values and col_indices depending on where the given row + starts. Each successive number in the tensor subtracted by the + number before it denotes the number of elements in a given + row. + col_indices (array_like): Column co-ordinates of each element in + values. (B+1)-dimensional tensor with the same length + as values. + values (array_list): Initial values for the tensor. Can be a list, + tuple, NumPy ``ndarray``, scalar, and other types that + represents a (1+K)-dimensional tensor where ``K`` is the number + of dense dimensions. + size (list, tuple, :class:`torch.Size`, optional): Size of the + sparse tensor: ``(*batchsize, nrows, ncols, *densesize)``. If + not provided, the size will be inferred as the minimum size + big enough to hold all non-zero elements. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of + returned tensor. Default: if None, infers data type from + :attr:`values`. + device (:class:`torch.device`, optional): the desired device of + returned tensor. Default: if None, uses the current device + for the default tensor type (see + :func:`torch.set_default_device`). :attr:`device` will be + the CPU for CPU tensor types and the current CUDA device for + CUDA tensor types. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + check_invariants (bool, optional): If sparse tensor invariants are checked. + Default: as returned by :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled`, + initially False. + + Example:: + + >>> crow_indices = [0, 2, 4] + >>> col_indices = [0, 1, 0, 1] + >>> values = [1, 2, 3, 4] + >>> torch.sparse_csr_tensor(torch.tensor(crow_indices, dtype=torch.int64), + ... torch.tensor(col_indices, dtype=torch.int64), + ... torch.tensor(values), dtype=torch.double) + tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 1]), + values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4, + dtype=torch.float64, layout=torch.sparse_csr) + """ + +def split_copy( + input: Tensor, + split_size: _int | SymInt, + dim: _int = 0, + *, + out: tuple[Tensor, ...] | list[Tensor] | None = None, +) -> None: + r""" + Performs the same operation as :func:`torch.split`, but all output tensors + are freshly created instead of aliasing the input. + """ + +def split_with_sizes( + input: Tensor, + split_sizes: Sequence[_int | SymInt], + dim: _int = 0, +) -> tuple[Tensor, ...]: ... +def split_with_sizes_copy( + input: Tensor, + split_sizes: Sequence[_int | SymInt], + dim: _int = 0, + *, + out: tuple[Tensor, ...] | list[Tensor] | None = None, +) -> None: + r""" + Performs the same operation as :func:`torch.split_with_sizes`, but all output tensors + are freshly created instead of aliasing the input. + """ + +def spmm(input: Tensor, mat2: Tensor) -> Tensor: ... +def sqrt(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + sqrt(input, *, out=None) -> Tensor + + Returns a new tensor with the square-root of the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \sqrt{\text{input}_{i}} + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-2.0755, 1.0226, 0.0831, 0.4806]) + >>> torch.sqrt(a) + tensor([ nan, 1.0112, 0.2883, 0.6933]) + """ + +def sqrt_(input: Tensor) -> Tensor: ... +def square(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + square(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Returns a new tensor with the square of the elements of :attr:`input`. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-2.0755, 1.0226, 0.0831, 0.4806]) + >>> torch.square(a) + tensor([ 4.3077, 1.0457, 0.0069, 0.2310]) + """ + +def square_(input: Tensor) -> Tensor: ... +@overload +def squeeze(input: Tensor) -> Tensor: + r""" + squeeze(input: Tensor, dim: Optional[Union[int, List[int]]]) -> Tensor + + Returns a tensor with all specified dimensions of :attr:`input` of size `1` removed. + + For example, if `input` is of shape: + :math:`(A \times 1 \times B \times C \times 1 \times D)` then the `input.squeeze()` + will be of shape: :math:`(A \times B \times C \times D)`. + + When :attr:`dim` is given, a squeeze operation is done only in the given + dimension(s). If `input` is of shape: :math:`(A \times 1 \times B)`, + ``squeeze(input, 0)`` leaves the tensor unchanged, but ``squeeze(input, 1)`` + will squeeze the tensor to the shape :math:`(A \times B)`. + + .. note:: The returned tensor shares the storage with the input tensor, + so changing the contents of one will change the contents of the other. + + .. warning:: If the tensor has a batch dimension of size 1, then `squeeze(input)` + will also remove the batch dimension, which can lead to unexpected + errors. Consider specifying only the dims you wish to be squeezed. + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints, optional): if given, the input will be squeezed + only in the specified dimensions. + + .. versionchanged:: 2.0 + :attr:`dim` now accepts tuples of dimensions. + + Example:: + + >>> x = torch.zeros(2, 1, 2, 1, 2) + >>> x.size() + torch.Size([2, 1, 2, 1, 2]) + >>> y = torch.squeeze(x) + >>> y.size() + torch.Size([2, 2, 2]) + >>> y = torch.squeeze(x, 0) + >>> y.size() + torch.Size([2, 1, 2, 1, 2]) + >>> y = torch.squeeze(x, 1) + >>> y.size() + torch.Size([2, 2, 1, 2]) + >>> y = torch.squeeze(x, (1, 2, 3)) + torch.Size([2, 2, 2]) + """ + +@overload +def squeeze(input: Tensor, dim: _int) -> Tensor: + r""" + squeeze(input: Tensor, dim: Optional[Union[int, List[int]]]) -> Tensor + + Returns a tensor with all specified dimensions of :attr:`input` of size `1` removed. + + For example, if `input` is of shape: + :math:`(A \times 1 \times B \times C \times 1 \times D)` then the `input.squeeze()` + will be of shape: :math:`(A \times B \times C \times D)`. + + When :attr:`dim` is given, a squeeze operation is done only in the given + dimension(s). If `input` is of shape: :math:`(A \times 1 \times B)`, + ``squeeze(input, 0)`` leaves the tensor unchanged, but ``squeeze(input, 1)`` + will squeeze the tensor to the shape :math:`(A \times B)`. + + .. note:: The returned tensor shares the storage with the input tensor, + so changing the contents of one will change the contents of the other. + + .. warning:: If the tensor has a batch dimension of size 1, then `squeeze(input)` + will also remove the batch dimension, which can lead to unexpected + errors. Consider specifying only the dims you wish to be squeezed. + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints, optional): if given, the input will be squeezed + only in the specified dimensions. + + .. versionchanged:: 2.0 + :attr:`dim` now accepts tuples of dimensions. + + Example:: + + >>> x = torch.zeros(2, 1, 2, 1, 2) + >>> x.size() + torch.Size([2, 1, 2, 1, 2]) + >>> y = torch.squeeze(x) + >>> y.size() + torch.Size([2, 2, 2]) + >>> y = torch.squeeze(x, 0) + >>> y.size() + torch.Size([2, 1, 2, 1, 2]) + >>> y = torch.squeeze(x, 1) + >>> y.size() + torch.Size([2, 2, 1, 2]) + >>> y = torch.squeeze(x, (1, 2, 3)) + torch.Size([2, 2, 2]) + """ + +@overload +def squeeze(input: Tensor, dim: _size) -> Tensor: + r""" + squeeze(input: Tensor, dim: Optional[Union[int, List[int]]]) -> Tensor + + Returns a tensor with all specified dimensions of :attr:`input` of size `1` removed. + + For example, if `input` is of shape: + :math:`(A \times 1 \times B \times C \times 1 \times D)` then the `input.squeeze()` + will be of shape: :math:`(A \times B \times C \times D)`. + + When :attr:`dim` is given, a squeeze operation is done only in the given + dimension(s). If `input` is of shape: :math:`(A \times 1 \times B)`, + ``squeeze(input, 0)`` leaves the tensor unchanged, but ``squeeze(input, 1)`` + will squeeze the tensor to the shape :math:`(A \times B)`. + + .. note:: The returned tensor shares the storage with the input tensor, + so changing the contents of one will change the contents of the other. + + .. warning:: If the tensor has a batch dimension of size 1, then `squeeze(input)` + will also remove the batch dimension, which can lead to unexpected + errors. Consider specifying only the dims you wish to be squeezed. + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints, optional): if given, the input will be squeezed + only in the specified dimensions. + + .. versionchanged:: 2.0 + :attr:`dim` now accepts tuples of dimensions. + + Example:: + + >>> x = torch.zeros(2, 1, 2, 1, 2) + >>> x.size() + torch.Size([2, 1, 2, 1, 2]) + >>> y = torch.squeeze(x) + >>> y.size() + torch.Size([2, 2, 2]) + >>> y = torch.squeeze(x, 0) + >>> y.size() + torch.Size([2, 1, 2, 1, 2]) + >>> y = torch.squeeze(x, 1) + >>> y.size() + torch.Size([2, 2, 1, 2]) + >>> y = torch.squeeze(x, (1, 2, 3)) + torch.Size([2, 2, 2]) + """ + +@overload +def squeeze(input: Tensor, dim: str | EllipsisType | None) -> Tensor: + r""" + squeeze(input: Tensor, dim: Optional[Union[int, List[int]]]) -> Tensor + + Returns a tensor with all specified dimensions of :attr:`input` of size `1` removed. + + For example, if `input` is of shape: + :math:`(A \times 1 \times B \times C \times 1 \times D)` then the `input.squeeze()` + will be of shape: :math:`(A \times B \times C \times D)`. + + When :attr:`dim` is given, a squeeze operation is done only in the given + dimension(s). If `input` is of shape: :math:`(A \times 1 \times B)`, + ``squeeze(input, 0)`` leaves the tensor unchanged, but ``squeeze(input, 1)`` + will squeeze the tensor to the shape :math:`(A \times B)`. + + .. note:: The returned tensor shares the storage with the input tensor, + so changing the contents of one will change the contents of the other. + + .. warning:: If the tensor has a batch dimension of size 1, then `squeeze(input)` + will also remove the batch dimension, which can lead to unexpected + errors. Consider specifying only the dims you wish to be squeezed. + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints, optional): if given, the input will be squeezed + only in the specified dimensions. + + .. versionchanged:: 2.0 + :attr:`dim` now accepts tuples of dimensions. + + Example:: + + >>> x = torch.zeros(2, 1, 2, 1, 2) + >>> x.size() + torch.Size([2, 1, 2, 1, 2]) + >>> y = torch.squeeze(x) + >>> y.size() + torch.Size([2, 2, 2]) + >>> y = torch.squeeze(x, 0) + >>> y.size() + torch.Size([2, 1, 2, 1, 2]) + >>> y = torch.squeeze(x, 1) + >>> y.size() + torch.Size([2, 2, 1, 2]) + >>> y = torch.squeeze(x, (1, 2, 3)) + torch.Size([2, 2, 2]) + """ + +@overload +def squeeze_copy(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.squeeze`, but all output tensors + are freshly created instead of aliasing the input. + """ + +@overload +def squeeze_copy( + input: Tensor, + dim: _int, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + Performs the same operation as :func:`torch.squeeze`, but all output tensors + are freshly created instead of aliasing the input. + """ + +@overload +def squeeze_copy( + input: Tensor, + dim: _size, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + Performs the same operation as :func:`torch.squeeze`, but all output tensors + are freshly created instead of aliasing the input. + """ + +@overload +def sspaddmm( + beta: Number | _complex, + self: Tensor, + alpha: Number | _complex, + mat1: Tensor, + mat2: Tensor, +) -> Tensor: + r""" + sspaddmm(input, mat1, mat2, *, beta=1, alpha=1, out=None) -> Tensor + + Matrix multiplies a sparse tensor :attr:`mat1` with a dense tensor + :attr:`mat2`, then adds the sparse tensor :attr:`input` to the result. + + Note: This function is equivalent to :func:`torch.addmm`, except + :attr:`input` and :attr:`mat1` are sparse. + + Args: + input (Tensor): a sparse matrix to be added + mat1 (Tensor): a sparse matrix to be matrix multiplied + mat2 (Tensor): a dense matrix to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + """ + +@overload +def sspaddmm( + input: Tensor, + mat1: Tensor, + mat2: Tensor, + *, + beta: Number | _complex = 1, + alpha: Number | _complex = 1, + out: Tensor | None = None, +) -> Tensor: + r""" + sspaddmm(input, mat1, mat2, *, beta=1, alpha=1, out=None) -> Tensor + + Matrix multiplies a sparse tensor :attr:`mat1` with a dense tensor + :attr:`mat2`, then adds the sparse tensor :attr:`input` to the result. + + Note: This function is equivalent to :func:`torch.addmm`, except + :attr:`input` and :attr:`mat1` are sparse. + + Args: + input (Tensor): a sparse matrix to be added + mat1 (Tensor): a sparse matrix to be matrix multiplied + mat2 (Tensor): a dense matrix to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + """ + +@overload +def sspaddmm( + beta: Number | _complex, + self: Tensor, + mat1: Tensor, + mat2: Tensor, +) -> Tensor: + r""" + sspaddmm(input, mat1, mat2, *, beta=1, alpha=1, out=None) -> Tensor + + Matrix multiplies a sparse tensor :attr:`mat1` with a dense tensor + :attr:`mat2`, then adds the sparse tensor :attr:`input` to the result. + + Note: This function is equivalent to :func:`torch.addmm`, except + :attr:`input` and :attr:`mat1` are sparse. + + Args: + input (Tensor): a sparse matrix to be added + mat1 (Tensor): a sparse matrix to be matrix multiplied + mat2 (Tensor): a dense matrix to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + """ + +def stack( + tensors: tuple[Tensor, ...] | list[Tensor] | None, + dim: _int = 0, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + stack(tensors, dim=0, *, out=None) -> Tensor + + Concatenates a sequence of tensors along a new dimension. + + All tensors need to be of the same size. + + .. seealso:: + + :func:`torch.cat` concatenates the given sequence along an existing dimension. + + Arguments: + tensors (sequence of Tensors): sequence of tensors to concatenate + dim (int, optional): dimension to insert. Has to be between 0 and the number + of dimensions of concatenated tensors (inclusive). Default: 0 + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> x = torch.randn(2, 3) + >>> x + tensor([[ 0.3367, 0.1288, 0.2345], + [ 0.2303, -1.1229, -0.1863]]) + >>> torch.stack((x, x)) # same as torch.stack((x, x), dim=0) + tensor([[[ 0.3367, 0.1288, 0.2345], + [ 0.2303, -1.1229, -0.1863]], + + [[ 0.3367, 0.1288, 0.2345], + [ 0.2303, -1.1229, -0.1863]]]) + >>> torch.stack((x, x)).size() + torch.Size([2, 2, 3]) + >>> torch.stack((x, x), dim=1) + tensor([[[ 0.3367, 0.1288, 0.2345], + [ 0.3367, 0.1288, 0.2345]], + + [[ 0.2303, -1.1229, -0.1863], + [ 0.2303, -1.1229, -0.1863]]]) + >>> torch.stack((x, x), dim=2) + tensor([[[ 0.3367, 0.3367], + [ 0.1288, 0.1288], + [ 0.2345, 0.2345]], + + [[ 0.2303, 0.2303], + [-1.1229, -1.1229], + [-0.1863, -0.1863]]]) + >>> torch.stack((x, x), dim=-1) + tensor([[[ 0.3367, 0.3367], + [ 0.1288, 0.1288], + [ 0.2345, 0.2345]], + + [[ 0.2303, 0.2303], + [-1.1229, -1.1229], + [-0.1863, -0.1863]]]) + """ + +@overload +def std( + input: Tensor, + dim: _int | _size | None, + unbiased: _bool = True, + keepdim: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + std(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the standard deviation over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.std(a, dim=1, keepdim=True) + tensor([[1.0311], + [0.7477], + [1.2204], + [0.9087]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +@overload +def std( + input: Tensor, + dim: _int | _size | None = None, + *, + correction: Number | _complex | None = None, + keepdim: _bool = False, + out: Tensor | None = None, +) -> Tensor: + r""" + std(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the standard deviation over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.std(a, dim=1, keepdim=True) + tensor([[1.0311], + [0.7477], + [1.2204], + [0.9087]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +@overload +def std(input: Tensor, unbiased: _bool = True) -> Tensor: + r""" + std(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the standard deviation over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.std(a, dim=1, keepdim=True) + tensor([[1.0311], + [0.7477], + [1.2204], + [0.9087]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +@overload +def std( + input: Tensor, + dim: Sequence[str | EllipsisType | None], + *, + correction: Number | _complex | None = None, + keepdim: _bool = False, + out: Tensor | None = None, +) -> Tensor: + r""" + std(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the standard deviation over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.std(a, dim=1, keepdim=True) + tensor([[1.0311], + [0.7477], + [1.2204], + [0.9087]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +@overload +def std( + input: Tensor, + dim: Sequence[str | EllipsisType | None], + unbiased: _bool = True, + keepdim: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + std(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the standard deviation over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.std(a, dim=1, keepdim=True) + tensor([[1.0311], + [0.7477], + [1.2204], + [0.9087]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +@overload +def std_mean( + input: Tensor, + dim: _int | _size | None, + unbiased: _bool = True, + keepdim: _bool = False, +) -> tuple[Tensor, Tensor]: + r""" + std_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the standard deviation and mean over the dimensions specified by + :attr:`dim`. :attr:`dim` can be a single dimension, list of dimensions, or + ``None`` to reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Returns: + A tuple (std, mean) containing the standard deviation and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.std_mean(a, dim=0, keepdim=True) + (tensor([[1.2620, 1.0028, 1.0957, 0.6038]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +@overload +def std_mean( + input: Tensor, + dim: _int | _size | None = None, + *, + correction: Number | _complex | None = None, + keepdim: _bool = False, +) -> tuple[Tensor, Tensor]: + r""" + std_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the standard deviation and mean over the dimensions specified by + :attr:`dim`. :attr:`dim` can be a single dimension, list of dimensions, or + ``None`` to reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Returns: + A tuple (std, mean) containing the standard deviation and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.std_mean(a, dim=0, keepdim=True) + (tensor([[1.2620, 1.0028, 1.0957, 0.6038]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +@overload +def std_mean( + input: Tensor, + unbiased: _bool = True, +) -> tuple[Tensor, Tensor]: + r""" + std_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the standard deviation and mean over the dimensions specified by + :attr:`dim`. :attr:`dim` can be a single dimension, list of dimensions, or + ``None`` to reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Returns: + A tuple (std, mean) containing the standard deviation and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.std_mean(a, dim=0, keepdim=True) + (tensor([[1.2620, 1.0028, 1.0957, 0.6038]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +@overload +def std_mean( + input: Tensor, + dim: Sequence[str | EllipsisType | None], + *, + correction: Number | _complex | None = None, + keepdim: _bool = False, +) -> tuple[Tensor, Tensor]: + r""" + std_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the standard deviation and mean over the dimensions specified by + :attr:`dim`. :attr:`dim` can be a single dimension, list of dimensions, or + ``None`` to reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Returns: + A tuple (std, mean) containing the standard deviation and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.std_mean(a, dim=0, keepdim=True) + (tensor([[1.2620, 1.0028, 1.0957, 0.6038]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +@overload +def std_mean( + input: Tensor, + dim: Sequence[str | EllipsisType | None], + unbiased: _bool = True, + keepdim: _bool = False, +) -> tuple[Tensor, Tensor]: + r""" + std_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the standard deviation and mean over the dimensions specified by + :attr:`dim`. :attr:`dim` can be a single dimension, list of dimensions, or + ``None`` to reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Returns: + A tuple (std, mean) containing the standard deviation and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.std_mean(a, dim=0, keepdim=True) + (tensor([[1.2620, 1.0028, 1.0957, 0.6038]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +@overload +def sub( + input: Tensor | Number | _complex, + other: Tensor | Number | _complex, + *, + alpha: Number | _complex | None = 1, + out: Tensor | None = None, +) -> Tensor: + r""" + sub(input, other, *, alpha=1, out=None) -> Tensor + + Subtracts :attr:`other`, scaled by :attr:`alpha`, from :attr:`input`. + + .. math:: + \text{{out}}_i = \text{{input}}_i - \text{{alpha}} \times \text{{other}}_i + + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer, float, and complex inputs. + + Args: + input (Tensor): the input tensor. + other (Tensor or Number): the tensor or number to subtract from :attr:`input`. + + Keyword args: + alpha (Number): the multiplier for :attr:`other`. + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor((1, 2)) + >>> b = torch.tensor((0, 1)) + >>> torch.sub(a, b, alpha=2) + tensor([1, 0]) + """ + +@overload +def sub(self: Tensor, alpha: Number | _complex, other: Tensor) -> Tensor: + r""" + sub(input, other, *, alpha=1, out=None) -> Tensor + + Subtracts :attr:`other`, scaled by :attr:`alpha`, from :attr:`input`. + + .. math:: + \text{{out}}_i = \text{{input}}_i - \text{{alpha}} \times \text{{other}}_i + + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer, float, and complex inputs. + + Args: + input (Tensor): the input tensor. + other (Tensor or Number): the tensor or number to subtract from :attr:`input`. + + Keyword args: + alpha (Number): the multiplier for :attr:`other`. + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor((1, 2)) + >>> b = torch.tensor((0, 1)) + >>> torch.sub(a, b, alpha=2) + tensor([1, 0]) + """ + +@overload +def sub( + self: Tensor, + alpha: Number | _complex, + other: Tensor, + *, + out: Tensor, +) -> Tensor: + r""" + sub(input, other, *, alpha=1, out=None) -> Tensor + + Subtracts :attr:`other`, scaled by :attr:`alpha`, from :attr:`input`. + + .. math:: + \text{{out}}_i = \text{{input}}_i - \text{{alpha}} \times \text{{other}}_i + + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer, float, and complex inputs. + + Args: + input (Tensor): the input tensor. + other (Tensor or Number): the tensor or number to subtract from :attr:`input`. + + Keyword args: + alpha (Number): the multiplier for :attr:`other`. + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor((1, 2)) + >>> b = torch.tensor((0, 1)) + >>> torch.sub(a, b, alpha=2) + tensor([1, 0]) + """ + +@overload +def subtract( + input: Tensor, + other: Tensor, + *, + alpha: Number | _complex = 1, + out: Tensor | None = None, +) -> Tensor: + r""" + subtract(input, other, *, alpha=1, out=None) -> Tensor + + Alias for :func:`torch.sub`. + """ + +@overload +def subtract( + input: Tensor, + other: Number | _complex, + alpha: Number | _complex = 1, +) -> Tensor: + r""" + subtract(input, other, *, alpha=1, out=None) -> Tensor + + Alias for :func:`torch.sub`. + """ + +@overload +def sum(input: Tensor, *, dtype: _dtype | None = None) -> Tensor: + r""" + sum(input, *, dtype=None) -> Tensor + + Returns the sum of all elements in the :attr:`input` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + .. note:: Use the `dtype` argument if you need the result in a specific tensor type. + Otherwise, the result type may be automatically promoted (e.g., from `torch.int32` to `torch.int64`). + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.1133, -0.9567, 0.2958]]) + >>> torch.sum(a) + tensor(-0.5475) + + .. function:: sum(input, dim, keepdim=False, *, dtype=None) -> Tensor + :noindex: + + Returns the sum of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`. If :attr:`dim` is a list of dimensions, + reduce over all of them. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.0569, -0.2475, 0.0737, -0.3429], + [-0.2993, 0.9138, 0.9337, -1.6864], + [ 0.1132, 0.7892, -0.1003, 0.5688], + [ 0.3637, -0.9906, -0.4752, -1.5197]]) + >>> torch.sum(a, 1) + tensor([-0.4598, -0.1381, 1.3708, -2.6217]) + >>> b = torch.arange(4 * 5 * 6).view(4, 5, 6) + >>> torch.sum(b, (2, 1)) + tensor([ 435., 1335., 2235., 3135.]) + """ + +@overload +def sum( + input: Tensor, + dim: _int | _size | None, + keepdim: _bool = False, + *, + dtype: _dtype | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + sum(input, *, dtype=None) -> Tensor + + Returns the sum of all elements in the :attr:`input` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + .. note:: Use the `dtype` argument if you need the result in a specific tensor type. + Otherwise, the result type may be automatically promoted (e.g., from `torch.int32` to `torch.int64`). + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.1133, -0.9567, 0.2958]]) + >>> torch.sum(a) + tensor(-0.5475) + + .. function:: sum(input, dim, keepdim=False, *, dtype=None) -> Tensor + :noindex: + + Returns the sum of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`. If :attr:`dim` is a list of dimensions, + reduce over all of them. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.0569, -0.2475, 0.0737, -0.3429], + [-0.2993, 0.9138, 0.9337, -1.6864], + [ 0.1132, 0.7892, -0.1003, 0.5688], + [ 0.3637, -0.9906, -0.4752, -1.5197]]) + >>> torch.sum(a, 1) + tensor([-0.4598, -0.1381, 1.3708, -2.6217]) + >>> b = torch.arange(4 * 5 * 6).view(4, 5, 6) + >>> torch.sum(b, (2, 1)) + tensor([ 435., 1335., 2235., 3135.]) + """ + +@overload +def sum( + input: Tensor, + dim: Sequence[str | EllipsisType | None], + keepdim: _bool = False, + *, + dtype: _dtype | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + sum(input, *, dtype=None) -> Tensor + + Returns the sum of all elements in the :attr:`input` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + .. note:: Use the `dtype` argument if you need the result in a specific tensor type. + Otherwise, the result type may be automatically promoted (e.g., from `torch.int32` to `torch.int64`). + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.1133, -0.9567, 0.2958]]) + >>> torch.sum(a) + tensor(-0.5475) + + .. function:: sum(input, dim, keepdim=False, *, dtype=None) -> Tensor + :noindex: + + Returns the sum of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`. If :attr:`dim` is a list of dimensions, + reduce over all of them. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.0569, -0.2475, 0.0737, -0.3429], + [-0.2993, 0.9138, 0.9337, -1.6864], + [ 0.1132, 0.7892, -0.1003, 0.5688], + [ 0.3637, -0.9906, -0.4752, -1.5197]]) + >>> torch.sum(a, 1) + tensor([-0.4598, -0.1381, 1.3708, -2.6217]) + >>> b = torch.arange(4 * 5 * 6).view(4, 5, 6) + >>> torch.sum(b, (2, 1)) + tensor([ 435., 1335., 2235., 3135.]) + """ + +def svd( + input: Tensor, + some: _bool = True, + compute_uv: _bool = True, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.svd: + r""" + svd(input, some=True, compute_uv=True, *, out=None) -> (Tensor, Tensor, Tensor) + + Computes the singular value decomposition of either a matrix or batch of + matrices :attr:`input`. The singular value decomposition is represented as a + namedtuple `(U, S, V)`, such that :attr:`input` :math:`= U \text{diag}(S) V^{\text{H}}`. + where :math:`V^{\text{H}}` is the transpose of `V` for real inputs, + and the conjugate transpose of `V` for complex inputs. + If :attr:`input` is a batch of matrices, then `U`, `S`, and `V` are also + batched with the same batch dimensions as :attr:`input`. + + If :attr:`some` is `True` (default), the method returns the reduced singular + value decomposition. In this case, if the last two dimensions of :attr:`input` are + `m` and `n`, then the returned `U` and `V` matrices will contain only + `min(n, m)` orthonormal columns. + + If :attr:`compute_uv` is `False`, the returned `U` and `V` will be + zero-filled matrices of shape `(m, m)` and `(n, n)` + respectively, and the same device as :attr:`input`. The argument :attr:`some` + has no effect when :attr:`compute_uv` is `False`. + + Supports :attr:`input` of float, double, cfloat and cdouble data types. + The dtypes of `U` and `V` are the same as :attr:`input`'s. `S` will + always be real-valued, even if :attr:`input` is complex. + + .. warning:: + + :func:`torch.svd` is deprecated in favor of :func:`torch.linalg.svd` + and will be removed in a future PyTorch release. + + ``U, S, V = torch.svd(A, some=some, compute_uv=True)`` (default) should be replaced with + + .. code:: python + + U, S, Vh = torch.linalg.svd(A, full_matrices=not some) + V = Vh.mH + + ``_, S, _ = torch.svd(A, some=some, compute_uv=False)`` should be replaced with + + .. code:: python + + S = torch.linalg.svdvals(A) + + .. note:: Differences with :func:`torch.linalg.svd`: + + * :attr:`some` is the opposite of + :func:`torch.linalg.svd`'s :attr:`full_matrices`. Note that + default value for both is `True`, so the default behavior is + effectively the opposite. + * :func:`torch.svd` returns `V`, whereas :func:`torch.linalg.svd` returns + `Vh`, that is, :math:`V^{\text{H}}`. + * If :attr:`compute_uv` is `False`, :func:`torch.svd` returns zero-filled + tensors for `U` and `Vh`, whereas :func:`torch.linalg.svd` returns + empty tensors. + + .. note:: The singular values are returned in descending order. If :attr:`input` is a batch of matrices, + then the singular values of each matrix in the batch are returned in descending order. + + .. note:: The `S` tensor can only be used to compute gradients if :attr:`compute_uv` is `True`. + + .. note:: When :attr:`some` is `False`, the gradients on `U[..., :, min(m, n):]` + and `V[..., :, min(m, n):]` will be ignored in the backward pass, as those vectors + can be arbitrary bases of the corresponding subspaces. + + .. note:: The implementation of :func:`torch.linalg.svd` on CPU uses LAPACK's routine `?gesdd` + (a divide-and-conquer algorithm) instead of `?gesvd` for speed. Analogously, + on GPU, it uses cuSOLVER's routines `gesvdj` and `gesvdjBatched` on CUDA 10.1.243 + and later, and MAGMA's routine `gesdd` on earlier versions of CUDA. + + .. note:: The returned `U` will not be contiguous. The matrix (or batch of matrices) will + be represented as a column-major matrix (i.e. Fortran-contiguous). + + .. warning:: The gradients with respect to `U` and `V` will only be finite when the input does not + have zero nor repeated singular values. + + .. warning:: If the distance between any two singular values is close to zero, the gradients with respect to + `U` and `V` will be numerically unstable, as they depends on + :math:`\frac{1}{\min_{i \neq j} \sigma_i^2 - \sigma_j^2}`. The same happens when the matrix + has small singular values, as these gradients also depend on `S^{-1}`. + + .. warning:: For complex-valued :attr:`input` the singular value decomposition is not unique, + as `U` and `V` may be multiplied by an arbitrary phase factor :math:`e^{i \phi}` on every column. + The same happens when :attr:`input` has repeated singular values, where one may multiply + the columns of the spanning subspace in `U` and `V` by a rotation matrix + and `the resulting vectors will span the same subspace`_. + Different platforms, like NumPy, or inputs on different device types, + may produce different `U` and `V` tensors. + + Args: + input (Tensor): the input tensor of size `(*, m, n)` where `*` is zero or more + batch dimensions consisting of `(m, n)` matrices. + some (bool, optional): controls whether to compute the reduced or full decomposition, and + consequently, the shape of returned `U` and `V`. Default: `True`. + compute_uv (bool, optional): controls whether to compute `U` and `V`. Default: `True`. + + Keyword args: + out (tuple, optional): the output tuple of tensors + + Example:: + + >>> a = torch.randn(5, 3) + >>> a + tensor([[ 0.2364, -0.7752, 0.6372], + [ 1.7201, 0.7394, -0.0504], + [-0.3371, -1.0584, 0.5296], + [ 0.3550, -0.4022, 1.5569], + [ 0.2445, -0.0158, 1.1414]]) + >>> u, s, v = torch.svd(a) + >>> u + tensor([[ 0.4027, 0.0287, 0.5434], + [-0.1946, 0.8833, 0.3679], + [ 0.4296, -0.2890, 0.5261], + [ 0.6604, 0.2717, -0.2618], + [ 0.4234, 0.2481, -0.4733]]) + >>> s + tensor([2.3289, 2.0315, 0.7806]) + >>> v + tensor([[-0.0199, 0.8766, 0.4809], + [-0.5080, 0.4054, -0.7600], + [ 0.8611, 0.2594, -0.4373]]) + >>> torch.dist(a, torch.mm(torch.mm(u, torch.diag(s)), v.t())) + tensor(8.6531e-07) + >>> a_big = torch.randn(7, 5, 3) + >>> u, s, v = torch.svd(a_big) + >>> torch.dist(a_big, torch.matmul(torch.matmul(u, torch.diag_embed(s)), v.mT)) + tensor(2.6503e-06) + + .. _the resulting vectors will span the same subspace: + (https://en.wikipedia.org/wiki/Singular_value_decomposition#Singular_values,_singular_vectors,_and_their_relation_to_the_SVD) + """ + +def swapaxes(input: Tensor, axis0: _int, axis1: _int) -> Tensor: + r""" + swapaxes(input, axis0, axis1) -> Tensor + + Alias for :func:`torch.transpose`. + + This function is equivalent to NumPy's swapaxes function. + + Examples:: + + >>> x = torch.tensor([[[0,1],[2,3]],[[4,5],[6,7]]]) + >>> x + tensor([[[0, 1], + [2, 3]], + + [[4, 5], + [6, 7]]]) + >>> torch.swapaxes(x, 0, 1) + tensor([[[0, 1], + [4, 5]], + + [[2, 3], + [6, 7]]]) + >>> torch.swapaxes(x, 0, 2) + tensor([[[0, 4], + [2, 6]], + + [[1, 5], + [3, 7]]]) + """ + +def swapdims(input: Tensor, dim0: _int, dim1: _int) -> Tensor: + r""" + swapdims(input, dim0, dim1) -> Tensor + + Alias for :func:`torch.transpose`. + + This function is equivalent to NumPy's swapaxes function. + + Examples:: + + >>> x = torch.tensor([[[0,1],[2,3]],[[4,5],[6,7]]]) + >>> x + tensor([[[0, 1], + [2, 3]], + + [[4, 5], + [6, 7]]]) + >>> torch.swapdims(x, 0, 1) + tensor([[[0, 1], + [4, 5]], + + [[2, 3], + [6, 7]]]) + >>> torch.swapdims(x, 0, 2) + tensor([[[0, 4], + [2, 6]], + + [[1, 5], + [3, 7]]]) + """ + +def sym_constrain_range( + size: Number | _complex, + *, + min: _int | None = None, + max: _int | None = None, +) -> None: ... +def sym_constrain_range_for_size( + size: Number | _complex, + *, + min: _int | None = None, + max: _int | None = None, +) -> None: ... +def t(input: Tensor) -> Tensor: + r""" + t(input) -> Tensor + + Expects :attr:`input` to be <= 2-D tensor and transposes dimensions 0 + and 1. + + 0-D and 1-D tensors are returned as is. When input is a 2-D tensor this + is equivalent to ``transpose(input, 0, 1)``. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> x = torch.randn(()) + >>> x + tensor(0.1995) + >>> torch.t(x) + tensor(0.1995) + >>> x = torch.randn(3) + >>> x + tensor([ 2.4320, -0.4608, 0.7702]) + >>> torch.t(x) + tensor([ 2.4320, -0.4608, 0.7702]) + >>> x = torch.randn(2, 3) + >>> x + tensor([[ 0.4875, 0.9158, -0.5872], + [ 0.3938, -0.6929, 0.6932]]) + >>> torch.t(x) + tensor([[ 0.4875, 0.3938], + [ 0.9158, -0.6929], + [-0.5872, 0.6932]]) + + See also :func:`torch.transpose`. + """ + +def t_copy(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.t`, but all output tensors + are freshly created instead of aliasing the input. + """ + +def take( + input: Tensor, + index: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + take(input, index) -> Tensor + + Returns a new tensor with the elements of :attr:`input` at the given indices. + The input tensor is treated as if it were viewed as a 1-D tensor. The result + takes the same shape as the indices. + + Args: + input (Tensor): the input tensor. + index (LongTensor): the indices into tensor + + Example:: + + >>> src = torch.tensor([[4, 3, 5], + ... [6, 7, 8]]) + >>> torch.take(src, torch.tensor([0, 2, 5])) + tensor([ 4, 5, 8]) + """ + +def take_along_dim( + input: Tensor, + indices: Tensor, + dim: _int | None = None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + take_along_dim(input, indices, dim=None, *, out=None) -> Tensor + + Selects values from :attr:`input` at the 1-dimensional indices from :attr:`indices` along the given :attr:`dim`. + + If :attr:`dim` is None, the input array is treated as if it has been flattened to 1d. + + Functions that return indices along a dimension, like :func:`torch.argmax` and :func:`torch.argsort`, + are designed to work with this function. See the examples below. + + .. note:: + This function is similar to NumPy's `take_along_axis`. + See also :func:`torch.gather`. + + Args: + input (Tensor): the input tensor. + indices (LongTensor): the indices into :attr:`input`. Must have long dtype. + dim (int, optional): dimension to select along. Default: 0 + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.tensor([[10, 30, 20], [60, 40, 50]]) + >>> max_idx = torch.argmax(t) + >>> torch.take_along_dim(t, max_idx) + tensor([60]) + >>> sorted_idx = torch.argsort(t, dim=1) + >>> torch.take_along_dim(t, sorted_idx, dim=1) + tensor([[10, 20, 30], + [40, 50, 60]]) + """ + +def tan(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + tan(input, *, out=None) -> Tensor + + Returns a new tensor with the tangent of the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \tan(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-1.2027, -1.7687, 0.4412, -1.3856]) + >>> torch.tan(a) + tensor([-2.5930, 4.9859, 0.4722, -5.3366]) + """ + +def tan_(input: Tensor) -> Tensor: ... +def tanh(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + tanh(input, *, out=None) -> Tensor + + Returns a new tensor with the hyperbolic tangent of the elements + of :attr:`input`. + + .. math:: + \text{out}_{i} = \tanh(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.8986, -0.7279, 1.1745, 0.2611]) + >>> torch.tanh(a) + tensor([ 0.7156, -0.6218, 0.8257, 0.2553]) + """ + +def tanh_(input: Tensor) -> Tensor: ... +def tensor( + data: Any, + dtype: _dtype | None = None, + device: DeviceLikeType | None = None, + requires_grad: _bool = False, + pin_memory: _bool = False, +) -> Tensor: + r""" + tensor(data, *, dtype=None, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Constructs a tensor with no autograd history (also known as a "leaf tensor", see :doc:`/notes/autograd`) by copying :attr:`data`. + + .. warning:: + + When working with tensors prefer using :func:`torch.Tensor.clone`, + :func:`torch.Tensor.detach`, and :func:`torch.Tensor.requires_grad_` for + readability. Letting `t` be a tensor, ``torch.tensor(t)`` is equivalent to + ``t.detach().clone()``, and ``torch.tensor(t, requires_grad=True)`` + is equivalent to ``t.detach().clone().requires_grad_(True)``. + + .. seealso:: + + :func:`torch.as_tensor` preserves autograd history and avoids copies where possible. + :func:`torch.from_numpy` creates a tensor that shares storage with a NumPy array. + + Args: + data (array_like): Initial data for the tensor. Can be a list, tuple, + NumPy ``ndarray``, scalar, and other types. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, infers data type from :attr:`data`. + device (:class:`torch.device`, optional): the device of the constructed tensor. If None and data is a tensor + then the device of data is used. If None and data is not a tensor then + the result tensor is constructed on the current device. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + + Example:: + + >>> torch.tensor([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]]) + tensor([[ 0.1000, 1.2000], + [ 2.2000, 3.1000], + [ 4.9000, 5.2000]]) + + >>> torch.tensor([0, 1]) # Type inference on data + tensor([ 0, 1]) + + >>> torch.tensor([[0.11111, 0.222222, 0.3333333]], + ... dtype=torch.float64, + ... device=torch.device('cuda:0')) # creates a double tensor on a CUDA device + tensor([[ 0.1111, 0.2222, 0.3333]], dtype=torch.float64, device='cuda:0') + + >>> torch.tensor(3.14159) # Create a zero-dimensional (scalar) tensor + tensor(3.1416) + + >>> torch.tensor([]) # Create an empty tensor (of size (0,)) + tensor([]) + """ + +@overload +def tensor_split( + input: Tensor, + tensor_indices_or_sections: Tensor, + dim: _int = 0, +) -> tuple[Tensor, ...]: + r""" + tensor_split(input, indices_or_sections, dim=0) -> List of Tensors + + Splits a tensor into multiple sub-tensors, all of which are views of :attr:`input`, + along dimension :attr:`dim` according to the indices or number of sections specified + by :attr:`indices_or_sections`. This function is based on NumPy's + :func:`numpy.array_split`. + + Args: + input (Tensor): the tensor to split + indices_or_sections (Tensor, int or list or tuple of ints): + If :attr:`indices_or_sections` is an integer ``n`` or a zero dimensional long tensor + with value ``n``, :attr:`input` is split into ``n`` sections along dimension :attr:`dim`. + If :attr:`input` is divisible by ``n`` along dimension :attr:`dim`, each + section will be of equal size, :code:`input.size(dim) / n`. If :attr:`input` + is not divisible by ``n``, the sizes of the first :code:`int(input.size(dim) % n)` + sections will have size :code:`int(input.size(dim) / n) + 1`, and the rest will + have size :code:`int(input.size(dim) / n)`. + + If :attr:`indices_or_sections` is a list or tuple of ints, or a one-dimensional long + tensor, then :attr:`input` is split along dimension :attr:`dim` at each of the indices + in the list, tuple or tensor. For instance, :code:`indices_or_sections=[2, 3]` and :code:`dim=0` + would result in the tensors :code:`input[:2]`, :code:`input[2:3]`, and :code:`input[3:]`. + + If :attr:`indices_or_sections` is a tensor, it must be a zero-dimensional or one-dimensional + long tensor on the CPU. + + dim (int, optional): dimension along which to split the tensor. Default: ``0`` + + Example:: + + >>> x = torch.arange(8) + >>> torch.tensor_split(x, 3) + (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7])) + + >>> x = torch.arange(7) + >>> torch.tensor_split(x, 3) + (tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6])) + >>> torch.tensor_split(x, (1, 6)) + (tensor([0]), tensor([1, 2, 3, 4, 5]), tensor([6])) + + >>> x = torch.arange(14).reshape(2, 7) + >>> x + tensor([[ 0, 1, 2, 3, 4, 5, 6], + [ 7, 8, 9, 10, 11, 12, 13]]) + >>> torch.tensor_split(x, 3, dim=1) + (tensor([[0, 1, 2], + [7, 8, 9]]), + tensor([[ 3, 4], + [10, 11]]), + tensor([[ 5, 6], + [12, 13]])) + >>> torch.tensor_split(x, (1, 6), dim=1) + (tensor([[0], + [7]]), + tensor([[ 1, 2, 3, 4, 5], + [ 8, 9, 10, 11, 12]]), + tensor([[ 6], + [13]])) + """ + +@overload +def tensor_split( + input: Tensor, + sections: _int | SymInt, + dim: _int = 0, +) -> tuple[Tensor, ...]: + r""" + tensor_split(input, indices_or_sections, dim=0) -> List of Tensors + + Splits a tensor into multiple sub-tensors, all of which are views of :attr:`input`, + along dimension :attr:`dim` according to the indices or number of sections specified + by :attr:`indices_or_sections`. This function is based on NumPy's + :func:`numpy.array_split`. + + Args: + input (Tensor): the tensor to split + indices_or_sections (Tensor, int or list or tuple of ints): + If :attr:`indices_or_sections` is an integer ``n`` or a zero dimensional long tensor + with value ``n``, :attr:`input` is split into ``n`` sections along dimension :attr:`dim`. + If :attr:`input` is divisible by ``n`` along dimension :attr:`dim`, each + section will be of equal size, :code:`input.size(dim) / n`. If :attr:`input` + is not divisible by ``n``, the sizes of the first :code:`int(input.size(dim) % n)` + sections will have size :code:`int(input.size(dim) / n) + 1`, and the rest will + have size :code:`int(input.size(dim) / n)`. + + If :attr:`indices_or_sections` is a list or tuple of ints, or a one-dimensional long + tensor, then :attr:`input` is split along dimension :attr:`dim` at each of the indices + in the list, tuple or tensor. For instance, :code:`indices_or_sections=[2, 3]` and :code:`dim=0` + would result in the tensors :code:`input[:2]`, :code:`input[2:3]`, and :code:`input[3:]`. + + If :attr:`indices_or_sections` is a tensor, it must be a zero-dimensional or one-dimensional + long tensor on the CPU. + + dim (int, optional): dimension along which to split the tensor. Default: ``0`` + + Example:: + + >>> x = torch.arange(8) + >>> torch.tensor_split(x, 3) + (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7])) + + >>> x = torch.arange(7) + >>> torch.tensor_split(x, 3) + (tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6])) + >>> torch.tensor_split(x, (1, 6)) + (tensor([0]), tensor([1, 2, 3, 4, 5]), tensor([6])) + + >>> x = torch.arange(14).reshape(2, 7) + >>> x + tensor([[ 0, 1, 2, 3, 4, 5, 6], + [ 7, 8, 9, 10, 11, 12, 13]]) + >>> torch.tensor_split(x, 3, dim=1) + (tensor([[0, 1, 2], + [7, 8, 9]]), + tensor([[ 3, 4], + [10, 11]]), + tensor([[ 5, 6], + [12, 13]])) + >>> torch.tensor_split(x, (1, 6), dim=1) + (tensor([[0], + [7]]), + tensor([[ 1, 2, 3, 4, 5], + [ 8, 9, 10, 11, 12]]), + tensor([[ 6], + [13]])) + """ + +@overload +def tensor_split( + input: Tensor, + indices: Sequence[_int | SymInt], + dim: _int = 0, +) -> tuple[Tensor, ...]: + r""" + tensor_split(input, indices_or_sections, dim=0) -> List of Tensors + + Splits a tensor into multiple sub-tensors, all of which are views of :attr:`input`, + along dimension :attr:`dim` according to the indices or number of sections specified + by :attr:`indices_or_sections`. This function is based on NumPy's + :func:`numpy.array_split`. + + Args: + input (Tensor): the tensor to split + indices_or_sections (Tensor, int or list or tuple of ints): + If :attr:`indices_or_sections` is an integer ``n`` or a zero dimensional long tensor + with value ``n``, :attr:`input` is split into ``n`` sections along dimension :attr:`dim`. + If :attr:`input` is divisible by ``n`` along dimension :attr:`dim`, each + section will be of equal size, :code:`input.size(dim) / n`. If :attr:`input` + is not divisible by ``n``, the sizes of the first :code:`int(input.size(dim) % n)` + sections will have size :code:`int(input.size(dim) / n) + 1`, and the rest will + have size :code:`int(input.size(dim) / n)`. + + If :attr:`indices_or_sections` is a list or tuple of ints, or a one-dimensional long + tensor, then :attr:`input` is split along dimension :attr:`dim` at each of the indices + in the list, tuple or tensor. For instance, :code:`indices_or_sections=[2, 3]` and :code:`dim=0` + would result in the tensors :code:`input[:2]`, :code:`input[2:3]`, and :code:`input[3:]`. + + If :attr:`indices_or_sections` is a tensor, it must be a zero-dimensional or one-dimensional + long tensor on the CPU. + + dim (int, optional): dimension along which to split the tensor. Default: ``0`` + + Example:: + + >>> x = torch.arange(8) + >>> torch.tensor_split(x, 3) + (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7])) + + >>> x = torch.arange(7) + >>> torch.tensor_split(x, 3) + (tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6])) + >>> torch.tensor_split(x, (1, 6)) + (tensor([0]), tensor([1, 2, 3, 4, 5]), tensor([6])) + + >>> x = torch.arange(14).reshape(2, 7) + >>> x + tensor([[ 0, 1, 2, 3, 4, 5, 6], + [ 7, 8, 9, 10, 11, 12, 13]]) + >>> torch.tensor_split(x, 3, dim=1) + (tensor([[0, 1, 2], + [7, 8, 9]]), + tensor([[ 3, 4], + [10, 11]]), + tensor([[ 5, 6], + [12, 13]])) + >>> torch.tensor_split(x, (1, 6), dim=1) + (tensor([[0], + [7]]), + tensor([[ 1, 2, 3, 4, 5], + [ 8, 9, 10, 11, 12]]), + tensor([[ 6], + [13]])) + """ + +def threshold( + input: Tensor, + threshold: Number | _complex, + value: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: ... +def threshold_( + input: Tensor, + threshold: Number | _complex, + value: Number | _complex, +) -> Tensor: ... +def tile(input: Tensor, dims: Sequence[_int | SymInt]) -> Tensor: + r""" + tile(input, dims) -> Tensor + + Constructs a tensor by repeating the elements of :attr:`input`. + The :attr:`dims` argument specifies the number of repetitions + in each dimension. + + If :attr:`dims` specifies fewer dimensions than :attr:`input` has, then + ones are prepended to :attr:`dims` until all dimensions are specified. + For example, if :attr:`input` has shape (8, 6, 4, 2) and :attr:`dims` + is (2, 2), then :attr:`dims` is treated as (1, 1, 2, 2). + + Analogously, if :attr:`input` has fewer dimensions than :attr:`dims` + specifies, then :attr:`input` is treated as if it were unsqueezed at + dimension zero until it has as many dimensions as :attr:`dims` specifies. + For example, if :attr:`input` has shape (4, 2) and :attr:`dims` + is (3, 3, 2, 2), then :attr:`input` is treated as if it had the + shape (1, 1, 4, 2). + + .. note:: + + This function is similar to NumPy's tile function. + + Args: + input (Tensor): the tensor whose elements to repeat. + dims (tuple): the number of repetitions per dimension. + + Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> x.tile((2,)) + tensor([1, 2, 3, 1, 2, 3]) + >>> y = torch.tensor([[1, 2], [3, 4]]) + >>> torch.tile(y, (2, 2)) + tensor([[1, 2, 1, 2], + [3, 4, 3, 4], + [1, 2, 1, 2], + [3, 4, 3, 4]]) + """ + +def topk( + input: Tensor, + k: _int | SymInt, + dim: _int = -1, + largest: _bool = True, + sorted: _bool = True, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.topk: + r""" + topk(input, k, dim=None, largest=True, sorted=True, *, out=None) -> (Tensor, LongTensor) + + Returns the :attr:`k` largest elements of the given :attr:`input` tensor along + a given dimension. + + If :attr:`dim` is not given, the last dimension of the `input` is chosen. + + If :attr:`largest` is ``False`` then the `k` smallest elements are returned. + + A namedtuple of `(values, indices)` is returned with the `values` and + `indices` of the largest `k` elements of each row of the `input` tensor in the + given dimension `dim`. + + The boolean option :attr:`sorted` if ``True``, will make sure that the returned + `k` elements are themselves sorted + + .. note:: + When using `torch.topk`, the indices of tied elements are not guaranteed to be stable + and may vary across different invocations. + + Args: + input (Tensor): the input tensor. + k (int): the k in "top-k" + dim (int, optional): the dimension to sort along + largest (bool, optional): controls whether to return largest or + smallest elements + sorted (bool, optional): controls whether to return the elements + in sorted order + + Keyword args: + out (tuple, optional): the output tuple of (Tensor, LongTensor) that can be + optionally given to be used as output buffers + + Example:: + + >>> x = torch.arange(1., 6.) + >>> x + tensor([ 1., 2., 3., 4., 5.]) + >>> torch.topk(x, 3) + torch.return_types.topk(values=tensor([5., 4., 3.]), indices=tensor([4, 3, 2])) + """ + +def trace(input: Tensor) -> Tensor: + r""" + trace(input) -> Tensor + + Returns the sum of the elements of the diagonal of the input 2-D matrix. + + Example:: + + >>> x = torch.arange(1., 10.).view(3, 3) + >>> x + tensor([[ 1., 2., 3.], + [ 4., 5., 6.], + [ 7., 8., 9.]]) + >>> torch.trace(x) + tensor(15.) + """ + +@overload +def transpose(input: Tensor, dim0: _int, dim1: _int) -> Tensor: + r""" + transpose(input, dim0, dim1) -> Tensor + + Returns a tensor that is a transposed version of :attr:`input`. + The given dimensions :attr:`dim0` and :attr:`dim1` are swapped. + + If :attr:`input` is a strided tensor then the resulting :attr:`out` + tensor shares its underlying storage with the :attr:`input` tensor, so + changing the content of one would change the content of the other. + + If :attr:`input` is a :ref:`sparse tensor ` then the + resulting :attr:`out` tensor *does not* share the underlying storage + with the :attr:`input` tensor. + + If :attr:`input` is a :ref:`sparse tensor ` with compressed + layout (SparseCSR, SparseBSR, SparseCSC or SparseBSC) the arguments + :attr:`dim0` and :attr:`dim1` must be both batch dimensions, or must + both be sparse dimensions. The batch dimensions of a sparse tensor are the + dimensions preceding the sparse dimensions. + + .. note:: + Transpositions which interchange the sparse dimensions of a `SparseCSR` + or `SparseCSC` layout tensor will result in the layout changing between + the two options. Transposition of the sparse dimensions of a ` SparseBSR` + or `SparseBSC` layout tensor will likewise generate a result with the + opposite layout. + + + Args: + input (Tensor): the input tensor. + dim0 (int): the first dimension to be transposed + dim1 (int): the second dimension to be transposed + + Example:: + + >>> x = torch.randn(2, 3) + >>> x + tensor([[ 1.0028, -0.9893, 0.5809], + [-0.1669, 0.7299, 0.4942]]) + >>> torch.transpose(x, 0, 1) + tensor([[ 1.0028, -0.1669], + [-0.9893, 0.7299], + [ 0.5809, 0.4942]]) + + See also :func:`torch.t`. + """ + +@overload +def transpose( + input: Tensor, + dim0: str | EllipsisType | None, + dim1: str | EllipsisType | None, +) -> Tensor: + r""" + transpose(input, dim0, dim1) -> Tensor + + Returns a tensor that is a transposed version of :attr:`input`. + The given dimensions :attr:`dim0` and :attr:`dim1` are swapped. + + If :attr:`input` is a strided tensor then the resulting :attr:`out` + tensor shares its underlying storage with the :attr:`input` tensor, so + changing the content of one would change the content of the other. + + If :attr:`input` is a :ref:`sparse tensor ` then the + resulting :attr:`out` tensor *does not* share the underlying storage + with the :attr:`input` tensor. + + If :attr:`input` is a :ref:`sparse tensor ` with compressed + layout (SparseCSR, SparseBSR, SparseCSC or SparseBSC) the arguments + :attr:`dim0` and :attr:`dim1` must be both batch dimensions, or must + both be sparse dimensions. The batch dimensions of a sparse tensor are the + dimensions preceding the sparse dimensions. + + .. note:: + Transpositions which interchange the sparse dimensions of a `SparseCSR` + or `SparseCSC` layout tensor will result in the layout changing between + the two options. Transposition of the sparse dimensions of a ` SparseBSR` + or `SparseBSC` layout tensor will likewise generate a result with the + opposite layout. + + + Args: + input (Tensor): the input tensor. + dim0 (int): the first dimension to be transposed + dim1 (int): the second dimension to be transposed + + Example:: + + >>> x = torch.randn(2, 3) + >>> x + tensor([[ 1.0028, -0.9893, 0.5809], + [-0.1669, 0.7299, 0.4942]]) + >>> torch.transpose(x, 0, 1) + tensor([[ 1.0028, -0.1669], + [-0.9893, 0.7299], + [ 0.5809, 0.4942]]) + + See also :func:`torch.t`. + """ + +def transpose_copy( + input: Tensor, + dim0: _int, + dim1: _int, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + Performs the same operation as :func:`torch.transpose`, but all output tensors + are freshly created instead of aliasing the input. + """ + +@overload +def trapezoid(y: Tensor, x: Tensor, *, dim: _int = -1) -> Tensor: + r""" + trapezoid(y, x=None, *, dx=None, dim=-1) -> Tensor + + Computes the `trapezoidal rule `_ along + :attr:`dim`. By default the spacing between elements is assumed to be 1, but + :attr:`dx` can be used to specify a different constant spacing, and :attr:`x` can be + used to specify arbitrary spacing along :attr:`dim`. Only one of :attr:`x` or :attr:`dx` should be specified. + + + Assuming :attr:`y` is a one-dimensional tensor with elements :math:`{y_0, y_1, ..., y_n}`, + the default computation is + + .. math:: + \begin{aligned} + \sum_{i = 1}^{n} \frac{1}{2} (y_i + y_{i-1}) + \end{aligned} + + When :attr:`dx` is specified the computation becomes + + .. math:: + \begin{aligned} + \sum_{i = 1}^{n} \frac{\Delta x}{2} (y_i + y_{i-1}) + \end{aligned} + + effectively multiplying the result by :attr:`dx`. When :attr:`x` is specified, + assuming :attr:`x` is also a one-dimensional tensor with + elements :math:`{x_0, x_1, ..., x_n}`, the computation becomes + + .. math:: + \begin{aligned} + \sum_{i = 1}^{n} \frac{(x_i - x_{i-1})}{2} (y_i + y_{i-1}) + \end{aligned} + + When :attr:`x` and :attr:`y` have the same size, the computation is as described above and no broadcasting is needed. + The broadcasting behavior of this function is as follows when their sizes are different. For both :attr:`x` + and :attr:`y`, the function computes the difference between consecutive elements along + dimension :attr:`dim`. This effectively creates two tensors, `x_diff` and `y_diff`, that have + the same shape as the original tensors except their lengths along the dimension :attr:`dim` is reduced by 1. + After that, those two tensors are broadcast together to compute final output as part of the trapezoidal rule. + See the examples below for details. + + .. note:: + The trapezoidal rule is a technique for approximating the definite integral of a function + by averaging its left and right Riemann sums. The approximation becomes more accurate as + the resolution of the partition increases. + + Arguments: + y (Tensor): Values to use when computing the trapezoidal rule. + x (Tensor): If specified, defines spacing between values as specified above. + + Keyword arguments: + dx (float): constant spacing between values. If neither :attr:`x` or :attr:`dx` + are specified then this defaults to 1. Effectively multiplies the result by its value. + dim (int): The dimension along which to compute the trapezoidal rule. + The last (inner-most) dimension by default. + + Examples:: + + >>> # Computes the trapezoidal rule in 1D, spacing is implicitly 1 + >>> y = torch.tensor([1, 5, 10]) + >>> torch.trapezoid(y) + tensor(10.5) + + >>> # Computes the same trapezoidal rule directly to verify + >>> (1 + 10 + 10) / 2 + 10.5 + + >>> # Computes the trapezoidal rule in 1D with constant spacing of 2 + >>> # NOTE: the result is the same as before, but multiplied by 2 + >>> torch.trapezoid(y, dx=2) + 21.0 + + >>> # Computes the trapezoidal rule in 1D with arbitrary spacing + >>> x = torch.tensor([1, 3, 6]) + >>> torch.trapezoid(y, x) + 28.5 + + >>> # Computes the same trapezoidal rule directly to verify + >>> ((3 - 1) * (1 + 5) + (6 - 3) * (5 + 10)) / 2 + 28.5 + + >>> # Computes the trapezoidal rule for each row of a 3x3 matrix + >>> y = torch.arange(9).reshape(3, 3) + tensor([[0, 1, 2], + [3, 4, 5], + [6, 7, 8]]) + >>> torch.trapezoid(y) + tensor([ 2., 8., 14.]) + + >>> # Computes the trapezoidal rule for each column of the matrix + >>> torch.trapezoid(y, dim=0) + tensor([ 6., 8., 10.]) + + >>> # Computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with the same arbitrary spacing + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([1, 3, 6]) + >>> torch.trapezoid(y, x) + array([5., 5., 5.]) + + >>> # Computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with different arbitrary spacing per row + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([[1, 2, 3], [1, 3, 5], [1, 4, 7]]) + >>> torch.trapezoid(y, x) + array([2., 4., 6.]) + """ + +@overload +def trapezoid( + y: Tensor, + *, + dx: Number | _complex = 1, + dim: _int = -1, +) -> Tensor: + r""" + trapezoid(y, x=None, *, dx=None, dim=-1) -> Tensor + + Computes the `trapezoidal rule `_ along + :attr:`dim`. By default the spacing between elements is assumed to be 1, but + :attr:`dx` can be used to specify a different constant spacing, and :attr:`x` can be + used to specify arbitrary spacing along :attr:`dim`. Only one of :attr:`x` or :attr:`dx` should be specified. + + + Assuming :attr:`y` is a one-dimensional tensor with elements :math:`{y_0, y_1, ..., y_n}`, + the default computation is + + .. math:: + \begin{aligned} + \sum_{i = 1}^{n} \frac{1}{2} (y_i + y_{i-1}) + \end{aligned} + + When :attr:`dx` is specified the computation becomes + + .. math:: + \begin{aligned} + \sum_{i = 1}^{n} \frac{\Delta x}{2} (y_i + y_{i-1}) + \end{aligned} + + effectively multiplying the result by :attr:`dx`. When :attr:`x` is specified, + assuming :attr:`x` is also a one-dimensional tensor with + elements :math:`{x_0, x_1, ..., x_n}`, the computation becomes + + .. math:: + \begin{aligned} + \sum_{i = 1}^{n} \frac{(x_i - x_{i-1})}{2} (y_i + y_{i-1}) + \end{aligned} + + When :attr:`x` and :attr:`y` have the same size, the computation is as described above and no broadcasting is needed. + The broadcasting behavior of this function is as follows when their sizes are different. For both :attr:`x` + and :attr:`y`, the function computes the difference between consecutive elements along + dimension :attr:`dim`. This effectively creates two tensors, `x_diff` and `y_diff`, that have + the same shape as the original tensors except their lengths along the dimension :attr:`dim` is reduced by 1. + After that, those two tensors are broadcast together to compute final output as part of the trapezoidal rule. + See the examples below for details. + + .. note:: + The trapezoidal rule is a technique for approximating the definite integral of a function + by averaging its left and right Riemann sums. The approximation becomes more accurate as + the resolution of the partition increases. + + Arguments: + y (Tensor): Values to use when computing the trapezoidal rule. + x (Tensor): If specified, defines spacing between values as specified above. + + Keyword arguments: + dx (float): constant spacing between values. If neither :attr:`x` or :attr:`dx` + are specified then this defaults to 1. Effectively multiplies the result by its value. + dim (int): The dimension along which to compute the trapezoidal rule. + The last (inner-most) dimension by default. + + Examples:: + + >>> # Computes the trapezoidal rule in 1D, spacing is implicitly 1 + >>> y = torch.tensor([1, 5, 10]) + >>> torch.trapezoid(y) + tensor(10.5) + + >>> # Computes the same trapezoidal rule directly to verify + >>> (1 + 10 + 10) / 2 + 10.5 + + >>> # Computes the trapezoidal rule in 1D with constant spacing of 2 + >>> # NOTE: the result is the same as before, but multiplied by 2 + >>> torch.trapezoid(y, dx=2) + 21.0 + + >>> # Computes the trapezoidal rule in 1D with arbitrary spacing + >>> x = torch.tensor([1, 3, 6]) + >>> torch.trapezoid(y, x) + 28.5 + + >>> # Computes the same trapezoidal rule directly to verify + >>> ((3 - 1) * (1 + 5) + (6 - 3) * (5 + 10)) / 2 + 28.5 + + >>> # Computes the trapezoidal rule for each row of a 3x3 matrix + >>> y = torch.arange(9).reshape(3, 3) + tensor([[0, 1, 2], + [3, 4, 5], + [6, 7, 8]]) + >>> torch.trapezoid(y) + tensor([ 2., 8., 14.]) + + >>> # Computes the trapezoidal rule for each column of the matrix + >>> torch.trapezoid(y, dim=0) + tensor([ 6., 8., 10.]) + + >>> # Computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with the same arbitrary spacing + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([1, 3, 6]) + >>> torch.trapezoid(y, x) + array([5., 5., 5.]) + + >>> # Computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with different arbitrary spacing per row + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([[1, 2, 3], [1, 3, 5], [1, 4, 7]]) + >>> torch.trapezoid(y, x) + array([2., 4., 6.]) + """ + +@overload +def trapz(y: Tensor, *, dx: _float = 1, dim: _int = -1) -> Tensor: + r""" + trapz(y, x, *, dim=-1) -> Tensor + + Alias for :func:`torch.trapezoid`. + """ + +@overload +def trapz(y: Tensor, x: Tensor, *, dim: _int = -1) -> Tensor: + r""" + trapz(y, x, *, dim=-1) -> Tensor + + Alias for :func:`torch.trapezoid`. + """ + +def triangular_solve( + input: Tensor, + A: Tensor, + upper: _bool = True, + transpose: _bool = False, + unitriangular: _bool = False, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.triangular_solve: + r""" + triangular_solve(b, A, upper=True, transpose=False, unitriangular=False, *, out=None) -> (Tensor, Tensor) + + Solves a system of equations with a square upper or lower triangular invertible matrix :math:`A` + and multiple right-hand sides :math:`b`. + + In symbols, it solves :math:`AX = b` and assumes :math:`A` is square upper-triangular + (or lower-triangular if :attr:`upper`\ `= False`) and does not have zeros on the diagonal. + + `torch.triangular_solve(b, A)` can take in 2D inputs `b, A` or inputs that are + batches of 2D matrices. If the inputs are batches, then returns + batched outputs `X` + + If the diagonal of :attr:`A` contains zeros or elements that are very close to zero and + :attr:`unitriangular`\ `= False` (default) or if the input matrix is badly conditioned, + the result may contain `NaN` s. + + Supports input of float, double, cfloat and cdouble data types. + + .. warning:: + + :func:`torch.triangular_solve` is deprecated in favor of :func:`torch.linalg.solve_triangular` + and will be removed in a future PyTorch release. + :func:`torch.linalg.solve_triangular` has its arguments reversed and does not return a + copy of one of the inputs. + + ``X = torch.triangular_solve(B, A).solution`` should be replaced with + + .. code:: python + + X = torch.linalg.solve_triangular(A, B) + + Args: + b (Tensor): multiple right-hand sides of size :math:`(*, m, k)` where + :math:`*` is zero of more batch dimensions + A (Tensor): the input triangular coefficient matrix of size :math:`(*, m, m)` + where :math:`*` is zero or more batch dimensions + upper (bool, optional): whether :math:`A` is upper or lower triangular. Default: ``True``. + transpose (bool, optional): solves `op(A)X = b` where `op(A) = A^T` if this flag is ``True``, + and `op(A) = A` if it is ``False``. Default: ``False``. + unitriangular (bool, optional): whether :math:`A` is unit triangular. + If True, the diagonal elements of :math:`A` are assumed to be + 1 and not referenced from :math:`A`. Default: ``False``. + + Keyword args: + out ((Tensor, Tensor), optional): tuple of two tensors to write + the output to. Ignored if `None`. Default: `None`. + + Returns: + A namedtuple `(solution, cloned_coefficient)` where `cloned_coefficient` + is a clone of :math:`A` and `solution` is the solution :math:`X` to :math:`AX = b` + (or whatever variant of the system of equations, depending on the keyword arguments.) + + Examples:: + + >>> A = torch.randn(2, 2).triu() + >>> A + tensor([[ 1.1527, -1.0753], + [ 0.0000, 0.7986]]) + >>> b = torch.randn(2, 3) + >>> b + tensor([[-0.0210, 2.3513, -1.5492], + [ 1.5429, 0.7403, -1.0243]]) + >>> torch.triangular_solve(b, A) + torch.return_types.triangular_solve( + solution=tensor([[ 1.7841, 2.9046, -2.5405], + [ 1.9320, 0.9270, -1.2826]]), + cloned_coefficient=tensor([[ 1.1527, -1.0753], + [ 0.0000, 0.7986]])) + """ + +def tril( + input: Tensor, + diagonal: _int = 0, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + tril(input, diagonal=0, *, out=None) -> Tensor + + Returns the lower triangular part of the matrix (2-D tensor) or batch of matrices + :attr:`input`, the other elements of the result tensor :attr:`out` are set to 0. + + The lower triangular part of the matrix is defined as the elements on and + below the diagonal. + + The argument :attr:`diagonal` controls which diagonal to consider. If + :attr:`diagonal` = 0, all elements on and below the main diagonal are + retained. A positive value includes just as many diagonals above the main + diagonal, and similarly a negative value excludes just as many diagonals below + the main diagonal. The main diagonal are the set of indices + :math:`\lbrace (i, i) \rbrace` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]` where + :math:`d_{1}, d_{2}` are the dimensions of the matrix. + + Args: + input (Tensor): the input tensor. + diagonal (int, optional): the diagonal to consider + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(3, 3) + >>> a + tensor([[-1.0813, -0.8619, 0.7105], + [ 0.0935, 0.1380, 2.2112], + [-0.3409, -0.9828, 0.0289]]) + >>> torch.tril(a) + tensor([[-1.0813, 0.0000, 0.0000], + [ 0.0935, 0.1380, 0.0000], + [-0.3409, -0.9828, 0.0289]]) + + >>> b = torch.randn(4, 6) + >>> b + tensor([[ 1.2219, 0.5653, -0.2521, -0.2345, 1.2544, 0.3461], + [ 0.4785, -0.4477, 0.6049, 0.6368, 0.8775, 0.7145], + [ 1.1502, 3.2716, -1.1243, -0.5413, 0.3615, 0.6864], + [-0.0614, -0.7344, -1.3164, -0.7648, -1.4024, 0.0978]]) + >>> torch.tril(b, diagonal=1) + tensor([[ 1.2219, 0.5653, 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.4785, -0.4477, 0.6049, 0.0000, 0.0000, 0.0000], + [ 1.1502, 3.2716, -1.1243, -0.5413, 0.0000, 0.0000], + [-0.0614, -0.7344, -1.3164, -0.7648, -1.4024, 0.0000]]) + >>> torch.tril(b, diagonal=-1) + tensor([[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.4785, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [ 1.1502, 3.2716, 0.0000, 0.0000, 0.0000, 0.0000], + [-0.0614, -0.7344, -1.3164, 0.0000, 0.0000, 0.0000]]) + """ + +def tril_indices( + row: _int, + col: _int, + offset: _int = 0, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + tril_indices(row, col, offset=0, *, dtype=torch.long, device='cpu', layout=torch.strided) -> Tensor + + Returns the indices of the lower triangular part of a :attr:`row`-by- + :attr:`col` matrix in a 2-by-N Tensor, where the first row contains row + coordinates of all indices and the second row contains column coordinates. + Indices are ordered based on rows and then columns. + + The lower triangular part of the matrix is defined as the elements on and + below the diagonal. + + The argument :attr:`offset` controls which diagonal to consider. If + :attr:`offset` = 0, all elements on and below the main diagonal are + retained. A positive value includes just as many diagonals above the main + diagonal, and similarly a negative value excludes just as many diagonals below + the main diagonal. The main diagonal are the set of indices + :math:`\lbrace (i, i) \rbrace` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]` + where :math:`d_{1}, d_{2}` are the dimensions of the matrix. + + .. note:: + When running on CUDA, ``row * col`` must be less than :math:`2^{59}` to + prevent overflow during calculation. + + Args: + row (``int``): number of rows in the 2-D matrix. + col (``int``): number of columns in the 2-D matrix. + offset (``int``): diagonal offset from the main diagonal. + Default: if not provided, 0. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor, + only support ``torch.int``, ``torch.long``. Default: if ``None``, ``torch.long``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + layout (:class:`torch.layout`, optional): currently only support ``torch.strided``. + + Example:: + + >>> a = torch.tril_indices(3, 3) + >>> a + tensor([[0, 1, 1, 2, 2, 2], + [0, 0, 1, 0, 1, 2]]) + + >>> a = torch.tril_indices(4, 3, -1) + >>> a + tensor([[1, 2, 2, 3, 3, 3], + [0, 0, 1, 0, 1, 2]]) + + >>> a = torch.tril_indices(4, 3, 1) + >>> a + tensor([[0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], + [0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2]]) + """ + +def triplet_margin_loss( + anchor: Tensor, + positive: Tensor, + negative: Tensor, + margin: _float = 1.0, + p: _float = 2, + eps: _float = 1e-06, + swap: _bool = False, + reduction: _int = 1, +) -> Tensor: ... +def triu( + input: Tensor, + diagonal: _int = 0, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + triu(input, diagonal=0, *, out=None) -> Tensor + + Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices + :attr:`input`, the other elements of the result tensor :attr:`out` are set to 0. + + The upper triangular part of the matrix is defined as the elements on and + above the diagonal. + + The argument :attr:`diagonal` controls which diagonal to consider. If + :attr:`diagonal` = 0, all elements on and above the main diagonal are + retained. A positive value excludes just as many diagonals above the main + diagonal, and similarly a negative value includes just as many diagonals below + the main diagonal. The main diagonal are the set of indices + :math:`\lbrace (i, i) \rbrace` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]` where + :math:`d_{1}, d_{2}` are the dimensions of the matrix. + + Args: + input (Tensor): the input tensor. + diagonal (int, optional): the diagonal to consider + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(3, 3) + >>> a + tensor([[ 0.2309, 0.5207, 2.0049], + [ 0.2072, -1.0680, 0.6602], + [ 0.3480, -0.5211, -0.4573]]) + >>> torch.triu(a) + tensor([[ 0.2309, 0.5207, 2.0049], + [ 0.0000, -1.0680, 0.6602], + [ 0.0000, 0.0000, -0.4573]]) + >>> torch.triu(a, diagonal=1) + tensor([[ 0.0000, 0.5207, 2.0049], + [ 0.0000, 0.0000, 0.6602], + [ 0.0000, 0.0000, 0.0000]]) + >>> torch.triu(a, diagonal=-1) + tensor([[ 0.2309, 0.5207, 2.0049], + [ 0.2072, -1.0680, 0.6602], + [ 0.0000, -0.5211, -0.4573]]) + + >>> b = torch.randn(4, 6) + >>> b + tensor([[ 0.5876, -0.0794, -1.8373, 0.6654, 0.2604, 1.5235], + [-0.2447, 0.9556, -1.2919, 1.3378, -0.1768, -1.0857], + [ 0.4333, 0.3146, 0.6576, -1.0432, 0.9348, -0.4410], + [-0.9888, 1.0679, -1.3337, -1.6556, 0.4798, 0.2830]]) + >>> torch.triu(b, diagonal=1) + tensor([[ 0.0000, -0.0794, -1.8373, 0.6654, 0.2604, 1.5235], + [ 0.0000, 0.0000, -1.2919, 1.3378, -0.1768, -1.0857], + [ 0.0000, 0.0000, 0.0000, -1.0432, 0.9348, -0.4410], + [ 0.0000, 0.0000, 0.0000, 0.0000, 0.4798, 0.2830]]) + >>> torch.triu(b, diagonal=-1) + tensor([[ 0.5876, -0.0794, -1.8373, 0.6654, 0.2604, 1.5235], + [-0.2447, 0.9556, -1.2919, 1.3378, -0.1768, -1.0857], + [ 0.0000, 0.3146, 0.6576, -1.0432, 0.9348, -0.4410], + [ 0.0000, 0.0000, -1.3337, -1.6556, 0.4798, 0.2830]]) + """ + +def triu_indices( + row: _int, + col: _int, + offset: _int = 0, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + triu_indices(row, col, offset=0, *, dtype=torch.long, device='cpu', layout=torch.strided) -> Tensor + + Returns the indices of the upper triangular part of a :attr:`row` by + :attr:`col` matrix in a 2-by-N Tensor, where the first row contains row + coordinates of all indices and the second row contains column coordinates. + Indices are ordered based on rows and then columns. + + The upper triangular part of the matrix is defined as the elements on and + above the diagonal. + + The argument :attr:`offset` controls which diagonal to consider. If + :attr:`offset` = 0, all elements on and above the main diagonal are + retained. A positive value excludes just as many diagonals above the main + diagonal, and similarly a negative value includes just as many diagonals below + the main diagonal. The main diagonal are the set of indices + :math:`\lbrace (i, i) \rbrace` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]` + where :math:`d_{1}, d_{2}` are the dimensions of the matrix. + + .. note:: + When running on CUDA, ``row * col`` must be less than :math:`2^{59}` to + prevent overflow during calculation. + + Args: + row (``int``): number of rows in the 2-D matrix. + col (``int``): number of columns in the 2-D matrix. + offset (``int``): diagonal offset from the main diagonal. + Default: if not provided, 0. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor, + only support ``torch.int``, ``torch.long``. Default: if ``None``, ``torch.long``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + layout (:class:`torch.layout`, optional): currently only support ``torch.strided``. + + Example:: + + >>> a = torch.triu_indices(3, 3) + >>> a + tensor([[0, 0, 0, 1, 1, 2], + [0, 1, 2, 1, 2, 2]]) + + >>> a = torch.triu_indices(4, 3, -1) + >>> a + tensor([[0, 0, 0, 1, 1, 1, 2, 2, 3], + [0, 1, 2, 0, 1, 2, 1, 2, 2]]) + + >>> a = torch.triu_indices(4, 3, 1) + >>> a + tensor([[0, 0, 1], + [1, 2, 2]]) + """ + +def true_divide( + input: Tensor | Number, + other: Tensor | Number, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + true_divide(dividend, divisor, *, out) -> Tensor + + Alias for :func:`torch.div` with ``rounding_mode=None``. + """ + +def trunc(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + trunc(input, *, out=None) -> Tensor + + Returns a new tensor with the truncated integer values of + the elements of :attr:`input`. + + For integer inputs, follows the array-api convention of returning a + copy of the input tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 3.4742, 0.5466, -0.8008, -0.9079]) + >>> torch.trunc(a) + tensor([ 3., 0., -0., -0.]) + """ + +def trunc_(input: Tensor) -> Tensor: ... +@overload +def unbind(input: Tensor, dim: _int = 0) -> tuple[Tensor, ...]: + r""" + unbind(input, dim=0) -> seq + + Removes a tensor dimension. + + Returns a tuple of all slices along a given dimension, already without it. + + Arguments: + input (Tensor): the tensor to unbind + dim (int): dimension to remove + + Example:: + + >>> torch.unbind(torch.tensor([[1, 2, 3], + >>> [4, 5, 6], + >>> [7, 8, 9]])) + (tensor([1, 2, 3]), tensor([4, 5, 6]), tensor([7, 8, 9])) + """ + +@overload +def unbind( + input: Tensor, + dim: str | EllipsisType | None, +) -> tuple[Tensor, ...]: + r""" + unbind(input, dim=0) -> seq + + Removes a tensor dimension. + + Returns a tuple of all slices along a given dimension, already without it. + + Arguments: + input (Tensor): the tensor to unbind + dim (int): dimension to remove + + Example:: + + >>> torch.unbind(torch.tensor([[1, 2, 3], + >>> [4, 5, 6], + >>> [7, 8, 9]])) + (tensor([1, 2, 3]), tensor([4, 5, 6]), tensor([7, 8, 9])) + """ + +def unbind_copy( + input: Tensor, + dim: _int = 0, + *, + out: tuple[Tensor, ...] | list[Tensor] | None = None, +) -> None: + r""" + Performs the same operation as :func:`torch.unbind`, but all output tensors + are freshly created instead of aliasing the input. + """ + +@overload +def unflatten( + input: Tensor, + dim: str | EllipsisType | None, + sizes: Sequence[_int | SymInt], + names: Sequence[str | EllipsisType | None], +) -> Tensor: + r""" + unflatten(input, dim, sizes) -> Tensor + + Expands a dimension of the input tensor over multiple dimensions. + + .. seealso:: + + :func:`torch.flatten` the inverse of this function. It coalesces several dimensions into one. + + Args: + input (Tensor): the input tensor. + dim (int): Dimension to be unflattened, specified as an index into + ``input.shape``. + sizes (Tuple[int]): New shape of the unflattened dimension. + One of its elements can be `-1` in which case the corresponding output + dimension is inferred. Otherwise, the product of ``sizes`` *must* + equal ``input.shape[dim]``. + + Returns: + A View of input with the specified dimension unflattened. + + Examples:: + >>> torch.unflatten(torch.randn(3, 4, 1), 1, (2, 2)).shape + torch.Size([3, 2, 2, 1]) + >>> torch.unflatten(torch.randn(3, 4, 1), 1, (-1, 2)).shape + torch.Size([3, 2, 2, 1]) + >>> torch.unflatten(torch.randn(5, 12, 3), -2, (2, 2, 3, 1, 1)).shape + torch.Size([5, 2, 2, 3, 1, 1, 3]) + """ + +@overload +def unflatten( + input: Tensor, + dim: _int, + sizes: Sequence[_int | SymInt], +) -> Tensor: + r""" + unflatten(input, dim, sizes) -> Tensor + + Expands a dimension of the input tensor over multiple dimensions. + + .. seealso:: + + :func:`torch.flatten` the inverse of this function. It coalesces several dimensions into one. + + Args: + input (Tensor): the input tensor. + dim (int): Dimension to be unflattened, specified as an index into + ``input.shape``. + sizes (Tuple[int]): New shape of the unflattened dimension. + One of its elements can be `-1` in which case the corresponding output + dimension is inferred. Otherwise, the product of ``sizes`` *must* + equal ``input.shape[dim]``. + + Returns: + A View of input with the specified dimension unflattened. + + Examples:: + >>> torch.unflatten(torch.randn(3, 4, 1), 1, (2, 2)).shape + torch.Size([3, 2, 2, 1]) + >>> torch.unflatten(torch.randn(3, 4, 1), 1, (-1, 2)).shape + torch.Size([3, 2, 2, 1]) + >>> torch.unflatten(torch.randn(5, 12, 3), -2, (2, 2, 3, 1, 1)).shape + torch.Size([5, 2, 2, 3, 1, 1, 3]) + """ + +def unfold_copy( + input: Tensor, + dimension: _int, + size: _int, + step: _int, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + Performs the same operation as :func:`torch.unfold`, but all output tensors + are freshly created instead of aliasing the input. + """ + +def unique_dim( + input: Tensor, + dim: _int, + sorted: _bool = True, + return_inverse: _bool = False, + return_counts: _bool = False, +) -> tuple[Tensor, Tensor, Tensor]: ... +def unsafe_chunk( + input: Tensor, + chunks: _int, + dim: _int = 0, +) -> tuple[Tensor, ...]: + r""" + unsafe_chunk(input, chunks, dim=0) -> List of Tensors + + Works like :func:`torch.chunk` but without enforcing the autograd restrictions + on inplace modification of the outputs. + + .. warning:: + This function is safe to use as long as only the input, or only the outputs + are modified inplace after calling this function. It is user's + responsibility to ensure that is the case. If both the input and one or more + of the outputs are modified inplace, gradients computed by autograd will be + silently incorrect. + """ + +def unsafe_split( + input: Tensor, + split_size: _int | SymInt, + dim: _int = 0, +) -> tuple[Tensor, ...]: + r""" + unsafe_split(tensor, split_size_or_sections, dim=0) -> List of Tensors + + Works like :func:`torch.split` but without enforcing the autograd restrictions + on inplace modification of the outputs. + + .. warning:: + This function is safe to use as long as only the input, or only the outputs + are modified inplace after calling this function. It is user's + responsibility to ensure that is the case. If both the input and one or more + of the outputs are modified inplace, gradients computed by autograd will be + silently incorrect. + """ + +def unsafe_split_with_sizes( + input: Tensor, + split_sizes: Sequence[_int | SymInt], + dim: _int = 0, +) -> tuple[Tensor, ...]: ... +def unsqueeze(input: Tensor, dim: _int) -> Tensor: + r""" + unsqueeze(input, dim) -> Tensor + + Returns a new tensor with a dimension of size one inserted at the + specified position. + + The returned tensor shares the same underlying data with this tensor. + + A :attr:`dim` value within the range ``[-input.dim() - 1, input.dim() + 1)`` + can be used. Negative :attr:`dim` will correspond to :meth:`unsqueeze` + applied at :attr:`dim` = ``dim + input.dim() + 1``. + + Args: + input (Tensor): the input tensor. + dim (int): the index at which to insert the singleton dimension + + Example:: + + >>> x = torch.tensor([1, 2, 3, 4]) + >>> torch.unsqueeze(x, 0) + tensor([[ 1, 2, 3, 4]]) + >>> torch.unsqueeze(x, 1) + tensor([[ 1], + [ 2], + [ 3], + [ 4]]) + """ + +def unsqueeze_copy( + input: Tensor, + dim: _int, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + Performs the same operation as :func:`torch.unsqueeze`, but all output tensors + are freshly created instead of aliasing the input. + """ + +def values_copy(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.values`, but all output tensors + are freshly created instead of aliasing the input. + """ + +def vander( + x: Tensor, + N: _int | None = None, + increasing: _bool = False, +) -> Tensor: + r""" + vander(x, N=None, increasing=False) -> Tensor + + Generates a Vandermonde matrix. + + The columns of the output matrix are elementwise powers of the input vector :math:`x^{(N-1)}, x^{(N-2)}, ..., x^0`. + If increasing is True, the order of the columns is reversed :math:`x^0, x^1, ..., x^{(N-1)}`. Such a + matrix with a geometric progression in each row is named for Alexandre-Theophile Vandermonde. + + Arguments: + x (Tensor): 1-D input tensor. + N (int, optional): Number of columns in the output. If N is not specified, + a square array is returned :math:`(N = len(x))`. + increasing (bool, optional): Order of the powers of the columns. If True, + the powers increase from left to right, if False (the default) they are reversed. + + Returns: + Tensor: Vandermonde matrix. If increasing is False, the first column is :math:`x^{(N-1)}`, + the second :math:`x^{(N-2)}` and so forth. If increasing is True, the columns + are :math:`x^0, x^1, ..., x^{(N-1)}`. + + Example:: + + >>> x = torch.tensor([1, 2, 3, 5]) + >>> torch.vander(x) + tensor([[ 1, 1, 1, 1], + [ 8, 4, 2, 1], + [ 27, 9, 3, 1], + [125, 25, 5, 1]]) + >>> torch.vander(x, N=3) + tensor([[ 1, 1, 1], + [ 4, 2, 1], + [ 9, 3, 1], + [25, 5, 1]]) + >>> torch.vander(x, N=3, increasing=True) + tensor([[ 1, 1, 1], + [ 1, 2, 4], + [ 1, 3, 9], + [ 1, 5, 25]]) + """ + +@overload +def var( + input: Tensor, + dim: _int | _size | None, + unbiased: _bool = True, + keepdim: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + var(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the variance over the dimensions specified by :attr:`dim`. :attr:`dim` + can be a single dimension, list of dimensions, or ``None`` to reduce over all + dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.var(a, dim=1, keepdim=True) + tensor([[1.0631], + [0.5590], + [1.4893], + [0.8258]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +@overload +def var( + input: Tensor, + dim: _int | _size | None = None, + *, + correction: Number | _complex | None = None, + keepdim: _bool = False, + out: Tensor | None = None, +) -> Tensor: + r""" + var(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the variance over the dimensions specified by :attr:`dim`. :attr:`dim` + can be a single dimension, list of dimensions, or ``None`` to reduce over all + dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.var(a, dim=1, keepdim=True) + tensor([[1.0631], + [0.5590], + [1.4893], + [0.8258]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +@overload +def var(input: Tensor, unbiased: _bool = True) -> Tensor: + r""" + var(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the variance over the dimensions specified by :attr:`dim`. :attr:`dim` + can be a single dimension, list of dimensions, or ``None`` to reduce over all + dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.var(a, dim=1, keepdim=True) + tensor([[1.0631], + [0.5590], + [1.4893], + [0.8258]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +@overload +def var( + input: Tensor, + dim: Sequence[str | EllipsisType | None], + *, + correction: Number | _complex | None = None, + keepdim: _bool = False, + out: Tensor | None = None, +) -> Tensor: + r""" + var(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the variance over the dimensions specified by :attr:`dim`. :attr:`dim` + can be a single dimension, list of dimensions, or ``None`` to reduce over all + dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.var(a, dim=1, keepdim=True) + tensor([[1.0631], + [0.5590], + [1.4893], + [0.8258]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +@overload +def var( + input: Tensor, + dim: Sequence[str | EllipsisType | None], + unbiased: _bool = True, + keepdim: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + var(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the variance over the dimensions specified by :attr:`dim`. :attr:`dim` + can be a single dimension, list of dimensions, or ``None`` to reduce over all + dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.var(a, dim=1, keepdim=True) + tensor([[1.0631], + [0.5590], + [1.4893], + [0.8258]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +@overload +def var_mean( + input: Tensor, + dim: _int | _size | None, + unbiased: _bool = True, + keepdim: _bool = False, +) -> tuple[Tensor, Tensor]: + r""" + var_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the variance and mean over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Returns: + A tuple (var, mean) containing the variance and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.var_mean(a, dim=0, keepdim=True) + (tensor([[1.5926, 1.0056, 1.2005, 0.3646]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +@overload +def var_mean( + input: Tensor, + dim: _int | _size | None = None, + *, + correction: Number | _complex | None = None, + keepdim: _bool = False, +) -> tuple[Tensor, Tensor]: + r""" + var_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the variance and mean over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Returns: + A tuple (var, mean) containing the variance and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.var_mean(a, dim=0, keepdim=True) + (tensor([[1.5926, 1.0056, 1.2005, 0.3646]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +@overload +def var_mean( + input: Tensor, + unbiased: _bool = True, +) -> tuple[Tensor, Tensor]: + r""" + var_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the variance and mean over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Returns: + A tuple (var, mean) containing the variance and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.var_mean(a, dim=0, keepdim=True) + (tensor([[1.5926, 1.0056, 1.2005, 0.3646]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +@overload +def var_mean( + input: Tensor, + dim: Sequence[str | EllipsisType | None], + *, + correction: Number | _complex | None = None, + keepdim: _bool = False, +) -> tuple[Tensor, Tensor]: + r""" + var_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the variance and mean over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Returns: + A tuple (var, mean) containing the variance and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.var_mean(a, dim=0, keepdim=True) + (tensor([[1.5926, 1.0056, 1.2005, 0.3646]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +@overload +def var_mean( + input: Tensor, + dim: Sequence[str | EllipsisType | None], + unbiased: _bool = True, + keepdim: _bool = False, +) -> tuple[Tensor, Tensor]: + r""" + var_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the variance and mean over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Returns: + A tuple (var, mean) containing the variance and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.var_mean(a, dim=0, keepdim=True) + (tensor([[1.5926, 1.0056, 1.2005, 0.3646]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +def vdot( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + vdot(input, other, *, out=None) -> Tensor + + Computes the dot product of two 1D vectors along a dimension. + + In symbols, this function computes + + .. math:: + + \sum_{i=1}^n \overline{x_i}y_i. + + where :math:`\overline{x_i}` denotes the conjugate for complex + vectors, and it is the identity for real vectors. + + .. note:: + + Unlike NumPy's vdot, torch.vdot intentionally only supports computing the dot product + of two 1D tensors with the same number of elements. + + .. seealso:: + + :func:`torch.linalg.vecdot` computes the dot product of two batches of vectors along a dimension. + + Args: + input (Tensor): first tensor in the dot product, must be 1D. Its conjugate is used if it's complex. + other (Tensor): second tensor in the dot product, must be 1D. + + Keyword args: + + .. note:: out (Tensor, optional): the output tensor. + + + Example:: + + >>> torch.vdot(torch.tensor([2, 3]), torch.tensor([2, 1])) + tensor(7) + >>> a = torch.tensor((1 +2j, 3 - 1j)) + >>> b = torch.tensor((2 +1j, 4 - 0j)) + >>> torch.vdot(a, b) + tensor([16.+1.j]) + >>> torch.vdot(b, a) + tensor([16.-1.j]) + """ + +def view_as_complex(input: Tensor) -> Tensor: + r""" + view_as_complex(input) -> Tensor + + Returns a view of :attr:`input` as a complex tensor. For an input complex + tensor of :attr:`size` :math:`m1, m2, \dots, mi, 2`, this function returns a + new complex tensor of :attr:`size` :math:`m1, m2, \dots, mi` where the last + dimension of the input tensor is expected to represent the real and imaginary + components of complex numbers. + + .. warning:: + :func:`view_as_complex` is only supported for tensors with + :class:`torch.dtype` ``torch.float64`` and ``torch.float32``. The input is + expected to have the last dimension of :attr:`size` 2. In addition, the + tensor must have a `stride` of 1 for its last dimension. The strides of all + other dimensions must be even numbers. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> x=torch.randn(4, 2) + >>> x + tensor([[ 1.6116, -0.5772], + [-1.4606, -0.9120], + [ 0.0786, -1.7497], + [-0.6561, -1.6623]]) + >>> torch.view_as_complex(x) + tensor([(1.6116-0.5772j), (-1.4606-0.9120j), (0.0786-1.7497j), (-0.6561-1.6623j)]) + """ + +def view_as_complex_copy( + input: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + Performs the same operation as :func:`torch.view_as_complex`, but all output tensors + are freshly created instead of aliasing the input. + """ + +def view_as_real(input: Tensor) -> Tensor: + r""" + view_as_real(input) -> Tensor + + Returns a view of :attr:`input` as a real tensor. For an input complex tensor of + :attr:`size` :math:`m1, m2, \dots, mi`, this function returns a new + real tensor of size :math:`m1, m2, \dots, mi, 2`, where the last dimension of size 2 + represents the real and imaginary components of complex numbers. + + .. warning:: + :func:`view_as_real` is only supported for tensors with ``complex dtypes``. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> x=torch.randn(4, dtype=torch.cfloat) + >>> x + tensor([(0.4737-0.3839j), (-0.2098-0.6699j), (0.3470-0.9451j), (-0.5174-1.3136j)]) + >>> torch.view_as_real(x) + tensor([[ 0.4737, -0.3839], + [-0.2098, -0.6699], + [ 0.3470, -0.9451], + [-0.5174, -1.3136]]) + """ + +def view_as_real_copy( + input: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + Performs the same operation as :func:`torch.view_as_real`, but all output tensors + are freshly created instead of aliasing the input. + """ + +@overload +def view_copy( + input: Tensor, + dtype: _dtype, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + Performs the same operation as :func:`torch.view`, but all output tensors + are freshly created instead of aliasing the input. + """ + +@overload +def view_copy( + input: Tensor, + size: Sequence[_int | SymInt], + *, + out: Tensor | None = None, +) -> Tensor: + r""" + Performs the same operation as :func:`torch.view`, but all output tensors + are freshly created instead of aliasing the input. + """ + +@overload +def vsplit(input: Tensor, sections: _int) -> tuple[Tensor, ...]: + r""" + vsplit(input, indices_or_sections) -> List of Tensors + + Splits :attr:`input`, a tensor with two or more dimensions, into multiple tensors + vertically according to :attr:`indices_or_sections`. Each split is a view of + :attr:`input`. + + This is equivalent to calling torch.tensor_split(input, indices_or_sections, dim=0) + (the split dimension is 0), except that if :attr:`indices_or_sections` is an integer + it must evenly divide the split dimension or a runtime error will be thrown. + + This function is based on NumPy's :func:`numpy.vsplit`. + + Args: + input (Tensor): tensor to split. + indices_or_sections (int or list or tuple of ints): See argument in :func:`torch.tensor_split`. + + Example:: + + >>> t = torch.arange(16.0).reshape(4,4) + >>> t + tensor([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.], + [12., 13., 14., 15.]]) + >>> torch.vsplit(t, 2) + (tensor([[0., 1., 2., 3.], + [4., 5., 6., 7.]]), + tensor([[ 8., 9., 10., 11.], + [12., 13., 14., 15.]])) + >>> torch.vsplit(t, [3, 6]) + (tensor([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.]]), + tensor([[12., 13., 14., 15.]]), + tensor([], size=(0, 4))) + """ + +@overload +def vsplit(input: Tensor, indices: _size) -> tuple[Tensor, ...]: + r""" + vsplit(input, indices_or_sections) -> List of Tensors + + Splits :attr:`input`, a tensor with two or more dimensions, into multiple tensors + vertically according to :attr:`indices_or_sections`. Each split is a view of + :attr:`input`. + + This is equivalent to calling torch.tensor_split(input, indices_or_sections, dim=0) + (the split dimension is 0), except that if :attr:`indices_or_sections` is an integer + it must evenly divide the split dimension or a runtime error will be thrown. + + This function is based on NumPy's :func:`numpy.vsplit`. + + Args: + input (Tensor): tensor to split. + indices_or_sections (int or list or tuple of ints): See argument in :func:`torch.tensor_split`. + + Example:: + + >>> t = torch.arange(16.0).reshape(4,4) + >>> t + tensor([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.], + [12., 13., 14., 15.]]) + >>> torch.vsplit(t, 2) + (tensor([[0., 1., 2., 3.], + [4., 5., 6., 7.]]), + tensor([[ 8., 9., 10., 11.], + [12., 13., 14., 15.]])) + >>> torch.vsplit(t, [3, 6]) + (tensor([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.]]), + tensor([[12., 13., 14., 15.]]), + tensor([], size=(0, 4))) + """ + +def vstack( + tensors: tuple[Tensor, ...] | list[Tensor] | None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + vstack(tensors, *, out=None) -> Tensor + + Stack tensors in sequence vertically (row wise). + + This is equivalent to concatenation along the first axis after all 1-D tensors have been reshaped by :func:`torch.atleast_2d`. + + Args: + tensors (sequence of Tensors): sequence of tensors to concatenate + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([1, 2, 3]) + >>> b = torch.tensor([4, 5, 6]) + >>> torch.vstack((a,b)) + tensor([[1, 2, 3], + [4, 5, 6]]) + >>> a = torch.tensor([[1],[2],[3]]) + >>> b = torch.tensor([[4],[5],[6]]) + >>> torch.vstack((a,b)) + tensor([[1], + [2], + [3], + [4], + [5], + [6]]) + """ + +@overload +def where(condition: Tensor) -> tuple[Tensor, ...]: + r""" + where(condition, input, other, *, out=None) -> Tensor + + Return a tensor of elements selected from either :attr:`input` or :attr:`other`, depending on :attr:`condition`. + + The operation is defined as: + + .. math:: + \text{out}_i = \begin{cases} + \text{input}_i & \text{if } \text{condition}_i \\ + \text{other}_i & \text{otherwise} \\ + \end{cases} + + .. note:: + The tensors :attr:`condition`, :attr:`input`, :attr:`other` must be :ref:`broadcastable `. + + Arguments: + condition (BoolTensor): When True (nonzero), yield input, otherwise yield other + input (Tensor or Scalar): value (if :attr:`input` is a scalar) or values selected at indices + where :attr:`condition` is ``True`` + other (Tensor or Scalar): value (if :attr:`other` is a scalar) or values selected at indices + where :attr:`condition` is ``False`` + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + Tensor: A tensor of shape equal to the broadcasted shape of :attr:`condition`, :attr:`input`, :attr:`other` + + Example:: + + >>> x = torch.randn(3, 2) + >>> y = torch.ones(3, 2) + >>> x + tensor([[-0.4620, 0.3139], + [ 0.3898, -0.7197], + [ 0.0478, -0.1657]]) + >>> torch.where(x > 0, 1.0, 0.0) + tensor([[0., 1.], + [1., 0.], + [1., 0.]]) + >>> torch.where(x > 0, x, y) + tensor([[ 1.0000, 0.3139], + [ 0.3898, 1.0000], + [ 0.0478, 1.0000]]) + >>> x = torch.randn(2, 2, dtype=torch.double) + >>> x + tensor([[ 1.0779, 0.0383], + [-0.8785, -1.1089]], dtype=torch.float64) + >>> torch.where(x > 0, x, 0.) + tensor([[1.0779, 0.0383], + [0.0000, 0.0000]], dtype=torch.float64) + + .. function:: where(condition) -> tuple of LongTensor + :noindex: + + ``torch.where(condition)`` is identical to + ``torch.nonzero(condition, as_tuple=True)``. + + .. note:: + See also :func:`torch.nonzero`. + """ + +@overload +def where( + condition: Tensor, + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + where(condition, input, other, *, out=None) -> Tensor + + Return a tensor of elements selected from either :attr:`input` or :attr:`other`, depending on :attr:`condition`. + + The operation is defined as: + + .. math:: + \text{out}_i = \begin{cases} + \text{input}_i & \text{if } \text{condition}_i \\ + \text{other}_i & \text{otherwise} \\ + \end{cases} + + .. note:: + The tensors :attr:`condition`, :attr:`input`, :attr:`other` must be :ref:`broadcastable `. + + Arguments: + condition (BoolTensor): When True (nonzero), yield input, otherwise yield other + input (Tensor or Scalar): value (if :attr:`input` is a scalar) or values selected at indices + where :attr:`condition` is ``True`` + other (Tensor or Scalar): value (if :attr:`other` is a scalar) or values selected at indices + where :attr:`condition` is ``False`` + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + Tensor: A tensor of shape equal to the broadcasted shape of :attr:`condition`, :attr:`input`, :attr:`other` + + Example:: + + >>> x = torch.randn(3, 2) + >>> y = torch.ones(3, 2) + >>> x + tensor([[-0.4620, 0.3139], + [ 0.3898, -0.7197], + [ 0.0478, -0.1657]]) + >>> torch.where(x > 0, 1.0, 0.0) + tensor([[0., 1.], + [1., 0.], + [1., 0.]]) + >>> torch.where(x > 0, x, y) + tensor([[ 1.0000, 0.3139], + [ 0.3898, 1.0000], + [ 0.0478, 1.0000]]) + >>> x = torch.randn(2, 2, dtype=torch.double) + >>> x + tensor([[ 1.0779, 0.0383], + [-0.8785, -1.1089]], dtype=torch.float64) + >>> torch.where(x > 0, x, 0.) + tensor([[1.0779, 0.0383], + [0.0000, 0.0000]], dtype=torch.float64) + + .. function:: where(condition) -> tuple of LongTensor + :noindex: + + ``torch.where(condition)`` is identical to + ``torch.nonzero(condition, as_tuple=True)``. + + .. note:: + See also :func:`torch.nonzero`. + """ + +@overload +def where( + condition: Tensor, + self: Number | _complex, + other: Tensor, +) -> Tensor: + r""" + where(condition, input, other, *, out=None) -> Tensor + + Return a tensor of elements selected from either :attr:`input` or :attr:`other`, depending on :attr:`condition`. + + The operation is defined as: + + .. math:: + \text{out}_i = \begin{cases} + \text{input}_i & \text{if } \text{condition}_i \\ + \text{other}_i & \text{otherwise} \\ + \end{cases} + + .. note:: + The tensors :attr:`condition`, :attr:`input`, :attr:`other` must be :ref:`broadcastable `. + + Arguments: + condition (BoolTensor): When True (nonzero), yield input, otherwise yield other + input (Tensor or Scalar): value (if :attr:`input` is a scalar) or values selected at indices + where :attr:`condition` is ``True`` + other (Tensor or Scalar): value (if :attr:`other` is a scalar) or values selected at indices + where :attr:`condition` is ``False`` + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + Tensor: A tensor of shape equal to the broadcasted shape of :attr:`condition`, :attr:`input`, :attr:`other` + + Example:: + + >>> x = torch.randn(3, 2) + >>> y = torch.ones(3, 2) + >>> x + tensor([[-0.4620, 0.3139], + [ 0.3898, -0.7197], + [ 0.0478, -0.1657]]) + >>> torch.where(x > 0, 1.0, 0.0) + tensor([[0., 1.], + [1., 0.], + [1., 0.]]) + >>> torch.where(x > 0, x, y) + tensor([[ 1.0000, 0.3139], + [ 0.3898, 1.0000], + [ 0.0478, 1.0000]]) + >>> x = torch.randn(2, 2, dtype=torch.double) + >>> x + tensor([[ 1.0779, 0.0383], + [-0.8785, -1.1089]], dtype=torch.float64) + >>> torch.where(x > 0, x, 0.) + tensor([[1.0779, 0.0383], + [0.0000, 0.0000]], dtype=torch.float64) + + .. function:: where(condition) -> tuple of LongTensor + :noindex: + + ``torch.where(condition)`` is identical to + ``torch.nonzero(condition, as_tuple=True)``. + + .. note:: + See also :func:`torch.nonzero`. + """ + +@overload +def where( + condition: Tensor, + input: Tensor, + other: Number | _complex, +) -> Tensor: + r""" + where(condition, input, other, *, out=None) -> Tensor + + Return a tensor of elements selected from either :attr:`input` or :attr:`other`, depending on :attr:`condition`. + + The operation is defined as: + + .. math:: + \text{out}_i = \begin{cases} + \text{input}_i & \text{if } \text{condition}_i \\ + \text{other}_i & \text{otherwise} \\ + \end{cases} + + .. note:: + The tensors :attr:`condition`, :attr:`input`, :attr:`other` must be :ref:`broadcastable `. + + Arguments: + condition (BoolTensor): When True (nonzero), yield input, otherwise yield other + input (Tensor or Scalar): value (if :attr:`input` is a scalar) or values selected at indices + where :attr:`condition` is ``True`` + other (Tensor or Scalar): value (if :attr:`other` is a scalar) or values selected at indices + where :attr:`condition` is ``False`` + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + Tensor: A tensor of shape equal to the broadcasted shape of :attr:`condition`, :attr:`input`, :attr:`other` + + Example:: + + >>> x = torch.randn(3, 2) + >>> y = torch.ones(3, 2) + >>> x + tensor([[-0.4620, 0.3139], + [ 0.3898, -0.7197], + [ 0.0478, -0.1657]]) + >>> torch.where(x > 0, 1.0, 0.0) + tensor([[0., 1.], + [1., 0.], + [1., 0.]]) + >>> torch.where(x > 0, x, y) + tensor([[ 1.0000, 0.3139], + [ 0.3898, 1.0000], + [ 0.0478, 1.0000]]) + >>> x = torch.randn(2, 2, dtype=torch.double) + >>> x + tensor([[ 1.0779, 0.0383], + [-0.8785, -1.1089]], dtype=torch.float64) + >>> torch.where(x > 0, x, 0.) + tensor([[1.0779, 0.0383], + [0.0000, 0.0000]], dtype=torch.float64) + + .. function:: where(condition) -> tuple of LongTensor + :noindex: + + ``torch.where(condition)`` is identical to + ``torch.nonzero(condition, as_tuple=True)``. + + .. note:: + See also :func:`torch.nonzero`. + """ + +@overload +def where( + condition: Tensor, + self: Number | _complex, + other: Number | _complex, +) -> Tensor: + r""" + where(condition, input, other, *, out=None) -> Tensor + + Return a tensor of elements selected from either :attr:`input` or :attr:`other`, depending on :attr:`condition`. + + The operation is defined as: + + .. math:: + \text{out}_i = \begin{cases} + \text{input}_i & \text{if } \text{condition}_i \\ + \text{other}_i & \text{otherwise} \\ + \end{cases} + + .. note:: + The tensors :attr:`condition`, :attr:`input`, :attr:`other` must be :ref:`broadcastable `. + + Arguments: + condition (BoolTensor): When True (nonzero), yield input, otherwise yield other + input (Tensor or Scalar): value (if :attr:`input` is a scalar) or values selected at indices + where :attr:`condition` is ``True`` + other (Tensor or Scalar): value (if :attr:`other` is a scalar) or values selected at indices + where :attr:`condition` is ``False`` + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + Tensor: A tensor of shape equal to the broadcasted shape of :attr:`condition`, :attr:`input`, :attr:`other` + + Example:: + + >>> x = torch.randn(3, 2) + >>> y = torch.ones(3, 2) + >>> x + tensor([[-0.4620, 0.3139], + [ 0.3898, -0.7197], + [ 0.0478, -0.1657]]) + >>> torch.where(x > 0, 1.0, 0.0) + tensor([[0., 1.], + [1., 0.], + [1., 0.]]) + >>> torch.where(x > 0, x, y) + tensor([[ 1.0000, 0.3139], + [ 0.3898, 1.0000], + [ 0.0478, 1.0000]]) + >>> x = torch.randn(2, 2, dtype=torch.double) + >>> x + tensor([[ 1.0779, 0.0383], + [-0.8785, -1.1089]], dtype=torch.float64) + >>> torch.where(x > 0, x, 0.) + tensor([[1.0779, 0.0383], + [0.0000, 0.0000]], dtype=torch.float64) + + .. function:: where(condition) -> tuple of LongTensor + :noindex: + + ``torch.where(condition)`` is identical to + ``torch.nonzero(condition, as_tuple=True)``. + + .. note:: + See also :func:`torch.nonzero`. + """ + +@overload +def xlogy( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + xlogy(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.special.xlogy`. + """ + +@overload +def xlogy( + self: Number | _complex, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + xlogy(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.special.xlogy`. + """ + +@overload +def xlogy( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + xlogy(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.special.xlogy`. + """ + +@overload +def xlogy_(input: Tensor, other: Tensor) -> Tensor: ... +@overload +def xlogy_(input: Tensor, other: Number | _complex) -> Tensor: ... +def zero_(input: Tensor) -> Tensor: ... +@overload +def zeros( + size: Sequence[_int | SymInt], + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + zeros(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with the scalar value `0`, with the shape defined + by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.zeros(2, 3) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.]]) + + >>> torch.zeros(5) + tensor([ 0., 0., 0., 0., 0.]) + """ + +@overload +def zeros( + *size: _int | SymInt, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + zeros(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with the scalar value `0`, with the shape defined + by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.zeros(2, 3) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.]]) + + >>> torch.zeros(5) + tensor([ 0., 0., 0., 0., 0.]) + """ + +@overload +def zeros( + size: _size, + *, + names: Sequence[str | EllipsisType | None] | None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + zeros(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with the scalar value `0`, with the shape defined + by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.zeros(2, 3) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.]]) + + >>> torch.zeros(5) + tensor([ 0., 0., 0., 0., 0.]) + """ + +@overload +def zeros( + *size: _int, + names: Sequence[str | EllipsisType | None] | None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + zeros(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with the scalar value `0`, with the shape defined + by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.zeros(2, 3) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.]]) + + >>> torch.zeros(5) + tensor([ 0., 0., 0., 0., 0.]) + """ + +def zeros_like( + input: Tensor, + *, + memory_format: memory_format | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + zeros_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + + Returns a tensor filled with the scalar value `0`, with the same size as + :attr:`input`. ``torch.zeros_like(input)`` is equivalent to + ``torch.zeros(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``. + + .. warning:: + As of 0.4, this function does not support an :attr:`out` keyword. As an alternative, + the old ``torch.zeros_like(input, out=output)`` is equivalent to + ``torch.zeros(input.size(), out=output)``. + + Args: + input (Tensor): the size of :attr:`input` will determine size of the output tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: if ``None``, defaults to the dtype of :attr:`input`. + layout (:class:`torch.layout`, optional): the desired layout of returned tensor. + Default: if ``None``, defaults to the layout of :attr:`input`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, defaults to the device of :attr:`input`. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + + Example:: + + >>> input = torch.empty(2, 3) + >>> torch.zeros_like(input) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.]]) + """ diff --git a/phivenv/Lib/site-packages/torch/__config__.py b/phivenv/Lib/site-packages/torch/__config__.py new file mode 100644 index 0000000000000000000000000000000000000000..9027c86a57fb7f25ebf33918509d5ffbb60b68c8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/__config__.py @@ -0,0 +1,22 @@ +import torch + + +def show() -> str: + """ + Return a human-readable string with descriptions of the + configuration of PyTorch. + """ + return torch._C._show_config() + + +# TODO: In principle, we could provide more structured version/config +# information here. For now only CXX_FLAGS is exposed, as Timer +# uses them. +def _cxx_flags() -> str: + """Returns the CXX_FLAGS used when building PyTorch.""" + return torch._C._cxx_flags() + + +def parallel_info() -> str: + r"""Returns detailed string with parallelization settings""" + return torch._C._parallel_info() diff --git a/phivenv/Lib/site-packages/torch/__future__.py b/phivenv/Lib/site-packages/torch/__future__.py new file mode 100644 index 0000000000000000000000000000000000000000..950bdaaff336117b80390aa776d93e5ae72f7c24 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/__future__.py @@ -0,0 +1,75 @@ +_overwrite_module_params_on_conversion: bool = False +_swap_module_params_on_conversion: bool = False + + +def set_overwrite_module_params_on_conversion(value: bool) -> None: + """ + Sets whether to assign new tensors to the parameters instead of changing the + existing parameters in-place when converting an ``nn.Module``. + + When enabled, the following methods will assign new parameters to the module: + + #. ``module.{device}()`` (e.g. :meth:`nn.Module.cuda()`) for moving a module between devices + #. ``module.{dtype}()`` (e.g. :meth:`nn.Module.float()`) for converting a module to a different dtype + #. :meth:`nn.Module.to` + #. :meth:`nn.Module.to_empty` + + Args: + value (bool): Whether to assign new tensors or not. + + """ + global _overwrite_module_params_on_conversion + _overwrite_module_params_on_conversion = value + + +def get_overwrite_module_params_on_conversion() -> bool: + """ + Returns whether to assign new tensors to the parameters instead of changing the + existing parameters in-place when converting an :class:`torch.nn.Module`. Defaults to ``False``. + + See :func:`~torch.__future__.set_overwrite_module_params_on_conversion` for more information. + """ + return _overwrite_module_params_on_conversion + + +def set_swap_module_params_on_conversion(value: bool) -> None: + """ + Sets whether to use :func:`~torch.utils.swap_tensors` instead of setting ``.data`` to + change the existing parameters in-place when converting an ``nn.Module`` and instead + of ``param.copy_(state_dict[key])`` when loading a state dict into an ``nn.Module``. + + .. note:: + This function takes precedence over :func:`~torch.__future__.get_overwrite_module_params_on_conversion` + + When enabled, the following methods will swap the existing parameters in-place: + + #. ``module.{device}()`` (e.g. :meth:`nn.Module.cuda()`) for moving a module between devices + #. ``module.{dtype}()`` (e.g. :meth:`nn.Module.float()`) for converting a module to a different dtype + #. :meth:`nn.Module.to` + #. :meth:`nn.Module.to_empty` + #. :meth:`nn.Module.load_state_dict` + + The semantics for :meth:`~nn.Module.load_state_dict` when this is set are as follows: + + #. For each parameter/buffer, its corresponding ``state_dict['key']`` is transformed via + :meth:`~torch.Tensor.module_load` (i.e. ``res = param.module_load(state_dict['key'])``) + #. If necessary, ``res`` will be wrapped in an :class:`~nn.Parameter` + #. The parameter/buffer in the module will be swapped via :func:`~torch.utils.swap_tensors` + with ``res`` + + Args: + value (bool): Whether to use :func:`~torch.utils.swap_tensors` or not. + + """ + global _swap_module_params_on_conversion + _swap_module_params_on_conversion = value + + +def get_swap_module_params_on_conversion() -> bool: + """ + Returns whether to use :func:`~torch.utils.swap_tensors` instead of setting .data to + change the existing parameters in-place when converting an ``nn.Module``. Defaults to ``False``. + + See :func:`~torch.__future__.set_swap_module_params_on_conversion` for more information. + """ + return _swap_module_params_on_conversion diff --git a/phivenv/Lib/site-packages/torch/__init__.py b/phivenv/Lib/site-packages/torch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7571d79293ce67567f92bec26dcc9fa517783045 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/__init__.py @@ -0,0 +1,2866 @@ +""" +The torch package contains data structures for multi-dimensional +tensors and defines mathematical operations over these tensors. +Additionally, it provides many utilities for efficient serialization of +Tensors and arbitrary types, and other useful utilities. + +It has a CUDA counterpart, that enables you to run your tensor computations +on an NVIDIA GPU with compute capability >= 3.0. +""" + +# mypy: allow-untyped-defs + +import builtins +import ctypes +import functools +import glob +import importlib +import inspect +import math +import os +import platform +import sys +import textwrap +import threading +from typing import ( + Any as _Any, + Callable as _Callable, + get_origin as _get_origin, + Optional as _Optional, + overload as _overload, + TYPE_CHECKING, + TypeVar as _TypeVar, + Union as _Union, +) +from typing_extensions import ParamSpec as _ParamSpec + + +if TYPE_CHECKING: + from .types import Device, IntLikeType + + +# multipy/deploy is setting this import before importing torch, this is the most +# reliable way we have to detect if we're running within deploy. +# https://github.com/pytorch/multipy/blob/d60f34ad38c371e441fe7ffdb77a3c3dda5a5d19/multipy/runtime/interpreter/interpreter_impl.cpp#L134-L137 +def _running_with_deploy() -> builtins.bool: + return sys.modules.get("torch._meta_registrations", None) is object + + +from torch._utils import ( + _functionalize_sync as _sync, + _import_dotted_name, + classproperty, +) +from torch._utils_internal import ( + get_file_path, + prepare_multiprocessing_environment, + profiler_allow_cudagraph_cupti_lazy_reinit_cuda12, + USE_GLOBAL_DEPS, + USE_RTLD_GLOBAL_WITH_LIBTORCH, +) + + +# TODO(torch_deploy) figure out how to freeze version.py in fbcode build +if _running_with_deploy(): + __version__ = "torch-deploy-1.8" + # TODO: Remove this ugly hack when deploy typing extensions are updated to 4.10+ + if not TYPE_CHECKING: + import typing_extensions + + _TypeIs = typing_extensions.TypeGuard + typing_extensions.TypeIs = _TypeIs +else: + from typing_extensions import TypeIs as _TypeIs + + from torch.torch_version import __version__ as __version__ + +__all__ = [ + "BoolStorage", + "BoolTensor", + "ByteStorage", + "ByteTensor", + "CharStorage", + "CharTensor", + "DoubleStorage", + "DoubleTensor", + "FloatStorage", + "FloatTensor", + "GradScaler", + "IntStorage", + "IntTensor", + "LongStorage", + "LongTensor", + "ShortStorage", + "ShortTensor", + "SymBool", + "SymFloat", + "SymInt", + "Tensor", + "TypedStorage", + "UntypedStorage", + "are_deterministic_algorithms_enabled", + "autocast", + "chunk", + "compile", + "cond", + "enable_grad", + "export", + "get_default_device", + "get_deterministic_debug_mode", + "get_device_module", + "get_float32_matmul_precision", + "get_rng_state", + "inference_mode", + "initial_seed", + "is_deterministic_algorithms_warn_only_enabled", + "is_storage", + "is_tensor", + "is_warn_always_enabled", + "load", + "lobpcg", + "manual_seed", + "matmul", + "no_grad", + "rand", + "randn", + "save", + "seed", + "set_default_device", + "set_default_tensor_type", + "set_deterministic_debug_mode", + "set_float32_matmul_precision", + "set_printoptions", + "set_rng_state", + "set_warn_always", + "split", + "stack", + "sym_float", + "sym_fresh_size", + "sym_int", + "sym_ite", + "sym_max", + "sym_min", + "sym_not", + "sym_sum", + "typename", + "unravel_index", + "use_deterministic_algorithms", + "vmap", +] + +# Please keep this list sorted +assert __all__ == sorted(__all__) + +################################################################################ +# Load the extension module +################################################################################ + +# If PyTorch was built against the ROCm runtime wheels, then there will be +# a _rocm_init module and it will define an initialize() function which can +# prepare ROCm for use. See general documentation on ROCm runtime wheels: +# https://github.com/ROCm/TheRock/blob/main/docs/packaging/python_packaging.md +# Since this module is only ever added to the wheel if built for such a +# deployment, it is always safe to attempt. +try: + from . import _rocm_init # type: ignore[attr-defined] +except ImportError: + pass +else: + _rocm_init.initialize() + del _rocm_init + + +if sys.platform == "win32": + + def _load_dll_libraries() -> None: + import sysconfig + + from torch.version import cuda as cuda_version + + pfiles_path = os.getenv("ProgramFiles", r"C:\Program Files") + py_dll_path = os.path.join(sys.exec_prefix, "Library", "bin") + th_dll_path = os.path.join(os.path.dirname(__file__), "lib") + usebase_path = os.path.join( + sysconfig.get_config_var("userbase"), "Library", "bin" + ) + py_root_bin_path = os.path.join(sys.exec_prefix, "bin") + + # When users create a virtualenv that inherits the base environment, + # we will need to add the corresponding library directory into + # DLL search directories. Otherwise, it will rely on `PATH` which + # is dependent on user settings. + if sys.exec_prefix != sys.base_exec_prefix: + base_py_dll_path = os.path.join(sys.base_exec_prefix, "Library", "bin") + else: + base_py_dll_path = "" + + dll_paths = [ + p + for p in ( + th_dll_path, + py_dll_path, + base_py_dll_path, + usebase_path, + py_root_bin_path, + ) + if os.path.exists(p) + ] + + if cuda_version and builtins.all( + not glob.glob(os.path.join(p, "cudart64*.dll")) for p in dll_paths + ): + cuda_version_1 = cuda_version.replace(".", "_") + cuda_path_var = "CUDA_PATH_V" + cuda_version_1 + default_path = os.path.join( + pfiles_path, "NVIDIA GPU Computing Toolkit", "CUDA", f"v{cuda_version}" + ) + cuda_path = os.path.join(os.getenv(cuda_path_var, default_path), "bin") + else: + cuda_path = "" + + dll_paths.extend(p for p in (cuda_path,) if os.path.exists(p)) + + kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True) + with_load_library_flags = hasattr(kernel32, "AddDllDirectory") + prev_error_mode = kernel32.SetErrorMode(0x0001) + + kernel32.LoadLibraryW.restype = ctypes.c_void_p + if with_load_library_flags: + kernel32.LoadLibraryExW.restype = ctypes.c_void_p + + for dll_path in dll_paths: + os.add_dll_directory(dll_path) + + try: + ctypes.CDLL("vcruntime140.dll") + ctypes.CDLL("msvcp140.dll") + if platform.machine() != "ARM64": + ctypes.CDLL("vcruntime140_1.dll") + except OSError: + print( + textwrap.dedent( + """ + Microsoft Visual C++ Redistributable is not installed, this may lead to the DLL load failure. + It can be downloaded at https://aka.ms/vs/16/release/vc_redist.x64.exe + """ + ).strip() + ) + + dlls = glob.glob(os.path.join(th_dll_path, "*.dll")) + path_patched = False + for dll in dlls: + is_loaded = False + if with_load_library_flags: + res = kernel32.LoadLibraryExW(dll, None, 0x00001100) + last_error = ctypes.get_last_error() + if res is None and last_error != 126: + err = ctypes.WinError(last_error) + err.strerror += ( + f' Error loading "{dll}" or one of its dependencies.' + ) + raise err + elif res is not None: + is_loaded = True + if not is_loaded: + if not path_patched: + os.environ["PATH"] = ";".join(dll_paths + [os.environ["PATH"]]) + path_patched = True + res = kernel32.LoadLibraryW(dll) + if res is None: + err = ctypes.WinError(ctypes.get_last_error()) + err.strerror += ( + f' Error loading "{dll}" or one of its dependencies.' + ) + raise err + + kernel32.SetErrorMode(prev_error_mode) + + _load_dll_libraries() + del _load_dll_libraries + + +def _get_cuda_dep_paths(path: str, lib_folder: str, lib_name: str) -> list[str]: + # Libraries can either be in path/nvidia/lib_folder/lib or path/lib_folder/lib + nvidia_lib_paths = glob.glob( + os.path.join(path, "nvidia", lib_folder, "lib", lib_name) + ) + lib_paths = glob.glob(os.path.join(path, lib_folder, "lib", lib_name)) + + return nvidia_lib_paths + lib_paths + + +def _preload_cuda_deps(lib_folder: str, lib_name: str) -> None: + """Preloads cuda deps if they could not be found otherwise.""" + # Should only be called on Linux if default path resolution have failed + assert platform.system() == "Linux", "Should only be called on Linux" + + lib_path = None + for path in sys.path: + candidate_lib_paths = _get_cuda_dep_paths(path, lib_folder, lib_name) + if candidate_lib_paths: + lib_path = candidate_lib_paths[0] + break + if not lib_path: + raise ValueError(f"{lib_name} not found in the system path {sys.path}") + ctypes.CDLL(lib_path) + + +# See Note [Global dependencies] +def _load_global_deps() -> None: + if _running_with_deploy() or platform.system() == "Windows": + return + + # Determine the file extension based on the platform + lib_ext = ".dylib" if platform.system() == "Darwin" else ".so" + lib_name = f"libtorch_global_deps{lib_ext}" + here = os.path.abspath(__file__) + global_deps_lib_path = os.path.join(os.path.dirname(here), "lib", lib_name) + + try: + ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL) + # Workaround slim-wheel CUDA dependency bugs in cusparse and cudnn by preloading nvjitlink + # and nvrtc. In CUDA-12.4+ cusparse depends on nvjitlink, but does not have rpath when + # shipped as wheel, which results in OS picking wrong/older version of nvjitlink library + # if `LD_LIBRARY_PATH` is defined, see https://github.com/pytorch/pytorch/issues/138460 + # Similar issue exist in cudnn that dynamically loads nvrtc, unaware of its relative path. + # See https://github.com/pytorch/pytorch/issues/145580 + try: + with open("/proc/self/maps") as f: + _maps = f.read() + # libtorch_global_deps.so always depends in cudart, check if its installed via wheel + if "nvidia/cuda_runtime/lib/libcudart.so" not in _maps: + return + # If all above-mentioned conditions are met, preload nvrtc and nvjitlink + # Please note that order are important for CUDA-11.8 , as nvjitlink does not exist there + _preload_cuda_deps("cuda_nvrtc", "libnvrtc.so.*[0-9]") + _preload_cuda_deps("nvjitlink", "libnvJitLink.so.*[0-9]") + except Exception: + pass + + except OSError as err: + # Can only happen for wheel with cuda libs as PYPI deps + # As PyTorch is not purelib, but nvidia-*-cu12 is + from torch.version import cuda as cuda_version + + cuda_libs: dict[str, str] = { + "cublas": "libcublas.so.*[0-9]", + "cudnn": "libcudnn.so.*[0-9]", + "cuda_nvrtc": "libnvrtc.so.*[0-9]", + "cuda_runtime": "libcudart.so.*[0-9]", + "cuda_cupti": "libcupti.so.*[0-9]", + "cufft": "libcufft.so.*[0-9]", + "curand": "libcurand.so.*[0-9]", + "nvjitlink": "libnvJitLink.so.*[0-9]", + "cusparse": "libcusparse.so.*[0-9]", + "cusparselt": "libcusparseLt.so.*[0-9]", + "cusolver": "libcusolver.so.*[0-9]", + "nccl": "libnccl.so.*[0-9]", + } + # cufiile is only available on cuda 12+ + # TODO: Remove once CUDA 11.8 binaries are deprecated + if cuda_version is not None: + t_version = cuda_version.split(".") + t_major = int(t_version[0]) # type: ignore[operator] + if t_major >= 12: + cuda_libs["cufile"] = "libcufile.so.*[0-9]" + + is_cuda_lib_err = [ + lib for lib in cuda_libs.values() if lib.split(".")[0] in err.args[0] + ] + if not is_cuda_lib_err: + raise err + for lib_folder, lib_name in cuda_libs.items(): + _preload_cuda_deps(lib_folder, lib_name) + ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL) + + +if (USE_RTLD_GLOBAL_WITH_LIBTORCH or os.getenv("TORCH_USE_RTLD_GLOBAL")) and ( + _running_with_deploy() or platform.system() != "Windows" +): + # Do it the hard way. You might want to load libtorch with RTLD_GLOBAL in a + # few circumstances: + # + # 1. You're in a build environment (e.g., fbcode) where + # libtorch_global_deps is not available, but you still need + # to get mkl to link in with RTLD_GLOBAL or it will just + # not work. + # + # 2. You're trying to run PyTorch under UBSAN and you need + # to ensure that only one copy of libtorch is loaded, so + # vptr checks work properly + # + # If you're using this setting, you must verify that all the libraries + # you load consistently use the same libstdc++, or you may have + # mysterious segfaults. + # + old_flags = sys.getdlopenflags() + sys.setdlopenflags(os.RTLD_GLOBAL | os.RTLD_LAZY) + + from torch._C import * # noqa: F403 + + sys.setdlopenflags(old_flags) + del old_flags + +else: + # Easy way. You want this most of the time, because it will prevent + # C++ symbols from libtorch clobbering C++ symbols from other + # libraries, leading to mysterious segfaults. + # + # If building in an environment where libtorch_global_deps isn't available + # like parts of fbsource, but where RTLD_GLOBAL causes segfaults, you will + # want USE_RTLD_GLOBAL_WITH_LIBTORCH = False and USE_GLOBAL_DEPS = False + # + # See Note [Global dependencies] + if USE_GLOBAL_DEPS: + _load_global_deps() + from torch._C import * # noqa: F403 + + +class SymInt: + """ + Like an int (including magic methods), but redirects all operations on the + wrapped node. This is used in particular to symbolically record operations + in the symbolic shape workflow. + """ + + def __init__(self, node): + # This field MUST be named node; C++ binding code assumes that this + # class has a field named node that stores SymNode + self.node = node + + def __bool__(self): + return builtins.bool(self != 0) + + def __int__(self): + return self.node.int_() + + def __index__(self): + return self.node.int_() + + # Magic methods installed by torch.fx.experimental.sym_node + + def __round__(self, ndigits=None): + return self + + def __truediv__(self, other): + if isinstance(other, (builtins.float, SymFloat)): + return sym_float(self).__float_truediv__(other) + if not isinstance(other, (builtins.int, SymInt)): + return NotImplemented + return self.__int_truediv__(other) + + def __rtruediv__(self, other): + if isinstance(other, (builtins.float, SymFloat)): + return sym_float(self).__rfloat_truediv__(other) + if not isinstance(other, (builtins.int, SymInt)): + return NotImplemented + return self.__rint_truediv__(other) + + def __floordiv__(self, other): + if isinstance(other, (builtins.float, SymFloat)): + return sym_float(math.floor(sym_float(self) / other)) + if not isinstance(other, (builtins.int, SymInt)): + return NotImplemented + return self.__int_floordiv__(other) + + def __rfloordiv__(self, other): + if isinstance(other, (builtins.float, SymFloat)): + return sym_float(math.floor(other / sym_float(self))) + if not isinstance(other, (builtins.int, SymInt)): + return NotImplemented + return self.__rint_floordiv__(other) + + # nb: complex is impossible to handle correctly lol, with + # negative base and integral float need to diverge semantics and + # just always return complex. Neener neener pretend this problem + # doesn't exist + def __pow__(self, other): + if isinstance(other, (builtins.float, SymFloat)): + return sym_float(self).__pow__(other) + if not isinstance(other, (builtins.int, SymInt)): + return NotImplemented + # Guards! This guard is necessary because we need to know it to + # determine the output type of this operation + if other >= 0: + return self.__pow_by_natural__(other) + else: + # Mercifully, when the exponent is negative, Python just promotes + # to doubles and does a float pow: + # + # if (Py_SIZE(b) < 0 && c == NULL) { + # /* if exponent is negative and there's no modulus: + # return a float. This works because we know + # that this calls float_pow() which converts its + # arguments to double. */ + # Py_DECREF(a); + # Py_DECREF(b); + # return PyFloat_Type.tp_as_number->nb_power(v, w, x); + # } + return sym_float(self).__pow__(sym_float(other)) + + def __rpow__(self, other): + if isinstance(other, (builtins.float, SymFloat)): + return sym_float(self).__rpow__(other) + if not isinstance(other, (builtins.int, SymInt)): + return NotImplemented + if self >= 0: # self is exponent + return self.__rpow_by_natural__(other) + else: + return sym_float(self).__rpow__(sym_float(other)) + + def __eq__(self, other: object) -> builtins.bool: + raise TypeError("type stub not overridden") + + def __lt__(self, other) -> builtins.bool: + raise TypeError("type stub not overridden") + + def __gt__(self, other) -> builtins.bool: + raise TypeError("type stub not overridden") + + def __le__(self, other) -> builtins.bool: + raise TypeError("type stub not overridden") + + def __ge__(self, other) -> builtins.bool: + raise TypeError("type stub not overridden") + + def __add__(self, other) -> "SymInt": + raise TypeError("type stub not overridden") + + def __radd__(self, other) -> "SymInt": + raise TypeError("type stub not overridden") + + def __rmul__(self, other) -> "SymInt": + raise TypeError("type stub not overridden") + + def __mod__(self, other: "IntLikeType") -> "SymInt": + raise TypeError("type stub not overridden") + + def __mul__(self, other) -> "SymInt": + raise TypeError("type stub not overridden") + + def __pow_by_natural__(self, other) -> "SymInt": + raise TypeError("type stub not overridden") + + def __rpow_by_natural__(self, other) -> "SymInt": + raise TypeError("type stub not overridden") + + def __int_truediv__(self, other) -> "SymFloat": + raise TypeError("type stub not overridden") + + def __rint_truediv__(self, other) -> "SymFloat": + raise TypeError("type stub not overridden") + + def __int_floordiv__(self, other) -> "SymFloat": + raise TypeError("type stub not overridden") + + def __rint_floordiv__(self, other) -> "SymFloat": + raise TypeError("type stub not overridden") + + def __sym_max__(self, other): + raise TypeError("type stub not overridden") + + def __sym_min__(self, other): + raise TypeError("type stub not overridden") + + def __sym_float__(self): + raise TypeError("type stub not overridden") + + def __neg__(self): + raise TypeError("type stub not overridden") + + def __sub__(self, other: "IntLikeType") -> "SymInt": + raise TypeError("type stub not overridden") + + def __rsub__(self, other: "IntLikeType") -> "SymInt": + raise TypeError("type stub not overridden") + + def __and__(self, other) -> "SymInt": + raise TypeError("type stub not overridden") + + def __or__(self, other) -> "SymInt": + raise TypeError("type stub not overridden") + + def __repr__(self): + return self.node._graph_repr() + + def _sympy_(self): + return self.node.expr + + def __hash__(self) -> builtins.int: + if self.node.is_nested_int(): + return hash(self.node.nested_int()) + else: + # We could support constant SymInts as well, but not doing it for now + raise TypeError("unhashable type: non-nested SymInt") + # TODO: Force specialization + # This can't be done because the TypeError here is load bearing + # for einops + # https://github.com/arogozhnikov/einops/blob/6181e1e95dc58c00a3143c1726da1c6ee0463164/einops/einops.py#L237 + # return hash(builtins.int(self)) + + def as_integer_ratio(self) -> tuple["SymInt", builtins.int]: + """Represent this int as an exact integer ratio""" + return self, 1 + + def bit_length(self) -> builtins.int: + # TODO: A more relaxed guard is possible here, where you guard to + # allow all integer quantities which would result in the same bit + # length. We can also just make a dedicated Sympy function for + # computing this quantity and represent it symbolically. + return builtins.int(self).bit_length() + + def conjugate(self) -> "SymInt": + return self + + +class SymFloat: + """ + Like a float (including magic methods), but redirects all operations on the + wrapped node. This is used in particular to symbolically record operations + in the symbolic shape workflow. + """ + + def __init__(self, node): + # This field MUST be named node; C++ binding code assumes that this + # class has a field named node that stores SymNode + self.node = node + + def __truediv__(self, other): + if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): + return NotImplemented + return self.__float_truediv__(sym_float(other)) + + def __rtruediv__(self, other): + if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): + return NotImplemented + return self.__rfloat_truediv__(sym_float(other)) + + def __floordiv__(self, other): + if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): + return NotImplemented + return sym_float(math.floor(self / sym_float(other))) + + def __rfloordiv__(self, other): + if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): + return NotImplemented + return sym_float(math.floor(sym_float(other) / self)) + + def __bool__(self): + return self.node.bool_() + + def __float__(self): + return self.node.guard_float("", 0) + + # Symbolic power does NOT work with negative base, this is to avoid + # potential complex outputs + def __pow__(self, other): + if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): + return NotImplemented + torch._check(self >= 0) + return self.__float_pow__(other) + + def __rpow__(self, other): + if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): + return NotImplemented + torch._check(other >= 0) + return self.__rfloat_pow__(other) + + # Magic methods installed by torch.fx.experimental.sym_node + + def __eq__(self, other: object) -> builtins.bool: + raise TypeError("type stub not overridden") + + def __lt__(self, other) -> builtins.bool: + raise TypeError("type stub not overridden") + + def __gt__(self, other) -> builtins.bool: + raise TypeError("type stub not overridden") + + def __le__(self, other) -> builtins.bool: + raise TypeError("type stub not overridden") + + def __ge__(self, other) -> builtins.bool: + raise TypeError("type stub not overridden") + + def __float_pow__(self, other) -> "SymFloat": + raise TypeError("type stub not overridden") + + def __rfloat_pow__(self, other) -> "SymFloat": + raise TypeError("type stub not overridden") + + def __float_truediv__(self, other) -> "SymFloat": + raise TypeError("type stub not overridden") + + def __rfloat_truediv__(self, other) -> "SymFloat": + raise TypeError("type stub not overridden") + + def __trunc__(self): + raise TypeError("type stub not overridden") + + def __sym_max__(self, other): + raise TypeError("type stub not overridden") + + def __sym_min__(self, other): + raise TypeError("type stub not overridden") + + def __sym_int__(self): + raise TypeError("type stub not overridden") + + def is_integer(self): + """Return True if the float is an integer.""" + raise TypeError("type stub not overridden") + + def as_integer_ratio(self) -> tuple[builtins.int, builtins.int]: + """Represent this float as an exact integer ratio""" + return builtins.float(self).as_integer_ratio() + + def __repr__(self): + return self.node._graph_repr() + + def _sympy_(self): + return self.node.expr + + def __hash__(self): + return hash(builtins.float(self)) + + def conjugate(self) -> "SymFloat": + """Returns the complex conjugate of the float.""" + return self + + def hex(self) -> str: + """Returns the hexadecimal representation of the float.""" + return self.node.guard_float("", 0).hex() + + +class SymBool: + """ + Like a bool (including magic methods), but redirects all operations on the + wrapped node. This is used in particular to symbolically record operations + in the symbolic shape workflow. + + Unlike regular bools, regular boolean operators will force extra guards instead + of symbolically evaluate. Use the bitwise operators instead to handle this. + """ + + def __init__(self, node): + # This field MUST be named node; C++ binding code assumes that this + # class has a field named node that stores SymNode + self.node = node + + def __bool__(self): + return self.node.bool_() + + def __int__(self): + return builtins.int(self.node.bool_()) + + # Magic methods installed by torch.fx.experimental.sym_node + def __and__(self, other) -> "SymBool": + raise TypeError("type stub not overridden") + + def __or__(self, other) -> "SymBool": + raise TypeError("type stub not overridden") + + # We very carefully define __sym_not__, and not a number of other + # plausible alternatives: + # + # - We do not override __not__ because this is not a real magic + # method; you cannot override the meaning of the not builtin in + # Python. We use the name 'sym_not' to clarify that in user code you + # cannot use the builtin not or operator.not_ or operator.__not__ and + # hit this magic method; you must use our custom sym_not operator. + # + # - We do not override the __invert__ method because SymBool is + # meant to be usable in situations where bool is expected. However, + # bitwise negation ~a does the wrong thing with booleans (because + # bool is a subclass of int, so ~1 = -2 which is not falseish.) + # This would be a giant footgun, so we get around it by defining + # our own operator. Note that bitwise and/or do the right thing, + # so we reuse the conventional operators there for readability. + # + def __sym_not__(self) -> "SymBool": + raise TypeError("type stub not overridden") + + def __sym_ite__(self, then_val, else_val): + raise TypeError("type stub not overridden") + + def __eq__(self, other) -> builtins.bool: + raise TypeError("type stub not overridden") + + def __repr__(self): + return self.node._graph_repr() + + def _sympy_(self): + return self.node.expr + + def __hash__(self): + if self.node.is_constant(): + return hash(self.node.bool_()) + else: + # Force specialization + return hash(builtins.bool(self)) + + +def sym_not(a): + r"""SymInt-aware utility for logical negation. + + Args: + a (SymBool or bool): Object to negate + """ + import sympy + + if overrides.has_torch_function_unary(a): + return overrides.handle_torch_function(sym_not, (a,), a) + if hasattr(a, "__sym_not__"): + return a.__sym_not__() + if isinstance(a, sympy.Basic): + return ~a # type: ignore[operator] + return not a + + +def sym_float(a): + r"""SymInt-aware utility for float casting. + + Args: + a (SymInt, SymFloat, or object): Object to cast + """ + if overrides.has_torch_function_unary(a): + return overrides.handle_torch_function(sym_float, (a,), a) + if isinstance(a, SymFloat): + return a + elif hasattr(a, "__sym_float__"): + return a.__sym_float__() + return builtins.float(a) # type: ignore[operator] + + +def sym_int(a): + r"""SymInt-aware utility for int casting. + + Args: + a (SymInt, SymFloat, or object): Object to cast + """ + if overrides.has_torch_function_unary(a): + return overrides.handle_torch_function(sym_int, (a,), a) + if isinstance(a, SymInt): + return a + elif isinstance(a, SymFloat): + return math.trunc(a) + return builtins.int(a) # type: ignore[operator] + + +def sym_max(a, b): + """ + SymInt-aware utility for max which avoids branching on a < b. + Unlike builtins.max(), this only works for int/float, and it always + promotes to float if any argument is float (unlike builtins.max, which + will faithfully preserve the type of the input argument). + """ + if overrides.has_torch_function((a, b)): + return overrides.handle_torch_function(sym_max, (a, b), a, b) + if isinstance(a, (SymInt, SymFloat)): + return a.__sym_max__(b) + elif isinstance(b, (SymInt, SymFloat)): + # Due to promotion semantics, this is operator is commutative: + # max(1, 1.0) === max(1.0, 1) === 1.0 + return b.__sym_max__(a) + # TODO: Probably can make bool work too, just lazy + + all_types, float_types = __all_and_float_types() + + assert isinstance(a, all_types), type(a) + assert isinstance(b, all_types), type(b) + if isinstance(a, float_types) or isinstance(b, float_types): + return builtins.float(builtins.max(a, b)) # type: ignore[call-overload] + else: + return builtins.max(a, b) # type: ignore[call-overload] + + +def __all_and_float_types() -> tuple[tuple[type, ...], tuple[type, ...]]: + try: + import numpy as np + + all_types: tuple[type, ...] = ( + np.integer, + np.floating, + builtins.int, + builtins.float, + ) + float_types: tuple[type, ...] = (np.floating, builtins.float) + except ModuleNotFoundError: + all_types = (builtins.int, builtins.float) + float_types = (builtins.float,) + + return all_types, float_types + + +def sym_min(a, b): + """SymInt-aware utility for min().""" + if overrides.has_torch_function((a, b)): + return overrides.handle_torch_function(sym_min, (a, b), a, b) + if isinstance(a, (SymInt, SymFloat)): + return a.__sym_min__(b) + elif isinstance(b, (SymInt, SymFloat)): + return b.__sym_min__(a) + + all_types, float_types = __all_and_float_types() + + assert isinstance(a, all_types), type(a) + assert isinstance(b, all_types), type(b) + if isinstance(a, float_types) or isinstance(b, float_types): + return builtins.float(builtins.min(a, b)) # type: ignore[call-overload] + else: + return builtins.min(a, b) # type: ignore[call-overload] + + +def sym_sum(args): + """ + N-ary add which is faster to compute for long lists than iterated binary + addition. Only does something special for integers. + """ + if overrides.has_torch_function(args): + return overrides.handle_torch_function(sym_sum, args, args) + + found = None + for a in args: + if not isinstance(a, (SymInt, builtins.int)): + return builtins.sum(args) + if isinstance(a, SymInt): + found = a.node + if found is None: + return builtins.sum(args) + + from torch.fx.experimental.sym_node import to_node, wrap_node + + return wrap_node(found.sym_sum(tuple(to_node(found, a) for a in args))) + + +# Drop in replacement for math.sqrt, math.sin, math.cos etc +def _get_sym_math_fn(name): + def fn(a): + if overrides.has_torch_function_unary(a): + return overrides.handle_torch_function(fn, (a,), a) + if isinstance(a, SymInt): + a = torch.sym_float(a) + if hasattr(a, f"__sym_{name}__"): + return getattr(a, f"__sym_{name}__")() + return getattr(math, name)(a) + + return fn + + +__fn, __name, __sym_name = None, "", "" +for __name in ( + "sqrt", + "cos", + "cosh", + "sin", + "sinh", + "tan", + "tanh", + "asin", + "acos", + "atan", + "log2", +): + __sym_name = f"_sym_{__name}" + __fn = _get_sym_math_fn(__name) + __fn.__qualname__ = __fn.__name__ = __sym_name + globals()[__sym_name] = __fn + + +del __fn, __name, __sym_name, _get_sym_math_fn + +# Adding temporary shortcut +sym_sqrt = globals()["_sym_sqrt"] +__all__.append("sym_sqrt") + + +def sym_ite(b, t, f): + """SymInt-aware utility for ternary operator (``t if b else f``.)""" + if overrides.has_torch_function((b, t, f)): + return overrides.handle_torch_function(sym_ite, (b, t, f), b, t, f) + assert isinstance(b, (SymBool, builtins.bool)) and type(t) == type(f) + if isinstance(b, SymBool): + return b.__sym_ite__(t, f) + return t if b else f + + +# Create a fresh unbacked int, from an (possibly unbacked int) expression. +def sym_fresh_size(expr): + return torch.tensor(expr).item() + + +# Check to see if we can load C extensions, and if not provide some guidance +# on what the problem might be. +try: + # _initExtension is chosen (arbitrarily) as a sentinel. + from torch._C import _initExtension +except ImportError: + import torch._C as _C_for_compiled_check + + # The __file__ check only works for Python 3.7 and above. + if _C_for_compiled_check.__file__ is None: + raise ImportError( + textwrap.dedent( + """ + Failed to load PyTorch C extensions: + It appears that PyTorch has loaded the `torch/_C` folder + of the PyTorch repository rather than the C extensions which + are expected in the `torch._C` namespace. This can occur when + using the `install` workflow. e.g. + $ python setup.py install && python -c "import torch" + + This error can generally be solved using the `develop` workflow + $ python setup.py develop && python -c "import torch" # This should succeed + or by running Python from a different directory. + """ + ).strip() + ) from None + raise # If __file__ is not None the cause is unknown, so just re-raise. + +# The torch._C submodule is already loaded via `from torch._C import *` above +# Make an explicit reference to the _C submodule to appease linters +from torch import _C as _C + + +__name, __obj = "", None +for __name in dir(_C): + if __name[0] != "_" and not __name.endswith("Base"): + __all__.append(__name) + __obj = getattr(_C, __name) + if callable(__obj) or inspect.isclass(__obj): + if __obj.__module__ != __name__: # "torch" + # TODO: fix their module from C++ side + if __name not in { + "DisableTorchFunctionSubclass", + "DisableTorchFunction", + "Generator", + }: + __obj.__module__ = __name__ # "torch" + elif __name == "TensorBase": + # issue 109438 / pr 109940. Prevent TensorBase from being copied into torch. + delattr(sys.modules[__name__], __name) + +del __name, __obj + +if not TYPE_CHECKING: + # issue 38137 and python issue 43367. Submodules of a C extension are + # non-standard, and attributes of those submodules cannot be pickled since + # pickle expect to be able to import them as "from _C.sub import attr" + # which fails with "_C is not a package + def _import_extension_to_sys_modules(module, memo=None): + if memo is None: + memo = set() + if module in memo: + return + memo.add(module) + module_name = module.__name__ + for name in dir(module): + member = getattr(module, name) + member_name = getattr(member, "__name__", "") + if inspect.ismodule(member) and member_name.startswith(module_name): + sys.modules.setdefault(member_name, member) + # Recurse for submodules (e.g., `_C._dynamo.eval_frame`) + _import_extension_to_sys_modules(member, memo) + + _import_extension_to_sys_modules(_C) + del _import_extension_to_sys_modules + +################################################################################ +# Define basic utilities +################################################################################ + + +def typename(obj: _Any, /) -> str: + """ + String representation of the type of an object. + + This function returns a fully qualified string representation of an object's type. + Args: + obj (object): The object whose type to represent + Returns: + str: the type of the object `o` + Example: + >>> x = torch.tensor([1, 2, 3]) + >>> torch.typename(x) + 'torch.LongTensor' + >>> torch.typename(torch.nn.Parameter) + 'torch.nn.parameter.Parameter' + """ + if isinstance(obj, torch.Tensor): + return obj.type() + + module = getattr(obj, "__module__", "") or "" + qualname = "" + + if hasattr(obj, "__qualname__"): + qualname = obj.__qualname__ + elif hasattr(obj, "__name__"): + qualname = obj.__name__ + else: + module = obj.__class__.__module__ or "" + qualname = obj.__class__.__qualname__ + + if module in {"", "builtins"}: + return qualname + return f"{module}.{qualname}" + + +def is_tensor(obj: _Any, /) -> _TypeIs["torch.Tensor"]: + r"""Returns True if `obj` is a PyTorch tensor. + + Note that this function is simply doing ``isinstance(obj, Tensor)``. + Using that ``isinstance`` check is better for typechecking with mypy, + and more explicit - so it's recommended to use that instead of + ``is_tensor``. + + Args: + obj (object): Object to test + Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> torch.is_tensor(x) + True + + """ + return isinstance(obj, torch.Tensor) + + +def is_storage(obj: _Any, /) -> _TypeIs[_Union["TypedStorage", "UntypedStorage"]]: + r"""Returns True if `obj` is a PyTorch storage object. + + Args: + obj (Object): Object to test + """ + return type(obj) in _storage_classes + + +_GLOBAL_DEVICE_CONTEXT = threading.local() + + +def get_default_device() -> "torch.device": + r"""Gets the default ``torch.Tensor`` to be allocated on ``device``""" + global _GLOBAL_DEVICE_CONTEXT + + from torch.overrides import _get_current_function_mode_stack + from torch.utils._device import DeviceContext + + def _get_device_with_index(device): + if device.index is not None: + return device + else: + # TODO: Call like get_device_index() method corresponding to + # each device type + return torch.tensor([]).device + + # Get device from any active DeviceContext. + device_mode = next( + filter( + lambda mode: isinstance(mode, DeviceContext), + reversed(_get_current_function_mode_stack()), + ), + None, + ) + if device_mode: + device = device_mode.device + return _get_device_with_index(device) + + if hasattr(_GLOBAL_DEVICE_CONTEXT, "device_context"): + device = _GLOBAL_DEVICE_CONTEXT.device_context.device + return _get_device_with_index(device) + else: + return torch.device("cpu") + + +def set_default_device(device: "Device") -> None: + """Sets the default ``torch.Tensor`` to be allocated on ``device``. This + does not affect factory function calls which are called with an explicit + ``device`` argument. Factory calls will be performed as if they + were passed ``device`` as an argument. + + To only temporarily change the default device instead of setting it + globally, use ``with torch.device(device):`` instead. + + The default device is initially ``cpu``. If you set the default tensor + device to another device (e.g., ``cuda``) without a device index, tensors + will be allocated on whatever the current device for the device type, + even after :func:`torch.cuda.set_device` is called. + + .. warning:: + + This function imposes a slight performance cost on every Python + call to the torch API (not just factory functions). If this + is causing problems for you, please comment on + https://github.com/pytorch/pytorch/issues/92701 + + .. note:: + + This doesn't affect functions that create tensors that share the same memory as the input, like: + :func:`torch.from_numpy` and :func:`torch.frombuffer` + + Args: + device (device or string): the device to set as default + + Example:: + + >>> # xdoctest: +SKIP("requires cuda, changes global state") + >>> torch.get_default_device() + device(type='cpu') + >>> torch.set_default_device('cuda') # current device is 0 + >>> torch.get_default_device() + device(type='cuda', index=0) + >>> torch.set_default_device('cuda') + >>> torch.cuda.set_device('cuda:1') # current device is 1 + >>> torch.get_default_device() + device(type='cuda', index=1) + >>> torch.set_default_device('cuda:1') + >>> torch.get_default_device() + device(type='cuda', index=1) + + """ + global _GLOBAL_DEVICE_CONTEXT + if hasattr(_GLOBAL_DEVICE_CONTEXT, "device_context"): + device_context = _GLOBAL_DEVICE_CONTEXT.device_context + if device_context is not None: + device_context.__exit__(None, None, None) + + if device is None: + device_context = None + else: + from torch.utils._device import DeviceContext + + device_context = DeviceContext(device) + device_context.__enter__() + _GLOBAL_DEVICE_CONTEXT.device_context = device_context + + +def set_default_tensor_type(t: _Union[type["torch.Tensor"], str], /) -> None: + r""" + .. warning:: + + This function is deprecated as of PyTorch 2.1, please use :func:`torch.set_default_dtype()` and + :func:`torch.set_default_device()` as alternatives. + + Sets the default ``torch.Tensor`` type to floating point tensor type + ``t``. This type will also be used as default floating point type for + type inference in :func:`torch.tensor`. + + The default floating point tensor type is initially ``torch.FloatTensor``. + + Args: + t (type or string): the floating point tensor type or its name + + Example:: + + >>> # xdoctest: +SKIP("Other tests may have changed the default type. Can we reset it?") + >>> torch.tensor([1.2, 3]).dtype # initial default for floating point is torch.float32 + torch.float32 + >>> torch.set_default_tensor_type(torch.DoubleTensor) + >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor + torch.float64 + + """ + if isinstance(t, str): + t = _import_dotted_name(t) + _C._set_default_tensor_type(t) + + +def set_default_dtype(d: "torch.dtype", /) -> None: + r""" + + Sets the default floating point dtype to :attr:`d`. Supports floating point dtype + as inputs. Other dtypes will cause torch to raise an exception. + + When PyTorch is initialized its default floating point dtype is torch.float32, + and the intent of set_default_dtype(torch.float64) is to facilitate NumPy-like + type inference. The default floating point dtype is used to: + + 1. Implicitly determine the default complex dtype. When the default floating type is float16, + the default complex dtype is complex32. For float32, the default complex dtype is complex64. + For float64, it is complex128. For bfloat16, an exception will be raised because + there is no corresponding complex type for bfloat16. + 2. Infer the dtype for tensors constructed using Python floats or complex Python + numbers. See examples below. + 3. Determine the result of type promotion between bool and integer tensors and + Python floats and complex Python numbers. + + Args: + d (:class:`torch.dtype`): the floating point dtype to make the default. + + Example: + >>> # xdoctest: +SKIP("Other tests may have changed the default type. Can we reset it?") + >>> # initial default for floating point is torch.float32 + >>> # Python floats are interpreted as float32 + >>> torch.tensor([1.2, 3]).dtype + torch.float32 + >>> # initial default for floating point is torch.complex64 + >>> # Complex Python numbers are interpreted as complex64 + >>> torch.tensor([1.2, 3j]).dtype + torch.complex64 + + >>> torch.set_default_dtype(torch.float64) + >>> # Python floats are now interpreted as float64 + >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor + torch.float64 + >>> # Complex Python numbers are now interpreted as complex128 + >>> torch.tensor([1.2, 3j]).dtype # a new complex tensor + torch.complex128 + + >>> torch.set_default_dtype(torch.float16) + >>> # Python floats are now interpreted as float16 + >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor + torch.float16 + >>> # Complex Python numbers are now interpreted as complex128 + >>> torch.tensor([1.2, 3j]).dtype # a new complex tensor + torch.complex32 + + """ + _C._set_default_dtype(d) + + +def use_deterministic_algorithms( + mode: builtins.bool, + *, + warn_only: builtins.bool = False, +) -> None: + r"""Sets whether PyTorch operations must use "deterministic" + algorithms. That is, algorithms which, given the same input, and when + run on the same software and hardware, always produce the same output. + When enabled, operations will use deterministic algorithms when available, + and if only nondeterministic algorithms are available they will throw a + :class:`RuntimeError` when called. + + .. note:: This setting alone is not always enough to make an application + reproducible. Refer to :ref:`reproducibility` for more information. + + .. note:: :func:`torch.set_deterministic_debug_mode` offers an alternative + interface for this feature. + + The following normally-nondeterministic operations will act + deterministically when ``mode=True``: + + * :class:`torch.nn.Conv1d` when called on CUDA tensor + * :class:`torch.nn.Conv2d` when called on CUDA tensor + * :class:`torch.nn.Conv3d` when called on CUDA tensor + * :class:`torch.nn.ConvTranspose1d` when called on CUDA tensor + * :class:`torch.nn.ConvTranspose2d` when called on CUDA tensor + * :class:`torch.nn.ConvTranspose3d` when called on CUDA tensor + * :class:`torch.nn.ReplicationPad1d` when attempting to differentiate a CUDA tensor + * :class:`torch.nn.ReplicationPad2d` when attempting to differentiate a CUDA tensor + * :class:`torch.nn.ReplicationPad3d` when attempting to differentiate a CUDA tensor + * :func:`torch.bmm` when called on sparse-dense CUDA tensors + * :func:`torch.Tensor.__getitem__` when attempting to differentiate a CPU tensor + and the index is a list of tensors + * :func:`torch.Tensor.index_put` with ``accumulate=False`` + * :func:`torch.Tensor.index_put` with ``accumulate=True`` when called on a CPU + tensor + * :func:`torch.Tensor.put_` with ``accumulate=True`` when called on a CPU + tensor + * :func:`torch.Tensor.scatter_add_` when called on a CUDA tensor + * :func:`torch.gather` when called on a CUDA tensor that requires grad + * :func:`torch.index_add` when called on CUDA tensor + * :func:`torch.index_select` when attempting to differentiate a CUDA tensor + * :func:`torch.repeat_interleave` when attempting to differentiate a CUDA tensor + * :func:`torch.Tensor.index_copy` when called on a CPU or CUDA tensor + * :func:`torch.Tensor.scatter` when `src` type is Tensor and called on CUDA tensor + * :func:`torch.Tensor.scatter_reduce` when ``reduce='sum'`` or ``reduce='mean'`` and called on CUDA tensor + + The following normally-nondeterministic operations will throw a + :class:`RuntimeError` when ``mode=True``: + + * :class:`torch.nn.AvgPool3d` when attempting to differentiate a CUDA tensor + * :class:`torch.nn.AdaptiveAvgPool2d` when attempting to differentiate a CUDA tensor + * :class:`torch.nn.AdaptiveAvgPool3d` when attempting to differentiate a CUDA tensor + * :class:`torch.nn.MaxPool3d` when attempting to differentiate a CUDA tensor + * :class:`torch.nn.AdaptiveMaxPool2d` when attempting to differentiate a CUDA tensor + * :class:`torch.nn.FractionalMaxPool2d` when attempting to differentiate a CUDA tensor + * :class:`torch.nn.FractionalMaxPool3d` when attempting to differentiate a CUDA tensor + * :class:`torch.nn.MaxUnpool1d` + * :class:`torch.nn.MaxUnpool2d` + * :class:`torch.nn.MaxUnpool3d` + * :func:`torch.nn.functional.interpolate` when attempting to differentiate a CUDA tensor + and one of the following modes is used: + + - ``linear`` + - ``bilinear`` + - ``bicubic`` + - ``trilinear`` + + * :class:`torch.nn.ReflectionPad1d` when attempting to differentiate a CUDA tensor + * :class:`torch.nn.ReflectionPad2d` when attempting to differentiate a CUDA tensor + * :class:`torch.nn.ReflectionPad3d` when attempting to differentiate a CUDA tensor + * :class:`torch.nn.NLLLoss` when called on a CUDA tensor + * :class:`torch.nn.CTCLoss` when attempting to differentiate a CUDA tensor + * :class:`torch.nn.EmbeddingBag` when attempting to differentiate a CUDA tensor when + ``mode='max'`` + * :func:`torch.Tensor.put_` when ``accumulate=False`` + * :func:`torch.Tensor.put_` when ``accumulate=True`` and called on a CUDA tensor + * :func:`torch.histc` when called on a CUDA tensor + * :func:`torch.bincount` when called on a CUDA tensor and ``weights`` + tensor is given + * :func:`torch.kthvalue` with called on a CUDA tensor + * :func:`torch.median` with indices output when called on a CUDA tensor + * :func:`torch.nn.functional.grid_sample` when attempting to differentiate a CUDA tensor + * :func:`torch.cumsum` when called on a CUDA tensor when dtype is floating point or complex + * :func:`torch.Tensor.scatter_reduce` when ``reduce='prod'`` and called on CUDA tensor + * :func:`torch.Tensor.resize_` when called with a quantized tensor + + In addition, several operations fill uninitialized memory when this setting + is turned on and when + :attr:`torch.utils.deterministic.fill_uninitialized_memory` is turned on. + See the documentation for that attribute for more information. + + A handful of CUDA operations are nondeterministic if the CUDA version is + 10.2 or greater, unless the environment variable ``CUBLAS_WORKSPACE_CONFIG=:4096:8`` + or ``CUBLAS_WORKSPACE_CONFIG=:16:8`` is set. See the CUDA documentation for more + details: ``_ + If one of these environment variable configurations is not set, a :class:`RuntimeError` + will be raised from these operations when called with CUDA tensors: + + * :func:`torch.mm` + * :func:`torch.mv` + * :func:`torch.bmm` + + Note that deterministic operations tend to have worse performance than + nondeterministic operations. + + .. note:: + + This flag does not detect or prevent nondeterministic behavior caused + by calling an inplace operation on a tensor with an internal memory + overlap or by giving such a tensor as the :attr:`out` argument for an + operation. In these cases, multiple writes of different data may target + a single memory location, and the order of writes is not guaranteed. + + Args: + mode (:class:`bool`): If True, makes potentially nondeterministic + operations switch to a deterministic algorithm or throw a runtime + error. If False, allows nondeterministic operations. + + Keyword args: + warn_only (:class:`bool`, optional): If True, operations that do not + have a deterministic implementation will throw a warning instead of + an error. Default: ``False`` + + Example:: + + >>> # xdoctest: +SKIP + >>> torch.use_deterministic_algorithms(True) + + # Forward mode nondeterministic error + >>> torch.randn(10, device='cuda').kthvalue(1) + ... + RuntimeError: kthvalue CUDA does not have a deterministic implementation... + + # Backward mode nondeterministic error + >>> torch.nn.AvgPool3d(1)(torch.randn(3, 4, 5, 6, requires_grad=True).cuda()).sum().backward() + ... + RuntimeError: avg_pool3d_backward_cuda does not have a deterministic implementation... + """ + _C._set_deterministic_algorithms(mode, warn_only=warn_only) + + +def are_deterministic_algorithms_enabled() -> builtins.bool: + r"""Returns True if the global deterministic flag is turned on. Refer to + :func:`torch.use_deterministic_algorithms` documentation for more details. + """ + return _C._get_deterministic_algorithms() + + +def is_deterministic_algorithms_warn_only_enabled() -> builtins.bool: + r"""Returns True if the global deterministic flag is set to warn only. + Refer to :func:`torch.use_deterministic_algorithms` documentation for more + details. + """ + return _C._get_deterministic_algorithms_warn_only() + + +def set_deterministic_debug_mode(debug_mode: _Union[builtins.int, str]) -> None: + r"""Sets the debug mode for deterministic operations. + + .. note:: This is an alternative interface for + :func:`torch.use_deterministic_algorithms`. Refer to that function's + documentation for details about affected operations. + + Args: + debug_mode(str or int): If "default" or 0, don't error or warn on + nondeterministic operations. If "warn" or 1, warn on + nondeterministic operations. If "error" or 2, error on + nondeterministic operations. + """ + + # NOTE: builtins.int is used here because int in this scope resolves + # to torch.int + if not isinstance(debug_mode, (builtins.int, str)): + raise TypeError(f"debug_mode must be str or int, but got {type(debug_mode)}") + + if isinstance(debug_mode, str): + if debug_mode == "default": + debug_mode = 0 + elif debug_mode == "warn": + debug_mode = 1 + elif debug_mode == "error": + debug_mode = 2 + else: + raise RuntimeError( + "invalid value of debug_mode, expected one of `default`, " + f"`warn`, `error`, but got {debug_mode}" + ) + + if debug_mode == 0: + _C._set_deterministic_algorithms(False) + elif debug_mode == 1: + _C._set_deterministic_algorithms(True, warn_only=True) + elif debug_mode == 2: + _C._set_deterministic_algorithms(True) + else: + raise RuntimeError( + f"invalid value of debug_mode, expected 0, 1, or 2, but got {debug_mode}" + ) + + +def get_deterministic_debug_mode() -> builtins.int: + r"""Returns the current value of the debug mode for deterministic + operations. Refer to :func:`torch.set_deterministic_debug_mode` + documentation for more details. + """ + + if _C._get_deterministic_algorithms(): + if _C._get_deterministic_algorithms_warn_only(): + return 1 + else: + return 2 + else: + return 0 + + +def get_float32_matmul_precision() -> str: + r"""Returns the current value of float32 matrix multiplication precision. Refer to + :func:`torch.set_float32_matmul_precision` documentation for more details. + """ + return _C._get_float32_matmul_precision() + + +def set_float32_matmul_precision(precision: str) -> None: + r"""Sets the internal precision of float32 matrix multiplications. + + Running float32 matrix multiplications in lower precision may significantly increase + performance, and in some programs the loss of precision has a negligible impact. + + Supports three settings: + + * "highest", float32 matrix multiplications use the float32 datatype (24 mantissa + bits with 23 bits explicitly stored) for internal computations. + * "high", float32 matrix multiplications either use the TensorFloat32 datatype (10 + mantissa bits explicitly stored) or treat each float32 number as the sum of two bfloat16 numbers + (approximately 16 mantissa bits with 14 bits explicitly stored), if the appropriate fast matrix multiplication + algorithms are available. Otherwise float32 matrix multiplications are computed + as if the precision is "highest". See below for more information on the bfloat16 + approach. + * "medium", float32 matrix multiplications use the bfloat16 datatype (8 mantissa + bits with 7 bits explicitly stored) for internal computations, if a fast matrix multiplication algorithm + using that datatype internally is available. Otherwise float32 + matrix multiplications are computed as if the precision is "high". + + When using "high" precision, float32 multiplications may use a bfloat16-based algorithm + that is more complicated than simply truncating to some smaller number mantissa bits + (e.g. 10 for TensorFloat32, 7 for bfloat16 explicitly stored). Refer to [Henry2019]_ for a complete + description of this algorithm. To briefly explain here, the first step is to realize + that we can perfectly encode a single float32 number as the sum of three bfloat16 + numbers (because float32 has 23 mantissa bits while bfloat16 has 7 explicitly stored, and both have the + same number of exponent bits). This means that the product of two float32 numbers can + be exactly given by the sum of nine products of bfloat16 numbers. We can then trade + accuracy for speed by dropping some of these products. The "high" precision algorithm + specifically keeps only the three most significant products, which conveniently excludes + all of the products involving the last 8 mantissa bits of either input. This means that + we can represent our inputs as the sum of two bfloat16 numbers rather than three. + Because bfloat16 fused-multiply-add (FMA) instructions are typically >10x faster than + float32 ones, it's faster to do three multiplications and 2 additions with bfloat16 + precision than it is to do a single multiplication with float32 precision. + + .. [Henry2019] http://arxiv.org/abs/1904.06376 + + .. note:: + + This does not change the output dtype of float32 matrix multiplications, + it controls how the internal computation of the matrix multiplication is performed. + + .. note:: + + This does not change the precision of convolution operations. Other flags, + like `torch.backends.cudnn.allow_tf32`, may control the precision of convolution + operations. + + .. note:: + + This flag currently only affects one native device type: CUDA. + If "high" or "medium" are set then the TensorFloat32 datatype will be used + when computing float32 matrix multiplications, equivalent to setting + `torch.backends.cuda.matmul.allow_tf32 = True`. When "highest" (the default) + is set then the float32 datatype is used for internal computations, equivalent + to setting `torch.backends.cuda.matmul.allow_tf32 = False`. + + Args: + precision(str): can be set to "highest" (default), "high", or "medium" (see above). + + """ + _C._set_float32_matmul_precision(precision) + + +def set_warn_always(b: builtins.bool, /) -> None: + r"""When this flag is False (default) then some PyTorch warnings may only + appear once per process. This helps avoid excessive warning information. + Setting it to True causes these warnings to always appear, which may be + helpful when debugging. + + Args: + b (:class:`bool`): If True, force warnings to always be emitted + If False, set to the default behaviour + """ + _C._set_warnAlways(b) + + +def is_warn_always_enabled() -> builtins.bool: + r"""Returns True if the global warn_always flag is turned on. Refer to + :func:`torch.set_warn_always` documentation for more details. + """ + return _C._get_warnAlways() + + +################################################################################ +# Define error checking functions +################################################################################ + +# These error checking functions must be kept consistent with their C++ +# equivalents. Their C++ equivalents are mentioned where applicable. + + +def _check_with( + error_type, + cond: _Union[builtins.bool, SymBool], + message: _Callable[[], str], +): # noqa: F811 + if not isinstance(cond, (builtins.bool, SymBool)): + raise TypeError(f"cond must be a bool, but got {type(cond)}") + + from torch.fx.experimental.symbolic_shapes import expect_true + + if expect_true(cond): + return + + # error_type must be a subclass of Exception and not subclass of Warning + assert issubclass(error_type, Exception) and not issubclass(error_type, Warning) + + if message is None: + message_evaluated = ( + "Expected cond to be True, but got False. (Could this error " + "message be improved? If so, please report an enhancement request " + "to PyTorch.)" + ) + + else: + if not callable(message): + raise TypeError("message must be a callable") + + message_evaluated = str(message()) + + raise error_type(message_evaluated) + + +def _check(cond, message=None): # noqa: F811 + r"""Throws error containing an optional message if the specified condition + is False. + + Error type: ``RuntimeError`` + + C++ equivalent: ``TORCH_CHECK`` + + Args: + cond (:class:`bool`): If False, throw error + + message (Callable, optional): Callable that returns either a string or + an object that has a ``__str__()`` method to be used as the error + message. Default: ``None`` + """ + _check_with(RuntimeError, cond, message) + + +def _check_is_size(i, message=None, *, max=None): + """Checks that a given integer is a valid size (i.e., is non-negative). + You should use this over ``_check(i >= 0)`` because it can prevent + ``GuardOnDataDependentSymNode`` exceptions by opting yourself into alternate + semantics for ``guard_size_oblivious`` tests that treat values 0 and 1 + equivalently to all other values. + + When max is not None, this specifies an upper bound equivalent to + ``_check(i <= max)``. This bound is also subject to alternate semantics: + in ``guard_size_oblivious`` tests, we assume that a constant max bound is + treated equivalently to all other values. Symbolic max bounds are not yet + supported. + + NB: Do NOT use this in contexts where a -1 size would be valid (indicating + to infer the size from context, or if you should wrap-around or truncate). + Only use this if the only valid value is an honest to goodness size. + """ + # This is responsible for the expect_true + _check(i >= 0, message) + from torch.fx.experimental.symbolic_shapes import _advise_is_size + + _advise_is_size(i) + + if max is not None: + _check(i <= max, message) + + from torch.fx.experimental.symbolic_shapes import _advise_is_bounded + + _advise_is_bounded(i, max) + + +def _check_index(cond, message=None): # noqa: F811 + r"""Throws error containing an optional message if the specified condition + is False. + + Error type: ``IndexError`` + + C++ equivalent: ``TORCH_CHECK_INDEX`` + + Args: + cond (:class:`bool`): If False, throw error + + message (Callable, optional): Callable that returns either a string or + an object that has a ``__str__()`` method to be used as the error + message. Default: ``None`` + """ + _check_with(IndexError, cond, message) + + +def _check_value(cond, message=None): # noqa: F811 + r"""Throws error containing an optional message if the specified condition + is False. + + Error type: ``ValueError`` + + C++ equivalent: ``TORCH_CHECK_VALUE`` + + Args: + cond (:class:`bool`): If False, throw error + + message (Callable, optional): Callable that returns either a string or + an object that has a ``__str__()`` method to be used as the error + message. Default: ``None`` + """ + _check_with(ValueError, cond, message) + + +def _check_type(cond, message=None): # noqa: F811 + r"""Throws error containing an optional message if the specified condition + is False. + + Error type: ``TypeError`` + + C++ equivalent: ``TORCH_CHECK_TYPE`` + + Args: + cond (:class:`bool`): If False, throw error + + message (Callable, optional): Callable that returns either a string or + an object that has a ``__str__()`` method to be used as the error + message. Default: ``None`` + """ + _check_with(TypeError, cond, message) + + +def _check_not_implemented(cond, message=None): # noqa: F811 + r"""Throws error containing an optional message if the specified condition + is False. + + Error type: ``NotImplementedError`` + + C++ equivalent: ``TORCH_CHECK_NOT_IMPLEMENTED`` + + Args: + cond (:class:`bool`): If False, throw error + + message (Callable, optional): Callable that returns either a string or + an object that has a ``__str__()`` method to be used as the error + message. Default: ``None`` + """ + _check_with(NotImplementedError, cond, message) + + +def _check_tensor_all_with(error_type, cond, message=None): # noqa: F811 + if not is_tensor(cond): + raise TypeError(f"cond must be a tensor, but got {type(cond)}") + + if not cond.dtype == torch.bool: + raise TypeError(f"cond tensor must have dtype torch.bool, but got {cond.dtype}") + + _check_with(error_type, cond._is_all_true().item(), message) # type: ignore[arg-type] + + +# C++ equivalent: `TORCH_CHECK_TENSOR_ALL` +def _check_tensor_all(cond, message=None): # noqa: F811 + r"""Throws error containing an optional message if the specified condition + is False. + + Error type: ``RuntimeError`` + + C++ equivalent: ``TORCH_CHECK_TENSOR_ALL`` + + Args: + cond (:class:`torch.Tensor`): Tensor of dtype ``torch.bool``. If any + element is ``False``, throw error + + message (Callable, optional): Callable that returns either a string or + an object that has a ``__str__()`` method to be used as the error + message. Default: ``None`` + """ + _check_tensor_all_with(RuntimeError, cond, message) + + +################################################################################ +# Define numeric constants +################################################################################ + +# For Python Array API (https://data-apis.org/array-api/latest/API_specification/constants.html) and +# NumPy consistency (https://numpy.org/devdocs/reference/constants.html) +from math import e, inf, nan, pi + + +newaxis: None = None + +__all__.extend(["e", "pi", "nan", "inf", "newaxis"]) + +################################################################################ +# Define Storage and Tensor classes +################################################################################ + +from torch._tensor import Tensor # usort: skip + +# needs to be after torch.Tensor is defined to avoid circular dependencies +from torch import storage as storage # usort: skip +from torch.storage import ( + _LegacyStorage, + _StorageBase, + _warn_typed_storage_removal, + TypedStorage, + UntypedStorage, +) + + +# NOTE: New Storage classes should never be added. When adding a new +# dtype, use torch.storage.TypedStorage directly. +class ByteStorage(_LegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal(stacklevel=3) + return self._dtype + + @classproperty + def _dtype(self): + return torch.uint8 + + +class DoubleStorage(_LegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal(stacklevel=3) + return self._dtype + + @classproperty + def _dtype(self): + return torch.double + + +class FloatStorage(_LegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal(stacklevel=3) + return self._dtype + + @classproperty + def _dtype(self): + return torch.float + + +class HalfStorage(_LegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal(stacklevel=3) + return self._dtype + + @classproperty + def _dtype(self): + return torch.half + + +class LongStorage(_LegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal(stacklevel=3) + return self._dtype + + @classproperty + def _dtype(self): + return torch.long + + +class IntStorage(_LegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal(stacklevel=3) + return self._dtype + + @classproperty + def _dtype(self): + return torch.int + + +class ShortStorage(_LegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal(stacklevel=3) + return self._dtype + + @classproperty + def _dtype(self): + return torch.short + + +class CharStorage(_LegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal(stacklevel=3) + return self._dtype + + @classproperty + def _dtype(self): + return torch.int8 + + +class BoolStorage(_LegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal(stacklevel=3) + return self._dtype + + @classproperty + def _dtype(self): + return torch.bool + + +class BFloat16Storage(_LegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal(stacklevel=3) + return self._dtype + + @classproperty + def _dtype(self): + return torch.bfloat16 + + +class ComplexDoubleStorage(_LegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal(stacklevel=3) + return self._dtype + + @classproperty + def _dtype(self): + return torch.cdouble + + +class ComplexFloatStorage(_LegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal(stacklevel=3) + return self._dtype + + @classproperty + def _dtype(self): + return torch.cfloat + + +class QUInt8Storage(_LegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal(stacklevel=3) + return self._dtype + + @classproperty + def _dtype(self): + return torch.quint8 + + +class QInt8Storage(_LegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal(stacklevel=3) + return self._dtype + + @classproperty + def _dtype(self): + return torch.qint8 + + +class QInt32Storage(_LegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal(stacklevel=3) + return self._dtype + + @classproperty + def _dtype(self): + return torch.qint32 + + +class QUInt4x2Storage(_LegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal(stacklevel=3) + return self._dtype + + @classproperty + def _dtype(self): + return torch.quint4x2 + + +class QUInt2x4Storage(_LegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal(stacklevel=3) + return self._dtype + + @classproperty + def _dtype(self): + return torch.quint2x4 + + +_storage_classes: set[type[_Union[TypedStorage, UntypedStorage]]] = { + UntypedStorage, + DoubleStorage, + FloatStorage, + LongStorage, + IntStorage, + ShortStorage, + CharStorage, + ByteStorage, + HalfStorage, + BoolStorage, + QUInt8Storage, + QInt8Storage, + QInt32Storage, + BFloat16Storage, + ComplexFloatStorage, + ComplexDoubleStorage, + QUInt4x2Storage, + QUInt2x4Storage, + TypedStorage, +} + +# The _tensor_classes set is initialized by the call to initialize_python_bindings. +_tensor_classes: set[type["torch.Tensor"]] = set() + +# If you edit these imports, please update torch/__init__.py.in as well +from torch import amp as amp, random as random, serialization as serialization +from torch._tensor_str import set_printoptions +from torch.amp import autocast, GradScaler +from torch.random import get_rng_state, initial_seed, manual_seed, seed, set_rng_state +from torch.serialization import load, save + + +################################################################################ +# Initialize extension +################################################################################ + + +# Shared memory manager needs to know the exact location of manager executable +def _manager_path(): + if _running_with_deploy() or platform.system() == "Windows": + return b"" + path = get_file_path("torch", "bin", "torch_shm_manager") + prepare_multiprocessing_environment(get_file_path("torch")) + if not os.path.exists(path): + raise RuntimeError("Unable to find torch_shm_manager at " + path) + return path.encode("utf-8") + + +_C._initExtension(_manager_path()) + +del _manager_path + +# Appease the type checker: it can't deal with direct setting of globals(). +# Note that we will see "too many" functions when reexporting this way; there +# is not a good way to fix this problem. Perhaps, try to redesign VariableFunctions +# so that this import is good enough +if TYPE_CHECKING: + # Some type signatures pulled in from _VariableFunctions here clash with + # signatures already imported. For now these clashes are ignored; see + # PR #43339 for details. + from torch._C._VariableFunctions import * # type: ignore[assignment, misc] # noqa: F403 + + # Fixup segment_reduce visibility + _segment_reduce = segment_reduce + del segment_reduce # noqa: F821 + +# Ops not to be exposed in `torch` namespace, +# mostly helper ops. +PRIVATE_OPS = ("unique_dim",) + +__name, __obj = "", None +for __name in dir(_C._VariableFunctions): + if __name.startswith("__") or __name in PRIVATE_OPS: + continue + __obj = getattr(_C._VariableFunctions, __name) + __obj.__module__ = __name__ # "torch" + # Hide some APIs that should not be public + if __name == "segment_reduce": + # TODO: Once the undocumented FC window is passed, remove the line below + globals()[__name] = __obj + __name = "_" + __name + globals()[__name] = __obj + if not __name.startswith("_"): + __all__.append(__name) + +del __name, __obj + +################################################################################ +# Add torch.dtype instances to the public API +################################################################################ + +import torch + + +__all__.extend( + name for name in dir(torch) if isinstance(getattr(torch, name), torch.dtype) +) + +################################################################################ +# Import TorchDynamo's lazy APIs to avoid circular dependenices +################################################################################ + +# needs to be before from torch.functional import * to avoid circular dependencies +from torch._compile import _disable_dynamo # usort: skip + +################################################################################ +# Import interface functions defined in Python +################################################################################ + +# needs to be after the above ATen bindings so we can overwrite from Python side +from torch import _VF as _VF, functional as functional # usort: skip +from torch.functional import * # usort: skip # noqa: F403 + +################################################################################ +# Remove unnecessary members +################################################################################ + +del _StorageBase +del _LegacyStorage + +################################################################################ +# Define _assert +################################################################################ + + +# needs to be before the submodule imports to avoid circular dependencies +def _assert(condition, message): + r"""A wrapper around Python's assert which is symbolically traceable.""" + if type(condition) is not torch.Tensor and overrides.has_torch_function( + (condition,) + ): + return overrides.handle_torch_function( + _assert, (condition,), condition, message + ) + assert condition, message + + +################################################################################ +# Import most common subpackages +################################################################################ + +# Use the redundant form so that type checkers know that these are a part of +# the public API. The "regular" import lines are there solely for the runtime +# side effect of adding to the imported module's members for other users. + +# needs to be before import torch.nn as nn to avoid circular dependencies +from torch.autograd import ( # usort: skip + enable_grad as enable_grad, + inference_mode as inference_mode, + no_grad as no_grad, + set_grad_enabled as set_grad_enabled, +) + +from torch import ( + __config__ as __config__, + __future__ as __future__, + _awaits as _awaits, + accelerator as accelerator, + autograd as autograd, + backends as backends, + cpu as cpu, + cuda as cuda, + distributed as distributed, + distributions as distributions, + fft as fft, + futures as futures, + hub as hub, + jit as jit, + linalg as linalg, + mps as mps, + mtia as mtia, + multiprocessing as multiprocessing, + nested as nested, + nn as nn, + optim as optim, + overrides as overrides, + profiler as profiler, + sparse as sparse, + special as special, + testing as testing, + types as types, + utils as utils, + xpu as xpu, +) +from torch.signal import windows as windows + + +# Quantized, sparse, AO, etc. should be last to get imported, as nothing +# is expected to depend on them. +from torch import ao as ao # usort: skip + +# nn.quant* depends on ao -- so should be after those. +import torch.nn.intrinsic +import torch.nn.qat +import torch.nn.quantizable +import torch.nn.quantized + + +_C._init_names(list(_storage_classes)) + +# attach docstrings to torch and tensor functions +from torch import _size_docs, _storage_docs, _tensor_docs, _torch_docs + + +del _torch_docs, _tensor_docs, _storage_docs, _size_docs + + +def compiled_with_cxx11_abi() -> builtins.bool: + r"""Returns whether PyTorch was built with _GLIBCXX_USE_CXX11_ABI=1""" + return True + + +from torch import _library as _library, _ops as _ops + + +# Import the ops and classes "namespace" +from torch._ops import ops as ops # usort: skip +from torch._classes import classes as classes # usort: skip + +sys.modules.setdefault(f"{__name__}.ops", ops) +sys.modules.setdefault(f"{__name__}.classes", classes) + +# quantization depends on torch.fx and torch.ops +# Import quantization +from torch import quantization as quantization # usort: skip + +# Import the quasi random sampler +from torch import quasirandom as quasirandom # usort: skip + +# If you are seeing this, it means that this call site was not checked if +# the memory format could be preserved, and it was switched to old default +# behaviour of contiguous +legacy_contiguous_format = contiguous_format # defined by _C._initExtension() + +# Register fork handler to initialize OpenMP in child processes (see gh-28389) +from torch.multiprocessing._atfork import register_after_fork + + +register_after_fork(torch.get_num_threads) +del register_after_fork + +# Import tools that require fully imported torch (for applying +# torch.jit.script as a decorator, for instance): +from torch._lobpcg import lobpcg as lobpcg + + +# These were previously defined in native_functions.yaml and appeared on the +# `torch` namespace, but we moved them to c10 dispatch to facilitate custom +# class usage. We add these lines here to preserve backward compatibility. +quantized_lstm = ops.aten.quantized_lstm +quantized_gru = ops.aten.quantized_gru + +# Import experimental masked operations support. See +# [RFC-0016](https://github.com/pytorch/rfcs/pull/27) for more +# information. +from torch import masked as masked + +# Import removed ops with error message about removal +from torch._linalg_utils import ( # type: ignore[misc] + _symeig as symeig, + eig, + lstsq, + matrix_rank, + solve, +) +from torch.utils.dlpack import from_dlpack, to_dlpack + + +class _TorchCompileInductorWrapper: + compiler_name = "inductor" + + def __init__(self, mode, options, dynamic): + from torch._inductor.compiler_bisector import CompilerBisector + + self.config: dict[str, _Any] = {} + self.dynamic = dynamic + self.apply_mode(mode) + self.apply_options(options) + self.apply_options(CompilerBisector.get_config_change("inductor")) + + cuda_version = None + if hasattr(torch, "version"): + from torch.torch_version import TorchVersion + + cuda_version = TorchVersion(getattr(torch.version, "cuda", "0.0")) + + if self.config.get("triton.cudagraphs", False) and ( + (cuda_version and cuda_version < "12.6") + or not profiler_allow_cudagraph_cupti_lazy_reinit_cuda12() + ): + os.environ["DISABLE_CUPTI_LAZY_REINIT"] = "1" + # FIXME: CUDA Graph does not work well with CUPTI teardown. + # 1) crashes on 1st lazy CUPTI re-init after teardown (CUDA 11) + # 2) crashes on 2nd non-lazy CUPTI re-init after teardown (CUDA 12) + # Workaround: turn off CUPTI teardown when using CUDA Graphs. + os.environ["TEARDOWN_CUPTI"] = "0" + + def __eq__(self, other): + return ( + isinstance(other, _TorchCompileInductorWrapper) + and self.config == other.config + and self.dynamic == other.dynamic + ) + + def apply_mode(self, mode: _Optional[str]): + if mode and mode != "default": + from torch._inductor import list_mode_options + + self.apply_options(list_mode_options(mode, self.dynamic)) + + def apply_options(self, options: _Optional[dict[str, _Any]]): + if not options: + return + + from torch._inductor import config + + current_config: dict[str, _Any] = config.get_config_copy() + + for key, val in options.items(): + attr_name = key.replace("-", "_") + if attr_name not in current_config: + raise RuntimeError( + f"Unexpected optimization option {key}, known options are {list(current_config.keys())}" + ) + attr_type = config.get_type(attr_name) # type: ignore[attr-defined] + # Subscriptable generic types don't support isinstance so skip the type + # check. There doesn't seem to be a good way of checking membership without + # 3rd party libraries. + if _get_origin(attr_type) is None: + if not isinstance(val, attr_type): + val_type_str = type(val).__name__ + expected_type_str = type(current_config[attr_name]).__name__ + raise RuntimeError( + f"Unexpected type of attr {key}, got {val_type_str} should be {expected_type_str}" + ) + self.config[attr_name] = val + + def __call__(self, model_, inputs_): + from torch._inductor.compile_fx import compile_fx + + return compile_fx(model_, inputs_, config_patches=self.config) + + def get_compiler_config(self): + from torch._inductor.compile_fx import get_patched_config_dict + + return get_patched_config_dict(config_patches=self.config) + + def reset(self): + from torch._inductor import config + + if "triton.cudagraphs" in self.config or config.triton.cudagraphs: + if self.config.get("triton.cudagraphs", True): + from torch._inductor.cudagraph_trees import reset_cudagraph_trees + + reset_cudagraph_trees() + + +class _TorchCompileWrapper: + def __init__(self, backend, mode, options, dynamic): + from torch._dynamo.backends.registry import lookup_backend + + if isinstance(backend, str): + self.compiler_name = backend + elif hasattr(backend, "__name__"): + self.compiler_name = backend.__name__ + else: + self.compiler_name = str(backend) + self.dynamic = dynamic + self.compiler_fn = lookup_backend(backend) + self.kwargs = {} + # only pass the args if they non-empty + if mode and mode != "default": + self.kwargs["mode"] = mode + if options: + self.kwargs["options"] = options + + def __eq__(self, other): + return ( + isinstance(other, _TorchCompileWrapper) + and self.compiler_fn == other.compiler_fn + and self.kwargs == other.kwargs + and self.dynamic == other.dynamic + ) + + def __call__(self, model_, inputs_): + return self.compiler_fn(model_, inputs_, **self.kwargs) + + def reset(self): + if hasattr(self.compiler_fn, "reset"): + self.compiler_fn.reset() + + +_InputT = _ParamSpec("_InputT") +_RetT = _TypeVar("_RetT") + + +@_overload +def compile( + model: _Callable[_InputT, _RetT], + *, + fullgraph: builtins.bool = False, + dynamic: _Optional[builtins.bool] = None, + backend: _Union[str, _Callable] = "inductor", + mode: _Union[str, None] = None, + options: _Optional[ + dict[str, _Union[str, builtins.int, builtins.bool, _Callable]] + ] = None, + disable: builtins.bool = False, +) -> _Callable[_InputT, _RetT]: ... + + +@_overload +def compile( + model: None = None, + *, + fullgraph: builtins.bool = False, + dynamic: _Optional[builtins.bool] = None, + backend: _Union[str, _Callable] = "inductor", + mode: _Union[str, None] = None, + options: _Optional[ + dict[str, _Union[str, builtins.int, builtins.bool, _Callable]] + ] = None, + disable: builtins.bool = False, +) -> _Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]]: ... + + +def compile( + model: _Optional[_Callable[_InputT, _RetT]] = None, + *, + fullgraph: builtins.bool = False, + dynamic: _Optional[builtins.bool] = None, + backend: _Union[str, _Callable] = "inductor", + mode: _Union[str, None] = None, + options: _Optional[ + dict[str, _Union[str, builtins.int, builtins.bool, _Callable]] + ] = None, + disable: builtins.bool = False, +) -> _Union[ + _Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]], + _Callable[_InputT, _RetT], +]: + """ + Optimizes given model/function using TorchDynamo and specified backend. + If you are compiling an :class:`torch.nn.Module`, you can also use :meth:`torch.nn.Module.compile` + to compile the module inplace without changing its structure. + + Concretely, for every frame executed within the compiled region, we will attempt + to compile it and cache the compiled result on the code object for future + use. A single frame may be compiled multiple times if previous compiled + results are not applicable for subsequent calls (this is called a "guard + failure), you can use TORCH_LOGS=guards to debug these situations. + Multiple compiled results can be associated with a frame up to + ``torch._dynamo.config.recompile_limit``, which defaults to 8; at which + point we will fall back to eager. Note that compile caches are per + *code object*, not frame; if you dynamically create multiple copies of a + function, they will all share the same code cache. + + Args: + model (Callable or None): Module/function to optimize + fullgraph (bool): If False (default), torch.compile attempts to discover compileable regions + in the function that it will optimize. If True, then we require that the entire function be + capturable into a single graph. If this is not possible (that is, if there are graph breaks), + then this will raise an error. + dynamic (bool or None): Use dynamic shape tracing. When this is True, we will up-front attempt + to generate a kernel that is as dynamic as possible to avoid recompilations when + sizes change. This may not always work as some operations/optimizations will + force specialization; use TORCH_LOGS=dynamic to debug overspecialization. + When this is False, we will NEVER generate dynamic kernels, we will always specialize. + By default (None), we automatically detect if dynamism has occurred and compile a more + dynamic kernel upon recompile. + backend (str or Callable): backend to be used + + - "inductor" is the default backend, which is a good balance between performance and overhead + + - Non experimental in-tree backends can be seen with `torch._dynamo.list_backends()` + + - Experimental or debug in-tree backends can be seen with `torch._dynamo.list_backends(None)` + + - To register an out-of-tree custom backend: + https://pytorch.org/docs/main/torch.compiler_custom_backends.html#registering-custom-backends + mode (str): Can be either "default", "reduce-overhead", "max-autotune" or "max-autotune-no-cudagraphs" + + - "default" is the default mode, which is a good balance between performance and overhead + + - "reduce-overhead" is a mode that reduces the overhead of python with CUDA graphs, + useful for small batches. Reduction of overhead can come at the cost of more memory + usage, as we will cache the workspace memory required for the invocation so that we + do not have to reallocate it on subsequent runs. Reduction of overhead is not guaranteed + to work; today, we only reduce overhead for CUDA only graphs which do not mutate inputs. + There are other circumstances where CUDA graphs are not applicable; use TORCH_LOG=perf_hints + to debug. + + - "max-autotune" is a mode that leverages Triton or template based matrix multiplications + on supported devices and Triton based convolutions on GPU. + It enables CUDA graphs by default on GPU. + + - "max-autotune-no-cudagraphs" is a mode similar to "max-autotune" but without CUDA graphs + + - To see the exact configs that each mode sets you can call `torch._inductor.list_mode_options()` + + options (dict): A dictionary of options to pass to the backend. Some notable ones to try out are + + - `epilogue_fusion` which fuses pointwise ops into templates. Requires `max_autotune` to also be set + + - `max_autotune` which will profile to pick the best matmul configuration + + - `fallback_random` which is useful when debugging accuracy issues + + - `shape_padding` which pads matrix shapes to better align loads on GPUs especially for tensor cores + + - `triton.cudagraphs` which will reduce the overhead of python with CUDA graphs + + - `trace.enabled` which is the most useful debugging flag to turn on + + - `trace.graph_diagram` which will show you a picture of your graph after fusion + + - `guard_filter_fn` that controls which dynamo guards are saved with compilations. + This is an unsafe feature and there is no backward compatibility guarantee provided + for dynamo guards as data types. + For stable helper functions to use, see the documentations in `torch.compiler`, for example: + - `torch.compiler.skip_guard_on_inbuilt_nn_modules_unsafe` + - `torch.compiler.skip_guard_on_all_nn_modules_unsafe` + - `torch.compiler.keep_tensor_guards_unsafe` + + - For inductor you can see the full list of configs that it supports by calling `torch._inductor.list_options()` + disable (bool): Turn torch.compile() into a no-op for testing + + Example:: + + @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True) + def foo(x): + return torch.sin(x) + torch.cos(x) + + """ + import sysconfig + + _C._log_api_usage_once("torch.compile") + if sys.version_info >= (3, 14): + raise RuntimeError("torch.compile is not supported on Python 3.14+") + elif sysconfig.get_config_var("Py_GIL_DISABLED") == 1 and sys.version_info < ( + 3, + 13, + 3, + ): + raise RuntimeError( + "torch.compile is not supported on Python < 3.13.3 built with GIL disabled. " + "Please use Python 3.13.3+." + ) + + # Decorator mode + if model is None: + + def fn(model: _Callable[_InputT, _RetT]) -> _Callable[_InputT, _RetT]: + if model is None: + raise RuntimeError("Model can't be None") + return compile( + model, + fullgraph=fullgraph, + dynamic=dynamic, + backend=backend, + mode=mode, + options=options, + disable=disable, + ) + + return fn + + if mode is not None and options is not None: + raise RuntimeError( + "Either mode or options can be specified, but both can't be specified at the same time." + ) + if mode is None and options is None: + mode = "default" + + from torch._inductor.compiler_bisector import CompilerBisector + + if bisect_backend := CompilerBisector.get_backend(): + backend = bisect_backend + + guard_filter_fn = None + if options and isinstance(options, dict): + guard_filter_fn = options.pop("guard_filter_fn", None) + + if backend == "inductor": + backend = _TorchCompileInductorWrapper(mode, options, dynamic) + else: + backend = _TorchCompileWrapper(backend, mode, options, dynamic) + + return torch._dynamo.optimize( + backend=backend, + nopython=fullgraph, + dynamic=dynamic, + disable=disable, + guard_filter_fn=guard_filter_fn, + )(model) # type: ignore[return-value] + + +def _register_device_module(device_type, module): + r"""Register an external runtime module of the specific :attr:`device_type` + supported by torch. + + After the :attr:`module` is registered correctly, the user can refer + the external runtime module as part of torch with attribute torch.xxx. + """ + # Make sure the device_type represent a supported device type for torch. + device_type = torch.device(device_type).type + m = sys.modules[__name__] + if hasattr(m, device_type): + raise RuntimeError( + f"The runtime module of '{device_type}' has already " + f"been registered with '{getattr(m, device_type)}'" + ) + setattr(m, device_type, module) + torch_module_name = ".".join([__name__, device_type]) + sys.modules[torch_module_name] = module + + +from torch import ( + export as export, + func as func, + library as library, + return_types as return_types, +) +from torch._higher_order_ops import cond as cond, while_loop as while_loop +from torch.func import vmap as vmap + + +if not TYPE_CHECKING: + from torch import _meta_registrations + +# Enable CUDA Sanitizer +if "TORCH_CUDA_SANITIZER" in os.environ: + import torch.cuda._sanitizer as csan + + csan.enable_cuda_sanitizer() + +# Populate magic methods on SymInt and SymFloat +import torch.fx.experimental.sym_node +from torch import fx as fx + + +# Register MPS specific decomps +torch.backends.mps._init() + +if not _running_with_deploy(): + from torch import compiler as compiler + + class _TritonLibrary: + lib = torch.library.Library("triton", "DEF") + ops_table: dict[tuple[str, str], _Callable] = {} + + @classmethod + def registerOp(cls, op_key, full_schema, op_impl, dispatch_key): + if (op_key, dispatch_key) not in cls.ops_table: + cls.lib.define(full_schema) + cls.lib.impl("triton::" + op_key, op_impl, dispatch_key) + cls.ops_table[(op_key, dispatch_key)] = op_impl + + return cls.ops_table[(op_key, dispatch_key)] + + +# Deprecated attributes +_deprecated_attrs = { + "has_mps": torch.backends.mps.is_built, + "has_cuda": torch.backends.cuda.is_built, + "has_cudnn": torch.backends.cudnn.is_available, + "has_mkldnn": torch.backends.mkldnn.is_available, +} + +if TYPE_CHECKING: + # Import the following modules during type checking to enable code intelligence features, + # such as auto-completion in tools like pylance, even when these modules are not explicitly + # imported in user code. + from torch import ( + _dynamo as _dynamo, + _inductor as _inductor, + _subclasses as _subclasses, + onnx as onnx, + ) + +else: + _lazy_modules = { + "_dynamo", + "_inductor", + "_export", + # ONNX must be imported after _dynamo, _ops, _subclasses, fx, func and jit + "onnx", + } + + def __getattr__(name): + # Deprecated attrs + replacement = _deprecated_attrs.get(name) + if replacement is not None: + import warnings + + warnings.warn( + f"'{name}' is deprecated, please use '{replacement.__module__}.{replacement.__name__}()'", + stacklevel=2, + ) + return replacement() + + # Lazy modules + if name in _lazy_modules: + return importlib.import_module(f".{name}", __name__) + + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") + + +@functools.cache +def get_device_module(device: _Optional[_Union[torch.device, str]] = None): + """ + Returns the module associated with a given device(e.g., torch.device('cuda'), "mtia:0", "xpu", ...). + If no device is given, return the module for the current accelerator or CPU if none is present. + """ + if isinstance(device, torch.device): + device_module_name = device.type + elif isinstance(device, str): + device_module_name = torch.device(device).type + elif device is None: + # Using default accelerator type. If no accelerator is available, it automatically returns CPU device. + device_module_name = torch._C._get_accelerator().type + else: + raise RuntimeError( + f"Invalid value of device '{device}', expect torch.device, str, or None" + ) + device_module = getattr(torch, device_module_name, None) + if device_module is None: + raise RuntimeError( + f"Device '{device_module_name}' does not have a corresponding module registered as 'torch.{device_module_name}'." + ) + return device_module + + +def _constrain_as_size( + symbol, + min: _Optional[builtins.int] = None, + max: _Optional[builtins.int] = None, +): + """ + This indicates that a given int is size-like, and can be used in any context where a size is expected. + You will typically use this when reading out integers from Tensors, e.g., max.item() or lengths.tolist() + which then need to be used as tensor constructors. Providing these assertions to PyTorch can help resolve + GuardOnDataDependentSymNode errors upon export, since we cannot guard on unbacked SymInts. + + This function has unusual semantics in some circumstances in framework + code, we will treat this int as >= 2 (when we do a size-oblivious guard). + This makes it easier to use the unbacked int in size contexts, + as we will often attempt to guard on a size being zero/one + (e.g., when computing the contiguity of a tensor, or testing if + broadcasting can occur), which will not work on unbacked SymInts. + However, if we conservatively assume that the size is not zero/one, we will + end up with a graph that will still work even if the size is zero/one. + + For more details, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit + ``` + """ + torch.sym_constrain_range_for_size(symbol, min=min, max=max) + + +from torch import _logging + + +_logging._init_logs() + + +def _import_device_backends(): + """ + Leverage the Python plugin mechanism to load out-of-the-tree device extensions. + See this RFC: https://github.com/pytorch/pytorch/issues/122468 + """ + from importlib.metadata import entry_points + + group_name = "torch.backends" + if sys.version_info < (3, 10): + backend_extensions = entry_points().get(group_name, ()) + else: + backend_extensions = entry_points(group=group_name) + + for backend_extension in backend_extensions: + try: + # Load the extension + entrypoint = backend_extension.load() + # Call the entrypoint + entrypoint() + except Exception as err: + raise RuntimeError( + f"Failed to load the backend extension: {backend_extension.name}. " + f"You can disable extension auto-loading with TORCH_DEVICE_BACKEND_AUTOLOAD=0." + ) from err + + +def _is_device_backend_autoload_enabled() -> builtins.bool: + """ + Whether autoloading out-of-the-tree device extensions is enabled. + The switch depends on the value of the environment variable + `TORCH_DEVICE_BACKEND_AUTOLOAD`. + + Returns: + bool: Whether to enable autoloading the extensions. Enabled by default. + + Examples: + >>> torch._is_device_backend_autoload_enabled() + True + """ + # enabled by default + return os.getenv("TORCH_DEVICE_BACKEND_AUTOLOAD", "1") == "1" + + +def _as_tensor_fullprec(t): + """ + Like torch.as_tensor, but when given Python data types it will keep + them in full precision. Used for calling convention for Dynamo. + """ + ty = type(t) + if ty is builtins.float: + return torch.as_tensor(t, dtype=torch.float64) + elif ty is builtins.int: + return torch.as_tensor(t, dtype=torch.int64) + else: + return torch.as_tensor(t) + + +# `_import_device_backends` should be kept at the end to ensure +# all the other functions in this module that may be accessed by +# an autoloaded backend are defined +if _is_device_backend_autoload_enabled(): + _import_device_backends() diff --git a/phivenv/Lib/site-packages/torch/_appdirs.py b/phivenv/Lib/site-packages/torch/_appdirs.py new file mode 100644 index 0000000000000000000000000000000000000000..f8e652471d5582706c06b0e39212f716a95f95db --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_appdirs.py @@ -0,0 +1,666 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) 2005-2010 ActiveState Software Inc. +# Copyright (c) 2013 Eddy Petrișor + +# flake8: noqa + +""" +This file is directly from +https://github.com/ActiveState/appdirs/blob/3fe6a83776843a46f20c2e5587afcffe05e03b39/appdirs.py + +The license of https://github.com/ActiveState/appdirs copied below: + + +# This is the MIT license + +Copyright (c) 2010 ActiveState Software Inc. + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be included +in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +""" + +"""Utilities for determining application-specific dirs. + +See for details and usage. +""" +# Dev Notes: +# - Windows "Known Folders": https://learn.microsoft.com/en-us/windows/win32/shell/csidl +# - macOS File System Programming Guide: https://developer.apple.com/library/archive/documentation/FileManagement/Conceptual/FileSystemProgrammingGuide/Introduction/Introduction.html +# - XDG spec for Un*x: https://standards.freedesktop.org/basedir-spec/basedir-spec-latest.html + +__version__ = "1.4.4" +__version_info__ = tuple(int(segment) for segment in __version__.split(".")) + + +import os +import sys + + +unicode = str + +if sys.platform.startswith("java"): + import platform + + os_name = platform.java_ver()[3][0] + if os_name.startswith("Windows"): # "Windows XP", "Windows 7", etc. + system = "win32" + elif os_name.startswith("Mac"): # "Mac OS X", etc. + system = "darwin" + else: # "Linux", "SunOS", "FreeBSD", etc. + # Setting this to "linux2" is not ideal, but only Windows or Mac + # are actually checked for and the rest of the module expects + # *sys.platform* style strings. + system = "linux2" +else: + system = sys.platform + + +def user_data_dir(appname=None, appauthor=None, version=None, roaming=False): + r"""Return full path to the user-specific data dir for this application. + + "appname" is the name of application. + If None, just the system directory is returned. + "appauthor" (only used on Windows) is the name of the + appauthor or distributing body for this application. Typically + it is the owning company name. This falls back to appname. You may + pass False to disable it. + "version" is an optional version path element to append to the + path. You might want to use this if you want multiple versions + of your app to be able to run independently. If used, this + would typically be ".". + Only applied when appname is present. + "roaming" (boolean, default False) can be set True to use the Windows + roaming appdata directory. That means that for users on a Windows + network setup for roaming profiles, this user data will be + sync'd on login. See + + for a discussion of issues. + + Typical user data directories are: + Mac OS X: ~/Library/Application Support/ + Unix: ~/.local/share/ # or in $XDG_DATA_HOME, if defined + Win XP (not roaming): C:\Documents and Settings\\Application Data\\ + Win XP (roaming): C:\Documents and Settings\\Local Settings\Application Data\\ + Win 7 (not roaming): C:\Users\\AppData\Local\\ + Win 7 (roaming): C:\Users\\AppData\Roaming\\ + + For Unix, we follow the XDG spec and support $XDG_DATA_HOME. + That means, by default "~/.local/share/". + """ + if system == "win32": + if appauthor is None: + appauthor = appname + const = roaming and "CSIDL_APPDATA" or "CSIDL_LOCAL_APPDATA" + path = os.path.normpath(_get_win_folder(const)) + if appname: + if appauthor is not False: + path = os.path.join(path, appauthor, appname) + else: + path = os.path.join(path, appname) + elif system == "darwin": + path = os.path.expanduser("~/Library/Application Support/") + if appname: + path = os.path.join(path, appname) + else: + path = os.getenv("XDG_DATA_HOME", os.path.expanduser("~/.local/share")) + if appname: + path = os.path.join(path, appname) + if appname and version: + path = os.path.join(path, version) + return path + + +def site_data_dir(appname=None, appauthor=None, version=None, multipath=False): + r"""Return full path to the user-shared data dir for this application. + + "appname" is the name of application. + If None, just the system directory is returned. + "appauthor" (only used on Windows) is the name of the + appauthor or distributing body for this application. Typically + it is the owning company name. This falls back to appname. You may + pass False to disable it. + "version" is an optional version path element to append to the + path. You might want to use this if you want multiple versions + of your app to be able to run independently. If used, this + would typically be ".". + Only applied when appname is present. + "multipath" is an optional parameter only applicable to *nix + which indicates that the entire list of data dirs should be + returned. By default, the first item from XDG_DATA_DIRS is + returned, or '/usr/local/share/', + if XDG_DATA_DIRS is not set + + Typical site data directories are: + Mac OS X: /Library/Application Support/ + Unix: /usr/local/share/ or /usr/share/ + Win XP: C:\Documents and Settings\All Users\Application Data\\ + Vista: (Fail! "C:\ProgramData" is a hidden *system* directory on Vista.) + Win 7: C:\ProgramData\\ # Hidden, but writeable on Win 7. + + For Unix, this is using the $XDG_DATA_DIRS[0] default. + + WARNING: Do not use this on Windows. See the Vista-Fail note above for why. + """ + if system == "win32": + if appauthor is None: + appauthor = appname + path = os.path.normpath(_get_win_folder("CSIDL_COMMON_APPDATA")) + if appname: + if appauthor is not False: + path = os.path.join(path, appauthor, appname) + else: + path = os.path.join(path, appname) + elif system == "darwin": + path = os.path.expanduser("/Library/Application Support") + if appname: + path = os.path.join(path, appname) + else: + # XDG default for $XDG_DATA_DIRS + # only first, if multipath is False + path = os.getenv( + "XDG_DATA_DIRS", os.pathsep.join(["/usr/local/share", "/usr/share"]) + ) + pathlist = [ + os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep) + ] + if appname: + if version: + appname = os.path.join(appname, version) + pathlist = [os.sep.join([x, appname]) for x in pathlist] + + if multipath: + path = os.pathsep.join(pathlist) + else: + path = pathlist[0] + return path + + if appname and version: + path = os.path.join(path, version) + return path + + +def user_config_dir(appname=None, appauthor=None, version=None, roaming=False): + r"""Return full path to the user-specific config dir for this application. + + "appname" is the name of application. + If None, just the system directory is returned. + "appauthor" (only used on Windows) is the name of the + appauthor or distributing body for this application. Typically + it is the owning company name. This falls back to appname. You may + pass False to disable it. + "version" is an optional version path element to append to the + path. You might want to use this if you want multiple versions + of your app to be able to run independently. If used, this + would typically be ".". + Only applied when appname is present. + "roaming" (boolean, default False) can be set True to use the Windows + roaming appdata directory. That means that for users on a Windows + network setup for roaming profiles, this user data will be + sync'd on login. See + + for a discussion of issues. + + Typical user config directories are: + Mac OS X: ~/Library/Preferences/ + Unix: ~/.config/ # or in $XDG_CONFIG_HOME, if defined + Win *: same as user_data_dir + + For Unix, we follow the XDG spec and support $XDG_CONFIG_HOME. + That means, by default "~/.config/". + """ + if system == "win32": + path = user_data_dir(appname, appauthor, None, roaming) + elif system == "darwin": + path = os.path.expanduser("~/Library/Preferences/") + if appname: + path = os.path.join(path, appname) + else: + path = os.getenv("XDG_CONFIG_HOME", os.path.expanduser("~/.config")) + if appname: + path = os.path.join(path, appname) + if appname and version: + path = os.path.join(path, version) + return path + + +def site_config_dir(appname=None, appauthor=None, version=None, multipath=False): + r"""Return full path to the user-shared data dir for this application. + + "appname" is the name of application. + If None, just the system directory is returned. + "appauthor" (only used on Windows) is the name of the + appauthor or distributing body for this application. Typically + it is the owning company name. This falls back to appname. You may + pass False to disable it. + "version" is an optional version path element to append to the + path. You might want to use this if you want multiple versions + of your app to be able to run independently. If used, this + would typically be ".". + Only applied when appname is present. + "multipath" is an optional parameter only applicable to *nix + which indicates that the entire list of config dirs should be + returned. By default, the first item from XDG_CONFIG_DIRS is + returned, or '/etc/xdg/', if XDG_CONFIG_DIRS is not set + + Typical site config directories are: + Mac OS X: same as site_data_dir + Unix: /etc/xdg/ or $XDG_CONFIG_DIRS[i]/ for each value in + $XDG_CONFIG_DIRS + Win *: same as site_data_dir + Vista: (Fail! "C:\ProgramData" is a hidden *system* directory on Vista.) + + For Unix, this is using the $XDG_CONFIG_DIRS[0] default, if multipath=False + + WARNING: Do not use this on Windows. See the Vista-Fail note above for why. + """ + if system == "win32": + path = site_data_dir(appname, appauthor) + if appname and version: + path = os.path.join(path, version) + elif system == "darwin": + path = os.path.expanduser("/Library/Preferences") + if appname: + path = os.path.join(path, appname) + else: + # XDG default for $XDG_CONFIG_DIRS + # only first, if multipath is False + path = os.getenv("XDG_CONFIG_DIRS", "/etc/xdg") + pathlist = [ + os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep) + ] + if appname: + if version: + appname = os.path.join(appname, version) + pathlist = [os.sep.join([x, appname]) for x in pathlist] + + if multipath: + path = os.pathsep.join(pathlist) + else: + path = pathlist[0] + return path + + +def user_cache_dir(appname=None, appauthor=None, version=None, opinion=True): + r"""Return full path to the user-specific cache dir for this application. + + "appname" is the name of application. + If None, just the system directory is returned. + "appauthor" (only used on Windows) is the name of the + appauthor or distributing body for this application. Typically + it is the owning company name. This falls back to appname. You may + pass False to disable it. + "version" is an optional version path element to append to the + path. You might want to use this if you want multiple versions + of your app to be able to run independently. If used, this + would typically be ".". + Only applied when appname is present. + "opinion" (boolean) can be False to disable the appending of + "Cache" to the base app data dir for Windows. See + discussion below. + + Typical user cache directories are: + Mac OS X: ~/Library/Caches/ + Unix: ~/.cache/ (XDG default) + Win XP: C:\Documents and Settings\\Local Settings\Application Data\\\Cache + Vista: C:\Users\\AppData\Local\\\Cache + + On Windows the only suggestion in the MSDN docs is that local settings go in + the `CSIDL_LOCAL_APPDATA` directory. This is identical to the non-roaming + app data dir (the default returned by `user_data_dir` above). Apps typically + put cache data somewhere *under* the given dir here. Some examples: + ...\Mozilla\Firefox\Profiles\\Cache + ...\Acme\SuperApp\Cache\1.0 + OPINION: This function appends "Cache" to the `CSIDL_LOCAL_APPDATA` value. + This can be disabled with the `opinion=False` option. + """ + if system == "win32": + if appauthor is None: + appauthor = appname + path = os.path.normpath(_get_win_folder("CSIDL_LOCAL_APPDATA")) + if appname: + if appauthor is not False: + path = os.path.join(path, appauthor, appname) + else: + path = os.path.join(path, appname) + if opinion: + path = os.path.join(path, "Cache") + elif system == "darwin": + path = os.path.expanduser("~/Library/Caches") + if appname: + path = os.path.join(path, appname) + else: + path = os.getenv("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + if appname: + path = os.path.join(path, appname) + if appname and version: + path = os.path.join(path, version) + return path + + +def user_state_dir(appname=None, appauthor=None, version=None, roaming=False): + r"""Return full path to the user-specific state dir for this application. + + "appname" is the name of application. + If None, just the system directory is returned. + "appauthor" (only used on Windows) is the name of the + appauthor or distributing body for this application. Typically + it is the owning company name. This falls back to appname. You may + pass False to disable it. + "version" is an optional version path element to append to the + path. You might want to use this if you want multiple versions + of your app to be able to run independently. If used, this + would typically be ".". + Only applied when appname is present. + "roaming" (boolean, default False) can be set True to use the Windows + roaming appdata directory. That means that for users on a Windows + network setup for roaming profiles, this user data will be + sync'd on login. See + + for a discussion of issues. + + Typical user state directories are: + Mac OS X: same as user_data_dir + Unix: ~/.local/state/ # or in $XDG_STATE_HOME, if defined + Win *: same as user_data_dir + + For Unix, we follow this Debian proposal + to extend the XDG spec and support $XDG_STATE_HOME. + + That means, by default "~/.local/state/". + """ + if system in ["win32", "darwin"]: + path = user_data_dir(appname, appauthor, None, roaming) + else: + path = os.getenv("XDG_STATE_HOME", os.path.expanduser("~/.local/state")) + if appname: + path = os.path.join(path, appname) + if appname and version: + path = os.path.join(path, version) + return path + + +def user_log_dir(appname=None, appauthor=None, version=None, opinion=True): + r"""Return full path to the user-specific log dir for this application. + + "appname" is the name of application. + If None, just the system directory is returned. + "appauthor" (only used on Windows) is the name of the + appauthor or distributing body for this application. Typically + it is the owning company name. This falls back to appname. You may + pass False to disable it. + "version" is an optional version path element to append to the + path. You might want to use this if you want multiple versions + of your app to be able to run independently. If used, this + would typically be ".". + Only applied when appname is present. + "opinion" (boolean) can be False to disable the appending of + "Logs" to the base app data dir for Windows, and "log" to the + base cache dir for Unix. See discussion below. + + Typical user log directories are: + Mac OS X: ~/Library/Logs/ + Unix: ~/.cache//log # or under $XDG_CACHE_HOME if defined + Win XP: C:\Documents and Settings\\Local Settings\Application Data\\\Logs + Vista: C:\Users\\AppData\Local\\\Logs + + On Windows the only suggestion in the MSDN docs is that local settings + go in the `CSIDL_LOCAL_APPDATA` directory. (Note: I'm interested in + examples of what some windows apps use for a logs dir.) + + OPINION: This function appends "Logs" to the `CSIDL_LOCAL_APPDATA` + value for Windows and appends "log" to the user cache dir for Unix. + This can be disabled with the `opinion=False` option. + """ + if system == "darwin": + path = os.path.join(os.path.expanduser("~/Library/Logs"), appname) + elif system == "win32": + path = user_data_dir(appname, appauthor, version) + version = False + if opinion: + path = os.path.join(path, "Logs") + else: + path = user_cache_dir(appname, appauthor, version) + version = False + if opinion: + path = os.path.join(path, "log") + if appname and version: + path = os.path.join(path, version) + return path + + +class AppDirs(object): + """Convenience wrapper for getting application dirs.""" + + def __init__( + self, appname=None, appauthor=None, version=None, roaming=False, multipath=False + ): + self.appname = appname + self.appauthor = appauthor + self.version = version + self.roaming = roaming + self.multipath = multipath + + @property + def user_data_dir(self): + return user_data_dir( + self.appname, self.appauthor, version=self.version, roaming=self.roaming + ) + + @property + def site_data_dir(self): + return site_data_dir( + self.appname, self.appauthor, version=self.version, multipath=self.multipath + ) + + @property + def user_config_dir(self): + return user_config_dir( + self.appname, self.appauthor, version=self.version, roaming=self.roaming + ) + + @property + def site_config_dir(self): + return site_config_dir( + self.appname, self.appauthor, version=self.version, multipath=self.multipath + ) + + @property + def user_cache_dir(self): + return user_cache_dir(self.appname, self.appauthor, version=self.version) + + @property + def user_state_dir(self): + return user_state_dir(self.appname, self.appauthor, version=self.version) + + @property + def user_log_dir(self): + return user_log_dir(self.appname, self.appauthor, version=self.version) + + +# ---- internal support stuff + + +def _get_win_folder_from_registry(csidl_name): + """This is a fallback technique at best. I'm not sure if using the + registry for this guarantees us the correct answer for all CSIDL_* + names. + """ + import winreg as _winreg + + shell_folder_name = { + "CSIDL_APPDATA": "AppData", + "CSIDL_COMMON_APPDATA": "Common AppData", + "CSIDL_LOCAL_APPDATA": "Local AppData", + }[csidl_name] + + key = _winreg.OpenKey( + _winreg.HKEY_CURRENT_USER, + r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders", + ) + dir, _type = _winreg.QueryValueEx(key, shell_folder_name) + return dir + + +def _get_win_folder_with_pywin32(csidl_name): + from win32com.shell import shell, shellcon + + dir = shell.SHGetFolderPath(0, getattr(shellcon, csidl_name), 0, 0) + # Try to make this a unicode path because SHGetFolderPath does + # not return unicode strings when there is unicode data in the + # path. + try: + dir = unicode(dir) + + # Downgrade to short path name if have highbit chars. See + # . + has_high_char = False + for c in dir: + if ord(c) > 255: + has_high_char = True + break + if has_high_char: + try: + import win32api + + dir = win32api.GetShortPathName(dir) + except ImportError: + pass + except UnicodeError: + pass + return dir + + +def _get_win_folder_with_ctypes(csidl_name): + import ctypes + + csidl_const = { + "CSIDL_APPDATA": 26, + "CSIDL_COMMON_APPDATA": 35, + "CSIDL_LOCAL_APPDATA": 28, + }[csidl_name] + + buf = ctypes.create_unicode_buffer(1024) + ctypes.windll.shell32.SHGetFolderPathW(None, csidl_const, None, 0, buf) + + # Downgrade to short path name if have highbit chars. See + # . + has_high_char = False + for c in buf: + if ord(c) > 255: + has_high_char = True + break + if has_high_char: + buf2 = ctypes.create_unicode_buffer(1024) + if ctypes.windll.kernel32.GetShortPathNameW(buf.value, buf2, 1024): + buf = buf2 + + return buf.value + + +def _get_win_folder_with_jna(csidl_name): + import array + + from com.sun import jna + from com.sun.jna.platform import win32 + + buf_size = win32.WinDef.MAX_PATH * 2 + buf = array.zeros("c", buf_size) + shell = win32.Shell32.INSTANCE + shell.SHGetFolderPath( + None, + getattr(win32.ShlObj, csidl_name), + None, + win32.ShlObj.SHGFP_TYPE_CURRENT, + buf, + ) + dir = jna.Native.toString(buf.tostring()).rstrip("\0") + + # Downgrade to short path name if have highbit chars. See + # . + has_high_char = False + for c in dir: + if ord(c) > 255: + has_high_char = True + break + if has_high_char: + buf = array.zeros("c", buf_size) + kernel = win32.Kernel32.INSTANCE + if kernel.GetShortPathName(dir, buf, buf_size): + dir = jna.Native.toString(buf.tostring()).rstrip("\0") + + return dir + + +if system == "win32": + try: + import win32com.shell + + _get_win_folder = _get_win_folder_with_pywin32 + except ImportError: + try: + from ctypes import windll + + _get_win_folder = _get_win_folder_with_ctypes + except ImportError: + try: + import com.sun.jna + + _get_win_folder = _get_win_folder_with_jna + except ImportError: + _get_win_folder = _get_win_folder_from_registry + + +# ---- self test code + +if __name__ == "__main__": + appname = "MyApp" + appauthor = "MyCompany" + + props = ( + "user_data_dir", + "user_config_dir", + "user_cache_dir", + "user_state_dir", + "user_log_dir", + "site_data_dir", + "site_config_dir", + ) + + print(f"-- app dirs {__version__} --") + + print("-- app dirs (with optional 'version')") + dirs = AppDirs(appname, appauthor, version="1.0") + for prop in props: + print(f"{prop}: {getattr(dirs, prop)}") + + print("\n-- app dirs (without optional 'version')") + dirs = AppDirs(appname, appauthor) + for prop in props: + print(f"{prop}: {getattr(dirs, prop)}") + + print("\n-- app dirs (without optional 'appauthor')") + dirs = AppDirs(appname) + for prop in props: + print(f"{prop}: {getattr(dirs, prop)}") + + print("\n-- app dirs (with disabled 'appauthor')") + dirs = AppDirs(appname, appauthor=False) + for prop in props: + print(f"{prop}: {getattr(dirs, prop)}") diff --git a/phivenv/Lib/site-packages/torch/_classes.py b/phivenv/Lib/site-packages/torch/_classes.py new file mode 100644 index 0000000000000000000000000000000000000000..2048495cd21443e439948328d0605f1eb4140f12 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_classes.py @@ -0,0 +1,56 @@ +# mypy: allow-untyped-defs +import types + +import torch._C + + +class _ClassNamespace(types.ModuleType): + def __init__(self, name): + super().__init__("torch.classes" + name) + self.name = name + + def __getattr__(self, attr): + proxy = torch._C._get_custom_class_python_wrapper(self.name, attr) + if proxy is None: + raise RuntimeError(f"Class {self.name}.{attr} not registered!") + return proxy + + +class _Classes(types.ModuleType): + __file__ = "_classes.py" + + def __init__(self) -> None: + super().__init__("torch.classes") + + def __getattr__(self, name): + namespace = _ClassNamespace(name) + setattr(self, name, namespace) + return namespace + + @property + def loaded_libraries(self): + return torch.ops.loaded_libraries + + def load_library(self, path): + """ + Loads a shared library from the given path into the current process. + + The library being loaded may run global initialization code to register + custom classes with the PyTorch JIT runtime. This allows dynamically + loading custom classes. For this, you should compile your class + and the static registration code into a shared library object, and then + call ``torch.classes.load_library('path/to/libcustom.so')`` to load the + shared object. + + After the library is loaded, it is added to the + ``torch.classes.loaded_libraries`` attribute, a set that may be inspected + for the paths of all libraries loaded using this function. + + Args: + path (str): A path to a shared library to load. + """ + torch.ops.load_library(path) + + +# The classes "namespace" +classes = _Classes() diff --git a/phivenv/Lib/site-packages/torch/_compile.py b/phivenv/Lib/site-packages/torch/_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..cfbf8f9a39e6bf63b9eb25159abe91f32c05691a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_compile.py @@ -0,0 +1,59 @@ +""" +APIs related to torch.compile which lazily import torch._dynamo to avoid +circular dependencies. +""" + +import functools +from typing import Callable, Literal, Optional, overload, TypeVar, Union +from typing_extensions import ParamSpec + + +_T = TypeVar("_T") +_P = ParamSpec("_P") + + +@overload +def _disable_dynamo( + fn: Callable[_P, _T], recursive: bool = True +) -> Callable[_P, _T]: ... + + +@overload +def _disable_dynamo( + fn: Literal[None] = None, recursive: bool = True +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: ... + + +def _disable_dynamo( + fn: Optional[Callable[_P, _T]] = None, recursive: bool = True +) -> Union[Callable[_P, _T], Callable[[Callable[_P, _T]], Callable[_P, _T]]]: + """ + This API should be only used inside torch, external users should still use + torch._dynamo.disable. The main goal of this API is to avoid circular + imports issues that is common while using _dynamo.disable inside torch + itself. + + This API avoids it by lazily importing torch._dynamo from the import time to + the invocation of the decorated function. + """ + if fn is not None: + + @functools.wraps(fn) + def inner(*args: _P.args, **kwargs: _P.kwargs) -> _T: + # cache this on the first invocation to avoid adding too much overhead. + disable_fn = getattr(fn, "__dynamo_disable", None) + if disable_fn is None: + import torch._dynamo + + # We can safely turn off functools.wraps here because the inner + # already wraps fn in the outer scope. + disable_fn = torch._dynamo.disable(fn, recursive, wrapping=False) + fn.__dynamo_disable = disable_fn # type: ignore[attr-defined] + + return disable_fn(*args, **kwargs) + + return inner + else: + # decorator usage like @_disable_dynamo(recursive=False). The resulting + # object expects the original decorated function as the arg. + return functools.partial(_disable_dynamo, recursive=recursive) diff --git a/phivenv/Lib/site-packages/torch/_custom_ops.py b/phivenv/Lib/site-packages/torch/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..86f506e5a21c4b8ce8dff3595775822296cf5b0e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_custom_ops.py @@ -0,0 +1,326 @@ +# mypy: allow-untyped-defs +import inspect + +from torch._custom_op.impl import ( + _custom_op_with_schema, + _find_custom_op, + infer_schema, + parse_qualname, + validate_namespace, +) +from torch.library import get_ctx + + +__all__ = [ + "custom_op", + "impl", + "impl_abstract", + "get_ctx", + "impl_save_for_backward", + "impl_backward", +] + + +def custom_op(qualname, func_or_schema=None): + r"""Register a new custom operator + + In PyTorch, defining an op (short for "operator") is a two step-process: + - we need to define the op (by providing an operator name and schema) + - we need to implement behavior for how the operator interacts with + various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc. + + This entrypoint defines the custom operator (the first step) + you must then perform the second step by calling various + ``impl_*`` APIs. + + This API may be used as a decorator (see examples). + + For a detailed guide on custom ops, please see + https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk + + Arguments: + qualname (str): Should be a string that looks like + "namespace::operator_name". Operators in PyTorch need a namespace to + avoid name collisions; a given operator may only be created once. + If you are writing a Python library, we recommend the namespace to + be the name of your top-level module. + func_or_schema (Union[Callable, str]): Each PyTorch operator needs a + schema that tells PyTorch the types of the inputs/outputs. + If this is a Callable, we will automatically infer the schema from + the type annotations on the function (see examples). Otherwise, + if you don't want to use type annotations, you may provide us the + schema string. + + Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> import torch + >>> import numpy as np + >>> from torch import Tensor + >>> + >>> # Step 1: define the custom op. + >>> # We need to provide the API a "prototype function" + >>> # (a function that returns NotImplementedError), from which + >>> # we will infer the types of the inputs and outputs. + >>> @torch._custom_ops.custom_op("mylibrary::numpy_sin") + >>> def numpy_sin(x: Tensor) -> Tensor: + >>> raise NotImplementedError + >>> + >>> # The custom op is now accessible via the torch.ops module: + >>> torch.ops.mylibrary.numpy_sin + >>> + >>> # Step 2: Register an implementation for various PyTorch subsystems + >>> + >>> # Register an implementation for CPU tensors + >>> @torch._custom_ops.impl("mylibrary::numpy_sin", device_types="cpu") + >>> def numpy_sin_impl_cpu(x): + >>> return torch.from_numpy(np.sin(x.numpy())) + >>> + >>> # Register an implementation for CUDA tensors + >>> @torch._custom_ops.impl("mylibrary::numpy_sin", device_types="cuda") + >>> def numpy_sin_impl_cuda(x): + >>> return torch.from_numpy(np.sin(x.cpu().numpy())).to(x.device) + >>> + >>> x = torch.randn(3) + >>> torch.ops.mylibrary.numpy_sin(x) # calls numpy_sin_impl_cpu + >>> + >>> x_cuda = x.cuda() + >>> torch.ops.mylibrary.numpy_sin(x) # calls numpy_sin_impl_cuda + + """ + ns, name = parse_qualname(qualname) + validate_namespace(ns) + + def inner(func): + if not inspect.isfunction(func): + raise ValueError( + f"custom_op(...)(func): Expected `func` to be a Python " + f"function, got: {type(func)}" + ) + + if func.__name__ != name: + raise ValueError( + f"custom_op(qualname='{qualname}', ...)(func): expected `func` " + f"to have name '{name}' but got '{func.__name__}'. " + f"Please either change the name of `func` or the qualname that " + f"is passed to `custom_op`" + ) + + schema = infer_schema(func, mutates_args=()) + _custom_op_with_schema(qualname, schema) + return func + + if func_or_schema is None: + return inner + if isinstance(func_or_schema, str): + _custom_op_with_schema(qualname, func_or_schema) + else: + return inner(func_or_schema) + + +def impl(qualname, *, device_types=("cpu", "cuda"), func=None): + r"""Register an implementation for a device type for this custom op. + + If the op is passed multiple Tensor inputs with different device + types, it will dispatch to the registered implementation for the highest + priority device type among those present. + The supported device types, in order of priority, are {'cuda', 'cpu'}. + + This API may be used as a decorator (see examples). + + For a detailed guide on custom ops, please see + https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk + + Arguments: + device_types (str or Iterable[str]): the device type(s) to register the function for. + + Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> import torch + >>> import numpy as np + >>> from torch import Tensor + >>> + >>> # Step 1: define the custom op. + >>> # We need to provide the API a "prototype function" + >>> # (a function that returns NotImplementedError), from which + >>> # we will infer the types of the inputs and outputs. + >>> @torch._custom_ops.custom_op("mylibrary::numpy_cos") + >>> def numpy_cos(x: Tensor) -> Tensor: + >>> raise NotImplementedError + >>> + >>> # The custom op is now accessible via the torch.ops module: + >>> torch.ops.mylibrary.numpy_cos + >>> + >>> # Step 2: Register an implementation for various PyTorch subsystems + >>> + >>> # Register an implementation for CPU tensors + >>> @torch._custom_ops.impl("mylibrary::numpy_cos", device_types="cpu") + >>> def numpy_cos_impl_cpu(x): + >>> return torch.from_numpy(np.cos(x.numpy())) + >>> + >>> # Register an implementation for CUDA tensors + >>> @torch._custom_ops.impl("mylibrary::numpy_cos", device_types="cuda") + >>> def numpy_cos_impl_cuda(x): + >>> return torch.from_numpy(np.cos(x.cpu().numpy())).to(x.device) + >>> + >>> x = torch.randn(3) + >>> torch.ops.mylibrary.numpy_cos(x) # calls numpy_cos_impl_cpu + >>> + >>> x_cuda = x.cuda() + >>> torch.ops.mylibrary.numpy_cos(x) # calls numpy_cos_impl_cuda + + """ + + def inner(func): + custom_op = _find_custom_op(qualname, also_check_torch_library=True) + custom_op.impl(device_types, _stacklevel=3)(func) + return func + + if func is None: + return inner + return inner(func) + + +def impl_abstract(qualname, *, func=None): + r"""Register an abstract implementation for this operator. + + An "abstract implementation" specifies the behavior of this operator on + Tensors that carry no data. Given some input Tensors with certain properties + (sizes/strides/storage_offset/device), it specifies what the properties of + the output Tensors are. + + The abstract implementation has the same signature as the operator. + It is run for both FakeTensors and meta tensors. To write an abstract + implementation, assume that all Tensor inputs to the operator are + regular CPU/CUDA/Meta tensors, but they do not have storage, and + you are trying to return regular CPU/CUDA/Meta tensor(s) as output. + The abstract implementation must consist of only PyTorch operations + (and may not directly access the storage or data of any input or + intermediate Tensors). + + This API may be used as a decorator (see examples). + + For a detailed guide on custom ops, please see + https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk + + Examples:: + >>> import numpy as np + >>> from torch import Tensor + >>> + >>> # Example 1: an operator without data-dependent output shape + >>> @torch._custom_ops.custom_op("mylibrary::custom_linear") + >>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor: + >>> raise NotImplementedError + >>> + >>> @torch._custom_ops.impl_abstract("mylibrary::custom_linear") + >>> def custom_linear_abstract(x, weight): + >>> assert x.dim() == 2 + >>> assert weight.dim() == 2 + >>> assert bias.dim() == 1 + >>> assert x.shape[1] == weight.shape[1] + >>> assert weight.shape[0] == bias.shape[0] + >>> assert x.device == weight.device + >>> + >>> return (x @ weight.t()) + bias + >>> + >>> # Example 2: an operator with data-dependent output shape + >>> @torch._custom_ops.custom_op('mylibrary::custom_nonzero') + >>> def custom_nonzero(x: Tensor) -> Tensor: + >>> ... + >>> + >>> @torch._custom_ops.impl_abstract("mylibrary::custom_nonzero") + >>> def custom_nonzero_abstract(x): + >>> # Number of nonzero-elements is data-dependent. + >>> # Since we cannot peek at the data in an abstract impl, + >>> # we use the ctx object to construct a new symint that + >>> # represents the data-dependent size. + >>> ctx = torch._custom_ops.get_ctx() + >>> nnz = ctx.create_unbacked_symint() + >>> shape = [x.dim(), nnz] + >>> result = x.new_empty(shape, dtype=torch.long) + >>> return result + >>> + >>> @torch._custom_ops.impl("mylibrary::custom_nonzero") + >>> def custom_nonzero_impl(x): + >>> x_np = to_numpy(x) + >>> res = np.stack(np.nonzero(x_np), axis=1) + >>> # unbacked symbolic ints in PyTorch must be >= 2, so we + >>> # constrain the range to at least 2 + >>> if res.shape[0] <= 1: + >>> raise RuntimeError("not supported") + >>> return torch.tensor(res, device=x.device) + + """ + import torch.library + + return torch.library.register_fake(qualname, func, _stacklevel=2) + + +def impl_save_for_backward(qualname, *, func=None): + r"""Register a function that tells us what to save for backward. + + Please see :func:`impl_backward` for more details. + """ + + def inner(func): + custom_op = _find_custom_op(qualname, also_check_torch_library=True) + custom_op.impl_save_for_backward(_stacklevel=3)(func) + return func + + if func is None: + return inner + return inner(func) + + +def impl_backward(qualname, output_differentiability=None, *, func=None): + r"""Registers a backward formula for an operator. + + In order for an operator to work with autograd, you need to register + a backward formula. There are two pieces to this: + 1. You must give us a function to specify what to save for backward. + Call this the "save for backward" function. + 2. You must give us a function that computes gradients. Call this the + "backward" function. + + Use `impl_save_for_backward` to define a "save for backward" function + that specifies what gets saved for backward. The function should accept + two arguments ``(inputs, output)`` and return the quantities to be saved + for backward. + + During runtime, when you call the operator in a forwards pass, PyTorch + will invoke the "save for backward" function with the inputs and output + of the operator. + + Use `impl_backward` to define the "backward" function. The backward + function must accept ``(ctx, saved, *grads)``: + - ``ctx`` is a context object where we may provide information + - ``saved`` is exactly what gets returned from the "save for backward" + function + - ``grads`` is one or more gradients. The number of gradients matches + the number of outputs of the operator. + + The backward function must return a dict that maps the name of + an input to the operator to its corresponding gradient. All inputs that + were declared to be Tensors in the operator definition must be accounted + for in the dict. The gradient may be a Tensor or None. + + For a detailed guide on custom ops, please see + https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk + + """ + + def inner(func): + custom_op = _find_custom_op(qualname, also_check_torch_library=True) + custom_op.impl_backward(output_differentiability, _stacklevel=3)(func) + return func + + if func is None: + return inner + return inner(func) + + +def _destroy(qualname): + """De-registers a custom op. For testing purposes only""" + custom_op = _find_custom_op(qualname) + custom_op._destroy() diff --git a/phivenv/Lib/site-packages/torch/_deploy.py b/phivenv/Lib/site-packages/torch/_deploy.py new file mode 100644 index 0000000000000000000000000000000000000000..f280be9c780de737446fa7f3fec974b5b0c96309 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_deploy.py @@ -0,0 +1,104 @@ +# mypy: allow-untyped-defs +import io + +import torch +from torch.package import Importer, OrderedImporter, PackageImporter, sys_importer +from torch.package._package_pickler import create_pickler +from torch.package._package_unpickler import PackageUnpickler +from torch.serialization import _maybe_decode_ascii + + +def _save_storages(importer, obj): + serialized_storages = [] + serialized_dtypes = [] + + importer = importer if isinstance(importer, torch.package.PackageImporter) else None + importers: Importer + if importer is not None: + importers = OrderedImporter(importer, sys_importer) + else: + importers = sys_importer + + def persistent_id(obj): + if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage): + if isinstance(obj, torch.storage.TypedStorage): + # TODO: Once we decide to break serialization FC, we can + # remove this case + dtype = obj.dtype + else: + dtype = torch.uint8 + + serialized_storages.append(obj) + serialized_dtypes.append(dtype) + return ("storage", len(serialized_storages) - 1) + + if hasattr(obj, "__reduce_deploy__"): + if _serialized_reduces.get(id(obj)) is None: + _serialized_reduces[id(obj)] = ( + "reduce_deploy", + id(obj), + *obj.__reduce_deploy__(importers), + ) + return _serialized_reduces[id(obj)] + + return None + + # Write the pickle data for `obj` + data_buf = io.BytesIO() + pickler = create_pickler(data_buf, importers) + pickler.persistent_id = persistent_id + pickler.dump(obj) + data_value = data_buf.getvalue() + return ( + data_value, + serialized_storages, + serialized_dtypes, + importer.zip_reader if importer else None, + ) + + +def _load_storages(id, zip_reader, obj_bytes, serialized_storages, serialized_dtypes): + def persistent_load(saved_id): + assert isinstance(saved_id, tuple) + typename = _maybe_decode_ascii(saved_id[0]) + data = saved_id[1:] + + if typename == "storage": + # TODO: Once we decide to break serialization FC, we can + # stop wrapping with TypedStorage + storage = serialized_storages[data[0]] + dtype = serialized_dtypes[data[0]] + return torch.storage.TypedStorage( + wrap_storage=storage.untyped(), dtype=dtype + ) + + if typename == "reduce_deploy": + reduce_id, func, args = data + if reduce_id not in _loaded_reduces: + _loaded_reduces[reduce_id] = func(_raw_packages[zip_reader], *args) + return _loaded_reduces[reduce_id] + + return None + + importer: Importer + if zip_reader is not None: + importer = OrderedImporter(_get_package(zip_reader), sys_importer) + else: + importer = sys_importer + + unpickler = PackageUnpickler(importer, io.BytesIO(obj_bytes)) + unpickler.persistent_load = persistent_load # type: ignore[method-assign] + result = _deploy_objects[id] = unpickler.load() + return result + + +def _get_package(zip_reader): + if zip_reader not in _raw_packages: + _raw_packages[zip_reader] = PackageImporter(zip_reader) + return _raw_packages[zip_reader] + + +_raw_packages: dict = {} +_deploy_objects: dict = {} +_serialized_reduces: dict = {} +_loaded_reduces: dict = {} diff --git a/phivenv/Lib/site-packages/torch/_environment.py b/phivenv/Lib/site-packages/torch/_environment.py new file mode 100644 index 0000000000000000000000000000000000000000..78e2ff3177db024114a9340663c1273a15c39c9d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_environment.py @@ -0,0 +1,2 @@ +def is_fbcode() -> bool: + return False diff --git a/phivenv/Lib/site-packages/torch/_guards.py b/phivenv/Lib/site-packages/torch/_guards.py new file mode 100644 index 0000000000000000000000000000000000000000..31bcaf405547829ab03b8da2fc9c426be9bf454a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_guards.py @@ -0,0 +1,1141 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import contextlib +import dataclasses +import enum +import functools +import logging +import re +import threading +import traceback +import unittest.mock +import weakref +from abc import abstractmethod +from collections import defaultdict +from contextlib import contextmanager +from dataclasses import dataclass +from typing import ( + Any, + Callable, + Generic, + NamedTuple, + Optional, + TYPE_CHECKING, + TypeVar, + Union, +) + +import torch +from torch.utils import _pytree as pytree +from torch.utils._backport_slots import dataclass_slots +from torch.utils._traceback import CapturedTraceback, format_frame +from torch.utils.weak import WeakTensorKeyDictionary + + +log = logging.getLogger(__name__) + + +if TYPE_CHECKING: + from types import CodeType + + import sympy + + +""" +torch._guards is the definitional source of truth for general purpose guard structures. + +An important thing to keep in mind here is the preservation of layering. There should be no dynamo notions, +and no guard installation notions here. +""" + +COMPILE_ID_PATTERN = re.compile(r"^(?P\d+)/(?P\d+)$") +CA_COMPILE_ID_PATTERN = re.compile( + r"^!(?P\d+)(?:/(?P\d+)/(?P\d+))?$" +) + +# [Note: Updating CompiledId] +# +# CompiledId represents a unique program-level identifier, and we want to keep that +# property as the codebase evolves. This property is relied on even outside of the pytorch +# repo, e.g. tlparse or other internal tooling. The in-memory format can be freely changed, +# as those dependencies only consume the string serialization. +# +# The string form should be: +# 1. Program-level uid: CompileId can uniquely identify a compiled graph. +# 2. Storage efficient: This object is logged in nearly every entry. We should elide symbols when possible. +# 3. Compact: The string form is directly displayed by some tools. Special symbols are okay. + + +# TODO: mark as kw_only=True once we drop support for bool: + return self in (GuardSource.GLOBAL_FSDP_MODULE, GuardSource.LOCAL_FSDP_MODULE) + + def is_specialized_nn_module(self) -> bool: + import torch._dynamo.config as config + + if config._unsafe_skip_fsdp_module_guards: + return ( + self + in ( + GuardSource.GLOBAL_SPECIALIZED_NN_MODULE, + GuardSource.LOCAL_SPECIALIZED_NN_MODULE, + ) + or self.is_fsdp_module() + ) + return self in ( + GuardSource.GLOBAL_SPECIALIZED_NN_MODULE, + GuardSource.LOCAL_SPECIALIZED_NN_MODULE, + ) + + def is_unspecialized_nn_module(self) -> bool: + return self in ( + GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE, + GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE, + GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE, + GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, + ) + + def is_unspecialized_builtin_nn_module(self) -> bool: + return self in ( + GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE, + GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, + ) + + def is_local(self): + return self in ( + GuardSource.LOCAL, + GuardSource.LOCAL_SPECIALIZED_NN_MODULE, + GuardSource.LOCAL_FSDP_MODULE, + GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE, + GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, + ) + + +""" +Base class for a "GuardBuilder" role. + +The GuardBuilderBase role is to represent a scope within which to build a guard. The name is a little +confusing, as its not a builder, but for the sake of avoiding a lot of renames and keeping the original reference +to torchdynamo's GuardBuilder. + +Note: create_fn is invoked with a GuardBuilderBase and a Guard. A GuardBuilder is chosen based +on GuardSource's select function. + +There is value in keeping this GuardBuilderBase empty to keep layering clean. +""" + + +class GuardBuilderBase: + pass + + +@dataclasses.dataclass(frozen=True) +class SLoc: + framework_loc: Optional[Union[traceback.FrameSummary, str]] + maybe_user_loc: Optional[str] + + def __str__(self): + floc = ( + self.framework_loc + if isinstance(self.framework_loc, str) + else format_frame(self.framework_loc) + ) + if self.maybe_user_loc is not None: + return f"{self.maybe_user_loc} ({floc})" + else: + return f"({floc})" + + +class ShapeGuard(NamedTuple): + expr: sympy.logic.boolalg.Boolean + sloc: SLoc + size_oblivious: bool + + +@dataclass_slots +@dataclasses.dataclass +class Guard: + # originating_source is the source that called the make_guard method to + # construct this guard object. The property name specifies what exactly it + # is the guard is guarding on. The meaning of the name is dependent on the + # create_fn; you must look at the use-site inside create_fn to know what + # name means. + # + # That being said, although you might think this is just a "name", name is + # usually an arbitrary Python expression that will be evaluated with all + # globals (and locals, if you create a LOCAL guard) to extract the Python + # object that we want to perform guard tests on. This evaluation + # typically happens in GuardBuilder.eval. In these cases, name is + # typically produced by originating_source.name() (not to be confused with + # GuardSource - the property source). + # + # Occasionally, name is not a valid Python expression; sometimes + # it is meaningless. Example create_fns that are like this include + # GRAD_MODE and SHAPE_ENV. + originating_source: Source + create_fn: Callable[[GuardBuilderBase, Guard], None] + + # Export only. These values are written to at time of guard check_fn creation. + guard_types: Optional[list[str]] = None + code_list: Optional[list[str]] = None + obj_weakref: Optional[object] = None + guarded_class_weakref: Optional[type] = None + + stack: Optional[CapturedTraceback] = None + user_stack: Optional[traceback.StackSummary] = None + _hash: Optional[int] = None + + def __hash__(self): + if self._hash is None: + self._hash = hash((self.name, self.source, id(self.create_fn))) + return self._hash + + def sort_key(self): + # Put the duplicate input guards at the end. The duplicate guards have + # two sources while guard.name only considers one source. + + is_duplicate_input = ( + isinstance(self.create_fn, functools.partial) + and self.create_fn.func is torch._dynamo.guards.GuardBuilder.DUPLICATE_INPUT + ) + return ( + is_duplicate_input, + self.source.value if self.source else -1, + len(self.name), + self.name, + self.inner_create_fn().__code__.co_firstlineno, + ) + + def __lt__(self, other): + return self.sort_key() < other.sort_key() + + def inner_create_fn(self): + if isinstance(self.create_fn, functools.partial): + return self.create_fn.func + else: + return self.create_fn + + @property + def name(self) -> str: + return self.originating_source.name() + + @property + def source(self) -> GuardSource: + return self.originating_source.guard_source() + + @staticmethod + def weakref_to_str(obj_weakref): + """ + This is a workaround of a Python weakref bug. + + `obj_weakref` is instance returned by `weakref.ref`, + `str(obj_weakref)` is buggy if the original obj overrides __getattr__, e.g: + + class MyConfig(dict): + def __getattr__(self, x): + return self[x] + + obj = MyConfig(offset=5) + obj_weakref = weakref.ref(obj) + str(obj_weakref) # raise error: KeyError: '__name__' + """ + if isinstance(obj_weakref, weakref.ReferenceType): + obj = obj_weakref() + if obj is not None: + return f"" + else: + return f"" + else: + return str(obj_weakref) + + def __repr__(self): + s = f""" + {self.source.name.lower() if self.source else ""} {repr(self.name)} {self.inner_create_fn().__name__} + {{ + 'guard_types': {self.guard_types}, + 'code': {self.code_list}, + 'obj_weakref': {self.weakref_to_str(self.obj_weakref)} + 'guarded_class': {self.guarded_class_weakref} + }} + """ + return s + + def __str__(self): + output = f"Name: {repr(self.name)}\n" + source = self.source.name.lower() if self.source else "" + output += f" Source: {source}\n" + output += f" Create Function: {self.inner_create_fn().__name__}\n" + output += f" Guard Types: {self.guard_types}\n" + output += f" Code List: {self.code_list}\n" + output += f" Object Weakref: {self.weakref_to_str(self.obj_weakref)}\n" + output += f" Guarded Class Weakref: {self.guarded_class_weakref}\n" + return output + + def create(self, builder: GuardBuilderBase): + try: + return self.create_fn(builder, self) + except Exception: + log.exception("Error while creating guard:\n%s", str(self).rstrip()) + if self.stack: + log.error("Created at:\n%s", "".join(self.stack.format()[-4:]).rstrip()) + raise + + def is_specialized_nn_module(self): + return self.source.is_specialized_nn_module() + + def is_fsdp_module(self): + return self.source.is_fsdp_module() + + def is_local(self): + return self.source.is_local() + + def set_export_info(self, guard_type, guarded_class, code_list, obj_weakref): + if not self.guard_types: + self.guard_types = [] + + self.guard_types.append(guard_type) + + assert self.guarded_class_weakref in ( + guarded_class, + None, + ), "Guarded class id must be identical, or None" + self.guarded_class_weakref = guarded_class + + if not self.code_list: + self.code_list = code_list + else: + self.code_list.extend(code_list) + + # Some objects are ephemeral, e.g., list[slice(1, 2)]. If we have + # multiple guards on the same object, the weakref can die between the + # invocation of set_export_info calls. So a dead weakref is also + # acceptable. + assert ( + self.obj_weakref in (obj_weakref, None) + or callable(self.obj_weakref) + and self.obj_weakref() is None + ), "Guarded object must be identical, None or ephemeral (dead weakref)" + self.obj_weakref = obj_weakref + + +T = TypeVar("T") + +""" +Parent structure for guard env expressions. +A GuardEnvExpr can have any subtype. +Note: All subtypes must be handled exhaustively in +torch._dynamo.guards._parse_guard_env_guards to avoid a RuntimeError. +""" + + +@dataclasses.dataclass(frozen=True) +class GuardEnvExpr: + pass + + +""" +A class representing a pair of duplicate inputs. +input_pos_a and input_pos_b are input positions we have deduped. +""" + + +@dataclasses.dataclass(frozen=True) +class DuplicateInputs(GuardEnvExpr): + input_source_a: Source + input_source_b: Source + + def __post_init__(self): + assert self.input_source_a != self.input_source_b + + +""" +A class representing storage overlap relations among inputs that aliases the same storage. + +Given that a set of tensors alias the same storage, this guard checks whether they actually +have overlapping storages. + +While non_overlapping_sources represent input tensors that definitely don't have any storage +overlapping with any other input, overlapping_sources represent tensors that either: + +1. Do overlap some other input tensor +2. Might not overlap some other input tensor, but we are not sure +""" + + +@dataclasses.dataclass(frozen=True) +class StorageOverlap(GuardEnvExpr): + overlapping_sources: list[Source] + non_overlapping_sources: list[Source] + + +""" +Checkpointable is an interface for driving state snapshotting, left purposely vague for now. + +copy_graphstate() -> T, a somewhat legacy name, is expected to emit a snapshot of any type that +can also be taken in at restore_graphstate(T) calls. + +When to snapshot, is, at the moment, an implementation detail of upstream callers. Checkpointable +does not provide any garuantees around consistency, idempotency, or safety of calling its APIs, yet. + +In the future, it will have a closer coupling to a generic Checkpoint management system. +""" + + +class Checkpointable(Generic[T]): + @abstractmethod + def copy_graphstate(self) -> T: ... + + @abstractmethod + def restore_graphstate(self, state: T): ... + + +class GuardsCheckpointState: + """ + The GuardCheckpointState - it is the T of Checkpointable[T] for GuardsContext + """ + + dynamo_guards: set[Guard] = set() + + def __init__(self, dynamo_guards): + self.dynamo_guards = dynamo_guards + + def diff(self, other): + """ + Produces a delta against another GuardsCheckpointState. + + Returns None if no delta is found, otherwise, return a set() of mismatched + Guard type objects. + """ + r = self.dynamo_guards.difference(other.dynamo_guards) + if len(r) == 0: + return None + return r + + def __eq__(self, other): + return self.diff(other) is None + + +class ModuleContextCheckpointState: + nn_modules: dict[str, torch.nn.Module] = {} + + def __init__(self, nn_modules): + self.nn_modules = nn_modules + + def diff(self, other): + """ + Produces a delta against another ModuleContextCheckpointState. + + Returns None if no delta is found, otherwise, return a set() of mismatched + module key names. + """ + r = set(self.nn_modules.keys()).difference(set(other.nn_modules.keys())) + if len(r) == 0: + return None + return r + + def __eq__(self, other): + return self.diff(other) is None + + +class ModuleContext(Checkpointable[ModuleContextCheckpointState]): + def __init__(self) -> None: + self.nn_modules: dict[str, Any] = {} + + def copy_graphstate(self): + return ModuleContextCheckpointState(dict(self.nn_modules)) + + def restore_graphstate(self, state): + assert isinstance(state, ModuleContextCheckpointState) + self.nn_modules = state.nn_modules + + +class GlobalContextCheckpointState: + global_state: dict[str, tuple[Callable, ...]] = {} + + def __init__(self, global_states): + self.global_state = global_states + + def diff(self, other): + """ + Produces a delta against another GlobalContextCheckpointState. + + Returns None if no delta is found, otherwise, return a set() of mismatched + global key names. + """ + r = set(self.global_state.keys()).difference(set(other.global_state.keys())) + if len(r) == 0: + return None + return r + + def __eq__(self, other): + return self.diff(other) is None + + +class GlobalContext(Checkpointable[GlobalContextCheckpointState]): + """ + This keeps track of the global torch state during tracing of a function. + For example, torch.is_grad_enabled. + """ + + _supported_global_states = { + "grad_enabled", + "autocast_enabled", + "autocast_cpu_enabled", + "autocast_gpu_dtype", + "autocast_cpu_dtype", + "autocast_cache_enabled", + } + + def __init__(self) -> None: + self.global_state: dict[str, tuple[Callable, ...]] = {} + + def copy_graphstate(self): + return GlobalContextCheckpointState(dict(self.global_state)) + + def restore_graphstate(self, state): + assert isinstance(state, GlobalContextCheckpointState) + self.global_state = state.global_state + assert ( + len(self.global_state) == len(self._supported_global_states) + and set(self.global_state.keys()) == self._supported_global_states + ), "Global state mismatch" + for func, args in self.global_state.values(): + func(args) + + +# Like a Set[Guard] but will record the user stack on all guards at the +# time they were installed at their destination +class GuardsSet: + def __init__(self, inner=None): + if inner is None: + inner = set() + self.inner = inner + + def __iter__(self): + return iter(self.inner) + + def __len__(self): + return len(self.inner) + + # Subtraction along with bool is typically used to determine the delta of + # added guards between checkpoints for higher order ops + def __sub__(self, other): + return GuardsSet(self.inner - other.inner) + + def __bool__(self): + return bool(self.inner) + + def add(self, guard: Guard, *, collect_debug_stack=True, skip=0): + if guard in self.inner: + return + if collect_debug_stack: + if guard.stack is None: + guard.stack = CapturedTraceback.extract(skip=1 + skip) + if guard.user_stack is None: + guard.user_stack = TracingContext.extract_stack() + self.inner.add(guard) + + def update(self, *others: set[Guard]): + for o in others: + for g in o: + self.add(g, skip=1) + + def remove_guards_with_source(self, source): + """Delete all guards that contains a given source""" + from ._dynamo.source import is_from_source + + self.inner = { + g for g in self.inner if not is_from_source(g.originating_source, source) + } + + +""" +A GuardsContext is a checkpointable representation of all the guards in the current tracing +context. It's lifecycle is bound 1:1 to the tracing context, and it should never be instantiated +directly outside of it. For passing around internal state representations of this object, +prefer to extract them with copy_graphstate to produce a GuardsCheckpointState. +""" + + +class GuardsContext(Checkpointable[GuardsCheckpointState]): + def __init__(self) -> None: + self.dynamo_guards: GuardsSet = GuardsSet() + self.aotautograd_guards: list[GuardEnvExpr] = [] + + def copy_graphstate(self): + return GuardsCheckpointState(set(self.dynamo_guards.inner)) + + def restore_graphstate(self, state): + # NB: "steals" the passed in state + assert isinstance(state, GuardsCheckpointState) + self.dynamo_guards = GuardsSet(state.dynamo_guards) + + +class HopSubgraphCache: + @abstractmethod + def add_dynamo_installed_submodule(self, fn_id: int, identifier: str): ... + + @abstractmethod + def get_dynamo_installed_submodules(self, fn_id: int) -> list[str]: ... + + @abstractmethod + def add_autograd_key_entry(self, identifier: str, key: Callable): ... + + @abstractmethod + def get_autograd_key_entry(self, identifier: str): ... + + @abstractmethod + def add_proxy_dispatch_entry(self, identifier: str, key: Callable): ... + + @abstractmethod + def get_proxy_dispatch_entry(self, identifier: str): ... + + @abstractmethod + def add_lazy_bwd_entry( + self, + identifier: str, + tangent_metadata: tuple[object], + gmod: torch.fx.GraphModule, + ): ... + + @abstractmethod + def get_lazy_bwd_entry( + self, identifier: str, tangent_metadata: tuple[object] + ) -> int: ... + + +class InvokeSubgraphCache(HopSubgraphCache): + def __init__(self) -> None: + self.autograd_cache: dict[str, Callable] = {} + self.proxy_dispatch_cache: dict[str, Callable] = {} + self.dynamo_installed_submodules: dict[int, list[str]] = defaultdict(list) + self.lazy_bwd_cache: dict[ + str, dict[tuple[object], tuple[torch.fx.GraphModule, int]] + ] = defaultdict(dict) + + def add_dynamo_installed_submodule(self, fn_id: int, identifier: str): + self.dynamo_installed_submodules[fn_id].append(identifier) + + def get_dynamo_installed_submodules(self, fn_id: int) -> list[str]: + return self.dynamo_installed_submodules.get(fn_id, []) + + def add_autograd_key_entry(self, identifier: str, key: Callable): + self.autograd_cache[identifier] = key + + def get_autograd_key_entry(self, identifier: str): + return self.autograd_cache.get(identifier, None) + + def add_proxy_dispatch_entry(self, identifier: str, key: Callable): + self.proxy_dispatch_cache[identifier] = key + + def get_proxy_dispatch_entry(self, identifier: str): + return self.proxy_dispatch_cache.get(identifier, None) + + def add_lazy_bwd_entry( + self, + identifier: str, + tangent_metadata: tuple[object], + gmod: torch.fx.GraphModule, + ): + # Save the number of existing graph modules in the dictionary to get the suffix + num_gmods = len(self.lazy_bwd_cache[identifier]) + self.lazy_bwd_cache[identifier][tangent_metadata] = (gmod, num_gmods) + return num_gmods + + def get_lazy_bwd_entry(self, identifier: str, tangent_metadata: tuple[object]): + if identifier not in self.lazy_bwd_cache: + return (None, None) + + return self.lazy_bwd_cache[identifier].get(tangent_metadata, (None, None)) + + +class HopDispatchSetCache: + def __init__(self) -> None: + # Delayed import to avoid circular dependency + from torch._higher_order_ops.invoke_subgraph import invoke_subgraph + + self.hop_cache_map = {invoke_subgraph: InvokeSubgraphCache()} + + def get_cache( + self, op: torch._ops.HigherOrderOperator + ) -> Optional[HopSubgraphCache]: + if op not in self.hop_cache_map: + return None + return self.hop_cache_map[op] # type: ignore[index] + + +_TLS = threading.local() + +""" +TracingContext is the source of truth for all currently accumulated information +needed to trace. Its lifecycle is kept 1:1 when using TorchDynamo, but other systems +are open to managing their own TracingContext with that in mind. + +The purpose of TracingContext is not to be a dumping ground, or god object, but rather to avoid +having to plumb complex subsystems across multiple verticals. + +Ex: A common example is guard accumulation between dynamo, shape_env, aot_autograd, and inductor. +Accessing the current tracing context via +TracingContext.get() allows users to accumulate their own guards for processing, without needing to know how +to plumb objects back up to where frame interpretation happened. + +Note that you can end up with multiple TracingContext for a single compilation +of a frame, as we reset the TracingContext whenever we restart analysis. +CompileContext is a more overarching context that encompasses multiple restarts. +""" + + +class CompileContext: + @staticmethod + def get() -> CompileContext: + assert _TLS.compile_context is not None + return _TLS.compile_context + + @staticmethod + def try_get() -> Optional[CompileContext]: + return getattr(_TLS, "compile_context", None) + + def __init__(self, compile_id): + assert compile_id is None or isinstance(compile_id, CompileId) + self.compile_id: Optional[CompileId] = compile_id + self.attempt = 0 + # Verbose ShapeEnv guards produced. + self.shape_env_guards: list[str] = [] + + @staticmethod + def current_compile_id(): + self = CompileContext.try_get() + if self is None: + return None + return self.compile_id + + @staticmethod + def current_trace_id(): + self = CompileContext.try_get() + if self is None: + return None + if self.compile_id is None: + return None + return TraceId(self.compile_id, self.attempt) + + +class TracingContext: + """ + Provides the currently installed TracingContext, or None. + + Note that it is a staticmethod, and invocations outside of `with tracing()` (see below), are valid but + will return None. + """ + + @staticmethod + def try_get() -> Optional[TracingContext]: + return getattr(_TLS, "tracing_context", None) + + @staticmethod + def get() -> TracingContext: + if ctx := TracingContext.try_get(): + return ctx + raise RuntimeError( + "TracingContext.get() must be called within an ongoing trace." + ) + + def __init__(self, fake_mode): + self.guards_context = GuardsContext() + self.module_context = ModuleContext() + self.global_context = GlobalContext() + self.previously_inlined_functions = dict() + self.previously_cleaned_instructions = dict() + self.fake_mode = fake_mode + self.frame_summary_stack = [] + # This is morally part of frame_summary_stack, but it is kept separate + # for clarity. As we process a frame, this variable gets updated + # to keep track of what line we are in the function. We make a + # function call, this gets cleared and the frame location is pushed + # to frame_summary_stack (prepping this variable for the inner frame's + # progress) + self.loc_in_frame = None + # this is only set after aot_autograd + self.fw_metadata = None + # this is only set after aot_autograd + self.aot_graph_name = None + self.params_flat = None + self.params_flat_unwrap_subclasses = None + self.params_unwrapped_to_flat_index = None + # this is for extended return calling convention from backend + # compiler to aot_autograd + # Per output, what the compiler specified stride of the output is, + # or None if no stride is known. This is always the HINT, it + # is never a SymInt (it would be better if it was a SymInt, but + # I can't conveniently get this from Inductor atm. Also, be + # careful not to accidentally induce guards on the SymInt if + # you ever do change this in aot_autograd.py; you should check + # on permutations preferentially.) + self.output_strides: Optional[list[Optional[tuple[int, ...]]]] = None + # When this is True, whenever we encounter an int in Dynamo tracing, + # we will (1) force unspec it and (2) force it as a size-like unbacked + # integer. This is currently used when processing certain lists of + # ints that are known to be size-like and may have 0/1 entries that we + # must not specialize on. + self.force_unspec_int_unbacked_size_like = False + # See note [Tensor Fakification and Symbol Caching] + self.tensor_to_context = WeakTensorKeyDictionary() + + # If this true, Aot Autograd will return output Fake Tensors with appropiate + # meta on the first invocation + # see note: [Returning Fake Tensors on First AOT Autograd Call] + self.fakify_first_call = False + self.hop_dispatch_set_cache = HopDispatchSetCache() + # list of code objects for inlined functions + self.traced_code: list[CodeType] = [] + + def clear(self): + # Look at the note in output_graph.py in function `save_global_state` + # for the context on clearing global context. + self.global_context.global_state = {} + self.previously_inlined_functions.clear() + self.previously_cleaned_instructions.clear() + + @staticmethod + @contextmanager + def patch(**kwargs): + prior = {} + ctx = TracingContext.get() + + for key in kwargs.keys(): + # KeyError on invalid entry + prior[key] = getattr(ctx, key) + for key, val in kwargs.items(): + setattr(ctx, key, val) + try: + yield + finally: + for key, val in prior.items(): + setattr(ctx, key, val) + + @staticmethod + def extract_stack(): + self = TracingContext.try_get() + if self is None: + return traceback.StackSummary() + stack = self.frame_summary_stack + if self.loc_in_frame is not None: + stack = stack + [self._populate_loc_in_frame_summary()] + return traceback.StackSummary.from_list(stack) + + def _populate_loc_in_frame_summary(self): + assert self.loc_in_frame is not None + filename, lineno, frame_name = self.loc_in_frame + return traceback.FrameSummary(filename, lineno, frame_name, lookup_line=False) + + # Call this when you want to call into some code that isn't necessarily + # associated with the current frame state + @staticmethod + @contextlib.contextmanager + def clear_frame(): + tc = TracingContext.get() + with ( + unittest.mock.patch.object(tc, "frame_summary_stack", []), + unittest.mock.patch.object(tc, "loc_in_frame", None), + ): + try: + yield + except Exception as e: + # Prevent real_stack from getting attached + # + # The invariant is that if an Exception as real_stack, we've + # appropriately attached a user stack and we no longer need to + # attach anything. Because we cannot conveniently interpose + # when an exception is thrown, we instead interpose everywhere + # we set what the user stack is set (using the context + # manager). However, our compiler stack does "tail calls" + # (when it calls into user compiler), at which point the + # parent exception frames would incorrectly attach an + # incorrect frame. + # + # However, if, somehow, someone raised an exception with this + # scope that had a stack (for example, because they are + # restoring the user stack state appropriately as they process + # node by node), we should respect it. Thus, we cannot + # unconditionally set None. + if not hasattr(e, "real_stack"): + e.real_stack = None # type: ignore[attr-defined] + raise + + @staticmethod + @contextlib.contextmanager + def current_frame(frame_summary): + # frame_summary can be None to solely take advantage of real_stack + # attachment to thrown exceptions + tc = TracingContext.get() + if frame_summary is not None: + tc.frame_summary_stack.append(frame_summary) + old = tc.loc_in_frame + tc.loc_in_frame = None + try: + yield + except Exception as e: + if not hasattr(e, "real_stack"): + e.real_stack = tc.extract_stack() # type: ignore[attr-defined] + raise + finally: + if frame_summary is not None: + tc.frame_summary_stack.pop() + tc.loc_in_frame = old + + @staticmethod + @contextlib.contextmanager + def report_output_strides(): + tc = TracingContext.try_get() + if tc is None: + yield None + return + old_output_strides = tc.output_strides + tc.output_strides = [] + try: + yield tc.output_strides + finally: + tc.output_strides = old_output_strides + + @staticmethod + def set_current_loc(filename, lineno, frame_name): + # Save the current location in the frame. Lazily generate the + # framesummary. + TracingContext.get().loc_in_frame = (filename, lineno, frame_name) + + @staticmethod + def get_traced_code(): + tc = TracingContext.try_get() + if tc is None: + return None + return tc.traced_code + + +@contextmanager +def compile_context(context: Optional[CompileContext]): + old_context = getattr(_TLS, "compile_context", None) + _TLS.compile_context = context + try: + yield context + finally: + _TLS.compile_context = old_context + + +@contextmanager +def tracing(context: Optional[TracingContext]): + """ + This function installs the passed in tracing context as a dynamic scoped + global variable. + + Calls to TracingContext.get() while not under a `with tracing()` context + will return None. + """ + old_context = getattr(_TLS, "tracing_context", None) + _TLS.tracing_context = context + try: + yield context + except Exception as e: + if not hasattr(e, "real_stack") and context is not None: + e.real_stack = context.extract_stack() # type: ignore[attr-defined] + raise + finally: + if ( + context is not None + and context.fake_mode is not None + and context.fake_mode.shape_env is not None + ): + context.fake_mode.shape_env.cleanup() + _TLS.tracing_context = old_context + + +# Subclasses can be found in torch/_dynamo/source.py +# TODO(voz): Consider a toplevel torch/_source.py +@dataclasses.dataclass(frozen=True) +class Source: + def is_dict_key(self): + return False + + def is_ephemeral(self): + return False + + def reconstruct(self, codegen): + raise NotImplementedError + + def guard_source(self) -> GuardSource: + raise NotImplementedError + + def name(self) -> str: + raise NotImplementedError + + def make_guard(self, fn) -> Guard: + if self.guard_source() is GuardSource.CONSTANT: + raise NotImplementedError + return Guard(self, fn) + + def is_specialized_nn_module(self) -> bool: + return self.guard_source().is_specialized_nn_module() + + def subguards_allowed(self): + """True if you can guard on attributes of this""" + return self.guard_source() != GuardSource.SYNTHETIC_LOCAL + + +# Subclasses can be found in torch/_dynamo/source.py +@dataclasses.dataclass(frozen=True) +class ChainedSource(Source): + base: Source + + def is_dict_key(self): + # Recurse until you either hit a ConstDictKey or a Source + return self.base.is_dict_key() + + def is_ephemeral(self): + return self.base.is_ephemeral() + + def get_base(self) -> Source: + current: Source = self + while isinstance(current, ChainedSource): + current = current.base + return current + + +def detect_fake_mode(inputs: Any = None): + """ + Attempts to "detect" what the current fake mode is. If there is one ambiently + available from TracingContext, we preferentially use that. Otherwise, we + heuristically detect the fake mode via the following sources, in order of + priority: + + - Currently active fake mode on stack + - Fake mode associated with passed in tensors (inputs does not + have to be flattened) + """ + from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode + + fake_modes = [] + + if context := TracingContext.try_get(): + fake_mode = context.fake_mode + if fake_mode is not None: + fake_modes.append((fake_mode, "tracing context", 0)) + + from torch.utils._python_dispatch import _get_current_dispatch_mode_stack + + for i, m in enumerate(reversed(_get_current_dispatch_mode_stack())): + if isinstance(m, FakeTensorMode): + fake_modes.append((m, "active fake mode", i)) + + flat_inputs = pytree.tree_leaves(inputs) + for i, flat_input in enumerate(flat_inputs): + if isinstance(flat_input, FakeTensor): + fake_modes.append((flat_input.fake_mode, "fake tensor input", i)) + + if fake_modes: + fake_mode, desc1, i1 = fake_modes[0] + for m, desc2, i2 in fake_modes[1:]: + assert fake_mode is m, ( + f"fake mode ({fake_mode}) from {desc1} {i1} doesn't match mode ({m}) from {desc2} {i2}\n\n" + f"fake mode from {desc1} {i1} allocated at:\n{fake_mode.stack}\n" + f"fake mode from {desc2} {i2} allocated at:\n{m.stack}" + ) + return fake_mode + else: + return None + + +def active_fake_mode(): + """ + Inspects the dispatch mode stack for an active fake mode and returns it. + Returns None if no fake mode is active. + """ + from torch._subclasses.fake_tensor import FakeTensorMode + from torch.utils._python_dispatch import _get_current_dispatch_mode_stack + + for _, m in enumerate(reversed(_get_current_dispatch_mode_stack())): + if isinstance(m, FakeTensorMode): + return m + + return None diff --git a/phivenv/Lib/site-packages/torch/_jit_internal.py b/phivenv/Lib/site-packages/torch/_jit_internal.py new file mode 100644 index 0000000000000000000000000000000000000000..3e06e387c9649ea9c1c183d2b9de825c65023dd2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_jit_internal.py @@ -0,0 +1,1552 @@ +# mypy: allow-untyped-defs +""" +The weak_script annotation needs to be here instead of inside torch/jit/ so it +can be used in other places in torch/ (namely torch.nn) without running into +circular dependency problems +""" + +import ast +import builtins +import collections +import contextlib +import enum +import inspect +import io +import pickle +import sys +import textwrap +import threading +import types +import typing +import warnings +import weakref +from typing import ( # noqa: UP035, F401 # (Dict, List, Tuple) imported by torch.jit.annotations + Any, + Callable, + Dict, + Final, + ForwardRef, + get_args, + get_origin, + List, + Optional, + Tuple, + TypeVar, + Union, +) +from typing_extensions import ParamSpec + +import torch + +# This is needed. `torch._jit_internal` is imported before `torch.distributed.__init__`. +# Explicitly ask to import `torch.distributed.__init__` first. +# Otherwise, "AttributeError: module 'torch' has no attribute 'distributed'" is raised. +import torch.distributed.rpc +import torch.package._mangling as package_mangling +from torch._awaits import _Await +from torch._C import _Await as CAwait, Future as CFuture +from torch._sources import fake_range, get_source_lines_and_file, parse_def +from torch.futures import Future + + +_P = ParamSpec("_P") +_R = TypeVar("_R") + +IS_PY310_PLUS: Final[bool] = sys.version_info >= (3, 10) + +BuiltinUnionType: Union[type, tuple[type, ...]] +if sys.version_info >= (3, 10): + # NOTE: IS_PY310_PLUS doesn't work with mypy. + # cf. https://mypy.readthedocs.io/en/stable/common_issues.html#python-version-and-system-platform-checks + BuiltinUnionType = types.UnionType +else: + BuiltinUnionType = () # trick: this makes isinstance short circuit. + +LockType: type +try: + import _thread + + LockType = _thread.LockType +except ImportError: + import _dummy_thread # type: ignore[import-not-found] + + LockType = _dummy_thread.LockType + +# Wrapper functions that can call either of 2 functions depending on a boolean +# argument +boolean_dispatched: "weakref.WeakKeyDictionary[Callable, dict[str, Callable]]" = ( + weakref.WeakKeyDictionary() +) # noqa: T484 + + +FAKE_FILENAME_PREFIX = "__torch_jit_dataclass" + + +def is_final(ann) -> bool: + return ( + hasattr(ann, "__module__") + and ann.__module__ in {"typing", "typing_extensions"} + and (get_origin(ann) is Final or isinstance(ann, type(Final))) + ) + + +# allows BroadcastingList instance to be subscriptable +class BroadcastingListCls: + def __getitem__(self, types): + return + + +# mypy doesn't support parameters on types, so we have to explicitly type each +# list size +BroadcastingList1 = BroadcastingListCls() +for i in range(2, 7): + globals()[f"BroadcastingList{i}"] = BroadcastingList1 + + +def is_scripting() -> bool: + r""" + Function that returns True when in compilation and False otherwise. This + is useful especially with the @unused decorator to leave code in your + model that is not yet TorchScript compatible. + .. testcode:: + + import torch + + @torch.jit.unused + def unsupported_linear_op(x): + return x + + def linear(x): + if torch.jit.is_scripting(): + return torch.linear(x) + else: + return unsupported_linear_op(x) + """ + return False + + +# Retrieves a fully-qualified name (module hierarchy + classname) for a given obj. +def _qualified_name(obj, mangle_name=True) -> str: + # This special case allows us to override the qualified name on a type. + # It's currently used in conjunction with tracing, where we create a + # fake module to filter only supported attributes. However, since this + # new type is defined as a local class, we need a mechanism to override + # its qualname so it appears correctly in the TorchScript system. This, + # we set '_jit_override_qualname' with the original traced module's + # qualified name, which is picked up here + if hasattr(obj, "_jit_override_qualname"): + return obj._jit_override_qualname + # short-circuit in cases where the object already has a known qualified name + if isinstance(obj, torch._C.ScriptFunction): + return obj.qualified_name + + if getattr(obj, "__name__", None): + name = obj.__name__ + # Enum classes do not have `__name__` attr, instead they have `name`. + elif isinstance(obj, enum.Enum): + name = obj.name + else: + raise RuntimeError("Could not get name of python class object") + + if name == "": + name = "_lambda" # make name a valid identifier + + module_name = obj.__module__ + + # If the module is actually a torchbind module, then we should short circuit + if module_name == "torch._classes": + return obj.qualified_name + + # The Python docs are very clear that `__module__` can be None, but I can't + # figure out when it actually would be. + if module_name is None: + raise RuntimeError( + f"Could not get qualified name for class '{name}': " + "__module__ can't be None." + ) + + # if getattr(sys.modules[module_name], name) is not obj: + # raise RuntimeError(f"Could not get qualified name for class '{name}': " + # f"the attr {name} on module {module_name} is not the class") + + # torch.package and TorchScript have separate mangling schemes to avoid + # name collisions from multiple packages. To avoid them interfering with + # each other, normalize the package manging here. + if package_mangling.is_mangled(module_name): + module_name = module_name.replace("<", "_") + module_name = module_name.replace(">", "_") + + # The PythonExceptionValue C++ class in torch/csrc/jit/python/python_sugared_value.h + # does not need mangle the python class name. + if mangle_name: + # __main__ is a builtin module, so rewrite it to "__torch__". + if module_name == "__main__": + module_name = "__torch__" + else: + # Everything else gets a "__torch__" prefix to avoid name collisions + # with the names of user values. + module_name = "__torch__." + module_name + + if "." in name: + raise RuntimeError( + f"Could not get qualified name for class '{name}': " + f"'{name}' is not a valid identifier" + ) + + return module_name + "." + name + + +class SourceLoader: + def __init__(self): + self.content = {} + + def cache(self, fn, source): + self.content[fn] = source + + def get_source(self, fn): + return self.content.get(fn) + + +loader = SourceLoader() + + +def createResolutionCallbackFromEnv(lookup_base): + """ + Creates a resolution callback that will look up qualified names in an + environment, starting with `lookup_base` for the base of any qualified + names, then proceeding down the lookup chain with the resolved object. + + You should not use this directly, it should only be used from the other + createResolutionCallbackFrom* functions. + """ + + def lookupInModule(qualified_name, module): + if "." in qualified_name: + base, remaining_pieces = qualified_name.split(".", maxsplit=1) + module_value = getattr(module, base) + return lookupInModule(remaining_pieces, module_value) + else: + return getattr(module, qualified_name) + + def parseNestedExpr(expr, module) -> tuple[Any, int]: + i = 0 + while i < len(expr) and expr[i] not in (",", "[", "]"): + i += 1 + + # Special case logic for the empty Tuple as a subscript (used + # in the type annotation `Tuple[()]`) + if expr[:i] == "()": + return (), i + + base = lookupInModule(expr[:i].strip(), module) + assert base is not None, f"Unresolvable type {expr[:i]}" + if i == len(expr) or expr[i] != "[": + return base, i + + assert expr[i] == "[" + parts = [] + while expr[i] != "]": + part_len = 0 + i += 1 + part, part_len = parseNestedExpr(expr[i:], module) + parts.append(part) + i += part_len + if len(parts) > 1: + return base[tuple(parts)], i + 1 + else: + return base[parts[0]], i + 1 + + def parseExpr(expr, module): + try: + value, len_parsed = parseNestedExpr(expr, module) + assert len_parsed == len(expr), ( + "whole expression was not parsed, falling back to c++ parser" + ) + return value + except Exception: + """ + The python resolver fails in several cases in known unit tests, and is intended + to fall back gracefully to the c++ resolver in general. For example, python 2 style + annotations which are frequent in our unit tests often fail with types e.g. int not + resolvable from the calling frame. + """ + return None + + return lambda expr: parseExpr(expr, lookup_base) + + +def createResolutionCallbackFromFrame(frames_up: int = 0): + """ + Creates a function which, given a string variable name, + returns the value of the variable in the scope of the caller of + the function which called createResolutionCallbackFromFrame (by default). + + This is used to enable access in-scope Python variables inside + TorchScript fragments. + + frames_up is number of additional frames to go up on the stack. + The default value is 0, which correspond to the frame of the caller + of createResolutionCallbackFromFrame. Also for example, if frames_up is set + to 1, then the frame of the caller's caller of createResolutionCallbackFromFrame + will be taken. + + For example, the following program prints 2:: + + def bar(): + cb = createResolutionCallbackFromFrame(1) + print(cb("foo")) + + + def baz(): + foo = 2 + bar() + + + baz() + """ + frame = inspect.currentframe() + i = 0 + while i < frames_up + 1: + assert frame is not None + frame = frame.f_back + i += 1 + + assert frame is not None + f_locals = frame.f_locals + f_globals = frame.f_globals + + class env: + def __getattr__(self, key): + if key in f_locals: + return f_locals[key] + elif key in f_globals: + return f_globals[key] + elif key in dir(builtins): + return getattr(builtins, key) + + return createResolutionCallbackFromEnv(env()) + + +def get_closure(fn): + """ + Get a dictionary of closed over variables from a function + """ + captures = {} + captures.update(fn.__globals__) + + for index, captured_name in enumerate(fn.__code__.co_freevars): + captures[captured_name] = fn.__closure__[index].cell_contents + + return captures + + +# [local resolution in python] +# Depending on where a variable is defined, and where it is used, we may +# or may not be able to recover its value when recursively compiling a +# script function. Remember in the general case, a module or function is +# first defined and then later scripted. This means we do not have a +# chance to capture the active frames when the function is defined. Hence any +# name resolution has to happen later on the created closure. The way +# python captures type annotations restricts what we can recover. The +# follow example illustrates the different cases: +# +# class MyGlobalClass: +# ... +# def my_local_scope(): +# @torch.jit.script +# class MyClass: +# ... +# @torch.jit.script +# class MyClassUsedAsVar: +# ... +# def eg(x: MyClass, y: MyGlobalClass): +# a_local_capture : Foo +# return MyClassUsedAsVar(x) +# +# MyGlobalClass is defined in the __globals__ dictionary of function +# 'eg', so it is always recoverable. my_local_scope introduces a new local +# variable scope in the function. Classes defined here are only visible as +# local variables. For the case of MyClassUsedAsVar, it is captured +# because it is used as a variable inside the body of the function, and we +# can resolve it using the captures returned from `get_closure`. However, +# the type annotations are not captured by the closure. In Python +# 3.0--3.9, the _value_ of MyClass and MyGlobalClass will be available as +# annotations on `eg``, but starting in Python 4.0, they will represented as +# strings and no longer present. Furthermore, since the body of `eg` does +# not reference those names, they do not appear in the list of closed over +# variables. In Python 2.x, type annotations are in comments, leading to a +# similar situation where their definitions are not available. We anticipate +# that most users will not run into this issue because their modules and +# functions will be defined at a global scope like MyGlobalClass. In cases +# where they are not, it is possible to work around issues by declaring the +# values global in the function. +# In Python 3.9 declaring class as global will make it invisible to +# `inspect.getsource`, see https://bugs.python.org/issue42666 . +# This could be worked around by manualy adding it to `global()` dictionary. + + +def createResolutionCallbackFromClosure(fn): + """ + Create a resolutionCallback by introspecting the function instead of + looking up the stack for the enclosing scope + """ + closure = get_closure(fn) + + class closure_lookup: + # This is a class since `closure` is a dict and it's easier in + # `env_helper` if everything just works with `getattr` calls + def __getattr__(self, key): + if key in closure: + return closure[key] + elif hasattr(typing, key): + return getattr(typing, key) + elif hasattr(builtins, key): + return getattr(builtins, key) + return None + + return createResolutionCallbackFromEnv(closure_lookup()) + + +def can_compile_class(cls) -> bool: + # If any of the functions on a type don't have a code object, this type can't + # be compiled and is probably a builtin / bound from C + if is_ignored_fn(cls): + return False + + # Ignore the following list of built-in classes. + ignored_builtin_classes = (torch.nn.Module, tuple, list, Exception) + if issubclass(cls, ignored_builtin_classes): + return False + + names = cls.__dict__ + fns = [ + getattr(cls, name) + for name in names + if inspect.isroutine(getattr(cls, name, None)) + ] + has_code = [hasattr(fn, "__code__") for fn in fns] + return all(has_code) + + +def get_callable_argument_names(fn) -> list[str]: + """ + Gets names of all POSITIONAL_OR_KEYWORD arguments for callable `fn`. + Returns an empty list when other types of arguments are present. + + This is used by `torch.jit.trace` to assign meaningful argument names to + traced functions and modules. + + Args: + fn: A callable. + Returns: + Argument names: List[str] + """ + # inspect.signature may fail, give up in that case. + try: + callable_signature = inspect.signature(fn) + except Exception: + return [] + + argument_names = [] + for name, param in callable_signature.parameters.items(): + # All four other types of arguments do not map to individual values + # with a keyword as name. + if not param.kind == param.POSITIONAL_OR_KEYWORD: + continue + + argument_names.append(name) + + return argument_names + + +def get_annotation_str(annotation): + """ + Convert an AST node containing a type annotation to the string present in the source + that represents the same annotation. + """ + if isinstance(annotation, ast.Name): + return annotation.id + elif isinstance(annotation, ast.Attribute): + return ".".join([get_annotation_str(annotation.value), annotation.attr]) + elif isinstance(annotation, ast.Subscript): + # In Python3.9+ subscript indicies are not wrapped in ast.Index + subscript_slice = annotation.slice + return f"{get_annotation_str(annotation.value)}[{get_annotation_str(subscript_slice)}]" + elif isinstance(annotation, ast.Tuple): + return ",".join([get_annotation_str(elt) for elt in annotation.elts]) + elif isinstance(annotation, ast.Constant): + return f"{annotation.value}" + + # If an AST node is not handled here, it's probably handled in ScriptTypeParser. + return None + + +def get_type_hint_captures(fn): + """ + Get a dictionary containing type resolution mappings necessary to resolve types + for the literal annotations on 'fn'. These are not considered to be closed-over by fn + and must be obtained separately (e.g. using this function). + + Args: + fn: A callable. + Returns: + A Dict[str, Any] containing a mapping from the literal annotations used on + fn to the Python objects they refer to. + """ + # First, try to get the source of the function. We'll need to parse it to find the actual string names + # that were used to annotate the types, since inspect.signature() will only return the class object that + # the annotation refers to, not the string name. If we can't get the source, simply return an empty dict. + # This may happen in cases where the function is synthesized dynamically at runtime. + src = loader.get_source(fn) + if src is None: + try: + src = inspect.getsource(fn) + except OSError as e: + raise OSError( + f"Failed to get source for {fn} using inspect.getsource" + ) from e + + # Gather a dictionary of parameter name -> type, skipping any parameters whose annotated + # types are strings. These are only understood by TorchScript in the context of a type annotation + # that refers to a class in its own definition, but trying to include a mapping for this in the result + # function would cause infinite recursion because the class is currently being compiled. + # In addition, there is logic in ScriptTypeParser to handle this. + signature = inspect.signature(fn) + name_to_type = { + name: parameter.annotation + for name, parameter in signature.parameters.items() + if parameter.annotation is not inspect.Parameter.empty + and not isinstance(parameter.annotation, str) + } + + # Then, get the literal type annotations from the function declaration + # by source inspection. This accounts for the case in which aliases are used + # to annotate the arguments (e.g device_t = torch.device, and then d: device_t). + # frontend.py cannot be used here because it includes _jit_internal, so use ast instead. + a = ast.parse(textwrap.dedent(src)) + if len(a.body) != 1 or not isinstance(a.body[0], ast.FunctionDef): + raise RuntimeError(f"Expected {fn} to be a function") + f = a.body[0] + + # Prepare a dictionary of source annotation -> type, which will be the final result of this function, + # by using the parsed AST (f) to reconstruct source annotations as strings for each parameter and mapping + # them to the type object corresponding to the annotation via name_to_type using the parameter name. + annotation_to_type = {} + + for arg in f.args.args: + # Get the source type annotation string for this argument if possible. + arg_annotation_str = ( + get_annotation_str(arg.annotation) if arg.annotation else None + ) + + # If the argument has no annotation or get_annotation_str cannot convert it to a string, + # arg_annotation_str will be None. Skip this arg; ScriptTypeParser will probably handle + # this in the latter case. + if arg_annotation_str is None: + continue + + # Insert {arg_annotation_str: type} into annotation_to_type if possible. One reason arg_name may not + # be present in name_to_type is that the annotation itself is a string and not a type object + # (common for self-refential annotations in classes). Once again, let ScriptTypeParser handle this. + arg_name = arg.arg + if arg_name in name_to_type: + annotation_to_type[arg_annotation_str] = name_to_type[arg_name] + + # If there is a valid return annotation, include it in annotation_to_type. As with argument annotations, + # the literal annotation has to be convertible to a string by get_annotation_str, and the actual type + # of the annotation cannot be a string. + literal_return_annotation = get_annotation_str(f.returns) + valid_literal_annotation = literal_return_annotation is not None + return_annotation = signature.return_annotation + valid_return_annotation_type = ( + return_annotation is not inspect.Parameter.empty + and not isinstance(return_annotation, str) + ) + if valid_literal_annotation and valid_return_annotation_type: + annotation_to_type[literal_return_annotation] = return_annotation + + return annotation_to_type + + +def createResolutionCallbackForClassMethods(cls): + """ + This looks at all the methods defined in a class and pulls their closed-over + variables into a dictionary and uses that to resolve variables. + """ + # cls is a type here, so `ismethod` is false since the methods on the type + # aren't bound to anything, so Python treats them as regular functions + fns = [ + getattr(cls, name) + for name in cls.__dict__ + if inspect.isroutine(getattr(cls, name)) + ] + # Skip built-ins, as they do not have global scope nor type hints + # Needed to support `enum.Enum` derived classes in Python-3.11 + # That adds `_new_member_` property which is an alias to `__new__` + fns = [fn for fn in fns if not inspect.isbuiltin(fn) and hasattr(fn, "__globals__")] + captures = {} + + for fn in fns: + captures.update(get_closure(fn)) + captures.update(get_type_hint_captures(fn)) + + def lookup_in_class(key): + if key in captures: + return captures[key] + else: + return getattr(builtins, key, None) + + return lookup_in_class + + +def boolean_dispatch( + arg_name, + arg_index, + default, + if_true, + if_false, + module_name, + func_name, +): + """ + Dispatches to either of 2 script functions based on a boolean argument. + In TorchScript, the boolean argument must be constant so that the correct + function to use can be determined at compile time. + """ + + def fn(*args, **kwargs): + dispatch_flag = default + if arg_name in kwargs: + dispatch_flag = kwargs[arg_name] + elif arg_index < len(args): + dispatch_flag = args[arg_index] + + if dispatch_flag: + return if_true(*args, **kwargs) + else: + return if_false(*args, **kwargs) + + if if_true.__doc__ is None and if_false.__doc__ is not None: + doc = if_false.__doc__ + if_true.__doc__ = doc + elif if_false.__doc__ is None and if_true.__doc__ is not None: + doc = if_true.__doc__ + if_false.__doc__ = doc + elif if_false.__doc__ is None and if_true.__doc__ is None: + # neither function has a docstring + doc = None + else: + raise RuntimeError("only one function can have a docstring") + fn.__doc__ = doc + + if module_name is not None: + fn.__module__ = module_name + if func_name is not None: + fn.__name__ = func_name + + boolean_dispatched[fn] = { + "if_true": if_true, + "if_false": if_false, + "index": arg_index, + "default": default, + "arg_name": arg_name, + } + return fn + + +class FunctionModifiers: + """ + Used to denote the behavior of a function in TorchScript. See export() and + ignore() for details. + """ + + UNUSED = "unused (ignored and replaced with raising of an exception)" + IGNORE = "ignore (leave as a call to Python, cannot be torch.jit.save'd)" + EXPORT = "export (compile this function even if nothing calls it)" + DEFAULT = "default (compile if called from a exported function / forward)" + COPY_TO_SCRIPT_WRAPPER = ( + "if this method is not scripted, copy the python method onto the scripted model" + ) + _DROP = "_drop (function is fully ignored, declaration can be unscriptable)" + + +def export(fn: Callable[_P, _R]) -> Callable[_P, _R]: + """ + This decorator indicates that a method on an ``nn.Module`` is used as an entry point into a + :class:`ScriptModule` and should be compiled. + + ``forward`` implicitly is assumed to be an entry point, so it does not need this decorator. + Functions and methods called from ``forward`` are compiled as they are seen + by the compiler, so they do not need this decorator either. + + Example (using ``@torch.jit.export`` on a method): + + .. testcode:: + + import torch + import torch.nn as nn + + class MyModule(nn.Module): + def implicitly_compiled_method(self, x): + return x + 99 + + # `forward` is implicitly decorated with `@torch.jit.export`, + # so adding it here would have no effect + def forward(self, x): + return x + 10 + + @torch.jit.export + def another_forward(self, x): + # When the compiler sees this call, it will compile + # `implicitly_compiled_method` + return self.implicitly_compiled_method(x) + + def unused_method(self, x): + return x - 20 + + # `m` will contain compiled methods: + # `forward` + # `another_forward` + # `implicitly_compiled_method` + # `unused_method` will not be compiled since it was not called from + # any compiled methods and wasn't decorated with `@torch.jit.export` + m = torch.jit.script(MyModule()) + """ + fn._torchscript_modifier = FunctionModifiers.EXPORT # type:ignore[attr-defined] + return fn + + +def unused(fn: Callable[_P, _R]) -> Callable[_P, _R]: + """ + This decorator indicates to the compiler that a function or method should + be ignored and replaced with the raising of an exception. This allows you + to leave code in your model that is not yet TorchScript compatible and still + export your model. + + Example (using ``@torch.jit.unused`` on a method):: + + import torch + import torch.nn as nn + + + class MyModule(nn.Module): + def __init__(self, use_memory_efficient): + super().__init__() + self.use_memory_efficient = use_memory_efficient + + @torch.jit.unused + def memory_efficient(self, x): + import pdb + + pdb.set_trace() + return x + 10 + + def forward(self, x): + # Use not-yet-scriptable memory efficient mode + if self.use_memory_efficient: + return self.memory_efficient(x) + else: + return x + 10 + + + m = torch.jit.script(MyModule(use_memory_efficient=False)) + m.save("m.pt") + + m = torch.jit.script(MyModule(use_memory_efficient=True)) + # exception raised + m(torch.rand(100)) + """ + if isinstance(fn, property): + prop = fn + setattr( # noqa: B010 + prop.fget, "_torchscript_modifier", FunctionModifiers.UNUSED + ) + + if prop.fset: + setattr( # noqa: B010 + prop.fset, "_torchscript_modifier", FunctionModifiers.UNUSED + ) + + return prop + + fn._torchscript_modifier = FunctionModifiers.UNUSED # type: ignore[attr-defined] + return fn + + +# No op context manager from python side +class _IgnoreContextManager(contextlib.AbstractContextManager): + def __init__(self, **kwargs): + pass + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + pass + + +def ignore(drop=False, **kwargs): + """ + This decorator indicates to the compiler that a function or method should + be ignored and left as a Python function. This allows you to leave code in + your model that is not yet TorchScript compatible. If called from TorchScript, + ignored functions will dispatch the call to the Python interpreter. Models with ignored + functions cannot be exported; use :func:`@torch.jit.unused ` instead. + + Example (using ``@torch.jit.ignore`` on a method):: + + import torch + import torch.nn as nn + + + class MyModule(nn.Module): + @torch.jit.ignore + def debugger(self, x): + import pdb + + pdb.set_trace() + + def forward(self, x): + x += 10 + # The compiler would normally try to compile `debugger`, + # but since it is `@ignore`d, it will be left as a call + # to Python + self.debugger(x) + return x + + + m = torch.jit.script(MyModule()) + + # Error! The call `debugger` cannot be saved since it calls into Python + m.save("m.pt") + + Example (using ``@torch.jit.ignore(drop=True)`` on a method): + + .. testcode:: + + import torch + import torch.nn as nn + + class MyModule(nn.Module): + @torch.jit.ignore(drop=True) + def training_method(self, x): + import pdb + pdb.set_trace() + + def forward(self, x): + if self.training: + self.training_method(x) + return x + + m = torch.jit.script(MyModule()) + + # This is OK since `training_method` is not saved, the call is replaced + # with a `raise`. + m.save("m.pt") + + .. testcleanup:: + + import os + os.remove('m.pt') + """ + + if callable(drop): + # used without any args, so drop is actually a function + # @torch.jit.ignore + # def fn(...): + fn = drop + fn._torchscript_modifier = FunctionModifiers.IGNORE + return fn + + if not isinstance(drop, bool): + raise RuntimeError( + f"Argument to @torch.jit.ignore must be a bool or a function but got {drop}" + ) + + # for backwards compat + drop_on_export = kwargs.pop("drop_on_export", None) + if drop_on_export: + warnings.warn( + "ignore(drop_on_export=True) has been deprecated. TorchScript will now drop the function " + "call on compilation. Use torch.jit.unused now. {}", + category=FutureWarning, + ) + + drop = drop_on_export + elif drop: + warnings.warn( + "ignore(True) has been deprecated. TorchScript will now drop the function " + "call on compilation. Use torch.jit.unused now. {}", + category=FutureWarning, + ) + + def decorator(fn): + if drop: + fn._torchscript_modifier = FunctionModifiers.UNUSED + else: + fn._torchscript_modifier = FunctionModifiers.IGNORE + return fn + + return decorator + + +def _drop(fn: Callable[_P, _R]) -> Callable[_P, _R]: + fn._torchscript_modifier = FunctionModifiers._DROP # type: ignore[attr-defined] + return fn + + +def _copy_to_script_wrapper(fn: Callable[_P, _R]) -> Callable[_P, _R]: + fn._torchscript_modifier = FunctionModifiers.COPY_TO_SCRIPT_WRAPPER # type: ignore[attr-defined] + return fn + + +def module_has_exports(mod): + for name in dir(mod): + if hasattr(mod, name): + item = getattr(mod, name) + if callable(item): + if get_torchscript_modifier(item) is FunctionModifiers.EXPORT: + return True + return False + + +# WARNING: should_drop is currently being used by our JIT code coverage plug-in to mark JIT'd code as covered. If you +# rename this function, please update references in tools/coverage_plugins_package/src/coverage_plugins/jit_plugin.py to +# allow JIT'd code to still be covered. +def should_drop(fn) -> bool: + attr = get_torchscript_modifier(fn) + if attr is None: + return False + return attr is FunctionModifiers.UNUSED or attr is FunctionModifiers._DROP + + +def is_ignored_fn(fn) -> bool: + mod = get_torchscript_modifier(fn) + return ( + mod is FunctionModifiers.UNUSED + or mod is FunctionModifiers.IGNORE + or mod is FunctionModifiers._DROP + ) + + +def _is_drop_fn(fn) -> bool: + mod = get_torchscript_modifier(fn) + return mod is FunctionModifiers._DROP + + +def is_static_fn(cls, fn) -> bool: + return isinstance(inspect.getattr_static(cls, fn, default=None), staticmethod) + + +def get_static_fn(cls, fn): + return inspect.getattr_static(cls, fn).__func__ + + +def get_torchscript_modifier(fn): + if not callable(fn): + return None + if hasattr(fn, "__func__"): + fn = fn.__func__ + return getattr(fn, "_torchscript_modifier", FunctionModifiers.DEFAULT) + + +def copy_torchscript_modifier(orig, new) -> None: + attr = get_torchscript_modifier(orig) + if attr is None: + return + new._torchscript_modifier = attr + + +# overloading registration +# overloads get registered in this file, and compiled in torch/jit/__init__.py +# so that they can be imported in nn/functional.py without an import cycle + +# qualified_name => list[overload_functions] +_overloaded_fns: dict[str, list[Callable]] = {} # noqa: T484 + + +_OVERLOAD_EXAMPLE = """ +Example usage of overload function: +@torch.jit._overload +def my_function(x: type0) -> type0: # decl 1 + pass + +@torch.jit._overload +def my_function(x: type1) -> type1: # decl 2 + pass + +def my_function(x): # implementation + if isinstance(x, type0): + return x + elif isinstance(x, type1): + return x +""" + + +def get_overload_no_implementation_error_message(kind, obj): + sourcelines, file_lineno, filename = get_source_lines_and_file(obj) + return ( + f'Implementation for the {kind} "{_qualified_name(obj)}" is missing. Please make ' + f"sure a definition is provided and defined after all overload declarations.\n" + f'File "{filename}", line {file_lineno}:\n' + + "".join(sourcelines) + + "\n" + + _OVERLOAD_EXAMPLE + ) + + +def _check_overload_body(func): + try: + parsed_def = parse_def(func) + except OSError: + # Parsing the function definition can raise an OSError if source is unavailable. + # Since this is just an initial check, just raise a warning if this is the case. + warnings.warn( + f"Unable to retrieve source for @torch.jit._overload function: {func}." + ) + return + + body = parsed_def.ast.body[0].body + + def is_pass(x): + return isinstance(x, ast.Pass) + + def is_ellipsis(x): + return ( + isinstance(x, ast.Expr) + and isinstance(x.value, ast.Constant) + and x.value.value is Ellipsis + ) + + if len(body) != 1 or not (is_pass(body[0]) or is_ellipsis(body[0])): + msg = ( + "Only `pass` statement or `...` can be the body of overload declaration:\n" + ) + msg += "\n".join(parsed_def.source.split("\n")[:3]) + msg += " <- Expecting `pass` or `...` here!\n" + _OVERLOAD_EXAMPLE + raise RuntimeError(msg) + + +def _overload(func): + _check_overload_body(func) + qual_name = _qualified_name(func) + global _overloaded_fns + fn_overload_list = _overloaded_fns.get(qual_name) + if fn_overload_list is None: + fn_overload_list = [] + _overloaded_fns[qual_name] = fn_overload_list + fn_overload_list.append(func) + return func + + +def _get_fn_overloads(qual_name): + return _overloaded_fns.get(qual_name) + + +def _clear_fn_overloads(qual_name) -> None: + del _overloaded_fns[qual_name] + + +def get_class_name_lineno(method) -> tuple[str, int]: + current_frame = inspect.currentframe() + + # one for the get_class_name call, one for _overload_method call + for _ in range(2): + assert ( + current_frame is not None + ) # assert current frame is not an Optional[FrameType] + current_frame = current_frame.f_back + + assert current_frame is not None # same here + class_name = current_frame.f_code.co_name + line_no = current_frame.f_code.co_firstlineno + return class_name, line_no + + +# At the point the decorator is applied to class methods the method +# has no reference to its owning class. _qualified_name would not include +# the class it is defined in, so any methods with the same name in the same file +# would have the same _qualified_name, even if they were defined in different +# classes. This problem only exists in python 2. +# We get around this problem by looking at the stack frame and identifying +# the class name, and throwing an error whenever overloads are used +# when modules of the same name are in the same file + +# qualified_name => class name => list[overload_functions] +_overloaded_methods: dict[str, dict[str, list[Callable]]] = {} # noqa: T484 + + +# (qualified_name, class name) => class_fileno +_overloaded_method_class_fileno: dict[tuple[str, str], int] = {} + + +def _overload_method(func): + _check_overload_body(func) + qual_name = _qualified_name(func) + global _overloaded_methods + class_name_map = _overloaded_methods.get(qual_name, None) + if class_name_map is None: + class_name_map = {} + _overloaded_methods[qual_name] = class_name_map + + class_name, line_no = get_class_name_lineno(func) + method_overloads = class_name_map.get(class_name, None) + if method_overloads is None: + method_overloads = [] + class_name_map[class_name] = method_overloads + _overloaded_method_class_fileno[(qual_name, class_name)] = line_no + else: + existing_lineno = _overloaded_method_class_fileno[(qual_name, class_name)] + if existing_lineno != line_no: + raise RuntimeError( + "Cannot currently overload the same method name in two different" + " classes with the same name in the same module" + ) + + method_overloads.append(func) + return func + + +def _get_overloaded_methods(method, mod_class): + # TODO: __name__ not set for submodules in recursive script + if not hasattr(method, "__name__"): + return None + qual_name = _qualified_name(method) + class_name_map = _overloaded_methods.get(qual_name, None) + if class_name_map is None: + return None + overloads = class_name_map.get(mod_class.__name__, None) + if overloads is None: + return None + + method_line_no = get_source_lines_and_file(method)[1] + mod_class_fileno = get_source_lines_and_file(mod_class)[1] + mod_end_fileno = mod_class_fileno + len(get_source_lines_and_file(mod_class)[0]) + if not (method_line_no >= mod_class_fileno and method_line_no <= mod_end_fileno): + raise AssertionError( + "Overloads are not useable when a module is redeclared within the same file: " + + str(method) + ) + return overloads + + +def is_tuple(ann) -> bool: + # Check for typing.Tuple missing args (but `tuple` is fine) + if ann is typing.Tuple: # noqa: UP006 + raise_error_container_parameter_missing("Tuple") + + # For some reason Python 3.7 violates the Type[A, B].__origin__ == Type rule + if not hasattr(ann, "__module__"): + return False + + ann_origin = get_origin(ann) + return ann.__module__ in ("builtins", "typing") and ann_origin is tuple + + +def is_list(ann) -> bool: + # Check for typing.List missing args (but `list` is fine) + if ann is typing.List: # noqa: UP006 + raise_error_container_parameter_missing("List") + + if not hasattr(ann, "__module__"): + return False + + ann_origin = get_origin(ann) + return ann.__module__ in ("builtins", "typing") and ann_origin is list + + +def is_dict(ann) -> bool: + # Check for typing.Dict missing args (but `dict` is fine) + if ann is typing.Dict: # noqa: UP006 + raise_error_container_parameter_missing("Dict") + + if not hasattr(ann, "__module__"): + return False + + ann_origin = get_origin(ann) + return ann.__module__ in ("builtins", "typing") and ann_origin is dict + + +def is_union(ann): + if ann is Union: + raise_error_container_parameter_missing("Union") + + return isinstance(ann, BuiltinUnionType) or ( + hasattr(ann, "__module__") + and ann.__module__ == "typing" + and (get_origin(ann) is Union) + ) + + +def is_optional(ann): + if ann is Optional: + raise_error_container_parameter_missing("Optional") + + def is_optional_as_optional(ann): + return ( + hasattr(ann, "__module__") + and ann.__module__ == "typing" + and (get_origin(ann) is Optional) + ) + + def is_union_as_optional(ann): + ann_args = get_args(ann) + return len(ann_args) == 2 and (None in ann_args or type(None) in ann_args) + + return is_optional_as_optional(ann) or (is_union(ann) and is_union_as_optional(ann)) + + +def is_future(ann) -> bool: + if ann is Future: + raise RuntimeError( + "Attempted to use Future without a " + "contained type. Please add a contained type, e.g. " + "Future[int]" + ) + return get_origin(ann) is Future + + +def is_await(ann) -> bool: + if ann is _Await: + return True + return get_origin(ann) is _Await + + +if torch.distributed.rpc.is_available(): + from torch._C._distributed_rpc import PyRRef + from torch.distributed.rpc import RRef + + def is_rref(ann) -> bool: + if ann is RRef: + raise RuntimeError( + "Attempted to use RRef without a " + "contained type. Please add a contained type, e.g. " + "RRef[int]" + ) + return get_origin(ann) is RRef + + def is_rref_instance(obj) -> bool: + return isinstance(obj, PyRRef) + +else: + + def is_rref_instance(obj) -> bool: + # If the RPC module doesn't exist then RRefs don't exist either. + return False + + +def _try_get_dispatched_fn(fn): + if not callable(fn): + return None + return boolean_dispatched.get(fn) + + +def _get_named_tuple_properties( + obj, + loc: Optional[torch._C._jit_tree_views.SourceRange] = None, + rcb=None, +): + if loc is None: + loc = fake_range() + + assert issubclass(obj, tuple) and hasattr(obj, "_fields") + if hasattr(obj, "_field_defaults"): + defaults = [ + obj._field_defaults[field] + for field in obj._fields + if field in obj._field_defaults + ] + else: + defaults = [] + # In 3.10 recommended way to get annotations is to call `inspect.get_annotations` function + # Also, annotations from base class are not inherited so they need to be queried explicitly + if sys.version_info[:2] < (3, 10): + obj_annotations = getattr(obj, "__annotations__", {}) + else: + obj_annotations = inspect.get_annotations(obj) + if len(obj_annotations) == 0 and hasattr(obj, "__base__"): + obj_annotations = inspect.get_annotations(obj.__base__) + + annotations = [] + for field in obj._fields: + if field in obj_annotations: + field_type = obj_annotations[field] + # [Note: ForwardRef annotations in NamedTuple attributes] + # NamedTuple types are slightly different from normal types. + # + # Normally, annotations are evaluted like this (during jit.script): + # 1. Load strings of python code into c++ and parse. + # 2. Get annotations as strings + # 3. Use the PythonResolver's resolution callback (rcb) to convert + # the string into a python object + # 4. We call into annotations.py:ann_to_type to convert python obj + # from step 3 into a type that torchscript understands. + # + # NamedTuples are more complicated, because it has sub-types. + # Normally, once we have the NamedTuple type object from #3, + # we can just look at the annotation literal values and use + # ann_to_type directly on them. + # + # But sometimes, users will annotate with string literals, e.g. + # x: 'int' + # This also happens with PEP563 (from __forward__ import annotations) + # + # These annotations appear in the annotation dict as ForwardRef('int'). + # + # Then, we need to convert the string into a python object. This + # requires having local context for custom objects or imported types. + # rcb() is what gives us this. So, we plumb rcb through the stack so + # it can be used in this context for the if block below. + # + # FAQ: + # - Why do we need this special handling for NamedTuple but string + # annotations work fine for normal types? Normally, we parse the + # string directly and then call rcb() directly from C++. + # - Why not use ForwardRef._evaluate? For that, we need globals() + # and locals() for the local context where the NamedTuple was defined. + # rcb is what lets us look up into these. So, basically rcb does the + # hard work for us. + if isinstance(field_type, ForwardRef) and rcb is not None: + rcb_type = rcb(field_type.__forward_arg__) + # rcb returns None if it can't find anything. + if rcb_type is None: + raise ValueError( + f"Unknown type annotation: '{field_type}' in NamedTuple {obj.__name__}." + f" Likely due to partial support for ForwardRef parameters in NamedTuples, see #95858." + f" Issue occurred at {loc.highlight()}" + ) + field_type = rcb_type + the_type = torch.jit.annotations.ann_to_type(field_type, loc, rcb) + annotations.append(the_type) + else: + annotations.append(torch._C.TensorType.getInferred()) + return type(obj).__name__, obj._fields, annotations, defaults + + +def _create_named_tuple( + t, + unqual_name: str, + field_names: list[str], + defaults: tuple[Any, ...], +): + TupleType = collections.namedtuple(unqual_name, field_names, defaults=defaults) # type: ignore[call-arg, no-redef, misc] + return TupleType(*t) + + +@contextlib.contextmanager +def _disable_emit_hooks(): + hooks = torch._C._jit_get_emit_hooks() + torch._C._jit_set_emit_hooks(None, None) + try: + yield + finally: + torch._C._jit_set_emit_hooks(hooks[0], hooks[1]) + + +def _disable_emit_hooks_decorator(_DecoratorContextManager) -> None: # noqa: F811 + # noqa: F841 + def __enter__(self) -> None: + self.hooks = torch._C._jit_get_emit_hooks() + torch._C._jit_set_emit_hooks(None, None) + + def __exit__(self, *args) -> None: + torch._C._jit_set_emit_hooks(self.hooks[0], self.hooks[1]) + + +def _is_exception(obj) -> bool: + if not inspect.isclass(obj): + return False + return issubclass(obj, Exception) + + +def raise_error_container_parameter_missing(target_type) -> None: + if target_type.endswith("ict"): + raise RuntimeError( + f"Attempted to use {target_type} without " + "contained types. Please add contained type, e.g. " + f"{target_type}[int, int]" + ) + raise RuntimeError( + f"Attempted to use {target_type} without a " + "contained type. Please add a contained type, e.g. " + f"{target_type}[int]" + ) + + +_RAW_TYPE_NAME_MAPPING = { + dict: "dict", + list: "list", + tuple: "tuple", + typing.Dict: "Dict", # noqa: UP006 + typing.List: "List", # noqa: UP006 + typing.Optional: "Optional", + typing.Tuple: "Tuple", # noqa: UP006 +} + + +def check_args_exist(target_type) -> None: + if name := _RAW_TYPE_NAME_MAPPING.get(target_type): + raise_error_container_parameter_missing(name) + + +def check_empty_containers(obj) -> None: + if obj == [] or obj == {} or obj == (): + warnings.warn( + "The inner type of a container is lost when " + "calling torch.jit.isinstance in eager mode. For " + "example, List[int] would become list and " + "therefore falsely return True for List[float] or" + " List[str]." + ) + + +# supports List/Dict/Tuple and Optional types +# TODO support future +def container_checker(obj, target_type) -> bool: + origin_type = get_origin(target_type) + check_args_exist(target_type) + if origin_type is None: + return False + elif origin_type is list or origin_type is typing.List: # noqa: UP006 + check_empty_containers(obj) + if not isinstance(obj, list): + return False + arg_type = get_args(target_type)[0] + arg_origin = get_origin(arg_type) + for el in obj: + # check if nested container, ex: List[List[str]] + if arg_origin: # processes nested container, ex: List[List[str]] + if not container_checker(el, arg_type): + return False + elif not isinstance(el, arg_type): + return False + return True + elif origin_type is typing.Dict or origin_type is dict: # noqa: UP006 + check_empty_containers(obj) + if not isinstance(obj, dict): + return False + key_type = get_args(target_type)[0] + val_type = get_args(target_type)[1] + for key, val in obj.items(): + # check if keys are of right type + if not isinstance(key, key_type): + return False + val_origin = get_origin(val_type) + if val_origin: + if not container_checker(val, val_type): + return False + elif not isinstance(val, val_type): + return False + return True + elif origin_type is typing.Tuple or origin_type is tuple: # noqa: UP006 + check_empty_containers(obj) + if not isinstance(obj, tuple): + return False + arg_types = get_args(target_type) + if len(obj) != len(arg_types): + return False + for el, el_type in zip(obj, arg_types): + el_origin = get_origin(el_type) + if el_origin: + if not container_checker(el, el_type): + return False + elif not isinstance(el, el_type): + return False + return True + elif origin_type is Union or issubclass( + origin_type, BuiltinUnionType + ): # also handles Optional + if obj is None: # check before recursion because None is always fine + return True + inner_types = get_args(target_type) + for t in inner_types: + t_origin = get_origin(t) + if t_origin: + return container_checker(obj, t) + elif isinstance(obj, t): + return True + return False + + +def _isinstance(obj, target_type) -> bool: + if isinstance(target_type, collections.abc.Container): + if not isinstance(target_type, tuple): + raise RuntimeError( + "The second argument to " + "`torch.jit.isinstance` must be a type " + "or a tuple of types" + ) + for t_type in target_type: + if _isinstance(obj, t_type): + return True + return False + + origin_type = get_origin(target_type) + if origin_type: + return container_checker(obj, target_type) + + # Check to handle non-typed optional origin returns as none instead + # of as optional in 3.7-3.8 + check_args_exist(target_type) + + # handle non-containers + return isinstance(obj, target_type) + + +class _TensorExtractor(pickle.Pickler): + def __init__(self, *args, tensors: list[torch.Tensor], **kwargs): + super().__init__(*args, **kwargs) + self.tensors = tensors + + def persistent_id(self, obj): + if isinstance(obj, torch.Tensor): + self.tensors.append(obj) + return "" + # Since we just want to extract tensors, we don't mind if an object is + # unpicklable if it doesn't contain tensors, as we can just ignore/skip + # it. To play it safe, we only do so for common objects that we're sure + # don't contain tensors. Feel free to add new types here. Note also that + # even if a type isn't listed here this won't block users, since thet + # can just add a __getstate__ or __reduce__ method to their class. + if isinstance(obj, LockType): + return "" + # Futures and RRefs don't technically contain a value, they just offer + # the means to access a value. + if isinstance(obj, CFuture) or is_rref_instance(obj): + return "" + if isinstance(obj, CAwait): + return "" + if isinstance(obj, torch.cuda.Event): + return "" + if isinstance(obj, threading.Thread): + return "" + return None + + +def _extract_tensors(obj): + r""" + This function is exclusively called from C++. + See ``torch/csrc/jit/python/python_ivalue.h``. + + It extracts the tensors contained in the given object, through pickling. + """ + tensors: list[torch.Tensor] = [] + extractor = _TensorExtractor(io.BytesIO(), protocol=-1, tensors=tensors) + extractor.dump(obj) + return tensors + + +def _get_model_id(obj) -> Optional[str]: + if isinstance(obj, torch.jit.ScriptModule): + return str(obj._c._type()) + elif isinstance(obj, torch.jit.ScriptFunction): + return obj.qualified_name + else: + return None + + +# In Python-3.11+ typed enums (i.e. IntEnum for example) retain number of base class methods in subclass +# that were previously dropped. To preserve the behavior, explicitly drop them there + +if sys.version_info >= (3, 11): + _drop(enum.Enum.__new__) + _drop(enum.Enum.__format__) + _drop(enum.Enum.__repr__) + _drop(enum.Enum.__str__) diff --git a/phivenv/Lib/site-packages/torch/_linalg_utils.py b/phivenv/Lib/site-packages/torch/_linalg_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..39ac80571a0c74f8a26012488df0fc82ab06a2db --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_linalg_utils.py @@ -0,0 +1,150 @@ +# mypy: allow-untyped-defs +"""Various linear algebra utility methods for internal use.""" + +from typing import Optional + +import torch +from torch import Tensor + + +def is_sparse(A): + """Check if tensor A is a sparse tensor""" + if isinstance(A, torch.Tensor): + return A.layout == torch.sparse_coo + + error_str = "expected Tensor" + if not torch.jit.is_scripting(): + error_str += f" but got {type(A)}" + raise TypeError(error_str) + + +def get_floating_dtype(A): + """Return the floating point dtype of tensor A. + + Integer types map to float32. + """ + dtype = A.dtype + if dtype in (torch.float16, torch.float32, torch.float64): + return dtype + return torch.float32 + + +def matmul(A: Optional[Tensor], B: Tensor) -> Tensor: + """Multiply two matrices. + + If A is None, return B. A can be sparse or dense. B is always + dense. + """ + if A is None: + return B + if is_sparse(A): + return torch.sparse.mm(A, B) + return torch.matmul(A, B) + + +def bform(X: Tensor, A: Optional[Tensor], Y: Tensor) -> Tensor: + """Return bilinear form of matrices: :math:`X^T A Y`.""" + return matmul(X.mT, matmul(A, Y)) + + +def qform(A: Optional[Tensor], S: Tensor): + """Return quadratic form :math:`S^T A S`.""" + return bform(S, A, S) + + +def basis(A): + """Return orthogonal basis of A columns.""" + return torch.linalg.qr(A).Q + + +def symeig(A: Tensor, largest: Optional[bool] = False) -> tuple[Tensor, Tensor]: + """Return eigenpairs of A with specified ordering.""" + if largest is None: + largest = False + E, Z = torch.linalg.eigh(A, UPLO="U") + # assuming that E is ordered + if largest: + E = torch.flip(E, dims=(-1,)) + Z = torch.flip(Z, dims=(-1,)) + return E, Z + + +# These functions were deprecated and removed +# This nice error message can be removed in version 1.13+ +def matrix_rank(input, tol=None, symmetric=False, *, out=None) -> Tensor: + raise RuntimeError( + "This function was deprecated since version 1.9 and is now removed.\n" + "Please use the `torch.linalg.matrix_rank` function instead. " + "The parameter 'symmetric' was renamed in `torch.linalg.matrix_rank()` to 'hermitian'." + ) + + +def solve(input: Tensor, A: Tensor, *, out=None) -> tuple[Tensor, Tensor]: + raise RuntimeError( + "This function was deprecated since version 1.9 and is now removed. " + "`torch.solve` is deprecated in favor of `torch.linalg.solve`. " + "`torch.linalg.solve` has its arguments reversed and does not return the LU factorization.\n\n" + "To get the LU factorization see `torch.lu`, which can be used with `torch.lu_solve` or `torch.lu_unpack`.\n" + "X = torch.solve(B, A).solution " + "should be replaced with:\n" + "X = torch.linalg.solve(A, B)" + ) + + +def lstsq(input: Tensor, A: Tensor, *, out=None) -> tuple[Tensor, Tensor]: + raise RuntimeError( + "This function was deprecated since version 1.9 and is now removed. " + "`torch.lstsq` is deprecated in favor of `torch.linalg.lstsq`.\n" + "`torch.linalg.lstsq` has reversed arguments and does not return the QR decomposition in " + "the returned tuple (although it returns other information about the problem).\n\n" + "To get the QR decomposition consider using `torch.linalg.qr`.\n\n" + "The returned solution in `torch.lstsq` stored the residuals of the solution in the " + "last m - n columns of the returned value whenever m > n. In torch.linalg.lstsq, " + "the residuals are in the field 'residuals' of the returned named tuple.\n\n" + "The unpacking of the solution, as in\n" + "X, _ = torch.lstsq(B, A).solution[:A.size(1)]\n" + "should be replaced with:\n" + "X = torch.linalg.lstsq(A, B).solution" + ) + + +def _symeig( + input, + eigenvectors=False, + upper=True, + *, + out=None, +) -> tuple[Tensor, Tensor]: + raise RuntimeError( + "This function was deprecated since version 1.9 and is now removed. " + "The default behavior has changed from using the upper triangular portion of the matrix by default " + "to using the lower triangular portion.\n\n" + "L, _ = torch.symeig(A, upper=upper) " + "should be replaced with:\n" + "L = torch.linalg.eigvalsh(A, UPLO='U' if upper else 'L')\n\n" + "and\n\n" + "L, V = torch.symeig(A, eigenvectors=True) " + "should be replaced with:\n" + "L, V = torch.linalg.eigh(A, UPLO='U' if upper else 'L')" + ) + + +def eig( + self: Tensor, + eigenvectors: bool = False, + *, + e=None, + v=None, +) -> tuple[Tensor, Tensor]: + raise RuntimeError( + "This function was deprecated since version 1.9 and is now removed. " + "`torch.linalg.eig` returns complex tensors of dtype `cfloat` or `cdouble` rather than real tensors " + "mimicking complex tensors.\n\n" + "L, _ = torch.eig(A) " + "should be replaced with:\n" + "L_complex = torch.linalg.eigvals(A)\n\n" + "and\n\n" + "L, V = torch.eig(A, eigenvectors=True) " + "should be replaced with:\n" + "L_complex, V_complex = torch.linalg.eig(A)" + ) diff --git a/phivenv/Lib/site-packages/torch/_lobpcg.py b/phivenv/Lib/site-packages/torch/_lobpcg.py new file mode 100644 index 0000000000000000000000000000000000000000..fbd0b6523bf56c95fc7d55e04cd8ba7dadcae4a6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_lobpcg.py @@ -0,0 +1,1154 @@ +# mypy: allow-untyped-defs +"""Locally Optimal Block Preconditioned Conjugate Gradient methods.""" +# Author: Pearu Peterson +# Created: February 2020 + +from typing import Optional + +import torch +from torch import _linalg_utils as _utils, Tensor +from torch.overrides import handle_torch_function, has_torch_function + + +__all__ = ["lobpcg"] + + +def _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U): + # compute F, such that F_ij = (d_j - d_i)^{-1} for i != j, F_ii = 0 + F = D.unsqueeze(-2) - D.unsqueeze(-1) + F.diagonal(dim1=-2, dim2=-1).fill_(float("inf")) + F.pow_(-1) + + # A.grad = U (D.grad + (U^T U.grad * F)) U^T + Ut = U.mT.contiguous() + res = torch.matmul( + U, torch.matmul(torch.diag_embed(D_grad) + torch.matmul(Ut, U_grad) * F, Ut) + ) + + return res + + +def _polynomial_coefficients_given_roots(roots): + """ + Given the `roots` of a polynomial, find the polynomial's coefficients. + + If roots = (r_1, ..., r_n), then the method returns + coefficients (a_0, a_1, ..., a_n (== 1)) so that + p(x) = (x - r_1) * ... * (x - r_n) + = x^n + a_{n-1} * x^{n-1} + ... a_1 * x_1 + a_0 + + Note: for better performance requires writing a low-level kernel + """ + poly_order = roots.shape[-1] + poly_coeffs_shape = list(roots.shape) + # we assume p(x) = x^n + a_{n-1} * x^{n-1} + ... + a_1 * x + a_0, + # so poly_coeffs = {a_0, ..., a_n, a_{n+1}(== 1)}, + # but we insert one extra coefficient to enable better vectorization below + poly_coeffs_shape[-1] += 2 + poly_coeffs = roots.new_zeros(poly_coeffs_shape) + poly_coeffs[..., 0] = 1 + poly_coeffs[..., -1] = 1 + + # perform the Horner's rule + for i in range(1, poly_order + 1): + # note that it is computationally hard to compute backward for this method, + # because then given the coefficients it would require finding the roots and/or + # calculating the sensitivity based on the Vieta's theorem. + # So the code below tries to circumvent the explicit root finding by series + # of operations on memory copies imitating the Horner's method. + # The memory copies are required to construct nodes in the computational graph + # by exploting the explicit (not in-place, separate node for each step) + # recursion of the Horner's method. + # Needs more memory, O(... * k^2), but with only O(... * k^2) complexity. + poly_coeffs_new = poly_coeffs.clone() if roots.requires_grad else poly_coeffs + out = poly_coeffs_new.narrow(-1, poly_order - i, i + 1) + out -= roots.narrow(-1, i - 1, 1) * poly_coeffs.narrow( + -1, poly_order - i + 1, i + 1 + ) + poly_coeffs = poly_coeffs_new + + return poly_coeffs.narrow(-1, 1, poly_order + 1) + + +def _polynomial_value(poly, x, zero_power, transition): + """ + A generic method for computing poly(x) using the Horner's rule. + + Args: + poly (Tensor): the (possibly batched) 1D Tensor representing + polynomial coefficients such that + poly[..., i] = (a_{i_0}, ..., a{i_n} (==1)), and + poly(x) = poly[..., 0] * zero_power + ... + poly[..., n] * x^n + + x (Tensor): the value (possible batched) to evalate the polynomial `poly` at. + + zero_power (Tensor): the representation of `x^0`. It is application-specific. + + transition (Callable): the function that accepts some intermediate result `int_val`, + the `x` and a specific polynomial coefficient + `poly[..., k]` for some iteration `k`. + It basically performs one iteration of the Horner's rule + defined as `x * int_val + poly[..., k] * zero_power`. + Note that `zero_power` is not a parameter, + because the step `+ poly[..., k] * zero_power` depends on `x`, + whether it is a vector, a matrix, or something else, so this + functionality is delegated to the user. + """ + + res = zero_power.clone() + for k in range(poly.size(-1) - 2, -1, -1): + res = transition(res, x, poly[..., k]) + return res + + +def _matrix_polynomial_value(poly, x, zero_power=None): + """ + Evaluates `poly(x)` for the (batched) matrix input `x`. + Check out `_polynomial_value` function for more details. + """ + + # matrix-aware Horner's rule iteration + def transition(curr_poly_val, x, poly_coeff): + res = x.matmul(curr_poly_val) + res.diagonal(dim1=-2, dim2=-1).add_(poly_coeff.unsqueeze(-1)) + return res + + if zero_power is None: + zero_power = torch.eye( + x.size(-1), x.size(-1), dtype=x.dtype, device=x.device + ).view(*([1] * len(list(x.shape[:-2]))), x.size(-1), x.size(-1)) + + return _polynomial_value(poly, x, zero_power, transition) + + +def _vector_polynomial_value(poly, x, zero_power=None): + """ + Evaluates `poly(x)` for the (batched) vector input `x`. + Check out `_polynomial_value` function for more details. + """ + + # vector-aware Horner's rule iteration + def transition(curr_poly_val, x, poly_coeff): + res = torch.addcmul(poly_coeff.unsqueeze(-1), x, curr_poly_val) + return res + + if zero_power is None: + zero_power = x.new_ones(1).expand(x.shape) + + return _polynomial_value(poly, x, zero_power, transition) + + +def _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest): + # compute a projection operator onto an orthogonal subspace spanned by the + # columns of U defined as (I - UU^T) + Ut = U.mT.contiguous() + proj_U_ortho = -U.matmul(Ut) + proj_U_ortho.diagonal(dim1=-2, dim2=-1).add_(1) + + # compute U_ortho, a basis for the orthogonal complement to the span(U), + # by projecting a random [..., m, m - k] matrix onto the subspace spanned + # by the columns of U. + # + # fix generator for determinism + gen = torch.Generator(A.device) + + # orthogonal complement to the span(U) + U_ortho = proj_U_ortho.matmul( + torch.randn( + (*A.shape[:-1], A.size(-1) - D.size(-1)), + dtype=A.dtype, + device=A.device, + generator=gen, + ) + ) + U_ortho_t = U_ortho.mT.contiguous() + + # compute the coefficients of the characteristic polynomial of the tensor D. + # Note that D is diagonal, so the diagonal elements are exactly the roots + # of the characteristic polynomial. + chr_poly_D = _polynomial_coefficients_given_roots(D) + + # the code belows finds the explicit solution to the Sylvester equation + # U_ortho^T A U_ortho dX - dX D = -U_ortho^T A U + # and incorporates it into the whole gradient stored in the `res` variable. + # + # Equivalent to the following naive implementation: + # res = A.new_zeros(A.shape) + # p_res = A.new_zeros(*A.shape[:-1], D.size(-1)) + # for k in range(1, chr_poly_D.size(-1)): + # p_res.zero_() + # for i in range(0, k): + # p_res += (A.matrix_power(k - 1 - i) @ U_grad) * D.pow(i).unsqueeze(-2) + # res -= chr_poly_D[k] * (U_ortho @ poly_D_at_A.inverse() @ U_ortho_t @ p_res @ U.t()) + # + # Note that dX is a differential, so the gradient contribution comes from the backward sensitivity + # Tr(f(U_grad, D_grad, A, U, D)^T dX) = Tr(g(U_grad, A, U, D)^T dA) for some functions f and g, + # and we need to compute g(U_grad, A, U, D) + # + # The naive implementation is based on the paper + # Hu, Qingxi, and Daizhan Cheng. + # "The polynomial solution to the Sylvester matrix equation." + # Applied mathematics letters 19.9 (2006): 859-864. + # + # We can modify the computation of `p_res` from above in a more efficient way + # p_res = U_grad * (chr_poly_D[1] * D.pow(0) + ... + chr_poly_D[k] * D.pow(k)).unsqueeze(-2) + # + A U_grad * (chr_poly_D[2] * D.pow(0) + ... + chr_poly_D[k] * D.pow(k - 1)).unsqueeze(-2) + # + ... + # + A.matrix_power(k - 1) U_grad * chr_poly_D[k] + # Note that this saves us from redundant matrix products with A (elimination of matrix_power) + U_grad_projected = U_grad + series_acc = U_grad_projected.new_zeros(U_grad_projected.shape) + for k in range(1, chr_poly_D.size(-1)): + poly_D = _vector_polynomial_value(chr_poly_D[..., k:], D) + series_acc += U_grad_projected * poly_D.unsqueeze(-2) + U_grad_projected = A.matmul(U_grad_projected) + + # compute chr_poly_D(A) which essentially is: + # + # chr_poly_D_at_A = A.new_zeros(A.shape) + # for k in range(chr_poly_D.size(-1)): + # chr_poly_D_at_A += chr_poly_D[k] * A.matrix_power(k) + # + # Note, however, for better performance we use the Horner's rule + chr_poly_D_at_A = _matrix_polynomial_value(chr_poly_D, A) + + # compute the action of `chr_poly_D_at_A` restricted to U_ortho_t + chr_poly_D_at_A_to_U_ortho = torch.matmul( + U_ortho_t, torch.matmul(chr_poly_D_at_A, U_ortho) + ) + # we need to invert 'chr_poly_D_at_A_to_U_ortho`, for that we compute its + # Cholesky decomposition and then use `torch.cholesky_solve` for better stability. + # Cholesky decomposition requires the input to be positive-definite. + # Note that `chr_poly_D_at_A_to_U_ortho` is positive-definite if + # 1. `largest` == False, or + # 2. `largest` == True and `k` is even + # under the assumption that `A` has distinct eigenvalues. + # + # check if `chr_poly_D_at_A_to_U_ortho` is positive-definite or negative-definite + chr_poly_D_at_A_to_U_ortho_sign = -1 if (largest and (k % 2 == 1)) else +1 + chr_poly_D_at_A_to_U_ortho_L = torch.linalg.cholesky( + chr_poly_D_at_A_to_U_ortho_sign * chr_poly_D_at_A_to_U_ortho + ) + + # compute the gradient part in span(U) + res = _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U) + + # incorporate the Sylvester equation solution into the full gradient + # it resides in span(U_ortho) + res -= U_ortho.matmul( + chr_poly_D_at_A_to_U_ortho_sign + * torch.cholesky_solve( + U_ortho_t.matmul(series_acc), chr_poly_D_at_A_to_U_ortho_L + ) + ).matmul(Ut) + + return res + + +def _symeig_backward(D_grad, U_grad, A, D, U, largest): + # if `U` is square, then the columns of `U` is a complete eigenspace + if U.size(-1) == U.size(-2): + return _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U) + else: + return _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest) + + +class LOBPCGAutogradFunction(torch.autograd.Function): + @staticmethod + def forward( # type: ignore[override] + ctx, + A: Tensor, + k: Optional[int] = None, + B: Optional[Tensor] = None, + X: Optional[Tensor] = None, + n: Optional[int] = None, + iK: Optional[Tensor] = None, + niter: Optional[int] = None, + tol: Optional[float] = None, + largest: Optional[bool] = None, + method: Optional[str] = None, + tracker: None = None, + ortho_iparams: Optional[dict[str, int]] = None, + ortho_fparams: Optional[dict[str, float]] = None, + ortho_bparams: Optional[dict[str, bool]] = None, + ) -> tuple[Tensor, Tensor]: + # makes sure that input is contiguous for efficiency. + # Note: autograd does not support dense gradients for sparse input yet. + A = A.contiguous() if (not A.is_sparse) else A + if B is not None: + B = B.contiguous() if (not B.is_sparse) else B + + D, U = _lobpcg( + A, + k, + B, + X, + n, + iK, + niter, + tol, + largest, + method, + tracker, + ortho_iparams, + ortho_fparams, + ortho_bparams, + ) + + ctx.save_for_backward(A, B, D, U) + ctx.largest = largest + + return D, U + + @staticmethod + def backward(ctx, D_grad, U_grad): + A_grad = B_grad = None + grads = [None] * 14 + + A, B, D, U = ctx.saved_tensors + largest = ctx.largest + + # lobpcg.backward has some limitations. Checks for unsupported input + if A.is_sparse or (B is not None and B.is_sparse and ctx.needs_input_grad[2]): + raise ValueError( + "lobpcg.backward does not support sparse input yet." + "Note that lobpcg.forward does though." + ) + if ( + A.dtype in (torch.complex64, torch.complex128) + or B is not None + and B.dtype in (torch.complex64, torch.complex128) + ): + raise ValueError( + "lobpcg.backward does not support complex input yet." + "Note that lobpcg.forward does though." + ) + if B is not None: + raise ValueError( + "lobpcg.backward does not support backward with B != I yet." + ) + + if largest is None: + largest = True + + # symeig backward + if B is None: + A_grad = _symeig_backward(D_grad, U_grad, A, D, U, largest) + + # A has index 0 + grads[0] = A_grad + # B has index 2 + grads[2] = B_grad + return tuple(grads) + + +def lobpcg( + A: Tensor, + k: Optional[int] = None, + B: Optional[Tensor] = None, + X: Optional[Tensor] = None, + n: Optional[int] = None, + iK: Optional[Tensor] = None, + niter: Optional[int] = None, + tol: Optional[float] = None, + largest: Optional[bool] = None, + method: Optional[str] = None, + tracker: None = None, + ortho_iparams: Optional[dict[str, int]] = None, + ortho_fparams: Optional[dict[str, float]] = None, + ortho_bparams: Optional[dict[str, bool]] = None, +) -> tuple[Tensor, Tensor]: + """Find the k largest (or smallest) eigenvalues and the corresponding + eigenvectors of a symmetric positive definite generalized + eigenvalue problem using matrix-free LOBPCG methods. + + This function is a front-end to the following LOBPCG algorithms + selectable via `method` argument: + + `method="basic"` - the LOBPCG method introduced by Andrew + Knyazev, see [Knyazev2001]. A less robust method, may fail when + Cholesky is applied to singular input. + + `method="ortho"` - the LOBPCG method with orthogonal basis + selection [StathopoulosEtal2002]. A robust method. + + Supported inputs are dense, sparse, and batches of dense matrices. + + .. note:: In general, the basic method spends least time per + iteration. However, the robust methods converge much faster and + are more stable. So, the usage of the basic method is generally + not recommended but there exist cases where the usage of the + basic method may be preferred. + + .. warning:: The backward method does not support sparse and complex inputs. + It works only when `B` is not provided (i.e. `B == None`). + We are actively working on extensions, and the details of + the algorithms are going to be published promptly. + + .. warning:: While it is assumed that `A` is symmetric, `A.grad` is not. + To make sure that `A.grad` is symmetric, so that `A - t * A.grad` is symmetric + in first-order optimization routines, prior to running `lobpcg` + we do the following symmetrization map: `A -> (A + A.t()) / 2`. + The map is performed only when the `A` requires gradients. + + Args: + + A (Tensor): the input tensor of size :math:`(*, m, m)` + + B (Tensor, optional): the input tensor of size :math:`(*, m, + m)`. When not specified, `B` is interpreted as + identity matrix. + + X (tensor, optional): the input tensor of size :math:`(*, m, n)` + where `k <= n <= m`. When specified, it is used as + initial approximation of eigenvectors. X must be a + dense tensor. + + iK (tensor, optional): the input tensor of size :math:`(*, m, + m)`. When specified, it will be used as preconditioner. + + k (integer, optional): the number of requested + eigenpairs. Default is the number of :math:`X` + columns (when specified) or `1`. + + n (integer, optional): if :math:`X` is not specified then `n` + specifies the size of the generated random + approximation of eigenvectors. Default value for `n` + is `k`. If :math:`X` is specified, the value of `n` + (when specified) must be the number of :math:`X` + columns. + + tol (float, optional): residual tolerance for stopping + criterion. Default is `feps ** 0.5` where `feps` is + smallest non-zero floating-point number of the given + input tensor `A` data type. + + largest (bool, optional): when True, solve the eigenproblem for + the largest eigenvalues. Otherwise, solve the + eigenproblem for smallest eigenvalues. Default is + `True`. + + method (str, optional): select LOBPCG method. See the + description of the function above. Default is + "ortho". + + niter (int, optional): maximum number of iterations. When + reached, the iteration process is hard-stopped and + the current approximation of eigenpairs is returned. + For infinite iteration but until convergence criteria + is met, use `-1`. + + tracker (callable, optional) : a function for tracing the + iteration process. When specified, it is called at + each iteration step with LOBPCG instance as an + argument. The LOBPCG instance holds the full state of + the iteration process in the following attributes: + + `iparams`, `fparams`, `bparams` - dictionaries of + integer, float, and boolean valued input + parameters, respectively + + `ivars`, `fvars`, `bvars`, `tvars` - dictionaries + of integer, float, boolean, and Tensor valued + iteration variables, respectively. + + `A`, `B`, `iK` - input Tensor arguments. + + `E`, `X`, `S`, `R` - iteration Tensor variables. + + For instance: + + `ivars["istep"]` - the current iteration step + `X` - the current approximation of eigenvectors + `E` - the current approximation of eigenvalues + `R` - the current residual + `ivars["converged_count"]` - the current number of converged eigenpairs + `tvars["rerr"]` - the current state of convergence criteria + + Note that when `tracker` stores Tensor objects from + the LOBPCG instance, it must make copies of these. + + If `tracker` sets `bvars["force_stop"] = True`, the + iteration process will be hard-stopped. + + ortho_iparams, ortho_fparams, ortho_bparams (dict, optional): + various parameters to LOBPCG algorithm when using + `method="ortho"`. + + Returns: + + E (Tensor): tensor of eigenvalues of size :math:`(*, k)` + + X (Tensor): tensor of eigenvectors of size :math:`(*, m, k)` + + References: + + [Knyazev2001] Andrew V. Knyazev. (2001) Toward the Optimal + Preconditioned Eigensolver: Locally Optimal Block Preconditioned + Conjugate Gradient Method. SIAM J. Sci. Comput., 23(2), + 517-541. (25 pages) + https://epubs.siam.org/doi/abs/10.1137/S1064827500366124 + + [StathopoulosEtal2002] Andreas Stathopoulos and Kesheng + Wu. (2002) A Block Orthogonalization Procedure with Constant + Synchronization Requirements. SIAM J. Sci. Comput., 23(6), + 2165-2182. (18 pages) + https://epubs.siam.org/doi/10.1137/S1064827500370883 + + [DuerschEtal2018] Jed A. Duersch, Meiyue Shao, Chao Yang, Ming + Gu. (2018) A Robust and Efficient Implementation of LOBPCG. + SIAM J. Sci. Comput., 40(5), C655-C676. (22 pages) + https://arxiv.org/abs/1704.07458 + + """ + + if not torch.jit.is_scripting(): + tensor_ops = (A, B, X, iK) + if not set(map(type, tensor_ops)).issubset( + (torch.Tensor, type(None)) + ) and has_torch_function(tensor_ops): + return handle_torch_function( + lobpcg, + tensor_ops, + A, + k=k, + B=B, + X=X, + n=n, + iK=iK, + niter=niter, + tol=tol, + largest=largest, + method=method, + tracker=tracker, + ortho_iparams=ortho_iparams, + ortho_fparams=ortho_fparams, + ortho_bparams=ortho_bparams, + ) + + if not torch._jit_internal.is_scripting(): + if A.requires_grad or (B is not None and B.requires_grad): + # While it is expected that `A` is symmetric, + # the `A_grad` might be not. Therefore we perform the trick below, + # so that `A_grad` becomes symmetric. + # The symmetrization is important for first-order optimization methods, + # so that (A - alpha * A_grad) is still a symmetric matrix. + # Same holds for `B`. + A_sym = (A + A.mT) / 2 + B_sym = (B + B.mT) / 2 if (B is not None) else None + + return LOBPCGAutogradFunction.apply( + A_sym, + k, + B_sym, + X, + n, + iK, + niter, + tol, + largest, + method, + tracker, + ortho_iparams, + ortho_fparams, + ortho_bparams, + ) + else: + if A.requires_grad or (B is not None and B.requires_grad): + raise RuntimeError( + "Script and require grads is not supported atm." + "If you just want to do the forward, use .detach()" + "on A and B before calling into lobpcg" + ) + + return _lobpcg( + A, + k, + B, + X, + n, + iK, + niter, + tol, + largest, + method, + tracker, + ortho_iparams, + ortho_fparams, + ortho_bparams, + ) + + +def _lobpcg( + A: Tensor, + k: Optional[int] = None, + B: Optional[Tensor] = None, + X: Optional[Tensor] = None, + n: Optional[int] = None, + iK: Optional[Tensor] = None, + niter: Optional[int] = None, + tol: Optional[float] = None, + largest: Optional[bool] = None, + method: Optional[str] = None, + tracker: None = None, + ortho_iparams: Optional[dict[str, int]] = None, + ortho_fparams: Optional[dict[str, float]] = None, + ortho_bparams: Optional[dict[str, bool]] = None, +) -> tuple[Tensor, Tensor]: + # A must be square: + assert A.shape[-2] == A.shape[-1], A.shape + if B is not None: + # A and B must have the same shapes: + assert A.shape == B.shape, (A.shape, B.shape) + + dtype = _utils.get_floating_dtype(A) + device = A.device + if tol is None: + feps = {torch.float32: 1.2e-07, torch.float64: 2.23e-16}[dtype] + tol = feps**0.5 + + m = A.shape[-1] + k = (1 if X is None else X.shape[-1]) if k is None else k + n = (k if n is None else n) if X is None else X.shape[-1] + + if m < 3 * n: + raise ValueError( + f"LPBPCG algorithm is not applicable when the number of A rows (={m})" + f" is smaller than 3 x the number of requested eigenpairs (={n})" + ) + + method = "ortho" if method is None else method + + iparams = { + "m": m, + "n": n, + "k": k, + "niter": 1000 if niter is None else niter, + } + + fparams = { + "tol": tol, + } + + bparams = {"largest": True if largest is None else largest} + + if method == "ortho": + if ortho_iparams is not None: + iparams.update(ortho_iparams) + if ortho_fparams is not None: + fparams.update(ortho_fparams) + if ortho_bparams is not None: + bparams.update(ortho_bparams) + iparams["ortho_i_max"] = iparams.get("ortho_i_max", 3) + iparams["ortho_j_max"] = iparams.get("ortho_j_max", 3) + fparams["ortho_tol"] = fparams.get("ortho_tol", tol) + fparams["ortho_tol_drop"] = fparams.get("ortho_tol_drop", tol) + fparams["ortho_tol_replace"] = fparams.get("ortho_tol_replace", tol) + bparams["ortho_use_drop"] = bparams.get("ortho_use_drop", False) + + if not torch.jit.is_scripting(): + LOBPCG.call_tracker = LOBPCG_call_tracker # type: ignore[method-assign] + + if len(A.shape) > 2: + N = int(torch.prod(torch.tensor(A.shape[:-2]))) + bA = A.reshape((N,) + A.shape[-2:]) + bB = B.reshape((N,) + A.shape[-2:]) if B is not None else None + bX = X.reshape((N,) + X.shape[-2:]) if X is not None else None + bE = torch.empty((N, k), dtype=dtype, device=device) + bXret = torch.empty((N, m, k), dtype=dtype, device=device) + + for i in range(N): + A_ = bA[i] + B_ = bB[i] if bB is not None else None + X_ = ( + torch.randn((m, n), dtype=dtype, device=device) if bX is None else bX[i] + ) + assert len(X_.shape) == 2 and X_.shape == (m, n), (X_.shape, (m, n)) + iparams["batch_index"] = i + worker = LOBPCG(A_, B_, X_, iK, iparams, fparams, bparams, method, tracker) + worker.run() + bE[i] = worker.E[:k] + bXret[i] = worker.X[:, :k] + + if not torch.jit.is_scripting(): + LOBPCG.call_tracker = LOBPCG_call_tracker_orig # type: ignore[method-assign] + + return bE.reshape(A.shape[:-2] + (k,)), bXret.reshape(A.shape[:-2] + (m, k)) + + X = torch.randn((m, n), dtype=dtype, device=device) if X is None else X + assert len(X.shape) == 2 and X.shape == (m, n), (X.shape, (m, n)) + + worker = LOBPCG(A, B, X, iK, iparams, fparams, bparams, method, tracker) + + worker.run() + + if not torch.jit.is_scripting(): + LOBPCG.call_tracker = LOBPCG_call_tracker_orig # type: ignore[method-assign] + + return worker.E[:k], worker.X[:, :k] + + +class LOBPCG: + """Worker class of LOBPCG methods.""" + + def __init__( + self, + A: Optional[Tensor], + B: Optional[Tensor], + X: Tensor, + iK: Optional[Tensor], + iparams: dict[str, int], + fparams: dict[str, float], + bparams: dict[str, bool], + method: str, + tracker: None, + ) -> None: + # constant parameters + self.A = A + self.B = B + self.iK = iK + self.iparams = iparams + self.fparams = fparams + self.bparams = bparams + self.method = method + self.tracker = tracker + m = iparams["m"] + n = iparams["n"] + + # variable parameters + self.X = X + self.E = torch.zeros((n,), dtype=X.dtype, device=X.device) + self.R = torch.zeros((m, n), dtype=X.dtype, device=X.device) + self.S = torch.zeros((m, 3 * n), dtype=X.dtype, device=X.device) + self.tvars: dict[str, Tensor] = {} + self.ivars: dict[str, int] = {"istep": 0} + self.fvars: dict[str, float] = {"_": 0.0} + self.bvars: dict[str, bool] = {"_": False} + + def __str__(self): + lines = ["LOPBCG:"] + lines += [f" iparams={self.iparams}"] + lines += [f" fparams={self.fparams}"] + lines += [f" bparams={self.bparams}"] + lines += [f" ivars={self.ivars}"] + lines += [f" fvars={self.fvars}"] + lines += [f" bvars={self.bvars}"] + lines += [f" tvars={self.tvars}"] + lines += [f" A={self.A}"] + lines += [f" B={self.B}"] + lines += [f" iK={self.iK}"] + lines += [f" X={self.X}"] + lines += [f" E={self.E}"] + r = "" + for line in lines: + r += line + "\n" + return r + + def update(self): + """Set and update iteration variables.""" + if self.ivars["istep"] == 0: + X_norm = float(torch.norm(self.X)) + iX_norm = X_norm**-1 + A_norm = float(torch.norm(_utils.matmul(self.A, self.X))) * iX_norm + B_norm = float(torch.norm(_utils.matmul(self.B, self.X))) * iX_norm + self.fvars["X_norm"] = X_norm + self.fvars["A_norm"] = A_norm + self.fvars["B_norm"] = B_norm + self.ivars["iterations_left"] = self.iparams["niter"] + self.ivars["converged_count"] = 0 + self.ivars["converged_end"] = 0 + + if self.method == "ortho": + self._update_ortho() + else: + self._update_basic() + + self.ivars["iterations_left"] = self.ivars["iterations_left"] - 1 + self.ivars["istep"] = self.ivars["istep"] + 1 + + def update_residual(self): + """Update residual R from A, B, X, E.""" + mm = _utils.matmul + self.R = mm(self.A, self.X) - mm(self.B, self.X) * self.E + + def update_converged_count(self): + """Determine the number of converged eigenpairs using backward stable + convergence criterion, see discussion in Sec 4.3 of [DuerschEtal2018]. + + Users may redefine this method for custom convergence criteria. + """ + # (...) -> int + prev_count = self.ivars["converged_count"] + tol = self.fparams["tol"] + A_norm = self.fvars["A_norm"] + B_norm = self.fvars["B_norm"] + E, X, R = self.E, self.X, self.R + rerr = torch.norm(R, 2, (0,)) / ( + torch.norm(X, 2, (0,)) * (A_norm + torch.abs(E[: X.shape[-1]]) * B_norm) + ) + converged = rerr < tol + count = 0 + for b in converged: + if not b: + # ignore convergence of following pairs to ensure + # strict ordering of eigenpairs + break + count += 1 + assert count >= prev_count, ( + f"the number of converged eigenpairs (was {prev_count}, got {count}) cannot decrease" + ) + self.ivars["converged_count"] = count + self.tvars["rerr"] = rerr + return count + + def stop_iteration(self): + """Return True to stop iterations. + + Note that tracker (if defined) can force-stop iterations by + setting ``worker.bvars['force_stop'] = True``. + """ + return ( + self.bvars.get("force_stop", False) + or self.ivars["iterations_left"] == 0 + or self.ivars["converged_count"] >= self.iparams["k"] + ) + + def run(self): + """Run LOBPCG iterations. + + Use this method as a template for implementing LOBPCG + iteration scheme with custom tracker that is compatible with + TorchScript. + """ + self.update() + + if not torch.jit.is_scripting() and self.tracker is not None: + self.call_tracker() + + while not self.stop_iteration(): + self.update() + + if not torch.jit.is_scripting() and self.tracker is not None: + self.call_tracker() + + @torch.jit.unused + def call_tracker(self): + """Interface for tracking iteration process in Python mode. + + Tracking the iteration process is disabled in TorchScript + mode. In fact, one should specify tracker=None when JIT + compiling functions using lobpcg. + """ + # do nothing when in TorchScript mode + + # Internal methods + + def _update_basic(self): + """ + Update or initialize iteration variables when `method == "basic"`. + """ + mm = torch.matmul + ns = self.ivars["converged_end"] + nc = self.ivars["converged_count"] + n = self.iparams["n"] + largest = self.bparams["largest"] + + if self.ivars["istep"] == 0: + Ri = self._get_rayleigh_ritz_transform(self.X) + M = _utils.qform(_utils.qform(self.A, self.X), Ri) + E, Z = _utils.symeig(M, largest) + self.X[:] = mm(self.X, mm(Ri, Z)) + self.E[:] = E + np = 0 + self.update_residual() + nc = self.update_converged_count() + self.S[..., :n] = self.X + + W = _utils.matmul(self.iK, self.R) + self.ivars["converged_end"] = ns = n + np + W.shape[-1] + self.S[:, n + np : ns] = W + else: + S_ = self.S[:, nc:ns] + Ri = self._get_rayleigh_ritz_transform(S_) + M = _utils.qform(_utils.qform(self.A, S_), Ri) + E_, Z = _utils.symeig(M, largest) + self.X[:, nc:] = mm(S_, mm(Ri, Z[:, : n - nc])) + self.E[nc:] = E_[: n - nc] + P = mm(S_, mm(Ri, Z[:, n : 2 * n - nc])) + np = P.shape[-1] + + self.update_residual() + nc = self.update_converged_count() + self.S[..., :n] = self.X + self.S[:, n : n + np] = P + W = _utils.matmul(self.iK, self.R[:, nc:]) + + self.ivars["converged_end"] = ns = n + np + W.shape[-1] + self.S[:, n + np : ns] = W + + def _update_ortho(self): + """ + Update or initialize iteration variables when `method == "ortho"`. + """ + mm = torch.matmul + ns = self.ivars["converged_end"] + nc = self.ivars["converged_count"] + n = self.iparams["n"] + largest = self.bparams["largest"] + + if self.ivars["istep"] == 0: + Ri = self._get_rayleigh_ritz_transform(self.X) + M = _utils.qform(_utils.qform(self.A, self.X), Ri) + _E, Z = _utils.symeig(M, largest) + self.X = mm(self.X, mm(Ri, Z)) + self.update_residual() + np = 0 + nc = self.update_converged_count() + self.S[:, :n] = self.X + W = self._get_ortho(self.R, self.X) + ns = self.ivars["converged_end"] = n + np + W.shape[-1] + self.S[:, n + np : ns] = W + + else: + S_ = self.S[:, nc:ns] + # Rayleigh-Ritz procedure + E_, Z = _utils.symeig(_utils.qform(self.A, S_), largest) + + # Update E, X, P + self.X[:, nc:] = mm(S_, Z[:, : n - nc]) + self.E[nc:] = E_[: n - nc] + P = mm(S_, mm(Z[:, n - nc :], _utils.basis(Z[: n - nc, n - nc :].mT))) + np = P.shape[-1] + + # check convergence + self.update_residual() + nc = self.update_converged_count() + + # update S + self.S[:, :n] = self.X + self.S[:, n : n + np] = P + W = self._get_ortho(self.R[:, nc:], self.S[:, : n + np]) + ns = self.ivars["converged_end"] = n + np + W.shape[-1] + self.S[:, n + np : ns] = W + + def _get_rayleigh_ritz_transform(self, S): + """Return a transformation matrix that is used in Rayleigh-Ritz + procedure for reducing a general eigenvalue problem :math:`(S^TAS) + C = (S^TBS) C E` to a standard eigenvalue problem :math: `(Ri^T + S^TAS Ri) Z = Z E` where `C = Ri Z`. + + .. note:: In the original Rayleight-Ritz procedure in + [DuerschEtal2018], the problem is formulated as follows:: + + SAS = S^T A S + SBS = S^T B S + D = () ** -1/2 + R^T R = Cholesky(D SBS D) + Ri = D R^-1 + solve symeig problem Ri^T SAS Ri Z = Theta Z + C = Ri Z + + To reduce the number of matrix products (denoted by empty + space between matrices), here we introduce element-wise + products (denoted by symbol `*`) so that the Rayleight-Ritz + procedure becomes:: + + SAS = S^T A S + SBS = S^T B S + d = () ** -1/2 # this is 1-d column vector + dd = d d^T # this is 2-d matrix + R^T R = Cholesky(dd * SBS) + Ri = R^-1 * d # broadcasting + solve symeig problem Ri^T SAS Ri Z = Theta Z + C = Ri Z + + where `dd` is 2-d matrix that replaces matrix products `D M + D` with one element-wise product `M * dd`; and `d` replaces + matrix product `D M` with element-wise product `M * + d`. Also, creating the diagonal matrix `D` is avoided. + + Args: + S (Tensor): the matrix basis for the search subspace, size is + :math:`(m, n)`. + + Returns: + Ri (tensor): upper-triangular transformation matrix of size + :math:`(n, n)`. + + """ + B = self.B + SBS = _utils.qform(B, S) + d_row = SBS.diagonal(0, -2, -1) ** -0.5 + d_col = d_row.reshape(d_row.shape[0], 1) + # TODO use torch.linalg.cholesky_solve once it is implemented + R = torch.linalg.cholesky((SBS * d_row) * d_col, upper=True) + return torch.linalg.solve_triangular( + R, d_row.diag_embed(), upper=True, left=False + ) + + def _get_svqb(self, U: Tensor, drop: bool, tau: float) -> Tensor: + """Return B-orthonormal U. + + .. note:: When `drop` is `False` then `svqb` is based on the + Algorithm 4 from [DuerschPhD2015] that is a slight + modification of the corresponding algorithm + introduced in [StathopolousWu2002]. + + Args: + + U (Tensor) : initial approximation, size is (m, n) + drop (bool) : when True, drop columns that + contribution to the `span([U])` is small. + tau (float) : positive tolerance + + Returns: + + U (Tensor) : B-orthonormal columns (:math:`U^T B U = I`), size + is (m, n1), where `n1 = n` if `drop` is `False, + otherwise `n1 <= n`. + + """ + if torch.numel(U) == 0: + return U + UBU = _utils.qform(self.B, U) + d = UBU.diagonal(0, -2, -1) + + # Detect and drop exact zero columns from U. While the test + # `abs(d) == 0` is unlikely to be True for random data, it is + # possible to construct input data to lobpcg where it will be + # True leading to a failure (notice the `d ** -0.5` operation + # in the original algorithm). To prevent the failure, we drop + # the exact zero columns here and then continue with the + # original algorithm below. + nz = torch.where(abs(d) != 0.0) + assert len(nz) == 1, nz + if len(nz[0]) < len(d): + U = U[:, nz[0]] + if torch.numel(U) == 0: + return U + UBU = _utils.qform(self.B, U) + d = UBU.diagonal(0, -2, -1) + nz = torch.where(abs(d) != 0.0) + assert len(nz[0]) == len(d) + + # The original algorithm 4 from [DuerschPhD2015]. + d_col = (d**-0.5).reshape(d.shape[0], 1) + DUBUD = (UBU * d_col) * d_col.mT + E, Z = _utils.symeig(DUBUD) + t = tau * abs(E).max() + if drop: + keep = torch.where(E > t) + assert len(keep) == 1, keep + E = E[keep[0]] + Z = Z[:, keep[0]] + d_col = d_col[keep[0]] + else: + E[(torch.where(E < t))[0]] = t + + return torch.matmul(U * d_col.mT, Z * E**-0.5) + + def _get_ortho(self, U, V): + """Return B-orthonormal U with columns are B-orthogonal to V. + + .. note:: When `bparams["ortho_use_drop"] == False` then + `_get_ortho` is based on the Algorithm 3 from + [DuerschPhD2015] that is a slight modification of + the corresponding algorithm introduced in + [StathopolousWu2002]. Otherwise, the method + implements Algorithm 6 from [DuerschPhD2015] + + .. note:: If all U columns are B-collinear to V then the + returned tensor U will be empty. + + Args: + + U (Tensor) : initial approximation, size is (m, n) + V (Tensor) : B-orthogonal external basis, size is (m, k) + + Returns: + + U (Tensor) : B-orthonormal columns (:math:`U^T B U = I`) + such that :math:`V^T B U=0`, size is (m, n1), + where `n1 = n` if `drop` is `False, otherwise + `n1 <= n`. + """ + mm = torch.matmul + mm_B = _utils.matmul + m = self.iparams["m"] + tau_ortho = self.fparams["ortho_tol"] + tau_drop = self.fparams["ortho_tol_drop"] + tau_replace = self.fparams["ortho_tol_replace"] + i_max = self.iparams["ortho_i_max"] + j_max = self.iparams["ortho_j_max"] + # when use_drop==True, enable dropping U columns that have + # small contribution to the `span([U, V])`. + use_drop = self.bparams["ortho_use_drop"] + + # clean up variables from the previous call + for vkey in list(self.fvars.keys()): + if vkey.startswith("ortho_") and vkey.endswith("_rerr"): + self.fvars.pop(vkey) + self.ivars.pop("ortho_i", 0) + self.ivars.pop("ortho_j", 0) + + BV_norm = torch.norm(mm_B(self.B, V)) + BU = mm_B(self.B, U) + VBU = mm(V.mT, BU) + i = j = 0 + for i in range(i_max): + U = U - mm(V, VBU) + drop = False + tau_svqb = tau_drop + for j in range(j_max): + if use_drop: + U = self._get_svqb(U, drop, tau_svqb) + drop = True + tau_svqb = tau_replace + else: + U = self._get_svqb(U, False, tau_replace) + if torch.numel(U) == 0: + # all initial U columns are B-collinear to V + self.ivars["ortho_i"] = i + self.ivars["ortho_j"] = j + return U + BU = mm_B(self.B, U) + UBU = mm(U.mT, BU) + U_norm = torch.norm(U) + BU_norm = torch.norm(BU) + R = UBU - torch.eye(UBU.shape[-1], device=UBU.device, dtype=UBU.dtype) + R_norm = torch.norm(R) + # https://github.com/pytorch/pytorch/issues/33810 workaround: + rerr = float(R_norm) * float(BU_norm * U_norm) ** -1 + vkey = f"ortho_UBUmI_rerr[{i}, {j}]" + self.fvars[vkey] = rerr + if rerr < tau_ortho: + break + VBU = mm(V.mT, BU) + VBU_norm = torch.norm(VBU) + U_norm = torch.norm(U) + rerr = float(VBU_norm) * float(BV_norm * U_norm) ** -1 + vkey = f"ortho_VBU_rerr[{i}]" + self.fvars[vkey] = rerr + if rerr < tau_ortho: + break + if m < U.shape[-1] + V.shape[-1]: + # TorchScript needs the class var to be assigned to a local to + # do optional type refinement + B = self.B + assert B is not None + raise ValueError( + "Overdetermined shape of U:" + f" #B-cols(={B.shape[-1]}) >= #U-cols(={U.shape[-1]}) + #V-cols(={V.shape[-1]}) must hold" + ) + self.ivars["ortho_i"] = i + self.ivars["ortho_j"] = j + return U + + +# Calling tracker is separated from LOBPCG definitions because +# TorchScript does not support user-defined callback arguments: +LOBPCG_call_tracker_orig = LOBPCG.call_tracker + + +def LOBPCG_call_tracker(self): + self.tracker(self) diff --git a/phivenv/Lib/site-packages/torch/_lowrank.py b/phivenv/Lib/site-packages/torch/_lowrank.py new file mode 100644 index 0000000000000000000000000000000000000000..3887c984776fb6deb8d8933c76305c61768ca822 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_lowrank.py @@ -0,0 +1,294 @@ +"""Implement various linear algebra algorithms for low rank matrices.""" + +__all__ = ["svd_lowrank", "pca_lowrank"] + +from typing import Optional + +import torch +from torch import _linalg_utils as _utils, Tensor +from torch.overrides import handle_torch_function, has_torch_function + + +def get_approximate_basis( + A: Tensor, + q: int, + niter: Optional[int] = 2, + M: Optional[Tensor] = None, +) -> Tensor: + """Return tensor :math:`Q` with :math:`q` orthonormal columns such + that :math:`Q Q^H A` approximates :math:`A`. If :math:`M` is + specified, then :math:`Q` is such that :math:`Q Q^H (A - M)` + approximates :math:`A - M`. without instantiating any tensors + of the size of :math:`A` or :math:`M`. + + .. note:: The implementation is based on the Algorithm 4.4 from + Halko et al., 2009. + + .. note:: For an adequate approximation of a k-rank matrix + :math:`A`, where k is not known in advance but could be + estimated, the number of :math:`Q` columns, q, can be + choosen according to the following criteria: in general, + :math:`k <= q <= min(2*k, m, n)`. For large low-rank + matrices, take :math:`q = k + 5..10`. If k is + relatively small compared to :math:`min(m, n)`, choosing + :math:`q = k + 0..2` may be sufficient. + + .. note:: To obtain repeatable results, reset the seed for the + pseudorandom number generator + + Args:: + A (Tensor): the input tensor of size :math:`(*, m, n)` + + q (int): the dimension of subspace spanned by :math:`Q` + columns. + + niter (int, optional): the number of subspace iterations to + conduct; ``niter`` must be a + nonnegative integer. In most cases, the + default value 2 is more than enough. + + M (Tensor, optional): the input tensor's mean of size + :math:`(*, m, n)`. + + References:: + - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding + structure with randomness: probabilistic algorithms for + constructing approximate matrix decompositions, + arXiv:0909.4061 [math.NA; math.PR], 2009 (available at + `arXiv `_). + """ + + niter = 2 if niter is None else niter + dtype = _utils.get_floating_dtype(A) if not A.is_complex() else A.dtype + matmul = _utils.matmul + + R = torch.randn(A.shape[-1], q, dtype=dtype, device=A.device) + + # The following code could be made faster using torch.geqrf + torch.ormqr + # but geqrf is not differentiable + + X = matmul(A, R) + if M is not None: + X = X - matmul(M, R) + Q = torch.linalg.qr(X).Q + for _ in range(niter): + X = matmul(A.mH, Q) + if M is not None: + X = X - matmul(M.mH, Q) + Q = torch.linalg.qr(X).Q + X = matmul(A, Q) + if M is not None: + X = X - matmul(M, Q) + Q = torch.linalg.qr(X).Q + return Q + + +def svd_lowrank( + A: Tensor, + q: Optional[int] = 6, + niter: Optional[int] = 2, + M: Optional[Tensor] = None, +) -> tuple[Tensor, Tensor, Tensor]: + r"""Return the singular value decomposition ``(U, S, V)`` of a matrix, + batches of matrices, or a sparse matrix :math:`A` such that + :math:`A \approx U \operatorname{diag}(S) V^{\text{H}}`. In case :math:`M` is given, then + SVD is computed for the matrix :math:`A - M`. + + .. note:: The implementation is based on the Algorithm 5.1 from + Halko et al., 2009. + + .. note:: For an adequate approximation of a k-rank matrix + :math:`A`, where k is not known in advance but could be + estimated, the number of :math:`Q` columns, q, can be + choosen according to the following criteria: in general, + :math:`k <= q <= min(2*k, m, n)`. For large low-rank + matrices, take :math:`q = k + 5..10`. If k is + relatively small compared to :math:`min(m, n)`, choosing + :math:`q = k + 0..2` may be sufficient. + + .. note:: This is a randomized method. To obtain repeatable results, + set the seed for the pseudorandom number generator + + .. note:: In general, use the full-rank SVD implementation + :func:`torch.linalg.svd` for dense matrices due to its 10x + higher performance characteristics. The low-rank SVD + will be useful for huge sparse matrices that + :func:`torch.linalg.svd` cannot handle. + + Args:: + A (Tensor): the input tensor of size :math:`(*, m, n)` + + q (int, optional): a slightly overestimated rank of A. + + niter (int, optional): the number of subspace iterations to + conduct; niter must be a nonnegative + integer, and defaults to 2 + + M (Tensor, optional): the input tensor's mean of size + :math:`(*, m, n)`, which will be broadcasted + to the size of A in this function. + + References:: + - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding + structure with randomness: probabilistic algorithms for + constructing approximate matrix decompositions, + arXiv:0909.4061 [math.NA; math.PR], 2009 (available at + `arXiv `_). + + """ + if not torch.jit.is_scripting(): + tensor_ops = (A, M) + if not set(map(type, tensor_ops)).issubset( + (torch.Tensor, type(None)) + ) and has_torch_function(tensor_ops): + return handle_torch_function( + svd_lowrank, tensor_ops, A, q=q, niter=niter, M=M + ) + return _svd_lowrank(A, q=q, niter=niter, M=M) + + +def _svd_lowrank( + A: Tensor, + q: Optional[int] = 6, + niter: Optional[int] = 2, + M: Optional[Tensor] = None, +) -> tuple[Tensor, Tensor, Tensor]: + # Algorithm 5.1 in Halko et al., 2009 + + q = 6 if q is None else q + m, n = A.shape[-2:] + matmul = _utils.matmul + if M is not None: + M = M.broadcast_to(A.size()) + + # Assume that A is tall + if m < n: + A = A.mH + if M is not None: + M = M.mH + + Q = get_approximate_basis(A, q, niter=niter, M=M) + B = matmul(Q.mH, A) + if M is not None: + B = B - matmul(Q.mH, M) + U, S, Vh = torch.linalg.svd(B, full_matrices=False) + V = Vh.mH + U = Q.matmul(U) + + if m < n: + U, V = V, U + + return U, S, V + + +def pca_lowrank( + A: Tensor, + q: Optional[int] = None, + center: bool = True, + niter: int = 2, +) -> tuple[Tensor, Tensor, Tensor]: + r"""Performs linear Principal Component Analysis (PCA) on a low-rank + matrix, batches of such matrices, or sparse matrix. + + This function returns a namedtuple ``(U, S, V)`` which is the + nearly optimal approximation of a singular value decomposition of + a centered matrix :math:`A` such that :math:`A \approx U \operatorname{diag}(S) V^{\text{H}}` + + .. note:: The relation of ``(U, S, V)`` to PCA is as follows: + + - :math:`A` is a data matrix with ``m`` samples and + ``n`` features + + - the :math:`V` columns represent the principal directions + + - :math:`S ** 2 / (m - 1)` contains the eigenvalues of + :math:`A^T A / (m - 1)` which is the covariance of + ``A`` when ``center=True`` is provided. + + - ``matmul(A, V[:, :k])`` projects data to the first k + principal components + + .. note:: Different from the standard SVD, the size of returned + matrices depend on the specified rank and q + values as follows: + + - :math:`U` is m x q matrix + + - :math:`S` is q-vector + + - :math:`V` is n x q matrix + + .. note:: To obtain repeatable results, reset the seed for the + pseudorandom number generator + + Args: + + A (Tensor): the input tensor of size :math:`(*, m, n)` + + q (int, optional): a slightly overestimated rank of + :math:`A`. By default, ``q = min(6, m, + n)``. + + center (bool, optional): if True, center the input tensor, + otherwise, assume that the input is + centered. + + niter (int, optional): the number of subspace iterations to + conduct; niter must be a nonnegative + integer, and defaults to 2. + + References:: + + - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding + structure with randomness: probabilistic algorithms for + constructing approximate matrix decompositions, + arXiv:0909.4061 [math.NA; math.PR], 2009 (available at + `arXiv `_). + + """ + + if not torch.jit.is_scripting(): + if type(A) is not torch.Tensor and has_torch_function((A,)): + return handle_torch_function( + pca_lowrank, (A,), A, q=q, center=center, niter=niter + ) + + (m, n) = A.shape[-2:] + + if q is None: + q = min(6, m, n) + elif not (q >= 0 and q <= min(m, n)): + raise ValueError( + f"q(={q}) must be non-negative integer and not greater than min(m, n)={min(m, n)}" + ) + if not (niter >= 0): + raise ValueError(f"niter(={niter}) must be non-negative integer") + + dtype = _utils.get_floating_dtype(A) + + if not center: + return _svd_lowrank(A, q, niter=niter, M=None) + + if _utils.is_sparse(A): + if len(A.shape) != 2: + raise ValueError("pca_lowrank input is expected to be 2-dimensional tensor") + c = torch.sparse.sum(A, dim=(-2,)) / m + # reshape c + column_indices = c.indices()[0] + indices = torch.zeros( + 2, + len(column_indices), + dtype=column_indices.dtype, + device=column_indices.device, + ) + indices[0] = column_indices + C_t = torch.sparse_coo_tensor( + indices, c.values(), (n, 1), dtype=dtype, device=A.device + ) + + ones_m1_t = torch.ones(A.shape[:-2] + (1, m), dtype=dtype, device=A.device) + M = torch.sparse.mm(C_t, ones_m1_t).mT + return _svd_lowrank(A, q, niter=niter, M=M) + else: + C = A.mean(dim=(-2,), keepdim=True) + return _svd_lowrank(A - C, q, niter=niter, M=None) diff --git a/phivenv/Lib/site-packages/torch/_meta_registrations.py b/phivenv/Lib/site-packages/torch/_meta_registrations.py new file mode 100644 index 0000000000000000000000000000000000000000..d5073ec1cdc03adb979972b70a98682318e69d43 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_meta_registrations.py @@ -0,0 +1,8016 @@ +# mypy: allow-untyped-defs +import math +import operator +from collections.abc import Sequence +from enum import Enum +from functools import reduce, wraps +from typing import Callable, Optional, TypeVar, Union +from typing_extensions import ParamSpec + +import torch +import torch._prims_common as utils +from torch import SymBool, SymFloat, Tensor +from torch._decomp import ( + _add_op_to_registry, + _convert_out_params, + global_decomposition_table, + meta_table, +) +from torch._ops import OpOverload +from torch._prims import ( + _prim_elementwise_meta, + ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND, + view_of, +) +from torch._prims_common import ( + BoolLike, + corresponding_complex_dtype, + corresponding_real_dtype, + definitely_contiguous, + elementwise_dtypes, + ELEMENTWISE_TYPE_PROMOTION_KIND, + FloatLike, + IntLike, + is_contiguous, + make_contiguous_strides_for, + Number, + suggest_memory_format, + TensorLike, +) +from torch._prims_common.wrappers import ( + _maybe_convert_to_dtype, + _maybe_resize_out, + _resize_output_check, + _safe_copy_out, + out_wrapper, +) +from torch._refs import _broadcast_shapes, _maybe_broadcast +from torch.fx.experimental import _config as exp_config +from torch.utils import _pytree as pytree + + +_T = TypeVar("_T") +_P = ParamSpec("_P") + +aten = torch.ops.aten + +_meta_lib_dont_use_me_use_register_meta = torch.library.Library("aten", "IMPL", "Meta") +MODE_SUM, MODE_MEAN, MODE_MAX = range(3) + + +def register_meta(op) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + def wrapper(fn): + fn = _convert_out_params(fn) + + def register(op): + _add_op_to_registry(meta_table, op, fn) + + pytree.tree_map_(register, op) + return fn + + return wrapper + + +def elementwise_meta( + *args, + type_promotion: ELEMENTWISE_TYPE_PROMOTION_KIND, +): + # Perform type promotion, as this is expected from prim_metafunction + _, result_dtype = utils.elementwise_dtypes( + *args, + type_promotion_kind=type_promotion, + ) + args = [_maybe_convert_to_dtype(x, result_dtype) for x in args] + + # Broadcast + args = _maybe_broadcast(*args) + + # Perform prim checks + return _prim_elementwise_meta( + *args, type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT + ) + + +def toRealValueType(dtype): + from_complex = { + torch.complex32: torch.half, + torch.cfloat: torch.float, + torch.cdouble: torch.double, + } + return from_complex.get(dtype, dtype) + + +def check_inplace_broadcast(self_shape, *args_shape): + broadcasted_shape = tuple(_broadcast_shapes(self_shape, *args_shape)) + torch._check( + broadcasted_shape == self_shape, + lambda: f"output with shape {self_shape} doesn't match the broadcast shape {broadcasted_shape}", + ) + + +@register_meta([aten.linspace, aten.logspace]) +@out_wrapper() +def meta_linspace_logspace( + start, + end, + steps, + base=None, + dtype=None, + device=None, + layout=torch.strided, + pin_memory=False, + requires_grad=False, +): + if isinstance(start, torch.Tensor): + torch._check( + start.dim() == 0, + lambda: "linspace only supports 0-dimensional start and end tensors", + ) + if isinstance(end, torch.Tensor): + torch._check( + end.dim() == 0, + lambda: "linspace only supports 0-dimensional start and end tensors", + ) + + if any(isinstance(arg, complex) for arg in (start, end, steps)): + default_complex_dtype = utils.corresponding_complex_dtype( + torch.get_default_dtype() + ) + if dtype is None: + dtype = default_complex_dtype + else: + torch._check( + utils.is_complex_dtype(dtype), + lambda: f"linspace(): inferred dtype {default_complex_dtype} can't be safely cast to passed dtype {dtype}", + ) + else: + dtype = dtype or torch.get_default_dtype() + assert isinstance(dtype, torch.dtype) + + # steps does not participate in the computation of the dtype + torch._check_type( + isinstance(steps, IntLike), + lambda: f"received an invalid combination of arguments - got \ +({type(start).__name__}, {type(end).__name__}, {type(steps).__name__})", + ) + assert isinstance(steps, IntLike) # for mypy + torch._check(steps >= 0, lambda: "number of steps must be non-negative") + + return torch.empty( + (steps,), # type: ignore[arg-type] + dtype=dtype, + layout=layout, + device="meta", + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + + +@register_meta([aten.take.default, aten.take.out]) +@out_wrapper() +def meta_take(self, index): + # Type and device checks + torch._check( + index.dtype == torch.long, + lambda: f"take(): Expected a long tensor for index, but got {index.dtype}", + ) + # Index checks + torch._check_index( + not (self.numel() == 0 and index.numel() != 0), + lambda: "take(): tried to take from an empty tensor", + ) + return self.new_empty(index.shape) + + +@register_meta([aten.linalg_cross.default, aten.linalg_cross.out]) +@out_wrapper() +def linalg_cross(self, other, *, dim=-1): + x_d = self.ndim + y_d = other.ndim + torch._check( + x_d == y_d, + lambda: "linalg.cross: inputs must have the same number of dimensions.", + ) + torch._check( + self.size(dim) == 3 and other.size(dim) == 3, + lambda: ( + f"linalg.cross: inputs dimension {dim} must have length 3. " + f"Got {self.size(dim)} and {other.size(dim)}" + ), + ) + out_shape = _broadcast_shapes(self.shape, other.shape) + return self.new_empty(out_shape) + + +# This function is python match of computeStride_impl in TensorUtils.cpp +def _compute_stride(old_shape, old_stride, new_shape, size_oblivious=False): + from torch.fx.experimental.symbolic_shapes import ( + guard_or_false, + guard_or_true, + sym_eq, + ) + + def maybe_guard_or_false(x): + if size_oblivious: + return guard_or_false(x) + + return x + + def maybe_guard_or_true(x): + if size_oblivious: + return guard_or_true(x) + + return x + + if len(old_shape) == 0: + return [1] * len(new_shape) + + numel = reduce(operator.mul, old_shape, 1) + zero_numel = maybe_guard_or_false(numel == 0) + if zero_numel and maybe_guard_or_false(sym_eq(old_shape, new_shape)): + return old_stride + + new_stride = [0] * len(new_shape) + + if zero_numel: + for view_d in range(len(new_shape) - 1, -1, -1): + if view_d == len(new_shape) - 1: + new_stride[view_d] = 1 + else: + new_stride[view_d] = ( + max(new_shape[view_d + 1], 1) * new_stride[view_d + 1] + ) + return new_stride + + view_d = len(new_shape) - 1 + chunk_base_stride = old_stride[-1] + tensor_numel = 1 + view_numel = 1 + + for tensor_d in range(len(old_shape) - 1, -1, -1): + tensor_numel *= old_shape[tensor_d] + + if tensor_d == 0 or ( + maybe_guard_or_true(old_shape[tensor_d - 1] != 1) + and maybe_guard_or_true( + old_stride[tensor_d - 1] != tensor_numel * chunk_base_stride + ) + ): + while view_d >= 0 and ( + maybe_guard_or_true(view_numel < tensor_numel) + or maybe_guard_or_false(new_shape[view_d] == 1) + ): + new_stride[view_d] = view_numel * chunk_base_stride + view_numel *= new_shape[view_d] + view_d -= 1 + + if maybe_guard_or_true(view_numel != tensor_numel): + return None + + if tensor_d > 0: + chunk_base_stride = old_stride[tensor_d - 1] + tensor_numel = 1 + view_numel = 1 + if view_d != -1: + return None + return new_stride + + +def _view_has_unbacked_input(a, shape): + from torch.fx.experimental.symbolic_shapes import has_hint + + return ( + any(not has_hint(s) for s in a.size()) + or any(not has_hint(s) for s in a.stride()) + or any(not has_hint(s) for s in shape) + ) + + +def _view_unbacked_meta(a, shape, size_oblivious_enabled=True): + from torch.fx.experimental.symbolic_shapes import guard_or_false, sym_eq + + # Creates a valid shape + shape = utils.extract_shape_from_varargs(shape, validate=False) + + # Reshape may be given a shape with a -1 length + # This indicates that the dimension's length should be inferred + shape = utils.infer_size(shape, a.numel()) + + # Special-cases reshaping zero dim tensors + if a.ndim == 0: + _a = a + for length in shape: + torch._check(length == 1) + _a = torch._refs.unsqueeze(_a, -1) + if _a is a: + return view_of(a) + else: + return _a + + # Special-cases reshaping to zero dim tensors + if len(shape) == 0: + _a = a + for length in a.shape: + torch._check(length == 1) + _a = torch._refs.squeeze(_a, -1) + if _a is a: + return view_of(a) + else: + return _a + + shape_numel = reduce(operator.mul, shape, 1) + + torch._check( + a.numel() == shape_numel, + lambda: f"Could not reshape a tensor with shape {a.shape} as a tensor with shape {shape}!", + ) + + if len(shape) == len(a.shape) and guard_or_false(sym_eq(shape, a.shape)): + return view_of(a) + + if definitely_contiguous(a) if size_oblivious_enabled else is_contiguous(a): + strides = utils.make_contiguous_strides_for(shape) + return a.as_strided(shape, strides) + + new_strides = _compute_stride( + a.size(), a.stride(), shape, size_oblivious=size_oblivious_enabled + ) + + if new_strides is not None: + return a.as_strided(shape, new_strides) + + # If we fail to do size oblivious view, and backed_size_oblivious was on, + # then we redo everything by looking at hints and guarding instead of failing. + # Also if the expression has unbacked symbols, then we run again with size_oblivious_enabled=False + # to throw a data dependent error. + + if size_oblivious_enabled and ( + torch.fx.experimental._config.backed_size_oblivious + or _view_has_unbacked_input(a, shape) + ): + return _view_unbacked_meta(a, shape, size_oblivious_enabled=False) + + msg = f"Cannot view a tensor with shape {a.shape} and strides {a.stride()} as a tensor with shape {shape}!" + raise ValueError(msg) + + +@register_meta(aten.view.default) +def _view_meta(a, *shape): + if torch.fx.experimental._config.backed_size_oblivious or _view_has_unbacked_input( + a, shape + ): + return _view_unbacked_meta(a, shape) + else: + return torch._refs._reshape_view_helper(a, *shape, allow_copy=False) + + +@register_meta(aten.linalg_matrix_exp) +@out_wrapper() +def linalg_matrix_exp(self): + squareCheckInputs(self, "linalg.matrix_exp") + checkFloatingOrComplex(self, "linalg.matrix_exp") + return torch.empty_like(self, memory_format=torch.contiguous_format) + + +@register_meta( + [aten.cummax.default, aten.cummax.out, aten.cummin.default, aten.cummin.out] +) +@out_wrapper("values", "indices") +def cummaxmin(self, dim): + values = torch.empty(self.shape, device=self.device, dtype=self.dtype) + indices = torch.empty(self.shape, device=self.device, dtype=torch.int64) + if self.numel() != 0 and self.ndim != 0: + # Checks that dim is within bounds + maybe_wrap_dim(dim, self.ndim) + return values, indices + + +@register_meta([aten.logcumsumexp.default, aten.logcumsumexp.out]) +@out_wrapper() +def logcumsumexp(self, dim): + # Checks that dim is within bounds + maybe_wrap_dim(dim, self.ndim) + return torch.empty_like(self, memory_format=torch.contiguous_format) + + +# Stride-related code from _exec_fft in aten/src/ATen/native/mkl/SpectralOps.cpp +# and aten/src/ATen/cuda/SpectralOps.cpp +# +# Although the actual FFT launch is different, all the permuting code appears +# to be the same +def _exec_fft(out, self, out_sizes, dim, *, forward): + ndim = self.ndim + signal_ndim = len(dim) + batch_dims = ndim - signal_ndim + + # Permute dimensions so batch dimensions come first, and in stride order + dim_permute = list(range(ndim)) + + is_transformed_dim = [False for _ in range(ndim)] + for d in dim: + is_transformed_dim[d] = True + + # std::partition + left, right = [], [] + for d in dim_permute: + if not is_transformed_dim[d]: + left.append(d) + else: + right.append(d) + dim_permute = left + right + batch_end = len(left) + + self_strides = self.stride() + tmp = dim_permute[:batch_end] + tmp.sort(key=lambda x: self_strides[x], reverse=True) + dim_permute = tmp + dim_permute[batch_end:] + input = self.permute(dim_permute) + + # Collapse batch dimensions into a single dimension + batched_sizes = [-1] + list(input.shape[batch_dims:]) + input = input.reshape(batched_sizes) + + batch_size = input.size(0) + batched_sizes[0] = batch_size + batched_out_sizes = list(batched_sizes) + for i in range(len(dim)): + batched_out_sizes[i + 1] = out_sizes[dim[i]] + out.resize_(batched_out_sizes, memory_format=torch.contiguous_format) + + # Inplace reshaping to original batch shape and inverting the dimension permutation + out_strides = [0 for _ in range(ndim)] + batch_numel = 1 + i = batch_dims - 1 + while i >= 0: + out_strides[dim_permute[i]] = batch_numel * out.stride(0) + batch_numel *= out_sizes[dim_permute[i]] + i -= 1 + for i in range(batch_dims, ndim): + out_strides[dim_permute[i]] = out.stride(1 + (i - batch_dims)) + out.as_strided_(out_sizes, out_strides, out.storage_offset()) + + return out + + +def _sort_dims(self: Tensor, dim: list[int], exclude_last: bool = False): + sorted_dims = list(dim) + self_strides = self.stride() + sorted_dims[: len(sorted_dims) - int(exclude_last)].sort( + key=lambda i: self_strides[i] + ) + return sorted_dims + + +# See _fft_c2c_cufft in aten/src/ATen/native/cuda/SpectralOps.cpp +# and _fft_c2c_mkl in aten/src/ATen/native/mkl/SpectralOps.cpp +@register_meta([aten._fft_c2c.default, aten._fft_c2c.out]) +@out_wrapper() +def meta_fft_c2c(self, dim, normalization, forward): + torch._check(self.dtype.is_complex) + if not dim: + return self.clone() + + sorted_dims = _sort_dims(self, dim) + out = self.new_empty(self.size()) + return _exec_fft(out, self, self.size(), sorted_dims, forward=forward) + + +cufft_max_ndim = 3 + + +def use_optimized_cufft_path(dim: list[int]): + if len(dim) > cufft_max_ndim or (len(dim) >= 2 and dim[0] == 0 and dim[1] == 1): + return False + else: + return True + + +@register_meta([aten._fft_r2c.default, aten._fft_r2c.out]) +@out_wrapper() +def meta_fft_r2c(self, dim, normalization, onesided): + torch._check(self.dtype.is_floating_point) + input_sizes = list(self.size()) + out_sizes = list(input_sizes) + last_dim = dim[-1] + last_dim_halfsize = input_sizes[last_dim] // 2 + 1 + onesided_sizes = list(input_sizes) + onesided_sizes[last_dim] = last_dim_halfsize + + if onesided: + out_sizes[last_dim] = last_dim_halfsize + + if device_hint(self) == "cuda" or device_hint(self) == "xpu": + # _fft_r2c_cufft in aten/src/ATen/native/cuda/SpectralOps.cpp + # _fft_r2c_xpu in torch-xpu-ops/src/ATen/native/xpu/SpectralOps.cpp + output = self.new_empty( + out_sizes, dtype=utils.corresponding_complex_dtype(self.dtype) + ) + + working_tensor = self + if device_hint(self) == "cuda" and use_optimized_cufft_path(dim): + _exec_fft(output, working_tensor, out_sizes, dim, forward=True) + else: + # First do the R2C transform on the last dimension + target_sizes = out_sizes if len(dim) == 1 else onesided_sizes + _exec_fft(output, working_tensor, target_sizes, [last_dim], forward=True) + if len(dim) > 1: + working_tensor = self.new_empty( + out_sizes, dtype=utils.corresponding_complex_dtype(self.dtype) + ) + + # Then any remaining C2C transforms + sorted_dims = dim[:-1] + while sorted_dims: + output, working_tensor = working_tensor, output + strides = working_tensor.stride() + sorted_dims.sort( + key=lambda i: strides[i], reverse=True + ) # NB reverse! Not sure if this is og bug + max_dims = min(cufft_max_ndim, len(sorted_dims)) + last_dims = sorted_dims[len(sorted_dims) - max_dims :] + _exec_fft( + output, working_tensor, onesided_sizes, last_dims, forward=True + ) + sorted_dims = sorted_dims[: len(sorted_dims) - max_dims] + + if not onesided: + if output.size(last_dim) != out_sizes[last_dim]: + working_tensor.resize_(out_sizes, memory_format=torch.contiguous_format) + output = working_tensor + + return output + + else: + return self.new_empty( + out_sizes, dtype=utils.corresponding_complex_dtype(self.dtype) + ) + + +@register_meta(aten.randperm.generator_out) +def meta_randperm(n, *, generator=None, out): + return _maybe_resize_out(out, torch.Size([n])) + + +@register_meta(aten.randperm.default) +def meta_randperm_default( + n, + *, + dtype=torch.long, + layout=None, + device=None, + pin_memory=None, +): + return torch.empty( + n, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + +@register_meta([aten.randint.default, aten.randint.out]) +@out_wrapper() +def meta_randint( + high, + size, + *, + dtype=torch.long, + layout=None, + device=None, + pin_memory=None, +): + low = 0 + torch._check( + high > low, + lambda: f"random_ expects 'from' to be less than 'to', but got from={low} >= to={high}", + ) + return torch.empty( + size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + +@register_meta([aten.randint.low, aten.randint.low_out]) +@out_wrapper() +def meta_randint_low( + low, + high, + size, + *, + dtype=torch.long, + layout=None, + device=None, + pin_memory=None, +): + torch._check( + high > low, + lambda: f"random_ expects 'from' to be less than 'to', but got from={low} >= to={high}", + ) + return torch.empty( + size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + +@register_meta([aten.rand.default, aten.rand.out]) +@out_wrapper() +def meta_rand_default(size, *, dtype=None, layout=None, device=None, pin_memory=None): + return torch.empty( + size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + +@register_meta([aten._fft_c2r.default, aten._fft_c2r.out]) +@out_wrapper() +def meta_fft_c2r(self: Tensor, dim: list[int], normalization: int, lastdim: int): + # _fft_c2r_mkl + torch._check(self.dtype.is_complex) + + if device_hint(self) == "cuda": + out_sizes = list(self.size()) + out_sizes[dim[-1]] = lastdim + + output = self.new_empty(out_sizes, dtype=toRealValueType(self.dtype)) + + if use_optimized_cufft_path(dim): + return _exec_fft( + output, + self.clone(memory_format=torch.contiguous_format), + out_sizes, + dim, + forward=False, + ) + else: + # First complete any C2C transforms + if len(dim) > 1: + temp = meta_fft_c2c(self, dim[:-1], 0, lastdim) # fft_norm_mode::none + else: + temp = self.clone(memory_format=torch.contiguous_format) + return _exec_fft(output, temp, out_sizes, [dim[-1]], forward=False) + + else: + input = self + if len(dim) > 1: + c2c_dims = dim[:-1] + input = meta_fft_c2c(self, c2c_dims, normalization, forward=False) + dim = dim[-1:] + + out_sizes = list(input.size()) + out_sizes[dim[-1]] = lastdim + out = self.new_empty(out_sizes, dtype=toRealValueType(self.dtype)) + return _exec_fft(out, input, out_sizes, dim, forward=False) + + +@register_meta(aten.copy_.default) +def meta_copy_(self, src, non_blocking=False): + # This code simulates the original decomp from inductor, + # which runs most of the meta checks that we care about. + # In theory, we should make this more robust by carefully + # auditing our C++ copy_() kernel and copying the checks here. + from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols + + # TODO: Ideally, we'd insert a deferred runtime assert here, but if we are + # calling an actual copy_, you'll get that automatically + # https://github.com/pytorch/pytorch/issues/122477 + if ( + not free_unbacked_symbols(self) and torch._debug_has_internal_overlap(self) == 1 + ): # 1 == MemOverlap::Yes + raise RuntimeError( + "more than one element of the written-to tensor refers to a single memory location" + ) + + if isinstance(src, Tensor): + intermediate = src.to(self, non_blocking) + if self.size() != intermediate.size(): + aten.expand_copy.default(intermediate, self.size()) + return self + + +def inferUnsqueezeGeometry(tensor, dim): + result_sizes = list(tensor.size()) + result_strides = list(tensor.stride()) + new_stride = 1 if dim >= tensor.dim() else result_sizes[dim] * result_strides[dim] + result_sizes.insert(dim, 1) + result_strides.insert(dim, new_stride) + return result_sizes, result_strides + + +@register_meta(aten.unsqueeze_.default) +def meta_unsqueeze_(self, dim): + dim = maybe_wrap_dim(dim, self.dim() + 1) + g_sizes, g_strides = inferUnsqueezeGeometry(self, dim) + self.as_strided_(g_sizes, g_strides) + return self + + +@register_meta(aten._sparse_semi_structured_linear) +def meta_sparse_structured_linear( + input: Tensor, + weight: Tensor, + _meta: Tensor, + bias: Optional[Tensor] = None, + _activation_opt: Optional[str] = None, + out_dtype: Optional[torch.dtype] = None, +): + output_sizes = list(input.shape) + if bias is not None: + assert weight.size(0) == bias.size(0), "output size mismatch" + assert weight.size(1) == input.size(-1) / 2 + output_sizes[-1] = weight.size(0) + + # see: https://github.com/pytorch/pytorch/pull/114477#issuecomment-1830121375 + # We assume that we have already squashed the inputs into a 2-D tensor + # Then, as the output is transposed, we need to propagate the transposed + # stride information to the output tensor + assert len(input.shape) == 2, "we can only handle the squashed input case" + transposed_strides = (1, input.size(0)) + + if out_dtype is not None: + assert input.dtype == torch.int8 and out_dtype == torch.int32, ( + "out_dtype is only supported for i8i8->i32 linear operator" + ) + output = input.new_empty( + output_sizes, + dtype=input.dtype if out_dtype is None else out_dtype, + ).as_strided(output_sizes, transposed_strides) + + return output + + +@register_meta(aten._sparse_semi_structured_mm) +def meta_sparse_structured_mm( + mat1: Tensor, + mat1_meta: Tensor, + mat2: Tensor, + out_dtype: Optional[torch.dtype] = None, +): + assert len(mat1.shape) == 2 + assert len(mat1_meta.shape) == 2 + assert len(mat2.shape) == 2 + assert mat1.size(1) == mat2.size(0) / 2 + output_sizes = [mat1.size(0), mat2.size(1)] + + if out_dtype is not None: + assert mat2.dtype == torch.int8 and out_dtype == torch.int32, ( + "out_dtype is only supported for i8i8->i32 linear operator" + ) + output = mat2.new_empty( + output_sizes, + dtype=mat2.dtype if out_dtype is None else out_dtype, + ) + + return output + + +@register_meta(aten._sparse_semi_structured_addmm) +def meta_sparse_structured_addmm( + input: Tensor, + mat1: Tensor, + mat1_meta: Tensor, + mat2: Tensor, + *, + alpha=1, + beta=1, + out_dtype: Optional[torch.dtype] = None, +): + assert len(input.shape) == 1, ( + "only input broadcasted to columns of mat1 * mat2 product is supported" + ) + assert len(mat1.shape) == 2 + assert len(mat1_meta.shape) == 2 + assert len(mat2.shape) == 2 + assert input.size(0) == mat1.size(0), ( + "only input broadcasted to columns of mat1 * mat2 product is supported" + ) + assert mat1.size(1) == mat2.size(0) / 2 + output_sizes = [mat1.size(0), mat2.size(1)] + + if out_dtype is not None: + assert mat2.dtype == torch.int8 and out_dtype == torch.int32, ( + "out_dtype is only supported for i8i8->i32 linear operator" + ) + output = mat2.new_empty( + output_sizes, + dtype=mat2.dtype if out_dtype is None else out_dtype, + ) + + return output + + +@register_meta(aten._cslt_sparse_mm) +def meta__cslt_sparse_mm( + compressed_A: torch.Tensor, + dense_B: torch.Tensor, + bias: Optional[Tensor] = None, + alpha: Optional[Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + transpose_result: bool = False, + alg_id: int = 0, + split_k: int = 1, + split_k_mode: int = -1, +): + assert dense_B.dtype in { + torch.float32, + torch.float16, + torch.bfloat16, + torch.int8, + torch.float8_e4m3fn, + }, "_cslt_sparse_mm only supports fp16, bf16, int8, and fp8e4m3" + assert compressed_A.dtype == dense_B.dtype, "inputs must have the same dtype" + assert len(dense_B.shape) == 2, "_cslt_sparse_mm only supports 2d inputs" + + is_8bit_input_type = compressed_A.dtype in [torch.int8, torch.float8_e4m3fn] + compression_factor = 10 if is_8bit_input_type else 9 + + if is_8bit_input_type: + assert not dense_B.is_contiguous(), ( + "dense input must be transposed for 8bit dtypes" + ) + + k = dense_B.size(0) + n = dense_B.size(1) + m = (compressed_A.numel() * 16) // (compression_factor * k) + if bias is not None: + assert m == bias.size(0) + + if out_dtype is not None: + assert is_8bit_input_type and out_dtype in { + torch.float16, + torch.bfloat16, + torch.int32, + torch.float8_e4m3fn, + }, ( + "out_dtype is not supported for {compressed_A.dtype} x {dense_B.dtype} -> {out_dtype} matmul!" + ) + output_shape = (n, m) if transpose_result else (m, n) + return dense_B.new_empty(output_shape, dtype=out_dtype) + + +@register_meta(aten.index_reduce.default) +def meta_index_reduce( + self: Tensor, + dim: int, + index: Tensor, + source: torch.Tensor, + reduce: str, + *, + include_self: bool = True, +) -> Tensor: + return torch.empty_like(self, memory_format=torch.contiguous_format) + + +@register_meta(aten.index_reduce_.default) +def meta_index_reduce_( + self: Tensor, + dim: int, + index: Tensor, + source: torch.Tensor, + reduce: str, + *, + include_self: bool = True, +) -> Tensor: + return self + + +# Implementations below are taken from https://github.com/albanD/subclass_zoo/blob/main/python_meta_tensor.py +@out_wrapper() +@register_meta(aten.index_select.default) +def meta_index_select(self, dim, index): + result_size = list(self.size()) + if self.dim() > 0: + result_size[dim] = index.numel() + return self.new_empty(result_size) + + +@register_meta(aten.segment_reduce.default) +def meta_segment_reduce( + data: Tensor, + reduce: str, + *, + lengths: Optional[Tensor] = None, + indices: Optional[Tensor] = None, + offsets: Optional[Tensor] = None, + axis: int = 0, + unsafe: bool = False, + initial=None, +) -> Tensor: + if indices is not None: + raise NotImplementedError( + "segment_reduce(): indices based reduction is not supported yet." + ) + + def segment_reduce_lengths_tensor(lengths_shape): + return torch.empty( + lengths_shape + data.shape[axis + 1 :], + dtype=data.dtype, + device="meta", + memory_format=torch.contiguous_format, + ) + + if lengths is not None: + return segment_reduce_lengths_tensor(lengths.shape) + # FIXME should probably check that lengths and offset aren't both set, but + # the ATen implementation neglects this too + if offsets is not None: + # lengths == torch.diff(offsets) + lengths_shape = offsets.shape[:-1] + (offsets.shape[-1] - 1,) + return segment_reduce_lengths_tensor(lengths_shape) + raise RuntimeError("segment_reduce(): Either lengths or offsets must be defined.") + + +@register_meta([aten.max.default, aten.max.unary_out]) +@out_wrapper() +def meta_max(self): + return self.new_empty(()) + + +@register_meta(aten.max.dim) +def meta_max_dim(self, dim, keepdim=False): + dim = utils.reduction_dims(self.shape, (dim,)) + output_shape = _compute_reduction_shape(self, dim, keepdim) + return ( + self.new_empty(output_shape), + self.new_empty(output_shape, dtype=torch.long), + ) + + +@register_meta([aten.min.default, aten.min.unary_out]) +@out_wrapper() +def meta_min(self): + return self.new_empty(()) + + +@register_meta(aten.min.dim) +def meta_min_dim(self, dim, keepdim=False): + dim = utils.reduction_dims(self.shape, (dim,)) + output_shape = _compute_reduction_shape(self, dim, keepdim) + return ( + self.new_empty(output_shape), + self.new_empty(output_shape, dtype=torch.long), + ) + + +@register_meta(aten.angle.default) +def meta_angle(self): + if self.is_complex(): + result_dtype = corresponding_real_dtype(self.dtype) + else: + _, result_dtype = elementwise_dtypes( + self, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + ) + return torch.empty_like(self, dtype=result_dtype) + + +@register_meta(aten.angle.out) +def meta_angle_out(self, out): + torch._resize_output_(out, self.size(), self.device) + return out.copy_(torch.angle(self)) + + +@register_meta(aten._assert_async.default) +def assert_async(val): + return + + +@register_meta(aten._assert_async.msg) +def assert_async_meta(val, assert_msg): + return + + +@register_meta(aten._print.default) +def print_meta(s): + return + + +@register_meta(aten._make_dep_token.default) +def make_dep_token( + *, + dtype=None, + layout=None, + device=None, + pin_memory=None, + memory_format=None, +): + return torch.empty(0, device="meta") + + +@register_meta(aten.sym_constrain_range.default) +def sym_constrain_range(size, min=None, max=None): + # Avoid importing sympy at a module level + from torch.fx.experimental.symbolic_shapes import constrain_range + + if isinstance(size, (SymFloat, SymBool)): + raise ValueError("Constraining SymFloat or Symbool is nyi") + constrain_range(size, min=min, max=max) + + +@register_meta(aten._functional_sym_constrain_range.default) +def functional_sym_constrain_range(size, min=None, max=None, dep_token=None): + aten.sym_constrain_range(size, min=min, max=max) + return dep_token + + +@register_meta(aten.sym_constrain_range_for_size.default) +def sym_constrain_range_for_size(size, min=None, max=None): + # Avoid importing sympy at a module level + from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size + + if min is None and max is None: + torch._check_is_size(size) + return + + if isinstance(size, (SymFloat, SymBool)): + raise ValueError("Constraining SymFloat or Symbool is nyi") + if type(size) is int: + if min is not None: + torch._check(size >= min) + if max is not None: + torch._check(size <= max) + return + _constrain_range_for_size(size, min=min, max=max) + + +@register_meta(aten._functional_sym_constrain_range_for_size.default) +def functional_sym_constrain_range_for_size(size, min, max, dep_token): + aten.sym_constrain_range_for_size(size, min=min, max=max) + return dep_token + + +@register_meta(aten._functional_assert_async.msg) +def functional_assert_async_meta(val, assert_msg, dep_token): + return dep_token + + +# From aten/src/ATen/native/LinearAlgebraUtils.h +def squareCheckInputs(self: Tensor, f_name: str): + assert self.dim() >= 2, ( + f"{f_name}: The input tensor must have at least 2 dimensions." + ) + assert self.size(-1) == self.size(-2), ( + f"{f_name}: A must be batches of square matrices, but they are {self.size(-2)} by {self.size(-1)} matrices" + ) + + +# Validates input shapes and devices +# for linear solve methods (solve, cholesky_solve, lu_solve, triangular_solve) +# From aten/src/ATen/native/LinearAlgebraUtils.h +def linearSolveCheckInputs(self: Tensor, A: Tensor, name: str): + torch._check( + self.device == A.device, + lambda: ( + f"Expected b and A to be on the same device, but found b on " + f"{self.device} and A on {A.device} instead." + ), + ) + + torch._check( + self.dtype == A.dtype, + lambda: ( + f"Expected b and A to have the same dtype, but found b of type " + f"{self.dtype} and A of type {A.dtype} instead." + ), + ) + + torch._check( + A.size(-1) == A.size(-2), + lambda: ( + f"A must be batches of square matrices, " + f"but they are {A.size(-2)} by {A.size(-1)} matrices" + ), + ) + + torch._check( + A.size(-1) == self.size(-2), + lambda: ( + f"Incompatible matrix sizes for {name}: each A " + f"matrix is {A.size(-1)} by {A.size(-1)}" + f" but each b matrix is {self.size(-2)} by {self.size(-1)}" + ), + ) + + +# From aten/src/ATen/native/LinearAlgebraUtils.h +def checkFloatingOrComplex( + t: Tensor, + f_name: str, + allow_low_precision_dtypes: bool = True, +): + dtype = t.dtype + torch._check( + t.is_floating_point() or t.is_complex(), + lambda: f"{f_name}: Expected a floating point or complex tensor as input. Got {dtype}", + ) + if not allow_low_precision_dtypes: + torch._check( + dtype in (torch.float, torch.double, torch.cfloat, torch.cdouble), + lambda: f"{f_name}: Low precision dtypes not supported. Got {dtype}", + ) + + +# From aten/src/ATen/native/LinearAlgebraUtils.h +def checkIsMatrix(A: Tensor, f_name: str, arg_name: str = "A"): + torch._check( + A.dim() >= 2, + lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.", + ) + + +def checkInputsSolver(A: Tensor, B: Tensor, left: bool, f_name: str): + squareCheckInputs(A, f_name) + checkIsMatrix(B, f_name) + torch._check( + A.size(-2) == B.size(-2) if left else A.size(-1) == B.size(-1), + lambda: ( + f"{f_name}: Incompatible shapes of A and B for the equation " + f"{'AX = B' if left else 'XA = B'}" + f" ({A.size(-2)}x{A.size(-1)} and {B.size(-2)}x{B.size(-1)})" + ), + ) + + +def checkSameDevice( + fn_name: str, + result: Tensor, + input: Tensor, + result_name: str = "result", +): + torch._check( + result.device == input.device, + lambda: ( + f"{fn_name}: Expected {result_name} and input tensors to be on the same device, but got " + f"{result_name} on {result.device} and input on {input.device}" + ), + ) + + +def checkUplo(UPLO: str): + UPLO_uppercase = UPLO.upper() + torch._check( + len(UPLO) == 1 and (UPLO_uppercase == "U" or UPLO_uppercase == "L"), + lambda: f"Expected UPLO argument to be 'L' or 'U', but got {UPLO}", + ) + + +@register_meta([aten._linalg_eigh.default, aten._linalg_eigh.eigenvalues]) +@out_wrapper("eigenvalues", "eigenvectors") +def meta__linalg_eigh(A: Tensor, UPLO: str = "L", compute_v: bool = True): + squareCheckInputs(A, "linalg.eigh") + checkUplo(UPLO) + + shape = list(A.shape) + if compute_v: + vecs = A.new_empty(shape) + vecs.as_strided_(shape, make_contiguous_strides_for(shape, row_major=False)) + else: + vecs = A.new_empty([0]) + + shape.pop() + vals = A.new_empty(shape, dtype=toRealValueType(A.dtype)) + + return vals, vecs + + +@register_meta([aten._linalg_eigvals.default, aten.linalg_eigvals.out]) +@out_wrapper() +def meta__linalg_eigvals(input: Tensor) -> Tensor: + squareCheckInputs(input, "linalg.eigvals") + complex_dtype = ( + input.dtype + if utils.is_complex_dtype(input.dtype) + else utils.corresponding_complex_dtype(input.dtype) + ) + return input.new_empty(input.shape[:-1], dtype=complex_dtype) + + +@register_meta([aten.linalg_eig]) +@out_wrapper("eigenvalues", "eigenvectors") +def meta_linalg_eig(input: Tensor): + squareCheckInputs(input, "linalg.eig") + complex_dtype = ( + input.dtype + if utils.is_complex_dtype(input.dtype) + else utils.corresponding_complex_dtype(input.dtype) + ) + values = input.new_empty(input.shape[:-1], dtype=complex_dtype) + vectors = input.new_empty(input.shape, dtype=complex_dtype) + return values, vectors + + +def cloneBatchedColumnMajor(src: Tensor) -> Tensor: + return src.mT.clone(memory_format=torch.contiguous_format).transpose(-2, -1) + + +@register_meta(aten._cholesky_solve_helper) +@out_wrapper() +def _cholesky_solve_helper(self: Tensor, A: Tensor, upper: bool) -> Tensor: + return cloneBatchedColumnMajor(self) + + +@register_meta(aten.cholesky_solve) +@out_wrapper() +def cholesky_solve(self: Tensor, A: Tensor, upper: bool = False) -> Tensor: + torch._check( + self.ndim >= 2, + lambda: f"b should have at least 2 dimensions, but has {self.ndim} dimensions instead", + ) + torch._check( + A.ndim >= 2, + lambda: f"u should have at least 2 dimensions, but has {A.ndim} dimensions instead", + ) + self_broadcasted, A_broadcasted = _linalg_broadcast_batch_dims_name( + self, A, "cholesky_solve" + ) + return _cholesky_solve_helper(self_broadcasted, A_broadcasted, upper) + + +@register_meta(aten.cholesky) +@out_wrapper() +def cholesky(self: Tensor, upper: bool = False) -> Tensor: + if self.numel() == 0: + return torch.empty_like(self, memory_format=torch.legacy_contiguous_format) + squareCheckInputs(self, "cholesky") + return cloneBatchedColumnMajor(self) + + +@register_meta(aten.cholesky_inverse) +@out_wrapper() +def cholesky_inverse(self: Tensor, upper: bool = False) -> Tensor: + squareCheckInputs(self, "cholesky_inverse") + return cloneBatchedColumnMajor(self) + + +# From aten/src/ATen/native/BatchLinearAlgebra.cpp +@register_meta(aten.linalg_cholesky_ex.default) +def linalg_cholesky_ex(A: Tensor, upper: bool = False, check_errors: bool = False): + squareCheckInputs(A, "linalg.cholesky") + checkFloatingOrComplex(A, "linalg.cholesky") + + A_shape = A.shape + ndim = len(A_shape) + + # L + L_strides = make_contiguous_strides_for(A_shape, False) + L = A.new_empty(A_shape) + L.as_strided_(A_shape, L_strides) + + # infos + infos = A.new_empty(A_shape[0 : ndim - 2], dtype=torch.int32) + return L, infos + + +@register_meta( + [aten.linalg_householder_product.default, aten.linalg_householder_product.out] +) +@out_wrapper() +def linalg_householder_product(input: Tensor, tau: Tensor) -> Tensor: + torch._check( + input.ndim >= 2, + lambda: "torch.linalg.householder_product: input must have at least 2 dimensions.", + ) + torch._check( + input.size(-2) >= input.size(-1), + lambda: "torch.linalg.householder_product: input.shape[-2] must be greater than or equal to input.shape[-1]", + ) + torch._check( + input.size(-1) >= tau.size(-1), + lambda: "torch.linalg.householder_product: input.shape[-1] must be greater than or equal to tau.shape[-1]", + ) + + torch._check( + input.ndim - tau.ndim == 1, + lambda: ( + f"torch.linalg.householder_product: Expected tau to have one dimension less than input, " + f"but got tau.ndim equal to {tau.ndim} and input.ndim is equal to {input.ndim}" + ), + ) + if input.ndim > 2: + expected_batch_tau_shape = input.shape[:-2] + actual_batch_tau_shape = tau.shape[:-1] + torch._check( + actual_batch_tau_shape == expected_batch_tau_shape, + lambda: ( + f"torch.linalg.householder_product: Expected batch dimensions of tau to be " + f"equal to input.shape[:-2], but got {actual_batch_tau_shape}" + ), + ) + + torch._check( + tau.dtype == input.dtype, + lambda: ( + f"torch.linalg.householder_product: tau dtype {tau.dtype}" + f" does not match input dtype {input.dtype}" + ), + ) + checkSameDevice("torch.linalg.householder_product", tau, input, "tau") + + return torch.empty_strided( + size=input.shape, + stride=make_contiguous_strides_for(input.shape, row_major=False), + dtype=input.dtype, + device=input.device, + ) + + +# From aten/src/ATen/native/BatchLinearAlgebra.cpp +@register_meta(aten.linalg_inv_ex.default) +def linalg_inv_ex_meta(A: Tensor, check_errors: bool = False): + squareCheckInputs(A, "linalg.inv_ex") + checkFloatingOrComplex(A, "linalg.inv_ex", allow_low_precision_dtypes=False) + + L = A.new_empty(A.shape) + L.as_strided_(A.shape, make_contiguous_strides_for(A.shape, row_major=False)) + + infos = A.new_empty(A.shape[:-2], dtype=torch.int32) + return L, infos + + +@register_meta([aten.linalg_ldl_factor_ex.default, aten.linalg_ldl_factor_ex.out]) +@out_wrapper("LD", "pivots", "info") +def linalg_ldl_factor_ex_meta( + self: Tensor, + *, + hermitian: bool = False, + check_errors: bool = False, +) -> tuple[Tensor, Tensor, Tensor]: + squareCheckInputs(self, "torch.linalg.ldl_factor_ex") + checkFloatingOrComplex(self, "torch.linalg.ldl_factor_ex") + LD = torch.empty_strided( + size=self.shape, + stride=make_contiguous_strides_for(self.shape, row_major=False), + dtype=self.dtype, + device=self.device, + ) + pivots = self.new_empty(self.shape[:-1], dtype=torch.int) + info = self.new_empty(self.shape[:-2], dtype=torch.int) + return LD, pivots, info + + +@register_meta([aten.linalg_ldl_solve.default, aten.linalg_ldl_solve.out]) +@out_wrapper() +def linalg_ldl_solve_meta( + LD: Tensor, + pivots: Tensor, + B: Tensor, + *, + hermitian: bool = False, +) -> Tensor: + squareCheckInputs(LD, "torch.linalg.ldl_solve") + checkFloatingOrComplex(LD, "torch.linalg.ldl_solve") + linearSolveCheckInputs(B, LD, "torch.linalg.ldl_solve") + torch._check( + B.ndim >= 2, + lambda: ( + f"torch.linalg.ldl_solve: Expected B to have at least 2 dimensions, " + f"but it has {B.ndim} dimensions instead" + ), + ) + expected_pivots_shape = LD.shape[:-1] + torch._check( + expected_pivots_shape == pivots.shape, + lambda: ( + f"torch.linalg.ldl_solve: Expected LD.shape[:-1] and pivots.shape to be the same, " + f"but got pivots with shape {pivots.shape} instead" + ), + ) + torch._check( + utils.is_integer_dtype(pivots.dtype), + lambda: f"torch.linalg.ldl_solve: Expected pivots to be integers. Got {pivots.dtype}", + ) + torch._check( + LD.dtype == B.dtype, + lambda: f"torch.linalg.ldl_solve: LD dtype {LD.dtype} does not match b dtype {B.dtype}", + ) + B_broadcast_size, _ = _linalg_broadcast_batch_dims(B, LD) + return torch.empty_strided( + size=B_broadcast_size, + stride=make_contiguous_strides_for(B_broadcast_size, row_major=False), + dtype=B.dtype, + device=B.device, + ) + + +@register_meta([aten.linalg_lu.default, aten.linalg_lu.out]) +@out_wrapper("P", "L", "U") +def linalg_lu_meta(A: Tensor, *, pivot: bool = True) -> tuple[Tensor, Tensor, Tensor]: + torch._check( + A.ndim >= 2, + lambda: f"linalg.lu: Expected tensor with 2 or more dimensions. Got size: {A.shape} instead", + ) + + sizes = list(A.shape) + m = sizes[-2] + n = sizes[-1] + k = min(m, n) + + sizes[-1] = m + if pivot: + P = A.new_empty(sizes) + else: + P = A.new_empty([0]) + + sizes[-1] = k + L = A.new_empty(sizes) + + sizes[-2] = k + sizes[-1] = n + U = A.new_empty(sizes) + return P, L, U + + +@register_meta([aten.linalg_lu_factor_ex.default, aten.linalg_lu_factor_ex.out]) +@out_wrapper("LU", "pivots", "info") +def linalg_lu_factor_ex_meta( + A: Tensor, + *, + pivot: bool = True, + check_errors: bool = False, +) -> tuple[Tensor, Tensor, Tensor]: + torch._check( + A.ndim >= 2, + lambda: f"torch.lu_factor: Expected tensor with 2 or more dimensions. Got size: {A.shape} instead", + ) + + sizes = list(A.shape) + m = sizes[-2] + n = sizes[-1] + + LU = torch.empty_strided( + size=sizes, + stride=make_contiguous_strides_for(sizes, row_major=False), + dtype=A.dtype, + device=A.device, + ) + + # Sets sizes to the size of pivots + sizes.pop() + sizes[-1] = min(m, n) + pivots = A.new_empty(sizes, dtype=torch.int) + + # Sets sizes to the size of info + sizes.pop() + info = A.new_empty(sizes, dtype=torch.int) + + return LU, pivots, info + + +@register_meta([aten.linalg_lu_solve.default, aten.linalg_lu_solve.out]) +@out_wrapper() +def linalg_lu_solve_meta( + LU: Tensor, + pivots: Tensor, + B: Tensor, + *, + left: bool = True, + adjoint: bool = False, +) -> Tensor: + # dtype + checkFloatingOrComplex(LU, "torch.linalg.lu_solve") + torch._check( + LU.dtype == B.dtype, + lambda: ( + f"linalg.lu_solve: Expected LU and B to have the same dtype, " + f"but found LU of type {LU.dtype} and B of type {B.dtype} instead" + ), + ) + torch._check( + pivots.dtype == torch.int, + lambda: "linalg.lu_solve: pivots should be a Tensor of scalar type torch.int32", + ) + + # matrix shapes + squareCheckInputs(LU, "torch.linalg.lu_solve") + checkInputsSolver(LU, B, left, "linalg.lu_solve") + torch._check( + LU.size(-1) == pivots.size(-1), + lambda: "linalg.lu_solve: Number of pivots per batch should be same as the dimension of the matrix", + ) + + # batches + torch._check( + LU.shape[:-1] == pivots.shape, + lambda: ( + f"linalg.lu_solve: Expected LU.shape[:-1] and pivots.shape to be the same, " + f"but got pivots with shape {pivots.shape} instead" + ), + ) + + B_broadcast_size, _ = _linalg_broadcast_batch_dims(B, LU) + + result = torch.empty_strided( + size=B_broadcast_size, + stride=make_contiguous_strides_for(B_broadcast_size, row_major=not left), + dtype=B.dtype, + device=B.device, + ) + + if result.numel() != 0 and not left: + if result.is_complex(): + result = result.conj() + + return result + + +@register_meta(aten.lu_unpack) +@out_wrapper("P", "L", "U") +def lu_unpack_meta( + LU: Tensor, + pivots: Tensor, + unpack_data: bool = True, + unpack_pivots: bool = True, +) -> tuple[Tensor, Tensor, Tensor]: + torch._check( + LU.ndim >= 2, + lambda: f"torch.lu_unpack: Expected tensor with 2 or more dimensions. Got size: {LU.shape} instead", + ) + if unpack_pivots: + torch._check( + pivots.dtype == torch.int32, + lambda: ( + "torch.lu_unpack: LU_pivots is expected to be a contiguous tensor of torch.int32 dtype.\n" + "Note: this function is intended to be used with the output produced by torch.linalg.lu_factor" + ), + ) + sizes = list(LU.shape) + m = sizes[-2] + n = sizes[-1] + k = min(m, n) + sizes[-1] = m + if unpack_pivots: + P = LU.new_empty(sizes) + else: + P = LU.new_empty([0]) + if unpack_data: + sizes[-1] = k + L = LU.new_empty(sizes) + sizes[-2] = k + sizes[-1] = n + U = LU.new_empty(sizes) + else: + L = LU.new_empty([0]) + U = LU.new_empty([0]) + return P, L, U + + +# parse the "mode" param in linalg_qr: return a tuple of bools (compute_q, reduced) +def _parse_qr_mode(mode: str) -> tuple[bool, bool]: + if mode == "reduced": + compute_q = True + reduced = True + elif mode == "complete": + compute_q = True + reduced = False + elif mode == "r": + compute_q = False + reduced = True # this is actually irrelevant in this mode + else: + torch._check( + False, + lambda: ( + f"qr received unrecognized mode '{mode}' " + f"but expected one of 'reduced' (default), 'r', or 'complete'" + ), + ) + return compute_q, reduced # type: ignore[possibly-undefined] + + +@register_meta([aten.linalg_qr.default, aten.linalg_qr.out]) +@out_wrapper("Q", "R") +def linalg_qr_meta(A: Tensor, mode: str = "reduced") -> tuple[Tensor, Tensor]: + checkIsMatrix(A, "linalg.qr") + checkFloatingOrComplex(A, "linalg.qr") + + compute_q, reduced_mode = _parse_qr_mode(mode) + + m = A.shape[-2] + n = A.shape[-1] + k = min(m, n) + + if compute_q: + Q_shape = list(A.shape) + Q_shape[-1] = k if reduced_mode else m + Q = A.new_empty(Q_shape) + Q.as_strided_(Q_shape, make_contiguous_strides_for(Q_shape, row_major=False)) + else: + Q = A.new_empty([0]) + + # For readability + R_shape = list(A.shape) + R_shape[-2] = k if reduced_mode or not compute_q else m + R = A.new_empty(R_shape) + R.as_strided_(R_shape, make_contiguous_strides_for(R_shape, row_major=False)) + return Q, R + + +@register_meta([aten._linalg_slogdet.default, aten._linalg_slogdet.sign]) +@out_wrapper("sign", "logabsdet", "LU", "pivots") +def _linalg_slogdet(A: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: + squareCheckInputs(A, "linalg.slogdet") + checkFloatingOrComplex(A, "linalg.slogdet", False) + shape = A.shape + sign = A.new_empty(shape[:-2]) + logabsdet = A.new_empty(shape[:-2], dtype=toRealValueType(A.dtype)) + LU = torch.empty_strided( + size=shape, + stride=make_contiguous_strides_for(shape, False), + dtype=A.dtype, + device=A.device, + ) + pivots = A.new_empty(shape[:-1], dtype=torch.int32) + return sign, logabsdet, LU, pivots + + +# From aten/src/ATen/native/BatchLinearAlgebra.cpp +# NOTE: matching defaults in aten/src/ATen/native/native_functions.yaml +@register_meta(aten._linalg_svd.default) +def _linalg_svd_meta( + A: Tensor, + full_matrices: bool = False, + compute_uv: bool = True, + driver: Optional[str] = None, +): + checkIsMatrix(A, "linalg.svd") + checkFloatingOrComplex(A, "linalg.svd") + + batch_dims = list(A.shape[:-2]) + m = A.shape[-2] + n = A.shape[-1] + k = min(m, n) + + if compute_uv: + U_shape = batch_dims + [m, m if full_matrices else k] + U = A.new_empty(U_shape) + U.as_strided_(U_shape, make_contiguous_strides_for(U_shape, row_major=False)) + + V_shape = batch_dims + [n if full_matrices else k, n] + V = A.new_empty(V_shape) + # NB: This checks for CUDA since there is no way to check for cuSolver. + # Also, this might not work correctly on CPU when fake_device is not + # available as device_hint just defaults to CUDA in that case. See + # _linalg_svd meta in core. + is_cuda = device_hint(A) == "cuda" + V.as_strided_(V_shape, make_contiguous_strides_for(V_shape, row_major=is_cuda)) + else: + # doesn't matter + U = A.new_empty([0]) + V = A.new_empty([0]) + + # S is always real, even when A is complex. + S = A.new_empty(batch_dims + [k], dtype=toRealValueType(A.dtype)) + return U, S, V + + +def _linalg_broadcast_batch_dims( + arg1: Tensor, + arg2: Tensor, +) -> tuple[list[int], list[int]]: + # broadcast the batch dimensions of arg1 and arg2. + arg1_batch_sizes = arg1.shape[:-2] + arg2_batch_sizes = arg2.shape[:-2] + expand_batch_portion = _broadcast_shapes(arg1_batch_sizes, arg2_batch_sizes) + + arg1_expand_size = list(expand_batch_portion) + arg1_expand_size += [arg1.size(-2), arg1.size(-1)] + + arg2_expand_size = list(expand_batch_portion) + arg2_expand_size += [arg2.size(-2), arg2.size(-1)] + return arg1_expand_size, arg2_expand_size + + +def _linalg_broadcast_batch_dims_name( + arg1: Tensor, + arg2: Tensor, + name: Optional[str], +) -> tuple[Tensor, Tensor]: + # If there's no name we assume we don't want to check the errors + if name: + linearSolveCheckInputs(arg1, arg2, name) + + arg1_expand_size, arg2_expand_size = _linalg_broadcast_batch_dims(arg1, arg2) + + arg1_broadcasted = ( + arg1 if arg1_expand_size == arg1.shape else arg1.expand(arg1_expand_size) + ) + arg2_broadcasted = ( + arg2 if arg2_expand_size == arg2.shape else arg2.expand(arg2_expand_size) + ) + return arg1_broadcasted, arg2_broadcasted + + +def linalg_solve_is_vector_rhs(input: Tensor, other: Tensor) -> bool: + expected_batched_rhs_shape = input.shape[:-1] + vector_case = other.ndim == 1 or ( + input.ndim - 1 == other.ndim and other.shape == expected_batched_rhs_shape + ) + return vector_case + + +@register_meta(aten._linalg_solve_ex) +def _linalg_solve_ex( + A: Tensor, + B: Tensor, + *, + left: bool = True, + check_errors: bool = False, + result: Optional[Tensor] = None, + LU: Optional[Tensor] = None, + pivots: Optional[Tensor] = None, + info: Optional[Tensor] = None, +) -> tuple[Tensor, Tensor, Tensor, Tensor]: + checkFloatingOrComplex(A, "linalg.solve") + torch._check( + A.dtype == B.dtype, + lambda: ( + f"linalg.solve: Expected A and B to have the same dtype, but found A of type " + f"{A.dtype} and B of type {B.dtype} instead" + ), + ) + vector_case = linalg_solve_is_vector_rhs(A, B) + B_ = B.unsqueeze(-1) if vector_case else B + checkInputsSolver(A, B_, left, "linalg.solve") + B_broad_shape, _ = _linalg_broadcast_batch_dims(B_, A) + torch._check( + left or not vector_case, + lambda: ( + "linalg.solve: Vector broadcasting of the left hand side is not supported for left=False. " + "In this case linalg.solve is equivalent to B / A.squeeze(-1)" + ), + ) + result_shape = B_broad_shape[:-1] if vector_case else B_broad_shape + result_ = torch.empty_strided( + size=result_shape, + stride=make_contiguous_strides_for(result_shape, not left), + dtype=B.dtype, + device=B.device, + ) + shape = A.shape + LU_ = torch.empty_strided( + size=shape, + stride=make_contiguous_strides_for(shape, False), + dtype=A.dtype, + device=A.device, + ) + pivots_ = A.new_empty(shape[:-1], dtype=torch.int32) + info_ = A.new_empty(shape[:-2], dtype=torch.int32) + out = (result, LU, pivots, info) + res = (result_, LU_, pivots_, info_) + if all(x is not None for x in out): + for r, o in zip(res, out): + # resize and copy operations are done in-place + _maybe_resize_out(o, r.shape) # type: ignore[arg-type] + # strides are not copied in out_wrapper + o.as_strided_(r.shape, r.stride()) # type: ignore[union-attr] + _safe_copy_out(copy_from=r, copy_to=o, exact_dtype=False) # type: ignore[arg-type] + return res + + +@register_meta([aten.linalg_solve_triangular.default, aten.linalg_solve_triangular.out]) +def linalg_solve_triangular_meta( + A: Tensor, + B: Tensor, + *, + upper: bool, + left: bool = True, + unitriangular: bool = False, + out: Optional[Tensor] = None, +) -> Tensor: + if out is None: + out = A.new_empty([0]) + assert isinstance(out, TensorLike) + checkInputsSolver(A, B, left, "linalg.solve_triangular") + B_, A_ = _linalg_broadcast_batch_dims_name(B, A, None) + avoid_copy_A = A_.transpose(-2, -1).is_contiguous() and A_.is_conj() + if avoid_copy_A: + out = _maybe_resize_out(out, B_.shape) + else: + # reimplementation of resize_output with result F-contig + if _resize_output_check(out, B_.shape): + out.resize_(B_.transpose(-2, -1).shape) + out.transpose_(-2, -1) + return out # type: ignore[return-value] + + +@register_meta(aten.triangular_solve) +@out_wrapper("X", "M", exact_dtype=True) +def triangular_solve_meta( + self: Tensor, + A: Tensor, + upper: bool = True, + transpose: bool = False, + unitriangular: bool = False, +) -> tuple[Tensor, Tensor]: + torch._check( + self.ndim >= 2, + lambda: ( + f"torch.triangular_solve: Expected b to have at least 2 dimensions, " + f"but it has {self.ndim} dimensions instead" + ), + ) + torch._check( + A.ndim >= 2, + lambda: ( + f"torch.triangular_solve: Expected A to have at least 2 dimensions, " + f"but it has {A.ndim} dimensions instead" + ), + ) + + linearSolveCheckInputs(self, A, "triangular_solve") + + if A.layout == torch.strided: + self_broadcast_size, A_broadcast_size = _linalg_broadcast_batch_dims(self, A) + solution = torch.empty_strided( + size=self_broadcast_size, + stride=make_contiguous_strides_for(self_broadcast_size, row_major=False), + dtype=self.dtype, + device=self.device, + ) + cloned_coefficient = torch.empty_strided( + size=A_broadcast_size, + stride=make_contiguous_strides_for(A_broadcast_size, row_major=False), + dtype=A.dtype, + device=A.device, + ) + elif A.layout == torch.sparse_csr or A.layout == torch.sparse_bsr: + solution = torch.empty_like(self) + cloned_coefficient = self.new_empty([0]) + else: + torch._check(False, lambda: "triangular_solve: Got an unexpected layout.") + return solution, cloned_coefficient # type: ignore[possibly-undefined] + + +# From aten/src/ATen/native/LinearAlgebra.cpp +@register_meta(aten._linalg_det.default) +def _linalg_det_meta(A): + squareCheckInputs(A, "linalg.det") + checkFloatingOrComplex(A, "linalg.det") + + det = A.new_empty(A.shape[:-2]) + + LU = A.new_empty(A.shape) + LU.as_strided_(A.shape, make_contiguous_strides_for(A.shape, row_major=False)) + + pivots = A.new_empty(A.shape[:-1], dtype=torch.int32) + return det, LU, pivots + + +@register_meta(aten.ormqr) +@out_wrapper() +def ormqr( + input: Tensor, + tau: Tensor, + other: Tensor, + left: bool = True, + transpose: bool = False, +) -> Tensor: + torch._check( + input.ndim >= 2, lambda: "torch.ormqr: input must have at least 2 dimensions." + ) + torch._check( + other.ndim >= 2, lambda: "torch.ormqr: other must have at least 2 dimensions." + ) + + left_size_condition = -2 if left else -1 + torch._check( + other.shape[left_size_condition] >= tau.shape[-1], + lambda: f"torch.ormqr: other.shape[{left_size_condition}] must be greater than or equal to tau.shape[-1]", + ) + torch._check( + other.shape[left_size_condition] == input.shape[-2], + lambda: f"torch.ormqr: other.shape[{left_size_condition}] must be equal to input.shape[-2]", + ) + + torch._check( + tau.shape[-1] <= input.shape[-1], + lambda: "torch.ormqr: tau.shape[-1] must be less than or equal to input.shape[-1]", + ) + + torch._check( + input.ndim - tau.ndim == 1, + lambda: ( + f"torch.ormqr: Expected tau to have one dimension less than input, " + f"but got tau.ndim equal to {tau.ndim} and input.ndim is equal to {input.ndim}" + ), + ) + torch._check( + input.ndim == other.ndim, + lambda: ( + f"torch.ormqr: Expected other to have the same number of dimensions as input, " + f"but got other.ndim equal to {other.ndim} and input.ndim is equal to {input.ndim}" + ), + ) + + if input.ndim > 2: + expected_batch_shape = input.shape[:-2] + actual_batch_tau_shape = tau.shape[:-1] + torch._check( + actual_batch_tau_shape == expected_batch_shape, + lambda: ( + f"torch.ormqr: Expected batch dimensions of tau to be " + f"equal to input.shape[:-2], but got {actual_batch_tau_shape}" + ), + ) + + actual_batch_other_shape = other.shape[:-2] + torch._check( + actual_batch_other_shape == expected_batch_shape, + lambda: ( + f"torch.ormqr: Expected batch dimensions of other to be " + f"equal to input.shape[:-2], but got {actual_batch_other_shape}" + ), + ) + + torch._check( + tau.dtype == input.dtype, + lambda: ( + f"torch.ormqr: Expected input and tau to have the same dtype, " + f"but input has dtype {input.dtype} and tau has dtype {tau.dtype}" + ), + ) + torch._check( + other.dtype == input.dtype, + lambda: ( + f"torch.ormqr: Expected input and other to have the same dtype, " + f"but input has dtype {input.dtype} and other has dtype {other.dtype}" + ), + ) + + checkSameDevice("torch.ormqr", tau, input, "tau") + checkSameDevice("torch.ormqr", other, input, "other") + + return torch.empty_strided( + size=other.shape, + stride=make_contiguous_strides_for(other.shape, row_major=False), + dtype=other.dtype, + device=other.device, + ) + + +def _padding_check_valid_input(input, padding, *, dim): + torch._check( + len(padding) == 2 * dim, + lambda: f"padding size is expected to be {2 * dim}, but got: {len(padding)}", + ) + + input_dim = input.ndim + + is_batch_mode = input_dim == (dim + 2) + + valid_batch_mode = is_batch_mode + valid_non_batch_mode = not is_batch_mode + + if is_batch_mode: + # allow batch size of 0-dim. + for d in range(1, input_dim): + valid_batch_mode = valid_batch_mode and input.size(d) != 0 + else: + for d in range(0, input_dim): + valid_non_batch_mode = valid_non_batch_mode and input.size(d) != 0 + + # allow empty batch size but not other dimensions. + torch._check( + valid_batch_mode or valid_non_batch_mode, + lambda: ( + f"Expected {dim + 1}D or {dim + 2}D (batch mode) tensor with possibly 0 batch size " + f"and other non-zero dimensions for input, but got: {input.shape}" + ), + ) + + +def _pad1d_common(input, padding, *, is_reflection): + dim_plane = 0 + dim_w = 1 + nbatch = 1 + + if input.ndim == 3: + nbatch = input.size(0) + dim_w += 1 + dim_plane += 1 + + _padding_check_valid_input(input, padding, dim=1) + + pad_l, pad_r = padding + + nplane = input.size(dim_plane) + input_w = input.size(dim_w) + output_w = input_w + pad_l + pad_r + + if is_reflection: + torch._check( + pad_l < input_w and pad_r < input_w, + lambda: ( + f"Argument #4: Padding size should be less than the corresponding input dimension, " + f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}" + ), + ) + + torch._check( + output_w >= 1, + lambda: f"input (W: {input_w}) is too small. Calculated output W: {output_w}", + ) + + if input.ndim == 2: + return input.new_empty((nplane, output_w)) + else: + return input.new_empty((nbatch, nplane, output_w)) + + +@register_meta(aten.reflection_pad1d) +@out_wrapper() +def meta_reflection_pad1d(input, padding): + return _pad1d_common(input, padding, is_reflection=True) + + +@register_meta(aten.replication_pad1d) +@out_wrapper() +def meta_replication_pad1d(input, padding): + torch._check( + input.dtype != torch.bool, + lambda: f""""replication_pad1d" not implemented for '{input.dtype.__str__()}'""", + ) + return _pad1d_common(input, padding, is_reflection=False) + + +def _pad1d_backward_common(grad_output, input, padding, *, is_reflection): + dim_w = 1 + if not is_reflection: + torch._check(len(padding) == 2, lambda: "padding size is expected to be 2") + + if input.ndim == 3: + dim_w += 1 + + pad_l, pad_r = padding + + input_w = input.size(dim_w) + output_w = input_w + pad_l + pad_r + + if is_reflection: + torch._check( + pad_l < input_w and pad_r < input_w, + lambda: ( + f"Argument #4: Padding size should be less than the corresponding input dimension, " + f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}" + ), + ) + + torch._check( + output_w == grad_output.size(dim_w), + lambda: f"grad_output width unexpected. Expected: {output_w}, Got: {grad_output.size(dim_w)}", + ) + + return input.new_empty(input.shape) + + +@register_meta(aten.reflection_pad1d_backward) +@out_wrapper("grad_input") +def meta_reflection_pad1d_backward(grad_output, input, padding): + return _pad1d_backward_common(grad_output, input, padding, is_reflection=True) + + +@register_meta(aten.replication_pad1d_backward) +@out_wrapper("grad_input") +def meta_replication_pad1d_backward(grad_output, input, padding): + return _pad1d_backward_common(grad_output, input, padding, is_reflection=False) + + +def _pad2d_common(input, padding, *, is_reflection): + dim_w = 2 + dim_h = 1 + dim_slices = 0 + nbatch = 1 + + _padding_check_valid_input(input, padding, dim=2) + + ndim = input.ndim + if ndim == 4: + nbatch = input.size(0) + dim_w += 1 + dim_h += 1 + dim_slices += 1 + + pad_l, pad_r, pad_t, pad_b = padding + + nplane = input.size(dim_slices) + input_h = input.size(dim_h) + input_w = input.size(dim_w) + output_h = input_h + pad_t + pad_b + output_w = input_w + pad_l + pad_r + + if is_reflection: + torch._check( + pad_l < input_w and pad_r < input_w, + lambda: ( + f"Argument #4: Padding size should be less than the corresponding input dimension, " + f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}" + ), + ) + torch._check( + pad_t < input_h and pad_b < input_h, + lambda: ( + f"Argument #6: Padding size should be less than the corresponding input dimension, " + f"but got: padding ({pad_t}, {pad_b}) at dimension {dim_h} of input {input.shape}" + ), + ) + + torch._check( + output_w >= 1 or output_h >= 1, + lambda: ( + f"input (H: {input_h} W: {input_w}) is too small. " + f"Calculated output H: {output_h} W: {output_w}" + ), + ) + + if input.ndim == 3: + return input.new_empty((nplane, output_h, output_w)) + else: + return input.new_empty((nbatch, nplane, output_h, output_w)) + + +@register_meta(aten.reflection_pad2d) +@out_wrapper() +def meta_reflection_pad2d(input, padding): + return _pad2d_common(input, padding, is_reflection=True) + + +@register_meta(aten.replication_pad2d) +@out_wrapper() +def meta_replication_pad2d(input, padding): + torch._check( + input.dtype != torch.bool, + lambda: f""""replication_pad2d" not implemented for '{input.dtype.__str__()}'""", + ) + return _pad2d_common(input, padding, is_reflection=False) + + +@register_meta( + [ + aten.reflection_pad2d_backward.default, + aten.reflection_pad2d_backward.grad_input, + aten.replication_pad2d_backward.default, + aten.replication_pad2d_backward.grad_input, + ] +) +@out_wrapper("grad_input") +def meta_pad2d_backward(grad_output, self, padding): + dim_w = 2 + dim_h = 1 + dim_plane = 0 + + self_shape = self.shape + if self.dim() == 4: + dim_w += 1 + dim_h += 1 + dim_plane += 1 + + pad_l, pad_r, pad_t, pad_b = padding + + input_h = self_shape[dim_h] + input_w = self_shape[dim_w] + output_h = input_h + pad_t + pad_b + output_w = input_w + pad_l + pad_r + + torch._check( + output_w == grad_output.size(dim_w), + lambda: f"grad_output width unexpected. Expected: {output_w}, Got: {grad_output.size(dim_w)}", + ) + torch._check( + output_h == grad_output.size(dim_h), + lambda: f"grad_output height unexpected. Expected: {output_h}, Got: {grad_output.size(dim_h)}", + ) + return self.new_empty(self.shape) + + +def _pad3d_common(input, padding, *, is_reflection): + dim_w = 3 + dim_h = 2 + dim_d = 1 + dim_plane = 0 + + _padding_check_valid_input(input, padding, dim=3) + + batch_mode = input.ndim == 5 + if batch_mode: + nbatch = input.size(0) + dim_w += 1 + dim_h += 1 + dim_d += 1 + dim_plane += 1 + + pad_l, pad_r, pad_t, pad_b, pad_f, pad_bk = padding + + nplane = input.size(dim_plane) + input_d = input.size(dim_d) + input_h = input.size(dim_h) + input_w = input.size(dim_w) + output_d = input_d + pad_f + pad_bk + output_h = input_h + pad_t + pad_b + output_w = input_w + pad_l + pad_r + + if is_reflection: + torch._check( + pad_l < input_w and pad_r < input_w, + lambda: ( + f"Argument #4: Padding size should be less than the corresponding input dimension, " + f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}" + ), + ) + torch._check( + pad_t < input_h and pad_b < input_h, + lambda: ( + f"Argument #6: Padding size should be less than the corresponding input dimension, " + f"but got: padding ({pad_t}, {pad_b}) at dimension {dim_h} of input {input.shape}" + ), + ) + torch._check( + pad_f < input_d and pad_bk < input_d, + lambda: ( + f"Argument #8: Padding size should be less than the corresponding input dimension, " + f"but got: padding ({pad_f}, {pad_bk}) at dimension {dim_d} of input {input.shape}" + ), + ) + + torch._check( + output_w >= 1 or output_h >= 1 or output_d >= 1, + lambda: ( + f"input (D: {input_d} H: {input_h} W: {input_w}) is too small. " + f"Calculated output D: {output_d} H: {output_h} W: {output_w}" + ), + ) + + if batch_mode: + return input.new_empty((nbatch, nplane, output_d, output_h, output_w)) # type: ignore[possibly-undefined] + else: + return input.new_empty((nplane, output_d, output_h, output_w)) + + +@register_meta(aten.reflection_pad3d) +@out_wrapper() +def meta_reflection_pad3d(input, padding): + return _pad3d_common(input, padding, is_reflection=True) + + +@register_meta(aten.replication_pad3d) +@out_wrapper() +def meta_replication_pad3d(input, padding): + torch._check( + input.dtype != torch.bool, + lambda: f""""replication_pad3d" not implemented for '{input.dtype.__str__()}'""", + ) + return _pad3d_common(input, padding, is_reflection=False) + + +@register_meta( + [ + aten.reflection_pad3d_backward.default, + aten.reflection_pad3d_backward.grad_input, + aten.replication_pad3d_backward.default, + aten.replication_pad3d_backward.grad_input, + ] +) +@out_wrapper("grad_input") +def meta_pad3d_backward(grad_output, input, padding): + torch._check(len(padding) == 6, lambda: "padding size is expected to be 6") + assert input.ndim > 3 + assert grad_output.ndim == input.ndim + + dim_w = 3 + dim_h = 2 + dim_d = 1 + + if input.ndim == 5: + dim_w += 1 + dim_h += 1 + dim_d += 1 + + pad_l, pad_r, pad_t, pad_b, pad_f, pad_bk = padding + + input_d = input.size(dim_d) + input_h = input.size(dim_h) + input_w = input.size(dim_w) + output_d = input_d + pad_f + pad_bk + output_h = input_h + pad_t + pad_b + output_w = input_w + pad_l + pad_r + + torch._check( + output_w == grad_output.size(dim_w), + lambda: f"grad_output width unexpected. Expected: {output_w}, Got: {grad_output.size(dim_w)}", + ) + torch._check( + output_h == grad_output.size(dim_h), + lambda: f"grad_output height unexpected. Expected: {output_h}, Got: {grad_output.size(dim_h)}", + ) + torch._check( + output_d == grad_output.size(dim_d), + lambda: f"grad_output depth unexpected. Expected: {output_d}, Got: {grad_output.size(dim_d)}", + ) + + return input.new_empty(input.shape) + + +@register_meta(aten._pdist_forward) +@out_wrapper() +def meta__pdist_forward(self: Tensor, p: float = 2) -> Tensor: + torch._check( + self.is_contiguous(), lambda: "_pdist_forward requires contiguous input" + ) + n = self.size(0) + if n <= 1: + return self.new_empty([0]).to(memory_format=torch.legacy_contiguous_format) # type: ignore[call-overload] + else: + return self.new_empty((n * (n - 1) // 2,)).to( + memory_format=torch.legacy_contiguous_format + ) # type: ignore[call-overload] + + +@register_meta(aten._pdist_backward) +@out_wrapper() +def meta__pdist_backward(grad: Tensor, self: Tensor, p: float, pdist: Tensor) -> Tensor: + torch._check( + self.is_contiguous(), lambda: "_pdist_backward requires self to be contiguous" + ) + torch._check( + pdist.is_contiguous(), lambda: "_pdist_backward requires pdist to be contiguous" + ) + return torch.empty_like(self, memory_format=torch.legacy_contiguous_format) + + +@register_meta([aten.baddbmm.default, aten.baddbmm.out]) +@out_wrapper(exact_dtype=True) +def meta_baddbmm(self, batch1, batch2, *, beta=1, alpha=1): + from torch.fx.experimental.symbolic_shapes import guard_or_true, sym_eq + + dim1 = batch1.size(0) + dim2 = batch1.size(1) + dim3 = batch2.size(2) + if guard_or_true(torch.sym_not(sym_eq(self.shape, (dim1, dim2, dim3)))): + self = self.expand((dim1, dim2, dim3)) + torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor") + torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor") + if not exp_config.skip_dtype_check_in_meta_registrations: + torch._check( + self.dtype == batch1.dtype == batch2.dtype, + lambda: f"Input dtypes must be the same, got: input: {self.dtype}, batch1: {batch1.dtype}, batch2: {batch2.dtype}", + ) + batch1_sizes = batch1.shape + batch2_sizes = batch2.shape + bs = batch1_sizes[0] + contraction_size = batch1_sizes[2] + torch._check( + batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size, + lambda: ( + f"Expected size for first two dimensions of batch2 tensor to be: " + f"[{bs}, {contraction_size}] but got: [{batch2_sizes[0]}, {batch2_sizes[1]}]." + ), + ) + return self.new_empty(self.size()) + + +@register_meta([aten.bernoulli.default, aten.bernoulli.out]) +@out_wrapper() +def meta_bernoulli(self, *, generator=None): + # https://github.com/pytorch/pytorch/issues/88612 + return torch.empty_like(self, memory_format=torch.contiguous_format) + + +@register_meta(aten.bernoulli_.float) +def meta_bernoulli_(self, p=0.5, generator=None): + return self + + +@register_meta(aten.bernoulli.p) +def meta_bernoulli_p(self, p=0.5, generator=None): + # https://github.com/pytorch/pytorch/issues/88612 + return torch.empty_like(self, memory_format=torch.contiguous_format) + + +@register_meta([aten.poisson.default, aten.poisson.out]) +@out_wrapper() +def meta_poisson(self, generator=None): + return torch.empty_like(self) + + +@register_meta(aten._fused_moving_avg_obs_fq_helper.default) +def meta__fused_moving_avg_obs_fq_helper( + self, + observer_on, + fake_quant_on, + running_min, + running_max, + scale, + zero_point, + averaging_const, + quant_min, + quant_max, + ch_axis, + per_row_fake_quant=False, + symmetric_quant=False, +): + torch._check( + ch_axis < self.dim(), + lambda: "Error in fused_moving_avg_obs_fake_quant_cpu: ch_axis must be < self.dim()", + ) + mask = torch.empty_like(self, dtype=torch.bool) + return (torch.empty_like(self), mask) + + +@register_meta(aten.mm) +@out_wrapper(exact_dtype=True) +def meta_mm(a, b): + torch._check(a.dim() == 2, lambda: "a must be 2D") + torch._check(b.dim() == 2, lambda: "b must be 2D") + N, M1 = a.shape + M2, P = b.shape + torch._check( + M1 == M2, + lambda: f"a and b must have same reduction dim, but got [{N}, {M1}] X [{M2}, {P}].", + ) + return a.new_empty(N, P) + + +def _compute_reduction_shape(self, dims, keepdim): + if keepdim: + return tuple(self.shape[i] if i not in dims else 1 for i in range(self.ndim)) + + return utils.compute_reduction_output_shape(self.shape, dims) + + +# FakeTensors (meta tensors with a device) will report device as meta +# when running meta kernels. Here, access the "fake device" of FakeTensor if it +# exists so meta kernels which have diverge per device will be more +# accurate when run with FakeTensors +def device_hint(tensor) -> "str": + if isinstance(tensor, torch._subclasses.FakeTensor): + return tensor.fake_device.type + elif ( + hasattr(tensor, "device") + and hasattr(tensor.device, "type") + and tensor.device.type != "meta" + ): + return tensor.device.type + else: + return "cuda" # default to cuda + + +def calc_conv_nd_return_shape( + input_tensor: torch.Tensor, + weight: torch.Tensor, + stride: Union[list[int], int], + padding: Union[list[int], int], + dilation: Union[list[int], int], + is_transposed: bool, + groups: int, + output_padding: Optional[Union[list[int], int]] = None, +): + def _formula(ln: int, p: int, d: int, k: int, s: int) -> int: + """ + Formula to apply to calculate the length of some dimension of the output + + See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html + + Args: + ln: length of the dimension + p: padding in that dim + d: dilation in that dim + k: kernel size in that dim + s: stride in that dim + Returns: + The output length + """ + return (ln + 2 * p - d * (k - 1) - 1) // s + 1 + + def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int: + """ + Formula to apply to calculate the length of some dimension of the output + if transposed convolution is used. + See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html + + Args: + ln: length of the dimension + p: padding in that dim + d: dilation in that dim + k: kernel size in that dim + s: stride in that dim + op: output padding in that dim + + Returns: + The output length + """ + return (ln - 1) * s - 2 * p + d * (k - 1) + op + 1 + + kernel_size = weight.shape[2:] + dims = input_tensor.shape[2:] + if is_transposed: + out_channels = groups * weight.shape[1] + else: + out_channels = weight.shape[0] + if weight.shape[1] * groups != input_tensor.shape[1]: + raise RuntimeError("Invalid channel dimensions") + + ret_shape = [input_tensor.shape[0], out_channels] + if isinstance(stride, IntLike): + stride = [stride] * len(dims) + elif len(stride) == 1: + stride = [stride[0]] * len(dims) + + if isinstance(padding, IntLike): + padding = [padding] * len(dims) + elif len(padding) == 1: + padding = [padding[0]] * len(dims) + + if isinstance(dilation, IntLike): + dilation = [dilation] * len(dims) + elif len(dilation) == 1: + dilation = [dilation[0]] * len(dims) + + output_padding_list: Optional[list[int]] = None + if output_padding: + if isinstance(output_padding, IntLike): + output_padding_list = [output_padding] * len(dims) + elif len(output_padding) == 1: + output_padding_list = [output_padding[0]] * len(dims) + else: + output_padding_list = output_padding + + for i in range(len(dims)): + # If output_padding is present, we are dealing with a transposed convolution + if output_padding_list: + ret_shape.append( + _formula_transposed( + dims[i], + padding[i], + dilation[i], + kernel_size[i], + stride[i], + output_padding_list[i], + ) + ) + else: + ret_shape.append( + _formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i]) + ) + + torch._check( + any(x > 0 for x in ret_shape[2:]), + lambda: f"Given input size per channel: {list(dims)}. " + f"Calculated output size per channel: {ret_shape[2:]}. " + f"Output size is too small", + ) + + return ret_shape + + +def is_channels_last(ten): + return torch._prims_common.suggest_memory_format(ten) == torch.channels_last + + +@register_meta(aten.miopen_batch_norm.default) +def meta_miopen_batch_norm( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + running_mean: Optional[torch.Tensor], + running_var: Optional[torch.Tensor], + training: bool, + exponential_average_factor: float, + epsilon: float, +): + # In batch norm the output is of the same shape as the input + out_shape = input_tensor.shape + + # If tensor is provided for running_mean and running_var then use this. If these are not + # provded then we return the shape of weight tensor. Similar to how this is handled in the decomposition + save_mean_shape = running_mean.shape if running_mean is not None else weight.shape + save_var_shape = running_var.shape if running_var is not None else weight.shape + + def pick_memory_format(): + if is_channels_last(input_tensor): + return torch.channels_last + if input_tensor.is_contiguous(memory_format=torch.contiguous_format): + return torch.contiguous_format + return torch.contiguous_format + + out = input_tensor.new_empty(out_shape).to(memory_format=pick_memory_format()) + + if training: + save_mean = input_tensor.new_empty(save_mean_shape) + save_var = input_tensor.new_empty(save_var_shape) + else: + save_mean = input_tensor.new_empty((0,)) + save_var = input_tensor.new_empty((0,)) + + return out, save_mean, save_var + + +@register_meta(aten.convolution.default) +def meta_conv( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: list[int], + padding: list[int], + dilation: list[int], + is_transposed: bool, + output_padding: list[int], + groups: int, +): + def pick_memory_format(): + if device_hint(input_tensor) == "cuda": + if is_channels_last(input_tensor) or is_channels_last(weight): + return torch.channels_last + else: + if is_channels_last(input_tensor): + return torch.channels_last + if input_tensor.is_contiguous(memory_format=torch.contiguous_format): + return torch.contiguous_format + elif input_tensor.is_contiguous(memory_format=torch.preserve_format): + return torch.preserve_format + + shape_out = calc_conv_nd_return_shape( + input_tensor, + weight, + stride, + padding, + dilation, + is_transposed, + groups, + output_padding if is_transposed else None, + ) + + input_channels_dim = 1 + output_channels_dim = 1 + if input_tensor.size(input_channels_dim) == 0: + shape_out[output_channels_dim] = 0 + + out = input_tensor.new_empty(shape_out) + out = out.to(memory_format=pick_memory_format()) # type: ignore[call-overload] + return out + + +if torch._C._has_mkldnn: + _meta_lib_dont_use_me_use_register_meta_for_mkldnn = torch.library.Library( + "mkldnn", "IMPL", "Meta" + ) + + @register_meta(torch.ops.mkldnn._convolution_pointwise.default) + def meta_mkldnn_convolution_default( + input_tensor, + weight, + bias, + padding, + stride, + dilation, + groups, + attr, + scalars, + algorithm, + ): + shape_out = calc_conv_nd_return_shape( + input_tensor, weight, stride, padding, dilation, False, groups, [] + ) + out = input_tensor.new_empty(shape_out) + out_memory_format = torch.channels_last + if input_tensor.dim() == 5: + out_memory_format = torch.channels_last_3d + out = out.to(memory_format=out_memory_format) # type: ignore[call-overload] + return out + + @register_meta(torch.ops.mkldnn._linear_pointwise.default) + def meta_linear_pointwise_default( + input_tensor, weight, bias, attr, scalars, algorithm + ): + return input_tensor.new_empty((*input_tensor.shape[:-1], weight.shape[0])) + + if torch._C.has_mkl: + _meta_lib_dont_use_me_use_register_meta_for_mkl = torch.library.Library( + "mkl", "IMPL", "Meta" + ) + + @register_meta(torch.ops.mkl._mkl_linear) + def meta_mkl_linear(input_tensor, packed_weight, orig_weight, bias, batch_size): + return input_tensor.new_empty( + (*input_tensor.shape[:-1], orig_weight.shape[0]) + ) + + _meta_lib_dont_use_me_use_register_meta_for_onednn = torch.library.Library( + "onednn", "IMPL", "Meta" + ) + + @register_meta(torch.ops.onednn.qconv2d_pointwise.default) + @register_meta(torch.ops.onednn.qconv_pointwise.default) + def meta_qconv_pointwise( + x, + x_scale, + x_zp, + w, # prepacked_weight + w_scale, + w_zp, + bias, + stride, + padding, + dilation, + groups, + output_scale, + output_zero_point, + output_dtype, + attr, + scalars, + algorithm, + ): + shape_out = calc_conv_nd_return_shape( + x, + w, + stride, + padding, + dilation, + False, + groups, + None, + ) + assert output_dtype in [torch.float32, torch.bfloat16, torch.uint8, torch.int8] + out = x.new_empty(shape_out, dtype=output_dtype) + assert len(shape_out) in [3, 4], "only conv1d/2d are supported" + format = torch.channels_last if len(shape_out) == 4 else torch.contiguous_format + out = out.to(memory_format=format) + return out + + @register_meta(torch.ops.onednn.qconv2d_pointwise.binary) + def meta_qconv2d_pointwise_binary( + x, + x_scale, + x_zp, + w, + w_scale, + w_zp, + accum, + bias, + stride, + padding, + dilation, + groups, + output_scale, + output_zero_point, + output_dtype, + accum_scale, + accum_zero_point, + binary_op_name, + alpha, + unary_op_name, + unary_op_args, + unary_op_algorithm, + ): + assert binary_op_name == "sum" + return accum + + @register_meta(torch.ops.onednn.qlinear_pointwise.default) + @register_meta(torch.ops.onednn.qlinear_pointwise.tensor) + def meta_qlinear_pointwise( + x, + x_scale, + x_zp, + w, + w_scale, + w_zp, + bias, + output_scale, + output_zero_point, + output_dtype, + post_op_name, + post_op_args, + post_op_algorithm, + ): + output_shape = list(x.shape) + # The weight has been transposed during the qlinear weight prepack process. + output_shape[-1] = w.shape[1] + assert output_dtype in [torch.float32, torch.bfloat16, torch.int8, torch.uint8] + out = x.new_empty(output_shape, dtype=output_dtype) + return out + + @register_meta(torch.ops.onednn.qlinear_pointwise.binary) + @register_meta(torch.ops.onednn.qlinear_pointwise.binary_tensor) + def meta_qlinear_pointwise_binary( + x, + x_scale, + x_zp, + w, + w_scale, + w_zp, + x_2, + bias, + output_scale, + output_zero_point, + output_dtype, + x2_scale, + x2_zp, + binary_op_name, + alpha, + unary_op_name, + unary_op_args, + unary_op_algorithm, + ): + if binary_op_name == "sum": + return x_2 + output_shape = list(x.shape) + # The weight has been transposed during the qlinear weight prepack process. + output_shape[-1] = w.shape[1] + assert output_dtype in [torch.float32, torch.bfloat16, torch.uint8, torch.int8] + out = x.new_empty(output_shape, dtype=output_dtype) + return out + + @register_meta(torch.ops.onednn.linear_dynamic_fp16.default) + @register_meta(torch.ops.onednn.linear_relu_dynamic_fp16.default) + def meta_linear_dynamic_fp16( + x, + w, + bias, + ): + output_shape = list(x.shape) + # The weight has been transposed during the qlinear weight prepack process. + output_shape[-1] = w.shape[1] + out = x.new_empty(output_shape) + return out + + _meta_lib_dont_use_me_use_register_meta_for_quantized = torch.library.Library( + "quantized", "IMPL", "Meta" + ) + + @register_meta(torch.ops.quantized.max_pool2d) + def meta_quantized_max_pool2d( + input, + kernel_size, + stride=(), + padding=(0,), + dilation=(1,), + ceil_mode=False, + ): + ( + nInputPlane, + outputHeight, + outputWidth, + ) = max_pool2d_checks_and_compute_shape( + input, kernel_size, stride, padding, dilation, ceil_mode + ) + nbatch = input.size(-4) if input.dim() == 4 else 1 + memory_format = torch.channels_last + if input.dim() == 3: + size = [nInputPlane, outputHeight, outputWidth] + else: + size = [nbatch, nInputPlane, outputHeight, outputWidth] + return torch.empty( + size, + dtype=input.dtype, + device=input.device, + memory_format=memory_format, + ) + + @register_meta(torch.ops.quantized.int4mm_packed_weight_cpu) + def meta_int4mm_packed_weight_cpu(x, w, q_group_size, q_scale_and_zeros): + torch._check(x.dim() == 2, f"x must be a 2D tensor, got {x.dim()}D") + torch._check(w.dim() == 2, f"w must be a 2D tensor, got {w.dim()}D") + torch._check( + x.dtype in [torch.float32, torch.float16, torch.bfloat16], + f"expected x to be f32/f16/bf16, got {x.dtype}", + ) + torch._check(w.dtype == torch.uint8, f"expected w to be uint8, got {w.dtype}") + torch._check( + q_group_size.dtype == torch.int64, + f"q_group_size must be int64, got {q_group_size.dtype}", + ) + torch._check( + q_scale_and_zeros.dtype == x.dtype, + f"q_scale_and_zeros must have the same dtype as x, got {q_scale_and_zeros.dtype}", + ) + return x.new_empty(x.size(0), w.size(0), dtype=x.dtype) + + +# from check_dim_size() in aten/src/ATen/TensorUtils.cpp. +def check_dim_size(tensor, dim, dim_size, size): + torch._check( + tensor.dim() == dim and tensor.shape[dim_size] == size, + lambda: f"Expected a tensor of dimension {dim} and tensor.size[{dim_size}] == {size}, " + + f"but got : dimension {tensor.dim()} and tensor.size[{dim_size}] = {tensor.shape[dim_size]}", + ) + + +@register_meta(aten.avg_pool2d.default) +def meta_avg_pool2d( + input, + kernel_size, + stride=(), + padding=(0,), + ceil_mode=False, + count_include_pad=True, + divisor_override=None, +): + def unpack(name, val): + torch._check( + len(val) in [1, 2], + lambda: f"avg_pool2d: {name} must either be a single int, or a tuple of two ints", + ) + H = val[0] + W = H if len(val) == 1 else val[1] + return H, W + + kH, kW = unpack("kernel_size", kernel_size) + torch._check( + len(stride) in [0, 1, 2], + lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints", + ) + torch._check( + input.dtype not in [torch.uint8, torch.uint16, torch.uint32, torch.uint64], + lambda: f""""avg_pool2d" not implemented for '{input.dtype.__str__()}'""", + ) + if len(stride) == 0: + dH, dW = kH, kW + elif len(stride) == 1: + dH, dW = stride[0], stride[0] + else: + dH, dW = unpack("stride", stride) + + padH, padW = unpack("padding", padding) + + torch._check( + divisor_override is None or divisor_override != 0, + lambda: "divisor must be not zero", + ) + + nbatch = input.size(-4) if input.dim() == 4 else 1 + nInputPlane = input.size(-3) + inputHeight = input.size(-2) + inputWidth = input.size(-1) + + outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode) + outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode) + + memory_format = utils.suggest_memory_format(input) + pool2d_shape_check( + input, + kH, + kW, + dH, + dW, + padH, + padW, + 1, + 1, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + memory_format, + ) + + if input.dim() == 3: + size = [nInputPlane, outputHeight, outputWidth] + else: + size = [nbatch, nInputPlane, outputHeight, outputWidth] + return torch.empty( + size, + dtype=input.dtype, + device=input.device, + memory_format=memory_format, + ) + + +# from avg_pool2d_backward_shape_check() in aten/src/ATen/native/Pool.h. +def avg_pool2d_backward_shape_check( + input, + gradOutput, + nbatch, + kH, + kW, + dH, + dW, + padH, + padW, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + mem_format, +): + pool2d_shape_check( + input, + kH, + kW, + dH, + dW, + padH, + padW, + 1, + 1, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + mem_format, + ) + + ndim = input.dim() + nOutputPlane = nInputPlane + + check_dim_size(gradOutput, ndim, ndim - 3, nOutputPlane) + check_dim_size(gradOutput, ndim, ndim - 2, outputHeight) + check_dim_size(gradOutput, ndim, ndim - 1, outputWidth) + + +# Don't override the C++ registration. +@register_meta(aten.avg_pool2d_backward.default) +def meta_avg_pool2d_backward( + gradOutput_, + input, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, +): + # From aten/src/ATen/native/AveragePool2d.cpp structured kernel meta func. + torch._check( + len(kernel_size) == 1 or len(kernel_size) == 2, + lambda: "avg_pool2d: kernel_size must either be a single int, or a tuple of two ints", + ) + kH = kernel_size[0] + kW = kH if len(kernel_size) == 1 else kernel_size[1] + torch._check( + len(stride) == 0 or len(stride) == 1 or len(stride) == 2, + lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints", + ) + dH = kH if len(stride) == 0 else stride[0] + dW = kW if len(stride) == 0 else dH if len(stride) == 1 else stride[1] + torch._check( + len(padding) == 1 or len(padding) == 2, + lambda: "avg_pool2d: padding must either be a single int, or a tuple of two ints", + ) + padH = padding[0] + padW = padH if len(padding) == 1 else padding[1] + + torch._check( + divisor_override is None or divisor_override != 0, + lambda: "divisor must be not zero", + ) + + input_size = input.shape + nbatch = input_size[-4] if input.dim() == 4 else 1 + nInputPlane = input_size[-3] + inputHeight = input_size[-2] + inputWidth = input_size[-1] + + outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode) + outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode) + + mem_format = utils.suggest_memory_format(input) + + avg_pool2d_backward_shape_check( + input, + gradOutput_, + nbatch, + kH, + kW, + dH, + dW, + padH, + padW, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + mem_format, + ) + + return torch.empty( + input_size, + dtype=input.dtype, + device=input.device, + memory_format=mem_format, + ) + + +@register_meta(aten.avg_pool3d) +@out_wrapper() +def meta_avg_pool3d( + input, + kernel_size, + stride=(), + padding=(0,), + ceil_mode=False, + count_include_pad=True, + divisor_override=None, +): + torch._check( + len(kernel_size) in (1, 3), + lambda: "avg_pool3d: kernel_size must be a single int, or a tuple of three ints", + ) + kT = kernel_size[0] + kH = kT if len(kernel_size) == 1 else kernel_size[1] + kW = kT if len(kernel_size) == 1 else kernel_size[2] + + torch._check( + not stride or len(stride) in (1, 3), + lambda: "avg_pool3d: stride must be omitted, a single int, or a tuple of three ints", + ) + torch._check( + input.dtype not in [torch.uint8, torch.uint16, torch.uint32, torch.uint64], + lambda: f""""avg_pool3d" not implemented for '{input.dtype.__str__()}'""", + ) + dT = kT if not stride else stride[0] + dH = kH if not stride else (dT if len(stride) == 1 else stride[1]) + dW = kW if not stride else (dT if len(stride) == 1 else stride[2]) + + torch._check( + len(padding) in (1, 3), + lambda: "avg_pool3d: padding must be a single int, or a tuple of three ints", + ) + padT = padding[0] + padH = padT if len(padding) == 1 else padding[1] + padW = padT if len(padding) == 1 else padding[2] + + torch._check( + input.ndim in (4, 5), + lambda: "non-empty 4D or 5D (batch mode) tensor expected for input", + ) + + torch._check( + not divisor_override or divisor_override != 0, + lambda: "divisor must be not zero", + ) + + nbatch = input.size(0) + nslices = input.size(-4) + itime = input.size(-3) + iheight = input.size(-2) + iwidth = input.size(-1) + + otime = pooling_output_shape(itime, kT, padT, dT, 1, ceil_mode) + oheight = pooling_output_shape(iheight, kH, padH, dH, 1, ceil_mode) + owidth = pooling_output_shape(iwidth, kW, padW, dW, 1, ceil_mode) + + pool3d_shape_check( + input, + nslices, + kT, + kH, + kW, + dT, + dH, + dW, + padT, + padH, + padW, + 1, + 1, + 1, + itime, + iheight, + iwidth, + otime, + oheight, + owidth, + "avg_pool3d()", + check_input_size=True, + ) + + if input.ndim == 4: + return input.new_empty((nslices, otime, oheight, owidth)) + else: + return input.new_empty((nbatch, nslices, otime, oheight, owidth)) + + +@register_meta(aten.avg_pool3d_backward) +@out_wrapper("grad_input") +def meta_avg_pool3d_backward( + grad_output, + input, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, +): + torch._check( + len(kernel_size) in (1, 3), + lambda: "avg_pool3d: kernel_size must be a single int, or a tuple of three ints", + ) + kT = kernel_size[0] + kH = kT if len(kernel_size) == 1 else kernel_size[1] + kW = kT if len(kernel_size) == 1 else kernel_size[2] + + torch._check( + not stride or len(stride) in (1, 3), + lambda: "avg_pool3d: stride must be omitted, a single int, or a tuple of three ints", + ) + dT = kT if not stride else stride[0] + dH = kH if not stride else (dT if len(stride) == 1 else stride[1]) + dW = kW if not stride else (dT if len(stride) == 1 else stride[2]) + + torch._check( + len(padding) in (1, 3), + lambda: "avg_pool3d: padding must be a single int, or a tuple of three ints", + ) + padT = padding[0] + padH = padT if len(padding) == 1 else padding[1] + padW = padT if len(padding) == 1 else padding[2] + + torch._check( + input.ndim in (4, 5), + lambda: "non-empty 4D or 5D (batch mode) tensor expected for input", + ) + + torch._check( + not divisor_override or divisor_override != 0, + lambda: "divisor must be not zero", + ) + + nslices = input.size(-4) + itime = input.size(-3) + iheight = input.size(-2) + iwidth = input.size(-1) + + otime_for_shape_check = pooling_output_shape(itime, kT, padT, dT, 1, ceil_mode) + oheight_for_shape_check = pooling_output_shape(iheight, kH, padH, dH, 1, ceil_mode) + owidth_for_shape_check = pooling_output_shape(iwidth, kW, padW, dW, 1, ceil_mode) + + avg_pool3d_backward_shape_check( + input, + grad_output, + nslices, + kT, + kH, + kW, + dT, + dH, + dW, + padT, + padH, + padW, + itime, + iheight, + iwidth, + otime_for_shape_check, + oheight_for_shape_check, + owidth_for_shape_check, + "avg_pool3d_backward()", + ) + + return input.new_empty(input.shape) + + +@register_meta(aten._adaptive_avg_pool2d.default) +def meta_adaptive_avg_pool2d(self, output_size): + torch._check( + self.ndim == 3 or self.ndim == 4, + lambda: f"Expected 3D or 4D tensor, but got {self.shape}", + ) + output_shape = self.shape[:-2] + tuple(output_size) + memory_format = utils.suggest_memory_format(self) + # need to set memory_format to preserve the memory format of the input + # channel last input should have channel last output + return torch.empty( + output_shape, + dtype=self.dtype, + device=self.device, + memory_format=memory_format, + ) + + +@register_meta(aten._adaptive_avg_pool3d.default) +def meta_adaptive_avg_pool3d(self, output_size): + torch._check( + self.ndim == 4 or self.ndim == 5, + lambda: f"Expected 4D or 5D tensor, but got {self.shape}", + ) + return self.new_empty(self.shape[:-3] + tuple(output_size)) + + +@register_meta(aten._adaptive_avg_pool2d_backward.default) +def meta__adaptive_avg_pool2d_backward(grad_out, self): + ndim = grad_out.ndim + for i in range(1, ndim): + torch._check( + grad_out.size(i) > 0, + lambda: f"adaptive_avg_pool2d_backward(): Expected grad_output to have non-zero \ + size for non-batch dimensions, {grad_out.shape} with dimension {i} being empty", + ) + torch._check( + ndim == 3 or ndim == 4, + lambda: f"adaptive_avg_pool2d_backward(): Expected 3D or 4D tensor, but got {self.shape}", + ) + torch._check( + self.dtype == grad_out.dtype, + lambda: f"expected dtype {self.dtype} for `grad_output` but got dtype {grad_out.dtype}", + ) + memory_format = torch.contiguous_format + if is_channels_last(self): + memory_format = torch.channels_last + return self.new_empty(self.shape).to(memory_format=memory_format) + + +@register_meta(aten._adaptive_avg_pool3d_backward) +@out_wrapper("grad_input") +def meta__adaptive_avg_pool3d_backward(grad_output, self): + _adaptive_pool_empty_output_check(grad_output, "adaptive_avg_pool3d_backward") + return torch.empty_like(self, memory_format=torch.legacy_contiguous_format) + + +def _adaptive_pool_empty_output_check(grad_output: Tensor, arg_name: str): + ndim = grad_output.ndim + for i in range(1, ndim): + torch._check( + grad_output.size(i) > 0, + lambda: ( + f"{arg_name}(): Expected grad_output to have non-zero size for non-batch dimensions, " + f"but grad_output has sizes {grad_output.shape} with dimension {i} being empty" + ), + ) + + +@register_meta(aten.adaptive_max_pool2d) +@out_wrapper("out", "indices") +def meta_adaptive_max_pool2d(input, output_size): + ndim = input.ndim + torch._check( + ndim in (3, 4), + lambda: f"adaptive_max_pool2d(): Expected 3D or 4D tensor, but got: {input.shape}", + ) + for i in range(1, ndim): + torch._check( + input.size(i) > 0, + lambda: ( + f"adaptive_max_pool2d(): Expected input to have non-zero size for non-batch dimensions, " + f"but input has sizes {input.shape} with dimension {i} being empty" + ), + ) + + torch._check( + len(output_size) == 2, + lambda: "adaptive_max_pool2d(): internal error: output_size.size() must be 2", + ) + + dimH = 1 + sizeB = 1 + sizeD = 0 + + if input.ndim == 4: + sizeB = input.size(0) + dimH += 1 + + sizeD = input.size(dimH - 1) + osizeH, osizeW = output_size + + if input.ndim == 3: + out_shape = (sizeD, osizeH, osizeW) + out = input.new_empty(out_shape) + indices = input.new_empty(out_shape, dtype=torch.int64) + return out, indices + else: + out_shape = (sizeB, sizeD, osizeH, osizeW) # type: ignore[assignment] + memory_format = utils.suggest_memory_format(input) + out = input.new_empty(out_shape).to(memory_format=memory_format) + indices = input.new_empty(out_shape, dtype=torch.int64).to( + memory_format=memory_format + ) + return out, indices + + +@register_meta(aten.adaptive_max_pool2d_backward) +@out_wrapper("grad_input") +def meta_adaptive_max_pool2d_backward(grad_output, input, indices): + ndim = grad_output.ndim + torch._check( + ndim in (3, 4), + lambda: f"adaptive_max_pooling2d_backward(): Expected 3D or 4D grad_output, but got: {grad_output.shape}", + ) + + _adaptive_pool_empty_output_check(grad_output, "adaptive_max_pool2d_backward") + + torch._check( + input.dtype == grad_output.dtype, + lambda: f"expected dtype {input.dtype} for `grad_output` but got dtype {grad_output.dtype}", + ) + + memory_format = utils.suggest_memory_format(input) + return input.new_empty(input.shape).to(memory_format=memory_format) + + +@register_meta(aten.adaptive_max_pool3d) +@out_wrapper("out", "indices") +def meta_adaptive_max_pool3d(input, output_size): + ndim = input.ndim + torch._check( + ndim in (4, 5), + lambda: f"adaptive_max_pool3d(): Expected 4D or 5D tensor, but got: {input.shape}", + ) + for i in range(1, ndim): + torch._check( + input.size(i) > 0, + lambda: ( + f"adaptive_max_pool3d(): Expected input to have non-zero size for non-batch dimensions, " + f"but input has sizes {input.shape} with dimension {i} being empty" + ), + ) + + torch._check( + len(output_size) == 3, + lambda: "adaptive_max_pool3d(): internal error: output_size.size() must be 3", + ) + + dimD = 0 + sizeB = 1 + sizeD = 0 + + if ndim == 5: + sizeB = input.size(0) + dimD += 1 + + sizeD = input.size(dimD) + osizeT, osizeH, osizeW = output_size + + if ndim == 4: + out_shape = (sizeD, osizeT, osizeH, osizeW) + else: + out_shape = (sizeB, sizeD, osizeT, osizeH, osizeW) # type: ignore[assignment] + + out = input.new_empty(out_shape) + indices = input.new_empty(out_shape, dtype=torch.int64) + + return out, indices + + +@register_meta(aten.adaptive_max_pool3d_backward) +@out_wrapper("grad_input") +def meta_adaptive_max_pool3d_backward(grad_output, input, indices): + _adaptive_pool_empty_output_check(grad_output, "adaptive_max_pool3d_backward") + return input.new_empty(input.shape) + + +@register_meta(aten.repeat_interleave.Tensor) +def meta_repeat_interleave_Tensor(repeats, output_size=None): + if output_size is None: + raise RuntimeError("cannot repeat_interleave a meta tensor without output_size") + return repeats.new_empty(output_size) + + +@register_meta([aten.complex.default, aten.complex.out]) +@out_wrapper() +def meta_complex(real, imag): + assert real.dtype.is_floating_point + assert imag.dtype.is_floating_point + out_shape = _broadcast_shapes(real.shape, imag.shape) + return real.new_empty(out_shape, dtype=corresponding_complex_dtype(real.dtype)) + + +@register_meta([aten.nonzero_static.default, aten.nonzero_static.out]) +@out_wrapper() +def nonzero_static(self, *, size, fill_value: int = -1): + return self.new_empty((size, self.dim()), dtype=torch.long) + + +@register_meta([torch.ops.aten.nonzero.default, torch.ops.aten.nonzero.out]) +@out_wrapper() +def nonzero(self): + torch._check_not_implemented( + exp_config.meta_nonzero_assume_all_nonzero, + lambda: "The register_meta function for torch.nonzero() raises unimplemented by default, " + "as a correct data-independent implementation does not exist. This implementation " + "returns a fake value, assuming all elements of the tensor are non-zero. " + "To enable this registration, please set " + "'torch.fx.experimental._config.meta_nonzero_assume_all_nonzero' to True.", + ) + return torch.empty_strided( + (self.numel(), self.dim()), + (1, self.numel()), + dtype=torch.long, + device=self.device, + ) + + +@register_meta([aten.index.Tensor, aten._unsafe_index.Tensor]) +def meta_index_Tensor(self, indices): + torch._check(bool(indices), lambda: "at least one index must be provided") + # aten::index is the internal advanced indexing implementation + # checkIndexTensorTypes and expandTensors + result: list[Optional[Tensor]] = [] + for i, index in enumerate(indices): + if index is not None: + torch._check( + index.dtype in [torch.long, torch.int, torch.int8, torch.bool], + lambda: "tensors used as indices must be long, int, byte or bool tensors", + ) + if index.dtype in [torch.int8, torch.bool]: + nonzero = index.nonzero() + k = len(result) + torch._check_index( + k + index.ndim <= self.ndim, + lambda: f"too many indices for tensor of dimension {self.ndim}", + ) + for j in range(index.ndim): + torch._check_index( + index.shape[j] == self.shape[k + j], + lambda: f"The shape of the mask {index.shape} at index {i} " + f"does not match the shape of the indexed tensor {self.shape} at index {k + j}", + ) + result.append(nonzero.select(1, j)) + else: + result.append(index) + else: + result.append(index) + indices = result + torch._check( + len(indices) <= self.ndim, + lambda: f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})", + ) + # expand_outplace + import torch._refs as refs # avoid import cycle in mypy + + indices = list(refs._maybe_broadcast(*indices)) + # add missing null tensors + while len(indices) < self.ndim: + indices.append(None) + + # hasContiguousSubspace + # true if all non-null tensors are adjacent + # See: + # https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing + # https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency + state = 0 + has_contiguous_subspace = False + for index in indices: + if state == 0: + if index is not None: + state = 1 + elif state == 1: + if index is None: + state = 2 + else: + if index is not None: + break + else: + has_contiguous_subspace = True + + # transposeToFront + # This is the logic that causes the newly inserted dimensions to show up + # at the beginning of the tensor, if they're not contiguous + if not has_contiguous_subspace: + dims = [] + transposed_indices = [] + for i, index in enumerate(indices): + if index is not None: + dims.append(i) + transposed_indices.append(index) + for i, index in enumerate(indices): + if index is None: + dims.append(i) + transposed_indices.append(index) + self = self.permute(dims) + indices = transposed_indices + + # AdvancedIndex::AdvancedIndex + # Now we can assume the indices have contiguous subspace + # This is simplified from AdvancedIndex which goes to more effort + # to put the input and indices in a form so that TensorIterator can + # take them. If we write a ref for this, probably that logic should + # get implemented + before_shape: list[int] = [] + after_shape: list[int] = [] + replacement_shape: list[int] = [] + for dim, index in enumerate(indices): + if index is None: + if replacement_shape: + after_shape.append(self.shape[dim]) + else: + before_shape.append(self.shape[dim]) + else: + replacement_shape = list(index.shape) + + def _restride_src(self): + """ + This follows restride_src in TensorAdvancedIndexing.cpp + """ + shape = before_shape + replacement_shape + after_shape + strides = list(self.stride()) + strides[len(before_shape) : len(self.shape) - len(after_shape)] = [0] * len( + replacement_shape + ) + return self.as_strided(shape, strides) + + out = self.new_empty(before_shape + replacement_shape + after_shape) + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if guard_size_oblivious(self.numel() == 0): + # No need to worry about the output strides if self is empty. + return out + + # Try to follow eager to decide the output stride based on self. + # Note that perm here is the reverse of the 'perm_' decided by + # TensorIteratorBase::reorder_dimensions + restrided_self = _restride_src(self) + perm = utils.compute_elementwise_output_logical_to_physical_perm(restrided_self) + + # Follow TensorIteratorBase::allocate_or_resize_outputs + if list(perm) != list(range(len(perm))): + perm_shape = utils.apply_perm(out.shape, perm) + new_stride = utils.make_contiguous_strides_for(perm_shape) + new_stride = utils.apply_perm(new_stride, utils.invert_perm(perm)) + out = out.as_strided(out.size(), new_stride) + return out + + +@register_meta([aten.convolution_backward.default]) +def meta_convolution_backward( + grad_output_, + input_, + weight_, + bias_sizes_opt, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + output_mask, +): + # High level logic taken from slow_conv3d_backward_cpu which should + # be representative of all convolution_backward impls + backend_grad_input = None + backend_grad_weight = None + backend_grad_bias = None + + if output_mask[0]: + backend_grad_input = grad_output_.new_empty(input_.size()) + if output_mask[1]: + backend_grad_weight = grad_output_.new_empty(weight_.size()) + if output_mask[2]: + backend_grad_bias = grad_output_.new_empty(bias_sizes_opt) + + return (backend_grad_input, backend_grad_weight, backend_grad_bias) + + +@register_meta([aten.addbmm.default, aten.addbmm.out]) +@out_wrapper(exact_dtype=True) +def meta_addbmm(self, batch1, batch2, *, beta=1, alpha=1): + dim1 = batch1.size(1) + dim2 = batch2.size(2) + self = self.expand((dim1, dim2)) + torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor") + torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor") + torch._check( + batch1.size(0) == batch2.size(0), + lambda: f"batch1 and batch2 must have same number of batches, got {batch1.size(0)} and {batch2.size(0)}", + ) + torch._check( + batch1.size(2) == batch2.size(1), + lambda: ( + f"Incompatible matrix sizes for bmm ({batch1.size(1)}x{batch1.size(2)} " + f"and {batch2.size(1)}x{batch2.size(2)})" + ), + ) + torch._check( + self.size(0) == dim1 and self.size(1) == dim2, + lambda: "self tensor does not match matmul output shape", + ) + return self.new_empty(self.size()) + + +@register_meta([aten.randint_like.Tensor]) +def meta_randint_like(self, high, **kwargs): + return self.new_empty(self.size()) + + +@register_meta([aten._fused_adam_.default, aten._fused_adamw_.default]) +def meta__fused_adam_( + self, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + *, + lr, + beta1, + beta2, + weight_decay, + eps, + amsgrad, + maximize, + grad_scale=None, + found_inf=None, +): + for l in [self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]: + torch._check( + isinstance(l, list), + lambda: f"exponent must be a tensor list but got {type(l)}", + ) + + +@register_meta([aten._fused_adam.default]) +def meta__fused_adam( + self, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + *, + lr, + beta1, + beta2, + weight_decay, + eps, + amsgrad, + maximize, + grad_scale=None, + found_inf=None, +): + for l in [self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]: + torch._check( + isinstance(l, list), + lambda: f"exponent must be a tensor list but got {type(l)}", + ) + + def empty_like_list(tensor_list): + return [torch.empty_like(t) for t in tensor_list] + + return ( + empty_like_list(self), + empty_like_list(grads), + empty_like_list(exp_avgs), + empty_like_list(exp_avg_sqs), + empty_like_list(max_exp_avg_sqs), + ) + + +@register_meta([aten._int_mm]) +@out_wrapper() +def meta__int_mm(a, b): + torch._check(a.dim() == 2, lambda: "a must be a 2D tensor") + torch._check(b.dim() == 2, lambda: "b must be a 2D tensor") + torch._check( + a.dtype is torch.int8, + lambda: f"expected self to be int8, got {a.dtype}", + ) + torch._check( + b.dtype is torch.int8, + lambda: f"expected mat2 to be int8, got {b.dtype}", + ) + torch._check( + a.size(1) == b.size(0), + lambda: ( + f"Incompatible matrix sizes for _int_mm ({a.size(0)}x{a.size(1)} " + f"and {b.size(0)}x{b.size(1)})" + ), + ) + return a.new_empty((a.size(0), b.size(1)), dtype=torch.int32) + + +@register_meta([aten._convert_weight_to_int4pack]) +def meta__convert_weight_to_int4pack(w, inner_k_tiles): + torch._check(w.dim() == 2, lambda: "w must be a 2D tensor") + torch._check( + w.dtype is torch.uint8, + lambda: f"expected w to be uint8, got {w.dtype}", + ) + n = w.size(0) + k = w.size(1) * 2 # w is [n][k / 2] uint8 + return w.new_empty( + ( + n // 8, + k // (inner_k_tiles * 16), + 32, + inner_k_tiles // 2, + ), + dtype=torch.int32, + ) + + +@register_meta([aten._convert_weight_to_int4pack_for_cpu]) +def meta__convert_weight_to_int4pack_for_cpu(w, inner_k_tiles): + torch._check(w.dim() == 2, lambda: "w must be a 2D tensor") + torch._check( + w.dtype is torch.int32, + lambda: f"expected w to be int32, got {w.dtype}", + ) + n = w.size(0) + k = w.size(1) # w is [n][k] int32 + return w.new_empty( + (n, k // 2), + dtype=torch.uint8, + ) + + +@register_meta([aten._weight_int4pack_mm]) +def meta__weight_int4pack_mm(x, w, q_group_size, q_scale_and_zeros): + torch._check(x.dim() == 2, lambda: "x must be a 2D tensor") + torch._check(w.dim() == 4, lambda: "w must be a 4D tensor") + torch._check( + x.dtype in [torch.float32, torch.float16, torch.bfloat16], + lambda: f"expected x to be f32/f16/bf16, got {x.dtype}", + ) + torch._check( + w.dtype is torch.int32, + lambda: f"expected w to be int32, got {w.dtype}", + ) + return x.new_empty(x.size(0), w.size(0) * 8, dtype=x.dtype) + + +@register_meta([aten._weight_int4pack_mm_for_cpu]) +def meta__weight_int4pack_mm_for_cpu(x, w, q_group_size, q_scale_and_zeros): + torch._check(x.dim() == 2, lambda: "x must be a 2D tensor") + torch._check(w.dim() == 2, lambda: "w must be a 2D tensor") + torch._check( + x.dtype in [torch.float32, torch.float16, torch.bfloat16], + lambda: f"expected x to be f32/f16/bf16, got {x.dtype}", + ) + torch._check( + w.dtype is torch.uint8, + lambda: f"expected w to be uint8, got {w.dtype}", + ) + return x.new_empty(x.size(0), w.size(0), dtype=x.dtype) + + +@register_meta([aten._weight_int4pack_mm_with_scales_and_zeros]) +def _weight_int4pack_mm_with_scales_and_zeros(x, w, q_group_size, qScale, qZeros): + torch._check(x.dim() == 2, lambda: "x must be a 2D tensor") + torch._check(w.dim() == 2, lambda: "w must be a 2D tensor") + torch._check( + x.dtype in [torch.float32, torch.float16, torch.bfloat16], + lambda: f"expected x to be f32/f16/bf16, got {x.dtype}", + ) + torch._check( + w.dtype is torch.int32, + lambda: f"expected w to be int32, got {w.dtype}", + ) + return x.new_empty(x.size(0), w.size(0), dtype=x.dtype) + + +def kai_roundup(a: int, b: int) -> int: + return ((a + b - 1) // b) * b + + +def get_kai_packed_weight_size(n_bits, N, K, groupsize): + if n_bits == 4: + if groupsize == K: # channelwise + # dotprod params only [1x8x32_neon_dotprod] + kai_nr = 8 + kai_kr = 16 + kai_sr = 2 + kai_num_bytes_sum_rhs = 4 # sizeof(int32_t) + kai_num_bytes_multiplier_rhs = 4 # sizeof(float) + kai_num_bytes_bias = 4 # sizeof(float) + + def kai_k_roundedup(k, kr, sr): + # Since we pack a float and int32 value at the end of the row, + # we must make sure that k is a multiple of 4 for alignment + kr_sr_roundedup4 = kai_roundup(kr * sr, 4) + return kai_roundup(k, kr_sr_roundedup4) + + def kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( + k, nr, kr, sr + ): + k_internal = kai_k_roundedup(k, kr, sr) + + assert (k_internal % 2) == 0, "k_internal must be even" + + return nr * ( + (k_internal // 2) + + kai_num_bytes_multiplier_rhs + + kai_num_bytes_sum_rhs + + kai_num_bytes_bias + ) + + def kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( + n, k, nr, kr, sr + ): + num_rows = kai_roundup(n, nr) // nr + + return ( + num_rows + * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( + k, nr, kr, sr + ) + ) + + return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( + N, K, kai_nr, kai_kr, kai_sr + ) + elif groupsize % 32 == 0 and K % groupsize == 0: # groupwise + kai_nr = 8 + kai_kr = 16 + kai_sr = 2 + kai_num_bytes_sum_rhs = 4 + kai_num_bytes_bias = 4 + kai_nr_multiple_of = 4 + kai_bl_multiple_of = 32 + + def kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + n, k, nr, kr, sr, bl + ): + assert (bl % kr) == 0 + assert (nr % kai_nr_multiple_of) == 0 + assert (bl % kai_bl_multiple_of) == 0 + + num_rows = kai_roundup(n, nr) // nr + + return ( + num_rows + * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + k, nr, kr, sr, bl + ) + ) + + def kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + k, nr, kr, sr, bl + ): + assert (bl % kr) == 0 + assert (nr % kai_nr_multiple_of) == 0 + assert (bl % kai_bl_multiple_of) == 0 + + # kr and sr are unused in the calculation + num_bytes_multiplier_rhs = kai_get_bf16_datatype_size_in_bytes() + num_blocks_per_row = kai_num_blocks_per_row(k, bl) + num_bytes_per_block = kai_num_bytes_per_block( + bl, num_bytes_multiplier_rhs + ) + + return nr * ( + (num_bytes_per_block * num_blocks_per_row) + + kai_num_bytes_sum_rhs + + kai_num_bytes_bias + ) + + # This funtion retuns size of these datatypes stored as enum. We modify it to just return bf16 datatype + # https://gitlab.arm.com/kleidi/kleidiai/-/blob/main/kai/kai_common.h?ref_type=heads#L55 + def kai_get_bf16_datatype_size_in_bytes(): + return 2 # 2 bytes + + def kai_num_blocks_per_row(k, bl): + assert (bl % kai_bl_multiple_of) == 0 + return kai_roundup(k, bl) // bl + + def kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs): + assert (bl % kai_bl_multiple_of) == 0 + return (bl // 2) + num_bytes_multiplier_rhs + + return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + N, K, kai_nr, kai_kr, kai_sr, groupsize + ) + + +@register_meta([aten._dyn_quant_pack_4bit_weight]) +def meta__dyn_quant_pack_4bit_weight( + weights, scales_zeros, bias: Optional[Tensor], block_size, in_features, out_features +): + torch._check( + weights.dtype is torch.uint8, + lambda: f"expected w to be uint8, got {weights.dtype}", + ) + if torch.backends.kleidiai.is_available() and ( + (block_size == in_features and scales_zeros.dtype == torch.float) + or ( + block_size < in_features + and block_size % 32 == 0 + and in_features % block_size == 0 + and scales_zeros.dtype == torch.bfloat16 + ) + ): + packed_weight_size = get_kai_packed_weight_size( + 4, out_features, in_features, block_size + ) + return weights.new_empty(int(packed_weight_size), dtype=torch.uint8) + packed_weight_size = weights.numel() + scales_zeros.numel() + return weights.new_empty(packed_weight_size, dtype=torch.float) + + +@register_meta([aten._dyn_quant_matmul_4bit]) +def meta__dyn_quant_matmul_4bit( + inp, + packed_weights, + block_size, + in_features, + out_features, +): + torch._check(inp.dim() == 2, lambda: "input must be a 2D tensor") + torch._check( + inp.dtype in [torch.float32], + lambda: f"expected input to be f32, got {inp.dtype}", + ) + M = inp.size(0) + return inp.new_empty(M, out_features, dtype=inp.dtype) + + +@register_meta([aten._weight_int8pack_mm]) +def meta__weight_int8pack_mm(x, w, q_scales): + torch._check(x.dim() == 2, lambda: "x must be a 2D tensor") + torch._check( + x.dtype in [torch.float32, torch.float16, torch.bfloat16], + lambda: f"expected x to be f32/f16/bf16, got {x.dtype}", + ) + torch._check(w.dim() == 2, lambda: "w must be a 2D tensor") + torch._check( + w.dtype is torch.int8, + lambda: f"expected w to be int8, got {w.dtype}", + ) + return x.new_empty(x.size(0), w.size(0), dtype=x.dtype) + + +@register_meta(aten._cdist_forward.default) +def meta_cdist_forward(x1, x2, p, compute_mode): + torch._check( + x1.dim() >= 2, + lambda: f"cdist only supports at least 2D tensors, X1 got: {x1.dim()}D", + ) + torch._check( + x2.dim() >= 2, + lambda: f"cdist only supports at least 2D tensors, X2 got: {x2.dim()}D", + ) + torch._check( + x1.size(-1) == x2.size(-1), + lambda: f"X1 and X2 must have the same number of columns. X1: {x1.size(-1)} X2: {x2.size(-1)}", + ) + torch._check( + utils.is_float_dtype(x1.dtype), + lambda: "cdist only supports floating-point dtypes, X1 got: {x1.dtype}", + ) + torch._check( + utils.is_float_dtype(x2.dtype), + lambda: "cdist only supports floating-point dtypes, X2 got: {x2.dtype}", + ) + torch._check(p >= 0, lambda: "cdist only supports non-negative p values") + torch._check( + compute_mode in (None, 1, 2), + lambda: f"possible modes: None, 1, 2, but was: {compute_mode}", + ) + r1 = x1.size(-2) + r2 = x2.size(-2) + batch_tensor1 = x1.shape[:-2] + batch_tensor2 = x2.shape[:-2] + output_shape = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2)) + output_shape.extend([r1, r2]) + return x1.new_empty(output_shape) + + +@register_meta(aten._cdist_backward) +@out_wrapper() +def meta_cdist_backward(grad, x1, x2, p, cdist): + c1 = x1.shape[-1] + r1 = x1.shape[-2] + r2 = x2.shape[-2] + batch_tensor1 = x1.shape[:-2] + batch_tensor2 = x2.shape[:-2] + expand_batch_portion = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2)) + tensor1_expand_size = expand_batch_portion.copy() + tensor1_expand_size.extend([r1, c1]) + batch_product = math.prod(expand_batch_portion) + if r1 == 0 or r2 == 0 or c1 == 0 or batch_product == 0: + return torch.zeros_like(x1) + if tensor1_expand_size != list(x1.shape): + x1 = x1.expand(tensor1_expand_size) + return torch.empty_like(x1, memory_format=torch.contiguous_format) + + +# NB: This meta function accepts non-meta arguments! When this behavior +# was originally introduced this was accidental, but it is now load bearing +# as people are using this so that they can conveniently test code involving +# embeddings (feeding CPU tensor inputs with meta device EmbeddingBag module) +@register_meta(aten._embedding_bag.default) +def meta_embedding_bag( + weight, + indices, + offsets, + scale_grad_by_freq=False, + mode=0, + sparse=False, + per_sample_weights=None, + include_last_offset=False, + padding_idx=-1, +): + torch._check( + indices.dtype in (torch.long, torch.int), + lambda: f"expected indices to be long or int, got {indices.dtype}", + ) + torch._check( + offsets.dtype in (torch.long, torch.int), + lambda: f"expected offsets to be long or int, got {offsets.dtype}", + ) + torch._check( + utils.is_float_dtype(weight.dtype), + lambda: f"expected weight to be floating point type, got {weight.dtype}", + ) + + num_bags = offsets.size(0) + if include_last_offset: + torch._check( + num_bags >= 1, + lambda: "include_last_offset: numBags should be at least 1", + ) + num_bags -= 1 + + output = weight.new_empty(num_bags, weight.size(1)) + + if per_sample_weights is not None: + torch._check( + mode == MODE_SUM, + lambda: "embedding_bag: per_sample_weights only supported with mode='sum'", + ) + torch._check( + per_sample_weights.ndim == 1, + lambda: f"expected per_sample_weights to be 1D tensor, got {per_sample_weights.ndim}D", + ) + torch._check( + per_sample_weights.numel() == indices.numel(), + lambda: ( + f"expected per_sample_weights.numel() ({per_sample_weights.numel()} " + f"to be the same as indices.numel() ({indices.numel()})" + ), + ) + + def is_fast_path_index_select_scale(src, scale, output, padding_idx): + return ( + is_fast_path_index_select(src, output, padding_idx) and scale.stride(0) == 1 + ) + + def is_fast_path_index_select(src, output, padding_idx): + return ( + (src.dtype == torch.float or src.dtype == torch.half) + and src.stride(1) == 1 + and output.stride(1) == 1 + and padding_idx < 0 + ) + + def is_fast_path(src, scale, output, padding_idx): + if scale is not None: + return is_fast_path_index_select_scale(src, scale, output, padding_idx) + else: + return is_fast_path_index_select(src, output, padding_idx) + + if device_hint(offsets) != "cpu": + offset2bag = indices.new_empty(indices.size(0)) + bag_size = indices.new_empty(offsets.size()) + if mode == MODE_MAX: + max_indices = indices.new_empty(num_bags, weight.size(1)) + else: + max_indices = indices.new_empty(0) + else: + fast_path_sum = is_fast_path(weight, per_sample_weights, output, padding_idx) + if mode in (MODE_MEAN, MODE_MAX) or not fast_path_sum: + offset2bag = offsets.new_empty(indices.size(0)) + else: + offset2bag = offsets.new_empty(0) + bag_size = offsets.new_empty(num_bags) + # This part of the logic comes from make_max_indices_out in EmbeddingBag.cpp + numBags = offsets.shape[0] + if mode == MODE_MAX: + if include_last_offset: + torch._check( + numBags >= 1, + lambda: "include_last_offset: numBags should be at least 1", + ) + numBags -= 1 + max_indices = offsets.new_empty(numBags, weight.shape[1]) + else: + max_indices = offsets.new_empty(bag_size.size()) + return output, offset2bag, bag_size, max_indices + + +@register_meta(aten._embedding_bag_forward_only.default) +def meta_embedding_bag_forward_only(weight, indices, offsets, *args): + output, offset2bag, bag_size, max_indices = meta_embedding_bag( + weight, indices, offsets, *args + ) + if device_hint(offsets) == "cpu": + bag_size = offsets.new_empty(offsets.size()) + return output, offset2bag, bag_size, max_indices + + +def _get_reduction_dtype(input, dtype, promote_int_to_long=True): + # if specified, dtype takes precedence + if dtype: + return dtype + + if input.dtype.is_floating_point or input.dtype.is_complex: + return input.dtype + elif promote_int_to_long: + return torch.long + + return input.dtype + + +@register_meta([aten.nansum.default, aten.nansum.out]) +@out_wrapper() +def meta_nansum(input, dims=None, keepdim=False, *, dtype=None): + output_dtype = _get_reduction_dtype(input, dtype, promote_int_to_long=True) + dims = utils.reduction_dims(input.shape, dims) + output_shape = _compute_reduction_shape(input, dims, keepdim) + return input.new_empty(output_shape, dtype=output_dtype) + + +@register_meta([aten.median.default, aten.nanmedian.default]) +def meta_median(input): + output_shape = utils.compute_reduction_output_shape( + input.shape, tuple(range(input.dim())) + ) + return input.new_empty(output_shape) + + +@register_meta( + [ + aten.median.dim, + aten.median.dim_values, + aten.nanmedian.dim, + aten.nanmedian.dim_values, + aten.mode.default, + aten.mode.values, + ] +) +@out_wrapper("values", "indices") +def meta_median_mode_dim(input, dim=-1, keepdim=False): + if device_hint(input) == "cuda": + utils.alert_not_deterministic("median CUDA with indices output") + dim = utils.reduction_dims(input.shape, (dim,)) + output_shape = _compute_reduction_shape(input, dim, keepdim) + return ( + input.new_empty(output_shape), + input.new_empty(output_shape, dtype=torch.long), + ) + + +@register_meta(aten.logical_not_.default) +def meta_logical_not_(self): + return self + + +@register_meta(aten.repeat.default) +def meta_repeat(self, repeats): + torch._check( + len(repeats) >= self.dim(), + lambda: "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor", + ) + for i, rep in enumerate(repeats): + torch._check( + rep >= 0, + lambda: f"Repeats cannot be negative, found {rep} at index {i}", + ) + # Add new leading dimensions to the tensor if the + # number of target dimensions is larger than the + # number of source dimensions. + num_new_dimensions = len(repeats) - self.dim() + padded_size = (1,) * num_new_dimensions + tuple(self.shape) + target_size = [padded_size[i] * repeats[i] for i in range(len(repeats))] + return self.new_empty(target_size) + + +@register_meta(aten.zero_.default) +def meta_zero_(self): + return self + + +@register_meta( + [ + aten.mul_.Scalar, + aten.div_.Scalar, + aten.mul_.Tensor, + aten.div_.Tensor, + aten.logical_and_.default, + aten.logical_or_.default, + aten.logical_xor_.default, + ], +) +def meta_binop_inplace(self, other): + if isinstance(other, torch.Tensor): + check_inplace_broadcast(self.shape, other.shape) + return self + + +@register_meta( + [ + aten.add_.Scalar, + aten.sub_.Scalar, + aten.add_.Tensor, + aten.sub_.Tensor, + ], +) +def meta_binop_inplace_alpha(self, other, alpha=1): + """ + Some checks for inplace ops. + Checks for promotion rules for some dtypes. + int.add/sub_(float) and bool.add/sub_(others) are rejected. + Promoting in these in-place operations would require reallocating + and copying over elements, hence not allowed. + Checks for alpha param. + """ + + def is_integeric(arg): + if isinstance(arg, TensorLike): + return utils.is_integer_dtype(arg.dtype) + else: + return isinstance(arg, IntLike) + + def is_floatic(arg): + if isinstance(arg, TensorLike): + return utils.is_float_dtype(arg.dtype) + else: + return isinstance(arg, FloatLike) + + def is_booleanic(arg): + if isinstance(arg, TensorLike): + return utils.is_boolean_dtype(arg.dtype) + else: + return isinstance(arg, BoolLike) + + # Do not allow int+float->int in-place + if is_integeric(self) and is_floatic(other): + raise RuntimeError( + "Promotion of int.add/sub_(float) in in-place ops are not possible due to element size change." + ) + + # Do not allow bool+other->bool in-place + if is_booleanic(self) and not is_booleanic(other): + raise RuntimeError( + "Promotion of book.add/sub_(others) in in-place ops are not possible due to element size change." + ) + + if isinstance(other, torch.Tensor): + check_inplace_broadcast(self.shape, other.shape) + return self + + +@register_meta([aten.round.default, aten.round.decimals]) +def meta_round(self, **kwargs): + return elementwise_meta( + self, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + + +def shift_dtype_check(fn_name, self, val): + torch._check( + utils.is_integer_dtype(self.dtype), + lambda: f"{fn_name}: Expected input tensor to have an integral dtype. Got {self.dtype}", + ) + if isinstance(val, torch.Tensor): + torch._check( + utils.is_integer_dtype(val.dtype), + lambda: f"{fn_name}: Expected shift value to have an integral dtype. Got {val.dtype}", + ) + else: + torch._check( + isinstance(val, IntLike), + lambda: f"{fn_name}: Expected shift value to be an int. Got {val}", + ) + + +@register_meta([aten.__rshift__.Tensor, aten.__rshift__.Scalar]) +def meta_rshifts(self, other): + shift_dtype_check("rshift", self, other) + return elementwise_meta( + self, other, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + + +@register_meta([aten.__lshift__.Tensor, aten.__lshift__.Scalar]) +def meta_lshifts(self, other): + shift_dtype_check("lshift", self, other) + return elementwise_meta( + self, other, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + + +@register_meta(aten.zero.default) +def meta_zero(self): + return self.new_empty(self.shape) + + +@register_meta([aten.fill_.Tensor, aten.fill_.Scalar]) +def meta_fill_(self, val): + return self + + +@register_meta([aten.fill.Tensor, aten.fill.Scalar]) +def meta_fill(self, val): + return torch.empty_like(self) + + +@register_meta(aten.relu_.default) +def meta_relu_(self): + return self + + +@register_meta(aten._add_relu.Tensor) +@out_wrapper() +def meta__add_relu(self, other, alpha=1) -> Tensor: + return elementwise_meta( + self, other, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + + +@register_meta([aten.rrelu_with_noise]) +@out_wrapper() +def meta_rrelu_with_noise( + self, noise, lower=0.125, upper=0.3333333333333333, training=False, generator=None +): + return torch.empty_like(self) + + +@register_meta([aten.rrelu_with_noise_functional]) +def meta_rrelu_with_noise_functional( + self, noise, lower=0.125, upper=0.3333333333333333, training=False, generator=None +): + return torch.empty_like(self), torch.empty_like(noise) + + +@register_meta([aten.rrelu_with_noise_]) +def meta_rrelu_with_noise_( + self, lower=0.125, upper=0.3333333333333333, training=False, generator=None +): + return self + + +@register_meta([aten.index_put.default, aten._unsafe_index_put.default]) +def meta_index_put(self, indices, values, accumulate=False): + return torch.empty_like(self) + + +@register_meta(aten.masked_fill_.Scalar) +def meta_masked_fill_(self, mask, value): + check_inplace_broadcast(self.shape, mask.shape) + return self + + +@register_meta(aten._masked_scale.default) +def meta__masked_scale(self, mask, scale): + masked_scale = self.new_empty(self.size()).to( + memory_format=utils.suggest_memory_format(self) + ) + return masked_scale + + +@register_meta(aten.masked_scatter_) +def meta_masked_scatter_(self, mask, source): + torch._check( + mask.dtype in (torch.bool, torch.uint8), lambda: "Mask must be bool or uint8" + ) + torch._check( + self.dtype == source.dtype, + lambda: "masked_scatter: expected self and source to have same " + f"dtypes but got {self.dtype} and {source.dtype}", + ) + return self + + +@register_meta(aten.masked_scatter) +@out_wrapper() +def meta_masked_scatter(self, mask, source): + self, mask = _maybe_broadcast(self, mask) + output = torch.empty_like(self, memory_format=torch.contiguous_format) + return meta_masked_scatter_(output, mask, source) + + +@register_meta(aten.masked_scatter_backward) +def meta_masked_scatter_backward(self, mask, sizes): + return self.new_empty(sizes) + + +@register_meta(aten.index_put_.default) +def meta_index_put_(self, indices, values, accumulate=False): + return self + + +@register_meta(aten.alias.default) +def meta_alias(self): + return self.view(self.shape) + + +def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None, out_dtype=None): + torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor") + torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor") + + batch1_sizes = batch1.size() + batch2_sizes = batch2.size() + + bs = batch1_sizes[0] + contraction_size = batch1_sizes[2] + res_rows = batch1_sizes[1] + res_cols = batch2_sizes[2] + output_size = (bs, res_rows, res_cols) + + torch._check( + batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size, + lambda: f"Expected size for first two dimensions of batch2 tensor to be: [{bs}" + f", {contraction_size}] but got: [{batch2_sizes[0]}, {batch2_sizes[1]}].", + ) + if out_dtype: + supported_out_dtype = ( + batch1.dtype == torch.float16 or batch1.dtype == torch.bfloat16 + ) and out_dtype == torch.float32 + torch._check( + out_dtype == batch1.dtype or supported_out_dtype, + lambda: "out_dtype only supported for torch.float32 output with float16/bfloat16 inputs or same as input dtypes", + ) + output = batch2.new_empty(output_size).to(out_dtype) + else: + # TODO: handle out + output = batch2.new_empty(output_size) + + if not is_bmm and self_baddbmm is not None: + torch._check(self_baddbmm.dim() == 3, lambda: "self must be a 3D tensor") + torch._check( + self_baddbmm.size() == output_size, + lambda: f"Expected an input tensor shape with shape {output_size} but got shape: {self_baddbmm.size()}", + ) + + return output + + +@register_meta(aten.bmm.default) +def meta_bmm(self, mat2): + return common_meta_baddbmm_bmm(self, mat2, True) + + +@register_meta(aten.bmm.dtype) +def meta_bmm_dtype(self, mat2, out_dtype): + return common_meta_baddbmm_bmm(self, mat2, True, out_dtype=out_dtype) + + +def div_rtn(x, y): + q = x // y + r = x % y + # WARNING: explicit bool conversion here is necessary; + # would be fixed by SymBool + if r != 0 and (bool(r < 0) != bool(y < 0)): + q -= 1 + return q + + +def pooling_output_shape_pad_lr( + inputSize, + kernelSize, + pad_l, + pad_r, + stride, + dilation, + ceil_mode, +): + outputSize = ( + div_rtn( + inputSize + + pad_l + + pad_r + - dilation * (kernelSize - 1) + - 1 + + (stride - 1 if ceil_mode else 0), + stride, + ) + + 1 + ) + if ceil_mode: + if (outputSize - 1) * stride >= inputSize + pad_l: + outputSize -= 1 + return outputSize + + +def pooling_output_shape(inputSize, kernelSize, pad, stride, dilation, ceil_mode): + torch._check(stride != 0, lambda: "stride should not be zero") + torch._check(pad >= 0, lambda: f"pad must be non-negative, but got pad: {pad}") + torch._check( + pad <= ((kernelSize - 1) * dilation + 1) // 2, + lambda: ( + f"pad should be at most half of effective kernel size, but got pad={pad}, " + f"kernel_size={kernelSize} and dilation={dilation}" + ), + ) + return pooling_output_shape_pad_lr( + inputSize, kernelSize, pad, pad, stride, dilation, ceil_mode + ) + + +def pool2d_shape_check( + input, + kH, + kW, + dH, + dW, + padH, + padW, + dilationH, + dilationW, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + memory_format, +): + ndim = input.dim() + nOutputPlane = nInputPlane + + torch._check( + kW > 0 and kH > 0, + lambda: "kernel size should be greater than zero, but got kH: {kH}, kW: {kW}", + ) + torch._check( + dW > 0 and dH > 0, + lambda: "stride should be greater than zero, but got dH: {dH}, dW: {dW}", + ) + torch._check( + dilationH > 0 and dilationW > 0, + lambda: "dilation should be greater than zero, but got dilationH: {dilationH}, dilationW: {dilationW}", + ) + + valid_dims = input.size(1) != 0 and input.size(2) != 0 + + if memory_format == torch.channels_last: + torch._check( + ndim == 4 and valid_dims and input.size(3) != 0, + lambda: "Expected 4D (batch mode) tensor expected for input with channels_last layout" + " with optional 0 dim batch size for input, but got: {input.size()}", + ) + else: + torch._check( + (ndim == 3 and input.size(0) != 0 and valid_dims) + or (ndim == 4 and valid_dims and input.size(3) != 0), + lambda: f"Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got: {input.size()}", + ) + + torch._check( + kW // 2 >= padW and kH // 2 >= padH, + lambda: "pad should be smaller than or equal to half of kernel size, but got " + f"padW = {padW}, padH = {padH}, kW = {kW}, kH = {kH}", + ) + + torch._check( + outputWidth >= 1 and outputHeight >= 1, + lambda: f"Given input size: ({nInputPlane}x{inputHeight}x{inputWidth}). " + f"Calculated output size: ({nOutputPlane}x{outputHeight}x{outputWidth}). " + "Output size is too small", + ) + + +def pool3d_shape_check( + input: Tensor, + nslices: int, + kT: int, + kH: int, + kW: int, + dT: int, + dH: int, + dW: int, + pT: int, + pH: int, + pW: int, + dilationT: int, + dilationH: int, + dilationW: int, + itime: int, + iheight: int, + iwidth: int, + otime: int, + oheight: int, + owidth: int, + fn_name: str, + check_input_size: bool = False, +): + ndim = input.ndim + + torch._check( + kT > 0 and kW > 0 and kH > 0, + lambda: ( + f"kernel size should be greater than zero, but got " + f"kT: {kT}, kH: {kH}, kW: {kW}" + ), + ) + torch._check( + dT > 0 and dW > 0 and dH > 0, + lambda: ( + f"stride should be greater than zero, but got dT: {dT}, dH: {dH}, dW: {dW}" + ), + ) + torch._check( + dilationT > 0 and dilationW > 0 and dilationH > 0, + lambda: ( + f"dilation should be greater than zero, but got " + f"dilationT: {dilationT}, dilationH: {dilationH}, dilationW: {dilationW}" + ), + ) + + torch._check( + ndim in (4, 5), + lambda: f"{fn_name}: Expected 4D or 5D tensor for input, but got: {input.shape}", + ) + + for i in range(ndim): + if ndim == 5 and i == 0: + # size of batch-dim can be 0. + continue + torch._check( + input.size(i) > 0, + lambda: ( + f"{fn_name}: Expected input's non-batch dimensions to have positive length," + f" but input has a shape of {input.shape}" + f" and non-batch dimension {input.size(i)} has length zero!" + ), + ) + + if check_input_size: # AveragePool3d + torch._check( + itime >= kT and iheight >= kH and iwidth >= kW, + lambda: ( + f"input image (T: {itime} H: {iheight} W: {iwidth}) smaller than " + f"kernel size (kT: {kT} kH: {kH} kW: {kW})" + ), + ) + + torch._check( + kT / 2 >= pT and kW / 2 >= pW and kH / 2 >= pH, + lambda: ( + f"pad should be smaller than or equal to half of kernel size, but got " + f"kT: {kT} kW: {kW} kH: {kH} padT: {pT} padW: {pW} padH: {pH}" + ), + ) + + torch._check( + otime >= 1 and owidth >= 1 and oheight >= 1, + lambda: ( + f"Given input size: ({nslices}x{itime}x{iheight}x{iwidth}). " + f"Calculated output size: ({nslices}x{otime}x{oheight}x{owidth}). " + f"Output size is too small" + ), + ) + + +def max_pool3d_backward_shape_check( + input, + grad_output, + indices, + nslices, + kT, + kH, + kW, + dT, + dH, + dW, + pT, + pH, + pW, + dilationT, + dilationH, + dilationW, + itime, + iheight, + iwidth, + otime, + oheight, + owidth, + fn_name, +): + ndim = input.ndim + + pool3d_shape_check( + input, + nslices, + kT, + kH, + kW, + dT, + dH, + dW, + pT, + pH, + pW, + dilationT, + dilationH, + dilationW, + itime, + iheight, + iwidth, + otime, + oheight, + owidth, + fn_name, + ) + + check_dim_size(grad_output, ndim, ndim - 4, nslices) + check_dim_size(grad_output, ndim, ndim - 3, otime) + check_dim_size(grad_output, ndim, ndim - 2, oheight) + check_dim_size(grad_output, ndim, ndim - 1, owidth) + + check_dim_size(indices, ndim, ndim - 4, nslices) + check_dim_size(indices, ndim, ndim - 3, otime) + check_dim_size(indices, ndim, ndim - 2, oheight) + check_dim_size(indices, ndim, ndim - 1, owidth) + + +def avg_pool3d_backward_shape_check( + input: Tensor, + grad_output: Tensor, + nslices: int, + kT: int, + kH: int, + kW: int, + dT: int, + dH: int, + dW: int, + pT: int, + pH: int, + pW: int, + itime: int, + iheight: int, + iwidth: int, + otime: int, + oheight: int, + owidth: int, + fn_name: str, +): + ndim = input.ndim + + pool3d_shape_check( + input, + nslices, + kT, + kH, + kW, + dT, + dH, + dW, + pT, + pH, + pW, + 1, + 1, + 1, + itime, + iheight, + iwidth, + otime, + oheight, + owidth, + fn_name, + True, + ) + + check_dim_size(grad_output, ndim, ndim - 4, nslices) + check_dim_size(grad_output, ndim, ndim - 3, otime) + check_dim_size(grad_output, ndim, ndim - 2, oheight) + check_dim_size(grad_output, ndim, ndim - 1, owidth) + + +def max_pool2d_checks_and_compute_shape( + input, + kernel_size, + stride, + padding, + dilation, + ceil_mode, +): + # Reference: aten/src/ATen/native/DilatedMaxPool2d.cpp + def unpack(name, val): + torch._check( + len(val) in [1, 2], + lambda: f"max_pool2d: {name} must either be a single int, or a tuple of two ints", + ) + H = val[0] + W = H if len(val) == 1 else val[1] + return H, W + + kH, kW = unpack("kernel_size", kernel_size) + + torch._check( + len(stride) in [0, 1, 2], + lambda: "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints", + ) + if len(stride) == 0: + dH, dW = kH, kW + else: + dH, dW = unpack("stride", stride) + + padH, padW = unpack("padding", padding) + dilationH, dilationW = unpack("dilation", dilation) + nInputPlane = input.size(-3) + inputHeight = input.size(-2) + inputWidth = input.size(-1) + + memory_format = utils.suggest_memory_format(input) + if memory_format == torch.channels_last: + torch._check( + input.dim() == 4, + lambda: "non-empty 4D (batch mode) tensor expected for input with channels_last layout", + ) + elif memory_format == torch.contiguous_format: + torch._check( + input.dim() in [3, 4], + lambda: "non-empty 3D or 4D (batch mode) tensor expected for input", + ) + else: + torch._check( + False, + lambda: "Unsupport memory format. Supports only ChannelsLast, Contiguous", + ) + + outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode) + outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode) + + pool2d_shape_check( + input, + kH, + kW, + dH, + dW, + padH, + padW, + dilationH, + dilationW, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + memory_format, + ) + + return nInputPlane, outputHeight, outputWidth + + +@register_meta(aten.max_pool2d_with_indices_backward.default) +def meta_max_pool2d_with_indices_backward( + grad_output, + self, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + indices, +): + ( + nInputPlane, + outputHeight, + outputWidth, + ) = max_pool2d_checks_and_compute_shape( + self, kernel_size, stride, padding, dilation, ceil_mode + ) + + torch._check( + self.dtype == grad_output.dtype, + lambda: f"Expected dtype {self.dtype} for `gradOutput` but got dtype {grad_output.dtype}", + ) + + nOutputPlane = nInputPlane + ndim = self.ndim + + def _check_dim_size(t): + check_dim_size(t, ndim, ndim - 3, nOutputPlane) + check_dim_size(t, ndim, ndim - 2, outputHeight) + check_dim_size(t, ndim, ndim - 1, outputWidth) + + _check_dim_size(grad_output) + _check_dim_size(indices) + + memory_format = utils.suggest_memory_format(self) + return torch.empty( + self.shape, + dtype=self.dtype, + device=self.device, + memory_format=memory_format, + ) + + +@register_meta(aten.max_pool2d_with_indices.default) +def meta_max_pool2d_with_indices( + input, + kernel_size, + stride=(), + padding=(0,), + dilation=(1,), + ceil_mode=False, +): + ( + nInputPlane, + outputHeight, + outputWidth, + ) = max_pool2d_checks_and_compute_shape( + input, kernel_size, stride, padding, dilation, ceil_mode + ) + + nbatch = input.size(-4) if input.dim() == 4 else 1 + memory_format = utils.suggest_memory_format(input) + if input.dim() == 3: + size = [nInputPlane, outputHeight, outputWidth] + else: + size = [nbatch, nInputPlane, outputHeight, outputWidth] + return ( + torch.empty( + size, + dtype=input.dtype, + device=input.device, + memory_format=memory_format, + ), + torch.empty( + size, + dtype=torch.int64, + device=input.device, + memory_format=memory_format, + ), + ) + + +@register_meta(aten.fractional_max_pool2d.default) +def meta_fractional_max_pool2d(self, kernel_size, output_size, random_samples): + torch._check( + self.ndim in (3, 4), + lambda: f"fractional_max_pool2d: Expected 3D or 4D tensor, but got: {self.ndim}", + ) + ndim = self.ndim + + for d in range(ndim - 3, ndim): + torch._check( + self.size(d) > 0, + f"fractional_max_pool2d: Expected input to have non-zero " + f" size for non-batch dimenions, but got {self.size()} with dimension {d} empty", + ) + + # the check and message are out of sync, but this matches the structured meta + torch._check( + len(kernel_size) == 2, + lambda: "fractional_max_pool2d: kernel_size must" + "either be a single int or tuple of Ints", + ) + torch._check( + len(output_size) == 2, + lambda: "fractional_max_pool2d: output_size must " + "either be a single int or tuple of Ints", + ) + + input_channels = self.size(-3) + input_height = self.size(-2) + input_width = self.size(-1) + if ndim == 4: + input_batch = self.size(0) + else: + input_batch = 1 + + torch._check( + self.dtype == random_samples.dtype, + lambda: "Expect _random_samples to have the same dtype as input", + ) + torch._check( + random_samples.ndim == 3, + lambda: f"Expect _random samples to have 3 dimensions got, {random_samples.ndim}", + ) + + n = random_samples.size(0) + c = random_samples.size(1) + d = random_samples.size(2) + torch._check( + n >= input_batch, + "Expect _random_samples.size(0) no less then input batch size.", + ) + torch._check( + c == input_channels, + lambda: "Expect _random_samples.size(1) equals to input channel size.", + ) + torch._check(d == 2, lambda: f"Expect _random_samples.size(2) equals to 2 got {d}.") + + torch._check( + output_size[0] + kernel_size[0] - 1 <= input_height, + lambda: f"fractional_max_pool2d: kernel height {kernel_size[0]} is too large relative to input height {input_height}", + ) + torch._check( + output_size[1] + kernel_size[1] - 1 <= input_width, + lambda: f"fractional_max_pool2d: kernel width {kernel_size[1]} is too large relative to input width {input_width}", + ) + + if self.dim() == 4: + size = [input_batch, input_channels, output_size[0], output_size[1]] + else: + size = [input_channels, output_size[0], output_size[1]] + + return ( + torch.empty( + size, + dtype=self.dtype, + device=self.device, + ), + torch.empty( + size, + dtype=torch.int64, + device=self.device, + ), + ) + + +@register_meta(aten.max_pool3d_with_indices) +@out_wrapper("out", "indices") +def meta_max_pool3d_with_indices( + input, + kernel_size, + stride=(), + padding=(0,), + dilation=(1,), + ceil_mode=False, +): + torch._check( + len(kernel_size) in (1, 3), + lambda: "max_pool3d: kernel_size must either be a single int, or a tuple of three ints", + ) + kT = kernel_size[0] + kH = kT if len(kernel_size) == 1 else kernel_size[1] + kW = kT if len(kernel_size) == 1 else kernel_size[2] + + torch._check( + not stride or len(stride) in (1, 3), + lambda: "max_pool3d: stride must either be omitted, a single int, or a tuple of three ints", + ) + dT = kT if not stride else stride[0] + dH = kH if not stride else (dT if len(stride) == 1 else stride[1]) + dW = kW if not stride else (dT if len(stride) == 1 else stride[2]) + + torch._check( + len(padding) in (1, 3), + lambda: "max_pool3d: padding must either be a single int, or a tuple of three ints", + ) + pT = padding[0] + pH = pT if len(padding) == 1 else padding[1] + pW = pT if len(padding) == 1 else padding[2] + + torch._check( + len(dilation) in (1, 3), + lambda: "max_pool3d: dilation must be either a single int, or a tuple of three ints", + ) + dilationT = dilation[0] + dilationH = dilationT if len(dilation) == 1 else dilation[1] + dilationW = dilationT if len(dilation) == 1 else dilation[2] + + torch._check( + input.ndim in (4, 5), + lambda: "non-empty 4D or 5D (batch mode) tensor expected for input", + ) + + nbatch = input.size(-5) if input.ndim == 5 else 1 + nslices = input.size(-4) + itime = input.size(-3) + iheight = input.size(-2) + iwidth = input.size(-1) + + otime = pooling_output_shape(itime, kT, pT, dT, dilationT, ceil_mode) + oheight = pooling_output_shape(iheight, kH, pH, dH, dilationH, ceil_mode) + owidth = pooling_output_shape(iwidth, kW, pW, dW, dilationW, ceil_mode) + + pool3d_shape_check( + input, + nslices, + kT, + kH, + kW, + dT, + dH, + dW, + pT, + pH, + pW, + dilationT, + dilationH, + dilationW, + itime, + iheight, + iwidth, + otime, + oheight, + owidth, + "max_pool3d_with_indices()", + ) + + channels_last = ( + input.ndim == 5 and utils.suggest_memory_format(input) == torch.channels_last_3d + ) + if input.ndim == 4: + input_channels_last_check = input.unsqueeze(0) + channels_last = ( + not input_channels_last_check.is_contiguous() + ) and input_channels_last_check.is_contiguous( + memory_format=torch.channels_last_3d + ) + out_shape = (nslices, otime, oheight, owidth) + else: + out_shape = (nbatch, nslices, otime, oheight, owidth) # type: ignore[assignment] + + out = input.new_empty(out_shape) + indices = input.new_empty(out_shape, dtype=torch.int64) + + if channels_last: + out = out.to(memory_format=torch.channels_last_3d) + indices = indices.to(memory_format=torch.channels_last_3d) + + return out, indices + + +@register_meta(aten.max_pool3d_with_indices_backward) +@out_wrapper("grad_input") +def meta_max_pool3d_with_indices_backward( + grad_output, + input, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + indices, +): + torch._check( + len(kernel_size) in (1, 3), + lambda: "max_pool3d: kernel_size must either be a single int, or a tuple of three ints", + ) + kT = kernel_size[0] + kH = kT if len(kernel_size) == 1 else kernel_size[1] + kW = kT if len(kernel_size) == 1 else kernel_size[2] + + torch._check( + not stride or len(stride) in (1, 3), + lambda: "max_pool3d: stride must either be omitted, a single int, or a tuple of three ints", + ) + dT = kT if not stride else stride[0] + dH = kH if not stride else (dT if len(stride) == 1 else stride[1]) + dW = kW if not stride else (dT if len(stride) == 1 else stride[2]) + + torch._check( + len(padding) in (1, 3), + lambda: "max_pool3d: padding must either be a single int, or a tuple of three ints", + ) + pT = padding[0] + pH = pT if len(padding) == 1 else padding[1] + pW = pT if len(padding) == 1 else padding[2] + + torch._check( + len(dilation) in (1, 3), + lambda: "max_pool3d: dilation must be either a single int, or a tuple of three ints", + ) + dilationT = dilation[0] + dilationH = dilationT if len(dilation) == 1 else dilation[1] + dilationW = dilationT if len(dilation) == 1 else dilation[2] + + torch._check( + input.ndim in (4, 5), + lambda: "non-empty 4D or 5D (batch mode) tensor expected for input", + ) + + nslices = input.size(-4) + itime = input.size(-3) + iheight = input.size(-2) + iwidth = input.size(-1) + + otime = grad_output.size(-3) + oheight = grad_output.size(-2) + owidth = grad_output.size(-1) + + max_pool3d_backward_shape_check( + input, + grad_output, + indices, + nslices, + kT, + kH, + kW, + dT, + dH, + dW, + pT, + pH, + pW, + dilationT, + dilationH, + dilationW, + itime, + iheight, + iwidth, + otime, + oheight, + owidth, + "max_pool3d_with_indices_backward()", + ) + + channels_last = ( + input.ndim == 5 and utils.suggest_memory_format(input) == torch.channels_last_3d + ) + if input.ndim == 4: + input_channels_last_check = input.unsqueeze(0) + channels_last = ( + not input_channels_last_check.is_contiguous() + ) and input_channels_last_check.is_contiguous( + memory_format=torch.channels_last_3d + ) + + grad_input = input.new_empty(input.shape) + + if channels_last: + grad_input = grad_input.to(memory_format=torch.channels_last_3d) + + return grad_input + + +def check_grid_sampler_common(input: Tensor, grid: Tensor): + torch._check( + input.device == grid.device, + lambda: ( + f"grid_sampler(): expected input and grid to be on same device, but input " + f"is on {input.device} and grid is on {grid.device}" + ), + ) + torch._check( + input.layout == torch.strided and grid.layout == torch.strided, + lambda: ( + f"grid_sampler(): expected input and grid to have torch.strided layout, but " + f"input has {input.layout} and grid has {grid.layout}" + ), + ) + torch._check( + input.shape[0] == grid.shape[0], + lambda: ( + f"grid_sampler(): expected grid and input to have same batch size, but got " + f"input with sizes {input.shape} and grid with sizes {grid.shape}" + ), + ) + torch._check( + grid.shape[-1] == input.ndim - 2, + lambda: ( + f"grid_sampler(): expected grid to have size {input.ndim - 2} in last " + f"dimension, but got grid with sizes {grid.shape}" + ), + ) + + for i in range(2, input.ndim): + torch._check( + input.shape[i] > 0, + lambda: ( + f"grid_sampler(): expected input to have non-empty spatial dimensions, " + f"but input has sizes {input.shape} with dimension {i} being empty" + ), + ) + + +class GridSamplerInterpolation(Enum): + BILINEAR = 0 + NEAREST = 1 + BICUBIC = 2 + + +def check_grid_sampler_3d(input: Tensor, grid: Tensor, interpolation_mode: int): + torch._check( + input.ndim == 5 and input.ndim == grid.ndim, + lambda: ( + f"grid_sampler(): expected 5D input and grid with same number of " + f"dimensions, but got input with sizes {input.shape}" + f" and grid with sizes {grid.shape}" + ), + ) + torch._check( + not ( + input.ndim == 5 + and interpolation_mode == GridSamplerInterpolation.BICUBIC.value + ), + lambda: "grid_sampler(): bicubic interpolation only supports 4D input", + ) + + +@register_meta(aten.grid_sampler_2d_backward.default) +def grid_sampler_2d_backward_meta( + grad_output, + input, + grid, + interpolation_mode, + padding_mode, + align_corners, + output_mask, +): + input_requires_grad = output_mask[0] + if input_requires_grad: + grad_input = torch.zeros_like(input, memory_format=torch.contiguous_format) + else: + grad_input = None + grad_grid = torch.empty_like(grid, memory_format=torch.contiguous_format) + return (grad_input, grad_grid) + + +@register_meta(aten.grid_sampler_3d) +@out_wrapper() +def grid_sampler_3d( + input, + grid, + interpolation_mode, + padding_mode, + align_corners, +): + check_grid_sampler_common(input, grid) + check_grid_sampler_3d(input, grid, interpolation_mode) + N = input.shape[0] + C = input.shape[1] + out_D = grid.shape[1] + out_H = grid.shape[2] + out_W = grid.shape[3] + return input.new_empty((N, C, out_D, out_H, out_W)) + + +@register_meta(aten.grid_sampler_3d_backward) +@out_wrapper("grad_input", "grad_grid") +def grid_sampler_3d_backward( + grad_output, + input, + grid, + interpolation_mode, + padding_mode, + align_corners, + output_mask, +): + check_grid_sampler_common(input, grid) + check_grid_sampler_3d(input, grid, interpolation_mode) + input_requires_grad = output_mask[0] + if input_requires_grad: + grad_input = torch.zeros_like( + input, memory_format=torch.legacy_contiguous_format + ) + else: + grad_input = None + grad_grid = torch.empty_like(grid, memory_format=torch.legacy_contiguous_format) + return grad_input, grad_grid + + +@register_meta([aten.full.default]) +def full(size, fill_value, *args, **kwargs): + dtype = kwargs.get("dtype", None) + if not dtype: + dtype = utils.get_dtype(fill_value) + kwargs["dtype"] = dtype + return torch.empty(size, *args, **kwargs) + + +# zeros_like is special cased to work for sparse +@register_meta(aten.zeros_like.default) +def zeros_like( + self, + dtype=None, + layout=None, + device=None, + pin_memory=None, + memory_format=None, +): + if layout == torch.sparse_coo: + torch._check( + memory_format is None, + lambda: "memory format option is only supported by strided tensors", + ) + + res = torch.empty( + 0, + dtype=self.dtype if dtype is None else dtype, + layout=layout, + device=self.device if device is None else device, + pin_memory=pin_memory, + ) + + if self.is_sparse: + res.sparse_resize_and_clear_( + self.size(), self.sparse_dim(), self.dense_dim() + ) + else: + res.sparse_resize_and_clear_(self.size(), self.dim(), 0) + + res._coalesced_(True) + return res + res = aten.empty_like.default( + self, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + memory_format=memory_format, + ) + # device can be not "meta" + res.fill_(0) + return res + + +@register_meta([aten.ones.default, aten.ones.out]) +@out_wrapper() +def meta_ones( + size, + *, + dtype=None, + layout=None, + device=None, + pin_memory=None, + requires_grad=False, +): + if dtype is None: + dtype = torch.get_default_dtype() + if device is None: + device = torch.get_default_device() + if layout is None: + layout = torch.strided + return torch.empty( + size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + +@register_meta([aten.zeros.default, aten.zeros.out]) +@out_wrapper() +def meta_zeros( + size, + *, + dtype=None, + layout=None, + device=None, + pin_memory=None, + requires_grad=False, +): + if dtype is None: + dtype = torch.get_default_dtype() + if device is None: + device = torch.get_default_device() + if layout is None: + layout = torch.strided + return torch.empty( + size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + +@register_meta(aten.select.int) +def meta_select(self, dim, index): + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + ndim = self.dim() + torch._check_index( + ndim != 0, + lambda: "select() cannot be applied to a 0-dim tensor.", + ) + + dim = dim if dim >= 0 else dim + ndim + size = self.size(dim) + + torch._check_index( + not ( + guard_size_oblivious(-index > size) or guard_size_oblivious(index >= size) + ), + lambda: f"select(): index {index} out of range for tensor of size " + f"{self.size()} at dimension {dim}", + ) + + index = index if index >= 0 else index + size + + new_size = list(self.size()) + new_stride = list(self.stride()) + + new_storage_offset = self.storage_offset() + index * new_stride[dim] + del new_size[dim] + del new_stride[dim] + + return self.as_strided(new_size, new_stride, new_storage_offset) + + +@register_meta(aten.select_scatter.default) +def meta_select_scatter(self, src, dim, index): + return utils.clone_preserve_strides(self) + + +@register_meta(aten.slice_scatter.default) +def meta_slice_scatter(self, src, dim=0, start=None, end=None, step=1): + return utils.clone_preserve_strides(self) + + +# TODO: Deduplicate this with canonicalize_dim +def maybe_wrap_dim(dim: int, dim_post_expr: int, wrap_scalar: bool = True): + if dim_post_expr <= 0: + assert wrap_scalar + dim_post_expr = 1 + min = -dim_post_expr + max = dim_post_expr - 1 + assert not (dim < min or dim > max), f"dim {dim} out of bounds ({min}, {max})" + if dim < 0: + dim += dim_post_expr + return dim + + +def ensure_nonempty_size(t, dim): + return 1 if t.dim() == 0 else t.shape[dim] + + +# From aten/src/ATen/native/ScatterGatherChecks.h +def gather_shape_check(self, dim, index): + self_dims = max(self.dim(), 1) + index_dims = max(index.dim(), 1) + torch._check( + self_dims == index_dims, + lambda: "Index tensor must have the same number of dimensions as input tensor", + ) + for i in range(self_dims): + if i != dim: + torch._check( + ensure_nonempty_size(index, i) <= ensure_nonempty_size(self, i), + lambda: f"Size does not match at dimension {i} expected index {index.shape}" + + f" to be no larger than self {self.shape} apart from dimension {dim}", + ) + + +@register_meta(aten.gather.default) +def meta_gather(self, dim, index, sparse_grad=False): + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + wrapped_dim = maybe_wrap_dim(dim, self.dim()) + is_index_empty = guard_size_oblivious(index.numel() == 0) + if not is_index_empty: + torch._check( + index.dtype == torch.long or index.dtype == torch.int, + lambda: f"gather(): Expected dtype int32/int64 for index, but got {index.dtype}", + ) + gather_shape_check(self, wrapped_dim, index) + return self.new_empty(index.shape) + + +# From aten/src/ATen/native/TensorAdvancedIndexing.cpp +def get_operator_enum(reduce_, use_new_options=False): + if use_new_options: + if reduce_ == "sum": + return "REDUCE_ADD" + elif reduce_ == "prod": + return "REDUCE_MULTIPLY" + elif reduce_ == "mean": + return "REDUCE_MEAN" + elif reduce_ == "amax": + return "REDUCE_MAXIMUM" + elif reduce_ == "amin": + return "REDUCE_MINIMUM" + torch._check( + False, + lambda: "reduce argument must be either sum, prod, mean, amax or amin.", + ) + return + else: + if reduce_ == "add": + return "REDUCE_ADD" + elif reduce_ == "multiply": + return "REDUCE_MULTIPLY" + torch._check(False, lambda: "reduce argument must be either add or multiply.") + return + + +# From aten/src/ATen/native/ScatterGatherChecks.h +def scatter_gather_dtype_check(method_name, self, index, src_opt=None): + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if guard_size_oblivious(index.numel() != 0): + torch._check( + index.dtype == torch.long or index.dtype == torch.int, + lambda: f"{method_name}(): Expected dtype int32/int64 for index", + ) + + if src_opt is not None: + torch._check( + self.dtype == src_opt.dtype, + lambda: f"{method_name}(): Expected self.dtype to be equal to src.dtype", + ) + + +def ensure_nonempty_dim(dim): + return max(dim, 1) + + +# From aten/src/ATen/native/ScatterGatherChecks.h +def scatter_shape_check(self, dim, index, src_opt=None): + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if guard_size_oblivious(index.numel() == 0): + return + torch._check( + ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()), + lambda: "Index tensor must have the same number of dimensions as self tensor", + ) + + is_wrong_shape = False + self_dims = ensure_nonempty_dim(self.dim()) + + # Check: index.size(d) <= self.size(d) for all d != dim + for d in range(self_dims): + index_d_size = ensure_nonempty_size(index, d) + if d == dim: + continue + if index_d_size > ensure_nonempty_size(self, d): + is_wrong_shape = True + break + + # Check: index.size(d) <= src.size(d) for all d if src is Tensor + if not is_wrong_shape and src_opt is not None: + for d in range(self_dims): + index_d_size = ensure_nonempty_size(index, d) + if index_d_size > ensure_nonempty_size(src_opt, d): + is_wrong_shape = True + break + + if src_opt is not None: + torch._check( + ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()), + lambda: "Index tensor must have the same number of dimensions as self tensor", + ) + torch._check( + not is_wrong_shape, + lambda: f"Expected index {index.shape} to be no larger than self {self.shape}" + + f" apart from dimension {dim} and to be no larger than src {src_opt.shape}", + ) + else: + torch._check( + not is_wrong_shape, + lambda: f"Expected index {index.shape} to be no larger than self {self.shape}" + + f" apart from dimension {dim}", + ) + + +# From aten/src/ATen/native/TensorAdvancedIndexing.cpp +def scatter_meta_impl(self, dim, index, src=None, reduce_=None, use_new_options=False): + wrapped_dim = maybe_wrap_dim(dim, self.dim()) + scatter_gather_dtype_check("scatter", self, index, src) + scatter_shape_check(self, wrapped_dim, index, src) + if reduce_ is not None: + # Check if we have a valid reduce operator. + get_operator_enum(reduce_, use_new_options) + + +@register_meta(aten.scatter_add.default) +def meta_scatter_add(self, dim, index, src): + scatter_meta_impl(self, dim, index, src, "add") + return self.new_empty(self.shape) + + +@register_meta(aten.scatter_add_) +def meta_scatter_add_(self, dim, index, src): + scatter_meta_impl(self, dim, index, src, "add") + return self + + +@register_meta( + [ + aten.scatter.src, + aten.scatter.value, + aten.scatter.reduce, + aten.scatter.value_reduce, + ] +) +@out_wrapper() +def meta_scatter(self, dim, index, src_or_value, reduce=None): + src = src_or_value if isinstance(src_or_value, torch.Tensor) else None + scatter_meta_impl(self, dim, index, src, reduce) + return self.new_empty(self.shape) + + +@register_meta( + [ + aten.scatter_.src, + aten.scatter_.value, + aten.scatter_.reduce, + aten.scatter_.value_reduce, + ] +) +def meta_scatter_(self, dim, index, src_or_value, reduce=None): + src = src_or_value if isinstance(src_or_value, torch.Tensor) else None + scatter_meta_impl(self, dim, index, src, reduce) + return self + + +@register_meta([aten._scaled_dot_product_flash_attention]) +def meta__scaled_dot_product_flash_attention( + query: Tensor, + key: Tensor, + value: Tensor, + dropout_p: float = 0.0, + is_causal: bool = False, + return_debug_mask: bool = False, + scale: Optional[float] = None, +): + batch_size = query.size(0) + num_heads = query.size(1) + max_seqlen_batch_q = query.size(2) + head_dim = query.size(3) + max_seqlen_batch_k = key.size(2) + + query_t = query.transpose(1, 2) + attention = torch.empty_like(query_t).transpose(1, 2) + logsumexp = torch.empty( + (batch_size, num_heads, max_seqlen_batch_q), + dtype=torch.float, + device=query.device, + ) + + if return_debug_mask: + blocksize_c = 128 if head_dim > 64 else 256 + max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c) + if max_seqlen_batch_k <= 128: + max_seqlen_k = 128 + elif max_seqlen_batch_k <= 256: + max_seqlen_k = 256 + debug_mask = torch.empty( + (batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k), + dtype=query.dtype, + device=query.device, + ) + else: + debug_mask = torch.empty(0, dtype=query.dtype, device=query.device) + + # Note [Seed and Offset]: device for seed and offset below depends on whether we are + # capturing or not, but at the time of tracing we don't know if we + # are going to use cudagraphs or not, so we return meta tensors here + # it's possible we'll need to have some special handling in inductor for sdpa + # See [Note] BC breaking change to flash seed/offset + if torch.version.hip and torch.cuda.is_available(): + # Maintian old path on AMD + seed = torch.empty((), dtype=torch.long, device="meta") + offset = torch.empty((), dtype=torch.long, device="meta") + else: + seed = torch.empty((2), dtype=torch.uint64, device="meta") + offset = torch.empty((), dtype=torch.uint64, device="meta") + + return ( + attention, + logsumexp, + None, + None, + max_seqlen_batch_q, + max_seqlen_batch_k, + seed, + offset, + debug_mask, + ) + + +def alloc_with_matching_layout( + query: Tensor, + res_shape: tuple[int, ...], +): + if tuple(query.shape) == res_shape: + query_t = query.transpose(1, 2) + res = torch.empty_like(query_t).transpose(1, 2) + else: + dim_order = sorted( + [0, 1, 2, 3], key=lambda idx: query.stride()[idx], reverse=True + ) + permuted_shape = [res_shape[idx] for idx in dim_order] + final_permute = [dim_order.index(i) for i in range(len(dim_order))] + res = torch.empty( + permuted_shape, dtype=query.dtype, device=query.device + ).permute(final_permute) + + return res + + +@register_meta([aten._scaled_dot_product_cudnn_attention]) +def meta__scaled_dot_product_cudnn_attention( + query: Tensor, + key: Tensor, + value: Tensor, + attn_bias: Optional[Tensor], + compute_log_sumexp: bool, + dropout_p: float = 0.0, + is_causal: bool = False, + return_debug_mask: bool = False, + scale: Optional[float] = None, +): + B = query.size(0) + H = query.size(1) + S_Q = query.size(2) + S_KV = key.size(2) + D_V = value.size(-1) + + res_shape = (B, H, S_Q, D_V) + res = alloc_with_matching_layout(query, res_shape) + + logsum_exp = torch.empty( + (B, H, S_Q), + dtype=torch.float, + device=query.device, + ) + + # See Note [Seed and Offset] + seed = torch.empty((), dtype=torch.long, device="meta") + offset = torch.empty((), dtype=torch.long, device="meta") + + return ( + res, + logsum_exp, + None, + None, + S_Q, + S_KV, + seed, + offset, + None, + ) + + +@register_meta([aten._scaled_dot_product_fused_attention_overrideable]) +def meta__scaled_dot_product_fused_attention_overrideable( + query: Tensor, + key: Tensor, + value: Tensor, + attn_bias: Optional[Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + return_debug_mask: bool = False, + scale: Optional[float] = None, +): + B = query.size(0) + H_Q = query.size(1) + S_Q = query.size(2) + S_KV = key.size(2) + D_V = value.size(-1) + + res_shape = (B, H_Q, S_Q, D_V) + res = alloc_with_matching_layout(query, res_shape) + + logsum_exp = torch.empty( + (B, H_Q, S_Q), + dtype=torch.float, + device=query.device, + ) + + # See Note [Seed and Offset] + seed = torch.empty((), dtype=torch.long, device="meta") + offset = torch.empty((), dtype=torch.long, device="meta") + + return ( + res, + logsum_exp, + None, + None, + S_Q, + S_KV, + seed, + offset, + None, + ) + + +@register_meta( + [ + aten._scaled_dot_product_flash_attention_backward, + ] +) +def meta__scaled_dot_product_flash_backward( + grad_out: Tensor, + query: Tensor, + key: Tensor, + value: Tensor, + out: Tensor, + logsumexp: Tensor, + cum_seq_q: Tensor, + cum_seq_k: Tensor, + max_q: int, + max_k: int, + dropout_p: float, + is_causal: bool, + philox_seed: Tensor, + philox_offset: Tensor, + scale: Optional[float] = None, +): + grad_q = torch.empty_like(query.transpose(1, 2)).transpose(1, 2) + grad_k = torch.empty_like(key.transpose(1, 2)).transpose(1, 2) + grad_v = torch.empty_like(value.transpose(1, 2)).transpose(1, 2) + return grad_q, grad_k, grad_v + + +@register_meta( + [ + aten._scaled_dot_product_flash_attention_for_cpu, + ] +) +def meta__scaled_dot_product_flash_attention_for_cpu( + query: Tensor, + key: Tensor, + value: Tensor, + dropout_p: float = 0.0, + is_causal: bool = False, + attn_mask: Optional[Tensor] = None, + scale: Optional[float] = None, +): + batch_size = query.size(0) + num_heads = query.size(1) + max_seqlen_batch_q = query.size(2) + + attention = torch.empty_like(query) + logsumexp = torch.empty( + ( + batch_size, + max_seqlen_batch_q, + num_heads, + ), + dtype=torch.float, + device=query.device, + ).transpose(1, 2) + return ( + attention, + logsumexp, + ) + + +@register_meta( + [ + aten._scaled_dot_product_flash_attention_for_cpu_backward, + ] +) +def meta__scaled_dot_product_flash_attention_for_cpu_backward( + grad_out: Tensor, + query: Tensor, + key: Tensor, + value: Tensor, + out: Tensor, + logsumexp: Tensor, + dropout_p: float, + is_causal: bool, + attn_mask: Optional[Tensor] = None, + scale: Optional[float] = None, +): + # cpus's grad layout is different from cuda's, + # i.e. (batch_size, seq_len,num_heads, head_dim) + batch_size = query.size(0) + num_heads = query.size(1) + head_dim = query.size(3) + len_q = query.size(2) + len_k = key.size(2) + + grad_q = torch.empty_permuted( + (batch_size, num_heads, len_q, head_dim), + (0, 2, 1, 3), + dtype=query.dtype, + device=query.device, + ) + grad_k = torch.empty_permuted( + (batch_size, num_heads, len_k, head_dim), + (0, 2, 1, 3), + dtype=key.dtype, + device=key.device, + ) + grad_v = torch.empty_permuted( + (batch_size, num_heads, len_k, head_dim), + (0, 2, 1, 3), + dtype=value.dtype, + device=value.device, + ) + + return grad_q, grad_k, grad_v + + +@register_meta([aten._scaled_dot_product_efficient_attention]) +def meta__scaled_dot_product_efficient_attention( + query: Tensor, + key: Tensor, + value: Tensor, + attn_bias: Optional[Tensor], + compute_log_sumexp: bool, + dropout_p=0.0, + is_causal: bool = False, + scale: Optional[float] = None, +): + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + B = query.size(0) + M = query.size(1) + num_heads = query.size(-2) + Kv = value.size(-1) + + res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device) + + if torch.version.hip and torch.cuda.is_available(): + """Please see: https://github.com/pytorch/pytorch/issues/146848 + longsumexp last dim should be seq length + """ + logsumexp_dim = M if compute_log_sumexp else 0 + else: + logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0 + + logsum_exp = torch.empty( + (B, num_heads, logsumexp_dim), + dtype=torch.float, + device=query.device, + ) + + res = res.transpose(1, 2) + + # See Note [Seed and Offset]: + seed = torch.empty((), dtype=torch.long, device="meta") + offset = torch.empty((), dtype=torch.long, device="meta") + + return res, logsum_exp, seed, offset + + +@register_meta( + [ + aten._scaled_dot_product_efficient_attention_backward, + ] +) +def meta__scaled_dot_product_efficient_backward( + grad_out: Tensor, + query: Tensor, + key: Tensor, + value: Tensor, + attn_bias: Optional[Tensor], + out: Tensor, + logsumexp: Tensor, + philox_seed: Tensor, + philox_offset: Tensor, + dropout_p: float, + grad_input_mask: list[bool], + is_causal: bool = False, + scale: Optional[float] = None, +): + batch_size = query.size(0) + num_heads = query.size(1) + max_q = query.size(2) + head_dim = query.size(3) + head_dim_v = value.size(3) + + max_k = key.size(2) + + grad_q = torch.empty_permuted( + (batch_size, num_heads, max_q, head_dim), + (0, 2, 1, 3), + dtype=query.dtype, + device=query.device, + ) + grad_k = torch.empty_permuted( + (batch_size, num_heads, max_k, head_dim), + (0, 2, 1, 3), + dtype=key.dtype, + device=key.device, + ) + grad_v = torch.empty_permuted( + (batch_size, num_heads, max_k, head_dim_v), + (0, 2, 1, 3), + dtype=value.dtype, + device=value.device, + ) + grad_bias = None + if attn_bias is not None and grad_input_mask[3]: + lastDim = attn_bias.size(-1) + lastDimAligned = lastDim if lastDim % 16 == 0 else lastDim + 16 - lastDim % 16 + new_sizes = list(attn_bias.size()) + new_sizes[-1] = lastDimAligned + grad_bias = torch.empty( + new_sizes, dtype=attn_bias.dtype, device=attn_bias.device + ) + grad_bias = grad_bias[..., :lastDim] + + return grad_q, grad_k, grad_v, grad_bias + + +@register_meta( + [ + aten._scaled_dot_product_cudnn_attention_backward, + ] +) +def meta__scaled_dot_product_cudnn_backward( + grad_out: Tensor, + query: Tensor, + key: Tensor, + value: Tensor, + out: Tensor, + logsumexp: Tensor, + philox_seed: Tensor, + philox_offset: Tensor, + attn_bias: Tensor, + cum_seq_q: Tensor, + cum_seq_k: Tensor, + max_q: int, + max_k: int, + dropout_p: float, + is_causal: bool, + scale: Optional[float] = None, +): + grad_q = torch.empty_like(query) + grad_k = torch.empty_like(key) + grad_v = torch.empty_like(value) + return grad_q, grad_k, grad_v + + +@register_meta( + [ + aten._flash_attention_forward, + ] +) +def meta__flash_attention_forward( + query: Tensor, + key: Tensor, + value: Tensor, + cum_seq_q: Optional[Tensor], + cum_seq_k: Optional[Tensor], + max_q: int, + max_k: int, + dropout_p: float, + is_causal: bool, + return_debug_mask: bool, + scale: Optional[float] = None, + window_size_left: Optional[int] = None, + window_size_right: Optional[int] = None, + seqused_k: Optional[Tensor] = None, + alibi_slopes: Optional[Tensor] = None, +): + # NB: there are two underlying paths: + # 1. normal dense path; expect 4D inputs of shape (batch_size, seqlen, num_heads, head_dim) + # 2. varseqlen path; expect 3D inputs of shape (total, num_heads, head_dim) where total + # includes all batch item sequences. cum_seq_q / cum_seq_k contain offsets into total + batch_size = query.size(0) if cum_seq_q is None else cum_seq_q.numel() - 1 + max_seqlen_batch_q = query.size(1) if cum_seq_q is None else max_q + max_seqlen_batch_k = key.size(1) if cum_seq_k is None else max_k + num_heads = query.size(-2) + head_dim = query.size(-1) + + # Cuda Path + attention = torch.empty_like(query) + if cum_seq_q is None: + logsumexp = torch.empty( + (batch_size, num_heads, max_seqlen_batch_q), + dtype=torch.float, + device=query.device, + ) + else: + total_q = query.size(0) + logsumexp = torch.empty( + (num_heads, total_q), dtype=torch.float, device=query.device + ) + + if return_debug_mask: + blocksize_c = 128 if head_dim > 64 else 256 + max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c) + if max_seqlen_batch_k <= 128: + max_seqlen_k = 128 + elif max_seqlen_batch_k <= 256: + max_seqlen_k = 256 + debug_mask = torch.empty( + (batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k), + dtype=query.dtype, + device=query.device, + ) + else: + debug_mask = torch.empty(0, dtype=query.dtype, device=query.device) + + # See Note [Seed and Offset] + # See [Note] BC breaking change to flash seed/offset + seed, offset = None, None + if torch.version.hip and torch.cuda.is_available(): + # Maintian old path on AMD + seed = torch.empty((), dtype=torch.long, device="meta") + offset = torch.empty((), dtype=torch.long, device="meta") + else: + seed = torch.empty((2), dtype=torch.uint64, device="meta") + offset = torch.empty((), dtype=torch.uint64, device="meta") + return ( + attention, + logsumexp, + seed, + offset, + debug_mask, + ) + + +@register_meta( + [ + aten._flash_attention_backward, + ] +) +def meta__flash_attention_backward( + grad_out: Tensor, + query: Tensor, + key: Tensor, + value: Tensor, + out: Tensor, + logsumexp: Tensor, + cum_seq_q: Tensor, + cum_seq_k: Tensor, + max_q: int, + max_k: int, + dropout_p: float, + is_causal: bool, + philox_seed: Tensor, + philox_offset: Tensor, + scale: Optional[float] = None, + window_size_left: Optional[int] = None, + window_size_right: Optional[int] = None, +): + grad_query = torch.empty_like(query) + grad_key = torch.empty_like(key) + grad_value = torch.empty_like(value) + + return grad_query, grad_key, grad_value + + +@register_meta( + [ + aten._efficient_attention_forward, + ] +) +def meta__efficient_attention_forward( + query: Tensor, + key: Tensor, + value: Tensor, + bias: Optional[Tensor], + cu_seqlens_q: Optional[Tensor], + cu_seqlens_k: Optional[Tensor], + max_seqlen_q: Optional[int], + max_seqlen_k: Optional[int], + dropout_p: float, + custom_mask_type: int, + compute_log_sumexp: bool = False, + scale: Optional[float] = None, + causal_diagonal: Optional[Tensor] = None, + seqlen_k: Optional[Tensor] = None, + window_size: Optional[int] = None, +): + B = query.size(0) + M = query.size(1) + N = key.size(1) + num_heads = query.size(-2) + Kv = value.size(-1) + + res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device) + + logsumexp_batch_dim = cu_seqlens_q.size(0) - 1 if (cu_seqlens_q is not None) else B + actual_max_seqlen_q = M + if cu_seqlens_q is not None: + assert max_seqlen_q is not None + actual_max_seqlen_q = max_seqlen_q + actual_max_seqlen_k = max_seqlen_k if max_seqlen_k is not None else N + logsumexp_dim = ( + math.ceil(actual_max_seqlen_q / 32) * 32 if compute_log_sumexp else 0 + ) + logsum_exp = torch.empty( + (logsumexp_batch_dim, num_heads, logsumexp_dim), + dtype=torch.float, + device=query.device, + ) + + # See Note [Seed and Offset]: + seed = torch.empty((), dtype=torch.long, device="meta") + offset = torch.empty((), dtype=torch.long, device="meta") + + return res, logsum_exp, seed, offset, actual_max_seqlen_q, actual_max_seqlen_k + + +@register_meta( + [ + aten._efficient_attention_backward, + ] +) +def meta__efficient_attention_backward( + grad_out: Tensor, + query: Tensor, + key: Tensor, + value: Tensor, + bias: Optional[Tensor], + cu_seqlens_q: Optional[Tensor], + cu_seqlens_k: Optional[Tensor], + max_seqlen_q: torch.SymInt, + max_seqlen_k: torch.SymInt, + logsumexp: Tensor, + dropout_p: float, + philox_seed: Tensor, + philox_offset: Tensor, + custom_mask_type: int, + bias_requires_grad: bool, + scale: Optional[float] = None, + num_splits_key: Optional[int] = None, + shared_storage_dqdkdv: bool = False, +): + if shared_storage_dqdkdv: + torch._check( + query.shape[1] == key.shape[1], + lambda: "seqlen must match for `shared_storage_dqdkdv", + ) + torch._check( + query.shape[3] == key.shape[3], + lambda: "embedding dim must match for `shared_storage_dqdkdv", + ) + chunk = torch.empty( + (*query.shape[0:-2], 3, query.shape[-2], query.shape[-1]), + dtype=query.dtype, + device=query.device, + ) + grad_query = chunk.select(-3, 0) + grad_key = chunk.select(-3, 1) + grad_value = chunk.select(-3, 2) + else: + grad_query = torch.empty_like(query) + grad_key = torch.empty_like(key) + grad_value = torch.empty_like(value) + + if bias is not None: + lastDim = bias.size(-1) + lastDimAligned = lastDim if lastDim % 16 == 0 else lastDim + 16 - lastDim % 16 + new_sizes = list(bias.size()) + new_sizes[-1] = lastDimAligned + grad_bias = torch.empty(new_sizes, dtype=bias.dtype, device=bias.device) + grad_bias = grad_bias[..., :lastDim] + else: + grad_bias = torch.empty((), device=query.device) + + return grad_query, grad_key, grad_value, grad_bias + + +@register_meta([aten._scaled_mm.default]) +def meta_scaled_mm( + self: torch.Tensor, + mat2: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + bias: Optional[torch.Tensor] = None, + scale_result: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + use_fast_accum: bool = False, +): + def is_fp8_or_fp4_type(dtype): + return dtype in ( + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float8_e4m3fnuz, + torch.float8_e5m2fnuz, + torch.float4_e2m1fn_x2, + ) + + torch._check( + self.dim() == 2 and mat2.dim() == 2, + lambda: f"Inputs must be 2D but got self.dim()={self.dim()} and mat2.dim()={mat2.dim()}", + ) + torch._check( + is_fp8_or_fp4_type(self.dtype) and is_fp8_or_fp4_type(mat2.dtype), + lambda: f"Expected both inputs to be fp8 or fp4 types but got self.dtype={self.dtype} and mat2.dtype={mat2.dtype}", + ) + + if device_hint(self) == "cuda": + + def is_row_major(stride): + return stride[0] > stride[1] and stride[1] == 1 + + def is_col_major(stride): + return stride[0] == 1 and stride[1] > 1 + + def has_zero_dim(tensor_2d): + return tensor_2d.size(0) == 0 or tensor_2d.size(1) == 0 + + torch._check( + is_row_major(self.stride()) or has_zero_dim(self), + lambda: f"self must be row_major, got stride {self.stride()}", + ) + torch._check( + is_col_major(mat2.stride()) or has_zero_dim(mat2), + lambda: f"mat2 must be col_major, got stride {mat2.stride()}", + ) + torch._check( + self.size(1) % 16 == 0, + lambda: f"Expected self.size(1) to be divisible by 16, but got self.size(1)={self.size(1)}", + ) + torch._check( + mat2.size(0) % 16 == 0 and mat2.size(1) % 16 == 0, + lambda: f"Expected both dimensions of mat2 to be divisble by 16 but got {mat2.shape}", + ) + + # determine scaling type and check input dimensions (refer to Blas.cpp op) + + m, _k = self.shape + n = mat2.size(1) + + is_blockwise_scaling = ( + scale_a.dtype == torch.float8_e8m0fnu + and scale_b.dtype == torch.float8_e8m0fnu + ) or ( + scale_a.dtype == torch.float8_e4m3fn + and scale_b.dtype == torch.float8_e4m3fn + ) + + if scale_a.numel() == 1 and scale_b.numel() == 1: + # tensorwise scaling + torch._check( + scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32, + lambda: "For tensorwise scaling, both scale_a and scale_b must be float (fp32) tensors.", + ) + elif is_blockwise_scaling: + # blockwise scaling + + if scale_a.dtype == torch.float8_e4m3fn: + # NVIDIA's nvfp4 recipe: + # * block size is 16 elements packed (32 unpacked) + # * _k needs to be translated to the unpacked version + block_size_k = 16 + _k = _k * 2 + else: + block_size_k = 32 + + block_size_mn = 128 + + def ceil_div(a, b): + return (a + b - 1) // b + + num_k_blocks = ceil_div(_k, block_size_k) + padded_num_k_blocks = ceil_div(num_k_blocks, 4) * 4 + + expected_a_size = ( + block_size_mn * ceil_div(m, block_size_mn) * padded_num_k_blocks + ) + expected_b_size = ( + block_size_mn * ceil_div(n, block_size_mn) * padded_num_k_blocks + ) + + if ( + scale_a.numel() == expected_a_size + and scale_b.numel() == expected_b_size + ): + torch._check( + scale_a.is_contiguous(), + lambda: "scale_a must be contiguous", + ) + torch._check( + scale_b.is_contiguous(), + lambda: "scale_b must be contiguous", + ) + else: + torch._check( + False, + lambda: ( + "Invalid blockwise scaling configuration. " + f"For blockwise scaling, scale_a should have {expected_a_size} elements, got {scale_a.numel()}, " + f"scale_b should have {expected_b_size} elements, got {scale_b.numel()}." + ), + ) + else: + torch._check( + scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32, + lambda: "For rowwise scaling, both scale_a and scale_b must be float (fp32) tensors.", + ) + # for rowwise scaling, enforce 2D input tensors + torch._check( + scale_a.dim() == 2 and scale_b.dim() == 2, + lambda: f"For non-tensorwise scaling, scale tensors must be 2D, but got {scale_a.dim()=} and {scale_b.dim()=}", + ) + + if ( + scale_a.size(0) == m + and scale_a.size(1) == 1 + and scale_b.size(0) == 1 + and scale_b.size(1) == n + ): + # rowwise scaling + torch._check( + scale_a.is_contiguous() and scale_b.is_contiguous(), + lambda: "Both scale_a and scale_b must be contiguous for rowwise scaling.", + ) + else: + # does not match any valid scaling type + torch._check( + False, + lambda: ( + "Invalid scaling configuration. " + "For tensorwise scaling, both scales should be scalar. " + f"For rowwise scaling, scale_a should be ({m}, 1), scale_b should be (1, {n}). " + f"Got scale_a.size()=({scale_a.size(0)}, {scale_a.size(1)}) " + f"and scale_b.size()=({scale_b.size(0)}, {scale_b.size(1)})" + ), + ) + + _out_dtype = out_dtype if out_dtype is not None else self.dtype + return torch.empty(self.size(0), mat2.size(1), dtype=_out_dtype, device=self.device) + + +@register_meta([aten.scatter_reduce.two, aten.scatter_reduce.two_out]) +@out_wrapper() +def meta_scatter_reduce_two(self, dim, index, src, reduce, include_self=True): + scatter_meta_impl(self, dim, index, src, reduce, use_new_options=True) + return self.new_empty(self.shape) + + +@register_meta(aten.scatter_reduce_.two) +def meta_scatter_reduce__two(self, dim, index, src, reduce, include_self=True): + scatter_meta_impl(self, dim, index, src, reduce, use_new_options=True) + return self + + +@register_meta([aten.multinomial.default, aten.multinomial.out]) +@out_wrapper() +def meta_multinomial(input, num_samples, replacement=False, *, generator=None): + torch._check( + 0 < input.dim() <= 2, + lambda: f"The probabilty distributions dimensions must be 1 or 2, but got {input.dim()}", + ) + if input.dim() == 1: + return torch.empty(num_samples, dtype=torch.long, device=input.device) + return torch.empty( + input.size(0), num_samples, dtype=torch.long, device=input.device + ) + + +def multiply_integers(vs): + r = 1 + for v in vs: + r *= v + return r + + +def upsample_common_check(input_size, output_size, num_spatial_dims): + torch._check( + len(output_size) == num_spatial_dims, + lambda: f"It is expected output_size equals to {num_spatial_dims}, but got size {len(output_size)}", + ) + expected_input_dims = num_spatial_dims + 2 # N, C, ... + torch._check( + len(input_size) == expected_input_dims, + lambda: f"It is expected input_size equals to {expected_input_dims}, but got size {len(input_size)}", + ) + + torch._check( + all(s > 0 for s in input_size[2:]) and all(s > 0 for s in output_size), + lambda: f"Input and output sizes should be greater than 0, but got " + f"input size {input_size} and output size {output_size}", + ) + + nbatch, channels = input_size[:2] + return (nbatch, channels, *output_size) + + +@register_meta( + [aten.upsample_nearest1d.default, aten._upsample_nearest_exact1d.default] +) +def upsample_nearest1d(input, output_size, scales=None): + torch._check( + input.numel() != 0 or multiply_integers(input.size()[1:]), + lambda: f"Non-empty 3D data tensor expected but got a tensor with sizes {input.size()}", + ) + full_output_size = upsample_common_check( + input.size(), output_size, num_spatial_dims=1 + ) + return input.new_empty(full_output_size).to( + memory_format=utils.suggest_memory_format(input) + ) + + +@register_meta( + [aten.upsample_nearest2d.default, aten._upsample_nearest_exact2d.default] +) +def upsample_nearest2d(input, output_size, scales_h=None, scales_w=None): + torch._check( + input.numel() != 0 or multiply_integers(input.size()[1:]), + lambda: f"Non-empty 4D data tensor expected but got a tensor with sizes {input.size()}", + ) + full_output_size = upsample_common_check( + input.size(), output_size, num_spatial_dims=2 + ) + output = input.new_empty(full_output_size) + + # convert output to correct memory format, if necessary + memory_format = utils.suggest_memory_format(input) + + # following "heuristic: only use channels_last path when it's faster than the contiguous path" + _, n_channels, _, _ = input.shape + if input.device.type == "cuda" and n_channels < 4: + memory_format = torch.contiguous_format + + output = output.contiguous(memory_format=memory_format) + + return output + + +@register_meta( + [ + aten.upsample_nearest2d_backward.default, + aten._upsample_nearest_exact2d_backward.default, + ] +) +def upsample_nearest2d_backward( + grad_output: Tensor, + output_size: Sequence[Union[int, torch.SymInt]], + input_size: Sequence[Union[int, torch.SymInt]], + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +): + full_output_size = upsample_common_check( + input_size, output_size, num_spatial_dims=2 + ) + torch._check( + grad_output.ndim == 4, + lambda: f"Expected grad_output to be a tensor of dimension 4 but got: dimension {grad_output.ndim}", + ) + for i in range(4): + torch._check( + grad_output.size(i) == full_output_size[i], + lambda: ( + f"Expected grad_output to have the same shape as output;" + f" output.size({i}) = {full_output_size[i]}" + f" but got grad_output.size({i}) = {grad_output.size(i)}" + ), + ) + + return grad_output.new_empty(input_size).to( + memory_format=utils.suggest_memory_format(grad_output) + ) # type: ignore[call-overload] + + +@register_meta( + [aten.upsample_nearest3d.default, aten._upsample_nearest_exact3d.default] +) +def upsample_nearest3d(input, output_size, scales_d=None, scales_h=None, scales_w=None): + torch._check( + input.numel() != 0 or multiply_integers(input.size()[1:]), + lambda: f"Non-empty 5D data tensor expected but got a tensor with sizes {input.size()}", + ) + full_output_size = upsample_common_check( + input.size(), output_size, num_spatial_dims=3 + ) + return input.new_empty(full_output_size).to( + memory_format=utils.suggest_memory_format(input) + ) + + +@register_meta( + [ + aten.sort.default, + aten.sort.stable, + aten.sort.values, + aten.sort.values_stable, + ] +) +def meta_sort(self, stable=None, dim=-1, descending=False, values=None, indices=None): + v, i = torch.empty_like(self), torch.empty_like(self, dtype=torch.int64) + if values is not None and indices is not None: + assert isinstance(values, TensorLike) + assert isinstance(indices, TensorLike) + # Makes sure values and indices have the same strides. For cases where + # these have different shapes, like (5, 10, 5) and (0) in msort. + out_shape = v.shape + out_stride = v.stride() + values = _maybe_resize_out(values, out_shape) + indices = _maybe_resize_out(indices, out_shape) + values.as_strided_(out_shape, out_stride) + indices.as_strided_(out_shape, out_stride) + _safe_copy_out(copy_from=v, copy_to=values) # type: ignore[arg-type] + _safe_copy_out(copy_from=i, copy_to=indices) # type: ignore[arg-type] + return values, indices + return v, i + + +def rnn_cell_checkSizes( + input_gates, + hidden_gates, + input_bias, + hidden_bias, + factor, + prev_hidden, +): + torch._check(input_gates.ndim == 2, lambda: f"{input_gates.ndim} != 2") + torch._check( + input_gates.shape == hidden_gates.shape, + lambda: f"{input_gates.shape} != {hidden_gates.shape}", + ) + gates_size = input_gates.size(1) + if input_bias is not None: + torch._check(input_bias.ndim == 1, lambda: f"{input_bias.ndim} != 1") + torch._check( + input_bias.numel() == gates_size, + lambda: f"{input_bias.numel()} != {gates_size}", + ) + torch._check( + input_bias.shape == hidden_bias.shape, + lambda: f"{input_bias.shape} != {hidden_bias.shape}", + ) + torch._check(prev_hidden.ndim == 2, lambda: f"{prev_hidden.ndim} != 2") + expected_prev_hidden_numel = input_gates.size(0) * gates_size // factor + torch._check( + prev_hidden.numel() == expected_prev_hidden_numel, + lambda: f"{prev_hidden.numel()} != {input_gates.size(0)} * {gates_size} // {factor} (aka {expected_prev_hidden_numel})", + ) + torch._check( + all( + x.device == input_gates.device + for x in [hidden_gates, input_bias, hidden_bias, prev_hidden] + ), + lambda: "expected all inputs to be same device", + ) + + +@register_meta(aten._thnn_fused_lstm_cell.default) +def _thnn_fused_lstm_cell_meta( + input_gates, + hidden_gates, + cx, + input_bias=None, + hidden_bias=None, +): + rnn_cell_checkSizes(input_gates, hidden_gates, input_bias, hidden_bias, 4, cx) + workspace = torch.empty_like(input_gates, memory_format=torch.contiguous_format) + hy = torch.empty_like(cx, memory_format=torch.contiguous_format) + cy = torch.empty_like(cx, memory_format=torch.contiguous_format) + return (hy, cy, workspace) + + +@register_meta(aten._cudnn_rnn.default) +def _cudnn_rnn( + input, + weight, + weight_stride0, + weight_buf, + hx, + cx, + mode, + hidden_size, + proj_size, + num_layers, + batch_first, + dropout, + train, + bidirectional, + batch_sizes, + dropout_state, +): + is_input_packed = len(batch_sizes) != 0 + if is_input_packed: + seq_length = len(batch_sizes) + mini_batch = batch_sizes[0] + batch_sizes_sum = input.shape[0] + else: + seq_length = input.shape[1] if batch_first else input.shape[0] + mini_batch = input.shape[0] if batch_first else input.shape[1] + batch_sizes_sum = -1 + + num_directions = 2 if bidirectional else 1 + out_size = proj_size if proj_size != 0 else hidden_size + if is_input_packed: + out_shape = [batch_sizes_sum, out_size * num_directions] + else: + out_shape = ( + [mini_batch, seq_length, out_size * num_directions] + if batch_first + else [seq_length, mini_batch, out_size * num_directions] + ) + output = input.new_empty(out_shape) + + cell_shape = [num_layers * num_directions, mini_batch, hidden_size] + if cx is None: + cy = torch.empty(0, device=input.device) + else: + cy = cx.new_empty(cell_shape) + + hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size]) + + # TODO: Query cudnnGetRNNTrainingReserveSize (expose to python) + reserve_shape = 0 if train else 0 + reserve = input.new_empty(reserve_shape, dtype=torch.uint8) + + return output, hy, cy, reserve, weight_buf + + +@register_meta(aten.mkldnn_rnn_layer.default) +def mkldnn_rnn_layer( + input, + w0, + w1, + w2, + w3, + hx_, + cx_, + reverse, + batch_sizes, + mode, + hidden_size, + num_layers, + has_biases, + bidirectional, + batch_first, + train, +): + seq_length = input.shape[1] if batch_first else input.shape[0] + mini_batch = input.shape[0] if batch_first else input.shape[1] + output_chanels = hidden_size + out_shape = ( + [mini_batch, seq_length, output_chanels] + if batch_first + else [seq_length, mini_batch, output_chanels] + ) + output = input.new_empty(out_shape) + if hx_ is None: + hy = torch.empty(0, device=input.device) + else: + hy = hx_.new_empty(hx_.shape) + if cx_ is None: + cy = torch.empty(0, device=input.device) + else: + cy = cx_.new_empty(cx_.shape) + workspace = torch.empty(0, device=input.device, dtype=torch.uint8) + return output, hy, cy, workspace + + +def zero_numel_check_dims(self, dim, fn_name): + if self.ndim == 0: + torch._check_index( + dim == 0 or dim == -1, + lambda: f"{fn_name}: Expected reduction dim -1 or 0 for scalar but got {dim}", + ) + else: + torch._check_index( + self.size(dim) != 0, + lambda: f"{fn_name}: Expected reduction dim {dim} to have non-zero size.", + ) + + +# From aten/src/ATen/native/ReduceOps.cpp +def check_argmax_argmin(name, self, dim): + if dim is not None: + dim = maybe_wrap_dim(dim, self.dim()) + zero_numel_check_dims(self, dim, name) + else: + torch._check( + self.numel() != 0, + lambda: f"{name}: Expected reduction dim to be specified for input.numel() == 0.", + ) + + +@register_meta([aten.argmax.default, aten.argmin.default]) +def argmax_argmin_meta(self, dim=None, keepdim=False): + check_argmax_argmin("argmax", self, dim) + dims = utils.reduction_dims(self.shape, (dim,) if dim is not None else None) + shape = _compute_reduction_shape(self, dims, keepdim) + return self.new_empty(shape, dtype=torch.int64) + + +@register_meta(aten.scalar_tensor.default) +def scalar_tensor(s, dtype=None, layout=None, device=None, pin_memory=None): + # NB: It's always wrong to try to create a scalar tensor with the jagged layout. + # Rather than fix this everywhere, just use the strided layout and let NJT handle + # scalar tensor broadcasting. + if layout == torch.jagged: + layout = torch.strided + return torch.empty( + (), dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + +@register_meta(aten.topk.default) +def topk_meta(self, k, dim=-1, largest=True, sorted=True): + # From aten/src/ATen/native/Sorting.cpp + dim = maybe_wrap_dim(dim, self.dim(), wrap_scalar=True) + sliceSize = 1 if self.dim() == 0 else self.size(dim) + torch._check_is_size(k) + torch._check(k <= sliceSize, lambda: "k not in range for dimension") + + topKSize = list(self.shape) + if len(topKSize) > 0: + topKSize[dim] = k + return self.new_empty(topKSize), self.new_empty(topKSize, dtype=torch.int64) + + +@register_meta(aten._segment_reduce_backward) +@out_wrapper() +def meta__segment_reduce_backward( + grad, output, data, reduce, lengths=None, offsets=None, axis=0, initial=None +): + assert lengths is not None or offsets is not None, ( + "segment_reduce(): Either lengths or offsets must be defined" + ) + data_contig = data.contiguous() + grad_contig = grad.contiguous() + return torch.empty_like( + data_contig, + dtype=grad_contig.dtype, + device=grad_contig.device, + layout=grad_contig.layout, + ) + + +@register_meta([aten.kthvalue.default, aten.kthvalue.values]) +@out_wrapper("values", "indices") +def kthvalue_meta(self, k, dim=-1, keepdim=False): + from torch.fx.experimental.symbolic_shapes import sym_and + + dim = maybe_wrap_dim(dim, self.dim(), wrap_scalar=True) + dimSize = self.size(dim) if self.dim() > 0 else 1 + torch._check( + sym_and(k >= 1, k <= dimSize), + lambda: f"kthvalue(): selected number k out of range for dimension {dim}", + ) + + shape = list(self.shape[:dim] + self.shape[dim + 1 :]) + if keepdim and self.dim() > 0: + shape.insert(dim, 1) + return self.new_empty(shape), self.new_empty(shape, dtype=torch.int64) + + +legacy_contiguous_memory_format = torch.contiguous_format + + +# From aten/src/ATen/native/cuda/RNN.cu +def checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace): + defined_grad = grad_hy if grad_hy is not None else grad_cy + torch._check(defined_grad.dim() == 2, lambda: "") + exp_size = defined_grad.size() + if grad_hy is not None: + torch._check(grad_hy.size() == exp_size, lambda: "") + if grad_cy is not None: + torch._check(grad_cy.size() == exp_size, lambda: "") + torch._check(cx.size() == exp_size, lambda: "") + torch._check(cy.size() == exp_size, lambda: "") + torch._check(workspace.dim() == 2, lambda: "") + torch._check(workspace.numel() == exp_size[0] * exp_size[1] * 4, lambda: "") + + +# From aten/src/ATen/native/cuda/RNN.cu +@register_meta(aten._thnn_fused_lstm_cell_backward_impl.default) +def _thnn_fused_lstm_cell_backward_impl(grad_hy, grad_cy, cx, cy, workspace, has_bias): + if grad_hy is None and grad_cy is None: + return None, None, None + checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace) + grad_gates = torch.empty_like( + workspace, memory_format=legacy_contiguous_memory_format + ) + grad_cx = torch.empty_like(cx, memory_format=legacy_contiguous_memory_format) + grad_bias = grad_gates.sum(0, keepdim=False) if has_bias else None + return grad_gates, grad_cx, grad_bias + + +# From aten/src/ATen/native/mps/operations/Linear.mm +@register_meta(aten.linear_backward.default) +def linear_backward(input_, grad_output_, weight_, output_mask): + grad_input = None + grad_weight = None + grad_bias = None + if output_mask[0]: + grad_input = grad_output_.new_empty(input_.size()) + if output_mask[1] or output_mask[2]: + grad_weight = grad_output_.new_empty((grad_output_.size(-1), input_.size(-1))) + grad_bias = grad_output_.new_empty(grad_output_.size(-1)) + return (grad_input, grad_weight, grad_bias) + + +@register_meta(aten.pixel_shuffle.default) +def meta_pixel_shuffle(self, upscale_factor): + assert ( + len(self.shape) > 2 and self.shape[-3] % (upscale_factor * upscale_factor) == 0 + ), ( + f"Invalid input shape for pixel_shuffle: {self.shape} with upscale_factor = {upscale_factor}" + ) + + def is_channels_last(ten): + return torch._prims_common.suggest_memory_format(ten) == torch.channels_last + + def pick_memory_format(): + if is_channels_last(self): + if device_hint(self) == "cuda": + return torch.contiguous_format + else: + return torch.channels_last + elif self.is_contiguous(memory_format=torch.contiguous_format): + return torch.contiguous_format + elif self.is_contiguous(memory_format=torch.preserve_format): + return torch.preserve_format + + C = self.shape[-3] // (upscale_factor * upscale_factor) + Hr = self.shape[-2] * upscale_factor + Wr = self.shape[-1] * upscale_factor + out_shape = (*self.shape[:-3], C, Hr, Wr) + + out = self.new_empty(out_shape) + out = out.to(memory_format=pick_memory_format()) # type: ignore[call-overload] + return out + + +@register_meta(aten.mkldnn_rnn_layer_backward.default) +def mkldnn_rnn_layer_backward( + input, + weight0, + weight1, + weight2, + weight3, + hx_, + cx_tmp, + output, + hy_, + cy_, + grad_output_r_opt, + grad_hy_r_opt, + grad_cy_r_opt, + reverse, + mode, + hidden_size, + num_layers, + has_biases, + train, + bidirectional, + batch_sizes, + batch_first, + workspace, +): + diff_x = input.new_empty(input.shape) + diff_hx = hx_.new_empty(hx_.shape) + diff_cx = cx_tmp.new_empty(cx_tmp.shape) + diff_w1 = weight0.new_empty(weight0.shape) + diff_w2 = weight1.new_empty(weight1.shape) + diff_b = weight2.new_empty(weight2.shape) + return diff_x, diff_w1, diff_w2, diff_b, diff_b, diff_hx, diff_cx + + +@register_meta([aten.bucketize.Tensor, aten.bucketize.Tensor_out]) +@out_wrapper() +def meta_bucketize(self, boundaries, *, out_int32=False, right=False): + return torch.empty_like( + self, + dtype=torch.int32 if out_int32 else torch.int64, + memory_format=torch.contiguous_format, + ) + + +@register_meta([aten.histc]) +@out_wrapper() +def meta_histc(input, bins=100, min=0, max=0): + fn_name = "histc()" + if device_hint(input) == "cpu": + torch._check( + input.is_floating_point(), + lambda: f"\"histogram_cpu\" not implemented for '{input.dtype}'", + ) + if device_hint(input) == "cuda" and input.is_floating_point(): + utils.alert_not_deterministic("_histc_cuda with floating point input") + torch._check( + isinstance(bins, IntLike), + lambda: f"{fn_name}: argument 'bins' must be int, not {type(bins)}", + ) + torch._check(bins > 0, lambda: f"{fn_name}: bins must be > 0, but got {bins}") + torch._check( + isinstance(min, Number), + lambda: f"{fn_name}: argument 'min' must be Number, not {type(min)}", + ) + torch._check( + isinstance(max, Number), + lambda: f"{fn_name}: argument 'max' must be Number, not {type(max)}", + ) + torch._check(max >= min, lambda: "{fn_name}: max must be larger than min") + return torch.empty(bins, device=input.device, dtype=input.dtype) + + +@register_meta( + [aten._upsample_bilinear2d_aa.default, aten._upsample_bicubic2d_aa.default] +) +def meta_upsample_bimode2d_aa( + input, + output_size, + align_corners, + scales_h=None, + scales_w=None, +): + full_output_size = upsample_common_check( + input.size(), output_size, num_spatial_dims=2 + ) + torch._check( + input.numel() != 0 or all(size > 0 for size in input.size()[1:]), + lambda: f"Non-empty 4D data tensor expected but got a tensor with sizes {input.size()}", + ) + return input.new_empty(full_output_size).to( + memory_format=utils.suggest_memory_format(input) + ) + + +@register_meta([aten._upsample_bilinear2d_aa_backward.default]) +def meta_upsample_bimode2d_aa_backward( + grad_output, + output_size, + input_size, + align_corners, + scales_h=None, + scales_w=None, +): + full_output_size = upsample_common_check( + input_size, output_size, num_spatial_dims=2 + ) + torch._check( + grad_output.ndim == 4, + lambda: f"Expected grad_output to be a tensor of dimension 4 but got: dimension {grad_output.ndim}", + ) + for i in range(4): + torch._check( + grad_output.shape[i] == full_output_size[i], + lambda: f""" +Expected grad_output to have the same shape as output; output.size({i}) = {full_output_size[i]} +but got grad_output_size({i}) = {grad_output.size(i)}""", + ) + return grad_output.new_empty(input_size).to( + memory_format=utils.suggest_memory_format(grad_output) + ) + + +# From aten/src/ATen/native/cuda/AmpKernels.cu +@register_meta(aten._amp_foreach_non_finite_check_and_unscale_.default) +def _amp_foreach_non_finite_check_and_unscale_(self, found_inf, inv_scale): + torch._check( + found_inf.numel() == 1, lambda: "found_inf must be a 1-element tensor." + ) + torch._check( + inv_scale.numel() == 1, lambda: "inv_scale must be a 1-element tensor." + ) + torch._check( + found_inf.dtype.is_floating_point, + lambda: "found_inf must be a float tensor.", + ) + torch._check( + inv_scale.dtype.is_floating_point, + lambda: "inv_scale must be a float tensor.", + ) + + +# From aten/src/ATen/native/UnaryOps.cpp +@register_meta([aten.nan_to_num.default, aten.nan_to_num.out]) +@out_wrapper() +def nan_to_num(self, nan=None, posinf=None, neginf=None): + result_size = list(self.size()) + return self.new_empty(result_size) + + +@register_meta(torch.ops.aten.transpose_) +def transpose_(self, dim0, dim1): + assert self.layout not in { + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + }, ( + f"torch.transpose_: in-place transposition is not supported for {self.layout} layout" + ) + + ndims = self.ndim + + dim0 = maybe_wrap_dim(dim0, ndims) + dim1 = maybe_wrap_dim(dim1, ndims) + + if dim0 == dim1: + return self + + size = list(self.size()) + stride = list(self.stride()) + + stride[dim0], stride[dim1] = stride[dim1], stride[dim0] + size[dim0], size[dim1] = size[dim1], size[dim0] + + self.as_strided_(size, stride) + return self + + +@register_meta(torch.ops.aten.t_) +def t_(self): + ndims = self.ndim + + if self.is_sparse: + sparse_dim = self.sparse_dim() + dense_dim = self.dense_dim() + assert sparse_dim <= 2 and dense_dim == 0, ( + f"t_ expects a tensor with <= 2 sparse and 0 dense dimensions, " + f"but got {sparse_dim} sparse and {dense_dim} dense dimensions" + ) + else: + assert self.dim() <= 2, ( + f"t_ expects a tensor with <= 2 dimensions, but self is {ndims}D" + ) + + return transpose_(self, 0, 0 if ndims < 2 else 1) + + +@register_meta(aten.searchsorted) +@out_wrapper() +def meta_searchsorted( + sorted_sequence, + self, + *, + out_int32=False, + right=False, + side=None, + sorter=None, +): + # If the sorted_sequence is not one-dimensional, its shape must match that of values + # in all but the last dimension. + torch._check( + len(sorted_sequence.shape) <= 1 + or sorted_sequence.shape[:-1] == self.shape[:-1], + lambda: ( + "torch.searchsorted(): boundaries tensor should be 1 dimension or the " + "first N-1 dimensions of boundaries tensor and input value tensor must " + f"match, but we got boundaries tensor {list(sorted_sequence.shape)} and " + f"input value tensor {list(self.shape)}" + ), + ) + + # If a sorter array is provided, its dimensions must exactly match sorted_sequence. + torch._check( + sorter is None or sorted_sequence.shape == sorter.shape, + lambda: ( + "torch.searchsorted(): boundary and sorter must have the same size, but " + f"got boundary tensor {list(sorted_sequence.shape)} and got sorter tensor " + f"{list(sorter.shape) if sorter is not None else []}" + ), + ) + + # Per the docs, if side == "left" and right is True, we error. + torch._check( + side != "left" or not right, + "torch.searchsorted(): side and right can't be set to opposites, got side of " + "left while right was True", + ) + + dtype = torch.int32 if out_int32 else torch.int64 + if isinstance(self, torch.Tensor): + return torch.empty_like( + self, dtype=dtype, memory_format=torch.contiguous_format + ) + else: # Scalar + return torch.empty((), dtype=dtype, device=sorted_sequence.device) + + +def _check_for_unsupported_isin_dtype(dtype): + torch._check( + dtype not in (torch.bool, torch.complex128, torch.complex64), + lambda: f"Unsupported input type encountered for isin(): {dtype}", + ) + + +@register_meta(aten.embedding_dense_backward) +def meta_embedding_dense_backward( + grad_output, + indices, + num_weights, + padding_idx, + scale_grad_by_freq, +): + grad_weight = grad_output.new_empty((num_weights, grad_output.size(-1))) + return grad_weight + + +@register_meta(aten._embedding_bag_backward) +def meta_embedding_bag_backward( + grad, + indices, + offsets, + offset2bag, + bag_size, + maximum_indices, + num_weights, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + padding_idx=-1, +): + if sparse: + return aten._embedding_bag_sparse_backward( + grad, + indices, + offsets, + offset2bag, + bag_size, + num_weights, + scale_grad_by_freq, + mode, + per_sample_weights, + padding_idx, + ) + else: + return meta_embedding_bag_dense_backward( + grad, + indices, + offset2bag, + bag_size, + maximum_indices, + num_weights, + scale_grad_by_freq, + mode, + per_sample_weights, + padding_idx, + ) + + +@register_meta(aten._embedding_bag_dense_backward) +def meta_embedding_bag_dense_backward( + grad, + indices, + offset2bag, + bag_size, + maximum_indices, + num_weights, + scale_grad_by_freq, + mode, + per_sample_weights, + padding_idx=-1, +): + torch._check( + grad.dtype in [torch.float16, torch.bfloat16, torch.float32, torch.float64], + lambda: f"Unsupported input type encountered: {grad.dtype}", + ) + if mode == MODE_MAX: + torch._check(maximum_indices is not None) + index_grad_weight = grad.new_empty((num_weights, grad.size(1))) + return index_grad_weight + + +@register_meta(aten._embedding_bag_per_sample_weights_backward) +def meta_embedding_bag_per_sample_weights_backward( + grad, + weight, + indices, + offsets, + offset2bag, + mode, + padding_idx=-1, +): + embedding_features = grad.size(1) + torch._check( + mode == MODE_SUM, + "embedding_bag_backward: per_sample_weights only supported for mode='sum'", + ) + torch._check(grad.dim() == 2) + torch._check(indices.dim() == 1) + num_samples = indices.size(0) + torch._check(weight.dim() == 2) + torch._check(weight.size(1) == embedding_features) + output = grad.new_empty((num_samples,)) + return output + + +@register_meta(aten.isin) +@out_wrapper() +def meta_isin(elements, test_elements, *, assume_unique=False, invert=False): + torch._check( + isinstance(elements, Tensor) or isinstance(test_elements, Tensor), + lambda: "At least one of elements and test_elements must be a Tensor.", + ) + if not isinstance(elements, Tensor): + elements = torch.tensor(elements, device=test_elements.device) + + if not isinstance(test_elements, Tensor): + test_elements = torch.tensor(test_elements, device=elements.device) + + _check_for_unsupported_isin_dtype(elements.dtype) + _check_for_unsupported_isin_dtype(test_elements.dtype) + return torch.empty_like(elements, dtype=torch.bool) + + +@register_meta(aten.polygamma) +@out_wrapper() +def meta_polygamma(n: int, self: Tensor) -> Tensor: + torch._check(n >= 0, lambda: "polygamma(n, x) does not support negative n.") + _, result_dtype = elementwise_dtypes( + self, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + ) + return torch.empty_like(self, dtype=result_dtype) + + +@register_meta(aten._local_scalar_dense) +def meta_local_scalar_dense(self: Tensor): + raise RuntimeError("Tensor.item() cannot be called on meta tensors") + + +@register_meta(aten.silu) +@out_wrapper(exact_dtype=True) +def silu(self: Tensor) -> Tensor: + return torch.empty_like(self) + + +@register_meta(aten.sigmoid) +@out_wrapper() +def sigmoid(self: Tensor) -> Tensor: + _, result_dtype = elementwise_dtypes( + self, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + ) + return torch.empty_like(self, dtype=result_dtype) + + +def _create_grouped_mm_output_tensor(mat1, mat2, offs, out_dtype): + mat1_is_2d = mat1.dim() == 2 + mat2_is_2d = mat2.dim() == 2 + + if mat1_is_2d: + if mat2_is_2d: + out_size = [offs.size(0), mat1.size(0), mat2.size(1)] + else: + torch._check( + offs.size(0) == mat2.size(0), "matrix batch sizes have to match" + ) + out_size = [mat1.size(0), mat2.size(-1)] + else: + if mat2_is_2d: + torch._check( + offs.size(0) == mat1.size(0), "matrix batch sizes have to match" + ) + out_size = [mat1.size(1), mat2.size(1)] + else: + # regular bmm + torch._check(mat1.size(0) == mat2.size(0), "batched dimension has to match") + out_size = [mat1.size(0), mat1.size(1), mat2.size(-1)] + + out_dtype = out_dtype or mat1.dtype + + alignment = 16 // out_dtype.itemsize + size_padded = (out_size[-1] + alignment - 1) // alignment * alignment + if mat1_is_2d == mat2_is_2d: + out_stride = [out_size[1] * size_padded, size_padded, 1] + else: + out_stride = [size_padded, 1] + out = torch.empty_strided(out_size, out_stride, dtype=out_dtype, device=mat1.device) + return out + + +def _meta_grouped_mm_common( + mat_a: Tensor, + mat_b: Tensor, + scale_a: Optional[torch.Tensor], + scale_b: Optional[torch.Tensor], + offs: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + scale_result: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + use_fast_accum: bool = False, +): + torch._check( + (scale_a is None) == (scale_b is None), + lambda: "Either both scale factors are given, or none", + ) + scaled = scale_a is not None and scale_b is not None + + # Implementing all the checks from + # _grouped_mm_cuda()/_scaled_grouped_mm_cuda() code in + # aten/src/ATen/native/cuda/Blas.cpp. + + if scaled: + torch._check( + mat_a.dtype == torch.float8_e4m3fn and mat_b.dtype == torch.float8_e4m3fn, + lambda: f"Expected inputs of E4M3 FP8 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.", + ) + else: + torch._check( + mat_a.dtype == torch.bfloat16 and mat_b.dtype == torch.bfloat16, + lambda: f"Expected inputs of BF16 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.", + ) + + torch._check( + mat_a.dim() in [2, 3] and mat_b.dim() in [2, 3], + lambda: f"Multiplicands must be 2D or 3D but got mat_a.dim()={mat_a.dim()} and mat_b.dim()={mat_b.dim()}", + ) + + mat_a_is_2d = mat_a.dim() == 2 + mat_b_is_2d = mat_b.dim() == 2 + + if scaled: + + def is_row_major(mat): + mat_stride = mat.stride() + return mat_stride[-2] > 1 and mat_stride[-1] == 1 + + def is_col_major(mat): + mat_stride = mat.stride() + return mat_stride[-2] == 1 and mat_stride[-1] > 1 + + torch._check( + is_row_major(mat_a), + lambda: f"Expected mat_a tensor to be row major in the last two dimensions, got strides {mat_a.stride()[-2:]}", + ) + torch._check( + is_col_major(mat_b), + lambda: f"Expected mat_b tensor to be column major in the last two dimensions, got strides {mat_b.stride()[-2:]}", + ) + + def check_valid_strides(mat_name, mat): + end_dim = mat.dim() - 1 + alignment = 16 // mat.element_size() + mat_stride = mat.stride() + if mat_stride[end_dim - 1] == 1 and mat_stride[end_dim] >= max( + 1, mat.shape[end_dim - 1] + ): + torch._check( + mat_stride[end_dim] % alignment == 0, + lambda: f"Expected {mat_name} stride along {end_dim} dim to be multiple of 16 bytes, got {mat_stride[end_dim]}.", + ) + elif mat_stride[end_dim] == 1 and mat_stride[end_dim - 1] >= max( + 1, mat.shape[end_dim] + ): + torch._check( + mat_stride[end_dim - 1] % alignment == 0, + lambda: f"Expected {mat_name} stride along {end_dim - 1} dim to be multiple of 16 bytes, got {mat_stride[end_dim - 1]}.", # noqa: B950 + ) + else: + torch._check( + False, + lambda: f"Invalid strides/sizes, got {mat_stride} for strides and {mat.shape} for sizes.", # noqa: B950 + ) + + check_valid_strides("mat_a", mat_a) + check_valid_strides("mat_b", mat_b) + + if scale_a is not None and scale_b is not None: + torch._check( + scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32, + lambda: "Both scale_a and scale_b must be float (fp32) tensors, but got scale_a.dtype={scale_a.dtype} and scale_b.dtype={scale_b.dtype}.", # noqa: B950 + ) + + def check_scale(scale_name, scale, mat, scaled_dim, scale_multiplier=1): + if mat.dim() == 2: + torch._check( + scale.dim() == 1, + lambda: f"Expected {scale_name} to be 1D tensor, but got {scale.dim()}D tensor.", + ) + torch._check( + scale.is_contiguous(), + lambda: f"Expected {scale_name} to be contiguous.", + ) + torch._check( + scale.shape[0] == mat.shape[scaled_dim] * scale_multiplier, + lambda: f"Expected {scale_name} to have {mat.shape[scaled_dim] * scale_multiplier} elements, got {scale.shape[0]} elements.", # noqa: B950 + ) + else: + torch._check( + scale.dim() == 2, + lambda: f"Expected {scale_name} to be 2D tensor, but got {scale.dim()}D tensor.", + ) + torch._check( + scale.stride(1) == 1, + lambda: f"Expected {scale_name} to be contiguous in the last dimension.", + ) + torch._check( + scale.shape[0] == mat.shape[0], + lambda: f"Expected {scale_name} batch dimension to be {mat.shape[0]}, got {scale.shape[0]}.", + ) + torch._check( + scale.shape[1] == mat.shape[1 + scaled_dim], + lambda: f"Expected {scale_name} non-batch dimension to be {mat.shape[1 + scaled_dim]}, got {scale.shape[1]}.", + ) + + scale_multiplier = ( + offs.shape[0] if offs is not None and mat_a_is_2d and mat_b_is_2d else 1 + ) + check_scale("scale_a", scale_a, mat_a, 0, scale_multiplier) + check_scale("scale_b", scale_b, mat_b, 1, scale_multiplier) + + torch._check( + scale_result is None, + lambda: "Scale result tensor provided, but it is not supported yet.", + ) + + if mat_a_is_2d or mat_b_is_2d: + torch._check( + offs is not None, + lambda: f"Offsets tensor not provided, but is needed for {mat_a.dim()}D/{mat_b.dim()}D multiplicand layouts.", + ) + if offs is not None: # to silence Mypy + torch._check( + offs.dim() == 1, + lambda: f"Offsets tensor must be 1D, but got offs.dim()={offs.dim()}.", + ) + torch._check( + offs.dtype == torch.int32, + lambda: f"Offsets tensor must be integer (int32) tensor, but got {offs.dtype}.", + ) + else: + torch._check( + offs is None, + lambda: "Offsets tensor provided, but is not needed for 3D/3D multiplicand layouts.", + ) + + torch._check( + bias is None, + lambda: "Bias tensor provided, but it is not supported yet.", + ) + + torch._check( + out_dtype is None or out_dtype == torch.bfloat16, + lambda: "If output dtype provided, it must be torch.bfloat16.", + ) + + return _create_grouped_mm_output_tensor(mat_a, mat_b, offs, out_dtype) + + +@register_meta(aten._grouped_mm) +@out_wrapper() +def grouped_mm( + mat_a: Tensor, + mat_b: Tensor, + offs: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + out_dtype: Optional[torch.dtype] = None, +) -> Tensor: + return _meta_grouped_mm_common( + mat_a, + mat_b, + scale_a=None, + scale_b=None, + offs=offs, + bias=bias, + scale_result=None, + out_dtype=out_dtype, + ) + + +@register_meta([aten._scaled_grouped_mm.default]) +def meta_scaled_grouped_mm( + mat_a: torch.Tensor, + mat_b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + offs: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + scale_result: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + use_fast_accum: bool = False, +): + return _meta_grouped_mm_common( + mat_a, + mat_b, + scale_a=scale_a, + scale_b=scale_b, + offs=offs, + bias=bias, + scale_result=scale_result, + out_dtype=out_dtype, + use_fast_accum=use_fast_accum, + ) + + +@register_meta(aten._softmax) +@out_wrapper() +def softmax(x: Tensor, dim: int, half_to_float: bool) -> Tensor: + if half_to_float: + assert x.dtype == torch.half + computation_dtype, result_dtype = utils.elementwise_dtypes( + x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + + result_dtype = result_dtype if not half_to_float else computation_dtype + res = torch.empty_like(x, dtype=result_dtype, memory_format=torch.contiguous_format) + return res + + +@register_meta(aten.constant_pad_nd) +@out_wrapper() +def _constant_pad_nd_meta(input, pad, value=0): + # same checks as decomposition in torch/_refs/__init__.py:constant_pad_nd() + torch._check( + len(pad) % 2 == 0, + lambda: f"Length of pad must be even but instead it equals {len(pad)}", + ) + + input_sizes = input.shape + l_inp = len(input_sizes) + l_pad = len(pad) // 2 + l_diff = l_inp - l_pad + + torch._check( + l_inp >= l_pad, + lambda: "Length of pad should be no more than twice the number of " + f"dimensions of the input. Pad length is {len(pad)} while the input has " + f"{l_inp} dimensions.", + ) + + new_shape = list(input_sizes[:l_diff]) + for i in range(l_pad): + pad_idx = len(pad) - ((i + 1) * 2) + new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1] + torch._check( + new_dim >= 0, + lambda: f"The input size {input_sizes[l_diff + i]}, plus negative padding " + f"{pad[pad_idx]} and {pad[pad_idx + 1]} resulted in a negative output size, " + f"which is invalid. Check dimension {l_diff + i} of your input.", + ) + new_shape.append(new_dim) + + return torch.empty( + new_shape, + dtype=input.dtype, + device=input.device, + requires_grad=input.requires_grad, + memory_format=suggest_memory_format(input), + ) + + +@register_meta(aten.embedding) +@out_wrapper() +def embedding( + weight: Tensor, + indices: Tensor, + padding_idx: int = -1, + scale_grad_by_freq: bool = False, + sparse: bool = False, +) -> Tensor: + assert weight.dim() == 2, "'weight' must be 2-D" + weight_shape = weight.shape + indices_shape = indices.shape + + if indices.ndim == 0: + out_shape: tuple[int, ...] = (weight_shape[1],) + elif indices.ndim == 1: + out_shape = (indices_shape[0], weight_shape[1]) + else: + out_shape = (*indices_shape, weight_shape[1]) + + out_dtype = weight.dtype + return weight.new_empty(out_shape, dtype=out_dtype) + + +@register_meta(aten._jagged_to_padded_dense_forward.default) +def meta__jagged_to_padded_dense_forward( + values: Tensor, + offsets: list[Tensor], + max_lengths: list[int], + padding_value: float = 0.0, +): + # only one jagged dim is supported for now + assert len(offsets) == 1 + assert len(max_lengths) == 1 + + B = offsets[0].shape[0] - 1 + S = max_lengths[0] + output_shape = (B, S, *values.shape[1:]) + return values.new_empty(output_shape) + + +def _create_unary_float_meta_func(func): + @register_meta(func) + @out_wrapper() + def _f(x): + return elementwise_meta( + x, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ) + + return _f + + +def _create_binary_float_meta_func(func): + @register_meta(func) + @out_wrapper() + def _f(x, y): + return elementwise_meta( + x, y, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ) + + return _f + + +_create_unary_float_meta_func(aten.special_airy_ai) +_create_unary_float_meta_func(aten.special_bessel_y0) +_create_unary_float_meta_func(aten.special_bessel_y1) +_create_unary_float_meta_func(aten.special_modified_bessel_i0) +_create_unary_float_meta_func(aten.special_modified_bessel_i1) +_create_unary_float_meta_func(aten.special_modified_bessel_k0) +_create_unary_float_meta_func(aten.special_modified_bessel_k1) +_create_unary_float_meta_func(aten.special_scaled_modified_bessel_k0) +_create_unary_float_meta_func(aten.special_scaled_modified_bessel_k1) + + +_create_binary_float_meta_func(aten.special_chebyshev_polynomial_t) +_create_binary_float_meta_func(aten.special_chebyshev_polynomial_u) +_create_binary_float_meta_func(aten.special_chebyshev_polynomial_v) +_create_binary_float_meta_func(aten.special_chebyshev_polynomial_w) +_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_t) +_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_u) +_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_v) +_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_w) +_create_binary_float_meta_func(aten.special_hermite_polynomial_h) +_create_binary_float_meta_func(aten.special_hermite_polynomial_he) +_create_binary_float_meta_func(aten.special_laguerre_polynomial_l) +_create_binary_float_meta_func(aten.special_legendre_polynomial_p) + + +def _register_inplace_meta(fn): + @wraps(fn) + def _fn(self, *args, **kwargs): + out = fn(self, *args, **kwargs) + check_inplace_broadcast(self.shape, out.shape) + return self + + inplace_name = f"{fn.__name__}_" + _fn.__name__ = inplace_name + _fn = register_meta(getattr(aten, inplace_name))(_fn) # type: ignore[assignment] + + return _fn + + +@register_meta(aten.lerp) +@out_wrapper() +def lerp(start, end, weight): + torch._check( + start.dtype == end.dtype, + lambda: f"expected dtype {start.dtype} for `end`, but got dtype {end.dtype}", + ) + args = [start, end] + if isinstance(weight, TensorLike): + if weight.ndim != 0: + torch._check( + start.dtype == weight.dtype, + lambda: f"expected dtype {start.dtype} for `weight`, but got dtype {weight.dtype}", + ) + args.append(weight) + return elementwise_meta( + *args, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + + +@register_meta(aten.addcmul) +@out_wrapper() +def addcmul(input, tensor1, tensor2, *, value=1): + return elementwise_meta( + input, tensor1, tensor2, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + + +@register_meta(aten.addcdiv) +@out_wrapper() +def addcdiv(input, tensor1, tensor2, *, value=1): + torch._check( + not ( + utils.is_integer_dtype(tensor1.dtype) + and utils.is_integer_dtype(tensor2.dtype) + ), + lambda: ( + "Integer division with addcdiv is no longer supported, and in a future ", + "release addcdiv will perform a true division of tensor1 and tensor2. ", + "The historic addcdiv behavior can be implemented as ", + "(input + value * torch.trunc(tensor1 / tensor2)).to(input.dtype) ", + "for integer inputs and as ", + "(input + value * tensor1 / tensor2) for float inputs. ", + "The future addcdiv behavior is just the latter implementation: ", + "(input + value * tensor1 / tensor2), for all dtypes.", + ), + ) + return elementwise_meta( + input, tensor1, tensor2, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + + +lerp_ = _register_inplace_meta(aten.lerp) +addcmul_ = _register_inplace_meta(aten.addcmul) +addcdiv_ = _register_inplace_meta(aten.addcdiv) + + +# We must also trigger meta registrations from PrimTorch ref +# decompositions +import torch._refs +import torch._refs.nn.functional +import torch._refs.special + + +def activate_meta(): + activate_meta_table = {} + + # For a given op, we pick the most specific decomp function from + # global_decomp_table in the precedence order of meta > post_autograd > pre_autograd + for type in ["meta", "post_autograd", "pre_autograd"]: + registry = global_decomposition_table[type] + + for opo in registry: + if opo not in activate_meta_table: + activate_meta_table[opo] = registry[opo] + + for op_overload, fn in activate_meta_table.items(): + # Don't register meta for HigherOrderOp's decomp. + # We can reconsider this in the future, but in general, + # the way you do a meta for a HigherOrderOp is different from + # OpOverload. + if isinstance(op_overload, torch._ops.HigherOrderOperator): + continue + assert isinstance(op_overload, OpOverload) + + op_overload.py_impl(torch._C.DispatchKey.Meta)(fn) + + if torch._C._dispatch_has_kernel_for_dispatch_key( + op_overload.name(), "CompositeImplicitAutograd" + ): + # Internally, we shouldn't be registering meta kernels for any operators that + # have CompositeImplicitAutograd kernels. + # Instead, we should be letting those decompositions run, and writing meta kernels + # only for the base operators. + if op_overload in global_decomposition_table["meta"]: + raise RuntimeError( + f"{op_overload} is a CompositeImplicitAutograd op, we shouldn't " + "register meta function for it. Instead, we should let the decomposition run and write " + "meta kernels for the base operators." + ) + elif op_overload.is_view: + # Attempting to register a python meta kernel for a view operator. + # We shouldn't do this, because the output will report as not having aliased storages. + # All view ops have meta kernels in C++ today, so we should use those instead. + pass + elif ( + op_overload.name() + in { + "aten::empty_strided", # causing infinite recursion, test_meta.py + "aten::clone", # causing infinite recursion + "aten::_to_copy", # causing infinite recursion, test_serialization.py -k test_tensor_subclass_getstate_overwrite # noqa: B950 + "aten::copy_", # Exception not raised, test_torch.py -k test_storage_meta_errors_cpu_int64 # noqa: B950 + "aten::constant_pad_nd", # requires_grad mismatch, test_ops.py -k test_fake_crossref_backward_amp_istft_cuda_float32 # noqa: B950 + "aten::rot90", # requires_grad mismatch! test_ops.py -k test_fake_crossref_backward_amp_rot90_cuda_float32 # noqa: B950 + "aten::as_strided_scatter", # requires_grad mismatch, test_ops.py -k test_fake_crossref_backward_no_amp_as_strided_scatter_cuda_float32 # noqa: B950 + } + ): + pass + else: + if "mkldnn::" in op_overload.name(): + _meta_lib_dont_use_me_use_register_meta_for_mkldnn.impl(op_overload, fn) + elif "mkl::" in op_overload.name(): + _meta_lib_dont_use_me_use_register_meta_for_mkl.impl(op_overload, fn) + elif "onednn::" in op_overload.name(): + _meta_lib_dont_use_me_use_register_meta_for_onednn.impl(op_overload, fn) + elif "quantized::" in op_overload.name(): + _meta_lib_dont_use_me_use_register_meta_for_quantized.impl( + op_overload, fn + ) + else: + _meta_lib_dont_use_me_use_register_meta.impl(op_overload, fn) + + +activate_meta() diff --git a/phivenv/Lib/site-packages/torch/_namedtensor_internals.py b/phivenv/Lib/site-packages/torch/_namedtensor_internals.py new file mode 100644 index 0000000000000000000000000000000000000000..36d0f34c520b58ada8b3a53d4d6e4fae9dc48da8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_namedtensor_internals.py @@ -0,0 +1,159 @@ +# mypy: allow-untyped-defs +from collections import OrderedDict + + +""" +This file contains helper functions that implement experimental functionality +for named tensors in python. All of these are experimental, unstable, and +subject to change or deletion. +""" + + +def check_serializing_named_tensor(tensor): + if tensor.has_names(): + raise RuntimeError( + "NYI: Named tensors don't support serialization. Please drop " + "names via `tensor = tensor.rename(None)` before serialization." + ) + + +def build_dim_map(tensor): + """Returns a map of { dim: dim_name } where dim is a name if the dim is named + and the dim index otherwise.""" + return OrderedDict( + [(idx if name is None else name, name) for idx, name in enumerate(tensor.names)] + ) + + +def unzip_namedshape(namedshape): + if isinstance(namedshape, OrderedDict): + namedshape = namedshape.items() + if not hasattr(namedshape, "__iter__") and not isinstance(namedshape, tuple): + raise RuntimeError( + f"Expected namedshape to be OrderedDict or iterable of tuples, got: {type(namedshape)}" + ) + if len(namedshape) == 0: + raise RuntimeError("Expected namedshape to non-empty.") + return zip(*namedshape) + + +def namer_api_name(inplace): + if inplace: + return "rename_" + else: + return "rename" + + +def is_ellipsis(item): + return item == Ellipsis or item == "..." + + +def single_ellipsis_index(names, fn_name): + ellipsis_indices = [i for i, name in enumerate(names) if is_ellipsis(name)] + if len(ellipsis_indices) >= 2: + raise RuntimeError( + f"{fn_name}: More than one Ellipsis ('...') found in names (" + f"{names}). This function supports up to one Ellipsis." + ) + if len(ellipsis_indices) == 1: + return ellipsis_indices[0] + return None + + +def expand_single_ellipsis(numel_pre_glob, numel_post_glob, names): + return names[numel_pre_glob : len(names) - numel_post_glob] + + +def replace_ellipsis_by_position(ellipsis_idx, names, tensor_names): + globbed_names = expand_single_ellipsis( + ellipsis_idx, len(names) - ellipsis_idx - 1, tensor_names + ) + return names[:ellipsis_idx] + globbed_names + names[ellipsis_idx + 1 :] + + +def resolve_ellipsis(names, tensor_names, fn_name): + """ + Expands ... inside `names` to be equal to a list of names from `tensor_names`. + """ + ellipsis_idx = single_ellipsis_index(names, fn_name) + if ellipsis_idx is None: + return names + return replace_ellipsis_by_position(ellipsis_idx, names, tensor_names) + + +def update_names_with_list(tensor, names, inplace): + # Special case for tensor.rename(None) + if len(names) == 1 and names[0] is None: + return tensor._update_names(None, inplace) + + return tensor._update_names( + resolve_ellipsis(names, tensor.names, namer_api_name(inplace)), inplace + ) + + +def update_names_with_mapping(tensor, rename_map, inplace): + dim_map = build_dim_map(tensor) + for old_dim in rename_map.keys(): + new_dim = rename_map[old_dim] + if old_dim in dim_map.keys(): + dim_map[old_dim] = new_dim + else: + raise RuntimeError( + f"{namer_api_name(inplace)}: Tried to rename dim '{old_dim}' to dim " + f"{new_dim} in Tensor[{tensor.names}] but dim '{old_dim}' does not exist" + ) + return tensor._update_names(tuple(dim_map.values()), inplace) + + +def update_names(tensor, names, rename_map, inplace): + """There are two usages: + + tensor.rename(*names) returns a view on tensor with named dims `names`. + `names` must be of length `tensor.dim()`; otherwise, if '...' is in `names`, + then it is expanded greedily to be equal to the corresponding names from + `tensor.names`. + + For example, + ``` + >>> # xdoctest: +SKIP + >>> x = torch.empty(2, 3, 5, 7, names=('N', 'C', 'H', 'W')) + >>> x.rename('...', 'height', 'width').names + ('N', 'C', 'height', 'width') + + >>> # xdoctest: +SKIP + >>> x.rename('batch', '...', 'width').names + ('batch', 'C', 'H', 'width') + + ``` + + tensor.rename(**rename_map) returns a view on tensor that has rename dims + as specified in the mapping `rename_map`. + + For example, + ``` + >>> # xdoctest: +SKIP + >>> x = torch.empty(2, 3, 5, 7, names=('N', 'C', 'H', 'W')) + >>> x.rename(W='width', H='height').names + ('N', 'C', 'height', 'width') + + ``` + + Finally, tensor.rename has an in-place version called tensor.rename_. + """ + has_names = len(names) > 0 + has_rename_pairs = bool(rename_map) + if has_names and has_rename_pairs: + raise RuntimeError( + f"{namer_api_name(inplace)}: This function takes either positional " + f"args or keyword args, but not both. Use tensor.{namer_api_name(inplace)}(*names) " + f"to name dims and tensor.{namer_api_name(inplace)}(**rename_map) to rename " + "dims." + ) + + # Special case for tensor.rename(*[]), which is valid for a 0 dim tensor. + if not has_names and not has_rename_pairs: + return update_names_with_list(tensor, names, inplace) + + if has_names: + return update_names_with_list(tensor, names, inplace) + return update_names_with_mapping(tensor, rename_map, inplace) diff --git a/phivenv/Lib/site-packages/torch/_ops.py b/phivenv/Lib/site-packages/torch/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..3465203098dbec43cc8ecf4c2f762bb4ab479bd8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_ops.py @@ -0,0 +1,1483 @@ +# mypy: allow-untyped-defs +import abc +import contextlib +import ctypes +import importlib +import inspect +import sys +import types +from collections.abc import Iterator +from functools import cached_property +from typing import ( + Any, + Callable, + ClassVar, + final, + Generic, + Optional, + TYPE_CHECKING, + Union, +) +from typing_extensions import Concatenate, ParamSpec, TypeVar + +import torch +import torch.utils._pytree as pytree +from torch import _utils_internal +from torch._C import _dispatch_is_included_in_alias as is_included_in_alias, DispatchKey +from torch._functorch.pyfunctorch import dispatch_functorch, TransformType +from torch.utils._python_dispatch import TorchDispatchMode + + +if TYPE_CHECKING: + from torch._subclasses.functional_tensor import BaseFunctionalizeAPI + + +_T = TypeVar("_T", default=Any) +_P = ParamSpec("_P", default=...) + + +# Query `hasattr` only once. +_SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags") + + +@contextlib.contextmanager +def dl_open_guard(): + """ + Context manager to set the RTLD_GLOBAL dynamic linker flag while we open a + shared library to load custom operators. + """ + if not _SET_GLOBAL_FLAGS: + yield + return + old_flags = sys.getdlopenflags() + sys.setdlopenflags(old_flags | ctypes.RTLD_GLOBAL) + try: + yield + finally: + sys.setdlopenflags(old_flags) + + +class OperatorBase: + """ + Base class for OpOverload (which represents C++ ATen operators) and HigherOrderOperator + (which represents Python-only operators that are unrepresentable in TorchScript). + """ + + def __init__(self): + # The dispatch cache precomputes a mapping of dispatch key that the + # dispatcher wants to dispatch to, to an actual implementation of the + # dispatch key. Confusingly, the actual implementation could *also* be a + # dispatch key, but in this case, this refers to the C++ kernel that + # was registered to some dispatch key. Aliases are permitted in the + # latter but not the former; for example, you might lookup the + # entry for AutogradCPU, and this maps you to the Autograd key for + # the generic autograd kernel that works for all devices. Since this + # is the Python dispatcher, you can also put an arbitrary Python + # callable to call instead. This handler gets precisely the + # args/kwargs that the operator was __call__'ed with. + # NB: This name is hard-coded in torch/csrc/autograd/python_variable.cpp + # for use with OpOverload; cache lookup is done entirely from C++ + # for speed. + # TODO: The cache is NOT currently used by HigherOrderOperator, but it should! + self._dispatch_cache: dict[ + DispatchKey, Union[DispatchKey, Callable[..., Any]] + ] = {} + + # This table allows you to override the behavior of a particular + # dispatch key to call a custom Python function, rather than the + # ordinary C++ configured behavior. This is the raison d'etre of + # Python dispatcher: to let you program the dispatcher from Python + # in case you need something unusual, and don't want to clobber + # the existing registrations using the Python operator registration + # API. + self.py_kernels: dict[DispatchKey, Callable[..., Any]] = {} + + # This table allows you to override the behavior of a particular + # operator for a particular TorchDispatchMode. In practice, + # we are using this mostly for ProxyTensorMode. Modes can be + # thought of as an open world extension of dispatch keys, so it + # makes sense that you should be able to register them, the same + # way you can register dispatch keys. + self.python_key_table: dict[ + type[Union[TorchDispatchMode, torch.Tensor]], Callable[..., Any] + ] = {} + + # This table allows you to override the behavior of functorch + # transformations. NB: this currently only does something for + # HigherOrderOperator + self.functorch_table = {} + + def __call__(self, *args, **kwargs): + raise NotImplementedError + + def has_kernel_for_dispatch_key(self, k): + return k in self.py_kernels + + def has_kernel_for_any_dispatch_key(self, ks): + for k in self.py_kernels: + if not torch._C._dispatch_is_alias_key(k) and ks.has(k): + return True + return False + + def py_impl( + self, + k: Union[ + type[TorchDispatchMode], + type[torch.Tensor], + TransformType, + DispatchKey, + ], + ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + def inner(fn: Callable[_P, _T]) -> Callable[_P, _T]: + if inspect.isclass(k) and ( + issubclass(k, TorchDispatchMode) or issubclass(k, torch.Tensor) + ): + assert k not in self.python_key_table + # TODO(voz): Should we replace setting DispatchKey.Python entirely with setting mode keys? + self.python_key_table[k] = fn + self._dispatch_cache.clear() + return fn + + if isinstance(k, TransformType): + assert k not in self.functorch_table + self.functorch_table[k] = fn + return fn + + assert isinstance(k, DispatchKey) + assert k != DispatchKey.Python, ( + "Please register a mode for the DispatchKey.Python key instead." + ) + + if k in self.py_kernels: + raise RuntimeError( + f"Trying to override a python impl for {k} on operator {self.name()}" + ) + self.py_kernels[k] = fn + self._dispatch_cache.clear() + return fn + + return inner + + # Registers an implementation to all **3** variants of functionalization that we have: + # - DispatchKey.Functionalize + # - functorch.TransformType.Functionalize + # - FunctionalTensorMode + # Example: + # @py_functionalize_impl + # def functionalize_rule(ctx, inner_f, *args): + # args_unwrapped = ctx.unwrap_tensors(args) + # with ctx.redispatch_to_next(): + # out = ctx.functionalize(inner_f)(*args_unwrapped) + # return ctx.wrap_tensors(out) + def py_functionalize_impl( + self, fn: Callable[Concatenate["BaseFunctionalizeAPI", _P], _T] + ) -> Callable[Concatenate["BaseFunctionalizeAPI", _P], _T]: + from torch._subclasses.functional_tensor import ( + CppFunctionalizeAPI, + FunctionalTensorMode, + FunctorchFunctionalizeAPI, + PythonFunctionalizeAPI, + ) + + # Construct our three flavors of functionalization, + # each of which have slightly different wrap/unwrap/redispatch policies + def functionalize_dk_fn(*args: _P.args, **kwargs: _P.kwargs) -> _T: + return fn(CppFunctionalizeAPI(), *args, **kwargs) + + def functionalize_dispatch_mode_fn( + mode: Optional[FunctionalTensorMode], *args: _P.args, **kwargs: _P.kwargs + ) -> _T: + return fn(PythonFunctionalizeAPI(mode), *args, **kwargs) + + def functionalize_functorch_fn( + interpreter, *args: _P.args, **kwargs: _P.kwargs + ) -> _T: + return fn(FunctorchFunctionalizeAPI(interpreter), *args, **kwargs) + + self.py_impl(DispatchKey.Functionalize)(functionalize_dk_fn) + self.py_impl(FunctionalTensorMode)(functionalize_dispatch_mode_fn) + self.py_impl(TransformType.Functionalize)(functionalize_functorch_fn) + + return fn + + def name(self): + raise NotImplementedError + + +# Equivalent to computeDispatchTableEntryWithDebug +def resolve_key(op: OperatorBase, k: DispatchKey): # type: ignore[valid-type] + # 1. (Direct) operator registration + if op.has_kernel_for_dispatch_key(k): + return k + # 2.1 Use CompositeExplicitAutogradNonFunctional kernel if available + cand = DispatchKey.CompositeExplicitAutogradNonFunctional + if ( + k == DispatchKey.Undefined or is_included_in_alias(k, cand) + ) and op.has_kernel_for_dispatch_key(cand): + return cand + # 2.2 Use CompositeExplicitAutograd kernel if available + cand = DispatchKey.CompositeExplicitAutograd + if ( + k == DispatchKey.Undefined or is_included_in_alias(k, cand) + ) and op.has_kernel_for_dispatch_key(cand): + return cand + has_backend_kernel = op.has_kernel_for_any_dispatch_key( + torch._C._dispatch_get_backend_keyset_from_autograd(k) + ) or op.has_kernel_for_dispatch_key(DispatchKey.CompositeExplicitAutograd) + # 2.3. Use CompositeImplicitAutograd kernel if available + cand = DispatchKey.CompositeImplicitAutogradNestedTensor + if ( + (k != DispatchKey.Undefined and is_included_in_alias(k, cand)) + and op.has_kernel_for_dispatch_key(cand) + and not has_backend_kernel + ): + return cand + cand = DispatchKey.CompositeImplicitAutograd + if ( + k == DispatchKey.Undefined or is_included_in_alias(k, cand) + ) and op.has_kernel_for_dispatch_key(cand): + if k == DispatchKey.AutogradOther and op.has_kernel_for_any_dispatch_key( + torch._C._dispatch_autogradother_backends + ): + raise RuntimeError("ambiguous autogradother kernel") + elif not has_backend_kernel: + return cand + # 2.4. For autograd backend keys, use kernel from DispatchKey::Autograd if available + cand = DispatchKey.Autograd + if is_included_in_alias(k, cand) and op.has_kernel_for_dispatch_key(cand): + return cand + # 2.5 Use kernel from DispatchKey::FuncTorchBatchedDecomposition if available + cand = DispatchKey.FuncTorchBatchedDecomposition + if is_included_in_alias(k, cand) and op.has_kernel_for_dispatch_key(cand): + return cand + # Backend fallback + if torch._C._dispatch_has_backend_fallback(k): + # The dispatch key itself will implicitly route to backend fallback. + # This is probably not great for the pure Python implementation. + return k + raise NotImplementedError(f"could not find kernel for {op} at dispatch key {k}") + + +_higher_order_ops: dict[str, "HigherOrderOperator"] = {} + +_HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS = [ + DispatchKey.PythonDispatcher, # type: ignore[attr-defined] + DispatchKey.PythonTLSSnapshot, # type: ignore[attr-defined] + DispatchKey.ADInplaceOrView, + DispatchKey.BackendSelect, + DispatchKey.AutocastCPU, # type: ignore[attr-defined] + DispatchKey.AutocastCUDA, # type: ignore[attr-defined] +] + + +class HigherOrderOperator(OperatorBase, abc.ABC): + # The HigherOrderOperator will appear as torch.ops.higher_order.{name} + # + # If you're creating a new HigherOrderOperator, please do not change the + # default. Adding operators to the global torch.ops namespace is a bad + # practice due to name collisions. + def __init__(self, name, *, cacheable=False): + super().__init__() + if type(self) is HigherOrderOperator: + raise RuntimeError( + "Direct instantiation of HigherOrderOperator is not allowed. Please subclass it." + ) + self._name = name + + # Make _OPNamespace not scream, this whole name based association needs a good hard look + self.__name__ = name + _higher_order_ops[name] = self + self._ns = "higher_order" + self.__module__ = "torch.ops.higher_order" + self._cacheable = cacheable + + self.non_fallthrough_keys = torch._C._dispatch_keyset_full() + + for dispatch_key in _HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS: + self.fallthrough(dispatch_key) + + # [NOTE] We have to register pre-dispatch key implementation + # because sometimes HOP use aot-dispatch tracing to detect certaion + # mutations. This is problematic when we are functionalizing HOP + # during pre-dispatch because when the inner tracer starts, it will see + # that PreDispatch key is still active. In that case, we just redispatch + # it to next key. This is only safe to do when PreDispatch key stack has no + # active modes. + + def py_impl( + self, + k: Union[ + type[TorchDispatchMode], + type[torch.Tensor], + TransformType, + DispatchKey, + ], + ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + if isinstance(k, DispatchKey) and not self.non_fallthrough_keys.has(k): + self.non_fallthrough_keys = self.non_fallthrough_keys.add(k) + return super().py_impl(k) + + def py_autograd_impl( + self, + fn: Callable[_P, _T], + ) -> Callable[_P, _T]: + def maybe_run_autograd(*args: _P.args, **kwargs: _P.kwargs) -> _T: + if not torch.is_grad_enabled() or pytree.tree_all_only( + torch.Tensor, + lambda t: not t.requires_grad, # type: ignore[union-attr] + (*args, kwargs), + ): + with torch._C._AutoDispatchBelowAutograd(): + return self(*args, **kwargs) + + from torch._higher_order_ops.utils import _has_gen_schema + + if _has_gen_schema(self): + schema = self.gen_schema(*args, **kwargs) + if any(arg.is_write for arg in schema.arguments): + raise RuntimeError( + f"The {self.name()} HigherOrderOperator does not currently support training " + "with in-place input or buffer mutations " + "If you require this feature, please submit an issue to PyTorch. " + "Alternatively, consider creating your own custom autograd.Function. " + ) + + return fn(*args, **kwargs) + + self.py_impl(DispatchKey.Autograd)(maybe_run_autograd) + + return fn + + @property + def namespace(self): + return self._ns + + @final + def cacheable(self) -> bool: + from torch._functorch.autograd_function import AutogradFunctionApply + + return ( + self._cacheable + or f"{self.__module__}.{self.__name__}" + in torch._inductor.config.unsafe_marked_cacheable_functions + or ( + isinstance(self, AutogradFunctionApply) + and torch._functorch.config.autograd_cache_allow_custom_autograd_functions + ) + ) + + def fallthrough(self, dispatch_key): + self.non_fallthrough_keys = self.non_fallthrough_keys.remove(dispatch_key) + + # Use positional-only argument to avoid naming collide with custom ops arguments + # that are named "self". + def dispatch(self, /, dispatch_key, *args, **kwargs): + from torch.utils._python_dispatch import _get_current_dispatch_mode + + if dispatch_key in self._dispatch_cache: + kernel = self._dispatch_cache[dispatch_key] + assert not isinstance(kernel, DispatchKey) + return kernel(*args, **kwargs) + + if dispatch_key == DispatchKey.FuncTorchDynamicLayerFrontMode: + return dispatch_functorch(self, args, kwargs) + + if dispatch_key == DispatchKey.Python: + # Keep the following 1:1 with handle_torch_function_no_python_arg_parser + # in torch/csrc/utils/python_arg_parser.cpp + + overloaded_args_list = [] + + def has_python_key(tensor): + return torch._C._dispatch_keys(tensor).has("Python") + + def check_overloaded(arg): + if isinstance(arg, torch.Tensor) and has_python_key(arg): + overloaded_args_list.append(arg) + + for arg in (*args, *kwargs.values()): + check_overloaded(arg) + if isinstance(arg, (list, tuple)): + for a in arg: + check_overloaded(a) + + overloaded_args = tuple(overloaded_args_list) + + # Step 1: dispatch on any user TorchDispatchModes + from torch.utils._python_dispatch import _pop_mode_temporarily + + curr_mode = _get_current_dispatch_mode() + if curr_mode is not None: + if type(curr_mode) in self.python_key_table: + handler = self.python_key_table[type(curr_mode)] + with _pop_mode_temporarily() as mode: + # "natural" calling convention: (mode, *args, **kwargs) + # TODO(rzou): we should support torch_dispatch calling convention too. + result = handler(mode, *args, **kwargs) + else: + raise NotImplementedError( + f"There was no rule registered for HOP {self._name} and mode {curr_mode}. " + f"We recommend filing an issue." + ) + if result is not NotImplemented: + return result + + # Step 2: dispatch on any subclasses + for arg in overloaded_args: + subclass_type = type(arg) + if ( + subclass_type.__torch_dispatch__ + == torch._C._disabled_torch_dispatch_impl + ): + continue + + # In some case, people are using FakeTensor without a FakeTensorMode. + # For example, some sparse arch model has a mix of FakeTensor and real + # tensor for weights during lowering, and ppl tends to run eager evaluation + # on the model without setting up the FakeTensorMode. + # In this case, we pull FakeTensorMode impl. + if subclass_type is torch._subclasses.fake_tensor.FakeTensor: + subclass_type = torch._subclasses.fake_tensor.FakeTensorMode # type: ignore[assignment] + handler = self.python_key_table[subclass_type] + result = handler(arg.fake_mode, *args, **kwargs) # type: ignore[attr-defined] + return result + + if subclass_type in self.python_key_table: + handler = self.python_key_table[subclass_type] + # "natural" calling convention: (*args, **kwargs) + # TODO(rzou): we should support torch_dispatch calling convention too. + result = handler(*args, **kwargs) + else: + raise NotImplementedError( + f"There was no rule registered for HOP {self._name} and subclass {subclass_type}. " + f"We recommend filing an issue." + ) + if result is not NotImplemented: + return result + + # All handlers returned NotImplemented + raise TypeError( + f"Multiple dispatch failed for {self._name}. There was no registered that " + f"did not return NotImplemented. Use HOP.py_impl to register some. " + f"Tried mode: {curr_mode}) and subclasses: " + f"{[type(a) for a in overloaded_args]}" + ) + + functionality_key = torch._C._to_functionality_key(dispatch_key) # type: ignore[attr-defined] + if functionality_key == DispatchKey.PreDispatch: + from torch.utils._python_dispatch import _pop_mode_temporarily + + # The check for Python in the exclude set is so we properly respect `with no_dispatch()` + # calls inside of a mode. + if ( + _len_torch_dispatch_stack_pre_dispatch() > 0 + ) and not torch._C._dispatch_tls_is_dispatch_key_excluded( + DispatchKey.Python + ): + curr_mode = _get_current_dispatch_mode_pre_dispatch() + assert curr_mode is not None, ( + "Illegal invocation of dispatch on DispatchKey.PreDispatch without a mode." + ) + assert type(curr_mode) in self.python_key_table, ( + f"Current active mode {curr_mode} not registered" + ) + handler = self.python_key_table[type(curr_mode)] + with _pop_mode_temporarily(functionality_key) as mode: + return handler(mode, *args, **kwargs) + + final_key = resolve_key(self, dispatch_key) + + # This can current fail due to backend fallbacks. You just have to + # register them by hand for HigherOrderOperator. + if final_key not in self.py_kernels: + raise NotImplementedError( + f"could not find kernel for HigherOrderOperator {self._name} " + f"at dispatch key {final_key} (resolved from {dispatch_key})" + ) + + # [NOTE] We shouldn't cache PreDispatch kernel here because depending + # on what modes are active, predispatch behaviour is different. + # Also we do same thing for normal ops: + # See Note [Not Caching Per-Dispatch-Key Mode Handlers] + if dispatch_key != DispatchKey.PreDispatch: + self._dispatch_cache[dispatch_key] = self.py_kernels[final_key] + kernel = self.py_kernels[final_key] + # It's illegal to register DispatchKey to py_kernels, since there's no + # C++ kernel to call into + assert not isinstance(kernel, DispatchKey) + return kernel(*args, **kwargs) + + @abc.abstractmethod + def __call__(self, /, *args, **kwargs): + def wrapper(): + flat_args = _to_flat_tuple(args, kwargs) + if torch.overrides.has_torch_function(flat_args): + return torch.overrides.handle_torch_function( + self, flat_args, *args, **kwargs + ) + + dispatch_key_set = _compute_keyset(args, kwargs, self.non_fallthrough_keys) + return self.dispatch( + dispatch_key_set.highestPriorityTypeId(), *args, **kwargs + ) + + return wrapper() + + # NOTE [HigherOrderOprator Schema] + # Each invocation of a HigherOrderOperator (hop) should have its own schema because + # the subgraphs and the arguments can be different even for the same hop. + # + # Each hop should implement its own gen_schema method, which should + # take the same input as the __call__ method and returns a FunctionSchema. + # The schema provides a unified way to check if the hop mutates its inputs, + # which can be useful in implementing optimizations. + # + # If the hop doesn't implement the gen_schema method, + # we expect it to be functional. It should not mutate its inputs and there + # are no input, output aliasing via views or direct referencing. + def gen_schema(self, *args, **kwargs): + raise NotImplementedError( + f"HigherOrderOperator {self._name} does not implement a gen_schema. " + f"This is OK as long as the hop is functional. " + f"e.g. it should not mutate its inputs and there are no input, output aliasing " + f"via views or direct referencing." + ) + + def __str__(self): + return f"{self.name()}" + + def name(self): + return self._name + + +def _to_flat_tuple(args, kwargs): + return pytree.arg_tree_leaves(*args, **kwargs) + + +def _compute_keyset(args, kwargs, non_fallthrough_keys): + tensors = _get_tensors(args, kwargs) + return key_extractor(tensors, non_fallthrough_keys) + + +def _get_tensors(args, kwargs): + flat_all = _to_flat_tuple(args, kwargs) + tensor_args = [t for t in flat_all if isinstance(t, torch.Tensor)] + return tuple(tensor_args) + + +# Note - this should maintain identical impl to the C++ dispatcher key extraction logic +# at ATen/core/dispatch/DispatchKeyExtractor.h +def key_extractor(tensors, key_mask): + key_set = torch._C._dispatch_tls_local_include_set() + for tensor in tensors: + key_set = key_set | torch._C._dispatch_keys(tensor) + key_set = key_set - torch._C._dispatch_tls_local_exclude_set() + key_set = key_set & key_mask + return key_set + + +# Mode stack for PreDispatchKey +# it should always have three keys with +# priority given to FunctionalTensorMode and +# then ProxyTorchDispatchMode. It means that +# slot 0 belongs to ProxyTorchDispatchMode and +# slot 1 belongs to FunctionalTensorMode. +# +# SchemaCheckMode is separate from the other 2, +# and is only valid when the stack is empty. +# SchemaCheckMode is for testing purposes, and +# is meant to run in eager mode on concrete inputs, +# checking for incorrect schemas in regards to +# aliasing or mutating ops. +class _ModeStackStateForPreDispatch: + def __init__(self): + self.__infra_modes = [None, None] + self._schema_check_mode = None + + def set(self, index, mode): + assert index < len(self.__infra_modes) + self.__infra_modes[index] = mode + + def get(self, index): + assert index < len(self.__infra_modes) + return self.__infra_modes[index] + + def count(self): + return len([i for i in self.__infra_modes if i is not None]) + int( + self._schema_check_mode is not None + ) + + +_mode_stack_state_for_pre_dispatch = _ModeStackStateForPreDispatch() + + +def unset_mode_pre_dispatch(mode_key, schema_check=False): + current_mode_stack_pre_dispatch = mode_stack_state_for_pre_dispatch() + assert mode_key is None or mode_key in ( + torch._C._TorchDispatchModeKey.PROXY, + torch._C._TorchDispatchModeKey.FUNCTIONAL, + ) + if schema_check: + assert mode_key is None + + def _unset_mode(): + if mode_key == torch._C._TorchDispatchModeKey.PROXY: + current_mode = current_mode_stack_pre_dispatch.get(0) + mode_stack_state_for_pre_dispatch().set(0, None) + return current_mode + elif mode_key == torch._C._TorchDispatchModeKey.FUNCTIONAL: + current_mode = current_mode_stack_pre_dispatch.get(1) + mode_stack_state_for_pre_dispatch().set(1, None) + return current_mode + else: + current_mode = mode_stack_state_for_pre_dispatch()._schema_check_mode + mode_stack_state_for_pre_dispatch()._schema_check_mode = None + return current_mode + + current_mode = _unset_mode() + + new_pre_dispatch_len = _len_torch_dispatch_stack_pre_dispatch() + # When we are unsetting a mode, we need to check if there is + # active mode left on the PreDispatch key. If there is nothing + # active, we need to remove PreDispatch key from local dispatch include + # set. + if new_pre_dispatch_len == 0: + torch._C._dispatch_tls_set_dispatch_key_included(DispatchKey.PreDispatch, False) + + return current_mode + + +def _set_mode_pre_dispatch(mode): + from torch._subclasses.functional_tensor import FunctionalTensorMode + from torch._subclasses.schema_check_mode import SchemaCheckMode + from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode + + assert isinstance( + mode, + ( + FunctionalTensorMode, + ProxyTorchDispatchMode, + SchemaCheckMode, + ), + ) + + previous_mode_stack_len = _len_torch_dispatch_stack_pre_dispatch() + if isinstance(mode, SchemaCheckMode): + current_mode = mode_stack_state_for_pre_dispatch()._schema_check_mode + if previous_mode_stack_len > 0: + raise AssertionError( + "SchemaCheckMode for pre-dispatch must be used exclusively, found other modes on the stack" + ) + mode_stack_state_for_pre_dispatch()._schema_check_mode = mode + elif isinstance(mode, FunctionalTensorMode): + current_mode = mode_stack_state_for_pre_dispatch().get(1) + assert current_mode is None + mode_stack_state_for_pre_dispatch().set(1, mode) + else: + current_mode = mode_stack_state_for_pre_dispatch().get(0) + assert current_mode is None + mode_stack_state_for_pre_dispatch().set(0, mode) + + # When we are setting a mode, we need to check if there is + # active mode left on the PreDispatch key. If there was nothing + # active before setting this mode, it means that PreDispatch key + # was turned off. So we need to turn it on again. + if previous_mode_stack_len == 0: + torch._C._dispatch_tls_set_dispatch_key_included(DispatchKey.PreDispatch, True) + + +def _pop_mode_from_pre_dispatch(): + mode_stack = mode_stack_state_for_pre_dispatch() + pre_dispatch_len = _len_torch_dispatch_stack_pre_dispatch() + + if pre_dispatch_len == 0: + raise AssertionError("Trying to pop empty mode stack") + + if mode_stack._schema_check_mode is not None: + return unset_mode_pre_dispatch(None, schema_check=True) + if mode_stack.get(1) is not None: + return unset_mode_pre_dispatch(torch._C._TorchDispatchModeKey.FUNCTIONAL) + if mode_stack.get(0) is not None: + return unset_mode_pre_dispatch(torch._C._TorchDispatchModeKey.PROXY) + + +def _len_torch_dispatch_stack_pre_dispatch(): + return mode_stack_state_for_pre_dispatch().count() + + +def _get_dispatch_mode_pre_dispatch(mode_key): + assert mode_key in ( + torch._C._TorchDispatchModeKey.PROXY, + torch._C._TorchDispatchModeKey.FUNCTIONAL, + ) + if mode_key == torch._C._TorchDispatchModeKey.PROXY: + return mode_stack_state_for_pre_dispatch().get(0) + else: + return mode_stack_state_for_pre_dispatch().get(1) + + +def _get_current_dispatch_mode_pre_dispatch(): + if mode_stack_state_for_pre_dispatch()._schema_check_mode is not None: + return mode_stack_state_for_pre_dispatch()._schema_check_mode + else: + stack_len = mode_stack_state_for_pre_dispatch().count() + if stack_len == 2: + return mode_stack_state_for_pre_dispatch().get(1) + if stack_len == 1: + return ( + mode_stack_state_for_pre_dispatch().get(1) + if mode_stack_state_for_pre_dispatch().get(1) is not None + else mode_stack_state_for_pre_dispatch().get(0) + ) + return None + + +def mode_stack_state_for_pre_dispatch(): + global _mode_stack_state_for_pre_dispatch + return _mode_stack_state_for_pre_dispatch + + +cached_ops: set["OpOverload"] = set() + + +def add_cached_op(op_overload): + global cached_ops + cached_ops.add(op_overload) + + +def reset_cached_ops(): + global cached_ops + cached_ops.clear() + + +def get_cached_ops(): + global cached_ops + return cached_ops + + +# Each OpOverload object contains pointer to a specific operator overload, a pointer to the parent `OpOverloadPacket` object. +# You can obtain an OpOverload object through attribute query on OpOverloadPacket. +class OpOverload(OperatorBase, Generic[_P, _T]): + def __init__( + self, + overloadpacket: "OpOverloadPacket", + op: Callable[_P, _T], + op_dk: Callable[Concatenate[DispatchKey, _P], _T], + schema: torch._C.FunctionSchema, + tags: list[Any], + ) -> None: + super().__init__() + self._op = op + self._op_dk = op_dk + self._schema = schema + self._overloadpacket = overloadpacket + self._tags = tags + self._overloadname = ( + "default" if schema.overload_name == "" else schema.overload_name + ) + if tags: + self._nondeterministic_seeded = torch.Tag.nondeterministic_seeded in tags + self._name = self._schema.name + if schema.overload_name: + self._name += "." + schema.overload_name + self.__name__ = f"{self._schema.name.split('::')[1]}.{self._overloadname}" + self.__module__ = overloadpacket.__module__ + op.__module__ = overloadpacket.__module__ + self.__qualname__ = self._name + self.__annotations__ = {} + + # If the OpOverload was constructed from a Library.def in Python. + self._defined_in_python = self.__qualname__ in torch.library._defs + + # Logic replicated from aten/src/ATen/native/MathBitsFallback.h + is_write = None + for a in self._schema.arguments: + if a.alias_info is None: + continue + if is_write is None: + is_write = a.alias_info.is_write + else: + # We will conservatively call mixed mutable/non-mutable + # aliased inputs as NOT a view + is_write = a.alias_info.is_write or is_write + self.is_view = is_write is not None and not is_write + + @cached_property + def _namespace(self) -> str: + return self._schema.name.split("::", maxsplit=1)[0] + + @cached_property + def _opname(self) -> str: + return self._schema.name.split("::", maxsplit=1)[1] + + @cached_property + def _handle(self) -> torch._C._DispatchOperatorHandle: + return torch._C._dispatch_find_schema_or_throw( + self._schema.name, self._schema.overload_name + ) + + # it's a no-op since OpOverload object is immutable and must be unique for a given op overload. + def __deepcopy__(self, memo=None): + return self + + def __repr__(self): + return f"" + + # Use positional-only argument to avoid naming collision with aten ops arguments + # that are named "self". This way, all the aten ops can be called by kwargs. + def __call__(self, /, *args: _P.args, **kwargs: _P.kwargs) -> _T: + return self._op(*args, **kwargs) + + # Use positional-only argument to avoid naming collision with aten ops arguments + # that are named "self". This way, all the aten ops can be called by kwargs. + def redispatch( + self, /, keyset: torch._C.DispatchKeySet, *args: _P.args, **kwargs: _P.kwargs + ) -> _T: + return self._handle.redispatch_boxed(keyset, *args, **kwargs) # type: ignore[return-value] + + def __hash__(self): + return hash(self._op) + + # `my_namespace.my_op_name.overload_name` + def __str__(self): + return "{}.{}.{}".format(*self._schema.name.split("::"), self._overloadname) + + def has_kernel_for_dispatch_key(self, k: DispatchKey) -> bool: + return super().has_kernel_for_dispatch_key( + k + ) or torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), k) + + def has_kernel_for_any_dispatch_key(self, ks: torch._C.DispatchKeySet) -> bool: + return torch._C._dispatch_has_kernel_for_any_dispatch_key( + self.name(), ks + ) or super().has_kernel_for_any_dispatch_key(ks) + + @property + def namespace(self) -> str: + return self._namespace + + def _can_decompose(self) -> bool: + dk = DispatchKey.CompositeImplicitAutograd + return dk in self.py_kernels or torch._C._dispatch_has_kernel_for_dispatch_key( + self.name(), dk + ) + + def decompose(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: + dk = DispatchKey.CompositeImplicitAutograd + if dk in self.py_kernels: + # NB: This branch is not too necessary anymore, because we can + # apply Python CompositeImplicitAutograd *before* tracing + # using Python dispatcher (also taking advantage of the autograd + # formula). But it's included for completeness + return self.py_kernels[dk](*args, **kwargs) + elif torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), dk): + return self._op_dk(dk, *args, **kwargs) + else: + return NotImplemented + + # Remove a dispatch key from the dispatch cache. This will force it to get + # recomputed the next time. Does nothing + # WARNING: if you register a dispatch key to py_kernels of an OpOverload, + # calling _del_dispatch on that key is NOT sufficient to apply your change, + # because a single registration may affect MULTIPLE dispatch keys (e.g., + # registering Autograd affects AutogradCPU). del_dispatch is to be used + # only if you are specifically modifying how get_dispatch handles a + # particular input 'key'. + def _uncache_dispatch(self, key: DispatchKey) -> None: + self._dispatch_cache.pop(key, None) + + # This implements the pre-computation logic for the Python dispatcher. + def _get_dispatch(self, key: DispatchKey) -> Union[DispatchKey, Callable[_P, _T]]: + # This is only called upon a cache miss + assert key not in self._dispatch_cache, f"{self} {key}" + + if key == DispatchKey.Python: + if not isinstance(self, TorchBindOpOverload) and not self.python_key_table: + self._dispatch_cache[key] = key + add_cached_op(self) + return key + + def handler(*args: _P.args, **kwargs: _P.kwargs) -> _T: + from torch.utils._python_dispatch import _get_current_dispatch_mode + + # TODO: We also need to handle tensor subclasses here + # TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now. + curr_mode = type(_get_current_dispatch_mode()) + assert curr_mode is not None, ( + "Illegal invocation of dispatch on DispatchKey.Python without a mode." + ) + + if curr_mode not in self.python_key_table: + if isinstance(self, TorchBindOpOverload): + with ( + torch.utils._python_dispatch._pop_mode_temporarily() as mode + ): + return torch._library.utils.handle_dispatch_mode( + mode, self, *args, **kwargs + ) + else: + return self._op_dk(key, *args, **kwargs) + + with torch.utils._python_dispatch._pop_mode_temporarily() as mode: + return self.python_key_table[curr_mode](mode, *args, **kwargs) + + self._dispatch_cache[key] = handler + add_cached_op(self) + return handler + + functionality_key = torch._C._to_functionality_key(key) # type: ignore[attr-defined] + if functionality_key == DispatchKey.PreDispatch: + curr_stack_len = _len_torch_dispatch_stack_pre_dispatch() + # The check for Python in the exclude set is so we properly respect `with no_dispatch()` + # calls inside of a mode. + if ( + curr_stack_len > 0 + and not torch._C._dispatch_tls_is_dispatch_key_excluded( + DispatchKey.Python + ) + ): + + def handler(*args: _P.args, **kwargs: _P.kwargs) -> _T: + @contextlib.contextmanager + def _temporarily_pop_modes_from_pre_dispatch(): + top_mode = _pop_mode_from_pre_dispatch() + try: + yield top_mode + finally: + _set_mode_pre_dispatch(top_mode) + + with _temporarily_pop_modes_from_pre_dispatch() as curr_mode: + return torch._library.utils.handle_dispatch_mode( + curr_mode, self, *args, **kwargs + ) + + # Note [Not Caching Per-Dispatch-Key Mode Handlers] + # Note that we're not caching this handler. There isn't really a point, since the slow bit + # is the handler itself (in python). + # Also, not caching means that we don't have to reset the cache when any existing + # modes go out of scope (which in of itself takes time to loop through all operators). + return handler + + final_key = resolve_key(self, key) + + # See Note [Not Caching Per-Dispatch-Key Mode Handlers] + cache_result = key != DispatchKey.PreDispatch + + # TODO: We could potentially have lots of debugging wrappers against + # dispatch keys; design some general registration mechanism instead of + # having if statement for each of them + if key == DispatchKey.Functionalize: + import torch._dispatch.python as pydispatch + + if pydispatch.CROSSREF_FUNCTIONALIZE: + handler = pydispatch.make_crossref_functionalize(self, final_key) # type: ignore[assignment] + if cache_result: + self._dispatch_cache[key] = handler + add_cached_op(self) + return handler + + r = self.py_kernels.get(final_key, final_key) + if cache_result: + self._dispatch_cache[key] = r + add_cached_op(self) + return r + + def name(self): + return self._name + + @property + def overloadpacket(self): + return self._overloadpacket + + @property + def op(self): + return self._op + + @property + def tags(self): + return self._tags + + # TODO: add more methods to expose information about input and output arguments + + +# TorchBindOpOverload are those custom ops which have at least one overload's +# schema consists of torch.ScriptObject (i.e. custom class) input. +# TorchBindOpOverload will skip C++ dispatcher and purely dispatched in python +# when its inputs contain FakeScriptObject in a similar way as higher order ops. +class TorchBindOpOverload(OpOverload[_P, _T]): + def _fallthrough_keys(self) -> list[DispatchKey]: + # TODO: we should be calling the fallback for these, but a fallthrough is almost close + # enough to the fallback in most cases that we care about. + _DEFAULT_FALLTHROUGH_KEYS = [ + DispatchKey.Autograd, + DispatchKey.AutogradCPU, + DispatchKey.AutogradCUDA, + DispatchKey.ADInplaceOrView, + DispatchKey.BackendSelect, + DispatchKey.PythonTLSSnapshot, + DispatchKey.PythonDispatcher, + ] + + def _may_use_fallthrough_instead_of_fallback(key: DispatchKey): + if torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), key): + return torch._C._dispatch_kernel_for_dispatch_key_is_fallthrough( + self.name(), key + ) + + return ( + key not in self.py_kernels + or self.py_kernels[key] is torch.library.fallthrough_kernel + ) + + return [ + key + for key in _DEFAULT_FALLTHROUGH_KEYS + if _may_use_fallthrough_instead_of_fallback(key) + ] + + @contextlib.contextmanager + def _register_as_effectful_op_temporarily(self): + from torch._higher_order_ops.effects import ( + _EffectType, + _register_effectful_op, + SIDE_EFFECTS, + ) + + try: + if self not in SIDE_EFFECTS: + _register_effectful_op(self, _EffectType.ORDERED) + yield + finally: + if self in SIDE_EFFECTS: + del SIDE_EFFECTS[self] + + # Use positional-only argument to avoid naming collision with aten ops arguments + # that are named "self". This way, all the aten ops can be called by kwargs. + def __call__(self, /, *args: _P.args, **kwargs: _P.kwargs) -> _T: + if _must_dispatch_in_python(args, kwargs): + # When any inputs are FakeScriptObject, we need to + # skip c++ dispatcher and dispatch in python through _get_dispatch of python_dispatcher + # because C++ dispatcher will check the schema and cannot recognize FakeScriptObject. + # + # Note: + # 1. We only register the torchbind op temporarily as effectful op because we only want + # the effect token functionalization logic to be applied during tracing. Otherwise, the behavior + # of the eagerly executing the op might change after tracing. + # 2. We don't want to register the op as effectful for all torchbind ops in ctor because this might + # cause unexpected behavior for some autograd.profiler ops e.g. profiler._record_function_exit._RecordFunction. + with self._register_as_effectful_op_temporarily(): + return self._dispatch_in_python( + self._fallthrough_keys(), *args, **kwargs + ) + return self._op(*args, **kwargs) + + def _dispatch_in_python( + self, fallthrough_keys: list[DispatchKey], *args: _P.args, **kwargs: _P.kwargs + ) -> _T: + non_fallthrough_keys = torch._C._dispatch_keyset_full() + for key in fallthrough_keys: + non_fallthrough_keys = non_fallthrough_keys.remove(key) + + dispatch_key_set = _compute_keyset(args, kwargs, non_fallthrough_keys) + dispatch_key = dispatch_key_set.highestPriorityTypeId() + + handler = ( + self._get_dispatch(dispatch_key) + if dispatch_key not in self._dispatch_cache + else self._dispatch_cache[dispatch_key] + ) + + if isinstance(handler, DispatchKey): + # fallthrough keys can be registered at runtime via torch.library.impl + # so need to add it to fallthrough_keys and re-dispatch. + if torch._C._dispatch_kernel_for_dispatch_key_is_fallthrough( + self.name(), dispatch_key + ): + return self._dispatch_in_python( + fallthrough_keys + [dispatch_key], + *args, + **kwargs, + ) + + raise RuntimeError( + f"Torchbind op {self} received a FakeScriptObject input when dispatching {handler}." + f" but no python implementation is found." + f" Please file an issue on this when you encounter this error." + f" This error can happen when you export or compile the model." + f" It can still happpen even if a C++ implementation for {dispatch_key}. " + f" has been registered. That's because FakeScriptObject purely lives in python and cannot work " + f" with a C++ implementation." + ) + + assert isinstance(handler, Callable) # type: ignore[arg-type] + return handler(*args, **kwargs) + + +def _must_dispatch_in_python(args, kwargs): + return pytree.tree_any( + lambda obj: isinstance( + obj, torch._library.fake_class_registry.FakeScriptObject + ), + (args, kwargs), + ) + + +def _has_script_object_arg(schema: torch.FunctionSchema) -> bool: + return any(isinstance(arg.type, torch.ClassType) for arg in schema.arguments) + + +# OpOverloadPacket class contains pointer to a base unresolved operator that doesn't correspond to a specific operator +# You can obtain an OpOverload object through attribute query. +class OpOverloadPacket(Generic[_P, _T]): + __file__: ClassVar[str] = "torch.ops" + + def __init__( + self, + qualified_op_name: str, + op_name: str, + op: Callable[_P, _T], + overload_names: list[str], + ) -> None: + # These attributes are accessible on the object through the properties + # defined below but are immutable + self._qualified_op_name = qualified_op_name + self.__name__ = op_name + self._op = op + self._overload_names = overload_names + self._dir: list[str] = [] + self._has_torchbind_op_overload = any( + _has_script_object_arg(schema) for schema in self._schemas.values() + ) + + # it's a no-op since OpOverloadPacket object is immutable and must be unique for a given op. + def __deepcopy__(self, memo=None): + return self + + def __repr__(self): + return "".format( + *self._qualified_op_name.split("::") + ) + + def __hash__(self): + return hash(self._op) + + def __str__(self): + return "{}.{}".format(*self._qualified_op_name.split("::")) + + @property + def op(self): + return self._op + + @property + def _schemas(self): + return { + overload_name: torch._C._get_schema(self._qualified_op_name, overload_name) + for overload_name in self._overload_names + } + + def __getattr__(self, key: str) -> OpOverload[_P, _T]: + # ensure that query for dunder attributes that does not exist on + # opoverloadpacket but instead exists on the self._op object does not unnecessarily call + # `_get_operation_overload` (which is an expensive operation). + # This is done to prevent any potential slowdown. This list can be extended + # if there exists other attributes like `__name__` that only exist on self._op and not on the + # opoverloadpacket. + # This is ok since we are guaranteed that an overload name for an aten op can't start with '__' + try: + if key.startswith("__"): + return getattr(self._op, key) + except AttributeError: + # for consistency because it seems weird to + # throw an attribute error with a message containing + # an object name different from the one the attribute + # query was performed on. + raise AttributeError( + f"'{str(self)}' can't have an overload name beginning with '__' and the " + f"underlying op {str(self._op)} has no attribute {key} either." + ) from None + + try: + # This is ok since we are guaranteed that an overload name for an aten op can't be 'default' + use_key = "" if key == "default" else key + # TODO: disallow access to overloads registered by JIT + op_dk_tags = torch._C._get_operation_overload( + self._qualified_op_name, use_key + ) + if op_dk_tags is None: + raise AttributeError( + f"The underlying op of '{str(self)}' has no overload name '{key}'" + ) + + op_, op_dk_, tags = op_dk_tags + schema = torch._C._get_schema(self._qualified_op_name, use_key) + overload: OpOverload[_P, _T] = ( + OpOverload(self, op_, op_dk_, schema, tags) + if not _has_script_object_arg(schema) + else TorchBindOpOverload(self, op_, op_dk_, schema, tags) + ) + # cache the overload object + setattr(self, key, overload) + self._dir.append(key) + return overload + except RuntimeError: + raise AttributeError( + f"The underlying op of '{str(self)}' has no overload name '{key}'" + ) from None + + def __iter__(self) -> Iterator[str]: + return iter(self._dir) + + # Use positional-only argument to avoid naming collision with aten ops arguments + # that are named "self". This way, all the aten ops can be called by kwargs. + def __call__(self, /, *args: _P.args, **kwargs: _P.kwargs) -> _T: + # overloading __call__ to ensure torch.ops.foo.bar() + # is still callable from JIT + # We save the function ptr as the `op` attribute on + # OpOverloadPacket to access it here. + + # Directly calling OverloadPacket goes into C++, which will check + # the schema and cause an error for torchbind op when inputs consist of FakeScriptObject so we + # intercept it here and call TorchBindOpverload instead. + if self._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs): + return _call_overload_packet_from_python(self, *args, **kwargs) + return self._op(*args, **kwargs) + + # TODO: use this to make a __dir__ + def overloads(self): + return [n if n else "default" for n in self._overload_names] + + +# Note - this mirrors the logic of the cpp_function defined in jit/python/init.cpp +# _jit_get_operations, which calls _get_operation_for_overload_or_packet. +def _call_overload_packet_from_python( + op: OpOverloadPacket[_P, _T], *args: _P.args, **kwargs: _P.kwargs +) -> _T: + # Re-use the torch function handling logic in cpp + torch_function_called, ret = torch._C._maybe_call_torch_function_for_op_packet( + op, *args, **kwargs + ) + + if torch_function_called: + return ret + + # The following mirrors getOpWithStack. + # In cpp, we do a schema matching for the arguments, and call ToIValue to + # to check whether the arguments are valid. But need to do similar things here + # and check the schema whether the FakeScriptObject is the corresponding fake class + # of the actual class used in schema. + exceptions = {} + found_op = None + for overload_name in op.overloads(): + op_overload = getattr(op, overload_name) + try: + _ = torch._C._check_schema_allow_fake_script_object( + op_overload._schema, *args, **kwargs + ) + found_op = op_overload + break + except RuntimeError as e: + exceptions[overload_name] = e + + if found_op: + return found_op(*args, **kwargs) + + err_msg = ( + f"Fail to match any TorchBindOverload of {op} with following exceptions:\n" + ) + for key, msg in exceptions.items(): + err_msg += f"Overload name {key}:\n {msg}\n" + raise RuntimeError(err_msg) + + +# Resolution of torch.fn is different from torch.ops.aten.fn +# torch.fn uses the Python argparser, matches with the +# appropriate schema, and calls into the unboxed version of the method +# torch.ops.aten.fn resolution is done via the mechanism defined in JIT. +# JIT creates a stack of all the overloads and then tries to match the +# correct one at runtime and always calls into the boxed version of the method +# Autograd codegen creates VariableType, TracerType, +# inplace or view type and python bindings. +# Aten codegen generates tensor methods for the tensor class. + +# _OpNamespace is a subclass of ModuleType because the torch script +# allows attribute lookups on modules only. Since we want torch.ops.foo.bar() +# to work from script, we need to ensure ops and foo are modules + + +class _OpNamespace(types.ModuleType): + """ + An op namespace to dynamically bind Operators into Python. + + Say a user has created a custom Operator called "my_namespace::my_op". To + call this op, the user will write torch.ops.my_namespace.my_op(...). + At startup, this operation will not yet be bound into Python. Instead, the + following sequence of magic tricks will occur: + 1. `torch.ops.my_namespace` will invoke the `__getattr__` magic method + on the `torch.ops` object, which will create a new `_OpNamespace` + object called `my_namespace` and set it as an attribute on the `ops` + object. + 2. `torch.ops.my_namespace.my_op` will then invoke `__getattr__` on + the `my_namespace` object, which will retrieve the operation via + `torch.get_operation`, a function bound from C++, and then in a similar + fashion bind this new object onto the `my_namespace` object. + 3. `torch.ops.my_namespace.my_op(...)` then calls this new operation + and subsequent accesses will incur no further lookup (the namespace and + operation will already exist). + """ + + __file__ = "torch.ops" + + def __init__(self, name: str) -> None: + super().__init__("torch.ops." + name) + self.name = name + self._dir: list[str] = [] + + def __iter__(self) -> Iterator[str]: + return iter(self._dir) + + def __getattr__(self, op_name: str) -> OpOverloadPacket: + if op_name in ("__origin__", "__self__"): + raise AttributeError( + f"Invalid attribute '{op_name}' for '_OpNamespace' '{self.name}'" + ) + + # Get the op `my_namespace::my_op` if available. This will also check + # for overloads and raise an exception if there are more than one. + namespace_name = self.name + qualified_op_name = f"{namespace_name}::{op_name}" + module_name = self.__module__ + "." + namespace_name + + try: + op, overload_names = _get_packet(qualified_op_name, module_name) + if op is None: + raise AttributeError( + f"'_OpNamespace' '{self.name}' object has no attribute '{op_name}'" + ) + except RuntimeError as e: + # Turn this into AttributeError so getattr(obj, key, default) + # works (this is called by TorchScript with __origin__) + raise AttributeError( + f"'_OpNamespace' '{self.name}' object has no attribute '{op_name}'" + ) from e + + op.__module__ = module_name + opoverloadpacket = OpOverloadPacket( + qualified_op_name, op_name, op, overload_names + ) + opoverloadpacket.__module__ = self.__module__ + "." + namespace_name + # cache the opoverloadpacket to ensure that each op corresponds to + # a unique OpOverloadPacket object + setattr(self, op_name, opoverloadpacket) + self._dir.append(op_name) + return opoverloadpacket + + +def _get_packet(qualname, op_module): + op, overload_names = torch._C._jit_get_operation(qualname) + if op is not None: + # let the script frontend know that op is identical to the builtin op + # with qualified_op_name + torch.jit._builtins._register_builtin(op, qualname) + op.__module__ = op_module + return op, overload_names + + +def _refresh_packet(packet): + op, overload_names = _get_packet(packet._qualified_op_name, packet._op.__module__) + assert op is not None + packet._op = op + packet._overload_names = overload_names + + +class _HigherOrderNamespace(types.ModuleType): + __file__ = "torch.ops" + + def __init__(self) -> None: + super().__init__("torch.ops.higher_order") + self._dir: list[str] = [] + + def __iter__(self) -> Iterator[str]: + return iter(self._dir) + + def __getattr__(self, name: str) -> HigherOrderOperator: + # Following _OpNamespace.__getattr__, we cache the op on this object. + op = _higher_order_ops.get(name, None) + if op is None: + raise AttributeError( + f"'_HigherOrderNamespace' 'torch.ops.higher_order' object has no attribute '{name}'" + ) + setattr(self, name, op) + self._dir.append(name) + return op + + +class _Ops(types.ModuleType): + __file__ = "_ops.py" + + def __init__(self): + super().__init__("torch.ops") + self.loaded_libraries = set() + self.higher_order = _HigherOrderNamespace() + self._dir = [] + + def __getattr__(self, name: str) -> _OpNamespace: + # Here we are creating `torch.ops.my_namespace` + namespace = _OpNamespace(name) + setattr(self, name, namespace) + self._dir.append(name) + return namespace + + def __iter__(self) -> Iterator[str]: + return iter(self._dir) + + def import_module(self, module): + """ + Imports a Python module that has torch.library registrations. + + Generally, to extend PyTorch with custom operators, a user will + create a Python module whose import triggers registration of + the custom operators via a torch.ops.load_library call or a call + to one or more torch.library.* APIs. + + It is unexpected for Python modules to have side effects, so some + linters and formatters will complain. Use this API to import Python + modules that contain these torch.library side effects. + + Args: + module (str): The name of the Python module to import + + """ + importlib.import_module(module) + + def load_library(self, path): + """ + Loads a shared library from the given path into the current process. + + The library being loaded may run global initialization code to register + custom operators with the PyTorch JIT runtime. This allows dynamically + loading custom operators. For this, you should compile your operator + and the static registration code into a shared library object, and then + call ``torch.ops.load_library('path/to/libcustom.so')`` to load the + shared object. + + After the library is loaded, it is added to the + ``torch.ops.loaded_libraries`` attribute, a set that may be inspected + for the paths of all libraries loaded using this function. + + Args: + path (str): A path to a shared library to load. + """ + if torch._running_with_deploy(): + return + + path = _utils_internal.resolve_library_path(path) + with dl_open_guard(): + # Import the shared library into the process, thus running its + # static (global) initialization code in order to register custom + # operators with the JIT. + ctypes.CDLL(path) + self.loaded_libraries.add(path) + + +# The ops "namespace" +ops: _Ops = _Ops() diff --git a/phivenv/Lib/site-packages/torch/_python_dispatcher.py b/phivenv/Lib/site-packages/torch/_python_dispatcher.py new file mode 100644 index 0000000000000000000000000000000000000000..931205a88f43e43d4c993395f3609af1324c6e4b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_python_dispatcher.py @@ -0,0 +1,182 @@ +# mypy: allow-untyped-defs +import re + +import torch._C as C + + +""" +PythonDispatcher class is a thin python-binding to C++ dispatcher and it +is designed to show how dispatcher precompute works. In particular, +it shows for a certain op `foo`, what the computed dispatch table looks +like after user register their kernels to certains dispatch keys. + +In the real C++ dispatcher we support many dispatch keys for different +functionalities. For simplicity PythonDispatcher only supports dispatch +keys for a single example of each use case. These use cases are listed below: + +- CPU/AutogradCPU: represents in-tree backends which we usually have dedicated inference & + autograd kernel in pytorch core library. + E.g. CPU, CUDA +- FPGA/AutogradOther: represents in-tree backends which we usually have backend specific + inference kernels, but they share the same autograd kernel specified in AutogradOther. + E.g. FPGA, SparseCsrCPU +- XLA/AutogradXLA: represents out-of-tree backends which we don't have either inference or autograd + kernel defined in pytorch core library. Backend owner is responsible for registering both + inference & autograd kernels in their extensions(e.g. torch-xla) for the operators they support. + E.g. XLA, XPU, MPS +- CompositeExplicitAutograd: alias key mapped to inference kernels of all backends like CPU, CUDA, XLA etc. + Kernels registered to this key MUST work for inference for all backends. +- Autograd: alias key mapped to autograd of all backends like AutogradCPU, AutogradXLA, AutogradOther. + Kernels registered to this key MUST work for autograd for all backends. +- CompositeImplicitAutograd: alias key CompositeImplicitAutograd = CompositeExplicitAutograd + Autograd + Kernels registered to this key MUST work for both inference + autograd for all backends. + +Note we only allow registrations to alias keys inside pytorch core library. E.g +you shouldn't register a CompositeImplicitAutograd or CompositeExplicitAutograd +kernel from torch-xla extension, instead you should upstream the kernel into +pytorch/pytorch repo so that it's available for all backends and continuously +tested even without the extension. + +Usage: + dispatcher = PythonDispatcher() + dispatcher.register(["CPU", "XLA", "CompositeImplicitAutograd"]) + print(dispatcher.dispatchTable()) # This tells you exactly which kernel is used for certain backend. + # For more debugging information + # print(dispatcher.keys()) + # print(dispatcher.registrations()) + # print(dispatcher.rawRegistrations()) + # print(dispatcher.rawDispatchTable()) +PythonDispatcher calls C++ dispatcher under the hood for to precompute dispatch table. +This file only provides the simplified API for developers, relevant test code is located in +test/test_dispatch.py +""" + + +class PythonDispatcher: + namespace = "__test__" + name = "foo" + # fmt: off + runtime_keys = [ + "CPU", "AutogradCPU", + "FPGA", "AutogradOther", + "XLA", "AutogradXLA", + "Lazy", "AutogradLazy", + ] + # fmt: on + alias_keys = [ + "CompositeExplicitAutograd", + "Autograd", + "CompositeImplicitAutograd", + ] + supported_keys = runtime_keys + alias_keys + + def __init__(self) -> None: + C._dispatch_check_invariants(self.name) # type: ignore[attr-defined] + self.ref = C._dispatch_library("FRAGMENT", self.namespace, "") + self.ref.def_("foo(Tensor x) -> Tensor") + + """ + Returns a list of dispatch keys supported by PythonDispatcher. + You can register kernels to these keys. + """ + + def keys(self): + return self.supported_keys + + """ + Register kernels to the target dispatchKeys. + dispatchKeys(list[str]): a list of dispatch keys that you want to register + your own kernel. Note that you don't need to write the kernel yourself in + this PythonDispatcher.E.g. for CPU key, a kernel(e.g fn_CPU for CPU) is + automatically generated and registered. + """ + + def register(self, dispatchKeys): + # Overridden is not supported and triggers a warning in C++ dispatcher. + if len(set(dispatchKeys)) != len(dispatchKeys): + raise RuntimeError( + f"Overridden is not allowed but found duplicates in {dispatchKeys}." + ) + # We currently forbid this in codegen instead of C++ dispatcher. + if ( + "CompositeImplicitAutograd" in dispatchKeys + and "CompositeExplicitAutograd" in dispatchKeys + ): + raise RuntimeError( + "Registration to both CompositeImplicitAutograd and CompositeExplicitAutograd is not allowed." + ) + for key in dispatchKeys: + if key not in self.supported_keys: + raise RuntimeError( + f"{key} is not supported, please select a dispatch key in {self.supported_keys}." + ) + self.ref.impl_t_t("foo", dispatch=key, debug="fn_" + key) + + """ + Helper function to format (key, kernel). + """ + + def _format_line(self, key, kernel): + return f"{key:<15} {kernel}\n" + + """ + Helper function to print a table header. + """ + + def _format_header(self, header): + s = f""" +{header} +""" + s += self._format_line("key", "kernel") + s += "---------------------------\n" + return s + + """ + Returns raw output of all registration info for debugging only. + Use registrations() for a simplified version. + """ + + def rawRegistrations(self): + return C._dispatch_dump(f"{self.namespace}::{self.name}") # type: ignore[attr-defined] + + """ + Returns raw output of computed dispatch table for debugging only. + Use dispatchTable() for a simplified version. + """ + + def rawDispatchTable(self): + return C._dispatch_dump_table(f"{self.namespace}::{self.name}") # type: ignore[attr-defined] + + """ + Returns a table(str) including all the registrations from users. + Note this includes registrations to both runtime keys and alias keys. + """ + + def registrations(self): + output = self._format_header("Registered Kernels") + state = self.rawRegistrations() + state_entries = state.split("\n") + for line in state_entries: + first = line.split(":")[0] + if any(first.startswith(k) for k in self.supported_keys): + kernel = line.split("::")[0].split(" ")[1] + output += self._format_line(first, kernel) + return output + + """ + Returns the computed dispatch table(str). Note this only include + runtime keys, registrations to alias keys have been decoded to their + mapped runtime keys. + """ + + def dispatchTable(self): + output = self._format_header("Computed Dispatch Table") + table = self.rawDispatchTable() + table_entries = table.split("\n") + regex = re.compile(r"registered at .*FallbackKernel\.cpp.*(\[)") + for line in table_entries: + k = line.split(":")[0] + if k in self.runtime_keys: + entry = regex.sub("[", line) + output += self._format_line(k, entry.split(": ")[1]) + return output diff --git a/phivenv/Lib/site-packages/torch/_size_docs.py b/phivenv/Lib/site-packages/torch/_size_docs.py new file mode 100644 index 0000000000000000000000000000000000000000..ed7622c547ba6f96a03b118669347c391497df52 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_size_docs.py @@ -0,0 +1,39 @@ +"""Adds docstrings to torch.Size functions""" + +import torch._C +from torch._C import _add_docstr as add_docstr + + +def add_docstr_all(method: str, docstr: str) -> None: + add_docstr(getattr(torch._C.Size, method), docstr) + + +add_docstr_all( + "numel", + """ +numel() -> int + +Returns the number of elements a :class:`torch.Tensor` with the given size would contain. + +More formally, for a tensor ``x = tensor.ones(10, 10)`` with size ``s = torch.Size([10, 10])``, +``x.numel() == x.size().numel() == s.numel() == 100`` holds true. + +Example:: + + >>> x=torch.ones(10, 10) + >>> s=x.size() + >>> s + torch.Size([10, 10]) + >>> s.numel() + 100 + >>> x.numel() == s.numel() + True + + +.. warning:: + + This function does not return the number of dimensions described by :class:`torch.Size`, but instead the number + of elements a :class:`torch.Tensor` with that size would contain. + +""", +) diff --git a/phivenv/Lib/site-packages/torch/_sources.py b/phivenv/Lib/site-packages/torch/_sources.py new file mode 100644 index 0000000000000000000000000000000000000000..26350a054b3ce7dc2ee73f91aa62b93778f37c24 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_sources.py @@ -0,0 +1,138 @@ +# mypy: allow-untyped-defs +import ast +import functools +import inspect +from textwrap import dedent +from typing import Any, NamedTuple, Optional + +from torch._C import ErrorReport +from torch._C._jit_tree_views import SourceRangeFactory + + +def get_source_lines_and_file( + obj: Any, + error_msg: Optional[str] = None, +) -> tuple[list[str], int, Optional[str]]: + """ + Wrapper around inspect.getsourcelines and inspect.getsourcefile. + + Returns: (sourcelines, file_lino, filename) + """ + filename = None # in case getsourcefile throws + try: + filename = inspect.getsourcefile(obj) + sourcelines, file_lineno = inspect.getsourcelines(obj) + except OSError as e: + msg = ( + f"Can't get source for {obj}. TorchScript requires source access in " + "order to carry out compilation, make sure original .py files are " + "available." + ) + if error_msg: + msg += "\n" + error_msg + raise OSError(msg) from e + + return sourcelines, file_lineno, filename + + +def normalize_source_lines(sourcelines: list[str]) -> list[str]: + """ + This helper function accepts a list of source lines. It finds the + indentation level of the function definition (`def`), then it indents + all lines in the function body to a point at or greater than that + level. This allows for comments and continued string literals that + are at a lower indentation than the rest of the code. + Args: + sourcelines: function source code, separated into lines by + the '\n' character + Returns: + A list of source lines that have been correctly aligned + """ + + def remove_prefix(text, prefix): + return text[text.startswith(prefix) and len(prefix) :] + + # Find the line and line number containing the function definition + idx = None + for i, l in enumerate(sourcelines): + if l.lstrip().startswith("def"): + idx = i + break + + # This will happen when the function is a lambda- we won't find "def" anywhere in the source + # lines in that case. Currently trying to JIT compile a lambda will throw an error up in + # `parse_def()`, but we might want to handle this case in the future. + if idx is None: + return sourcelines + + # Get a string representing the amount of leading whitespace + fn_def = sourcelines[idx] + whitespace = fn_def.split("def")[0] + + # Add this leading whitespace to all lines before and after the `def` + aligned_prefix = [ + whitespace + remove_prefix(s, whitespace) for s in sourcelines[:idx] + ] + aligned_suffix = [ + whitespace + remove_prefix(s, whitespace) for s in sourcelines[idx + 1 :] + ] + + # Put it together again + aligned_prefix.append(fn_def) + return aligned_prefix + aligned_suffix + + +# Thin wrapper around SourceRangeFactory to store extra metadata +# about the function-to-be-compiled. +class SourceContext(SourceRangeFactory): + def __init__( + self, + source, + filename, + file_lineno, + leading_whitespace_len, + uses_true_division=True, + funcname=None, + ): + super().__init__(source, filename, file_lineno, leading_whitespace_len) + self.uses_true_division = uses_true_division + self.filename = filename + self.funcname = funcname + + +@functools.cache +def make_source_context(*args): + return SourceContext(*args) + + +def fake_range(): + return SourceContext("", None, 0, 0).make_raw_range(0, 1) + + +class ParsedDef(NamedTuple): + ast: ast.Module + ctx: SourceContext + source: str + filename: Optional[str] + file_lineno: int + + +def parse_def(fn): + sourcelines, file_lineno, filename = get_source_lines_and_file( + fn, ErrorReport.call_stack() + ) + sourcelines = normalize_source_lines(sourcelines) + source = "".join(sourcelines) + dedent_src = dedent(source) + py_ast = ast.parse(dedent_src) + if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef): + raise RuntimeError( + f"Expected a single top-level function: {filename}:{file_lineno}" + ) + leading_whitespace_len = len(source.split("\n", 1)[0]) - len( + dedent_src.split("\n", 1)[0] + ) + ctx = make_source_context( + source, filename, file_lineno, leading_whitespace_len, True, fn.__name__ + ) + return ParsedDef(py_ast, ctx, source, filename, file_lineno) diff --git a/phivenv/Lib/site-packages/torch/_storage_docs.py b/phivenv/Lib/site-packages/torch/_storage_docs.py new file mode 100644 index 0000000000000000000000000000000000000000..6b7082a3db282339bb6c2cf4d2f0144d4bc79a64 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_storage_docs.py @@ -0,0 +1,42 @@ +# mypy: allow-untyped-defs +"""Adds docstrings to Storage functions""" + +import torch._C +from torch._C import _add_docstr as add_docstr + + +storage_classes = ["StorageBase"] + + +def add_docstr_all(method, docstr): + for cls_name in storage_classes: + cls = getattr(torch._C, cls_name) + try: + add_docstr(getattr(cls, method), docstr) + except AttributeError: + pass + + +add_docstr_all( + "from_file", + """ +from_file(filename, shared=False, nbytes=0) -> Storage + +Creates a CPU storage backed by a memory-mapped file. + +If ``shared`` is ``True``, then memory is shared between all processes. +All changes are written to the file. If ``shared`` is ``False``, then the changes on +the storage do not affect the file. + +``nbytes`` is the number of bytes of storage. If ``shared`` is ``False``, +then the file must contain at least ``nbytes`` bytes. If ``shared`` is +``True`` the file will be created if needed. (Note that for ``UntypedStorage`` +this argument differs from that of ``TypedStorage.from_file``) + +Args: + filename (str): file name to map + shared (bool): whether to share memory (whether ``MAP_SHARED`` or ``MAP_PRIVATE`` is passed to the + underlying `mmap(2) call `_) + nbytes (int): number of bytes of storage +""", +) diff --git a/phivenv/Lib/site-packages/torch/_streambase.py b/phivenv/Lib/site-packages/torch/_streambase.py new file mode 100644 index 0000000000000000000000000000000000000000..864c367df6cf0e28b677809983955056a9b687a7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_streambase.py @@ -0,0 +1,20 @@ +from typing_extensions import deprecated + +import torch + + +# Preserved only for BC reasons +@deprecated( + "`torch._streambase._StreamBase` is deprecated. Please use `torch.Stream` instead.", + category=FutureWarning, +) +class _StreamBase(torch.Stream): + pass + + +@deprecated( + "`torch._streambase._EventBase` is deprecated. Please use `torch.Event` instead.", + category=FutureWarning, +) +class _EventBase(torch.Event): + pass diff --git a/phivenv/Lib/site-packages/torch/_tensor.py b/phivenv/Lib/site-packages/torch/_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..16835d14e9bad70c1d4034e95d7850f10174a85f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_tensor.py @@ -0,0 +1,1796 @@ +# mypy: allow-untyped-defs +import copyreg +import enum +import functools +import warnings +from collections import OrderedDict +from copy import deepcopy +from numbers import Number +from typing import Any, Callable, cast, Optional, Union + +import torch +import torch._C as _C +from torch._namedtensor_internals import ( + check_serializing_named_tensor, + is_ellipsis, + resolve_ellipsis, + single_ellipsis_index, + unzip_namedshape, + update_names, +) +from torch.overrides import ( + get_default_nowrap_functions, + handle_torch_function, + has_torch_function, + has_torch_function_unary, + has_torch_function_variadic, +) + + +def _handle_torch_function_and_wrap_type_error_to_not_implemented(f): + assigned = functools.WRAPPER_ASSIGNMENTS + + @functools.wraps(f, assigned=assigned) + def wrapped(*args, **kwargs): + try: + # See https://github.com/pytorch/pytorch/issues/75462 + if has_torch_function(args): + return handle_torch_function(wrapped, args, *args, **kwargs) + return f(*args, **kwargs) + except TypeError: + return NotImplemented + + return wrapped + + +# Should not be used, this is kept only for BC of loading old serialized Tensor subclasses +def _rebuild_from_type(func, type, args, dict): + if type is Tensor: + return func(*args) + + ret = func(*args).as_subclass(type) + ret.__dict__ = dict + return ret + + +def _rebuild_from_type_v2(func, new_type, args, state): + ret = func(*args) + if type(ret) is not new_type: + ret = ret.as_subclass(new_type) + # Tensor does define __setstate__ even though it doesn't define + # __getstate__. So only use __setstate__ if it is NOT the one defined + # on Tensor + if ( + getattr(ret.__class__, "__setstate__", Tensor.__setstate__) + is not Tensor.__setstate__ + ): + ret.__setstate__(state) + else: + ret = torch._utils._set_obj_state(ret, state) + return ret + + +def _dtype_to_typestr(dtype): + # CUDA devices are little-endian and tensors are stored in native byte + # order. 1-byte entries are endian-agnostic. + return { + torch.complex64: " torch.TypedStorage + + Returns the underlying :class:`TypedStorage`. + + .. warning:: + + :class:`TypedStorage` is deprecated. It will be removed in the future, and + :class:`UntypedStorage` will be the only storage class. To access the + :class:`UntypedStorage` directly, use :attr:`Tensor.untyped_storage()`. + """ + if has_torch_function_unary(self): + return handle_torch_function(Tensor.storage, (self,), self) + + torch.storage._warn_typed_storage_removal(stacklevel=2) + return self._typed_storage() + + # For internal use only, to avoid raising deprecation warning + def _typed_storage(self): + untyped_storage = self.untyped_storage() + return torch.TypedStorage( + wrap_storage=untyped_storage, dtype=self.dtype, _internal=True + ) + + def _reduce_ex_internal(self, proto): + check_serializing_named_tensor(self) + + from torch.utils.hooks import warn_if_has_hooks + + # See Note [Don't serialize hooks] + warn_if_has_hooks(self) + backward_hooks: dict[Any, Any] = OrderedDict() + + skip_data = torch.serialization._serialization_tls.skip_data + materialize_fake_tensors = ( + torch.serialization._serialization_tls.materialize_fake_tensors + ) + + if self.device.type in ["xla", "maia"] or ( + not torch._C._has_storage(self) + and self.device.type == torch._C._get_privateuse1_backend_name() + ): + if skip_data: + raise RuntimeError( + "Cannot serialize tensors on backends with no storage under skip_data context manager" + ) + cpu_tensor = self.cpu() + return ( + torch._utils._rebuild_device_tensor_from_cpu_tensor, + (cpu_tensor, self.dtype, str(self.device), self.requires_grad), + ) + # Legacy comment that does not hold anymore. + # Note: Numpy array is chosen to be the rebuild component for XLA, MTIA, MAIA Tensors. + # We considered a few options: + # 1. CPU tensor can't be used here. + # Otherwise in torch.load CPU storage is reconstructed with randomly + # initialized data, moved onto backend device, and then storage is updated + # to the serialized content. This works perfectly for CPU/CUDA but not these backends; + # their tensors are disconnected with storage so they don't get the update. + # 2. Python list is not a good fit due to performance reason. + # `tolist()` converts every single element in the tensor into python objects + # and serialize them one by one. + if self.device.type in ["mtia"]: + # Convert BFloat16 tesors to Float32 before conversion to numpy, as numpy doesn't + # support BFloat16. The rebuild tensor from numpy takes in the original self.dtype, + # this would reconstruct the BFloat16 tensor from numpy. + if skip_data: + raise RuntimeError( + "Cannot serialize tensors on backends with no storage under skip_data context manager" + ) + numpy_tensor = ( + self.cpu().numpy() + if self.dtype != torch.bfloat16 + else self.cpu().to(torch.float32).numpy() + ) + return ( + torch._utils._rebuild_device_tensor_from_numpy, + (numpy_tensor, self.dtype, str(self.device), self.requires_grad), + ) + if self.device.type == "meta": + # NB: This implementation BREAKS storage sharing. Current + # hypothesis is that no one cares for meta tensors. + if skip_data: + warnings.warn( + "Serializing tensors on the meta device under skip_data context manager is a no-op" + ) + arg_meta = ( + self.dtype, + tuple(self.size()), + self.stride(), + self.requires_grad, + ) + return (torch._utils._rebuild_meta_tensor_no_storage, arg_meta) + if self.is_quantized: + if skip_data: + raise RuntimeError( + "Cannot serialize qtensor under skip_data context manager, file an issue if you need this feature" + ) + # quantizer_params can be different type based on torch attribute + quantizer_params: Union[ + tuple[torch.qscheme, float, int], tuple[Any, Tensor, Tensor, int] + ] + if self.qscheme() == torch.per_tensor_affine: + quantizer_params = ( + torch.per_tensor_affine, + self.q_scale(), + self.q_zero_point(), + ) + elif self.qscheme() in ( + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + ): + # convert scales and zero points to tuple to avoid recursive calls + # when/if we get multi-axis quantized tensors in the future, the shape + # is recoverable from the main tensor shape + quantizer_params = ( + torch.per_channel_affine, + self.q_per_channel_scales(), + self.q_per_channel_zero_points(), + self.q_per_channel_axis(), + ) + else: + raise RuntimeError( + f"Serialization is not supported for tensors of type {self.qscheme()}" + ) + # TODO: Once we decide to break serialization FC, no longer + # need to wrap with TypedStorage + args_qtensor = ( + torch.storage.TypedStorage( + wrap_storage=self._typed_storage()._untyped_storage, + dtype=self.dtype, + _internal=True, + ), + self.storage_offset(), + tuple(self.size()), + self.stride(), + quantizer_params, + self.requires_grad, + backward_hooks, + ) + return (torch._utils._rebuild_qtensor, args_qtensor) + elif self.is_sparse: + if self.layout == torch.sparse_coo: + args_sparse = ( + self.layout, + (self._indices(), self._values(), self.size(), self.is_coalesced()), + ) + else: + raise NotImplementedError( + f"sparse tensor __reduce_ex__ for layout `{self.layout}`" + ) + return (torch._utils._rebuild_sparse_tensor, args_sparse) + elif self.layout in { + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + }: + if self.layout in {torch.sparse_csr, torch.sparse_bsr}: + compressed_indices, plain_indices = ( + self.crow_indices(), + self.col_indices(), + ) + else: + compressed_indices, plain_indices = ( + self.ccol_indices(), + self.row_indices(), + ) + args_sparse_compressed = ( + self.layout, + ( + compressed_indices, + plain_indices, + self.values(), + self.size(), + ), + ) + return (torch._utils._rebuild_sparse_tensor, args_sparse_compressed) + elif self.is_nested: + if skip_data: + raise RuntimeError( + "Cannot serialize nested tensor under skip_data context manager, file an issue if you need this feature" + ) + args_nested = ( + # NB: values() currently returns the storage as a buffer in an unsafe way. + # Ideally, we'd use a private API for this instead. TODO: Switch to this if + # we ever get around to adding it. + self.values(), + self._nested_tensor_size(), + self._nested_tensor_strides(), + self._nested_tensor_storage_offsets(), + ) + return (torch._utils._rebuild_nested_tensor, args_nested) + elif ( + type(self) is not torch.Tensor + and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__ + and ( + isinstance(self, torch._subclasses.functional_tensor.FunctionalTensor) + or ( + not isinstance(self, torch._subclasses.fake_tensor.FakeTensor) + and self.data_ptr() == 0 + ) + ) + ): + arg_wrapper_subclass = ( + type(self), + self.dtype, + tuple(self.size()), + self.stride(), + self.storage_offset(), + self.layout, + self.device, + self.requires_grad, + ) + return (torch._utils._rebuild_wrapper_subclass, arg_wrapper_subclass) + elif ( + type(self) is not torch.Tensor + and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__ + and ( + isinstance(self, torch._subclasses.fake_tensor.FakeTensor) + and not (skip_data and materialize_fake_tensors) + ) + ): + arg_wrapper_subclass = ( + type(self), + self.dtype, + tuple(self.size()), + self.stride(), + self.storage_offset(), + self.layout, + self.device, + self.requires_grad, + ) + return (torch._utils._rebuild_wrapper_subclass, arg_wrapper_subclass) + else: + v3_dtypes = torch.storage._new_dtypes() + if self.dtype in v3_dtypes: + rebuild_func = torch._utils._rebuild_tensor_v3 + storage = self.untyped_storage() + else: + # TODO: Once we decide to break serialization FC, no longer + # need to wrap with TypedStorage + rebuild_func = torch._utils._rebuild_tensor_v2 # type: ignore[assignment] + storage = torch.storage.TypedStorage( + wrap_storage=self._typed_storage()._untyped_storage, + dtype=self.dtype, + _internal=True, + ) # type: ignore[assignment] + + # TODO: remove hasattr, it's a hack to support versions of torch that + # don't have _subclasses + if ( + hasattr(torch, "_subclasses") + and isinstance(self, torch._subclasses.fake_tensor.FakeTensor) + and skip_data + ): + storage._fake_device = self.device + + args = ( + storage, + self.storage_offset(), + tuple(self.size()), + self.stride(), + self.requires_grad, + backward_hooks, + ) # previously was self._backward_hooks + + if isinstance(storage, torch.storage.UntypedStorage): + args = args + (self.dtype,) # type: ignore[assignment] + + metadata = torch._utils.get_tensor_metadata(self) + if metadata: + args = args + (metadata,) # type: ignore[assignment] + + return (rebuild_func, args) + + def __setstate__(self, state): + if has_torch_function_unary(self): + return handle_torch_function(Tensor.__setstate__, (self,), self, state) + # Warning: this method is NOT called when you torch.load() a tensor; + # that is managed by _rebuild_tensor_v2 + if not self.is_leaf: + raise RuntimeError("__setstate__ can be only called on leaf Tensors") + if len(state) == 4: + # legacy serialization of Tensor + self.set_(*state) + return + elif len(state) == 5: + # legacy serialization of Variable + self.data = state[0] + state = (state[3], state[4], state[2]) + # The setting of _backward_hooks is expected to be a no-op. + # See Note [Don't serialize hooks] + self.requires_grad, _, self._backward_hooks = state + + def __repr__(self, *, tensor_contents=None): + if has_torch_function_unary(self): + return handle_torch_function( + Tensor.__repr__, (self,), self, tensor_contents=tensor_contents + ) + # All strings are unicode in Python 3. + return torch._tensor_str._str(self, tensor_contents=tensor_contents) + + def backward( + self, gradient=None, retain_graph=None, create_graph=False, inputs=None + ): + r"""Computes the gradient of current tensor wrt graph leaves. + + The graph is differentiated using the chain rule. If the tensor is + non-scalar (i.e. its data has more than one element) and requires + gradient, the function additionally requires specifying a ``gradient``. + It should be a tensor of matching type and shape, that represents + the gradient of the differentiated function w.r.t. ``self``. + + This function accumulates gradients in the leaves - you might need to zero + ``.grad`` attributes or set them to ``None`` before calling it. + See :ref:`Default gradient layouts` + for details on the memory layout of accumulated gradients. + + .. note:: + + If you run any forward ops, create ``gradient``, and/or call ``backward`` + in a user-specified CUDA stream context, see + :ref:`Stream semantics of backward passes`. + + .. note:: + + When ``inputs`` are provided and a given input is not a leaf, + the current implementation will call its grad_fn (though it is not strictly needed to get this gradients). + It is an implementation detail on which the user should not rely. + See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details. + + Args: + gradient (Tensor, optional): The gradient of the function + being differentiated w.r.t. ``self``. + This argument can be omitted if ``self`` is a scalar. Defaults to ``None``. + retain_graph (bool, optional): If ``False``, the graph used to compute the grads will be freed; + If ``True``, it will be retained. The default is ``None``, in which case the value is inferred from ``create_graph`` + (i.e., the graph is retained only when higher-order derivative tracking is requested). Note that in nearly all cases + setting this option to True is not needed and often can be worked around in a much more efficient way. + create_graph (bool, optional): If ``True``, graph of the derivative will + be constructed, allowing to compute higher order derivative + products. Defaults to ``False``. + inputs (Sequence[Tensor], optional): Inputs w.r.t. which the gradient will be + accumulated into ``.grad``. All other tensors will be ignored. If not + provided, the gradient is accumulated into all the leaf Tensors that were + used to compute the :attr:`tensors`. Defaults to ``None``. + """ + if has_torch_function_unary(self): + return handle_torch_function( + Tensor.backward, + (self,), + self, + gradient=gradient, + retain_graph=retain_graph, + create_graph=create_graph, + inputs=inputs, + ) + torch.autograd.backward( + self, gradient, retain_graph, create_graph, inputs=inputs + ) + + def register_hook(self, hook): + r"""Registers a backward hook. + + The hook will be called every time a gradient with respect to the + Tensor is computed. The hook should have the following signature:: + + hook(grad) -> Tensor or None + + + The hook should not modify its argument, but it can optionally return + a new gradient which will be used in place of :attr:`grad`. + + This function returns a handle with a method ``handle.remove()`` + that removes the hook from the module. + + .. note:: + See :ref:`backward-hooks-execution` for more information on how when this hook + is executed, and how its execution is ordered relative to other hooks. + + Example:: + + >>> v = torch.tensor([0., 0., 0.], requires_grad=True) + >>> h = v.register_hook(lambda grad: grad * 2) # double the gradient + >>> v.backward(torch.tensor([1., 2., 3.])) + >>> v.grad + + 2 + 4 + 6 + [torch.FloatTensor of size (3,)] + + >>> h.remove() # removes the hook + """ + if has_torch_function_unary(self): + return handle_torch_function(Tensor.register_hook, (self,), self, hook) + if not self.requires_grad: + raise RuntimeError( + "cannot register a hook on a tensor that doesn't require gradient" + ) + if self._backward_hooks is None: + self._backward_hooks = OrderedDict() + if self.grad_fn is not None: + self.grad_fn._register_hook_dict(self) + + from torch.utils.hooks import RemovableHandle + + handle = RemovableHandle(self._backward_hooks) + self._backward_hooks[handle.id] = hook + return handle + + def register_post_accumulate_grad_hook(self, hook): + r"""Registers a backward hook that runs after grad accumulation. + + The hook will be called after all gradients for a tensor have been accumulated, + meaning that the .grad field has been updated on that tensor. The post + accumulate grad hook is ONLY applicable for leaf tensors (tensors without a + .grad_fn field). Registering this hook on a non-leaf tensor will error! + + The hook should have the following signature:: + + hook(param: Tensor) -> None + + Note that, unlike other autograd hooks, this hook operates on the tensor + that requires grad and not the grad itself. The hook can in-place modify + and access its Tensor argument, including its .grad field. + + This function returns a handle with a method ``handle.remove()`` + that removes the hook from the module. + + .. note:: + See :ref:`backward-hooks-execution` for more information on how when this hook + is executed, and how its execution is ordered relative to other hooks. Since + this hook runs during the backward pass, it will run in no_grad mode (unless + create_graph is True). You can use torch.enable_grad() to re-enable autograd + within the hook if you need it. + + Example:: + + >>> v = torch.tensor([0., 0., 0.], requires_grad=True) + >>> lr = 0.01 + >>> # simulate a simple SGD update + >>> h = v.register_post_accumulate_grad_hook(lambda p: p.add_(p.grad, alpha=-lr)) + >>> v.backward(torch.tensor([1., 2., 3.])) + >>> v + tensor([-0.0100, -0.0200, -0.0300], requires_grad=True) + + >>> h.remove() # removes the hook + """ + if has_torch_function_unary(self): + return handle_torch_function( + Tensor.register_post_accumulate_grad_hook, (self,), self, hook + ) + if not self.requires_grad: + raise RuntimeError( + "cannot register a hook on a tensor that doesn't require gradient" + ) + if self.grad_fn is not None: + raise RuntimeError( + "post accumulate grad hooks cannot be registered on non-leaf tensors" + ) + if self._post_accumulate_grad_hooks is None: + self._post_accumulate_grad_hooks: dict[Any, Any] = OrderedDict() + + from torch.utils.hooks import RemovableHandle + + handle = RemovableHandle(self._post_accumulate_grad_hooks) + self._post_accumulate_grad_hooks[handle.id] = hook + return handle + + def reinforce(self, reward): + def trim(str): + return "\n".join([line.strip() for line in str.split("\n")]) + + raise RuntimeError( + trim( + r"""reinforce() was removed. + Use torch.distributions instead. + See https://pytorch.org/docs/main/distributions.html + + Instead of: + + probs = policy_network(state) + action = probs.multinomial() + next_state, reward = env.step(action) + action.reinforce(reward) + action.backward() + + Use: + + probs = policy_network(state) + # NOTE: categorical is equivalent to what used to be called multinomial + m = torch.distributions.Categorical(probs) + action = m.sample() + next_state, reward = env.step(action) + loss = -m.log_prob(action) * reward + loss.backward() + """ + ) + ) + + detach = _C._add_docstr( + _C.TensorBase.detach, + r""" + Returns a new Tensor, detached from the current graph. + + The result will never require gradient. + + This method also affects forward mode AD gradients and the result will never + have forward mode AD gradients. + + .. note:: + + Returned Tensor shares the same storage with the original one. + In-place modifications on either of them will be seen, and may trigger + errors in correctness checks. + """, + ) + + detach_ = _C._add_docstr( + _C.TensorBase.detach_, + r""" + Detaches the Tensor from the graph that created it, making it a leaf. + Views cannot be detached in-place. + + This method also affects forward mode AD gradients and the result will never + have forward mode AD gradients. + """, + ) + + def is_shared(self): + r"""Checks if tensor is in shared memory. + + This is always ``True`` for CUDA tensors. + """ + if has_torch_function_unary(self): + return handle_torch_function(Tensor.is_shared, (self,), self) + return self._typed_storage()._is_shared() + + def share_memory_(self): + r"""Moves the underlying storage to shared memory. + + This is a no-op if the underlying storage is already in shared memory + and for CUDA tensors. Tensors in shared memory cannot be resized. + + See :meth:`torch.UntypedStorage.share_memory_` for more details. + """ + if has_torch_function_unary(self): + return handle_torch_function(Tensor.share_memory_, (self,), self) + self._typed_storage()._share_memory_() + return self + + def module_load(self, other, assign=False): + r"""Defines how to transform ``other`` when loading it into ``self`` in :meth:`~nn.Module.load_state_dict`. + + Used when :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. + + It is expected that ``self`` is a parameter or buffer in an ``nn.Module`` and ``other`` is the + value in the state dictionary with the corresponding key, this method defines + how ``other`` is remapped before being swapped with ``self`` via + :func:`~torch.utils.swap_tensors` in :meth:`~nn.Module.load_state_dict`. + + .. note:: + This method should always return a new object that is not ``self`` or ``other``. + For example, the default implementation returns ``self.copy_(other).detach()`` + if ``assign`` is ``False`` or ``other.detach()`` if ``assign`` is ``True``. + + Args: + other (Tensor): value in state dict with key corresponding to ``self`` + assign (bool): the assign argument passed to :meth:`nn.Module.load_state_dict` + + """ + if has_torch_function_variadic(self, other): + return handle_torch_function( + Tensor.module_load, (self, other), self, other, assign=assign + ) + + if assign: + return other.detach() + else: + return self.copy_(other).detach() + + def __reversed__(self): + r"""Reverses the tensor along dimension 0.""" + if has_torch_function_unary(self): + return handle_torch_function(Tensor.__reversed__, (self,), self) + if self.dim() == 0: + return self + else: + return self.flip(0) + + def norm( + self, + p: Optional[Union[float, str]] = "fro", + dim=None, + keepdim=False, + dtype=None, + ): + r"""See :func:`torch.norm`""" + if has_torch_function_unary(self): + return handle_torch_function( + Tensor.norm, (self,), self, p=p, dim=dim, keepdim=keepdim, dtype=dtype + ) + return torch.norm(self, p, dim, keepdim, dtype=dtype) + + def solve(self, other): + from torch._linalg_utils import solve + + return solve(self, other) + + def lstsq(self, other): + from torch._linalg_utils import lstsq + + return lstsq(self, other) + + def eig(self, eigenvectors=False): + from torch._linalg_utils import eig + + return eig(self, eigenvectors=eigenvectors) + + def symeig(self, eigenvectors=False): + from torch._linalg_utils import _symeig + + return _symeig(self, eigenvectors=eigenvectors) + + def lu(self, pivot=True, get_infos=False): + r"""See :func:`torch.lu`""" + # If get_infos is True, then we don't need to check for errors and vice versa + if has_torch_function_unary(self): + return handle_torch_function( + Tensor.lu, (self,), self, pivot=pivot, get_infos=get_infos + ) + + LU, pivots, infos = torch._lu_with_info( + self, pivot=pivot, check_errors=(not get_infos) + ) + if get_infos: + return LU, pivots, infos + else: + return LU, pivots + + def stft( + self, + n_fft: int, + hop_length: Optional[int] = None, + win_length: Optional[int] = None, + window: "Optional[Tensor]" = None, + center: bool = True, + pad_mode: str = "reflect", + normalized: bool = False, + onesided: Optional[bool] = None, + return_complex: Optional[bool] = None, + align_to_window: Optional[bool] = None, + ): + r"""See :func:`torch.stft` + + .. warning:: + This function changed signature at version 0.4.1. Calling with + the previous signature may cause error or return incorrect result. + """ + if has_torch_function_unary(self): + return handle_torch_function( + Tensor.stft, + (self,), + self, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + center=center, + pad_mode=pad_mode, + normalized=normalized, + onesided=onesided, + return_complex=return_complex, + align_to_window=align_to_window, + ) + return torch.stft( + self, + n_fft, + hop_length, + win_length, + window, + center, + pad_mode, + normalized, + onesided, + return_complex=return_complex, + align_to_window=align_to_window, + ) + + def istft( + self, + n_fft: int, + hop_length: Optional[int] = None, + win_length: Optional[int] = None, + window: "Optional[Tensor]" = None, + center: bool = True, + normalized: bool = False, + onesided: Optional[bool] = None, + length: Optional[int] = None, + return_complex: bool = False, + ): + r"""See :func:`torch.istft`""" + if has_torch_function_unary(self): + return handle_torch_function( + Tensor.istft, + (self,), + self, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + center=center, + normalized=normalized, + onesided=onesided, + length=length, + return_complex=return_complex, + ) + return torch.istft( + self, + n_fft, + hop_length, + win_length, + window, + center, + normalized, + onesided, + length, + return_complex=return_complex, + ) + + def resize(self, *sizes): + if has_torch_function_unary(self): + return handle_torch_function(Tensor.resize, (self,), self, *sizes) + warnings.warn("non-inplace resize is deprecated") + from torch.autograd._functions import Resize + + return Resize.apply(self, sizes) + + def resize_as(self, tensor): + if has_torch_function_variadic(self, tensor): + return handle_torch_function(Tensor.resize_as, (self, tensor), self, tensor) + warnings.warn("non-inplace resize_as is deprecated") + from torch.autograd._functions import Resize + + return Resize.apply(self, tensor.size()) + + def split(self, split_size, dim=0): + r"""See :func:`torch.split`""" + if has_torch_function_unary(self): + return handle_torch_function( + Tensor.split, (self,), self, split_size, dim=dim + ) + if isinstance(split_size, Tensor): + try: + split_size = int(split_size) + except ValueError: + pass + + if isinstance(split_size, (int, torch.SymInt)): + return torch._VF.split(self, split_size, dim) # type: ignore[attr-defined] + else: + return torch._VF.split_with_sizes(self, split_size, dim) + + def unique(self, sorted=True, return_inverse=False, return_counts=False, dim=None): + r"""Returns the unique elements of the input tensor. + + See :func:`torch.unique` + """ + if has_torch_function_unary(self): + return handle_torch_function( + Tensor.unique, + (self,), + self, + sorted=sorted, + return_inverse=return_inverse, + return_counts=return_counts, + dim=dim, + ) + return torch.unique( + self, + sorted=sorted, + return_inverse=return_inverse, + return_counts=return_counts, + dim=dim, + ) + + def unique_consecutive(self, return_inverse=False, return_counts=False, dim=None): + r"""Eliminates all but the first element from every consecutive group of equivalent elements. + + See :func:`torch.unique_consecutive` + """ + if has_torch_function_unary(self): + return handle_torch_function( + Tensor.unique_consecutive, + (self,), + self, + return_inverse=return_inverse, + return_counts=return_counts, + dim=dim, + ) + return torch.unique_consecutive( + self, return_inverse=return_inverse, return_counts=return_counts, dim=dim + ) + + @_handle_torch_function_and_wrap_type_error_to_not_implemented + def __rsub__(self, other): + return _C._VariableFunctions.rsub(self, other) + + @_handle_torch_function_and_wrap_type_error_to_not_implemented + def __rdiv__(self, other): + return self.reciprocal() * other + + __rtruediv__ = __rdiv__ + __itruediv__ = _C.TensorBase.__idiv__ + + __pow__ = cast( + Callable[ + ["torch._C.TensorBase", Union["Tensor", int, float, bool, complex]], + "Tensor", + ], + _handle_torch_function_and_wrap_type_error_to_not_implemented( + _C.TensorBase.pow + ), + ) + __ipow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented( + _C.TensorBase.pow_ + ) + + @_handle_torch_function_and_wrap_type_error_to_not_implemented + def __rmod__(self, other): + return torch.remainder(other, self) + + def __format__(self, format_spec): + if has_torch_function_unary(self): + return handle_torch_function(Tensor.__format__, (self,), self, format_spec) + if self.dim() == 0 and not self.is_meta and type(self) is Tensor: + # Use detach() here to avoid the warning when converting a scalar Tensor that + # requires gradients to a python number. It is ok for formatting. + return self.detach().item().__format__(format_spec) + return object.__format__(self, format_spec) + + @_handle_torch_function_and_wrap_type_error_to_not_implemented + def __rpow__(self, other): + return torch.pow(other, self) + + @_handle_torch_function_and_wrap_type_error_to_not_implemented + def __floordiv__(self, other): + return torch.floor_divide(self, other) + + @_handle_torch_function_and_wrap_type_error_to_not_implemented + def __rfloordiv__(self, other): + return torch.floor_divide(other, self) + + @_handle_torch_function_and_wrap_type_error_to_not_implemented + def __rlshift__(self, other): + return torch.bitwise_left_shift(other, self) + + @_handle_torch_function_and_wrap_type_error_to_not_implemented + def __rrshift__(self, other): + return torch.bitwise_right_shift(other, self) + + @_handle_torch_function_and_wrap_type_error_to_not_implemented + def __rmatmul__(self, other): + return torch.matmul(other, self) + + __pos__ = _C.TensorBase.positive + __neg__ = _C.TensorBase.neg + __abs__ = _C.TensorBase.abs + + def __len__(self): + if has_torch_function_unary(self): + return handle_torch_function(Tensor.__len__, (self,), self) + if self.dim() == 0: + raise TypeError("len() of a 0-d tensor") + if torch._C._get_tracing_state(): + warnings.warn( + "Using len to get tensor shape might cause the trace to be incorrect. " + "Recommended usage would be tensor.shape[0]. " + "Passing a tensor of different shape might lead to errors or silently give " + "incorrect results.", + category=torch.jit.TracerWarning, + stacklevel=2, + ) + return self.shape[0] + + def __iter__(self): + # NB: we use 'imap' and not 'map' here, so that in Python 2 we get a + # generator and don't eagerly perform all the indexes. This could + # save us work, and also helps keep trace ordering deterministic + # (e.g., if you zip(*hiddens), the eager map will force all the + # indexes of hiddens[0] before hiddens[1], while the generator + # map will interleave them.) + # NB: We have intentionally skipped __torch_function__ dispatch here. + # See gh-54457 + if self.dim() == 0: + raise TypeError("iteration over a 0-d tensor") + if torch._C._get_tracing_state(): + warnings.warn( + "Iterating over a tensor might cause the trace to be incorrect. " + "Passing a tensor of different shape won't change the number of " + "iterations executed (and might lead to errors or silently give " + "incorrect results).", + category=torch.jit.TracerWarning, + stacklevel=2, + ) + return iter(self.unbind(0)) + + def __hash__(self): + # Do NOT handle __torch_function__ here as user's default + # implementation that handle most functions will most likely do it wrong. + # It can be easily overridden by defining this method on the user + # subclass if needed. + return id(self) + + def __dir__(self): + if has_torch_function_unary(self): + return handle_torch_function(Tensor.__dir__, (self,), self) + tensor_methods = dir(self.__class__) + tensor_methods.remove("volatile") # deprecated + attrs = list(self.__dict__.keys()) + keys = tensor_methods + attrs + + # property only available dense, cuda tensors + if (not self.is_cuda) or self.is_sparse: + keys.remove("__cuda_array_interface__") + + return sorted(keys) + + # Numpy array interface, to support `numpy.asarray(tensor) -> ndarray` + __array_priority__ = 1000 # prefer Tensor ops over numpy ones + + def __array__(self, dtype=None): + if has_torch_function_unary(self): + return handle_torch_function(Tensor.__array__, (self,), self, dtype=dtype) + if dtype is None: + return self.numpy() + else: + return self.numpy().astype(dtype, copy=False) + + # Wrap Numpy array again in a suitable tensor when done, to support e.g. + # `numpy.sin(tensor) -> tensor` or `numpy.greater(tensor, 0) -> ByteTensor` + def __array_wrap__(self, array): + if has_torch_function_unary(self): + return handle_torch_function( + Tensor.__array_wrap__, (self,), self, array=array + ) + if array.dtype == bool: + # Workaround, torch has no built-in bool tensor + array = array.astype("uint8") + return torch.from_numpy(array) + + def __contains__(self, element: Any, /) -> bool: + r"""Check if `element` is present in tensor + + Args: + element (Tensor or scalar): element to be checked + for presence in current tensor" + """ + if has_torch_function_unary(self): + return handle_torch_function(Tensor.__contains__, (self,), self, element) + if isinstance( + element, (torch.Tensor, Number, torch.SymInt, torch.SymFloat, torch.SymBool) + ): + # type hint doesn't understand the __contains__ result array + return bool((element == self).any().item()) # type: ignore[union-attr] + + raise RuntimeError( + f"Tensor.__contains__ only supports Tensor or scalar, but you passed in a {type(element)}." + ) + + @property + def __cuda_array_interface__(self): + """Array view description for cuda tensors. + + See: + https://numba.pydata.org/numba-doc/dev/cuda/cuda_array_interface.html + """ + if has_torch_function_unary(self): + # TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185 + return handle_torch_function( + Tensor.__cuda_array_interface__.__get__, # type: ignore[attr-defined] + (self,), + self, + ) + + # raise AttributeError for unsupported tensors, so that + # hasattr(cpu_tensor, "__cuda_array_interface__") is False. + if not self.is_cuda: + raise AttributeError( + f"Can't get __cuda_array_interface__ on non-CUDA tensor type: {self.type()} " + "If CUDA data is required use tensor.cuda() to copy tensor to device memory." + ) + + if self.is_sparse: + raise AttributeError( + f"Can't get __cuda_array_interface__ on sparse type: {self.type()} " + "Use Tensor.to_dense() to convert to a dense tensor first." + ) + + # RuntimeError, matching tensor.__array__() behavior. + if self.requires_grad: + raise RuntimeError( + "Can't get __cuda_array_interface__ on Variable that requires grad. " + "If gradients aren't required, use var.detach() to get Variable that doesn't require grad." + ) + + typestr = _dtype_to_typestr(self.dtype) + itemsize = self.element_size() + shape = tuple(self.shape) + if self.is_contiguous(): + # __cuda_array_interface__ v2 requires the strides to be omitted + # (either not set or set to None) for C-contiguous arrays. + strides = None + else: + strides = tuple(s * itemsize for s in self.stride()) + data_ptr = self.data_ptr() if self.numel() > 0 else 0 + data = (data_ptr, False) # read-only is false + + return dict(typestr=typestr, shape=shape, strides=strides, data=data, version=2) + + def storage_type(self): + r"""storage_type() -> type + + Returns the type of the underlying storage. + + """ + if has_torch_function_unary(self): + return handle_torch_function(Tensor.storage_type, (self,), self) + + torch.storage._warn_typed_storage_removal() + + return self._typed_storage()._get_legacy_storage_class() + + def refine_names(self, *names): + r"""Refines the dimension names of :attr:`self` according to :attr:`names`. + + Refining is a special case of renaming that "lifts" unnamed dimensions. + A ``None`` dim can be refined to have any name; a named dim can only be + refined to have the same name. + + Because named tensors can coexist with unnamed tensors, refining names + gives a nice way to write named-tensor-aware code that works with both + named and unnamed tensors. + + :attr:`names` may contain up to one Ellipsis (``...``). + The Ellipsis is expanded greedily; it is expanded in-place to fill + :attr:`names` to the same length as ``self.dim()`` using names from the + corresponding indices of ``self.names``. + + Python 2 does not support Ellipsis but one may use a string literal + instead (``'...'``). + + Args: + names (iterable of str): The desired names of the output tensor. May + contain up to one Ellipsis. + + Examples:: + + >>> imgs = torch.randn(32, 3, 128, 128) + >>> named_imgs = imgs.refine_names('N', 'C', 'H', 'W') + >>> named_imgs.names + ('N', 'C', 'H', 'W') + + >>> tensor = torch.randn(2, 3, 5, 7, 11) + >>> tensor = tensor.refine_names('A', ..., 'B', 'C') + >>> tensor.names + ('A', None, None, 'B', 'C') + + .. warning:: + The named tensor API is experimental and subject to change. + + """ + if has_torch_function_unary(self): + return handle_torch_function(Tensor.refine_names, (self,), self, *names) + names = resolve_ellipsis(names, self.names, "refine_names") + return super().refine_names(names) + + def align_to(self, *names): + r"""Permutes the dimensions of the :attr:`self` tensor to match the order + specified in :attr:`names`, adding size-one dims for any new names. + + All of the dims of :attr:`self` must be named in order to use this method. + The resulting tensor is a view on the original tensor. + + All dimension names of :attr:`self` must be present in :attr:`names`. + :attr:`names` may contain additional names that are not in ``self.names``; + the output tensor has a size-one dimension for each of those new names. + + :attr:`names` may contain up to one Ellipsis (``...``). + The Ellipsis is expanded to be equal to all dimension names of :attr:`self` + that are not mentioned in :attr:`names`, in the order that they appear + in :attr:`self`. + + Python 2 does not support Ellipsis but one may use a string literal + instead (``'...'``). + + Args: + names (iterable of str): The desired dimension ordering of the + output tensor. May contain up to one Ellipsis that is expanded + to all unmentioned dim names of :attr:`self`. + + Examples:: + + >>> tensor = torch.randn(2, 2, 2, 2, 2, 2) + >>> named_tensor = tensor.refine_names('A', 'B', 'C', 'D', 'E', 'F') + + # Move the F and E dims to the front while keeping the rest in order + >>> named_tensor.align_to('F', 'E', ...) + + .. warning:: + The named tensor API is experimental and subject to change. + + """ + if has_torch_function_unary(self): + return handle_torch_function(Tensor.align_to, (self,), self, *names) + ellipsis_idx = single_ellipsis_index(names, "align_to") + if ellipsis_idx is None: + return super().align_to(names) + return super().align_to( + [name for name in names if not is_ellipsis(name)], ellipsis_idx + ) + + def unflatten(self, dim, sizes): # type: ignore[override] + r""" + unflatten(dim, sizes) -> Tensor + + See :func:`torch.unflatten`. + + """ + if has_torch_function_unary(self): + return handle_torch_function(Tensor.unflatten, (self,), self, dim, sizes) + + if not sizes: + raise RuntimeError("unflatten: sizes must be non-empty") + + names = None + if isinstance(sizes, OrderedDict) or ( + isinstance(sizes, (tuple, list)) and isinstance(sizes[0], (tuple, list)) + ): + names, sizes = unzip_namedshape(sizes) + return super().unflatten(dim, sizes, names) + else: + return super().unflatten(dim, sizes) + + def rename_(self, *names, **rename_map): + """In-place version of :meth:`~Tensor.rename`.""" + + if has_torch_function_unary(self): + return handle_torch_function( + Tensor.rename_, (self,), self, *names, **rename_map + ) + + # Note [rename_ / rename API] + # The Python API for these is different from the C++ API. In Python: + # 1) tensor.rename(*names) takes a vararglist of names + # 2) tensor.rename(**rename_map) takes a map of names to rename. + # C++ is static, making it difficult to implement similar behavior. + return update_names(self, names, rename_map, inplace=True) + + def rename(self, *names, **rename_map): + """Renames dimension names of :attr:`self`. + + There are two main usages: + + ``self.rename(**rename_map)`` returns a view on tensor that has dims + renamed as specified in the mapping :attr:`rename_map`. + + ``self.rename(*names)`` returns a view on tensor, renaming all + dimensions positionally using :attr:`names`. + Use ``self.rename(None)`` to drop names on a tensor. + + One cannot specify both positional args :attr:`names` and keyword args + :attr:`rename_map`. + + Examples:: + + >>> imgs = torch.rand(2, 3, 5, 7, names=('N', 'C', 'H', 'W')) + >>> renamed_imgs = imgs.rename(N='batch', C='channels') + >>> renamed_imgs.names + ('batch', 'channels', 'H', 'W') + + >>> renamed_imgs = imgs.rename(None) + >>> renamed_imgs.names + (None, None, None, None) + + >>> renamed_imgs = imgs.rename('batch', 'channel', 'height', 'width') + >>> renamed_imgs.names + ('batch', 'channel', 'height', 'width') + + .. warning:: + The named tensor API is experimental and subject to change. + + """ + if has_torch_function_unary(self): + return handle_torch_function( + Tensor.rename, (self,), self, *names, **rename_map + ) + + # See Note [rename_ / rename API] + return update_names(self, names, rename_map, inplace=False) + + def to_sparse_coo(self): + """Convert a tensor to :ref:`coordinate format `. + + Examples:: + + >>> dense = torch.randn(5, 5) + >>> sparse = dense.to_sparse_coo() + >>> sparse._nnz() + 25 + + """ + return self.to_sparse() + + def dim_order( + self, *, ambiguity_check: Union[bool, list[torch.memory_format]] = False + ): + """ + dim_order(ambiguity_check=False) -> tuple + + Returns the uniquely determined tuple of int describing the dim order or + physical layout of :attr:`self`. + + The dim order represents how dimensions are laid out in memory of dense tensors, + starting from the outermost to the innermost dimension. + + Note that the dim order may not always be uniquely determined. + If `ambiguity_check` is True, this function raises a RuntimeError when the dim order cannot be uniquely determined; + If `ambiguity_check` is a list of memory formats, this function raises a RuntimeError when tensor can not be interpreted + into exactly one of the given memory formats, or it cannot be uniquely determined. + If `ambiguity_check` is False, it will return one of legal dim order(s) without checking its uniqueness. + Otherwise, it will raise TypeError. + + Args: + ambiguity_check (bool or List[torch.memory_format]): The check method for ambiguity of dim order. + + Examples:: + + >>> torch.empty((2, 3, 5, 7)).dim_order() + (0, 1, 2, 3) + >>> torch.empty((2, 3, 5, 7)).transpose(1, 2).dim_order() + (0, 2, 1, 3) + >>> torch.empty((2, 3, 5, 7), memory_format=torch.channels_last).dim_order() + (0, 2, 3, 1) + >>> torch.empty((1, 2, 3, 4)).dim_order() + (0, 1, 2, 3) + >>> try: + ... torch.empty((1, 2, 3, 4)).dim_order(ambiguity_check=True) + ... except RuntimeError as e: + ... print(e) + The tensor does not have unique dim order, or cannot map to exact one of the given memory formats. + >>> torch.empty((1, 2, 3, 4)).dim_order( + ... ambiguity_check=[torch.contiguous_format, torch.channels_last] + ... ) # It can be mapped to contiguous format + (0, 1, 2, 3) + >>> try: + ... torch.empty((1, 2, 3, 4)).dim_order(ambiguity_check="ILLEGAL") + ... except TypeError as e: + ... print(e) + The ambiguity_check argument must be a bool or a list of memory formats. + + .. warning:: + The dim_order tensor API is experimental and subject to change. + """ + if has_torch_function_unary(self): + return handle_torch_function(Tensor.dim_order, (self,), self) + + if self.is_sparse: + raise AttributeError( + f"Can't get dim order on sparse type: {self.type()} " + "Use Tensor.to_dense() to convert to a dense tensor first." + ) + + # Sanity check ambiguity_check data types + if not isinstance(ambiguity_check, bool): + if not isinstance(ambiguity_check, list): + raise TypeError( + "The ambiguity_check argument must be a bool or a list of memory formats." + ) + for memory_format in ambiguity_check: + if not isinstance(memory_format, torch.memory_format): + raise TypeError( + "The ambiguity_check argument must be a bool or a list of memory formats." + ) + + def invalid_unique_memory_format(tensor, valid_memory_formats): + """ + Returns True if the tensor cannot be uniquely mapped to any of the given memory formats, False otherwise. + """ + + n_legality = 0 + + for memory_format in valid_memory_formats: + if tensor.is_contiguous(memory_format=memory_format): + n_legality += 1 + + return n_legality != 1 + + def has_multiple_dim_order(tensor): + """ + Returns True if there're multiple legal dim orders for given tensor, False otherwise. + + The tensor is considered to have multiple legal dim orders if either of the following conditions is met: + + * Singleton Dimensions: There's at least one singleteon dimension in the tensor. + Since their size is 1, they don't affect the memory offset (stride * index + is zero because index is always zero). Therefore, they can be placed anywhere + in the dimension order without changing how data is accessed. + * Same strides: Strides reflect how the tensor is stored in memory. + If any two dimensions have the same stride, swapping these dimensions won't + change how data is accessed, leading to multiple correct dimension orders. + """ + + sizes = tensor.size() + strides = tensor.stride() + + # Check if there are any duplicate strides + has_duplicate_strides = any( + earlier == later for earlier, later in zip(strides, strides[1:]) + ) + + # Check if there are any singleton dimensions + has_singleton_dims = any(size == 1 for size in sizes) + + return has_duplicate_strides or has_singleton_dims + + valid_memory_formats = ( + ambiguity_check if isinstance(ambiguity_check, list) else [] + ) + check_multiple_dim_order = ( + ambiguity_check if isinstance(ambiguity_check, bool) else True + ) + + if ( + check_multiple_dim_order and has_multiple_dim_order(self) + ) and invalid_unique_memory_format(self, valid_memory_formats): + raise RuntimeError( + "The tensor does not have unique dim order, or cannot map to exact one of the given memory formats." + ) + + import torch._prims_common as utils + + return tuple(utils.compute_elementwise_output_logical_to_physical_perm(self)) + + def _update_names(self, names, inplace): + if has_torch_function_unary(self): + return handle_torch_function( + Tensor._update_names, (self,), self, names, inplace + ) + + # See Note [rename_ / rename API] + if inplace: + return super().rename_(names) + else: + return super().rename(names) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + """ + This __torch_function__ implementation wraps subclasses such that + methods called on subclasses return a subclass instance instead of + a ``torch.Tensor`` instance. + + One corollary to this is that you need coverage for torch.Tensor + methods if implementing __torch_function__ for subclasses. + + We recommend always calling ``super().__torch_function__`` as the base + case when doing the above. + + While not mandatory, we recommend making `__torch_function__` a classmethod. + """ + if kwargs is None: + kwargs = {} + + if not all(issubclass(cls, t) for t in types): + return NotImplemented + + with _C.DisableTorchFunctionSubclass(): + ret = func(*args, **kwargs) + if func in get_default_nowrap_functions(): + return ret + else: + return _convert(ret, cls) + + __torch_dispatch__ = _C._disabled_torch_dispatch_impl + + def __dlpack__(self, stream=None): + """ + Creates a DLpack `capsule https://data-apis.org/array-api/latest/design_topics/data_interchange.html#data-interchange`_ + of the current tensor to be exported to other libraries. + + This function will be called from the `from_dlpack` method + of the library that will consume the capsule. `from_dlpack` passes the current + stream to this method as part of the specification. + + Args: + stream (integer or None): An optional Python integer representing a + pointer to a CUDA stream. The current stream is synchronized with + this stream before the capsule is created, and since the capsule + shares its storage with the tensor this make it safe to access from + both streams. If None or -1 is passed then no synchronization is performed. + If 1 (on CUDA) or 0 (on ROCM) then the default stream is used for + synchronization. + """ + if has_torch_function_unary(self): + return handle_torch_function(Tensor.__dlpack__, (self,), self, stream) + + # DLPack capsules can't capture all of PyTorch's semantics, + # so we prohibit exporting tensors that would lose their properties like + # requires_grad and having the conjugate bit set. + if self.requires_grad: + raise RuntimeError( + "Can't export tensors that require gradient, use tensor.detach()" + ) + if self.is_conj(): + raise RuntimeError("Can't export tensors with the conjugate bit set") + if self.layout != torch.strided: + raise RuntimeError( + "Can't export tensors with layout other than torch.strided" + ) + + if stream is not None and type(stream) is not int: + # Stream pointers in CUDA/ROCm are uniquely numbered and can + # be retrieved from their integer value. + raise TypeError("stream must be ``int`` or ``none``") + elif stream is not None and stream != -1: + if self.device.type == "cuda": + # NB: This logic handles the special case values for default + # streams and must be kept in sync with from_dlpack in + # torch/utils/dlpack.py + if stream == 1 and torch.version.hip is None: + stream = torch.cuda.default_stream() + elif stream == 0 and torch.version.hip is not None: + stream = torch.cuda.default_stream() + else: + stream = torch.cuda.ExternalStream(stream) + # Only synchronize on different streams + sync_stream = torch.cuda.current_stream() + if stream != sync_stream: + event = torch.cuda.Event() + event.record(sync_stream) + stream.wait_event(event) + if self.device.type == "xla": + import torch_xla + import torch_xla.utils.dlpack as xla_dlpack + + if ( + len(torch_xla.real_devices()) <= 0 + or "cuda" not in torch_xla.real_devices()[0].lower() + ): + raise RuntimeError( + "Can't export to dlpack an XLA tensor that is not on CUDA." + ) + return xla_dlpack.to_dlpack(self) + return torch.to_dlpack(self) + + def __dlpack_device__(self) -> tuple[enum.IntEnum, int]: + if has_torch_function_unary(self): + return handle_torch_function(Tensor.__dlpack_device__, (self,), self) + + from torch.utils.dlpack import DLDeviceType + + device = self.device + idx = device.index if device.index is not None else 0 + torch_device_type = device.type + if torch_device_type == "cuda" and torch.version.hip is not None: + device_type = DLDeviceType.kDLROCM + elif torch_device_type == "cpu" and self.is_pinned(): + device_type = DLDeviceType.kDLCPUPinned + elif torch_device_type == "cuda": + device_type = DLDeviceType.kDLGPU + elif torch_device_type == "cpu": + device_type = DLDeviceType.kDLCPU + elif torch_device_type == "xpu": + device_type = DLDeviceType.kDLOneAPI + elif self.device.type == "privateuse1": + device_type = DLDeviceType.kDLExtDev + elif torch_device_type == "xla": + import torch_xla + + if ( + len(torch_xla.real_devices()) <= 0 + or "cuda" not in torch_xla.real_devices()[0].lower() + ): + raise ValueError(f"Unknown device type {torch_device_type} for Dlpack") + + device_type = DLDeviceType.kDLGPU + else: + raise ValueError(f"Unknown device type {torch_device_type} for Dlpack") + return (device_type, idx) + + __module__ = "torch" + + +def _convert(ret, cls): + if cls is Tensor: + return ret + + if isinstance(ret, Tensor) and not isinstance(ret, cls): + ret = ret.as_subclass(cls) + + if isinstance(ret, (tuple, list)): + # Also handles things like namedtuples + ret = type(ret)(_convert(r, cls) for r in ret) + + return ret diff --git a/phivenv/Lib/site-packages/torch/_tensor_docs.py b/phivenv/Lib/site-packages/torch/_tensor_docs.py new file mode 100644 index 0000000000000000000000000000000000000000..1114c1fd02e01f2feae7976699fa598897ba7756 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_tensor_docs.py @@ -0,0 +1,6971 @@ +# mypy: allow-untyped-defs +"""Adds docstrings to Tensor functions""" + +import torch._C +from torch._C import _add_docstr as add_docstr +from torch._torch_docs import parse_kwargs, reproducibility_notes + + +def add_docstr_all(method: str, docstr: str) -> None: + add_docstr(getattr(torch._C.TensorBase, method), docstr) + + +common_args = parse_kwargs( + """ + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. +""" +) + +new_common_args = parse_kwargs( + """ + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + dtype (:class:`torch.dtype`, optional): the desired type of returned tensor. + Default: if None, same :class:`torch.dtype` as this tensor. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if None, same :class:`torch.device` as this tensor. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. +""" +) + +add_docstr_all( + "new_tensor", + """ +new_tensor(data, *, dtype=None, device=None, requires_grad=False, layout=torch.strided, \ +pin_memory=False) -> Tensor +""" + + r""" + +Returns a new Tensor with :attr:`data` as the tensor data. +By default, the returned Tensor has the same :class:`torch.dtype` and +:class:`torch.device` as this tensor. + +.. warning:: + + :func:`new_tensor` always copies :attr:`data`. If you have a Tensor + ``data`` and want to avoid a copy, use :func:`torch.Tensor.requires_grad_` + or :func:`torch.Tensor.detach`. + If you have a numpy array and want to avoid a copy, use + :func:`torch.from_numpy`. + +.. warning:: + + When data is a tensor `x`, :func:`new_tensor()` reads out 'the data' from whatever it is passed, + and constructs a leaf variable. Therefore ``tensor.new_tensor(x)`` is equivalent to ``x.detach().clone()`` + and ``tensor.new_tensor(x, requires_grad=True)`` is equivalent to ``x.detach().clone().requires_grad_(True)``. + The equivalents using ``detach()`` and ``clone()`` are recommended. + +Args: + data (array_like): The returned Tensor copies :attr:`data`. + +Keyword args: + {dtype} + {device} + {requires_grad} + {layout} + {pin_memory} + +Example:: + + >>> tensor = torch.ones((2,), dtype=torch.int8) + >>> data = [[0, 1], [2, 3]] + >>> tensor.new_tensor(data) + tensor([[ 0, 1], + [ 2, 3]], dtype=torch.int8) + +""".format(**new_common_args), +) + +add_docstr_all( + "new_full", + """ +new_full(size, fill_value, *, dtype=None, device=None, requires_grad=False, layout=torch.strided, \ +pin_memory=False) -> Tensor +""" + + r""" + +Returns a Tensor of size :attr:`size` filled with :attr:`fill_value`. +By default, the returned Tensor has the same :class:`torch.dtype` and +:class:`torch.device` as this tensor. + +Args: + fill_value (scalar): the number to fill the output tensor with. + +Keyword args: + {dtype} + {device} + {requires_grad} + {layout} + {pin_memory} + +Example:: + + >>> tensor = torch.ones((2,), dtype=torch.float64) + >>> tensor.new_full((3, 4), 3.141592) + tensor([[ 3.1416, 3.1416, 3.1416, 3.1416], + [ 3.1416, 3.1416, 3.1416, 3.1416], + [ 3.1416, 3.1416, 3.1416, 3.1416]], dtype=torch.float64) + +""".format(**new_common_args), +) + +add_docstr_all( + "new_empty", + """ +new_empty(size, *, dtype=None, device=None, requires_grad=False, layout=torch.strided, \ +pin_memory=False) -> Tensor +""" + + r""" + +Returns a Tensor of size :attr:`size` filled with uninitialized data. +By default, the returned Tensor has the same :class:`torch.dtype` and +:class:`torch.device` as this tensor. + +Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + +Keyword args: + {dtype} + {device} + {requires_grad} + {layout} + {pin_memory} + +Example:: + + >>> tensor = torch.ones(()) + >>> tensor.new_empty((2, 3)) + tensor([[ 5.8182e-18, 4.5765e-41, -1.0545e+30], + [ 3.0949e-41, 4.4842e-44, 0.0000e+00]]) + +""".format(**new_common_args), +) + +add_docstr_all( + "new_empty_strided", + """ +new_empty_strided(size, stride, dtype=None, device=None, requires_grad=False, layout=torch.strided, \ +pin_memory=False) -> Tensor +""" + + r""" + +Returns a Tensor of size :attr:`size` and strides :attr:`stride` filled with +uninitialized data. By default, the returned Tensor has the same +:class:`torch.dtype` and :class:`torch.device` as this tensor. + +Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + +Keyword args: + {dtype} + {device} + {requires_grad} + {layout} + {pin_memory} + +Example:: + + >>> tensor = torch.ones(()) + >>> tensor.new_empty_strided((2, 3), (3, 1)) + tensor([[ 5.8182e-18, 4.5765e-41, -1.0545e+30], + [ 3.0949e-41, 4.4842e-44, 0.0000e+00]]) + +""".format(**new_common_args), +) + +add_docstr_all( + "new_ones", + """ +new_ones(size, *, dtype=None, device=None, requires_grad=False, layout=torch.strided, \ +pin_memory=False) -> Tensor +""" + + r""" + +Returns a Tensor of size :attr:`size` filled with ``1``. +By default, the returned Tensor has the same :class:`torch.dtype` and +:class:`torch.device` as this tensor. + +Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + +Keyword args: + {dtype} + {device} + {requires_grad} + {layout} + {pin_memory} + +Example:: + + >>> tensor = torch.tensor((), dtype=torch.int32) + >>> tensor.new_ones((2, 3)) + tensor([[ 1, 1, 1], + [ 1, 1, 1]], dtype=torch.int32) + +""".format(**new_common_args), +) + +add_docstr_all( + "new_zeros", + """ +new_zeros(size, *, dtype=None, device=None, requires_grad=False, layout=torch.strided, \ +pin_memory=False) -> Tensor +""" + + r""" + +Returns a Tensor of size :attr:`size` filled with ``0``. +By default, the returned Tensor has the same :class:`torch.dtype` and +:class:`torch.device` as this tensor. + +Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + +Keyword args: + {dtype} + {device} + {requires_grad} + {layout} + {pin_memory} + +Example:: + + >>> tensor = torch.tensor((), dtype=torch.float64) + >>> tensor.new_zeros((2, 3)) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.]], dtype=torch.float64) + +""".format(**new_common_args), +) + +add_docstr_all( + "abs", + r""" +abs() -> Tensor + +See :func:`torch.abs` +""", +) + +add_docstr_all( + "abs_", + r""" +abs_() -> Tensor + +In-place version of :meth:`~Tensor.abs` +""", +) + +add_docstr_all( + "absolute", + r""" +absolute() -> Tensor + +Alias for :func:`abs` +""", +) + +add_docstr_all( + "absolute_", + r""" +absolute_() -> Tensor + +In-place version of :meth:`~Tensor.absolute` +Alias for :func:`abs_` +""", +) + +add_docstr_all( + "acos", + r""" +acos() -> Tensor + +See :func:`torch.acos` +""", +) + +add_docstr_all( + "acos_", + r""" +acos_() -> Tensor + +In-place version of :meth:`~Tensor.acos` +""", +) + +add_docstr_all( + "arccos", + r""" +arccos() -> Tensor + +See :func:`torch.arccos` +""", +) + +add_docstr_all( + "arccos_", + r""" +arccos_() -> Tensor + +In-place version of :meth:`~Tensor.arccos` +""", +) + +add_docstr_all( + "acosh", + r""" +acosh() -> Tensor + +See :func:`torch.acosh` +""", +) + +add_docstr_all( + "acosh_", + r""" +acosh_() -> Tensor + +In-place version of :meth:`~Tensor.acosh` +""", +) + +add_docstr_all( + "arccosh", + r""" +acosh() -> Tensor + +See :func:`torch.arccosh` +""", +) + +add_docstr_all( + "arccosh_", + r""" +acosh_() -> Tensor + +In-place version of :meth:`~Tensor.arccosh` +""", +) + +add_docstr_all( + "add", + r""" +add(other, *, alpha=1) -> Tensor + +Add a scalar or tensor to :attr:`self` tensor. If both :attr:`alpha` +and :attr:`other` are specified, each element of :attr:`other` is scaled by +:attr:`alpha` before being used. + +When :attr:`other` is a tensor, the shape of :attr:`other` must be +:ref:`broadcastable ` with the shape of the underlying +tensor + +See :func:`torch.add` +""", +) + +add_docstr_all( + "add_", + r""" +add_(other, *, alpha=1) -> Tensor + +In-place version of :meth:`~Tensor.add` +""", +) + +add_docstr_all( + "addbmm", + r""" +addbmm(batch1, batch2, *, beta=1, alpha=1) -> Tensor + +See :func:`torch.addbmm` +""", +) + +add_docstr_all( + "addbmm_", + r""" +addbmm_(batch1, batch2, *, beta=1, alpha=1) -> Tensor + +In-place version of :meth:`~Tensor.addbmm` +""", +) + +add_docstr_all( + "addcdiv", + r""" +addcdiv(tensor1, tensor2, *, value=1) -> Tensor + +See :func:`torch.addcdiv` +""", +) + +add_docstr_all( + "addcdiv_", + r""" +addcdiv_(tensor1, tensor2, *, value=1) -> Tensor + +In-place version of :meth:`~Tensor.addcdiv` +""", +) + +add_docstr_all( + "addcmul", + r""" +addcmul(tensor1, tensor2, *, value=1) -> Tensor + +See :func:`torch.addcmul` +""", +) + +add_docstr_all( + "addcmul_", + r""" +addcmul_(tensor1, tensor2, *, value=1) -> Tensor + +In-place version of :meth:`~Tensor.addcmul` +""", +) + +add_docstr_all( + "addmm", + r""" +addmm(mat1, mat2, *, beta=1, alpha=1) -> Tensor + +See :func:`torch.addmm` +""", +) + +add_docstr_all( + "addmm_", + r""" +addmm_(mat1, mat2, *, beta=1, alpha=1) -> Tensor + +In-place version of :meth:`~Tensor.addmm` +""", +) + +add_docstr_all( + "addmv", + r""" +addmv(mat, vec, *, beta=1, alpha=1) -> Tensor + +See :func:`torch.addmv` +""", +) + +add_docstr_all( + "addmv_", + r""" +addmv_(mat, vec, *, beta=1, alpha=1) -> Tensor + +In-place version of :meth:`~Tensor.addmv` +""", +) + +add_docstr_all( + "sspaddmm", + r""" +sspaddmm(mat1, mat2, *, beta=1, alpha=1) -> Tensor + +See :func:`torch.sspaddmm` +""", +) + +add_docstr_all( + "smm", + r""" +smm(mat) -> Tensor + +See :func:`torch.smm` +""", +) + +add_docstr_all( + "addr", + r""" +addr(vec1, vec2, *, beta=1, alpha=1) -> Tensor + +See :func:`torch.addr` +""", +) + +add_docstr_all( + "addr_", + r""" +addr_(vec1, vec2, *, beta=1, alpha=1) -> Tensor + +In-place version of :meth:`~Tensor.addr` +""", +) + +add_docstr_all( + "align_as", + r""" +align_as(other) -> Tensor + +Permutes the dimensions of the :attr:`self` tensor to match the dimension order +in the :attr:`other` tensor, adding size-one dims for any new names. + +This operation is useful for explicit broadcasting by names (see examples). + +All of the dims of :attr:`self` must be named in order to use this method. +The resulting tensor is a view on the original tensor. + +All dimension names of :attr:`self` must be present in ``other.names``. +:attr:`other` may contain named dimensions that are not in ``self.names``; +the output tensor has a size-one dimension for each of those new names. + +To align a tensor to a specific order, use :meth:`~Tensor.align_to`. + +Examples:: + + # Example 1: Applying a mask + >>> mask = torch.randint(2, [127, 128], dtype=torch.bool).refine_names('W', 'H') + >>> imgs = torch.randn(32, 128, 127, 3, names=('N', 'H', 'W', 'C')) + >>> imgs.masked_fill_(mask.align_as(imgs), 0) + + + # Example 2: Applying a per-channel-scale + >>> def scale_channels(input, scale): + >>> scale = scale.refine_names('C') + >>> return input * scale.align_as(input) + + >>> num_channels = 3 + >>> scale = torch.randn(num_channels, names=('C',)) + >>> imgs = torch.rand(32, 128, 128, num_channels, names=('N', 'H', 'W', 'C')) + >>> more_imgs = torch.rand(32, num_channels, 128, 128, names=('N', 'C', 'H', 'W')) + >>> videos = torch.randn(3, num_channels, 128, 128, 128, names=('N', 'C', 'H', 'W', 'D')) + + # scale_channels is agnostic to the dimension order of the input + >>> scale_channels(imgs, scale) + >>> scale_channels(more_imgs, scale) + >>> scale_channels(videos, scale) + +.. warning:: + The named tensor API is experimental and subject to change. + +""", +) + +add_docstr_all( + "all", + r""" +all(dim=None, keepdim=False) -> Tensor + +See :func:`torch.all` +""", +) + +add_docstr_all( + "allclose", + r""" +allclose(other, rtol=1e-05, atol=1e-08, equal_nan=False) -> Tensor + +See :func:`torch.allclose` +""", +) + +add_docstr_all( + "angle", + r""" +angle() -> Tensor + +See :func:`torch.angle` +""", +) + +add_docstr_all( + "any", + r""" +any(dim=None, keepdim=False) -> Tensor + +See :func:`torch.any` +""", +) + +add_docstr_all( + "apply_", + r""" +apply_(callable) -> Tensor + +Applies the function :attr:`callable` to each element in the tensor, replacing +each element with the value returned by :attr:`callable`. + +.. note:: + + This function only works with CPU tensors and should not be used in code + sections that require high performance. +""", +) + +add_docstr_all( + "asin", + r""" +asin() -> Tensor + +See :func:`torch.asin` +""", +) + +add_docstr_all( + "asin_", + r""" +asin_() -> Tensor + +In-place version of :meth:`~Tensor.asin` +""", +) + +add_docstr_all( + "arcsin", + r""" +arcsin() -> Tensor + +See :func:`torch.arcsin` +""", +) + +add_docstr_all( + "arcsin_", + r""" +arcsin_() -> Tensor + +In-place version of :meth:`~Tensor.arcsin` +""", +) + +add_docstr_all( + "asinh", + r""" +asinh() -> Tensor + +See :func:`torch.asinh` +""", +) + +add_docstr_all( + "asinh_", + r""" +asinh_() -> Tensor + +In-place version of :meth:`~Tensor.asinh` +""", +) + +add_docstr_all( + "arcsinh", + r""" +arcsinh() -> Tensor + +See :func:`torch.arcsinh` +""", +) + +add_docstr_all( + "arcsinh_", + r""" +arcsinh_() -> Tensor + +In-place version of :meth:`~Tensor.arcsinh` +""", +) + +add_docstr_all( + "as_strided", + r""" +as_strided(size, stride, storage_offset=None) -> Tensor + +See :func:`torch.as_strided` +""", +) + +add_docstr_all( + "as_strided_", + r""" +as_strided_(size, stride, storage_offset=None) -> Tensor + +In-place version of :meth:`~Tensor.as_strided` +""", +) + +add_docstr_all( + "atan", + r""" +atan() -> Tensor + +See :func:`torch.atan` +""", +) + +add_docstr_all( + "atan_", + r""" +atan_() -> Tensor + +In-place version of :meth:`~Tensor.atan` +""", +) + +add_docstr_all( + "arctan", + r""" +arctan() -> Tensor + +See :func:`torch.arctan` +""", +) + +add_docstr_all( + "arctan_", + r""" +arctan_() -> Tensor + +In-place version of :meth:`~Tensor.arctan` +""", +) + +add_docstr_all( + "atan2", + r""" +atan2(other) -> Tensor + +See :func:`torch.atan2` +""", +) + +add_docstr_all( + "atan2_", + r""" +atan2_(other) -> Tensor + +In-place version of :meth:`~Tensor.atan2` +""", +) + +add_docstr_all( + "arctan2", + r""" +arctan2(other) -> Tensor + +See :func:`torch.arctan2` +""", +) + +add_docstr_all( + "arctan2_", + r""" +atan2_(other) -> Tensor + +In-place version of :meth:`~Tensor.arctan2` +""", +) + +add_docstr_all( + "atanh", + r""" +atanh() -> Tensor + +See :func:`torch.atanh` +""", +) + +add_docstr_all( + "atanh_", + r""" +atanh_(other) -> Tensor + +In-place version of :meth:`~Tensor.atanh` +""", +) + +add_docstr_all( + "arctanh", + r""" +arctanh() -> Tensor + +See :func:`torch.arctanh` +""", +) + +add_docstr_all( + "arctanh_", + r""" +arctanh_(other) -> Tensor + +In-place version of :meth:`~Tensor.arctanh` +""", +) + +add_docstr_all( + "baddbmm", + r""" +baddbmm(batch1, batch2, *, beta=1, alpha=1) -> Tensor + +See :func:`torch.baddbmm` +""", +) + +add_docstr_all( + "baddbmm_", + r""" +baddbmm_(batch1, batch2, *, beta=1, alpha=1) -> Tensor + +In-place version of :meth:`~Tensor.baddbmm` +""", +) + +add_docstr_all( + "bernoulli", + r""" +bernoulli(*, generator=None) -> Tensor + +Returns a result tensor where each :math:`\texttt{result[i]}` is independently +sampled from :math:`\text{Bernoulli}(\texttt{self[i]})`. :attr:`self` must have +floating point ``dtype``, and the result will have the same ``dtype``. + +See :func:`torch.bernoulli` +""", +) + +add_docstr_all( + "bernoulli_", + r""" +bernoulli_(p=0.5, *, generator=None) -> Tensor + +Fills each location of :attr:`self` with an independent sample from +:math:`\text{Bernoulli}(\texttt{p})`. :attr:`self` can have integral +``dtype``. + +:attr:`p` should either be a scalar or tensor containing probabilities to be +used for drawing the binary random number. + +If it is a tensor, the :math:`\text{i}^{th}` element of :attr:`self` tensor +will be set to a value sampled from +:math:`\text{Bernoulli}(\texttt{p\_tensor[i]})`. In this case `p` must have +floating point ``dtype``. + +See also :meth:`~Tensor.bernoulli` and :func:`torch.bernoulli` +""", +) + +add_docstr_all( + "bincount", + r""" +bincount(weights=None, minlength=0) -> Tensor + +See :func:`torch.bincount` +""", +) + +add_docstr_all( + "bitwise_not", + r""" +bitwise_not() -> Tensor + +See :func:`torch.bitwise_not` +""", +) + +add_docstr_all( + "bitwise_not_", + r""" +bitwise_not_() -> Tensor + +In-place version of :meth:`~Tensor.bitwise_not` +""", +) + +add_docstr_all( + "bitwise_and", + r""" +bitwise_and() -> Tensor + +See :func:`torch.bitwise_and` +""", +) + +add_docstr_all( + "bitwise_and_", + r""" +bitwise_and_() -> Tensor + +In-place version of :meth:`~Tensor.bitwise_and` +""", +) + +add_docstr_all( + "bitwise_or", + r""" +bitwise_or() -> Tensor + +See :func:`torch.bitwise_or` +""", +) + +add_docstr_all( + "bitwise_or_", + r""" +bitwise_or_() -> Tensor + +In-place version of :meth:`~Tensor.bitwise_or` +""", +) + +add_docstr_all( + "bitwise_xor", + r""" +bitwise_xor() -> Tensor + +See :func:`torch.bitwise_xor` +""", +) + +add_docstr_all( + "bitwise_xor_", + r""" +bitwise_xor_() -> Tensor + +In-place version of :meth:`~Tensor.bitwise_xor` +""", +) + +add_docstr_all( + "bitwise_left_shift", + r""" +bitwise_left_shift(other) -> Tensor + +See :func:`torch.bitwise_left_shift` +""", +) + +add_docstr_all( + "bitwise_left_shift_", + r""" +bitwise_left_shift_(other) -> Tensor + +In-place version of :meth:`~Tensor.bitwise_left_shift` +""", +) + +add_docstr_all( + "bitwise_right_shift", + r""" +bitwise_right_shift(other) -> Tensor + +See :func:`torch.bitwise_right_shift` +""", +) + +add_docstr_all( + "bitwise_right_shift_", + r""" +bitwise_right_shift_(other) -> Tensor + +In-place version of :meth:`~Tensor.bitwise_right_shift` +""", +) + +add_docstr_all( + "broadcast_to", + r""" +broadcast_to(shape) -> Tensor + +See :func:`torch.broadcast_to`. +""", +) + +add_docstr_all( + "logical_and", + r""" +logical_and() -> Tensor + +See :func:`torch.logical_and` +""", +) + +add_docstr_all( + "logical_and_", + r""" +logical_and_() -> Tensor + +In-place version of :meth:`~Tensor.logical_and` +""", +) + +add_docstr_all( + "logical_not", + r""" +logical_not() -> Tensor + +See :func:`torch.logical_not` +""", +) + +add_docstr_all( + "logical_not_", + r""" +logical_not_() -> Tensor + +In-place version of :meth:`~Tensor.logical_not` +""", +) + +add_docstr_all( + "logical_or", + r""" +logical_or() -> Tensor + +See :func:`torch.logical_or` +""", +) + +add_docstr_all( + "logical_or_", + r""" +logical_or_() -> Tensor + +In-place version of :meth:`~Tensor.logical_or` +""", +) + +add_docstr_all( + "logical_xor", + r""" +logical_xor() -> Tensor + +See :func:`torch.logical_xor` +""", +) + +add_docstr_all( + "logical_xor_", + r""" +logical_xor_() -> Tensor + +In-place version of :meth:`~Tensor.logical_xor` +""", +) + +add_docstr_all( + "bmm", + r""" +bmm(batch2) -> Tensor + +See :func:`torch.bmm` +""", +) + +add_docstr_all( + "cauchy_", + r""" +cauchy_(median=0, sigma=1, *, generator=None) -> Tensor + +Fills the tensor with numbers drawn from the Cauchy distribution: + +.. math:: + + f(x) = \dfrac{1}{\pi} \dfrac{\sigma}{(x - \text{median})^2 + \sigma^2} + +.. note:: + Sigma (:math:`\sigma`) is used to denote the scale parameter in Cauchy distribution. +""", +) + +add_docstr_all( + "ceil", + r""" +ceil() -> Tensor + +See :func:`torch.ceil` +""", +) + +add_docstr_all( + "ceil_", + r""" +ceil_() -> Tensor + +In-place version of :meth:`~Tensor.ceil` +""", +) + +add_docstr_all( + "cholesky", + r""" +cholesky(upper=False) -> Tensor + +See :func:`torch.cholesky` +""", +) + +add_docstr_all( + "cholesky_solve", + r""" +cholesky_solve(input2, upper=False) -> Tensor + +See :func:`torch.cholesky_solve` +""", +) + +add_docstr_all( + "cholesky_inverse", + r""" +cholesky_inverse(upper=False) -> Tensor + +See :func:`torch.cholesky_inverse` +""", +) + +add_docstr_all( + "clamp", + r""" +clamp(min=None, max=None) -> Tensor + +See :func:`torch.clamp` +""", +) + +add_docstr_all( + "clamp_", + r""" +clamp_(min=None, max=None) -> Tensor + +In-place version of :meth:`~Tensor.clamp` +""", +) + +add_docstr_all( + "clip", + r""" +clip(min=None, max=None) -> Tensor + +Alias for :meth:`~Tensor.clamp`. +""", +) + +add_docstr_all( + "clip_", + r""" +clip_(min=None, max=None) -> Tensor + +Alias for :meth:`~Tensor.clamp_`. +""", +) + +add_docstr_all( + "clone", + r""" +clone(*, memory_format=torch.preserve_format) -> Tensor + +See :func:`torch.clone` +""".format(**common_args), +) + +add_docstr_all( + "coalesce", + r""" +coalesce() -> Tensor + +Returns a coalesced copy of :attr:`self` if :attr:`self` is an +:ref:`uncoalesced tensor `. + +Returns :attr:`self` if :attr:`self` is a coalesced tensor. + +.. warning:: + Throws an error if :attr:`self` is not a sparse COO tensor. +""", +) + +add_docstr_all( + "contiguous", + r""" +contiguous(memory_format=torch.contiguous_format) -> Tensor + +Returns a contiguous in memory tensor containing the same data as :attr:`self` tensor. If +:attr:`self` tensor is already in the specified memory format, this function returns the +:attr:`self` tensor. + +Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.contiguous_format``. +""", +) + +add_docstr_all( + "copy_", + r""" +copy_(src, non_blocking=False) -> Tensor + +Copies the elements from :attr:`src` into :attr:`self` tensor and returns +:attr:`self`. + +The :attr:`src` tensor must be :ref:`broadcastable ` +with the :attr:`self` tensor. It may be of a different data type or reside on a +different device. + +Args: + src (Tensor): the source tensor to copy from + non_blocking (bool): if ``True`` and this copy is between CPU and GPU, + the copy may occur asynchronously with respect to the host. For other + cases, this argument has no effect. +""", +) + +add_docstr_all( + "conj", + r""" +conj() -> Tensor + +See :func:`torch.conj` +""", +) + +add_docstr_all( + "conj_physical", + r""" +conj_physical() -> Tensor + +See :func:`torch.conj_physical` +""", +) + +add_docstr_all( + "conj_physical_", + r""" +conj_physical_() -> Tensor + +In-place version of :meth:`~Tensor.conj_physical` +""", +) + +add_docstr_all( + "resolve_conj", + r""" +resolve_conj() -> Tensor + +See :func:`torch.resolve_conj` +""", +) + +add_docstr_all( + "resolve_neg", + r""" +resolve_neg() -> Tensor + +See :func:`torch.resolve_neg` +""", +) + +add_docstr_all( + "copysign", + r""" +copysign(other) -> Tensor + +See :func:`torch.copysign` +""", +) + +add_docstr_all( + "copysign_", + r""" +copysign_(other) -> Tensor + +In-place version of :meth:`~Tensor.copysign` +""", +) + +add_docstr_all( + "cos", + r""" +cos() -> Tensor + +See :func:`torch.cos` +""", +) + +add_docstr_all( + "cos_", + r""" +cos_() -> Tensor + +In-place version of :meth:`~Tensor.cos` +""", +) + +add_docstr_all( + "cosh", + r""" +cosh() -> Tensor + +See :func:`torch.cosh` +""", +) + +add_docstr_all( + "cosh_", + r""" +cosh_() -> Tensor + +In-place version of :meth:`~Tensor.cosh` +""", +) + +add_docstr_all( + "cpu", + r""" +cpu(memory_format=torch.preserve_format) -> Tensor + +Returns a copy of this object in CPU memory. + +If this object is already in CPU memory, +then no copy is performed and the original object is returned. + +Args: + {memory_format} + +""".format(**common_args), +) + +add_docstr_all( + "count_nonzero", + r""" +count_nonzero(dim=None) -> Tensor + +See :func:`torch.count_nonzero` +""", +) + +add_docstr_all( + "cov", + r""" +cov(*, correction=1, fweights=None, aweights=None) -> Tensor + +See :func:`torch.cov` +""", +) + +add_docstr_all( + "corrcoef", + r""" +corrcoef() -> Tensor + +See :func:`torch.corrcoef` +""", +) + +add_docstr_all( + "cross", + r""" +cross(other, dim=None) -> Tensor + +See :func:`torch.cross` +""", +) + +add_docstr_all( + "cuda", + r""" +cuda(device=None, non_blocking=False, memory_format=torch.preserve_format) -> Tensor + +Returns a copy of this object in CUDA memory. + +If this object is already in CUDA memory and on the correct device, +then no copy is performed and the original object is returned. + +Args: + device (:class:`torch.device`): The destination GPU device. + Defaults to the current CUDA device. + non_blocking (bool): If ``True`` and the source is in pinned memory, + the copy will be asynchronous with respect to the host. + Otherwise, the argument has no effect. Default: ``False``. + {memory_format} +""".format(**common_args), +) + +add_docstr_all( + "mtia", + r""" +mtia(device=None, non_blocking=False, memory_format=torch.preserve_format) -> Tensor + +Returns a copy of this object in MTIA memory. + +If this object is already in MTIA memory and on the correct device, +then no copy is performed and the original object is returned. + +Args: + device (:class:`torch.device`): The destination MTIA device. + Defaults to the current MTIA device. + non_blocking (bool): If ``True`` and the source is in pinned memory, + the copy will be asynchronous with respect to the host. + Otherwise, the argument has no effect. Default: ``False``. + {memory_format} +""".format(**common_args), +) + +add_docstr_all( + "ipu", + r""" +ipu(device=None, non_blocking=False, memory_format=torch.preserve_format) -> Tensor + +Returns a copy of this object in IPU memory. + +If this object is already in IPU memory and on the correct device, +then no copy is performed and the original object is returned. + +Args: + device (:class:`torch.device`): The destination IPU device. + Defaults to the current IPU device. + non_blocking (bool): If ``True`` and the source is in pinned memory, + the copy will be asynchronous with respect to the host. + Otherwise, the argument has no effect. Default: ``False``. + {memory_format} +""".format(**common_args), +) + +add_docstr_all( + "xpu", + r""" +xpu(device=None, non_blocking=False, memory_format=torch.preserve_format) -> Tensor + +Returns a copy of this object in XPU memory. + +If this object is already in XPU memory and on the correct device, +then no copy is performed and the original object is returned. + +Args: + device (:class:`torch.device`): The destination XPU device. + Defaults to the current XPU device. + non_blocking (bool): If ``True`` and the source is in pinned memory, + the copy will be asynchronous with respect to the host. + Otherwise, the argument has no effect. Default: ``False``. + {memory_format} +""".format(**common_args), +) + +add_docstr_all( + "logcumsumexp", + r""" +logcumsumexp(dim) -> Tensor + +See :func:`torch.logcumsumexp` +""", +) + +add_docstr_all( + "cummax", + r""" +cummax(dim) -> (Tensor, Tensor) + +See :func:`torch.cummax` +""", +) + +add_docstr_all( + "cummin", + r""" +cummin(dim) -> (Tensor, Tensor) + +See :func:`torch.cummin` +""", +) + +add_docstr_all( + "cumprod", + r""" +cumprod(dim, dtype=None) -> Tensor + +See :func:`torch.cumprod` +""", +) + +add_docstr_all( + "cumprod_", + r""" +cumprod_(dim, dtype=None) -> Tensor + +In-place version of :meth:`~Tensor.cumprod` +""", +) + +add_docstr_all( + "cumsum", + r""" +cumsum(dim, dtype=None) -> Tensor + +See :func:`torch.cumsum` +""", +) + +add_docstr_all( + "cumsum_", + r""" +cumsum_(dim, dtype=None) -> Tensor + +In-place version of :meth:`~Tensor.cumsum` +""", +) + +add_docstr_all( + "data_ptr", + r""" +data_ptr() -> int + +Returns the address of the first element of :attr:`self` tensor. +""", +) + +add_docstr_all( + "dequantize", + r""" +dequantize() -> Tensor + +Given a quantized Tensor, dequantize it and return the dequantized float Tensor. +""", +) + +add_docstr_all( + "dense_dim", + r""" +dense_dim() -> int + +Return the number of dense dimensions in a :ref:`sparse tensor ` :attr:`self`. + +.. note:: + Returns ``len(self.shape)`` if :attr:`self` is not a sparse tensor. + +See also :meth:`Tensor.sparse_dim` and :ref:`hybrid tensors `. +""", +) + +add_docstr_all( + "diag", + r""" +diag(diagonal=0) -> Tensor + +See :func:`torch.diag` +""", +) + +add_docstr_all( + "diag_embed", + r""" +diag_embed(offset=0, dim1=-2, dim2=-1) -> Tensor + +See :func:`torch.diag_embed` +""", +) + +add_docstr_all( + "diagflat", + r""" +diagflat(offset=0) -> Tensor + +See :func:`torch.diagflat` +""", +) + +add_docstr_all( + "diagonal", + r""" +diagonal(offset=0, dim1=0, dim2=1) -> Tensor + +See :func:`torch.diagonal` +""", +) + +add_docstr_all( + "diagonal_scatter", + r""" +diagonal_scatter(src, offset=0, dim1=0, dim2=1) -> Tensor + +See :func:`torch.diagonal_scatter` +""", +) + +add_docstr_all( + "as_strided_scatter", + r""" +as_strided_scatter(src, size, stride, storage_offset=None) -> Tensor + +See :func:`torch.as_strided_scatter` +""", +) + +add_docstr_all( + "fill_diagonal_", + r""" +fill_diagonal_(fill_value, wrap=False) -> Tensor + +Fill the main diagonal of a tensor that has at least 2-dimensions. +When dims>2, all dimensions of input must be of equal length. +This function modifies the input tensor in-place, and returns the input tensor. + +Arguments: + fill_value (Scalar): the fill value + wrap (bool): the diagonal 'wrapped' after N columns for tall matrices. + +Example:: + + >>> a = torch.zeros(3, 3) + >>> a.fill_diagonal_(5) + tensor([[5., 0., 0.], + [0., 5., 0.], + [0., 0., 5.]]) + >>> b = torch.zeros(7, 3) + >>> b.fill_diagonal_(5) + tensor([[5., 0., 0.], + [0., 5., 0.], + [0., 0., 5.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.]]) + >>> c = torch.zeros(7, 3) + >>> c.fill_diagonal_(5, wrap=True) + tensor([[5., 0., 0.], + [0., 5., 0.], + [0., 0., 5.], + [0., 0., 0.], + [5., 0., 0.], + [0., 5., 0.], + [0., 0., 5.]]) + +""", +) + +add_docstr_all( + "floor_divide", + r""" +floor_divide(value) -> Tensor + +See :func:`torch.floor_divide` +""", +) + +add_docstr_all( + "floor_divide_", + r""" +floor_divide_(value) -> Tensor + +In-place version of :meth:`~Tensor.floor_divide` +""", +) + +add_docstr_all( + "diff", + r""" +diff(n=1, dim=-1, prepend=None, append=None) -> Tensor + +See :func:`torch.diff` +""", +) + +add_docstr_all( + "digamma", + r""" +digamma() -> Tensor + +See :func:`torch.digamma` +""", +) + +add_docstr_all( + "digamma_", + r""" +digamma_() -> Tensor + +In-place version of :meth:`~Tensor.digamma` +""", +) + +add_docstr_all( + "dim", + r""" +dim() -> int + +Returns the number of dimensions of :attr:`self` tensor. +""", +) + +add_docstr_all( + "dist", + r""" +dist(other, p=2) -> Tensor + +See :func:`torch.dist` +""", +) + +add_docstr_all( + "div", + r""" +div(value, *, rounding_mode=None) -> Tensor + +See :func:`torch.div` +""", +) + +add_docstr_all( + "div_", + r""" +div_(value, *, rounding_mode=None) -> Tensor + +In-place version of :meth:`~Tensor.div` +""", +) + +add_docstr_all( + "divide", + r""" +divide(value, *, rounding_mode=None) -> Tensor + +See :func:`torch.divide` +""", +) + +add_docstr_all( + "divide_", + r""" +divide_(value, *, rounding_mode=None) -> Tensor + +In-place version of :meth:`~Tensor.divide` +""", +) + +add_docstr_all( + "dot", + r""" +dot(other) -> Tensor + +See :func:`torch.dot` +""", +) + +add_docstr_all( + "element_size", + r""" +element_size() -> int + +Returns the size in bytes of an individual element. + +Example:: + + >>> torch.tensor([]).element_size() + 4 + >>> torch.tensor([], dtype=torch.uint8).element_size() + 1 + +""", +) + +add_docstr_all( + "eq", + r""" +eq(other) -> Tensor + +See :func:`torch.eq` +""", +) + +add_docstr_all( + "eq_", + r""" +eq_(other) -> Tensor + +In-place version of :meth:`~Tensor.eq` +""", +) + +add_docstr_all( + "equal", + r""" +equal(other) -> bool + +See :func:`torch.equal` +""", +) + +add_docstr_all( + "erf", + r""" +erf() -> Tensor + +See :func:`torch.erf` +""", +) + +add_docstr_all( + "erf_", + r""" +erf_() -> Tensor + +In-place version of :meth:`~Tensor.erf` +""", +) + +add_docstr_all( + "erfc", + r""" +erfc() -> Tensor + +See :func:`torch.erfc` +""", +) + +add_docstr_all( + "erfc_", + r""" +erfc_() -> Tensor + +In-place version of :meth:`~Tensor.erfc` +""", +) + +add_docstr_all( + "erfinv", + r""" +erfinv() -> Tensor + +See :func:`torch.erfinv` +""", +) + +add_docstr_all( + "erfinv_", + r""" +erfinv_() -> Tensor + +In-place version of :meth:`~Tensor.erfinv` +""", +) + +add_docstr_all( + "exp", + r""" +exp() -> Tensor + +See :func:`torch.exp` +""", +) + +add_docstr_all( + "exp_", + r""" +exp_() -> Tensor + +In-place version of :meth:`~Tensor.exp` +""", +) + +add_docstr_all( + "exp2", + r""" +exp2() -> Tensor + +See :func:`torch.exp2` +""", +) + +add_docstr_all( + "exp2_", + r""" +exp2_() -> Tensor + +In-place version of :meth:`~Tensor.exp2` +""", +) + +add_docstr_all( + "expm1", + r""" +expm1() -> Tensor + +See :func:`torch.expm1` +""", +) + +add_docstr_all( + "expm1_", + r""" +expm1_() -> Tensor + +In-place version of :meth:`~Tensor.expm1` +""", +) + +add_docstr_all( + "exponential_", + r""" +exponential_(lambd=1, *, generator=None) -> Tensor + +Fills :attr:`self` tensor with elements drawn from the PDF (probability density function): + +.. math:: + + f(x) = \lambda e^{-\lambda x}, x > 0 + +.. note:: + In probability theory, exponential distribution is supported on interval [0, :math:`\inf`) (i.e., :math:`x >= 0`) + implying that zero can be sampled from the exponential distribution. + However, :func:`torch.Tensor.exponential_` does not sample zero, + which means that its actual support is the interval (0, :math:`\inf`). + + Note that :func:`torch.distributions.exponential.Exponential` is supported on the interval [0, :math:`\inf`) and can sample zero. +""", +) + +add_docstr_all( + "fill_", + r""" +fill_(value) -> Tensor + +Fills :attr:`self` tensor with the specified value. +""", +) + +add_docstr_all( + "floor", + r""" +floor() -> Tensor + +See :func:`torch.floor` +""", +) + +add_docstr_all( + "flip", + r""" +flip(dims) -> Tensor + +See :func:`torch.flip` +""", +) + +add_docstr_all( + "fliplr", + r""" +fliplr() -> Tensor + +See :func:`torch.fliplr` +""", +) + +add_docstr_all( + "flipud", + r""" +flipud() -> Tensor + +See :func:`torch.flipud` +""", +) + +add_docstr_all( + "roll", + r""" +roll(shifts, dims) -> Tensor + +See :func:`torch.roll` +""", +) + +add_docstr_all( + "floor_", + r""" +floor_() -> Tensor + +In-place version of :meth:`~Tensor.floor` +""", +) + +add_docstr_all( + "fmod", + r""" +fmod(divisor) -> Tensor + +See :func:`torch.fmod` +""", +) + +add_docstr_all( + "fmod_", + r""" +fmod_(divisor) -> Tensor + +In-place version of :meth:`~Tensor.fmod` +""", +) + +add_docstr_all( + "frac", + r""" +frac() -> Tensor + +See :func:`torch.frac` +""", +) + +add_docstr_all( + "frac_", + r""" +frac_() -> Tensor + +In-place version of :meth:`~Tensor.frac` +""", +) + +add_docstr_all( + "frexp", + r""" +frexp(input) -> (Tensor mantissa, Tensor exponent) + +See :func:`torch.frexp` +""", +) + +add_docstr_all( + "flatten", + r""" +flatten(start_dim=0, end_dim=-1) -> Tensor + +See :func:`torch.flatten` +""", +) + +add_docstr_all( + "gather", + r""" +gather(dim, index) -> Tensor + +See :func:`torch.gather` +""", +) + +add_docstr_all( + "gcd", + r""" +gcd(other) -> Tensor + +See :func:`torch.gcd` +""", +) + +add_docstr_all( + "gcd_", + r""" +gcd_(other) -> Tensor + +In-place version of :meth:`~Tensor.gcd` +""", +) + +add_docstr_all( + "ge", + r""" +ge(other) -> Tensor + +See :func:`torch.ge`. +""", +) + +add_docstr_all( + "ge_", + r""" +ge_(other) -> Tensor + +In-place version of :meth:`~Tensor.ge`. +""", +) + +add_docstr_all( + "greater_equal", + r""" +greater_equal(other) -> Tensor + +See :func:`torch.greater_equal`. +""", +) + +add_docstr_all( + "greater_equal_", + r""" +greater_equal_(other) -> Tensor + +In-place version of :meth:`~Tensor.greater_equal`. +""", +) + +add_docstr_all( + "geometric_", + r""" +geometric_(p, *, generator=None) -> Tensor + +Fills :attr:`self` tensor with elements drawn from the geometric distribution: + +.. math:: + + P(X=k) = (1 - p)^{k - 1} p, k = 1, 2, ... + +.. note:: + :func:`torch.Tensor.geometric_` `k`-th trial is the first success hence draws samples in :math:`\{1, 2, \ldots\}`, whereas + :func:`torch.distributions.geometric.Geometric` :math:`(k+1)`-th trial is the first success + hence draws samples in :math:`\{0, 1, \ldots\}`. +""", +) + +add_docstr_all( + "geqrf", + r""" +geqrf() -> (Tensor, Tensor) + +See :func:`torch.geqrf` +""", +) + +add_docstr_all( + "ger", + r""" +ger(vec2) -> Tensor + +See :func:`torch.ger` +""", +) + +add_docstr_all( + "inner", + r""" +inner(other) -> Tensor + +See :func:`torch.inner`. +""", +) + +add_docstr_all( + "outer", + r""" +outer(vec2) -> Tensor + +See :func:`torch.outer`. +""", +) + +add_docstr_all( + "hypot", + r""" +hypot(other) -> Tensor + +See :func:`torch.hypot` +""", +) + +add_docstr_all( + "hypot_", + r""" +hypot_(other) -> Tensor + +In-place version of :meth:`~Tensor.hypot` +""", +) + +add_docstr_all( + "i0", + r""" +i0() -> Tensor + +See :func:`torch.i0` +""", +) + +add_docstr_all( + "i0_", + r""" +i0_() -> Tensor + +In-place version of :meth:`~Tensor.i0` +""", +) + +add_docstr_all( + "igamma", + r""" +igamma(other) -> Tensor + +See :func:`torch.igamma` +""", +) + +add_docstr_all( + "igamma_", + r""" +igamma_(other) -> Tensor + +In-place version of :meth:`~Tensor.igamma` +""", +) + +add_docstr_all( + "igammac", + r""" +igammac(other) -> Tensor +See :func:`torch.igammac` +""", +) + +add_docstr_all( + "igammac_", + r""" +igammac_(other) -> Tensor +In-place version of :meth:`~Tensor.igammac` +""", +) + +add_docstr_all( + "indices", + r""" +indices() -> Tensor + +Return the indices tensor of a :ref:`sparse COO tensor `. + +.. warning:: + Throws an error if :attr:`self` is not a sparse COO tensor. + +See also :meth:`Tensor.values`. + +.. note:: + This method can only be called on a coalesced sparse tensor. See + :meth:`Tensor.coalesce` for details. +""", +) + +add_docstr_all( + "get_device", + r""" +get_device() -> Device ordinal (Integer) + +For CUDA tensors, this function returns the device ordinal of the GPU on which the tensor resides. +For CPU tensors, this function returns `-1`. + +Example:: + + >>> x = torch.randn(3, 4, 5, device='cuda:0') + >>> x.get_device() + 0 + >>> x.cpu().get_device() + -1 +""", +) + +add_docstr_all( + "values", + r""" +values() -> Tensor + +Return the values tensor of a :ref:`sparse COO tensor `. + +.. warning:: + Throws an error if :attr:`self` is not a sparse COO tensor. + +See also :meth:`Tensor.indices`. + +.. note:: + This method can only be called on a coalesced sparse tensor. See + :meth:`Tensor.coalesce` for details. +""", +) + +add_docstr_all( + "gt", + r""" +gt(other) -> Tensor + +See :func:`torch.gt`. +""", +) + +add_docstr_all( + "gt_", + r""" +gt_(other) -> Tensor + +In-place version of :meth:`~Tensor.gt`. +""", +) + +add_docstr_all( + "greater", + r""" +greater(other) -> Tensor + +See :func:`torch.greater`. +""", +) + +add_docstr_all( + "greater_", + r""" +greater_(other) -> Tensor + +In-place version of :meth:`~Tensor.greater`. +""", +) + +add_docstr_all( + "has_names", + r""" +Is ``True`` if any of this tensor's dimensions are named. Otherwise, is ``False``. +""", +) + +add_docstr_all( + "hardshrink", + r""" +hardshrink(lambd=0.5) -> Tensor + +See :func:`torch.nn.functional.hardshrink` +""", +) + +add_docstr_all( + "heaviside", + r""" +heaviside(values) -> Tensor + +See :func:`torch.heaviside` +""", +) + +add_docstr_all( + "heaviside_", + r""" +heaviside_(values) -> Tensor + +In-place version of :meth:`~Tensor.heaviside` +""", +) + +add_docstr_all( + "histc", + r""" +histc(bins=100, min=0, max=0) -> Tensor + +See :func:`torch.histc` +""", +) + +add_docstr_all( + "histogram", + r""" +histogram(input, bins, *, range=None, weight=None, density=False) -> (Tensor, Tensor) + +See :func:`torch.histogram` +""", +) + +add_docstr_all( + "index_add_", + r""" +index_add_(dim, index, source, *, alpha=1) -> Tensor + +Accumulate the elements of :attr:`alpha` times ``source`` into the :attr:`self` +tensor by adding to the indices in the order given in :attr:`index`. For example, +if ``dim == 0``, ``index[i] == j``, and ``alpha=-1``, then the ``i``\ th row of +``source`` is subtracted from the ``j``\ th row of :attr:`self`. + +The :attr:`dim`\ th dimension of ``source`` must have the same size as the +length of :attr:`index` (which must be a vector), and all other dimensions must +match :attr:`self`, or an error will be raised. + +For a 3-D tensor the output is given as:: + + self[index[i], :, :] += alpha * src[i, :, :] # if dim == 0 + self[:, index[i], :] += alpha * src[:, i, :] # if dim == 1 + self[:, :, index[i]] += alpha * src[:, :, i] # if dim == 2 + +Note: + {forward_reproducibility_note} + +Args: + dim (int): dimension along which to index + index (Tensor): indices of ``source`` to select from, + should have dtype either `torch.int64` or `torch.int32` + source (Tensor): the tensor containing values to add + +Keyword args: + alpha (Number): the scalar multiplier for ``source`` + +Example:: + + >>> x = torch.ones(5, 3) + >>> t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float) + >>> index = torch.tensor([0, 4, 2]) + >>> x.index_add_(0, index, t) + tensor([[ 2., 3., 4.], + [ 1., 1., 1.], + [ 8., 9., 10.], + [ 1., 1., 1.], + [ 5., 6., 7.]]) + >>> x.index_add_(0, index, t, alpha=-1) + tensor([[ 1., 1., 1.], + [ 1., 1., 1.], + [ 1., 1., 1.], + [ 1., 1., 1.], + [ 1., 1., 1.]]) +""".format(**reproducibility_notes), +) + +add_docstr_all( + "index_copy_", + r""" +index_copy_(dim, index, tensor) -> Tensor + +Copies the elements of :attr:`tensor` into the :attr:`self` tensor by selecting +the indices in the order given in :attr:`index`. For example, if ``dim == 0`` +and ``index[i] == j``, then the ``i``\ th row of :attr:`tensor` is copied to the +``j``\ th row of :attr:`self`. + +The :attr:`dim`\ th dimension of :attr:`tensor` must have the same size as the +length of :attr:`index` (which must be a vector), and all other dimensions must +match :attr:`self`, or an error will be raised. + +.. note:: + If :attr:`index` contains duplicate entries, multiple elements from + :attr:`tensor` will be copied to the same index of :attr:`self`. The result + is nondeterministic since it depends on which copy occurs last. + +Args: + dim (int): dimension along which to index + index (LongTensor): indices of :attr:`tensor` to select from + tensor (Tensor): the tensor containing values to copy + +Example:: + + >>> x = torch.zeros(5, 3) + >>> t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float) + >>> index = torch.tensor([0, 4, 2]) + >>> x.index_copy_(0, index, t) + tensor([[ 1., 2., 3.], + [ 0., 0., 0.], + [ 7., 8., 9.], + [ 0., 0., 0.], + [ 4., 5., 6.]]) +""", +) + +add_docstr_all( + "index_fill_", + r""" +index_fill_(dim, index, value) -> Tensor + +Fills the elements of the :attr:`self` tensor with value :attr:`value` by +selecting the indices in the order given in :attr:`index`. + +Args: + dim (int): dimension along which to index + index (LongTensor): indices of :attr:`self` tensor to fill in + value (float): the value to fill with + +Example:: + + >>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float) + >>> index = torch.tensor([0, 2]) + >>> x.index_fill_(1, index, -1) + tensor([[-1., 2., -1.], + [-1., 5., -1.], + [-1., 8., -1.]]) +""", +) + +add_docstr_all( + "index_put_", + r""" +index_put_(indices, values, accumulate=False) -> Tensor + +Puts values from the tensor :attr:`values` into the tensor :attr:`self` using +the indices specified in :attr:`indices` (which is a tuple of Tensors). The +expression ``tensor.index_put_(indices, values)`` is equivalent to +``tensor[indices] = values``. Returns :attr:`self`. + +If :attr:`accumulate` is ``True``, the elements in :attr:`values` are added to +:attr:`self`. If accumulate is ``False``, the behavior is undefined if indices +contain duplicate elements. + +Args: + indices (tuple of LongTensor): tensors used to index into `self`. + values (Tensor): tensor of same dtype as `self`. + accumulate (bool): whether to accumulate into self +""", +) + +add_docstr_all( + "index_put", + r""" +index_put(indices, values, accumulate=False) -> Tensor + +Out-place version of :meth:`~Tensor.index_put_`. +""", +) + +add_docstr_all( + "index_reduce_", + r""" +index_reduce_(dim, index, source, reduce, *, include_self=True) -> Tensor + +Accumulate the elements of ``source`` into the :attr:`self` +tensor by accumulating to the indices in the order given in :attr:`index` +using the reduction given by the ``reduce`` argument. For example, if ``dim == 0``, +``index[i] == j``, ``reduce == prod`` and ``include_self == True`` then the ``i``\ th +row of ``source`` is multiplied by the ``j``\ th row of :attr:`self`. If +:obj:`include_self="True"`, the values in the :attr:`self` tensor are included +in the reduction, otherwise, rows in the :attr:`self` tensor that are accumulated +to are treated as if they were filled with the reduction identites. + +The :attr:`dim`\ th dimension of ``source`` must have the same size as the +length of :attr:`index` (which must be a vector), and all other dimensions must +match :attr:`self`, or an error will be raised. + +For a 3-D tensor with :obj:`reduce="prod"` and :obj:`include_self=True` the +output is given as:: + + self[index[i], :, :] *= src[i, :, :] # if dim == 0 + self[:, index[i], :] *= src[:, i, :] # if dim == 1 + self[:, :, index[i]] *= src[:, :, i] # if dim == 2 + +Note: + {forward_reproducibility_note} + +.. note:: + + This function only supports floating point tensors. + +.. warning:: + + This function is in beta and may change in the near future. + +Args: + dim (int): dimension along which to index + index (Tensor): indices of ``source`` to select from, + should have dtype either `torch.int64` or `torch.int32` + source (FloatTensor): the tensor containing values to accumulate + reduce (str): the reduction operation to apply + (:obj:`"prod"`, :obj:`"mean"`, :obj:`"amax"`, :obj:`"amin"`) + +Keyword args: + include_self (bool): whether the elements from the ``self`` tensor are + included in the reduction + +Example:: + + >>> x = torch.empty(5, 3).fill_(2) + >>> t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=torch.float) + >>> index = torch.tensor([0, 4, 2, 0]) + >>> x.index_reduce_(0, index, t, 'prod') + tensor([[20., 44., 72.], + [ 2., 2., 2.], + [14., 16., 18.], + [ 2., 2., 2.], + [ 8., 10., 12.]]) + >>> x = torch.empty(5, 3).fill_(2) + >>> x.index_reduce_(0, index, t, 'prod', include_self=False) + tensor([[10., 22., 36.], + [ 2., 2., 2.], + [ 7., 8., 9.], + [ 2., 2., 2.], + [ 4., 5., 6.]]) +""".format(**reproducibility_notes), +) + +add_docstr_all( + "index_select", + r""" +index_select(dim, index) -> Tensor + +See :func:`torch.index_select` +""", +) + +add_docstr_all( + "sparse_mask", + r""" +sparse_mask(mask) -> Tensor + +Returns a new :ref:`sparse tensor ` with values from a +strided tensor :attr:`self` filtered by the indices of the sparse +tensor :attr:`mask`. The values of :attr:`mask` sparse tensor are +ignored. :attr:`self` and :attr:`mask` tensors must have the same +shape. + +.. note:: + + The returned sparse tensor might contain duplicate values if :attr:`mask` + is not coalesced. It is therefore advisable to pass ``mask.coalesce()`` + if such behavior is not desired. + +.. note:: + + The returned sparse tensor has the same indices as the sparse tensor + :attr:`mask`, even when the corresponding values in :attr:`self` are + zeros. + +Args: + mask (Tensor): a sparse tensor whose indices are used as a filter + +Example:: + + >>> nse = 5 + >>> dims = (5, 5, 2, 2) + >>> I = torch.cat([torch.randint(0, dims[0], size=(nse,)), + ... torch.randint(0, dims[1], size=(nse,))], 0).reshape(2, nse) + >>> V = torch.randn(nse, dims[2], dims[3]) + >>> S = torch.sparse_coo_tensor(I, V, dims).coalesce() + >>> D = torch.randn(dims) + >>> D.sparse_mask(S) + tensor(indices=tensor([[0, 0, 0, 2], + [0, 1, 4, 3]]), + values=tensor([[[ 1.6550, 0.2397], + [-0.1611, -0.0779]], + + [[ 0.2326, -1.0558], + [ 1.4711, 1.9678]], + + [[-0.5138, -0.0411], + [ 1.9417, 0.5158]], + + [[ 0.0793, 0.0036], + [-0.2569, -0.1055]]]), + size=(5, 5, 2, 2), nnz=4, layout=torch.sparse_coo) +""", +) + +add_docstr_all( + "inverse", + r""" +inverse() -> Tensor + +See :func:`torch.inverse` +""", +) + +add_docstr_all( + "isnan", + r""" +isnan() -> Tensor + +See :func:`torch.isnan` +""", +) + +add_docstr_all( + "isinf", + r""" +isinf() -> Tensor + +See :func:`torch.isinf` +""", +) + +add_docstr_all( + "isposinf", + r""" +isposinf() -> Tensor + +See :func:`torch.isposinf` +""", +) + +add_docstr_all( + "isneginf", + r""" +isneginf() -> Tensor + +See :func:`torch.isneginf` +""", +) + +add_docstr_all( + "isfinite", + r""" +isfinite() -> Tensor + +See :func:`torch.isfinite` +""", +) + +add_docstr_all( + "isclose", + r""" +isclose(other, rtol=1e-05, atol=1e-08, equal_nan=False) -> Tensor + +See :func:`torch.isclose` +""", +) + +add_docstr_all( + "isreal", + r""" +isreal() -> Tensor + +See :func:`torch.isreal` +""", +) + +add_docstr_all( + "is_coalesced", + r""" +is_coalesced() -> bool + +Returns ``True`` if :attr:`self` is a :ref:`sparse COO tensor +` that is coalesced, ``False`` otherwise. + +.. warning:: + Throws an error if :attr:`self` is not a sparse COO tensor. + +See :meth:`coalesce` and :ref:`uncoalesced tensors `. +""", +) + +add_docstr_all( + "is_contiguous", + r""" +is_contiguous(memory_format=torch.contiguous_format) -> bool + +Returns True if :attr:`self` tensor is contiguous in memory in the order specified +by memory format. + +Args: + memory_format (:class:`torch.memory_format`, optional): Specifies memory allocation + order. Default: ``torch.contiguous_format``. +""", +) + +add_docstr_all( + "is_pinned", + r""" +Returns true if this tensor resides in pinned memory. +By default, the device pinned memory on will be the current :ref:`accelerator`. +""", +) + +add_docstr_all( + "is_floating_point", + r""" +is_floating_point() -> bool + +Returns True if the data type of :attr:`self` is a floating point data type. +""", +) + +add_docstr_all( + "is_complex", + r""" +is_complex() -> bool + +Returns True if the data type of :attr:`self` is a complex data type. +""", +) + +add_docstr_all( + "is_inference", + r""" +is_inference() -> bool + +See :func:`torch.is_inference` +""", +) + +add_docstr_all( + "is_conj", + r""" +is_conj() -> bool + +Returns True if the conjugate bit of :attr:`self` is set to true. +""", +) + +add_docstr_all( + "is_neg", + r""" +is_neg() -> bool + +Returns True if the negative bit of :attr:`self` is set to true. +""", +) + +add_docstr_all( + "is_signed", + r""" +is_signed() -> bool + +Returns True if the data type of :attr:`self` is a signed data type. +""", +) + +add_docstr_all( + "is_set_to", + r""" +is_set_to(tensor) -> bool + +Returns True if both tensors are pointing to the exact same memory (same +storage, offset, size and stride). +""", +) + +add_docstr_all( + "item", + r""" +item() -> number + +Returns the value of this tensor as a standard Python number. This only works +for tensors with one element. For other cases, see :meth:`~Tensor.tolist`. + +This operation is not differentiable. + +Example:: + + >>> x = torch.tensor([1.0]) + >>> x.item() + 1.0 + +""", +) + +add_docstr_all( + "kron", + r""" +kron(other) -> Tensor + +See :func:`torch.kron` +""", +) + +add_docstr_all( + "kthvalue", + r""" +kthvalue(k, dim=None, keepdim=False) -> (Tensor, LongTensor) + +See :func:`torch.kthvalue` +""", +) + +add_docstr_all( + "ldexp", + r""" +ldexp(other) -> Tensor + +See :func:`torch.ldexp` +""", +) + +add_docstr_all( + "ldexp_", + r""" +ldexp_(other) -> Tensor + +In-place version of :meth:`~Tensor.ldexp` +""", +) + +add_docstr_all( + "lcm", + r""" +lcm(other) -> Tensor + +See :func:`torch.lcm` +""", +) + +add_docstr_all( + "lcm_", + r""" +lcm_(other) -> Tensor + +In-place version of :meth:`~Tensor.lcm` +""", +) + +add_docstr_all( + "le", + r""" +le(other) -> Tensor + +See :func:`torch.le`. +""", +) + +add_docstr_all( + "le_", + r""" +le_(other) -> Tensor + +In-place version of :meth:`~Tensor.le`. +""", +) + +add_docstr_all( + "less_equal", + r""" +less_equal(other) -> Tensor + +See :func:`torch.less_equal`. +""", +) + +add_docstr_all( + "less_equal_", + r""" +less_equal_(other) -> Tensor + +In-place version of :meth:`~Tensor.less_equal`. +""", +) + +add_docstr_all( + "lerp", + r""" +lerp(end, weight) -> Tensor + +See :func:`torch.lerp` +""", +) + +add_docstr_all( + "lerp_", + r""" +lerp_(end, weight) -> Tensor + +In-place version of :meth:`~Tensor.lerp` +""", +) + +add_docstr_all( + "lgamma", + r""" +lgamma() -> Tensor + +See :func:`torch.lgamma` +""", +) + +add_docstr_all( + "lgamma_", + r""" +lgamma_() -> Tensor + +In-place version of :meth:`~Tensor.lgamma` +""", +) + +add_docstr_all( + "log", + r""" +log() -> Tensor + +See :func:`torch.log` +""", +) + +add_docstr_all( + "log_", + r""" +log_() -> Tensor + +In-place version of :meth:`~Tensor.log` +""", +) + +add_docstr_all( + "log10", + r""" +log10() -> Tensor + +See :func:`torch.log10` +""", +) + +add_docstr_all( + "log10_", + r""" +log10_() -> Tensor + +In-place version of :meth:`~Tensor.log10` +""", +) + +add_docstr_all( + "log1p", + r""" +log1p() -> Tensor + +See :func:`torch.log1p` +""", +) + +add_docstr_all( + "log1p_", + r""" +log1p_() -> Tensor + +In-place version of :meth:`~Tensor.log1p` +""", +) + +add_docstr_all( + "log2", + r""" +log2() -> Tensor + +See :func:`torch.log2` +""", +) + +add_docstr_all( + "log2_", + r""" +log2_() -> Tensor + +In-place version of :meth:`~Tensor.log2` +""", +) + +add_docstr_all( + "logaddexp", + r""" +logaddexp(other) -> Tensor + +See :func:`torch.logaddexp` +""", +) + +add_docstr_all( + "logaddexp2", + r""" +logaddexp2(other) -> Tensor + +See :func:`torch.logaddexp2` +""", +) + +add_docstr_all( + "log_normal_", + r""" +log_normal_(mean=1, std=2, *, generator=None) + +Fills :attr:`self` tensor with numbers samples from the log-normal distribution +parameterized by the given mean :math:`\mu` and standard deviation +:math:`\sigma`. Note that :attr:`mean` and :attr:`std` are the mean and +standard deviation of the underlying normal distribution, and not of the +returned distribution: + +.. math:: + + f(x) = \dfrac{1}{x \sigma \sqrt{2\pi}}\ e^{-\frac{(\ln x - \mu)^2}{2\sigma^2}} +""", +) + +add_docstr_all( + "logsumexp", + r""" +logsumexp(dim, keepdim=False) -> Tensor + +See :func:`torch.logsumexp` +""", +) + +add_docstr_all( + "lt", + r""" +lt(other) -> Tensor + +See :func:`torch.lt`. +""", +) + +add_docstr_all( + "lt_", + r""" +lt_(other) -> Tensor + +In-place version of :meth:`~Tensor.lt`. +""", +) + +add_docstr_all( + "less", + r""" +lt(other) -> Tensor + +See :func:`torch.less`. +""", +) + +add_docstr_all( + "less_", + r""" +less_(other) -> Tensor + +In-place version of :meth:`~Tensor.less`. +""", +) + +add_docstr_all( + "lu_solve", + r""" +lu_solve(LU_data, LU_pivots) -> Tensor + +See :func:`torch.lu_solve` +""", +) + +add_docstr_all( + "map_", + r""" +map_(tensor, callable) + +Applies :attr:`callable` for each element in :attr:`self` tensor and the given +:attr:`tensor` and stores the results in :attr:`self` tensor. :attr:`self` tensor and +the given :attr:`tensor` must be :ref:`broadcastable `. + +The :attr:`callable` should have the signature:: + + def callable(a, b) -> number +""", +) + +add_docstr_all( + "masked_scatter_", + r""" +masked_scatter_(mask, source) + +Copies elements from :attr:`source` into :attr:`self` tensor at positions where +the :attr:`mask` is True. Elements from :attr:`source` are copied into :attr:`self` +starting at position 0 of :attr:`source` and continuing in order one-by-one for each +occurrence of :attr:`mask` being True. +The shape of :attr:`mask` must be :ref:`broadcastable ` +with the shape of the underlying tensor. The :attr:`source` should have at least +as many elements as the number of ones in :attr:`mask`. + +Args: + mask (BoolTensor): the boolean mask + source (Tensor): the tensor to copy from + +.. note:: + + The :attr:`mask` operates on the :attr:`self` tensor, not on the given + :attr:`source` tensor. + +Example: + + >>> self = torch.tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]) + >>> mask = torch.tensor( + ... [[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]], + ... dtype=torch.bool, + ... ) + >>> source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) + >>> self.masked_scatter_(mask, source) + tensor([[0, 0, 0, 0, 1], + [2, 3, 0, 4, 5]]) + +""", +) + +add_docstr_all( + "masked_fill_", + r""" +masked_fill_(mask, value) + +Fills elements of :attr:`self` tensor with :attr:`value` where :attr:`mask` is +True. The shape of :attr:`mask` must be +:ref:`broadcastable ` with the shape of the underlying +tensor. + +Args: + mask (BoolTensor): the boolean mask + value (float): the value to fill in with +""", +) + +add_docstr_all( + "masked_select", + r""" +masked_select(mask) -> Tensor + +See :func:`torch.masked_select` +""", +) + +add_docstr_all( + "matrix_power", + r""" +matrix_power(n) -> Tensor + +.. note:: :meth:`~Tensor.matrix_power` is deprecated, use :func:`torch.linalg.matrix_power` instead. + +Alias for :func:`torch.linalg.matrix_power` +""", +) + +add_docstr_all( + "matrix_exp", + r""" +matrix_exp() -> Tensor + +See :func:`torch.matrix_exp` +""", +) + +add_docstr_all( + "max", + r""" +max(dim=None, keepdim=False) -> Tensor or (Tensor, Tensor) + +See :func:`torch.max` +""", +) + +add_docstr_all( + "amax", + r""" +amax(dim=None, keepdim=False) -> Tensor + +See :func:`torch.amax` +""", +) + +add_docstr_all( + "maximum", + r""" +maximum(other) -> Tensor + +See :func:`torch.maximum` +""", +) + +add_docstr_all( + "fmax", + r""" +fmax(other) -> Tensor + +See :func:`torch.fmax` +""", +) + +add_docstr_all( + "argmax", + r""" +argmax(dim=None, keepdim=False) -> LongTensor + +See :func:`torch.argmax` +""", +) + +add_docstr_all( + "argwhere", + r""" +argwhere() -> Tensor + +See :func:`torch.argwhere` +""", +) + +add_docstr_all( + "mean", + r""" +mean(dim=None, keepdim=False, *, dtype=None) -> Tensor + +See :func:`torch.mean` +""", +) + +add_docstr_all( + "nanmean", + r""" +nanmean(dim=None, keepdim=False, *, dtype=None) -> Tensor + +See :func:`torch.nanmean` +""", +) + +add_docstr_all( + "median", + r""" +median(dim=None, keepdim=False) -> (Tensor, LongTensor) + +See :func:`torch.median` +""", +) + +add_docstr_all( + "nanmedian", + r""" +nanmedian(dim=None, keepdim=False) -> (Tensor, LongTensor) + +See :func:`torch.nanmedian` +""", +) + +add_docstr_all( + "min", + r""" +min(dim=None, keepdim=False) -> Tensor or (Tensor, Tensor) + +See :func:`torch.min` +""", +) + +add_docstr_all( + "amin", + r""" +amin(dim=None, keepdim=False) -> Tensor + +See :func:`torch.amin` +""", +) + +add_docstr_all( + "minimum", + r""" +minimum(other) -> Tensor + +See :func:`torch.minimum` +""", +) + +add_docstr_all( + "aminmax", + r""" +aminmax(*, dim=None, keepdim=False) -> (Tensor min, Tensor max) + +See :func:`torch.aminmax` +""", +) + +add_docstr_all( + "fmin", + r""" +fmin(other) -> Tensor + +See :func:`torch.fmin` +""", +) + +add_docstr_all( + "argmin", + r""" +argmin(dim=None, keepdim=False) -> LongTensor + +See :func:`torch.argmin` +""", +) + +add_docstr_all( + "mm", + r""" +mm(mat2) -> Tensor + +See :func:`torch.mm` +""", +) + +add_docstr_all( + "mode", + r""" +mode(dim=None, keepdim=False) -> (Tensor, LongTensor) + +See :func:`torch.mode` +""", +) + +add_docstr_all( + "movedim", + r""" +movedim(source, destination) -> Tensor + +See :func:`torch.movedim` +""", +) + +add_docstr_all( + "moveaxis", + r""" +moveaxis(source, destination) -> Tensor + +See :func:`torch.moveaxis` +""", +) + +add_docstr_all( + "mul", + r""" +mul(value) -> Tensor + +See :func:`torch.mul`. +""", +) + +add_docstr_all( + "mul_", + r""" +mul_(value) -> Tensor + +In-place version of :meth:`~Tensor.mul`. +""", +) + +add_docstr_all( + "multiply", + r""" +multiply(value) -> Tensor + +See :func:`torch.multiply`. +""", +) + +add_docstr_all( + "multiply_", + r""" +multiply_(value) -> Tensor + +In-place version of :meth:`~Tensor.multiply`. +""", +) + +add_docstr_all( + "multinomial", + r""" +multinomial(num_samples, replacement=False, *, generator=None) -> Tensor + +See :func:`torch.multinomial` +""", +) + +add_docstr_all( + "mv", + r""" +mv(vec) -> Tensor + +See :func:`torch.mv` +""", +) + +add_docstr_all( + "mvlgamma", + r""" +mvlgamma(p) -> Tensor + +See :func:`torch.mvlgamma` +""", +) + +add_docstr_all( + "mvlgamma_", + r""" +mvlgamma_(p) -> Tensor + +In-place version of :meth:`~Tensor.mvlgamma` +""", +) + +add_docstr_all( + "narrow", + r""" +narrow(dimension, start, length) -> Tensor + +See :func:`torch.narrow`. +""", +) + +add_docstr_all( + "narrow_copy", + r""" +narrow_copy(dimension, start, length) -> Tensor + +See :func:`torch.narrow_copy`. +""", +) + +add_docstr_all( + "ndimension", + r""" +ndimension() -> int + +Alias for :meth:`~Tensor.dim()` +""", +) + +add_docstr_all( + "nan_to_num", + r""" +nan_to_num(nan=0.0, posinf=None, neginf=None) -> Tensor + +See :func:`torch.nan_to_num`. +""", +) + +add_docstr_all( + "nan_to_num_", + r""" +nan_to_num_(nan=0.0, posinf=None, neginf=None) -> Tensor + +In-place version of :meth:`~Tensor.nan_to_num`. +""", +) + +add_docstr_all( + "ne", + r""" +ne(other) -> Tensor + +See :func:`torch.ne`. +""", +) + +add_docstr_all( + "ne_", + r""" +ne_(other) -> Tensor + +In-place version of :meth:`~Tensor.ne`. +""", +) + +add_docstr_all( + "not_equal", + r""" +not_equal(other) -> Tensor + +See :func:`torch.not_equal`. +""", +) + +add_docstr_all( + "not_equal_", + r""" +not_equal_(other) -> Tensor + +In-place version of :meth:`~Tensor.not_equal`. +""", +) + +add_docstr_all( + "neg", + r""" +neg() -> Tensor + +See :func:`torch.neg` +""", +) + +add_docstr_all( + "negative", + r""" +negative() -> Tensor + +See :func:`torch.negative` +""", +) + +add_docstr_all( + "neg_", + r""" +neg_() -> Tensor + +In-place version of :meth:`~Tensor.neg` +""", +) + +add_docstr_all( + "negative_", + r""" +negative_() -> Tensor + +In-place version of :meth:`~Tensor.negative` +""", +) + +add_docstr_all( + "nelement", + r""" +nelement() -> int + +Alias for :meth:`~Tensor.numel` +""", +) + +add_docstr_all( + "nextafter", + r""" +nextafter(other) -> Tensor +See :func:`torch.nextafter` +""", +) + +add_docstr_all( + "nextafter_", + r""" +nextafter_(other) -> Tensor +In-place version of :meth:`~Tensor.nextafter` +""", +) + +add_docstr_all( + "nonzero", + r""" +nonzero() -> LongTensor + +See :func:`torch.nonzero` +""", +) + +add_docstr_all( + "nonzero_static", + r""" +nonzero_static(input, *, size, fill_value=-1) -> Tensor + +Returns a 2-D tensor where each row is the index for a non-zero value. +The returned Tensor has the same `torch.dtype` as `torch.nonzero()`. + +Args: + input (Tensor): the input tensor to count non-zero elements. + +Keyword args: + size (int): the size of non-zero elements expected to be included in the out + tensor. Pad the out tensor with `fill_value` if the `size` is larger + than total number of non-zero elements, truncate out tensor if `size` + is smaller. The size must be a non-negative integer. + fill_value (int): the value to fill the output tensor with when `size` is larger + than the total number of non-zero elements. Default is `-1` to represent + invalid index. + +Example: + + # Example 1: Padding + >>> input_tensor = torch.tensor([[1, 0], [3, 2]]) + >>> static_size = 4 + >>> t = torch.nonzero_static(input_tensor, size=static_size) + tensor([[ 0, 0], + [ 1, 0], + [ 1, 1], + [ -1, -1]], dtype=torch.int64) + + # Example 2: Truncating + >>> input_tensor = torch.tensor([[1, 0], [3, 2]]) + >>> static_size = 2 + >>> t = torch.nonzero_static(input_tensor, size=static_size) + tensor([[ 0, 0], + [ 1, 0]], dtype=torch.int64) + + # Example 3: 0 size + >>> input_tensor = torch.tensor([10]) + >>> static_size = 0 + >>> t = torch.nonzero_static(input_tensor, size=static_size) + tensor([], size=(0, 1), dtype=torch.int64) + + # Example 4: 0 rank input + >>> input_tensor = torch.tensor(10) + >>> static_size = 2 + >>> t = torch.nonzero_static(input_tensor, size=static_size) + tensor([], size=(2, 0), dtype=torch.int64) +""", +) + +add_docstr_all( + "norm", + r""" +norm(p=2, dim=None, keepdim=False) -> Tensor + +See :func:`torch.norm` +""", +) + +add_docstr_all( + "normal_", + r""" +normal_(mean=0, std=1, *, generator=None) -> Tensor + +Fills :attr:`self` tensor with elements samples from the normal distribution +parameterized by :attr:`mean` and :attr:`std`. +""", +) + +add_docstr_all( + "numel", + r""" +numel() -> int + +See :func:`torch.numel` +""", +) + +add_docstr_all( + "numpy", + r""" +numpy(*, force=False) -> numpy.ndarray + +Returns the tensor as a NumPy :class:`ndarray`. + +If :attr:`force` is ``False`` (the default), the conversion +is performed only if the tensor is on the CPU, does not require grad, +does not have its conjugate bit set, and is a dtype and layout that +NumPy supports. The returned ndarray and the tensor will share their +storage, so changes to the tensor will be reflected in the ndarray +and vice versa. + +If :attr:`force` is ``True`` this is equivalent to +calling ``t.detach().cpu().resolve_conj().resolve_neg().numpy()``. +If the tensor isn't on the CPU or the conjugate or negative bit is set, +the tensor won't share its storage with the returned ndarray. +Setting :attr:`force` to ``True`` can be a useful shorthand. + +Args: + force (bool): if ``True``, the ndarray may be a copy of the tensor + instead of always sharing memory, defaults to ``False``. +""", +) + +add_docstr_all( + "orgqr", + r""" +orgqr(input2) -> Tensor + +See :func:`torch.orgqr` +""", +) + +add_docstr_all( + "ormqr", + r""" +ormqr(input2, input3, left=True, transpose=False) -> Tensor + +See :func:`torch.ormqr` +""", +) + +add_docstr_all( + "permute", + r""" +permute(*dims) -> Tensor + +See :func:`torch.permute` +""", +) + +add_docstr_all( + "polygamma", + r""" +polygamma(n) -> Tensor + +See :func:`torch.polygamma` +""", +) + +add_docstr_all( + "polygamma_", + r""" +polygamma_(n) -> Tensor + +In-place version of :meth:`~Tensor.polygamma` +""", +) + +add_docstr_all( + "positive", + r""" +positive() -> Tensor + +See :func:`torch.positive` +""", +) + +add_docstr_all( + "pow", + r""" +pow(exponent) -> Tensor + +See :func:`torch.pow` +""", +) + +add_docstr_all( + "pow_", + r""" +pow_(exponent) -> Tensor + +In-place version of :meth:`~Tensor.pow` +""", +) + +add_docstr_all( + "float_power", + r""" +float_power(exponent) -> Tensor + +See :func:`torch.float_power` +""", +) + +add_docstr_all( + "float_power_", + r""" +float_power_(exponent) -> Tensor + +In-place version of :meth:`~Tensor.float_power` +""", +) + +add_docstr_all( + "prod", + r""" +prod(dim=None, keepdim=False, dtype=None) -> Tensor + +See :func:`torch.prod` +""", +) + +add_docstr_all( + "put_", + r""" +put_(index, source, accumulate=False) -> Tensor + +Copies the elements from :attr:`source` into the positions specified by +:attr:`index`. For the purpose of indexing, the :attr:`self` tensor is treated as if +it were a 1-D tensor. + +:attr:`index` and :attr:`source` need to have the same number of elements, but not necessarily +the same shape. + +If :attr:`accumulate` is ``True``, the elements in :attr:`source` are added to +:attr:`self`. If accumulate is ``False``, the behavior is undefined if :attr:`index` +contain duplicate elements. + +Args: + index (LongTensor): the indices into self + source (Tensor): the tensor containing values to copy from + accumulate (bool): whether to accumulate into self + +Example:: + + >>> src = torch.tensor([[4, 3, 5], + ... [6, 7, 8]]) + >>> src.put_(torch.tensor([1, 3]), torch.tensor([9, 10])) + tensor([[ 4, 9, 5], + [ 10, 7, 8]]) +""", +) + +add_docstr_all( + "put", + r""" +put(input, index, source, accumulate=False) -> Tensor + +Out-of-place version of :meth:`torch.Tensor.put_`. +`input` corresponds to `self` in :meth:`torch.Tensor.put_`. +""", +) + +add_docstr_all( + "qr", + r""" +qr(some=True) -> (Tensor, Tensor) + +See :func:`torch.qr` +""", +) + +add_docstr_all( + "qscheme", + r""" +qscheme() -> torch.qscheme + +Returns the quantization scheme of a given QTensor. +""", +) + +add_docstr_all( + "quantile", + r""" +quantile(q, dim=None, keepdim=False, *, interpolation='linear') -> Tensor + +See :func:`torch.quantile` +""", +) + +add_docstr_all( + "nanquantile", + r""" +nanquantile(q, dim=None, keepdim=False, *, interpolation='linear') -> Tensor + +See :func:`torch.nanquantile` +""", +) + +add_docstr_all( + "q_scale", + r""" +q_scale() -> float + +Given a Tensor quantized by linear(affine) quantization, +returns the scale of the underlying quantizer(). +""", +) + +add_docstr_all( + "q_zero_point", + r""" +q_zero_point() -> int + +Given a Tensor quantized by linear(affine) quantization, +returns the zero_point of the underlying quantizer(). +""", +) + +add_docstr_all( + "q_per_channel_scales", + r""" +q_per_channel_scales() -> Tensor + +Given a Tensor quantized by linear (affine) per-channel quantization, +returns a Tensor of scales of the underlying quantizer. It has the number of +elements that matches the corresponding dimensions (from q_per_channel_axis) of +the tensor. +""", +) + +add_docstr_all( + "q_per_channel_zero_points", + r""" +q_per_channel_zero_points() -> Tensor + +Given a Tensor quantized by linear (affine) per-channel quantization, +returns a tensor of zero_points of the underlying quantizer. It has the number of +elements that matches the corresponding dimensions (from q_per_channel_axis) of +the tensor. +""", +) + +add_docstr_all( + "q_per_channel_axis", + r""" +q_per_channel_axis() -> int + +Given a Tensor quantized by linear (affine) per-channel quantization, +returns the index of dimension on which per-channel quantization is applied. +""", +) + +add_docstr_all( + "random_", + r""" +random_(from=0, to=None, *, generator=None) -> Tensor + +Fills :attr:`self` tensor with numbers sampled from the discrete uniform +distribution over ``[from, to - 1]``. If not specified, the values are usually +only bounded by :attr:`self` tensor's data type. However, for floating point +types, if unspecified, range will be ``[0, 2^mantissa]`` to ensure that every +value is representable. For example, `torch.tensor(1, dtype=torch.double).random_()` +will be uniform in ``[0, 2^53]``. +""", +) + +add_docstr_all( + "rad2deg", + r""" +rad2deg() -> Tensor + +See :func:`torch.rad2deg` +""", +) + +add_docstr_all( + "rad2deg_", + r""" +rad2deg_() -> Tensor + +In-place version of :meth:`~Tensor.rad2deg` +""", +) + +add_docstr_all( + "deg2rad", + r""" +deg2rad() -> Tensor + +See :func:`torch.deg2rad` +""", +) + +add_docstr_all( + "deg2rad_", + r""" +deg2rad_() -> Tensor + +In-place version of :meth:`~Tensor.deg2rad` +""", +) + +add_docstr_all( + "ravel", + r""" +ravel() -> Tensor + +see :func:`torch.ravel` +""", +) + +add_docstr_all( + "reciprocal", + r""" +reciprocal() -> Tensor + +See :func:`torch.reciprocal` +""", +) + +add_docstr_all( + "reciprocal_", + r""" +reciprocal_() -> Tensor + +In-place version of :meth:`~Tensor.reciprocal` +""", +) + +add_docstr_all( + "record_stream", + r""" +record_stream(stream) + +Marks the tensor as having been used by this stream. When the tensor +is deallocated, ensure the tensor memory is not reused for another tensor +until all work queued on :attr:`stream` at the time of deallocation is +complete. + +.. note:: + + The caching allocator is aware of only the stream where a tensor was + allocated. Due to the awareness, it already correctly manages the life + cycle of tensors on only one stream. But if a tensor is used on a stream + different from the stream of origin, the allocator might reuse the memory + unexpectedly. Calling this method lets the allocator know which streams + have used the tensor. + +.. warning:: + + This method is most suitable for use cases where you are providing a + function that created a tensor on a side stream, and want users to be able + to make use of the tensor without having to think carefully about stream + safety when making use of them. These safety guarantees come at some + performance and predictability cost (analogous to the tradeoff between GC + and manual memory management), so if you are in a situation where + you manage the full lifetime of your tensors, you may consider instead + manually managing CUDA events so that calling this method is not necessary. + In particular, when you call this method, on later allocations the + allocator will poll the recorded stream to see if all operations have + completed yet; you can potentially race with side stream computation and + non-deterministically reuse or fail to reuse memory for an allocation. + + You can safely use tensors allocated on side streams without + :meth:`~Tensor.record_stream`; you must manually ensure that + any non-creation stream uses of a tensor are synced back to the creation + stream before you deallocate the tensor. As the CUDA caching allocator + guarantees that the memory will only be reused with the same creation stream, + this is sufficient to ensure that writes to future reallocations of the + memory will be delayed until non-creation stream uses are done. + (Counterintuitively, you may observe that on the CPU side we have already + reallocated the tensor, even though CUDA kernels on the old tensor are + still in progress. This is fine, because CUDA operations on the new + tensor will appropriately wait for the old operations to complete, as they + are all on the same stream.) + + Concretely, this looks like this:: + + with torch.cuda.stream(s0): + x = torch.zeros(N) + + s1.wait_stream(s0) + with torch.cuda.stream(s1): + y = some_comm_op(x) + + ... some compute on s0 ... + + # synchronize creation stream s0 to side stream s1 + # before deallocating x + s0.wait_stream(s1) + del x + + Note that some discretion is required when deciding when to perform + ``s0.wait_stream(s1)``. In particular, if we were to wait immediately + after ``some_comm_op``, there wouldn't be any point in having the side + stream; it would be equivalent to have run ``some_comm_op`` on ``s0``. + Instead, the synchronization must be placed at some appropriate, later + point in time where you expect the side stream ``s1`` to have finished + work. This location is typically identified via profiling, e.g., using + Chrome traces produced + :meth:`torch.autograd.profiler.profile.export_chrome_trace`. If you + place the wait too early, work on s0 will block until ``s1`` has finished, + preventing further overlapping of communication and computation. If you + place the wait too late, you will use more memory than is strictly + necessary (as you are keeping ``x`` live for longer.) For a concrete + example of how this guidance can be applied in practice, see this post: + `FSDP and CUDACachingAllocator + `_. +""", +) + +add_docstr_all( + "remainder", + r""" +remainder(divisor) -> Tensor + +See :func:`torch.remainder` +""", +) + +add_docstr_all( + "remainder_", + r""" +remainder_(divisor) -> Tensor + +In-place version of :meth:`~Tensor.remainder` +""", +) + +add_docstr_all( + "renorm", + r""" +renorm(p, dim, maxnorm) -> Tensor + +See :func:`torch.renorm` +""", +) + +add_docstr_all( + "renorm_", + r""" +renorm_(p, dim, maxnorm) -> Tensor + +In-place version of :meth:`~Tensor.renorm` +""", +) + +add_docstr_all( + "repeat", + r""" +repeat(*repeats) -> Tensor + +Repeats this tensor along the specified dimensions. + +Unlike :meth:`~Tensor.expand`, this function copies the tensor's data. + +.. warning:: + + :meth:`~Tensor.repeat` behaves differently from + `numpy.repeat `_, + but is more similar to + `numpy.tile `_. + For the operator similar to `numpy.repeat`, see :func:`torch.repeat_interleave`. + +Args: + repeat (torch.Size, int..., tuple of int or list of int): The number of times to repeat this tensor along each dimension + +Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> x.repeat(4, 2) + tensor([[ 1, 2, 3, 1, 2, 3], + [ 1, 2, 3, 1, 2, 3], + [ 1, 2, 3, 1, 2, 3], + [ 1, 2, 3, 1, 2, 3]]) + >>> x.repeat(4, 2, 1).size() + torch.Size([4, 2, 3]) +""", +) + +add_docstr_all( + "repeat_interleave", + r""" +repeat_interleave(repeats, dim=None, *, output_size=None) -> Tensor + +See :func:`torch.repeat_interleave`. +""", +) + +add_docstr_all( + "requires_grad_", + r""" +requires_grad_(requires_grad=True) -> Tensor + +Change if autograd should record operations on this tensor: sets this tensor's +:attr:`requires_grad` attribute in-place. Returns this tensor. + +:func:`requires_grad_`'s main use case is to tell autograd to begin recording +operations on a Tensor ``tensor``. If ``tensor`` has ``requires_grad=False`` +(because it was obtained through a DataLoader, or required preprocessing or +initialization), ``tensor.requires_grad_()`` makes it so that autograd will +begin to record operations on ``tensor``. + +Args: + requires_grad (bool): If autograd should record operations on this tensor. + Default: ``True``. + +Example:: + + >>> # Let's say we want to preprocess some saved weights and use + >>> # the result as new weights. + >>> saved_weights = [0.1, 0.2, 0.3, 0.25] + >>> loaded_weights = torch.tensor(saved_weights) + >>> weights = preprocess(loaded_weights) # some function + >>> weights + tensor([-0.5503, 0.4926, -2.1158, -0.8303]) + + >>> # Now, start to record operations done to weights + >>> weights.requires_grad_() + >>> out = weights.pow(2).sum() + >>> out.backward() + >>> weights.grad + tensor([-1.1007, 0.9853, -4.2316, -1.6606]) + +""", +) + +add_docstr_all( + "reshape", + r""" +reshape(*shape) -> Tensor + +Returns a tensor with the same data and number of elements as :attr:`self` +but with the specified shape. This method returns a view if :attr:`shape` is +compatible with the current shape. See :meth:`torch.Tensor.view` on when it is +possible to return a view. + +See :func:`torch.reshape` + +Args: + shape (tuple of ints or int...): the desired shape + +""", +) + +add_docstr_all( + "reshape_as", + r""" +reshape_as(other) -> Tensor + +Returns this tensor as the same shape as :attr:`other`. +``self.reshape_as(other)`` is equivalent to ``self.reshape(other.sizes())``. +This method returns a view if ``other.sizes()`` is compatible with the current +shape. See :meth:`torch.Tensor.view` on when it is possible to return a view. + +Please see :meth:`reshape` for more information about ``reshape``. + +Args: + other (:class:`torch.Tensor`): The result tensor has the same shape + as :attr:`other`. +""", +) + +add_docstr_all( + "resize_", + r""" +resize_(*sizes, memory_format=torch.contiguous_format) -> Tensor + +Resizes :attr:`self` tensor to the specified size. If the number of elements is +larger than the current storage size, then the underlying storage is resized +to fit the new number of elements. If the number of elements is smaller, the +underlying storage is not changed. Existing elements are preserved but any new +memory is uninitialized. + +.. warning:: + + This is a low-level method. The storage is reinterpreted as C-contiguous, + ignoring the current strides (unless the target size equals the current + size, in which case the tensor is left unchanged). For most purposes, you + will instead want to use :meth:`~Tensor.view()`, which checks for + contiguity, or :meth:`~Tensor.reshape()`, which copies data if needed. To + change the size in-place with custom strides, see :meth:`~Tensor.set_()`. + +.. note:: + + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, new elements are initialized to prevent nondeterministic behavior + from using the result as an input to an operation. Floating point and + complex values are set to NaN, and integer values are set to the maximum + value. + +Args: + sizes (torch.Size or int...): the desired size + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + Tensor. Default: ``torch.contiguous_format``. Note that memory format of + :attr:`self` is going to be unaffected if ``self.size()`` matches ``sizes``. + +Example:: + + >>> x = torch.tensor([[1, 2], [3, 4], [5, 6]]) + >>> x.resize_(2, 2) + tensor([[ 1, 2], + [ 3, 4]]) +""", +) + +add_docstr_all( + "resize_as_", + r""" +resize_as_(tensor, memory_format=torch.contiguous_format) -> Tensor + +Resizes the :attr:`self` tensor to be the same size as the specified +:attr:`tensor`. This is equivalent to ``self.resize_(tensor.size())``. + +Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + Tensor. Default: ``torch.contiguous_format``. Note that memory format of + :attr:`self` is going to be unaffected if ``self.size()`` matches ``tensor.size()``. + +""", +) + +add_docstr_all( + "rot90", + r""" +rot90(k, dims) -> Tensor + +See :func:`torch.rot90` +""", +) + +add_docstr_all( + "round", + r""" +round(decimals=0) -> Tensor + +See :func:`torch.round` +""", +) + +add_docstr_all( + "round_", + r""" +round_(decimals=0) -> Tensor + +In-place version of :meth:`~Tensor.round` +""", +) + +add_docstr_all( + "rsqrt", + r""" +rsqrt() -> Tensor + +See :func:`torch.rsqrt` +""", +) + +add_docstr_all( + "rsqrt_", + r""" +rsqrt_() -> Tensor + +In-place version of :meth:`~Tensor.rsqrt` +""", +) + +add_docstr_all( + "scatter_", + r""" +scatter_(dim, index, src, *, reduce=None) -> Tensor + +Writes all values from the tensor :attr:`src` into :attr:`self` at the indices +specified in the :attr:`index` tensor. For each value in :attr:`src`, its output +index is specified by its index in :attr:`src` for ``dimension != dim`` and by +the corresponding value in :attr:`index` for ``dimension = dim``. + +For a 3-D tensor, :attr:`self` is updated as:: + + self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0 + self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1 + self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2 + +This is the reverse operation of the manner described in :meth:`~Tensor.gather`. + +:attr:`self`, :attr:`index` and :attr:`src` (if it is a Tensor) should all have +the same number of dimensions. It is also required that +``index.size(d) <= src.size(d)`` for all dimensions ``d``, and that +``index.size(d) <= self.size(d)`` for all dimensions ``d != dim``. +Note that ``index`` and ``src`` do not broadcast. + +Moreover, as for :meth:`~Tensor.gather`, the values of :attr:`index` must be +between ``0`` and ``self.size(dim) - 1`` inclusive. + +.. warning:: + + When indices are not unique, the behavior is non-deterministic (one of the + values from ``src`` will be picked arbitrarily) and the gradient will be + incorrect (it will be propagated to all locations in the source that + correspond to the same index)! + +.. note:: + + The backward pass is implemented only for ``src.shape == index.shape``. + +Additionally accepts an optional :attr:`reduce` argument that allows +specification of an optional reduction operation, which is applied to all +values in the tensor :attr:`src` into :attr:`self` at the indices +specified in the :attr:`index`. For each value in :attr:`src`, the reduction +operation is applied to an index in :attr:`self` which is specified by +its index in :attr:`src` for ``dimension != dim`` and by the corresponding +value in :attr:`index` for ``dimension = dim``. + +Given a 3-D tensor and reduction using the multiplication operation, :attr:`self` +is updated as:: + + self[index[i][j][k]][j][k] *= src[i][j][k] # if dim == 0 + self[i][index[i][j][k]][k] *= src[i][j][k] # if dim == 1 + self[i][j][index[i][j][k]] *= src[i][j][k] # if dim == 2 + +Reducing with the addition operation is the same as using +:meth:`~torch.Tensor.scatter_add_`. + +.. warning:: + The reduce argument with Tensor ``src`` is deprecated and will be removed in + a future PyTorch release. Please use :meth:`~torch.Tensor.scatter_reduce_` + instead for more reduction options. + +Args: + dim (int): the axis along which to index + index (LongTensor): the indices of elements to scatter, can be either empty + or of the same dimensionality as ``src``. When empty, the operation + returns ``self`` unchanged. + src (Tensor): the source element(s) to scatter. + +Keyword args: + reduce (str, optional): reduction operation to apply, can be either + ``'add'`` or ``'multiply'``. + +Example:: + + >>> src = torch.arange(1, 11).reshape((2, 5)) + >>> src + tensor([[ 1, 2, 3, 4, 5], + [ 6, 7, 8, 9, 10]]) + >>> index = torch.tensor([[0, 1, 2, 0]]) + >>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src) + tensor([[1, 0, 0, 4, 0], + [0, 2, 0, 0, 0], + [0, 0, 3, 0, 0]]) + >>> index = torch.tensor([[0, 1, 2], [0, 1, 4]]) + >>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src) + tensor([[1, 2, 3, 0, 0], + [6, 7, 0, 0, 8], + [0, 0, 0, 0, 0]]) + + >>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]), + ... 1.23, reduce='multiply') + tensor([[2.0000, 2.0000, 2.4600, 2.0000], + [2.0000, 2.0000, 2.0000, 2.4600]]) + >>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]), + ... 1.23, reduce='add') + tensor([[2.0000, 2.0000, 3.2300, 2.0000], + [2.0000, 2.0000, 2.0000, 3.2300]]) + +.. function:: scatter_(dim, index, value, *, reduce=None) -> Tensor: + :noindex: + +Writes the value from :attr:`value` into :attr:`self` at the indices +specified in the :attr:`index` tensor. This operation is equivalent to the previous version, +with the :attr:`src` tensor filled entirely with :attr:`value`. + +Args: + dim (int): the axis along which to index + index (LongTensor): the indices of elements to scatter, can be either empty + or of the same dimensionality as ``src``. When empty, the operation + returns ``self`` unchanged. + value (Scalar): the value to scatter. + +Keyword args: + reduce (str, optional): reduction operation to apply, can be either + ``'add'`` or ``'multiply'``. + +Example:: + + >>> index = torch.tensor([[0, 1]]) + >>> value = 2 + >>> torch.zeros(3, 5).scatter_(0, index, value) + tensor([[2., 0., 0., 0., 0.], + [0., 2., 0., 0., 0.], + [0., 0., 0., 0., 0.]]) +""", +) + +add_docstr_all( + "scatter_add_", + r""" +scatter_add_(dim, index, src) -> Tensor + +Adds all values from the tensor :attr:`src` into :attr:`self` at the indices +specified in the :attr:`index` tensor in a similar fashion as +:meth:`~torch.Tensor.scatter_`. For each value in :attr:`src`, it is added to +an index in :attr:`self` which is specified by its index in :attr:`src` +for ``dimension != dim`` and by the corresponding value in :attr:`index` for +``dimension = dim``. + +For a 3-D tensor, :attr:`self` is updated as:: + + self[index[i][j][k]][j][k] += src[i][j][k] # if dim == 0 + self[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1 + self[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2 + +:attr:`self`, :attr:`index` and :attr:`src` should have same number of +dimensions. It is also required that ``index.size(d) <= src.size(d)`` for all +dimensions ``d``, and that ``index.size(d) <= self.size(d)`` for all dimensions +``d != dim``. Note that ``index`` and ``src`` do not broadcast. + +Note: + {forward_reproducibility_note} + +.. note:: + + The backward pass is implemented only for ``src.shape == index.shape``. + +Args: + dim (int): the axis along which to index + index (LongTensor): the indices of elements to scatter and add, can be + either empty or of the same dimensionality as ``src``. When empty, the + operation returns ``self`` unchanged. + src (Tensor): the source elements to scatter and add + +Example:: + + >>> src = torch.ones((2, 5)) + >>> index = torch.tensor([[0, 1, 2, 0, 0]]) + >>> torch.zeros(3, 5, dtype=src.dtype).scatter_add_(0, index, src) + tensor([[1., 0., 0., 1., 1.], + [0., 1., 0., 0., 0.], + [0., 0., 1., 0., 0.]]) + >>> index = torch.tensor([[0, 1, 2, 0, 0], [0, 1, 2, 2, 2]]) + >>> torch.zeros(3, 5, dtype=src.dtype).scatter_add_(0, index, src) + tensor([[2., 0., 0., 1., 1.], + [0., 2., 0., 0., 0.], + [0., 0., 2., 1., 1.]]) + +""".format(**reproducibility_notes), +) + +add_docstr_all( + "scatter_reduce_", + r""" +scatter_reduce_(dim, index, src, reduce, *, include_self=True) -> Tensor + +Reduces all values from the :attr:`src` tensor to the indices specified in +the :attr:`index` tensor in the :attr:`self` tensor using the applied reduction +defined via the :attr:`reduce` argument (:obj:`"sum"`, :obj:`"prod"`, :obj:`"mean"`, +:obj:`"amax"`, :obj:`"amin"`). For each value in :attr:`src`, it is reduced to an +index in :attr:`self` which is specified by its index in :attr:`src` for +``dimension != dim`` and by the corresponding value in :attr:`index` for +``dimension = dim``. If :obj:`include_self="True"`, the values in the :attr:`self` +tensor are included in the reduction. + +:attr:`self`, :attr:`index` and :attr:`src` should all have +the same number of dimensions. It is also required that +``index.size(d) <= src.size(d)`` for all dimensions ``d``, and that +``index.size(d) <= self.size(d)`` for all dimensions ``d != dim``. +Note that ``index`` and ``src`` do not broadcast. + +For a 3-D tensor with :obj:`reduce="sum"` and :obj:`include_self=True` the +output is given as:: + + self[index[i][j][k]][j][k] += src[i][j][k] # if dim == 0 + self[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1 + self[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2 + +Note: + {forward_reproducibility_note} + +.. note:: + + The backward pass is implemented only for ``src.shape == index.shape``. + +.. warning:: + + This function is in beta and may change in the near future. + +Args: + dim (int): the axis along which to index + index (LongTensor): the indices of elements to scatter and reduce. + src (Tensor): the source elements to scatter and reduce + reduce (str): the reduction operation to apply for non-unique indices + (:obj:`"sum"`, :obj:`"prod"`, :obj:`"mean"`, :obj:`"amax"`, :obj:`"amin"`) + include_self (bool): whether elements from the :attr:`self` tensor are + included in the reduction + +Example:: + + >>> src = torch.tensor([1., 2., 3., 4., 5., 6.]) + >>> index = torch.tensor([0, 1, 0, 1, 2, 1]) + >>> input = torch.tensor([1., 2., 3., 4.]) + >>> input.scatter_reduce(0, index, src, reduce="sum") + tensor([5., 14., 8., 4.]) + >>> input.scatter_reduce(0, index, src, reduce="sum", include_self=False) + tensor([4., 12., 5., 4.]) + >>> input2 = torch.tensor([5., 4., 3., 2.]) + >>> input2.scatter_reduce(0, index, src, reduce="amax") + tensor([5., 6., 5., 2.]) + >>> input2.scatter_reduce(0, index, src, reduce="amax", include_self=False) + tensor([3., 6., 5., 2.]) + + +""".format(**reproducibility_notes), +) + +add_docstr_all( + "select", + r""" +select(dim, index) -> Tensor + +See :func:`torch.select` +""", +) + +add_docstr_all( + "select_scatter", + r""" +select_scatter(src, dim, index) -> Tensor + +See :func:`torch.select_scatter` +""", +) + +add_docstr_all( + "slice_scatter", + r""" +slice_scatter(src, dim=0, start=None, end=None, step=1) -> Tensor + +See :func:`torch.slice_scatter` +""", +) + +add_docstr_all( + "set_", + r""" +set_(source=None, storage_offset=0, size=None, stride=None) -> Tensor + +Sets the underlying storage, size, and strides. If :attr:`source` is a tensor, +:attr:`self` tensor will share the same storage and have the same size and +strides as :attr:`source`. Changes to elements in one tensor will be reflected +in the other. + +If :attr:`source` is a :class:`~torch.Storage`, the method sets the underlying +storage, offset, size, and stride. + +Args: + source (Tensor or Storage): the tensor or storage to use + storage_offset (int, optional): the offset in the storage + size (torch.Size, optional): the desired size. Defaults to the size of the source. + stride (tuple, optional): the desired stride. Defaults to C-contiguous strides. +""", +) + +add_docstr_all( + "sigmoid", + r""" +sigmoid() -> Tensor + +See :func:`torch.sigmoid` +""", +) + +add_docstr_all( + "sigmoid_", + r""" +sigmoid_() -> Tensor + +In-place version of :meth:`~Tensor.sigmoid` +""", +) + +add_docstr_all( + "logit", + r""" +logit() -> Tensor + +See :func:`torch.logit` +""", +) + +add_docstr_all( + "logit_", + r""" +logit_() -> Tensor + +In-place version of :meth:`~Tensor.logit` +""", +) + +add_docstr_all( + "sign", + r""" +sign() -> Tensor + +See :func:`torch.sign` +""", +) + +add_docstr_all( + "sign_", + r""" +sign_() -> Tensor + +In-place version of :meth:`~Tensor.sign` +""", +) + +add_docstr_all( + "signbit", + r""" +signbit() -> Tensor + +See :func:`torch.signbit` +""", +) + +add_docstr_all( + "sgn", + r""" +sgn() -> Tensor + +See :func:`torch.sgn` +""", +) + +add_docstr_all( + "sgn_", + r""" +sgn_() -> Tensor + +In-place version of :meth:`~Tensor.sgn` +""", +) + +add_docstr_all( + "sin", + r""" +sin() -> Tensor + +See :func:`torch.sin` +""", +) + +add_docstr_all( + "sin_", + r""" +sin_() -> Tensor + +In-place version of :meth:`~Tensor.sin` +""", +) + +add_docstr_all( + "sinc", + r""" +sinc() -> Tensor + +See :func:`torch.sinc` +""", +) + +add_docstr_all( + "sinc_", + r""" +sinc_() -> Tensor + +In-place version of :meth:`~Tensor.sinc` +""", +) + +add_docstr_all( + "sinh", + r""" +sinh() -> Tensor + +See :func:`torch.sinh` +""", +) + +add_docstr_all( + "sinh_", + r""" +sinh_() -> Tensor + +In-place version of :meth:`~Tensor.sinh` +""", +) + +add_docstr_all( + "size", + r""" +size(dim=None) -> torch.Size or int + +Returns the size of the :attr:`self` tensor. If ``dim`` is not specified, +the returned value is a :class:`torch.Size`, a subclass of :class:`tuple`. +If ``dim`` is specified, returns an int holding the size of that dimension. + +Args: + dim (int, optional): The dimension for which to retrieve the size. + +Example:: + + >>> t = torch.empty(3, 4, 5) + >>> t.size() + torch.Size([3, 4, 5]) + >>> t.size(dim=1) + 4 + +""", +) + +add_docstr_all( + "shape", + r""" +shape() -> torch.Size + +Returns the size of the :attr:`self` tensor. Alias for :attr:`size`. + +See also :meth:`Tensor.size`. + +Example:: + + >>> t = torch.empty(3, 4, 5) + >>> t.size() + torch.Size([3, 4, 5]) + >>> t.shape + torch.Size([3, 4, 5]) + +""", +) + +add_docstr_all( + "sort", + r""" +sort(dim=-1, descending=False) -> (Tensor, LongTensor) + +See :func:`torch.sort` +""", +) + +add_docstr_all( + "msort", + r""" +msort() -> Tensor + +See :func:`torch.msort` +""", +) + +add_docstr_all( + "argsort", + r""" +argsort(dim=-1, descending=False) -> LongTensor + +See :func:`torch.argsort` +""", +) + +add_docstr_all( + "sparse_dim", + r""" +sparse_dim() -> int + +Return the number of sparse dimensions in a :ref:`sparse tensor ` :attr:`self`. + +.. note:: + Returns ``0`` if :attr:`self` is not a sparse tensor. + +See also :meth:`Tensor.dense_dim` and :ref:`hybrid tensors `. +""", +) + +add_docstr_all( + "sparse_resize_", + r""" +sparse_resize_(size, sparse_dim, dense_dim) -> Tensor + +Resizes :attr:`self` :ref:`sparse tensor ` to the desired +size and the number of sparse and dense dimensions. + +.. note:: + If the number of specified elements in :attr:`self` is zero, then + :attr:`size`, :attr:`sparse_dim`, and :attr:`dense_dim` can be any + size and positive integers such that ``len(size) == sparse_dim + + dense_dim``. + + If :attr:`self` specifies one or more elements, however, then each + dimension in :attr:`size` must not be smaller than the corresponding + dimension of :attr:`self`, :attr:`sparse_dim` must equal the number + of sparse dimensions in :attr:`self`, and :attr:`dense_dim` must + equal the number of dense dimensions in :attr:`self`. + +.. warning:: + Throws an error if :attr:`self` is not a sparse tensor. + +Args: + size (torch.Size): the desired size. If :attr:`self` is non-empty + sparse tensor, the desired size cannot be smaller than the + original size. + sparse_dim (int): the number of sparse dimensions + dense_dim (int): the number of dense dimensions +""", +) + +add_docstr_all( + "sparse_resize_and_clear_", + r""" +sparse_resize_and_clear_(size, sparse_dim, dense_dim) -> Tensor + +Removes all specified elements from a :ref:`sparse tensor +` :attr:`self` and resizes :attr:`self` to the desired +size and the number of sparse and dense dimensions. + +.. warning: + Throws an error if :attr:`self` is not a sparse tensor. + +Args: + size (torch.Size): the desired size. + sparse_dim (int): the number of sparse dimensions + dense_dim (int): the number of dense dimensions +""", +) + +add_docstr_all( + "sqrt", + r""" +sqrt() -> Tensor + +See :func:`torch.sqrt` +""", +) + +add_docstr_all( + "sqrt_", + r""" +sqrt_() -> Tensor + +In-place version of :meth:`~Tensor.sqrt` +""", +) + +add_docstr_all( + "square", + r""" +square() -> Tensor + +See :func:`torch.square` +""", +) + +add_docstr_all( + "square_", + r""" +square_() -> Tensor + +In-place version of :meth:`~Tensor.square` +""", +) + +add_docstr_all( + "squeeze", + r""" +squeeze(dim=None) -> Tensor + +See :func:`torch.squeeze` +""", +) + +add_docstr_all( + "squeeze_", + r""" +squeeze_(dim=None) -> Tensor + +In-place version of :meth:`~Tensor.squeeze` +""", +) + +add_docstr_all( + "std", + r""" +std(dim=None, *, correction=1, keepdim=False) -> Tensor + +See :func:`torch.std` +""", +) + +add_docstr_all( + "storage_offset", + r""" +storage_offset() -> int + +Returns :attr:`self` tensor's offset in the underlying storage in terms of +number of storage elements (not bytes). + +Example:: + + >>> x = torch.tensor([1, 2, 3, 4, 5]) + >>> x.storage_offset() + 0 + >>> x[3:].storage_offset() + 3 + +""", +) + +add_docstr_all( + "untyped_storage", + r""" +untyped_storage() -> torch.UntypedStorage + +Returns the underlying :class:`UntypedStorage`. +""", +) + +add_docstr_all( + "stride", + r""" +stride(dim) -> tuple or int + +Returns the stride of :attr:`self` tensor. + +Stride is the jump necessary to go from one element to the next one in the +specified dimension :attr:`dim`. A tuple of all strides is returned when no +argument is passed in. Otherwise, an integer value is returned as the stride in +the particular dimension :attr:`dim`. + +Args: + dim (int, optional): the desired dimension in which stride is required + +Example:: + + >>> x = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]) + >>> x.stride() + (5, 1) + >>> x.stride(0) + 5 + >>> x.stride(-1) + 1 + +""", +) + +add_docstr_all( + "sub", + r""" +sub(other, *, alpha=1) -> Tensor + +See :func:`torch.sub`. +""", +) + +add_docstr_all( + "sub_", + r""" +sub_(other, *, alpha=1) -> Tensor + +In-place version of :meth:`~Tensor.sub` +""", +) + +add_docstr_all( + "subtract", + r""" +subtract(other, *, alpha=1) -> Tensor + +See :func:`torch.subtract`. +""", +) + +add_docstr_all( + "subtract_", + r""" +subtract_(other, *, alpha=1) -> Tensor + +In-place version of :meth:`~Tensor.subtract`. +""", +) + +add_docstr_all( + "sum", + r""" +sum(dim=None, keepdim=False, dtype=None) -> Tensor + +See :func:`torch.sum` +""", +) + +add_docstr_all( + "nansum", + r""" +nansum(dim=None, keepdim=False, dtype=None) -> Tensor + +See :func:`torch.nansum` +""", +) + +add_docstr_all( + "svd", + r""" +svd(some=True, compute_uv=True) -> (Tensor, Tensor, Tensor) + +See :func:`torch.svd` +""", +) + +add_docstr_all( + "swapdims", + r""" +swapdims(dim0, dim1) -> Tensor + +See :func:`torch.swapdims` +""", +) + +add_docstr_all( + "swapdims_", + r""" +swapdims_(dim0, dim1) -> Tensor + +In-place version of :meth:`~Tensor.swapdims` +""", +) + +add_docstr_all( + "swapaxes", + r""" +swapaxes(axis0, axis1) -> Tensor + +See :func:`torch.swapaxes` +""", +) + +add_docstr_all( + "swapaxes_", + r""" +swapaxes_(axis0, axis1) -> Tensor + +In-place version of :meth:`~Tensor.swapaxes` +""", +) + +add_docstr_all( + "t", + r""" +t() -> Tensor + +See :func:`torch.t` +""", +) + +add_docstr_all( + "t_", + r""" +t_() -> Tensor + +In-place version of :meth:`~Tensor.t` +""", +) + +add_docstr_all( + "tile", + r""" +tile(dims) -> Tensor + +See :func:`torch.tile` +""", +) + +add_docstr_all( + "to", + r""" +to(*args, **kwargs) -> Tensor + +Performs Tensor dtype and/or device conversion. A :class:`torch.dtype` and :class:`torch.device` are +inferred from the arguments of ``self.to(*args, **kwargs)``. + +.. note:: + + If the ``self`` Tensor already + has the correct :class:`torch.dtype` and :class:`torch.device`, then ``self`` is returned. + Otherwise, the returned tensor is a copy of ``self`` with the desired + :class:`torch.dtype` and :class:`torch.device`. + +.. note:: + + If ``self`` requires gradients (``requires_grad=True``) but the target + ``dtype`` specified is an integer type, the returned tensor will implicitly + set ``requires_grad=False``. This is because only tensors with + floating-point or complex dtypes can require gradients. + +Here are the ways to call ``to``: + +.. method:: to(dtype, non_blocking=False, copy=False, memory_format=torch.preserve_format) -> Tensor + :noindex: + + Returns a Tensor with the specified :attr:`dtype` + + Args: + {memory_format} + +.. note:: + + According to `C++ type conversion rules `_, + converting floating point value to integer type will truncate the fractional part. + If the truncated value cannot fit into the target type (e.g., casting ``torch.inf`` to ``torch.long``), + the behavior is undefined and the result may vary across platforms. + +.. method:: to(device=None, dtype=None, non_blocking=False, copy=False, memory_format=torch.preserve_format) -> Tensor + :noindex: + + Returns a Tensor with the specified :attr:`device` and (optional) + :attr:`dtype`. If :attr:`dtype` is ``None`` it is inferred to be ``self.dtype``. + When :attr:`non_blocking` is set to ``True``, the function attempts to perform + the conversion asynchronously with respect to the host, if possible. This + asynchronous behavior applies to both pinned and pageable memory. However, + caution is advised when using this feature. For more information, refer to the + `tutorial on good usage of non_blocking and pin_memory `__. + When :attr:`copy` is set, a new Tensor is created even when the Tensor + already matches the desired conversion. + + Args: + {memory_format} + +.. method:: to(other, non_blocking=False, copy=False) -> Tensor + :noindex: + + Returns a Tensor with same :class:`torch.dtype` and :class:`torch.device` as + the Tensor :attr:`other`. + When :attr:`non_blocking` is set to ``True``, the function attempts to perform + the conversion asynchronously with respect to the host, if possible. This + asynchronous behavior applies to both pinned and pageable memory. However, + caution is advised when using this feature. For more information, refer to the + `tutorial on good usage of non_blocking and pin_memory `__. + When :attr:`copy` is set, a new Tensor is created even when the Tensor + already matches the desired conversion. + +Example:: + + >>> tensor = torch.randn(2, 2) # Initially dtype=float32, device=cpu + >>> tensor.to(torch.float64) + tensor([[-0.5044, 0.0005], + [ 0.3310, -0.0584]], dtype=torch.float64) + + >>> cuda0 = torch.device('cuda:0') + >>> tensor.to(cuda0) + tensor([[-0.5044, 0.0005], + [ 0.3310, -0.0584]], device='cuda:0') + + >>> tensor.to(cuda0, dtype=torch.float64) + tensor([[-0.5044, 0.0005], + [ 0.3310, -0.0584]], dtype=torch.float64, device='cuda:0') + + >>> other = torch.randn((), dtype=torch.float64, device=cuda0) + >>> tensor.to(other, non_blocking=True) + tensor([[-0.5044, 0.0005], + [ 0.3310, -0.0584]], dtype=torch.float64, device='cuda:0') +""".format(**common_args), +) + +add_docstr_all( + "byte", + r""" +byte(memory_format=torch.preserve_format) -> Tensor + +``self.byte()`` is equivalent to ``self.to(torch.uint8)``. See :func:`to`. + +Args: + {memory_format} +""".format(**common_args), +) + +add_docstr_all( + "bool", + r""" +bool(memory_format=torch.preserve_format) -> Tensor + +``self.bool()`` is equivalent to ``self.to(torch.bool)``. See :func:`to`. + +Args: + {memory_format} +""".format(**common_args), +) + +add_docstr_all( + "char", + r""" +char(memory_format=torch.preserve_format) -> Tensor + +``self.char()`` is equivalent to ``self.to(torch.int8)``. See :func:`to`. + +Args: + {memory_format} +""".format(**common_args), +) + +add_docstr_all( + "bfloat16", + r""" +bfloat16(memory_format=torch.preserve_format) -> Tensor +``self.bfloat16()`` is equivalent to ``self.to(torch.bfloat16)``. See :func:`to`. + +Args: + {memory_format} +""".format(**common_args), +) + +add_docstr_all( + "double", + r""" +double(memory_format=torch.preserve_format) -> Tensor + +``self.double()`` is equivalent to ``self.to(torch.float64)``. See :func:`to`. + +Args: + {memory_format} +""".format(**common_args), +) + +add_docstr_all( + "float", + r""" +float(memory_format=torch.preserve_format) -> Tensor + +``self.float()`` is equivalent to ``self.to(torch.float32)``. See :func:`to`. + +Args: + {memory_format} +""".format(**common_args), +) + +add_docstr_all( + "cdouble", + r""" +cdouble(memory_format=torch.preserve_format) -> Tensor + +``self.cdouble()`` is equivalent to ``self.to(torch.complex128)``. See :func:`to`. + +Args: + {memory_format} +""".format(**common_args), +) + +add_docstr_all( + "cfloat", + r""" +cfloat(memory_format=torch.preserve_format) -> Tensor + +``self.cfloat()`` is equivalent to ``self.to(torch.complex64)``. See :func:`to`. + +Args: + {memory_format} +""".format(**common_args), +) + +add_docstr_all( + "chalf", + r""" +chalf(memory_format=torch.preserve_format) -> Tensor + +``self.chalf()`` is equivalent to ``self.to(torch.complex32)``. See :func:`to`. + +Args: + {memory_format} + """.format(**common_args), +) + +add_docstr_all( + "half", + r""" +half(memory_format=torch.preserve_format) -> Tensor + +``self.half()`` is equivalent to ``self.to(torch.float16)``. See :func:`to`. + +Args: + {memory_format} +""".format(**common_args), +) + +add_docstr_all( + "int", + r""" +int(memory_format=torch.preserve_format) -> Tensor + +``self.int()`` is equivalent to ``self.to(torch.int32)``. See :func:`to`. + +Args: + {memory_format} +""".format(**common_args), +) + +add_docstr_all( + "int_repr", + r""" +int_repr() -> Tensor + +Given a quantized Tensor, +``self.int_repr()`` returns a CPU Tensor with uint8_t as data type that stores the +underlying uint8_t values of the given Tensor. +""", +) + + +add_docstr_all( + "long", + r""" +long(memory_format=torch.preserve_format) -> Tensor + +``self.long()`` is equivalent to ``self.to(torch.int64)``. See :func:`to`. + +Args: + {memory_format} +""".format(**common_args), +) + +add_docstr_all( + "short", + r""" +short(memory_format=torch.preserve_format) -> Tensor + +``self.short()`` is equivalent to ``self.to(torch.int16)``. See :func:`to`. + +Args: + {memory_format} +""".format(**common_args), +) + +add_docstr_all( + "take", + r""" +take(indices) -> Tensor + +See :func:`torch.take` +""", +) + +add_docstr_all( + "take_along_dim", + r""" +take_along_dim(indices, dim) -> Tensor + +See :func:`torch.take_along_dim` +""", +) + +add_docstr_all( + "tan", + r""" +tan() -> Tensor + +See :func:`torch.tan` +""", +) + +add_docstr_all( + "tan_", + r""" +tan_() -> Tensor + +In-place version of :meth:`~Tensor.tan` +""", +) + +add_docstr_all( + "tanh", + r""" +tanh() -> Tensor + +See :func:`torch.tanh` +""", +) + +add_docstr_all( + "softmax", + r""" +softmax(dim) -> Tensor + +Alias for :func:`torch.nn.functional.softmax`. +""", +) + +add_docstr_all( + "tanh_", + r""" +tanh_() -> Tensor + +In-place version of :meth:`~Tensor.tanh` +""", +) + +add_docstr_all( + "tolist", + r""" +tolist() -> list or number + +Returns the tensor as a (nested) list. For scalars, a standard +Python number is returned, just like with :meth:`~Tensor.item`. +Tensors are automatically moved to the CPU first if necessary. + +This operation is not differentiable. + +Examples:: + + >>> a = torch.randn(2, 2) + >>> a.tolist() + [[0.012766935862600803, 0.5415473580360413], + [-0.08909505605697632, 0.7729271650314331]] + >>> a[0,0].tolist() + 0.012766935862600803 +""", +) + +add_docstr_all( + "topk", + r""" +topk(k, dim=None, largest=True, sorted=True) -> (Tensor, LongTensor) + +See :func:`torch.topk` +""", +) + +add_docstr_all( + "to_dense", + r""" +to_dense(dtype=None, *, masked_grad=True) -> Tensor + +Creates a strided copy of :attr:`self` if :attr:`self` is not a strided tensor, otherwise returns :attr:`self`. + +Keyword args: + {dtype} + masked_grad (bool, optional): If set to ``True`` (default) and + :attr:`self` has a sparse layout then the backward of + :meth:`to_dense` returns ``grad.sparse_mask(self)``. + +Example:: + + >>> s = torch.sparse_coo_tensor( + ... torch.tensor([[1, 1], + ... [0, 2]]), + ... torch.tensor([9, 10]), + ... size=(3, 3)) + >>> s.to_dense() + tensor([[ 0, 0, 0], + [ 9, 0, 10], + [ 0, 0, 0]]) +""", +) + +add_docstr_all( + "to_sparse", + r""" +to_sparse(sparseDims) -> Tensor + +Returns a sparse copy of the tensor. PyTorch supports sparse tensors in +:ref:`coordinate format `. + +Args: + sparseDims (int, optional): the number of sparse dimensions to include in the new sparse tensor + +Example:: + + >>> d = torch.tensor([[0, 0, 0], [9, 0, 10], [0, 0, 0]]) + >>> d + tensor([[ 0, 0, 0], + [ 9, 0, 10], + [ 0, 0, 0]]) + >>> d.to_sparse() + tensor(indices=tensor([[1, 1], + [0, 2]]), + values=tensor([ 9, 10]), + size=(3, 3), nnz=2, layout=torch.sparse_coo) + >>> d.to_sparse(1) + tensor(indices=tensor([[1]]), + values=tensor([[ 9, 0, 10]]), + size=(3, 3), nnz=1, layout=torch.sparse_coo) + +.. method:: to_sparse(*, layout=None, blocksize=None, dense_dim=None) -> Tensor + :noindex: + +Returns a sparse tensor with the specified layout and blocksize. If +the :attr:`self` is strided, the number of dense dimensions could be +specified, and a hybrid sparse tensor will be created, with +`dense_dim` dense dimensions and `self.dim() - 2 - dense_dim` batch +dimension. + +.. note:: If the :attr:`self` layout and blocksize parameters match + with the specified layout and blocksize, return + :attr:`self`. Otherwise, return a sparse tensor copy of + :attr:`self`. + +Args: + + layout (:class:`torch.layout`, optional): The desired sparse + layout. One of ``torch.sparse_coo``, ``torch.sparse_csr``, + ``torch.sparse_csc``, ``torch.sparse_bsr``, or + ``torch.sparse_bsc``. Default: if ``None``, + ``torch.sparse_coo``. + + blocksize (list, tuple, :class:`torch.Size`, optional): Block size + of the resulting BSR or BSC tensor. For other layouts, + specifying the block size that is not ``None`` will result in a + RuntimeError exception. A block size must be a tuple of length + two such that its items evenly divide the two sparse dimensions. + + dense_dim (int, optional): Number of dense dimensions of the + resulting CSR, CSC, BSR or BSC tensor. This argument should be + used only if :attr:`self` is a strided tensor, and must be a + value between 0 and dimension of :attr:`self` tensor minus two. + +Example:: + + >>> x = torch.tensor([[1, 0], [0, 0], [2, 3]]) + >>> x.to_sparse(layout=torch.sparse_coo) + tensor(indices=tensor([[0, 2, 2], + [0, 0, 1]]), + values=tensor([1, 2, 3]), + size=(3, 2), nnz=3, layout=torch.sparse_coo) + >>> x.to_sparse(layout=torch.sparse_bsr, blocksize=(1, 2)) + tensor(crow_indices=tensor([0, 1, 1, 2]), + col_indices=tensor([0, 0]), + values=tensor([[[1, 0]], + [[2, 3]]]), size=(3, 2), nnz=2, layout=torch.sparse_bsr) + >>> x.to_sparse(layout=torch.sparse_bsr, blocksize=(2, 1)) + RuntimeError: Tensor size(-2) 3 needs to be divisible by blocksize[0] 2 + >>> x.to_sparse(layout=torch.sparse_csr, blocksize=(3, 1)) + RuntimeError: to_sparse for Strided to SparseCsr conversion does not use specified blocksize + + >>> x = torch.tensor([[[1], [0]], [[0], [0]], [[2], [3]]]) + >>> x.to_sparse(layout=torch.sparse_csr, dense_dim=1) + tensor(crow_indices=tensor([0, 1, 1, 3]), + col_indices=tensor([0, 0, 1]), + values=tensor([[1], + [2], + [3]]), size=(3, 2, 1), nnz=3, layout=torch.sparse_csr) + +""", +) + +add_docstr_all( + "to_sparse_csr", + r""" +to_sparse_csr(dense_dim=None) -> Tensor + +Convert a tensor to compressed row storage format (CSR). Except for +strided tensors, only works with 2D tensors. If the :attr:`self` is +strided, then the number of dense dimensions could be specified, and a +hybrid CSR tensor will be created, with `dense_dim` dense dimensions +and `self.dim() - 2 - dense_dim` batch dimension. + +Args: + + dense_dim (int, optional): Number of dense dimensions of the + resulting CSR tensor. This argument should be used only if + :attr:`self` is a strided tensor, and must be a value between 0 + and dimension of :attr:`self` tensor minus two. + +Example:: + + >>> dense = torch.randn(5, 5) + >>> sparse = dense.to_sparse_csr() + >>> sparse._nnz() + 25 + + >>> dense = torch.zeros(3, 3, 1, 1) + >>> dense[0, 0] = dense[1, 2] = dense[2, 1] = 1 + >>> dense.to_sparse_csr(dense_dim=2) + tensor(crow_indices=tensor([0, 1, 2, 3]), + col_indices=tensor([0, 2, 1]), + values=tensor([[[1.]], + + [[1.]], + + [[1.]]]), size=(3, 3, 1, 1), nnz=3, + layout=torch.sparse_csr) + +""", +) + +add_docstr_all( + "to_sparse_csc", + r""" +to_sparse_csc() -> Tensor + +Convert a tensor to compressed column storage (CSC) format. Except +for strided tensors, only works with 2D tensors. If the :attr:`self` +is strided, then the number of dense dimensions could be specified, +and a hybrid CSC tensor will be created, with `dense_dim` dense +dimensions and `self.dim() - 2 - dense_dim` batch dimension. + +Args: + + dense_dim (int, optional): Number of dense dimensions of the + resulting CSC tensor. This argument should be used only if + :attr:`self` is a strided tensor, and must be a value between 0 + and dimension of :attr:`self` tensor minus two. + +Example:: + + >>> dense = torch.randn(5, 5) + >>> sparse = dense.to_sparse_csc() + >>> sparse._nnz() + 25 + + >>> dense = torch.zeros(3, 3, 1, 1) + >>> dense[0, 0] = dense[1, 2] = dense[2, 1] = 1 + >>> dense.to_sparse_csc(dense_dim=2) + tensor(ccol_indices=tensor([0, 1, 2, 3]), + row_indices=tensor([0, 2, 1]), + values=tensor([[[1.]], + + [[1.]], + + [[1.]]]), size=(3, 3, 1, 1), nnz=3, + layout=torch.sparse_csc) + +""", +) + +add_docstr_all( + "to_sparse_bsr", + r""" +to_sparse_bsr(blocksize, dense_dim) -> Tensor + +Convert a tensor to a block sparse row (BSR) storage format of given +blocksize. If the :attr:`self` is strided, then the number of dense +dimensions could be specified, and a hybrid BSR tensor will be +created, with `dense_dim` dense dimensions and `self.dim() - 2 - +dense_dim` batch dimension. + +Args: + + blocksize (list, tuple, :class:`torch.Size`, optional): Block size + of the resulting BSR tensor. A block size must be a tuple of + length two such that its items evenly divide the two sparse + dimensions. + + dense_dim (int, optional): Number of dense dimensions of the + resulting BSR tensor. This argument should be used only if + :attr:`self` is a strided tensor, and must be a value between 0 + and dimension of :attr:`self` tensor minus two. + +Example:: + + >>> dense = torch.randn(10, 10) + >>> sparse = dense.to_sparse_csr() + >>> sparse_bsr = sparse.to_sparse_bsr((5, 5)) + >>> sparse_bsr.col_indices() + tensor([0, 1, 0, 1]) + + >>> dense = torch.zeros(4, 3, 1) + >>> dense[0:2, 0] = dense[0:2, 2] = dense[2:4, 1] = 1 + >>> dense.to_sparse_bsr((2, 1), 1) + tensor(crow_indices=tensor([0, 2, 3]), + col_indices=tensor([0, 2, 1]), + values=tensor([[[[1.]], + + [[1.]]], + + + [[[1.]], + + [[1.]]], + + + [[[1.]], + + [[1.]]]]), size=(4, 3, 1), nnz=3, + layout=torch.sparse_bsr) + +""", +) + +add_docstr_all( + "to_sparse_bsc", + r""" +to_sparse_bsc(blocksize, dense_dim) -> Tensor + +Convert a tensor to a block sparse column (BSC) storage format of +given blocksize. If the :attr:`self` is strided, then the number of +dense dimensions could be specified, and a hybrid BSC tensor will be +created, with `dense_dim` dense dimensions and `self.dim() - 2 - +dense_dim` batch dimension. + +Args: + + blocksize (list, tuple, :class:`torch.Size`, optional): Block size + of the resulting BSC tensor. A block size must be a tuple of + length two such that its items evenly divide the two sparse + dimensions. + + dense_dim (int, optional): Number of dense dimensions of the + resulting BSC tensor. This argument should be used only if + :attr:`self` is a strided tensor, and must be a value between 0 + and dimension of :attr:`self` tensor minus two. + +Example:: + + >>> dense = torch.randn(10, 10) + >>> sparse = dense.to_sparse_csr() + >>> sparse_bsc = sparse.to_sparse_bsc((5, 5)) + >>> sparse_bsc.row_indices() + tensor([0, 1, 0, 1]) + + >>> dense = torch.zeros(4, 3, 1) + >>> dense[0:2, 0] = dense[0:2, 2] = dense[2:4, 1] = 1 + >>> dense.to_sparse_bsc((2, 1), 1) + tensor(ccol_indices=tensor([0, 1, 2, 3]), + row_indices=tensor([0, 1, 0]), + values=tensor([[[[1.]], + + [[1.]]], + + + [[[1.]], + + [[1.]]], + + + [[[1.]], + + [[1.]]]]), size=(4, 3, 1), nnz=3, + layout=torch.sparse_bsc) + +""", +) + +add_docstr_all( + "to_mkldnn", + r""" +to_mkldnn() -> Tensor +Returns a copy of the tensor in ``torch.mkldnn`` layout. + +""", +) + +add_docstr_all( + "trace", + r""" +trace() -> Tensor + +See :func:`torch.trace` +""", +) + +add_docstr_all( + "transpose", + r""" +transpose(dim0, dim1) -> Tensor + +See :func:`torch.transpose` +""", +) + +add_docstr_all( + "transpose_", + r""" +transpose_(dim0, dim1) -> Tensor + +In-place version of :meth:`~Tensor.transpose` +""", +) + +add_docstr_all( + "triangular_solve", + r""" +triangular_solve(A, upper=True, transpose=False, unitriangular=False) -> (Tensor, Tensor) + +See :func:`torch.triangular_solve` +""", +) + +add_docstr_all( + "tril", + r""" +tril(diagonal=0) -> Tensor + +See :func:`torch.tril` +""", +) + +add_docstr_all( + "tril_", + r""" +tril_(diagonal=0) -> Tensor + +In-place version of :meth:`~Tensor.tril` +""", +) + +add_docstr_all( + "triu", + r""" +triu(diagonal=0) -> Tensor + +See :func:`torch.triu` +""", +) + +add_docstr_all( + "triu_", + r""" +triu_(diagonal=0) -> Tensor + +In-place version of :meth:`~Tensor.triu` +""", +) + +add_docstr_all( + "true_divide", + r""" +true_divide(value) -> Tensor + +See :func:`torch.true_divide` +""", +) + +add_docstr_all( + "true_divide_", + r""" +true_divide_(value) -> Tensor + +In-place version of :meth:`~Tensor.true_divide_` +""", +) + +add_docstr_all( + "trunc", + r""" +trunc() -> Tensor + +See :func:`torch.trunc` +""", +) + +add_docstr_all( + "fix", + r""" +fix() -> Tensor + +See :func:`torch.fix`. +""", +) + +add_docstr_all( + "trunc_", + r""" +trunc_() -> Tensor + +In-place version of :meth:`~Tensor.trunc` +""", +) + +add_docstr_all( + "fix_", + r""" +fix_() -> Tensor + +In-place version of :meth:`~Tensor.fix` +""", +) + +add_docstr_all( + "type", + r""" +type(dtype=None, non_blocking=False, **kwargs) -> str or Tensor +Returns the type if `dtype` is not provided, else casts this object to +the specified type. + +If this is already of the correct type, no copy is performed and the +original object is returned. + +Args: + dtype (dtype or string): The desired type + non_blocking (bool): If ``True``, and the source is in pinned memory + and destination is on the GPU or vice versa, the copy is performed + asynchronously with respect to the host. Otherwise, the argument + has no effect. + **kwargs: For compatibility, may contain the key ``async`` in place of + the ``non_blocking`` argument. The ``async`` arg is deprecated. +""", +) + +add_docstr_all( + "type_as", + r""" +type_as(tensor) -> Tensor + +Returns this tensor cast to the type of the given tensor. + +This is a no-op if the tensor is already of the correct type. This is +equivalent to ``self.type(tensor.type())`` + +Args: + tensor (Tensor): the tensor which has the desired type +""", +) + +add_docstr_all( + "unfold", + r""" +unfold(dimension, size, step) -> Tensor + +Returns a view of the original tensor which contains all slices of size :attr:`size` from +:attr:`self` tensor in the dimension :attr:`dimension`. + +Step between two slices is given by :attr:`step`. + +If `sizedim` is the size of dimension :attr:`dimension` for :attr:`self`, the size of +dimension :attr:`dimension` in the returned tensor will be +`(sizedim - size) / step + 1`. + +An additional dimension of size :attr:`size` is appended in the returned tensor. + +Args: + dimension (int): dimension in which unfolding happens + size (int): the size of each slice that is unfolded + step (int): the step between each slice + +Example:: + + >>> x = torch.arange(1., 8) + >>> x + tensor([ 1., 2., 3., 4., 5., 6., 7.]) + >>> x.unfold(0, 2, 1) + tensor([[ 1., 2.], + [ 2., 3.], + [ 3., 4.], + [ 4., 5.], + [ 5., 6.], + [ 6., 7.]]) + >>> x.unfold(0, 2, 2) + tensor([[ 1., 2.], + [ 3., 4.], + [ 5., 6.]]) +""", +) + +add_docstr_all( + "uniform_", + r""" +uniform_(from=0, to=1, *, generator=None) -> Tensor + +Fills :attr:`self` tensor with numbers sampled from the continuous uniform +distribution: + +.. math:: + f(x) = \dfrac{1}{\text{to} - \text{from}} +""", +) + +add_docstr_all( + "unsqueeze", + r""" +unsqueeze(dim) -> Tensor + +See :func:`torch.unsqueeze` +""", +) + +add_docstr_all( + "unsqueeze_", + r""" +unsqueeze_(dim) -> Tensor + +In-place version of :meth:`~Tensor.unsqueeze` +""", +) + +add_docstr_all( + "var", + r""" +var(dim=None, *, correction=1, keepdim=False) -> Tensor + +See :func:`torch.var` +""", +) + +add_docstr_all( + "vdot", + r""" +vdot(other) -> Tensor + +See :func:`torch.vdot` +""", +) + +add_docstr_all( + "view", + r""" +view(*shape) -> Tensor + +Returns a new tensor with the same data as the :attr:`self` tensor but of a +different :attr:`shape`. + +The returned tensor shares the same data and must have the same number +of elements, but may have a different size. For a tensor to be viewed, the new +view size must be compatible with its original size and stride, i.e., each new +view dimension must either be a subspace of an original dimension, or only span +across original dimensions :math:`d, d+1, \dots, d+k` that satisfy the following +contiguity-like condition that :math:`\forall i = d, \dots, d+k-1`, + +.. math:: + + \text{stride}[i] = \text{stride}[i+1] \times \text{size}[i+1] + +Otherwise, it will not be possible to view :attr:`self` tensor as :attr:`shape` +without copying it (e.g., via :meth:`contiguous`). When it is unclear whether a +:meth:`view` can be performed, it is advisable to use :meth:`reshape`, which +returns a view if the shapes are compatible, and copies (equivalent to calling +:meth:`contiguous`) otherwise. + +Args: + shape (torch.Size or int...): the desired size + +Example:: + + >>> x = torch.randn(4, 4) + >>> x.size() + torch.Size([4, 4]) + >>> y = x.view(16) + >>> y.size() + torch.Size([16]) + >>> z = x.view(-1, 8) # the size -1 is inferred from other dimensions + >>> z.size() + torch.Size([2, 8]) + + >>> a = torch.randn(1, 2, 3, 4) + >>> a.size() + torch.Size([1, 2, 3, 4]) + >>> b = a.transpose(1, 2) # Swaps 2nd and 3rd dimension + >>> b.size() + torch.Size([1, 3, 2, 4]) + >>> c = a.view(1, 3, 2, 4) # Does not change tensor layout in memory + >>> c.size() + torch.Size([1, 3, 2, 4]) + >>> torch.equal(b, c) + False + + +.. method:: view(dtype) -> Tensor + :noindex: + +Returns a new tensor with the same data as the :attr:`self` tensor but of a +different :attr:`dtype`. + +If the element size of :attr:`dtype` is different than that of ``self.dtype``, +then the size of the last dimension of the output will be scaled +proportionally. For instance, if :attr:`dtype` element size is twice that of +``self.dtype``, then each pair of elements in the last dimension of +:attr:`self` will be combined, and the size of the last dimension of the output +will be half that of :attr:`self`. If :attr:`dtype` element size is half that +of ``self.dtype``, then each element in the last dimension of :attr:`self` will +be split in two, and the size of the last dimension of the output will be +double that of :attr:`self`. For this to be possible, the following conditions +must be true: + + * ``self.dim()`` must be greater than 0. + * ``self.stride(-1)`` must be 1. + +Additionally, if the element size of :attr:`dtype` is greater than that of +``self.dtype``, the following conditions must be true as well: + + * ``self.size(-1)`` must be divisible by the ratio between the element + sizes of the dtypes. + * ``self.storage_offset()`` must be divisible by the ratio between the + element sizes of the dtypes. + * The strides of all dimensions, except the last dimension, must be + divisible by the ratio between the element sizes of the dtypes. + +If any of the above conditions are not met, an error is thrown. + +.. warning:: + + This overload is not supported by TorchScript, and using it in a Torchscript + program will cause undefined behavior. + + +Args: + dtype (:class:`torch.dtype`): the desired dtype + +Example:: + + >>> x = torch.randn(4, 4) + >>> x + tensor([[ 0.9482, -0.0310, 1.4999, -0.5316], + [-0.1520, 0.7472, 0.5617, -0.8649], + [-2.4724, -0.0334, -0.2976, -0.8499], + [-0.2109, 1.9913, -0.9607, -0.6123]]) + >>> x.dtype + torch.float32 + + >>> y = x.view(torch.int32) + >>> y + tensor([[ 1064483442, -1124191867, 1069546515, -1089989247], + [-1105482831, 1061112040, 1057999968, -1084397505], + [-1071760287, -1123489973, -1097310419, -1084649136], + [-1101533110, 1073668768, -1082790149, -1088634448]], + dtype=torch.int32) + >>> y[0, 0] = 1000000000 + >>> x + tensor([[ 0.0047, -0.0310, 1.4999, -0.5316], + [-0.1520, 0.7472, 0.5617, -0.8649], + [-2.4724, -0.0334, -0.2976, -0.8499], + [-0.2109, 1.9913, -0.9607, -0.6123]]) + + >>> x.view(torch.cfloat) + tensor([[ 0.0047-0.0310j, 1.4999-0.5316j], + [-0.1520+0.7472j, 0.5617-0.8649j], + [-2.4724-0.0334j, -0.2976-0.8499j], + [-0.2109+1.9913j, -0.9607-0.6123j]]) + >>> x.view(torch.cfloat).size() + torch.Size([4, 2]) + + >>> x.view(torch.uint8) + tensor([[ 0, 202, 154, 59, 182, 243, 253, 188, 185, 252, 191, 63, 240, 22, + 8, 191], + [227, 165, 27, 190, 128, 72, 63, 63, 146, 203, 15, 63, 22, 106, + 93, 191], + [205, 59, 30, 192, 112, 206, 8, 189, 7, 95, 152, 190, 12, 147, + 89, 191], + [ 43, 246, 87, 190, 235, 226, 254, 63, 111, 240, 117, 191, 177, 191, + 28, 191]], dtype=torch.uint8) + >>> x.view(torch.uint8).size() + torch.Size([4, 16]) +""", +) + +add_docstr_all( + "view_as", + r""" +view_as(other) -> Tensor + +View this tensor as the same size as :attr:`other`. +``self.view_as(other)`` is equivalent to ``self.view(other.size())``. + +Please see :meth:`~Tensor.view` for more information about ``view``. + +Args: + other (:class:`torch.Tensor`): The result tensor has the same size + as :attr:`other`. +""", +) + +add_docstr_all( + "expand", + r""" +expand(*sizes) -> Tensor + +Returns a new view of the :attr:`self` tensor with singleton dimensions expanded +to a larger size. + +Passing -1 as the size for a dimension means not changing the size of +that dimension. + +Tensor can be also expanded to a larger number of dimensions, and the +new ones will be appended at the front. For the new dimensions, the +size cannot be set to -1. + +Expanding a tensor does not allocate new memory, but only creates a +new view on the existing tensor where a dimension of size one is +expanded to a larger size by setting the ``stride`` to 0. Any dimension +of size 1 can be expanded to an arbitrary value without allocating new +memory. + +Args: + *sizes (torch.Size or int...): the desired expanded size + +.. warning:: + + More than one element of an expanded tensor may refer to a single + memory location. As a result, in-place operations (especially ones that + are vectorized) may result in incorrect behavior. If you need to write + to the tensors, please clone them first. + +Example:: + + >>> x = torch.tensor([[1], [2], [3]]) + >>> x.size() + torch.Size([3, 1]) + >>> x.expand(3, 4) + tensor([[ 1, 1, 1, 1], + [ 2, 2, 2, 2], + [ 3, 3, 3, 3]]) + >>> x.expand(-1, 4) # -1 means not changing the size of that dimension + tensor([[ 1, 1, 1, 1], + [ 2, 2, 2, 2], + [ 3, 3, 3, 3]]) +""", +) + +add_docstr_all( + "expand_as", + r""" +expand_as(other) -> Tensor + +Expand this tensor to the same size as :attr:`other`. +``self.expand_as(other)`` is equivalent to ``self.expand(other.size())``. + +Please see :meth:`~Tensor.expand` for more information about ``expand``. + +Args: + other (:class:`torch.Tensor`): The result tensor has the same size + as :attr:`other`. +""", +) + +add_docstr_all( + "sum_to_size", + r""" +sum_to_size(*size) -> Tensor + +Sum ``this`` tensor to :attr:`size`. +:attr:`size` must be broadcastable to ``this`` tensor size. + +Args: + size (int...): a sequence of integers defining the shape of the output tensor. +""", +) + + +add_docstr_all( + "zero_", + r""" +zero_() -> Tensor + +Fills :attr:`self` tensor with zeros. +""", +) + +add_docstr_all( + "matmul", + r""" +matmul(tensor2) -> Tensor + +See :func:`torch.matmul` +""", +) + +add_docstr_all( + "chunk", + r""" +chunk(chunks, dim=0) -> List of Tensors + +See :func:`torch.chunk` +""", +) + +add_docstr_all( + "unsafe_chunk", + r""" +unsafe_chunk(chunks, dim=0) -> List of Tensors + +See :func:`torch.unsafe_chunk` +""", +) + +add_docstr_all( + "unsafe_split", + r""" +unsafe_split(split_size, dim=0) -> List of Tensors + +See :func:`torch.unsafe_split` +""", +) + +add_docstr_all( + "tensor_split", + r""" +tensor_split(indices_or_sections, dim=0) -> List of Tensors + +See :func:`torch.tensor_split` +""", +) + +add_docstr_all( + "hsplit", + r""" +hsplit(split_size_or_sections) -> List of Tensors + +See :func:`torch.hsplit` +""", +) + +add_docstr_all( + "vsplit", + r""" +vsplit(split_size_or_sections) -> List of Tensors + +See :func:`torch.vsplit` +""", +) + +add_docstr_all( + "dsplit", + r""" +dsplit(split_size_or_sections) -> List of Tensors + +See :func:`torch.dsplit` +""", +) + +add_docstr_all( + "stft", + r""" +stft(frame_length, hop, fft_size=None, return_onesided=True, window=None, + pad_end=0, align_to_window=None) -> Tensor + +See :func:`torch.stft` +""", +) + +add_docstr_all( + "istft", + r""" +istft(n_fft, hop_length=None, win_length=None, window=None, + center=True, normalized=False, onesided=True, length=None) -> Tensor + +See :func:`torch.istft` +""", +) + +add_docstr_all( + "det", + r""" +det() -> Tensor + +See :func:`torch.det` +""", +) + +add_docstr_all( + "where", + r""" +where(condition, y) -> Tensor + +``self.where(condition, y)`` is equivalent to ``torch.where(condition, self, y)``. +See :func:`torch.where` +""", +) + +add_docstr_all( + "logdet", + r""" +logdet() -> Tensor + +See :func:`torch.logdet` +""", +) + +add_docstr_all( + "slogdet", + r""" +slogdet() -> (Tensor, Tensor) + +See :func:`torch.slogdet` +""", +) + +add_docstr_all( + "unbind", + r""" +unbind(dim=0) -> seq + +See :func:`torch.unbind` +""", +) + +add_docstr_all( + "pin_memory", + r""" +pin_memory() -> Tensor + +Copies the tensor to pinned memory, if it's not already pinned. +By default, the device pinned memory on will be the current :ref:`accelerator`. +""", +) + +add_docstr_all( + "pinverse", + r""" +pinverse() -> Tensor + +See :func:`torch.pinverse` +""", +) + +add_docstr_all( + "index_add", + r""" +index_add(dim, index, source, *, alpha=1) -> Tensor + +Out-of-place version of :meth:`torch.Tensor.index_add_`. +""", +) + +add_docstr_all( + "index_copy", + r""" +index_copy(dim, index, tensor2) -> Tensor + +Out-of-place version of :meth:`torch.Tensor.index_copy_`. +""", +) + +add_docstr_all( + "index_fill", + r""" +index_fill(dim, index, value) -> Tensor + +Out-of-place version of :meth:`torch.Tensor.index_fill_`. +""", +) + +add_docstr_all( + "scatter", + r""" +scatter(dim, index, src) -> Tensor + +Out-of-place version of :meth:`torch.Tensor.scatter_` +""", +) + +add_docstr_all( + "scatter_add", + r""" +scatter_add(dim, index, src) -> Tensor + +Out-of-place version of :meth:`torch.Tensor.scatter_add_` +""", +) + +add_docstr_all( + "scatter_reduce", + r""" +scatter_reduce(dim, index, src, reduce, *, include_self=True) -> Tensor + +Out-of-place version of :meth:`torch.Tensor.scatter_reduce_` +""", +) + +add_docstr_all( + "masked_scatter", + r""" +masked_scatter(mask, tensor) -> Tensor + +Out-of-place version of :meth:`torch.Tensor.masked_scatter_` + +.. note:: + + The inputs :attr:`self` and :attr:`mask` + :ref:`broadcast `. + +Example: + + >>> self = torch.tensor([0, 0, 0, 0, 0]) + >>> mask = torch.tensor( + ... [[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]], + ... dtype=torch.bool, + ... ) + >>> source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) + >>> self.masked_scatter(mask, source) + tensor([[0, 0, 0, 0, 1], + [2, 3, 0, 4, 5]]) + +""", +) + +add_docstr_all( + "xlogy", + r""" +xlogy(other) -> Tensor + +See :func:`torch.xlogy` +""", +) + +add_docstr_all( + "xlogy_", + r""" +xlogy_(other) -> Tensor + +In-place version of :meth:`~Tensor.xlogy` +""", +) + +add_docstr_all( + "masked_fill", + r""" +masked_fill(mask, value) -> Tensor + +Out-of-place version of :meth:`torch.Tensor.masked_fill_` +""", +) + +add_docstr_all( + "grad", + r""" +This attribute is ``None`` by default and becomes a Tensor the first time a call to +:func:`backward` computes gradients for ``self``. +The attribute will then contain the gradients computed and future calls to +:func:`backward` will accumulate (add) gradients into it. +""", +) + +add_docstr_all( + "retain_grad", + r""" +retain_grad() -> None + +Enables this Tensor to have their :attr:`grad` populated during +:func:`backward`. This is a no-op for leaf tensors. +""", +) + +add_docstr_all( + "retains_grad", + r""" +Is ``True`` if this Tensor is non-leaf and its :attr:`grad` is enabled to be +populated during :func:`backward`, ``False`` otherwise. +""", +) + +add_docstr_all( + "requires_grad", + r""" +Is ``True`` if gradients need to be computed for this Tensor, ``False`` otherwise. + +.. note:: + + The fact that gradients need to be computed for a Tensor do not mean that the :attr:`grad` + attribute will be populated, see :attr:`is_leaf` for more details. + +""", +) + +add_docstr_all( + "is_leaf", + r""" +All Tensors that have :attr:`requires_grad` which is ``False`` will be leaf Tensors by convention. + +For Tensors that have :attr:`requires_grad` which is ``True``, they will be leaf Tensors if they were +created by the user. This means that they are not the result of an operation and so +:attr:`grad_fn` is None. + +Only leaf Tensors will have their :attr:`grad` populated during a call to :func:`backward`. +To get :attr:`grad` populated for non-leaf Tensors, you can use :func:`retain_grad`. + +Example:: + + >>> a = torch.rand(10, requires_grad=True) + >>> a.is_leaf + True + >>> b = torch.rand(10, requires_grad=True).cuda() + >>> b.is_leaf + False + # b was created by the operation that cast a cpu Tensor into a cuda Tensor + >>> c = torch.rand(10, requires_grad=True) + 2 + >>> c.is_leaf + False + # c was created by the addition operation + >>> d = torch.rand(10).cuda() + >>> d.is_leaf + True + # d does not require gradients and so has no operation creating it (that is tracked by the autograd engine) + >>> e = torch.rand(10).cuda().requires_grad_() + >>> e.is_leaf + True + # e requires gradients and has no operations creating it + >>> f = torch.rand(10, requires_grad=True, device="cuda") + >>> f.is_leaf + True + # f requires grad, has no operation creating it + + +""", +) + +add_docstr_all( + "names", + r""" +Stores names for each of this tensor's dimensions. + +``names[idx]`` corresponds to the name of tensor dimension ``idx``. +Names are either a string if the dimension is named or ``None`` if the +dimension is unnamed. + +Dimension names may contain characters or underscore. Furthermore, a dimension +name must be a valid Python variable name (i.e., does not start with underscore). + +Tensors may not have two named dimensions with the same name. + +.. warning:: + The named tensor API is experimental and subject to change. + +""", +) + +add_docstr_all( + "is_cuda", + r""" +Is ``True`` if the Tensor is stored on the GPU, ``False`` otherwise. +""", +) + +add_docstr_all( + "is_cpu", + r""" +Is ``True`` if the Tensor is stored on the CPU, ``False`` otherwise. +""", +) + +add_docstr_all( + "is_xla", + r""" +Is ``True`` if the Tensor is stored on an XLA device, ``False`` otherwise. +""", +) + +add_docstr_all( + "is_ipu", + r""" +Is ``True`` if the Tensor is stored on the IPU, ``False`` otherwise. +""", +) + +add_docstr_all( + "is_xpu", + r""" +Is ``True`` if the Tensor is stored on the XPU, ``False`` otherwise. +""", +) + +add_docstr_all( + "is_quantized", + r""" +Is ``True`` if the Tensor is quantized, ``False`` otherwise. +""", +) + +add_docstr_all( + "is_meta", + r""" +Is ``True`` if the Tensor is a meta tensor, ``False`` otherwise. Meta tensors +are like normal tensors, but they carry no data. +""", +) + +add_docstr_all( + "is_mps", + r""" +Is ``True`` if the Tensor is stored on the MPS device, ``False`` otherwise. +""", +) + +add_docstr_all( + "is_sparse", + r""" +Is ``True`` if the Tensor uses sparse COO storage layout, ``False`` otherwise. +""", +) + +add_docstr_all( + "is_sparse_csr", + r""" +Is ``True`` if the Tensor uses sparse CSR storage layout, ``False`` otherwise. +""", +) + +add_docstr_all( + "device", + r""" +Is the :class:`torch.device` where this Tensor is. +""", +) + +add_docstr_all( + "ndim", + r""" +Alias for :meth:`~Tensor.dim()` +""", +) + +add_docstr_all( + "itemsize", + r""" +Alias for :meth:`~Tensor.element_size()` +""", +) + +add_docstr_all( + "nbytes", + r""" +Returns the number of bytes consumed by the "view" of elements of the Tensor +if the Tensor does not use sparse storage layout. +Defined to be :meth:`~Tensor.numel()` * :meth:`~Tensor.element_size()` +""", +) + +add_docstr_all( + "T", + r""" +Returns a view of this tensor with its dimensions reversed. + +If ``n`` is the number of dimensions in ``x``, +``x.T`` is equivalent to ``x.permute(n-1, n-2, ..., 0)``. + +.. warning:: + The use of :func:`Tensor.T` on tensors of dimension other than 2 to reverse their shape + is deprecated and it will throw an error in a future release. Consider :attr:`~.Tensor.mT` + to transpose batches of matrices or `x.permute(*torch.arange(x.ndim - 1, -1, -1))` to reverse + the dimensions of a tensor. +""", +) + +add_docstr_all( + "H", + r""" +Returns a view of a matrix (2-D tensor) conjugated and transposed. + +``x.H`` is equivalent to ``x.transpose(0, 1).conj()`` for complex matrices and +``x.transpose(0, 1)`` for real matrices. + +.. seealso:: + + :attr:`~.Tensor.mH`: An attribute that also works on batches of matrices. +""", +) + +add_docstr_all( + "mT", + r""" +Returns a view of this tensor with the last two dimensions transposed. + +``x.mT`` is equivalent to ``x.transpose(-2, -1)``. +""", +) + +add_docstr_all( + "mH", + r""" +Accessing this property is equivalent to calling :func:`adjoint`. +""", +) + +add_docstr_all( + "adjoint", + r""" +adjoint() -> Tensor + +Alias for :func:`adjoint` +""", +) + +add_docstr_all( + "real", + r""" +Returns a new tensor containing real values of the :attr:`self` tensor for a complex-valued input tensor. +The returned tensor and :attr:`self` share the same underlying storage. + +Returns :attr:`self` if :attr:`self` is a real-valued tensor tensor. + +Example:: + + >>> x=torch.randn(4, dtype=torch.cfloat) + >>> x + tensor([(0.3100+0.3553j), (-0.5445-0.7896j), (-1.6492-0.0633j), (-0.0638-0.8119j)]) + >>> x.real + tensor([ 0.3100, -0.5445, -1.6492, -0.0638]) + +""", +) + +add_docstr_all( + "imag", + r""" +Returns a new tensor containing imaginary values of the :attr:`self` tensor. +The returned tensor and :attr:`self` share the same underlying storage. + +.. warning:: + :func:`imag` is only supported for tensors with complex dtypes. + +Example:: + + >>> x=torch.randn(4, dtype=torch.cfloat) + >>> x + tensor([(0.3100+0.3553j), (-0.5445-0.7896j), (-1.6492-0.0633j), (-0.0638-0.8119j)]) + >>> x.imag + tensor([ 0.3553, -0.7896, -0.0633, -0.8119]) + +""", +) + +add_docstr_all( + "as_subclass", + r""" +as_subclass(cls) -> Tensor + +Makes a ``cls`` instance with the same data pointer as ``self``. Changes +in the output mirror changes in ``self``, and the output stays attached +to the autograd graph. ``cls`` must be a subclass of ``Tensor``. +""", +) + +add_docstr_all( + "crow_indices", + r""" +crow_indices() -> IntTensor + +Returns the tensor containing the compressed row indices of the :attr:`self` +tensor when :attr:`self` is a sparse CSR tensor of layout ``sparse_csr``. +The ``crow_indices`` tensor is strictly of shape (:attr:`self`.size(0) + 1) +and of type ``int32`` or ``int64``. When using MKL routines such as sparse +matrix multiplication, it is necessary to use ``int32`` indexing in order +to avoid downcasting and potentially losing information. + +Example:: + + >>> csr = torch.eye(5,5).to_sparse_csr() + >>> csr.crow_indices() + tensor([0, 1, 2, 3, 4, 5], dtype=torch.int32) + +""", +) + +add_docstr_all( + "col_indices", + r""" +col_indices() -> IntTensor + +Returns the tensor containing the column indices of the :attr:`self` +tensor when :attr:`self` is a sparse CSR tensor of layout ``sparse_csr``. +The ``col_indices`` tensor is strictly of shape (:attr:`self`.nnz()) +and of type ``int32`` or ``int64``. When using MKL routines such as sparse +matrix multiplication, it is necessary to use ``int32`` indexing in order +to avoid downcasting and potentially losing information. + +Example:: + + >>> csr = torch.eye(5,5).to_sparse_csr() + >>> csr.col_indices() + tensor([0, 1, 2, 3, 4], dtype=torch.int32) + +""", +) + +add_docstr_all( + "to_padded_tensor", + r""" +to_padded_tensor(padding, output_size=None) -> Tensor +See :func:`to_padded_tensor` +""", +) diff --git a/phivenv/Lib/site-packages/torch/_tensor_str.py b/phivenv/Lib/site-packages/torch/_tensor_str.py new file mode 100644 index 0000000000000000000000000000000000000000..c75c8523d82908fefb7277d4d727cb95ace360e3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_tensor_str.py @@ -0,0 +1,726 @@ +# mypy: allow-untyped-defs +import contextlib +import dataclasses +import math +import textwrap +from typing import Any, Optional + +import torch +from torch import inf + + +@dataclasses.dataclass +class __PrinterOptions: + precision: int = 4 + threshold: float = 1000 + edgeitems: int = 3 + linewidth: int = 80 + sci_mode: Optional[bool] = None + + +PRINT_OPTS = __PrinterOptions() + + +# We could use **kwargs, but this will give better docs +def set_printoptions( + precision=None, + threshold=None, + edgeitems=None, + linewidth=None, + profile=None, + sci_mode=None, +): + r"""Set options for printing. Items shamelessly taken from NumPy + + Args: + precision: Number of digits of precision for floating point output + (default = 4). + threshold: Total number of array elements which trigger summarization + rather than full `repr` (default = 1000). + edgeitems: Number of array items in summary at beginning and end of + each dimension (default = 3). + linewidth: The number of characters per line for the purpose of + inserting line breaks (default = 80). Thresholded matrices will + ignore this parameter. + profile: Sane defaults for pretty printing. Can override with any of + the above options. (any one of `default`, `short`, `full`) + sci_mode: Enable (True) or disable (False) scientific notation. If + None (default) is specified, the value is defined by + `torch._tensor_str._Formatter`. This value is automatically chosen + by the framework. + + Example:: + + >>> # Limit the precision of elements + >>> torch.set_printoptions(precision=2) + >>> torch.tensor([1.12345]) + tensor([1.12]) + >>> # Limit the number of elements shown + >>> torch.set_printoptions(threshold=5) + >>> torch.arange(10) + tensor([0, 1, 2, ..., 7, 8, 9]) + >>> # Restore defaults + >>> torch.set_printoptions(profile='default') + >>> torch.tensor([1.12345]) + tensor([1.1235]) + >>> torch.arange(10) + tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) + + """ + if profile is not None: + if profile == "default": + PRINT_OPTS.precision = 4 + PRINT_OPTS.threshold = 1000 + PRINT_OPTS.edgeitems = 3 + PRINT_OPTS.linewidth = 80 + elif profile == "short": + PRINT_OPTS.precision = 2 + PRINT_OPTS.threshold = 1000 + PRINT_OPTS.edgeitems = 2 + PRINT_OPTS.linewidth = 80 + elif profile == "full": + PRINT_OPTS.precision = 4 + PRINT_OPTS.threshold = inf + PRINT_OPTS.edgeitems = 3 + PRINT_OPTS.linewidth = 80 + + if precision is not None: + PRINT_OPTS.precision = precision + if threshold is not None: + PRINT_OPTS.threshold = threshold + if edgeitems is not None: + PRINT_OPTS.edgeitems = edgeitems + if linewidth is not None: + PRINT_OPTS.linewidth = linewidth + PRINT_OPTS.sci_mode = sci_mode + + +def get_printoptions() -> dict[str, Any]: + r"""Gets the current options for printing, as a dictionary that + can be passed as ``**kwargs`` to set_printoptions(). + """ + return dataclasses.asdict(PRINT_OPTS) + + +@contextlib.contextmanager +def printoptions(**kwargs): + r"""Context manager that temporarily changes the print options. Accepted + arguments are same as :func:`set_printoptions`.""" + old_kwargs = get_printoptions() + set_printoptions(**kwargs) + try: + yield + finally: + set_printoptions(**old_kwargs) + + +def tensor_totype(t): + dtype = ( + torch.float + if ( + t.is_mps + or (t.is_xpu and not torch.xpu.get_device_properties(t.device).has_fp64) + or t.is_maia + ) + else torch.double + ) + return t.to(dtype=dtype) + + +class _Formatter: + def __init__(self, tensor): + self.floating_dtype = tensor.dtype.is_floating_point + self.int_mode = True + self.sci_mode = False + self.max_width = 1 + + with torch.no_grad(): + tensor_view = tensor.reshape(-1) + + if not self.floating_dtype: + for value in tensor_view: + value_str = f"{value}" + self.max_width = max(self.max_width, len(value_str)) + + else: + if tensor.dtype == torch.float4_e2m1fn_x2: # type: ignore[attr-defined] + # torch.float4_e2m1fn_x2 is special and does not support the casts necessary + # to print it, we choose to display the uint8 representation here for + # convenience of being able to print a tensor. + # TODO(#146647): extend this to other dtypes without casts defined, such + # as the bits, uint1..7 and int1..7 dtypes. + tensor_view = tensor_view.view(torch.uint8) + + nonzero_finite_vals = torch.masked_select( + tensor_view, torch.isfinite(tensor_view) & tensor_view.ne(0) + ) + + if nonzero_finite_vals.numel() == 0: + # no valid number, do nothing + return + + if tensor.dtype == torch.float8_e8m0fnu: # type: ignore[attr-defined] + # float8_e8m0fnu is special and does not define arithmetic ops, + # and printing code further in this file assumes the existence + # of various arithmetic ops to figure out what to print. We hack + # and convert to float here to make printing work correctly. + # TODO(#113663): also add the other float8 dtypes here after arithmetic + # support for them is removed + nonzero_finite_vals = nonzero_finite_vals.float() + + # Convert to double (or float) for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU. + nonzero_finite_abs = tensor_totype(nonzero_finite_vals.abs()) + nonzero_finite_min = tensor_totype(nonzero_finite_abs.min()) + nonzero_finite_max = tensor_totype(nonzero_finite_abs.max()) + + for value in nonzero_finite_vals: + if value != torch.ceil(value): + self.int_mode = False + break + + if self.int_mode: + # in int_mode for floats, all numbers are integers, and we append a decimal to nonfinites + # to indicate that the tensor is of floating type. add 1 to the len to account for this. + if ( + nonzero_finite_max / nonzero_finite_min > 1000.0 + or nonzero_finite_max > 1.0e8 + ): + self.sci_mode = True + for value in nonzero_finite_vals: + value_str = f"{{:.{PRINT_OPTS.precision}e}}".format(value) + self.max_width = max(self.max_width, len(value_str)) + else: + for value in nonzero_finite_vals: + value_str = f"{value:.0f}" + self.max_width = max(self.max_width, len(value_str) + 1) + else: + # Check if scientific representation should be used. + if ( + nonzero_finite_max / nonzero_finite_min > 1000.0 + or nonzero_finite_max > 1.0e8 + or nonzero_finite_min < 1.0e-4 + ): + self.sci_mode = True + for value in nonzero_finite_vals: + value_str = f"{{:.{PRINT_OPTS.precision}e}}".format(value) + self.max_width = max(self.max_width, len(value_str)) + else: + for value in nonzero_finite_vals: + value_str = f"{{:.{PRINT_OPTS.precision}f}}".format(value) + self.max_width = max(self.max_width, len(value_str)) + + if PRINT_OPTS.sci_mode is not None: + self.sci_mode = PRINT_OPTS.sci_mode + + def width(self): + return self.max_width + + def format(self, value): + if self.floating_dtype: + if self.sci_mode: + ret = f"{{:{self.max_width}.{PRINT_OPTS.precision}e}}".format(value) + elif self.int_mode: + ret = f"{value:.0f}" + if not (math.isinf(value) or math.isnan(value)): + ret += "." + else: + ret = f"{{:.{PRINT_OPTS.precision}f}}".format(value) + else: + ret = f"{value}" + return (self.max_width - len(ret)) * " " + ret + + +def _scalar_str(self, formatter1, formatter2=None): + if formatter2 is not None: + real_str = _scalar_str(self.real, formatter1) + imag_str = (_scalar_str(self.imag, formatter2) + "j").lstrip() + # handles negative numbers, +0.0, -0.0 + if imag_str[0] == "+" or imag_str[0] == "-": + return real_str + imag_str + else: + return real_str + "+" + imag_str + else: + return formatter1.format(self.item()) + + +def _vector_str(self, indent, summarize, formatter1, formatter2=None): + # length includes spaces and comma between elements + element_length = formatter1.width() + 2 + if formatter2 is not None: + # width for imag_formatter + an extra j for complex + element_length += formatter2.width() + 1 + + elements_per_line = max( + 1, int(math.floor((PRINT_OPTS.linewidth - indent) / (element_length))) + ) + + def _val_formatter(val, formatter1=formatter1, formatter2=formatter2): + if formatter2 is not None: + real_str = formatter1.format(val.real) + imag_str = (formatter2.format(val.imag) + "j").lstrip() + # handles negative numbers, +0.0, -0.0 + if imag_str[0] == "+" or imag_str[0] == "-": + return real_str + imag_str + else: + return real_str + "+" + imag_str + else: + return formatter1.format(val) + + if self.dtype == torch.float4_e2m1fn_x2: # type: ignore[attr-defined] + # torch.float4_e2m1fn_x2 is special and does not support the casts necessary + # to print it, we choose to display the uint8 representation here for + # convenience of being able to print a tensor. + # TODO(#146647): extend this to other dtypes without casts defined, such + # as the bits, uint1..7 and int1..7 dtypes. + self = self.view(torch.uint8) + + if summarize and not PRINT_OPTS.edgeitems: + # Deal with edge case that negative zero is zero + data = ["..."] + elif summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems: + data = ( + [_val_formatter(val) for val in self[: PRINT_OPTS.edgeitems].tolist()] + + [" ..."] + + [_val_formatter(val) for val in self[-PRINT_OPTS.edgeitems :].tolist()] + ) + else: + data = [_val_formatter(val) for val in self.tolist()] + + data_lines = [ + data[i : i + elements_per_line] for i in range(0, len(data), elements_per_line) + ] + lines = [", ".join(line) for line in data_lines] + return "[" + ("," + "\n" + " " * (indent + 1)).join(lines) + "]" + + +# formatter2 is only used for printing complex tensors. +# For complex tensors, formatter1 and formatter2 are the formatters for tensor.real +# and tensor.imag respesectively +def _tensor_str_with_formatter(self, indent, summarize, formatter1, formatter2=None): + dim = self.dim() + + if dim == 0: + return _scalar_str(self, formatter1, formatter2) + + if dim == 1: + return _vector_str(self, indent, summarize, formatter1, formatter2) + + if summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems: + slices = ( + [ + _tensor_str_with_formatter( + self[i], indent + 1, summarize, formatter1, formatter2 + ) + for i in range(0, PRINT_OPTS.edgeitems) + ] + + ["..."] + + [ + _tensor_str_with_formatter( + self[i], indent + 1, summarize, formatter1, formatter2 + ) + for i in range(len(self) - PRINT_OPTS.edgeitems, len(self)) + ] + ) + else: + slices = [ + _tensor_str_with_formatter( + self[i], indent + 1, summarize, formatter1, formatter2 + ) + for i in range(0, self.size(0)) + ] + + tensor_str = ("," + "\n" * (dim - 1) + " " * (indent + 1)).join(slices) + return "[" + tensor_str + "]" + + +def _tensor_str(self, indent): + if self.numel() == 0: + return "[]" + + if self.has_names(): + # There are two main codepaths (possibly more) that tensor printing goes through: + # - tensor data can fit comfortably on screen + # - tensor data needs to be summarized + # Some of the codepaths don't fully support named tensors, so we send in + # an unnamed tensor to the formatting code as a workaround. + self = self.rename(None) + + summarize = self.numel() > PRINT_OPTS.threshold + + if self._is_zerotensor(): + self = self.clone() + + # handle the negative bit + if self.is_neg(): + self = self.resolve_neg() + + # TODO: Remove me when `masked_select` is implemented for FP8 + if self.dtype in [ + torch.float8_e5m2, + torch.float8_e5m2fnuz, + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + ]: + self = self.half() + + if self.dtype.is_complex: + # handle the conjugate bit + self = self.resolve_conj() + real_formatter = _Formatter( + get_summarized_data(self.real) if summarize else self.real + ) + imag_formatter = _Formatter( + get_summarized_data(self.imag) if summarize else self.imag + ) + return _tensor_str_with_formatter( + self, indent, summarize, real_formatter, imag_formatter + ) + else: + formatter = _Formatter(get_summarized_data(self) if summarize else self) + return _tensor_str_with_formatter(self, indent, summarize, formatter) + + +def _add_suffixes(tensor_str, suffixes, indent, force_newline): + tensor_strs = [tensor_str] + last_line_len = len(tensor_str) - tensor_str.rfind("\n") + 1 + for suffix in suffixes: + suffix_len = len(suffix) + if force_newline or last_line_len + suffix_len + 2 > PRINT_OPTS.linewidth: + tensor_strs.append(",\n" + " " * indent + suffix) + last_line_len = indent + suffix_len + force_newline = False + else: + tensor_strs.append(", " + suffix) + last_line_len += suffix_len + 2 + tensor_strs.append(")") + return "".join(tensor_strs) + + +def get_summarized_data(self): + dim = self.dim() + if dim == 0: + return self + if dim == 1: + if self.size(0) > 2 * PRINT_OPTS.edgeitems: + return torch.cat( + (self[: PRINT_OPTS.edgeitems], self[-PRINT_OPTS.edgeitems :]) + ) + else: + return self + if not PRINT_OPTS.edgeitems: + return self.new_empty([0] * self.dim()) + elif self.size(0) > 2 * PRINT_OPTS.edgeitems: + start = [self[i] for i in range(0, PRINT_OPTS.edgeitems)] + end = [self[i] for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))] + return torch.stack([get_summarized_data(x) for x in (start + end)]) + else: + return torch.stack([get_summarized_data(x) for x in self]) + + +def _str_intern(inp, *, tensor_contents=None): + if torch._C._functorch.is_functorch_wrapped_tensor(inp): + return _functorch_wrapper_str_intern(inp, tensor_contents=tensor_contents) + is_plain_tensor = type(inp) is torch.Tensor or type(inp) is torch.nn.Parameter + if inp.is_nested: + prefix = "nested_tensor(" + elif is_plain_tensor: + prefix = "tensor(" + else: + prefix = f"{type(inp).__name__}(" + indent = len(prefix) + suffixes = [] + custom_contents_provided = tensor_contents is not None + if custom_contents_provided: + tensor_str = tensor_contents + + # This is used to extract the primal value and thus disable the forward AD + # within this function. + # TODO(albanD) This needs to be updated when more than one level is supported + self, tangent = torch.autograd.forward_ad.unpack_dual(inp) + + # Note [Print tensor device]: + # A general logic here is we only print device when it doesn't match + # the device specified in default tensor type. + # Currently torch.set_default_tensor_type() only supports CPU/CUDA, thus + # torch._C._get_default_device() only returns either cpu or cuda. + # In other cases, we don't have a way to set them as default yet, + # and we should always print out device for them. + if ( + self.device.type != torch._C._get_default_device() + or ( + self.device.type == "cuda" + and torch.cuda.current_device() != self.device.index + ) + or (self.device.type == "mps") + ): + suffixes.append("device='" + str(self.device) + "'") + + # Tensor printing performs tensor operations like slice, indexing, etc to make it in a + # representable format. These operations on ipu/xla/lazy/mtia tensor results in compilations. Hence, + # to avoid compilations, copying the tensor to cpu before printing. + if self.device.type in ["xla", "lazy", "ipu", "mtia"]: + self = self.to("cpu") + + # TODO: add an API to map real -> complex dtypes + _default_complex_dtype = ( + torch.cdouble if torch.get_default_dtype() == torch.double else torch.cfloat + ) + has_default_dtype = self.dtype in ( + torch.get_default_dtype(), + _default_complex_dtype, + torch.int64, + torch.bool, + ) + if self.is_sparse: + suffixes.append("size=" + str(tuple(self.shape))) + from torch._subclasses.fake_tensor import FakeTensor + + is_meta = self.is_meta or isinstance(self, FakeTensor) + if not is_meta: + suffixes.append("nnz=" + str(self._nnz())) + if not has_default_dtype: + suffixes.append("dtype=" + str(self.dtype)) + if not custom_contents_provided: + indices_prefix = "indices=tensor(" + indices = self._indices().detach() + if is_meta: + indices_str = "..." + else: + indices_str = _tensor_str(indices, indent + len(indices_prefix)) + if is_meta or indices.numel() == 0: + indices_str += ", size=" + str(tuple(indices.shape)) + values_prefix = "values=tensor(" + values = self._values().detach() + if is_meta: + values_str = "..." + else: + values_str = _tensor_str(values, indent + len(values_prefix)) + if is_meta or values.numel() == 0: + values_str += ", size=" + str(tuple(values.shape)) + tensor_str = ( + indices_prefix + + indices_str + + "),\n" + + " " * indent + + values_prefix + + values_str + + ")" + ) + elif self.layout in { + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + }: + from torch._subclasses.fake_tensor import FakeTensor + + suffixes.append("size=" + str(tuple(self.shape))) + is_meta = self.is_meta or isinstance(self, FakeTensor) + if not is_meta: + suffixes.append("nnz=" + str(self._nnz())) + if not has_default_dtype: + suffixes.append("dtype=" + str(self.dtype)) + if not custom_contents_provided: + compressed_indices_method, plain_indices_method = { + torch.sparse_csr: (torch.Tensor.crow_indices, torch.Tensor.col_indices), + torch.sparse_csc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices), + torch.sparse_bsr: (torch.Tensor.crow_indices, torch.Tensor.col_indices), + torch.sparse_bsc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices), + }[self.layout] + if self.layout in {torch.sparse_csr, torch.sparse_bsr}: + cdimname, pdimname = "row", "column" + else: + cdimname, pdimname = "column", "row" + compressed_indices_prefix = f"c{cdimname[:3]}_indices=tensor(" + compressed_indices = compressed_indices_method(self).detach() + if is_meta: + compressed_indices_str = "..." + else: + compressed_indices_str = _tensor_str( + compressed_indices, indent + len(compressed_indices_prefix) + ) + if compressed_indices.numel() == 0 or is_meta: + compressed_indices_str += ", size=" + str( + tuple(compressed_indices.shape) + ) + plain_indices_prefix = f"{pdimname[:3]}_indices=tensor(" + plain_indices = plain_indices_method(self).detach() + if is_meta: + plain_indices_str = "..." + else: + plain_indices_str = _tensor_str( + plain_indices, indent + len(plain_indices_prefix) + ) + if plain_indices.numel() == 0 or is_meta: + plain_indices_str += ", size=" + str(tuple(plain_indices.shape)) + values_prefix = "values=tensor(" + values = self.values().detach() + if is_meta: + values_str = "..." + else: + values_str = _tensor_str(values, indent + len(values_prefix)) + if values.numel() == 0 or is_meta: + values_str += ", size=" + str(tuple(values.shape)) + tensor_str = ( + compressed_indices_prefix + + compressed_indices_str + + "),\n" + + " " * indent + + plain_indices_prefix + + plain_indices_str + + "),\n" + + " " * indent + + values_prefix + + values_str + + ")" + ) + elif self.is_quantized: + suffixes.append("size=" + str(tuple(self.shape))) + if not has_default_dtype: + suffixes.append("dtype=" + str(self.dtype)) + suffixes.append("quantization_scheme=" + str(self.qscheme())) + if ( + self.qscheme() == torch.per_tensor_affine + or self.qscheme() == torch.per_tensor_symmetric + ): + suffixes.append("scale=" + str(self.q_scale())) + suffixes.append("zero_point=" + str(self.q_zero_point())) + elif ( + self.qscheme() == torch.per_channel_affine + or self.qscheme() == torch.per_channel_symmetric + or self.qscheme() == torch.per_channel_affine_float_qparams + ): + suffixes.append("scale=" + str(self.q_per_channel_scales())) + suffixes.append("zero_point=" + str(self.q_per_channel_zero_points())) + suffixes.append("axis=" + str(self.q_per_channel_axis())) + if not custom_contents_provided: + tensor_str = _tensor_str(self.dequantize(), indent) + elif self.is_nested: + if not custom_contents_provided: + + def indented_str(s, indent): + return "\n".join(f" {line}" for line in s.split("\n")) + + strs = ",\n".join( + indented_str(str(t), indent + 1) + for t in torch.ops.aten.unbind.int(self, 0) + ) + tensor_str = f"[\n{strs}\n]" + elif torch._is_functional_tensor(self): + prefix = "_to_functional_tensor(" + tensor_str = repr(torch._from_functional_tensor(self)) + else: + # Circular import problem, so we import it here + from torch._subclasses.fake_tensor import FakeTensor + + if self.is_meta or isinstance(self, FakeTensor): + suffixes.append("size=" + str(tuple(self.shape))) + if self.dtype != torch.get_default_dtype(): + suffixes.append("dtype=" + str(self.dtype)) + # TODO: This implies that ellipses is valid syntax for allocating + # a meta tensor or FakeTensor, which it could be, but it isn't right now + if not custom_contents_provided: + tensor_str = "..." + else: + if self.numel() == 0 and not self.is_sparse: + # Explicitly print the shape if it is not (0,), to match NumPy behavior + if self.dim() != 1: + suffixes.append("size=" + str(tuple(self.shape))) + + # In an empty tensor, there are no elements to infer if the dtype + # should be int64, so it must be shown explicitly. + if self.dtype != torch.get_default_dtype(): + suffixes.append("dtype=" + str(self.dtype)) + if not custom_contents_provided: + tensor_str = "[]" + else: + if not PRINT_OPTS.edgeitems: + suffixes.append("size=" + str(tuple(self.shape))) + + if not has_default_dtype: + suffixes.append("dtype=" + str(self.dtype)) + + if not custom_contents_provided: + if self.layout != torch.strided: + tensor_str = _tensor_str(self.to_dense(), indent) + else: + tensor_str = _tensor_str(self, indent) + + if self.layout != torch.strided: + suffixes.append("layout=" + str(self.layout)) + + # Use inp here to get the original grad_fn and not the one generated by the forward grad + # unpacking. + grad_fn_name = None + try: + grad_fn = inp.grad_fn + except RuntimeError: + # Accessing the grad_fn calls rebasing logic which would cause an error + # if that tensor is a view created in no-grad mode modified in-place in + # no-grad mode. See: https://github.com/pytorch/pytorch/issues/99968 + grad_fn_name = "Invalid" + + if grad_fn_name is None and grad_fn is not None: # type: ignore[possibly-undefined] + grad_fn_name = type(grad_fn).__name__ + if grad_fn_name == "CppFunction": + grad_fn_name = grad_fn.name().rsplit("::", 1)[-1] + + if grad_fn_name is not None: + suffixes.append(f"grad_fn=<{grad_fn_name}>") + elif inp.requires_grad: + suffixes.append("requires_grad=True") + + if self.has_names(): + suffixes.append(f"names={self.names}") + + if tangent is not None: + suffixes.append(f"tangent={tangent}") + + string_repr = _add_suffixes( + prefix + tensor_str, # type: ignore[possibly-undefined] + suffixes, + indent, + force_newline=self.is_sparse, + ) + + # Check if this instance is flagged as a parameter and change the repr accordingly. + # Unfortunately, this function has to be aware of this detail. + # NB: This is currently skipped for plain tensor parameters to maintain BC. In the future, + # this should be done for those as well to produce a valid repr. + if isinstance(self, torch.nn.Parameter) and not is_plain_tensor: + string_repr = f"Parameter({string_repr})" + + return string_repr + + +def _functorch_wrapper_str_intern(tensor, *, tensor_contents=None): + level = torch._C._functorch.maybe_get_level(tensor) + assert level != -1 + + if torch._C._functorch.is_functionaltensor(tensor): + # Since we're unwrapping the FunctionalTensorWrapper, we need to make sure + # that it's up to date first + torch._sync(tensor) + + value = torch._C._functorch.get_unwrapped(tensor) + value_repr = repr(value) + + indented_value_repr = textwrap.indent(value_repr, " " * 4) + if torch._C._functorch.is_batchedtensor(tensor): + bdim = torch._C._functorch.maybe_get_bdim(tensor) + assert bdim != -1 + return ( + f"BatchedTensor(lvl={level}, bdim={bdim}, value=\n{indented_value_repr}\n)" + ) + if torch._C._functorch.is_gradtrackingtensor(tensor): + return f"GradTrackingTensor(lvl={level}, value=\n{indented_value_repr}\n)" + if torch._C._functorch.is_functionaltensor(tensor): + return f"FunctionalTensor(lvl={level}, value=\\\n{value_repr})" + + raise ValueError("We don't know how to print this, please file us an issue") + + +def _str(self, *, tensor_contents=None): + with torch.no_grad(), torch.utils._python_dispatch._disable_current_modes(): + guard = torch._C._DisableFuncTorch() # noqa: F841 + return _str_intern(self, tensor_contents=tensor_contents) diff --git a/phivenv/Lib/site-packages/torch/_thread_safe_fork.py b/phivenv/Lib/site-packages/torch/_thread_safe_fork.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/_torch_docs.py b/phivenv/Lib/site-packages/torch/_torch_docs.py new file mode 100644 index 0000000000000000000000000000000000000000..4a521c7fa3c6743d822cf00033d277e726760923 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_torch_docs.py @@ -0,0 +1,14140 @@ +# mypy: allow-untyped-defs +"""Adds docstrings to functions defined in the torch._C module.""" + +import re + +import torch._C +from torch._C import _add_docstr as add_docstr + + +def parse_kwargs(desc): + r"""Map a description of args to a dictionary of {argname: description}. + + Input: + (' weight (Tensor): a weight tensor\n' + + ' Some optional description') + Output: { + 'weight': \ + 'weight (Tensor): a weight tensor\n Some optional description' + } + """ + # Split on exactly 4 spaces after a newline + regx = re.compile(r"\n\s{4}(?!\s)") + kwargs = [section.strip() for section in regx.split(desc)] + kwargs = [section for section in kwargs if len(section) > 0] + return {desc.split(" ")[0]: desc for desc in kwargs} + + +def merge_dicts(*dicts): + """Merge dictionaries into a single dictionary.""" + return {x: d[x] for d in dicts for x in d} + + +common_args = parse_kwargs( + """ + input (Tensor): the input tensor. + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned tensor. Default: ``torch.preserve_format``. +""" +) + +reduceops_common_args = merge_dicts( + common_args, + parse_kwargs( + """ + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. +""" + ), + { + "opt_keepdim": """ + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. +""" + }, +) + +multi_dim_common = merge_dicts( + reduceops_common_args, + parse_kwargs( + """ + dim (int or tuple of ints): the dimension or dimensions to reduce. +""" + ), + { + "keepdim_details": """ +If :attr:`keepdim` is ``True``, the output tensor is of the same size +as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. +Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the +output tensor having 1 (or ``len(dim)``) fewer dimension(s). +""" + }, + { + "opt_dim": """ + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. +""" + }, + { + "opt_dim_all_reduce": """ + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. +""" + }, +) + +single_dim_common = merge_dicts( + reduceops_common_args, + parse_kwargs( + """ + dim (int): the dimension to reduce. +""" + ), + { + "opt_dim": """ + dim (int, optional): the dimension to reduce. +""" + }, + { + "opt_dim_all_reduce": """ + dim (int, optional): the dimension to reduce. + If ``None``, all dimensions are reduced. +""" + }, + { + "keepdim_details": """If :attr:`keepdim` is ``True``, the output tensor is of the same size +as :attr:`input` except in the dimension :attr:`dim` where it is of size 1. +Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in +the output tensor having 1 fewer dimension than :attr:`input`.""" + }, +) + +factory_common_args = merge_dicts( + common_args, + parse_kwargs( + """ + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.contiguous_format``. + check_invariants (bool, optional): If sparse tensor invariants are checked. + Default: as returned by :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled`, + initially False. +""" + ), + { + "sparse_factory_device_note": """\ +.. note:: + + If the ``device`` argument is not specified the device of the given + :attr:`values` and indices tensor(s) must match. If, however, the + argument is specified the input Tensors will be converted to the + given device and in turn determine the device of the constructed + sparse tensor.""" + }, +) + +factory_like_common_args = parse_kwargs( + """ + input (Tensor): the size of :attr:`input` will determine size of the output tensor. + layout (:class:`torch.layout`, optional): the desired layout of returned tensor. + Default: if ``None``, defaults to the layout of :attr:`input`. + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: if ``None``, defaults to the dtype of :attr:`input`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, defaults to the device of :attr:`input`. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. +""" +) + +factory_data_common_args = parse_kwargs( + """ + data (array_like): Initial data for the tensor. Can be a list, tuple, + NumPy ``ndarray``, scalar, and other types. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, infers data type from :attr:`data`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. +""" +) + +tf32_notes = { + "tf32_note": """This operator supports :ref:`TensorFloat32`.""" +} + +rocm_fp16_notes = { + "rocm_fp16_note": """On certain ROCm devices, when using float16 inputs this module will use \ +:ref:`different precision` for backward.""" +} + +reproducibility_notes: dict[str, str] = { + "forward_reproducibility_note": """This operation may behave nondeterministically when given tensors on \ +a CUDA device. See :doc:`/notes/randomness` for more information.""", + "backward_reproducibility_note": """This operation may produce nondeterministic gradients when given tensors on \ +a CUDA device. See :doc:`/notes/randomness` for more information.""", + "cudnn_reproducibility_note": """In some circumstances when given tensors on a CUDA device \ +and using CuDNN, this operator may select a nondeterministic algorithm to increase performance. If this is \ +undesirable, you can try to make the operation deterministic (potentially at \ +a performance cost) by setting ``torch.backends.cudnn.deterministic = True``. \ +See :doc:`/notes/randomness` for more information.""", +} + +sparse_support_notes = { + "sparse_beta_warning": """ +.. warning:: + Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported, + or may not have autograd support. If you notice missing functionality please + open a feature request.""", +} + +add_docstr( + torch.abs, + r""" +abs(input: Tensor, *, out: Optional[Tensor]) -> Tensor + +Computes the absolute value of each element in :attr:`input`. + +.. math:: + \text{out}_{i} = |\text{input}_{i}| +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> torch.abs(torch.tensor([-1, -2, 3])) + tensor([ 1, 2, 3]) +""".format(**common_args), +) + +add_docstr( + torch.absolute, + r""" +absolute(input: Tensor, *, out: Optional[Tensor]) -> Tensor + +Alias for :func:`torch.abs` +""", +) + +add_docstr( + torch.acos, + r""" +acos(input: Tensor, *, out: Optional[Tensor]) -> Tensor + +Computes the inverse cosine of each element in :attr:`input`. + +.. math:: + \text{out}_{i} = \cos^{-1}(\text{input}_{i}) +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.3348, -0.5889, 0.2005, -0.1584]) + >>> torch.acos(a) + tensor([ 1.2294, 2.2004, 1.3690, 1.7298]) +""".format(**common_args), +) + +add_docstr( + torch.arccos, + r""" +arccos(input: Tensor, *, out: Optional[Tensor]) -> Tensor + +Alias for :func:`torch.acos`. +""", +) + +add_docstr( + torch.acosh, + r""" +acosh(input: Tensor, *, out: Optional[Tensor]) -> Tensor + +Returns a new tensor with the inverse hyperbolic cosine of the elements of :attr:`input`. + +.. math:: + \text{out}_{i} = \cosh^{-1}(\text{input}_{i}) + +Note: + The domain of the inverse hyperbolic cosine is `[1, inf)` and values outside this range + will be mapped to ``NaN``, except for `+ INF` for which the output is mapped to `+ INF`. +""" + + r""" +Args: + {input} + +Keyword arguments: + {out} + +Example:: + + >>> a = torch.randn(4).uniform_(1, 2) + >>> a + tensor([ 1.3192, 1.9915, 1.9674, 1.7151 ]) + >>> torch.acosh(a) + tensor([ 0.7791, 1.3120, 1.2979, 1.1341 ]) +""".format(**common_args), +) + +add_docstr( + torch.arccosh, + r""" +arccosh(input: Tensor, *, out: Optional[Tensor]) -> Tensor + +Alias for :func:`torch.acosh`. +""", +) + +add_docstr( + torch.index_add, + r""" +index_add(input: Tensor, dim: int, index: Tensor, source: Tensor, *, alpha: Union[Number, _complex] = 1, out: Optional[Tensor]) -> Tensor # noqa: B950 + +See :meth:`~Tensor.index_add_` for function description. +""", +) + +add_docstr( + torch.index_copy, + r""" +index_copy(input: Tensor, dim: int, index: Tensor, source: Tensor, *, out: Optional[Tensor]) -> Tensor + +See :meth:`~Tensor.index_add_` for function description. +""", +) + +add_docstr( + torch.index_reduce, + r""" +index_reduce(input: Tensor, dim: int, index: Tensor, source: Tensor, reduce: str, *, include_self: bool = True, out: Optional[Tensor]) -> Tensor # noqa: B950 + +See :meth:`~Tensor.index_reduce_` for function description. +""", +) + +add_docstr( + torch.add, + r""" +add(input, other, *, alpha=1, out=None) -> Tensor + +Adds :attr:`other`, scaled by :attr:`alpha`, to :attr:`input`. + +.. math:: + \text{{out}}_i = \text{{input}}_i + \text{{alpha}} \times \text{{other}}_i +""" + + r""" + +Supports :ref:`broadcasting to a common shape `, +:ref:`type promotion `, and integer, float, and complex inputs. + +Args: + {input} + other (Tensor or Number): the tensor or number to add to :attr:`input`. + +Keyword arguments: + alpha (Number): the multiplier for :attr:`other`. + {out} + +Examples:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.0202, 1.0985, 1.3506, -0.6056]) + >>> torch.add(a, 20) + tensor([ 20.0202, 21.0985, 21.3506, 19.3944]) + + >>> b = torch.randn(4) + >>> b + tensor([-0.9732, -0.3497, 0.6245, 0.4022]) + >>> c = torch.randn(4, 1) + >>> c + tensor([[ 0.3743], + [-1.7724], + [-0.5811], + [-0.8017]]) + >>> torch.add(b, c, alpha=10) + tensor([[ 2.7695, 3.3930, 4.3672, 4.1450], + [-18.6971, -18.0736, -17.0994, -17.3216], + [ -6.7845, -6.1610, -5.1868, -5.4090], + [ -8.9902, -8.3667, -7.3925, -7.6147]]) +""".format(**common_args), +) + +add_docstr( + torch.addbmm, + r""" +addbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor + +Performs a batch matrix-matrix product of matrices stored +in :attr:`batch1` and :attr:`batch2`, +with a reduced add step (all matrix multiplications get accumulated +along the first dimension). +:attr:`input` is added to the final result. + +:attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the +same number of matrices. + +If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a +:math:`(b \times m \times p)` tensor, :attr:`input` must be +:ref:`broadcastable ` with a :math:`(n \times p)` tensor +and :attr:`out` will be a :math:`(n \times p)` tensor. + +.. math:: + out = \beta\ \text{input} + \alpha\ (\sum_{i=0}^{b-1} \text{batch1}_i \mathbin{@} \text{batch2}_i) + +If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in +it will not be propagated. +""" + + r""" +For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and :attr:`alpha` +must be real numbers, otherwise they should be integers. + +{tf32_note} + +{rocm_fp16_note} + +Args: + input (Tensor): matrix to be added + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + +Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for `batch1 @ batch2` (:math:`\alpha`) + {out} + +Example:: + + >>> M = torch.randn(3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.addbmm(M, batch1, batch2) + tensor([[ 6.6311, 0.0503, 6.9768, -12.0362, -2.1653], + [ -4.8185, -1.4255, -6.6760, 8.9453, 2.5743], + [ -3.8202, 4.3691, 1.0943, -1.1109, 5.4730]]) +""".format(**common_args, **tf32_notes, **rocm_fp16_notes), +) + +add_docstr( + torch.addcdiv, + r""" +addcdiv(input, tensor1, tensor2, *, value=1, out=None) -> Tensor + +Performs the element-wise division of :attr:`tensor1` by :attr:`tensor2`, +multiplies the result by the scalar :attr:`value` and adds it to :attr:`input`. + +.. warning:: + Integer division with addcdiv is no longer supported, and in a future + release addcdiv will perform a true division of tensor1 and tensor2. + The historic addcdiv behavior can be implemented as + (input + value * torch.trunc(tensor1 / tensor2)).to(input.dtype) + for integer inputs and as (input + value * tensor1 / tensor2) for float inputs. + The future addcdiv behavior is just the latter implementation: + (input + value * tensor1 / tensor2), for all dtypes. + +.. math:: + \text{out}_i = \text{input}_i + \text{value} \times \frac{\text{tensor1}_i}{\text{tensor2}_i} +""" + + r""" + +The shapes of :attr:`input`, :attr:`tensor1`, and :attr:`tensor2` must be +:ref:`broadcastable `. + +For inputs of type `FloatTensor` or `DoubleTensor`, :attr:`value` must be +a real number, otherwise an integer. + +Args: + input (Tensor): the tensor to be added + tensor1 (Tensor): the numerator tensor + tensor2 (Tensor): the denominator tensor + +Keyword args: + value (Number, optional): multiplier for :math:`\text{{tensor1}} / \text{{tensor2}}` + {out} + +Example:: + + >>> t = torch.randn(1, 3) + >>> t1 = torch.randn(3, 1) + >>> t2 = torch.randn(1, 3) + >>> torch.addcdiv(t, t1, t2, value=0.1) + tensor([[-0.2312, -3.6496, 0.1312], + [-1.0428, 3.4292, -0.1030], + [-0.5369, -0.9829, 0.0430]]) +""".format(**common_args), +) + +add_docstr( + torch.addcmul, + r""" +addcmul(input, tensor1, tensor2, *, value=1, out=None) -> Tensor + +Performs the element-wise multiplication of :attr:`tensor1` +by :attr:`tensor2`, multiplies the result by the scalar :attr:`value` +and adds it to :attr:`input`. + +.. math:: + \text{out}_i = \text{input}_i + \text{value} \times \text{tensor1}_i \times \text{tensor2}_i +""" + + r""" +The shapes of :attr:`tensor`, :attr:`tensor1`, and :attr:`tensor2` must be +:ref:`broadcastable `. + +For inputs of type `FloatTensor` or `DoubleTensor`, :attr:`value` must be +a real number, otherwise an integer. + +Args: + input (Tensor): the tensor to be added + tensor1 (Tensor): the tensor to be multiplied + tensor2 (Tensor): the tensor to be multiplied + +Keyword args: + value (Number, optional): multiplier for :math:`tensor1 .* tensor2` + {out} + +Example:: + + >>> t = torch.randn(1, 3) + >>> t1 = torch.randn(3, 1) + >>> t2 = torch.randn(1, 3) + >>> torch.addcmul(t, t1, t2, value=0.1) + tensor([[-0.8635, -0.6391, 1.6174], + [-0.7617, -0.5879, 1.7388], + [-0.8353, -0.6249, 1.6511]]) +""".format(**common_args), +) + +add_docstr( + torch.addmm, + r""" +addmm(input, mat1, mat2, out_dtype=None, *, beta=1, alpha=1, out=None) -> Tensor + +Performs a matrix multiplication of the matrices :attr:`mat1` and :attr:`mat2`. +The matrix :attr:`input` is added to the final result. + +If :attr:`mat1` is a :math:`(n \times m)` tensor, :attr:`mat2` is a +:math:`(m \times p)` tensor, then :attr:`input` must be +:ref:`broadcastable ` with a :math:`(n \times p)` tensor +and :attr:`out` will be a :math:`(n \times p)` tensor. + +:attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between +:attr:`mat1` and :attr:`mat2` and the added matrix :attr:`input` respectively. + +.. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat1}_i \mathbin{@} \text{mat2}_i) + +If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in +it will not be propagated. +""" + + r""" +For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and +:attr:`alpha` must be real numbers, otherwise they should be integers. + +This operation has support for arguments with :ref:`sparse layouts`. If +:attr:`input` is sparse the result will have the same layout and if :attr:`out` +is provided it must have the same layout as :attr:`input`. + +{sparse_beta_warning} + +{tf32_note} + +{rocm_fp16_note} + +Args: + input (Tensor): matrix to be added + mat1 (Tensor): the first matrix to be matrix multiplied + mat2 (Tensor): the second matrix to be matrix multiplied + out_dtype (dtype, optional): the dtype of the output tensor, + Supported only on CUDA and for torch.float32 given + torch.float16/torch.bfloat16 input dtypes + +Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + {out} + +Example:: + + >>> M = torch.randn(2, 3) + >>> mat1 = torch.randn(2, 3) + >>> mat2 = torch.randn(3, 3) + >>> torch.addmm(M, mat1, mat2) + tensor([[-4.8716, 1.4671, -1.3746], + [ 0.7573, -3.9555, -2.8681]]) +""".format(**common_args, **tf32_notes, **rocm_fp16_notes, **sparse_support_notes), +) + +add_docstr( + torch.adjoint, + r""" +adjoint(input: Tensor) -> Tensor +Returns a view of the tensor conjugated and with the last two dimensions transposed. + +``x.adjoint()`` is equivalent to ``x.transpose(-2, -1).conj()`` for complex tensors and +to ``x.transpose(-2, -1)`` for real tensors. + +Args: + {input} + +Example:: + + >>> x = torch.arange(4, dtype=torch.float) + >>> A = torch.complex(x, x).reshape(2, 2) + >>> A + tensor([[0.+0.j, 1.+1.j], + [2.+2.j, 3.+3.j]]) + >>> A.adjoint() + tensor([[0.-0.j, 2.-2.j], + [1.-1.j, 3.-3.j]]) + >>> (A.adjoint() == A.mH).all() + tensor(True) +""", +) + +add_docstr( + torch.sspaddmm, + r""" +sspaddmm(input, mat1, mat2, *, beta=1, alpha=1, out=None) -> Tensor + +Matrix multiplies a sparse tensor :attr:`mat1` with a dense tensor +:attr:`mat2`, then adds the sparse tensor :attr:`input` to the result. + +Note: This function is equivalent to :func:`torch.addmm`, except +:attr:`input` and :attr:`mat1` are sparse. + +Args: + input (Tensor): a sparse matrix to be added + mat1 (Tensor): a sparse matrix to be matrix multiplied + mat2 (Tensor): a dense matrix to be matrix multiplied + +Keyword args: + beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + {out} +""".format(**common_args), +) + +add_docstr( + torch.smm, + r""" +smm(input, mat) -> Tensor + +Performs a matrix multiplication of the sparse matrix :attr:`input` +with the dense matrix :attr:`mat`. + +Args: + input (Tensor): a sparse matrix to be matrix multiplied + mat (Tensor): a dense matrix to be matrix multiplied +""", +) + +add_docstr( + torch.addmv, + r""" +addmv(input, mat, vec, *, beta=1, alpha=1, out=None) -> Tensor + +Performs a matrix-vector product of the matrix :attr:`mat` and +the vector :attr:`vec`. +The vector :attr:`input` is added to the final result. + +If :attr:`mat` is a :math:`(n \times m)` tensor, :attr:`vec` is a 1-D tensor of +size `m`, then :attr:`input` must be +:ref:`broadcastable ` with a 1-D tensor of size `n` and +:attr:`out` will be 1-D tensor of size `n`. + +:attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between +:attr:`mat` and :attr:`vec` and the added tensor :attr:`input` respectively. + +.. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat} \mathbin{@} \text{vec}) + +If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in +it will not be propagated. +""" + + r""" +For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and +:attr:`alpha` must be real numbers, otherwise they should be integers. + +Args: + input (Tensor): vector to be added + mat (Tensor): matrix to be matrix multiplied + vec (Tensor): vector to be matrix multiplied + +Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat @ vec` (:math:`\alpha`) + {out} + +Example:: + + >>> M = torch.randn(2) + >>> mat = torch.randn(2, 3) + >>> vec = torch.randn(3) + >>> torch.addmv(M, mat, vec) + tensor([-0.3768, -5.5565]) +""".format(**common_args), +) + +add_docstr( + torch.addr, + r""" +addr(input, vec1, vec2, *, beta=1, alpha=1, out=None) -> Tensor + +Performs the outer-product of vectors :attr:`vec1` and :attr:`vec2` +and adds it to the matrix :attr:`input`. + +Optional values :attr:`beta` and :attr:`alpha` are scaling factors on the +outer product between :attr:`vec1` and :attr:`vec2` and the added matrix +:attr:`input` respectively. + +.. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{vec1} \otimes \text{vec2}) + +If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in +it will not be propagated. +""" + + r""" +If :attr:`vec1` is a vector of size `n` and :attr:`vec2` is a vector +of size `m`, then :attr:`input` must be +:ref:`broadcastable ` with a matrix of size +:math:`(n \times m)` and :attr:`out` will be a matrix of size +:math:`(n \times m)`. + +Args: + input (Tensor): matrix to be added + vec1 (Tensor): the first vector of the outer product + vec2 (Tensor): the second vector of the outer product + +Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{{vec1}} \otimes \text{{vec2}}` (:math:`\alpha`) + {out} + +Example:: + + >>> vec1 = torch.arange(1., 4.) + >>> vec2 = torch.arange(1., 3.) + >>> M = torch.zeros(3, 2) + >>> torch.addr(M, vec1, vec2) + tensor([[ 1., 2.], + [ 2., 4.], + [ 3., 6.]]) +""".format(**common_args), +) + +add_docstr( + torch.allclose, + r""" +allclose(input: Tensor, other: Tensor, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) -> bool + +This function checks if :attr:`input` and :attr:`other` satisfy the condition: + +.. math:: + \lvert \text{input}_i - \text{other}_i \rvert \leq \texttt{atol} + \texttt{rtol} \times \lvert \text{other}_i \rvert +""" + + r""" +elementwise, for all elements of :attr:`input` and :attr:`other`. The behaviour of this function is analogous to +`numpy.allclose `_ + +Args: + input (Tensor): first tensor to compare + other (Tensor): second tensor to compare + atol (float, optional): absolute tolerance. Default: 1e-08 + rtol (float, optional): relative tolerance. Default: 1e-05 + equal_nan (bool, optional): if ``True``, then two ``NaN`` s will be considered equal. Default: ``False`` + +Example:: + + >>> torch.allclose(torch.tensor([10000., 1e-07]), torch.tensor([10000.1, 1e-08])) + False + >>> torch.allclose(torch.tensor([10000., 1e-08]), torch.tensor([10000.1, 1e-09])) + True + >>> torch.allclose(torch.tensor([1.0, float('nan')]), torch.tensor([1.0, float('nan')])) + False + >>> torch.allclose(torch.tensor([1.0, float('nan')]), torch.tensor([1.0, float('nan')]), equal_nan=True) + True +""", +) + +add_docstr( + torch.all, + r""" +all(input: Tensor, *, out=None) -> Tensor + +Tests if all elements in :attr:`input` evaluate to `True`. + +.. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.all(a) + tensor(False, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.all(a) + tensor(False) + +.. function:: all(input, dim, keepdim=False, *, out=None) -> Tensor + :noindex: + +For each row of :attr:`input` in the given dimension :attr:`dim`, +returns `True` if all elements in the row evaluate to `True` and `False` otherwise. + +{keepdim_details} + +Args: + {input} + {opt_dim_all_reduce} + {opt_keepdim} + +Keyword args: + {out} + +Example:: + + >>> a = torch.rand(4, 2).bool() + >>> a + tensor([[True, True], + [True, False], + [True, True], + [True, True]], dtype=torch.bool) + >>> torch.all(a, dim=1) + tensor([ True, False, True, True], dtype=torch.bool) + >>> torch.all(a, dim=0) + tensor([ True, False], dtype=torch.bool) +""".format(**multi_dim_common), +) + +add_docstr( + torch.any, + r""" +any(input: Tensor, *, out: Optional[Tensor]) -> Tensor + +Tests if any element in :attr:`input` evaluates to `True`. + +.. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.any(a) + tensor(True, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.any(a) + tensor(True) + +.. function:: any(input, dim, keepdim=False, *, out=None) -> Tensor + :noindex: + +For each row of :attr:`input` in the given dimension :attr:`dim`, +returns `True` if any element in the row evaluate to `True` and `False` otherwise. + +{keepdim_details} + +Args: + {input} + {opt_dim_all_reduce} + {opt_keepdim} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4, 2) < 0 + >>> a + tensor([[ True, True], + [False, True], + [ True, True], + [False, False]]) + >>> torch.any(a, 1) + tensor([ True, True, True, False]) + >>> torch.any(a, 0) + tensor([True, True]) +""".format(**multi_dim_common), +) + +add_docstr( + torch.angle, + r""" +angle(input: Tensor, *, out: Optional[Tensor]) -> Tensor + +Computes the element-wise angle (in radians) of the given :attr:`input` tensor. + +.. math:: + \text{out}_{i} = angle(\text{input}_{i}) +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +.. note:: Starting in PyTorch 1.8, angle returns pi for negative real numbers, + zero for non-negative real numbers, and propagates NaNs. Previously + the function would return zero for all real numbers and not propagate + floating-point NaNs. + +Example:: + + >>> torch.angle(torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]))*180/3.14159 + tensor([ 135., 135, -45]) +""".format(**common_args), +) + +add_docstr( + torch.as_strided, + r""" +as_strided(input, size, stride, storage_offset=None) -> Tensor + +Create a view of an existing `torch.Tensor` :attr:`input` with specified +:attr:`size`, :attr:`stride` and :attr:`storage_offset`. + +.. warning:: + Prefer using other view functions, like :meth:`torch.Tensor.view` or + :meth:`torch.Tensor.expand`, to setting a view's strides manually with + `as_strided`, as this function will throw an error on non-standard Pytorch + backends (that do not have a concept of stride) and the result will depend + on the current layout in memory. The constructed view must only refer to + elements within the Tensor's storage or a runtime error will be thrown. + If the generated view is "overlapped" (with multiple indices referring to + the same element in memory), the behavior of inplace operations on this view + is undefined (and might not throw runtime errors). + +Args: + {input} + size (tuple or ints): the shape of the output tensor + stride (tuple or ints): the stride of the output tensor + storage_offset (int, optional): the offset in the underlying storage of the output tensor. + If ``None``, the storage_offset of the output tensor will match the input tensor. + +Example:: + + >>> x = torch.randn(3, 3) + >>> x + tensor([[ 0.9039, 0.6291, 1.0795], + [ 0.1586, 2.1939, -0.4900], + [-0.1909, -0.7503, 1.9355]]) + >>> t = torch.as_strided(x, (2, 2), (1, 2)) + >>> t + tensor([[0.9039, 1.0795], + [0.6291, 0.1586]]) + >>> t = torch.as_strided(x, (2, 2), (1, 2), 1) + tensor([[0.6291, 0.1586], + [1.0795, 2.1939]]) +""".format(**common_args), +) + +add_docstr( + torch.as_tensor, + r""" +as_tensor(data: Any, dtype: Optional[dtype] = None, device: Optional[DeviceLikeType]) -> Tensor + +Converts :attr:`data` into a tensor, sharing data and preserving autograd +history if possible. + +If :attr:`data` is already a tensor with the requested dtype and device +then :attr:`data` itself is returned, but if :attr:`data` is a +tensor with a different dtype or device then it's copied as if using +`data.to(dtype=dtype, device=device)`. + +If :attr:`data` is a NumPy array (an ndarray) with the same dtype and device then a +tensor is constructed using :func:`torch.from_numpy`. + +If :attr:`data` is a CuPy array, the returned tensor will be located on the same device as the CuPy array unless +specifically overwritten by :attr:`device` or a default device. + +.. seealso:: + + :func:`torch.tensor` never shares its data and creates a new "leaf tensor" (see :doc:`/notes/autograd`). + + +Args: + {data} + {dtype} + device (:class:`torch.device`, optional): the device of the constructed tensor. If None and data is a tensor + then the device of data is used. If None and data is not a tensor then + the result tensor is constructed on the current device. + + +Example:: + + >>> a = numpy.array([1, 2, 3]) + >>> t = torch.as_tensor(a) + >>> t + tensor([ 1, 2, 3]) + >>> t[0] = -1 + >>> a + array([-1, 2, 3]) + + >>> a = numpy.array([1, 2, 3]) + >>> t = torch.as_tensor(a, device=torch.device('cuda')) + >>> t + tensor([ 1, 2, 3]) + >>> t[0] = -1 + >>> a + array([1, 2, 3]) +""".format(**factory_data_common_args), +) + +add_docstr( + torch.asin, + r""" +asin(input: Tensor, *, out: Optional[Tensor]) -> Tensor + +Returns a new tensor with the arcsine of the elements of :attr:`input`. + +.. math:: + \text{out}_{i} = \sin^{-1}(\text{input}_{i}) +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-0.5962, 1.4985, -0.4396, 1.4525]) + >>> torch.asin(a) + tensor([-0.6387, nan, -0.4552, nan]) +""".format(**common_args), +) + +add_docstr( + torch.arcsin, + r""" +arcsin(input: Tensor, *, out: Optional[Tensor]) -> Tensor + +Alias for :func:`torch.asin`. +""", +) + +add_docstr( + torch.asinh, + r""" +asinh(input: Tensor, *, out: Optional[Tensor]) -> Tensor + +Returns a new tensor with the inverse hyperbolic sine of the elements of :attr:`input`. + +.. math:: + \text{out}_{i} = \sinh^{-1}(\text{input}_{i}) +""" + + r""" +Args: + {input} + +Keyword arguments: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.1606, -1.4267, -1.0899, -1.0250 ]) + >>> torch.asinh(a) + tensor([ 0.1599, -1.1534, -0.9435, -0.8990 ]) +""".format(**common_args), +) + +add_docstr( + torch.arcsinh, + r""" +arcsinh(input: Tensor, *, out: Optional[Tensor]) -> Tensor + +Alias for :func:`torch.asinh`. +""", +) + +add_docstr( + torch.atan, + r""" +atan(input: Tensor, *, out: Optional[Tensor]) -> Tensor + +Returns a new tensor with the arctangent of the elements of :attr:`input`. + +.. math:: + \text{out}_{i} = \tan^{-1}(\text{input}_{i}) +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.2341, 0.2539, -0.6256, -0.6448]) + >>> torch.atan(a) + tensor([ 0.2299, 0.2487, -0.5591, -0.5727]) +""".format(**common_args), +) + +add_docstr( + torch.arctan, + r""" +arctan(input: Tensor, *, out: Optional[Tensor]) -> Tensor + +Alias for :func:`torch.atan`. +""", +) + +add_docstr( + torch.atan2, + r""" +atan2(input: Tensor, other: Tensor, *, out: Optional[Tensor]) -> Tensor + +Element-wise arctangent of :math:`\text{{input}}_{{i}} / \text{{other}}_{{i}}` +with consideration of the quadrant. Returns a new tensor with the signed angles +in radians between vector :math:`(\text{{other}}_{{i}}, \text{{input}}_{{i}})` +and vector :math:`(1, 0)`. (Note that :math:`\text{{other}}_{{i}}`, the second +parameter, is the x-coordinate, while :math:`\text{{input}}_{{i}}`, the first +parameter, is the y-coordinate.) + +The shapes of ``input`` and ``other`` must be +:ref:`broadcastable `. + +Args: + input (Tensor): the first input tensor + other (Tensor): the second input tensor + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.9041, 0.0196, -0.3108, -2.4423]) + >>> torch.atan2(a, torch.randn(4)) + tensor([ 0.9833, 0.0811, -1.9743, -1.4151]) +""".format(**common_args), +) + +add_docstr( + torch.arctan2, + r""" +arctan2(input: Tensor, other: Tensor, *, out: Optional[Tensor]) -> Tensor +Alias for :func:`torch.atan2`. +""", +) + +add_docstr( + torch.atanh, + r""" +atanh(input: Tensor, *, out: Optional[Tensor]) -> Tensor + +Returns a new tensor with the inverse hyperbolic tangent of the elements of :attr:`input`. + +Note: + The domain of the inverse hyperbolic tangent is `(-1, 1)` and values outside this range + will be mapped to ``NaN``, except for the values `1` and `-1` for which the output is + mapped to `+/-INF` respectively. + +.. math:: + \text{out}_{i} = \tanh^{-1}(\text{input}_{i}) +""" + + r""" +Args: + {input} + +Keyword arguments: + {out} + +Example:: + + >>> a = torch.randn(4).uniform_(-1, 1) + >>> a + tensor([ -0.9385, 0.2968, -0.8591, -0.1871 ]) + >>> torch.atanh(a) + tensor([ -1.7253, 0.3060, -1.2899, -0.1893 ]) +""".format(**common_args), +) + +add_docstr( + torch.arctanh, + r""" +arctanh(input: Tensor, *, out: Optional[Tensor]) -> Tensor + +Alias for :func:`torch.atanh`. +""", +) + +add_docstr( + torch.asarray, + r""" +asarray(obj: Any, *, dtype: Optional[dtype], device: Optional[DeviceLikeType], copy: Optional[bool] = None, requires_grad: bool = False) -> Tensor # noqa: B950 + +Converts :attr:`obj` to a tensor. + +:attr:`obj` can be one of: + +1. a tensor +2. a NumPy array or a NumPy scalar +3. a DLPack capsule +4. an object that implements Python's buffer protocol +5. a scalar +6. a sequence of scalars + +When :attr:`obj` is a tensor, NumPy array, or DLPack capsule the returned tensor will, +by default, not require a gradient, have the same datatype as :attr:`obj`, be on the +same device, and share memory with it. These properties can be controlled with the +:attr:`dtype`, :attr:`device`, :attr:`copy`, and :attr:`requires_grad` keyword arguments. +If the returned tensor is of a different datatype, on a different device, or a copy is +requested then it will not share its memory with :attr:`obj`. If :attr:`requires_grad` +is ``True`` then the returned tensor will require a gradient, and if :attr:`obj` is +also a tensor with an autograd history then the returned tensor will have the same history. + +When :attr:`obj` is not a tensor, NumPy array, or DLPack capsule but implements Python's +buffer protocol then the buffer is interpreted as an array of bytes grouped according to +the size of the datatype passed to the :attr:`dtype` keyword argument. (If no datatype is +passed then the default floating point datatype is used, instead.) The returned tensor +will have the specified datatype (or default floating point datatype if none is specified) +and, by default, be on the CPU device and share memory with the buffer. + +When :attr:`obj` is a NumPy scalar, the returned tensor will be a 0-dimensional tensor on +the CPU and that doesn't share its memory (i.e. ``copy=True``). By default datatype will +be the PyTorch datatype corresponding to the NumPy's scalar's datatype. + +When :attr:`obj` is none of the above but a scalar, or a sequence of scalars then the +returned tensor will, by default, infer its datatype from the scalar values, be on the +current default device, and not share its memory. + +.. seealso:: + + :func:`torch.tensor` creates a tensor that always copies the data from the input object. + :func:`torch.from_numpy` creates a tensor that always shares memory from NumPy arrays. + :func:`torch.frombuffer` creates a tensor that always shares memory from objects that + implement the buffer protocol. + :func:`torch.from_dlpack` creates a tensor that always shares memory from + DLPack capsules. + +Args: + obj (object): a tensor, NumPy array, DLPack Capsule, object that implements Python's + buffer protocol, scalar, or sequence of scalars. + +Keyword args: + dtype (:class:`torch.dtype`, optional): the datatype of the returned tensor. + Default: ``None``, which causes the datatype of the returned tensor to be + inferred from :attr:`obj`. + copy (bool, optional): controls whether the returned tensor shares memory with :attr:`obj`. + Default: ``None``, which causes the returned tensor to share memory with :attr:`obj` + whenever possible. If ``True`` then the returned tensor does not share its memory. + If ``False`` then the returned tensor shares its memory with :attr:`obj` and an + error is thrown if it cannot. + device (:class:`torch.device`, optional): the device of the returned tensor. + Default: ``None``, which causes the device of :attr:`obj` to be used. Or, if + :attr:`obj` is a Python sequence, the current default device will be used. + requires_grad (bool, optional): whether the returned tensor requires grad. + Default: ``False``, which causes the returned tensor not to require a gradient. + If ``True``, then the returned tensor will require a gradient, and if :attr:`obj` + is also a tensor with an autograd history then the returned tensor will have + the same history. + +Example:: + + >>> a = torch.tensor([1, 2, 3]) + >>> # Shares memory with tensor 'a' + >>> b = torch.asarray(a) + >>> a.data_ptr() == b.data_ptr() + True + >>> # Forces memory copy + >>> c = torch.asarray(a, copy=True) + >>> a.data_ptr() == c.data_ptr() + False + + >>> a = torch.tensor([1., 2., 3.], requires_grad=True) + >>> b = a + 2 + >>> b + tensor([3., 4., 5.], grad_fn=) + >>> # Shares memory with tensor 'b', with no grad + >>> c = torch.asarray(b) + >>> c + tensor([3., 4., 5.]) + >>> # Shares memory with tensor 'b', retaining autograd history + >>> d = torch.asarray(b, requires_grad=True) + >>> d + tensor([3., 4., 5.], grad_fn=) + + >>> array = numpy.array([1, 2, 3]) + >>> # Shares memory with array 'array' + >>> t1 = torch.asarray(array) + >>> array.__array_interface__['data'][0] == t1.data_ptr() + True + >>> # Copies memory due to dtype mismatch + >>> t2 = torch.asarray(array, dtype=torch.float32) + >>> array.__array_interface__['data'][0] == t2.data_ptr() + False + + >>> scalar = numpy.float64(0.5) + >>> torch.asarray(scalar) + tensor(0.5000, dtype=torch.float64) +""", +) + +add_docstr( + torch.baddbmm, + r""" +baddbmm(input, batch1, batch2, out_dtype=None, *, beta=1, alpha=1, out=None) -> Tensor + +Performs a batch matrix-matrix product of matrices in :attr:`batch1` +and :attr:`batch2`. +:attr:`input` is added to the final result. + +:attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the same +number of matrices. + +If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a +:math:`(b \times m \times p)` tensor, then :attr:`input` must be +:ref:`broadcastable ` with a +:math:`(b \times n \times p)` tensor and :attr:`out` will be a +:math:`(b \times n \times p)` tensor. Both :attr:`alpha` and :attr:`beta` mean the +same as the scaling factors used in :meth:`torch.addbmm`. + +.. math:: + \text{out}_i = \beta\ \text{input}_i + \alpha\ (\text{batch1}_i \mathbin{@} \text{batch2}_i) + +If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in +it will not be propagated. +""" + + r""" +For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and +:attr:`alpha` must be real numbers, otherwise they should be integers. + +{tf32_note} + +{rocm_fp16_note} + +Args: + input (Tensor): the tensor to be added + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + out_dtype (dtype, optional): the dtype of the output tensor, + Supported only on CUDA and for torch.float32 given + torch.float16/torch.bfloat16 input dtypes + +Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{{batch1}} \mathbin{{@}} \text{{batch2}}` (:math:`\alpha`) + {out} + +Example:: + + >>> M = torch.randn(10, 3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.baddbmm(M, batch1, batch2).size() + torch.Size([10, 3, 5]) +""".format(**common_args, **tf32_notes, **rocm_fp16_notes), +) + +add_docstr( + torch.bernoulli, + r""" +bernoulli(input: Tensor, *, generator: Optional[Generator], out: Optional[Tensor]) -> Tensor + +Draws binary random numbers (0 or 1) from a Bernoulli distribution. + +The :attr:`input` tensor should be a tensor containing probabilities +to be used for drawing the binary random number. +Hence, all values in :attr:`input` have to be in the range: +:math:`0 \leq \text{input}_i \leq 1`. + +The :math:`\text{i}^{th}` element of the output tensor will draw a +value :math:`1` according to the :math:`\text{i}^{th}` probability value given +in :attr:`input`. + +.. math:: + \text{out}_{i} \sim \mathrm{Bernoulli}(p = \text{input}_{i}) +""" + + r""" +The returned :attr:`out` tensor only has values 0 or 1 and is of the same +shape as :attr:`input`. + +:attr:`out` can have integral ``dtype``, but :attr:`input` must have floating +point ``dtype``. + +Args: + input (Tensor): the input tensor of probability values for the Bernoulli distribution + +Keyword args: + {generator} + {out} + +Example:: + + >>> a = torch.empty(3, 3).uniform_(0, 1) # generate a uniform random matrix with range [0, 1] + >>> a + tensor([[ 0.1737, 0.0950, 0.3609], + [ 0.7148, 0.0289, 0.2676], + [ 0.9456, 0.8937, 0.7202]]) + >>> torch.bernoulli(a) + tensor([[ 1., 0., 0.], + [ 0., 0., 0.], + [ 1., 1., 1.]]) + + >>> a = torch.ones(3, 3) # probability of drawing "1" is 1 + >>> torch.bernoulli(a) + tensor([[ 1., 1., 1.], + [ 1., 1., 1.], + [ 1., 1., 1.]]) + >>> a = torch.zeros(3, 3) # probability of drawing "1" is 0 + >>> torch.bernoulli(a) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.], + [ 0., 0., 0.]]) +""".format(**common_args), +) + +add_docstr( + torch.bincount, + r""" +bincount(input, weights=None, minlength=0) -> Tensor + +Count the frequency of each value in an array of non-negative ints. + +The number of bins (size 1) is one larger than the largest value in +:attr:`input` unless :attr:`input` is empty, in which case the result is a +tensor of size 0. If :attr:`minlength` is specified, the number of bins is at least +:attr:`minlength` and if :attr:`input` is empty, then the result is tensor of size +:attr:`minlength` filled with zeros. If ``n`` is the value at position ``i``, +``out[n] += weights[i]`` if :attr:`weights` is specified else +``out[n] += 1``. + +Note: + {backward_reproducibility_note} + +Arguments: + input (Tensor): 1-d int tensor + weights (Tensor): optional, weight for each value in the input tensor. + Should be of same size as input tensor. + minlength (int): optional, minimum number of bins. Should be non-negative. + +Returns: + output (Tensor): a tensor of shape ``Size([max(input) + 1])`` if + :attr:`input` is non-empty, else ``Size(0)`` + +Example:: + + >>> input = torch.randint(0, 8, (5,), dtype=torch.int64) + >>> weights = torch.linspace(0, 1, steps=5) + >>> input, weights + (tensor([4, 3, 6, 3, 4]), + tensor([ 0.0000, 0.2500, 0.5000, 0.7500, 1.0000]) + + >>> torch.bincount(input) + tensor([0, 0, 0, 2, 2, 0, 1]) + + >>> input.bincount(weights) + tensor([0.0000, 0.0000, 0.0000, 1.0000, 1.0000, 0.0000, 0.5000]) +""".format(**reproducibility_notes), +) + +add_docstr( + torch.bitwise_not, + r""" +bitwise_not(input, *, out=None) -> Tensor + +Computes the bitwise NOT of the given input tensor. The input tensor must be of +integral or Boolean types. For bool tensors, it computes the logical NOT. + +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> torch.bitwise_not(torch.tensor([-1, -2, 3], dtype=torch.int8)) + tensor([ 0, 1, -4], dtype=torch.int8) +""".format(**common_args), +) + +add_docstr( + torch.bmm, + r""" +bmm(input, mat2, out_dtype=None, *, out=None) -> Tensor + +Performs a batch matrix-matrix product of matrices stored in :attr:`input` +and :attr:`mat2`. + +:attr:`input` and :attr:`mat2` must be 3-D tensors each containing +the same number of matrices. + +If :attr:`input` is a :math:`(b \times n \times m)` tensor, :attr:`mat2` is a +:math:`(b \times m \times p)` tensor, :attr:`out` will be a +:math:`(b \times n \times p)` tensor. + +.. math:: + \text{out}_i = \text{input}_i \mathbin{@} \text{mat2}_i +""" + + r""" +{tf32_note} + +{rocm_fp16_note} + +.. note:: This function does not :ref:`broadcast `. + For broadcasting matrix products, see :func:`torch.matmul`. + +Args: + input (Tensor): the first batch of matrices to be multiplied + mat2 (Tensor): the second batch of matrices to be multiplied + out_dtype (dtype, optional): the dtype of the output tensor, + Supported only on CUDA and for torch.float32 given + torch.float16/torch.bfloat16 input dtypes + +Keyword Args: + {out} + +Example:: + + >>> input = torch.randn(10, 3, 4) + >>> mat2 = torch.randn(10, 4, 5) + >>> res = torch.bmm(input, mat2) + >>> res.size() + torch.Size([10, 3, 5]) +""".format(**common_args, **tf32_notes, **rocm_fp16_notes), +) + +add_docstr( + torch.bitwise_and, + r""" +bitwise_and(input, other, *, out=None) -> Tensor + +Computes the bitwise AND of :attr:`input` and :attr:`other`. The input tensor must be of +integral or Boolean types. For bool tensors, it computes the logical AND. + +Args: + input: the first input tensor + other: the second input tensor + +Keyword args: + {out} + +Example:: + + >>> torch.bitwise_and(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([1, 0, 3], dtype=torch.int8) + >>> torch.bitwise_and(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ False, True, False]) +""".format(**common_args), +) + +add_docstr( + torch.bitwise_or, + r""" +bitwise_or(input: Tensor, other: Tensor, *, out: Optional[Tensor]) -> Tensor + +Computes the bitwise OR of :attr:`input` and :attr:`other`. The input tensor must be of +integral or Boolean types. For bool tensors, it computes the logical OR. + +Args: + input: the first input tensor + other: the second input tensor + +Keyword args: + {out} + +Example:: + + >>> torch.bitwise_or(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-1, -2, 3], dtype=torch.int8) + >>> torch.bitwise_or(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ True, True, False]) +""".format(**common_args), +) + +add_docstr( + torch.bitwise_xor, + r""" +bitwise_xor(input, other, *, out=None) -> Tensor + +Computes the bitwise XOR of :attr:`input` and :attr:`other`. The input tensor must be of +integral or Boolean types. For bool tensors, it computes the logical XOR. + +Args: + input: the first input tensor + other: the second input tensor + +Keyword args: + {out} + +Example:: + + >>> torch.bitwise_xor(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-2, -2, 0], dtype=torch.int8) + >>> torch.bitwise_xor(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ True, False, False]) +""".format(**common_args), +) + +add_docstr( + torch.bitwise_left_shift, + r""" +bitwise_left_shift(input, other, *, out=None) -> Tensor + +Computes the left arithmetic shift of :attr:`input` by :attr:`other` bits. +The input tensor must be of integral type. This operator supports +:ref:`broadcasting to a common shape ` and +:ref:`type promotion `. + +The operation applied is: + +.. math:: + \text{{out}}_i = \text{{input}}_i << \text{{other}}_i + +Args: + input (Tensor or Scalar): the first input tensor + other (Tensor or Scalar): the second input tensor + +Keyword args: + {out} + +Example:: + + >>> torch.bitwise_left_shift(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-2, -2, 24], dtype=torch.int8) +""".format(**common_args), +) + +add_docstr( + torch.bitwise_right_shift, + r""" +bitwise_right_shift(input, other, *, out=None) -> Tensor + +Computes the right arithmetic shift of :attr:`input` by :attr:`other` bits. +The input tensor must be of integral type. This operator supports +:ref:`broadcasting to a common shape ` and +:ref:`type promotion `. +In any case, if the value of the right operand is negative or is greater +or equal to the number of bits in the promoted left operand, the behavior is undefined. + +The operation applied is: + +.. math:: + \text{{out}}_i = \text{{input}}_i >> \text{{other}}_i + +Args: + input (Tensor or Scalar): the first input tensor + other (Tensor or Scalar): the second input tensor + +Keyword args: + {out} + +Example:: + + >>> torch.bitwise_right_shift(torch.tensor([-2, -7, 31], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-1, -7, 3], dtype=torch.int8) +""".format(**common_args), +) + +add_docstr( + torch.broadcast_to, + r""" +broadcast_to(input, shape) -> Tensor + +Broadcasts :attr:`input` to the shape :attr:`\shape`. +Equivalent to calling ``input.expand(shape)``. See :meth:`~Tensor.expand` for details. + +Args: + {input} + shape (list, tuple, or :class:`torch.Size`): the new shape. + +Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> torch.broadcast_to(x, (3, 3)) + tensor([[1, 2, 3], + [1, 2, 3], + [1, 2, 3]]) +""".format(**common_args), +) + +add_docstr( + torch.stack, + r""" +stack(tensors, dim=0, *, out=None) -> Tensor + +Concatenates a sequence of tensors along a new dimension. + +All tensors need to be of the same size. + +.. seealso:: + + :func:`torch.cat` concatenates the given sequence along an existing dimension. + +Arguments: + tensors (sequence of Tensors): sequence of tensors to concatenate + dim (int, optional): dimension to insert. Has to be between 0 and the number + of dimensions of concatenated tensors (inclusive). Default: 0 + +Keyword args: + {out} + +Example:: + + >>> x = torch.randn(2, 3) + >>> x + tensor([[ 0.3367, 0.1288, 0.2345], + [ 0.2303, -1.1229, -0.1863]]) + >>> torch.stack((x, x)) # same as torch.stack((x, x), dim=0) + tensor([[[ 0.3367, 0.1288, 0.2345], + [ 0.2303, -1.1229, -0.1863]], + + [[ 0.3367, 0.1288, 0.2345], + [ 0.2303, -1.1229, -0.1863]]]) + >>> torch.stack((x, x)).size() + torch.Size([2, 2, 3]) + >>> torch.stack((x, x), dim=1) + tensor([[[ 0.3367, 0.1288, 0.2345], + [ 0.3367, 0.1288, 0.2345]], + + [[ 0.2303, -1.1229, -0.1863], + [ 0.2303, -1.1229, -0.1863]]]) + >>> torch.stack((x, x), dim=2) + tensor([[[ 0.3367, 0.3367], + [ 0.1288, 0.1288], + [ 0.2345, 0.2345]], + + [[ 0.2303, 0.2303], + [-1.1229, -1.1229], + [-0.1863, -0.1863]]]) + >>> torch.stack((x, x), dim=-1) + tensor([[[ 0.3367, 0.3367], + [ 0.1288, 0.1288], + [ 0.2345, 0.2345]], + + [[ 0.2303, 0.2303], + [-1.1229, -1.1229], + [-0.1863, -0.1863]]]) +""".format(**common_args), +) + +add_docstr( + torch.hstack, + r""" +hstack(tensors, *, out=None) -> Tensor + +Stack tensors in sequence horizontally (column wise). + +This is equivalent to concatenation along the first axis for 1-D tensors, and along the second axis for all other tensors. + +Args: + tensors (sequence of Tensors): sequence of tensors to concatenate + +Keyword args: + {out} + +Example:: + + >>> a = torch.tensor([1, 2, 3]) + >>> b = torch.tensor([4, 5, 6]) + >>> torch.hstack((a,b)) + tensor([1, 2, 3, 4, 5, 6]) + >>> a = torch.tensor([[1],[2],[3]]) + >>> b = torch.tensor([[4],[5],[6]]) + >>> torch.hstack((a,b)) + tensor([[1, 4], + [2, 5], + [3, 6]]) + +""".format(**common_args), +) + +add_docstr( + torch.vstack, + r""" +vstack(tensors, *, out=None) -> Tensor + +Stack tensors in sequence vertically (row wise). + +This is equivalent to concatenation along the first axis after all 1-D tensors have been reshaped by :func:`torch.atleast_2d`. + +Args: + tensors (sequence of Tensors): sequence of tensors to concatenate + +Keyword args: + {out} + +Example:: + + >>> a = torch.tensor([1, 2, 3]) + >>> b = torch.tensor([4, 5, 6]) + >>> torch.vstack((a,b)) + tensor([[1, 2, 3], + [4, 5, 6]]) + >>> a = torch.tensor([[1],[2],[3]]) + >>> b = torch.tensor([[4],[5],[6]]) + >>> torch.vstack((a,b)) + tensor([[1], + [2], + [3], + [4], + [5], + [6]]) + + +""".format(**common_args), +) + +add_docstr( + torch.dstack, + r""" +dstack(tensors, *, out=None) -> Tensor + +Stack tensors in sequence depthwise (along third axis). + +This is equivalent to concatenation along the third axis after 1-D and 2-D tensors have been reshaped by :func:`torch.atleast_3d`. + +Args: + tensors (sequence of Tensors): sequence of tensors to concatenate + +Keyword args: + {out} + +Example:: + + >>> a = torch.tensor([1, 2, 3]) + >>> b = torch.tensor([4, 5, 6]) + >>> torch.dstack((a,b)) + tensor([[[1, 4], + [2, 5], + [3, 6]]]) + >>> a = torch.tensor([[1],[2],[3]]) + >>> b = torch.tensor([[4],[5],[6]]) + >>> torch.dstack((a,b)) + tensor([[[1, 4]], + [[2, 5]], + [[3, 6]]]) + + +""".format(**common_args), +) + +add_docstr( + torch.tensor_split, + r""" +tensor_split(input, indices_or_sections, dim=0) -> List of Tensors + +Splits a tensor into multiple sub-tensors, all of which are views of :attr:`input`, +along dimension :attr:`dim` according to the indices or number of sections specified +by :attr:`indices_or_sections`. This function is based on NumPy's +:func:`numpy.array_split`. + +Args: + input (Tensor): the tensor to split + indices_or_sections (Tensor, int or list or tuple of ints): + If :attr:`indices_or_sections` is an integer ``n`` or a zero dimensional long tensor + with value ``n``, :attr:`input` is split into ``n`` sections along dimension :attr:`dim`. + If :attr:`input` is divisible by ``n`` along dimension :attr:`dim`, each + section will be of equal size, :code:`input.size(dim) / n`. If :attr:`input` + is not divisible by ``n``, the sizes of the first :code:`int(input.size(dim) % n)` + sections will have size :code:`int(input.size(dim) / n) + 1`, and the rest will + have size :code:`int(input.size(dim) / n)`. + + If :attr:`indices_or_sections` is a list or tuple of ints, or a one-dimensional long + tensor, then :attr:`input` is split along dimension :attr:`dim` at each of the indices + in the list, tuple or tensor. For instance, :code:`indices_or_sections=[2, 3]` and :code:`dim=0` + would result in the tensors :code:`input[:2]`, :code:`input[2:3]`, and :code:`input[3:]`. + + If :attr:`indices_or_sections` is a tensor, it must be a zero-dimensional or one-dimensional + long tensor on the CPU. + + dim (int, optional): dimension along which to split the tensor. Default: ``0`` + +Example:: + + >>> x = torch.arange(8) + >>> torch.tensor_split(x, 3) + (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7])) + + >>> x = torch.arange(7) + >>> torch.tensor_split(x, 3) + (tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6])) + >>> torch.tensor_split(x, (1, 6)) + (tensor([0]), tensor([1, 2, 3, 4, 5]), tensor([6])) + + >>> x = torch.arange(14).reshape(2, 7) + >>> x + tensor([[ 0, 1, 2, 3, 4, 5, 6], + [ 7, 8, 9, 10, 11, 12, 13]]) + >>> torch.tensor_split(x, 3, dim=1) + (tensor([[0, 1, 2], + [7, 8, 9]]), + tensor([[ 3, 4], + [10, 11]]), + tensor([[ 5, 6], + [12, 13]])) + >>> torch.tensor_split(x, (1, 6), dim=1) + (tensor([[0], + [7]]), + tensor([[ 1, 2, 3, 4, 5], + [ 8, 9, 10, 11, 12]]), + tensor([[ 6], + [13]])) +""", +) + +add_docstr( + torch.chunk, + r""" +chunk(input: Tensor, chunks: int, dim: int = 0) -> Tuple[Tensor, ...] + +Attempts to split a tensor into the specified number of chunks. Each chunk is a view of +the input tensor. + + +.. note:: + + This function may return fewer than the specified number of chunks! + +.. seealso:: + + :func:`torch.tensor_split` a function that always returns exactly the specified number of chunks + +If the tensor size along the given dimension :attr:`dim` is divisible by :attr:`chunks`, +all returned chunks will be the same size. +If the tensor size along the given dimension :attr:`dim` is not divisible by :attr:`chunks`, +all returned chunks will be the same size, except the last one. +If such division is not possible, this function may return fewer +than the specified number of chunks. + +Arguments: + input (Tensor): the tensor to split + chunks (int): number of chunks to return + dim (int): dimension along which to split the tensor + +Example: + >>> torch.arange(11).chunk(6) + (tensor([0, 1]), + tensor([2, 3]), + tensor([4, 5]), + tensor([6, 7]), + tensor([8, 9]), + tensor([10])) + >>> torch.arange(12).chunk(6) + (tensor([0, 1]), + tensor([2, 3]), + tensor([4, 5]), + tensor([6, 7]), + tensor([8, 9]), + tensor([10, 11])) + >>> torch.arange(13).chunk(6) + (tensor([0, 1, 2]), + tensor([3, 4, 5]), + tensor([6, 7, 8]), + tensor([ 9, 10, 11]), + tensor([12])) +""", +) + +add_docstr( + torch.unsafe_chunk, + r""" +unsafe_chunk(input, chunks, dim=0) -> List of Tensors + +Works like :func:`torch.chunk` but without enforcing the autograd restrictions +on inplace modification of the outputs. + +.. warning:: + This function is safe to use as long as only the input, or only the outputs + are modified inplace after calling this function. It is user's + responsibility to ensure that is the case. If both the input and one or more + of the outputs are modified inplace, gradients computed by autograd will be + silently incorrect. +""", +) + +add_docstr( + torch.unsafe_split, + r""" +unsafe_split(tensor, split_size_or_sections, dim=0) -> List of Tensors + +Works like :func:`torch.split` but without enforcing the autograd restrictions +on inplace modification of the outputs. + +.. warning:: + This function is safe to use as long as only the input, or only the outputs + are modified inplace after calling this function. It is user's + responsibility to ensure that is the case. If both the input and one or more + of the outputs are modified inplace, gradients computed by autograd will be + silently incorrect. +""", +) + +add_docstr( + torch.hsplit, + r""" +hsplit(input, indices_or_sections) -> List of Tensors + +Splits :attr:`input`, a tensor with one or more dimensions, into multiple tensors +horizontally according to :attr:`indices_or_sections`. Each split is a view of +:attr:`input`. + +If :attr:`input` is one dimensional this is equivalent to calling +torch.tensor_split(input, indices_or_sections, dim=0) (the split dimension is +zero), and if :attr:`input` has two or more dimensions it's equivalent to calling +torch.tensor_split(input, indices_or_sections, dim=1) (the split dimension is 1), +except that if :attr:`indices_or_sections` is an integer it must evenly divide +the split dimension or a runtime error will be thrown. + +This function is based on NumPy's :func:`numpy.hsplit`. + +Args: + input (Tensor): tensor to split. + indices_or_sections (int or list or tuple of ints): See argument in :func:`torch.tensor_split`. + +Example:: + + >>> t = torch.arange(16.0).reshape(4,4) + >>> t + tensor([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.], + [12., 13., 14., 15.]]) + >>> torch.hsplit(t, 2) + (tensor([[ 0., 1.], + [ 4., 5.], + [ 8., 9.], + [12., 13.]]), + tensor([[ 2., 3.], + [ 6., 7.], + [10., 11.], + [14., 15.]])) + >>> torch.hsplit(t, [3, 6]) + (tensor([[ 0., 1., 2.], + [ 4., 5., 6.], + [ 8., 9., 10.], + [12., 13., 14.]]), + tensor([[ 3.], + [ 7.], + [11.], + [15.]]), + tensor([], size=(4, 0))) + +""", +) + +add_docstr( + torch.vsplit, + r""" +vsplit(input, indices_or_sections) -> List of Tensors + +Splits :attr:`input`, a tensor with two or more dimensions, into multiple tensors +vertically according to :attr:`indices_or_sections`. Each split is a view of +:attr:`input`. + +This is equivalent to calling torch.tensor_split(input, indices_or_sections, dim=0) +(the split dimension is 0), except that if :attr:`indices_or_sections` is an integer +it must evenly divide the split dimension or a runtime error will be thrown. + +This function is based on NumPy's :func:`numpy.vsplit`. + +Args: + input (Tensor): tensor to split. + indices_or_sections (int or list or tuple of ints): See argument in :func:`torch.tensor_split`. + +Example:: + + >>> t = torch.arange(16.0).reshape(4,4) + >>> t + tensor([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.], + [12., 13., 14., 15.]]) + >>> torch.vsplit(t, 2) + (tensor([[0., 1., 2., 3.], + [4., 5., 6., 7.]]), + tensor([[ 8., 9., 10., 11.], + [12., 13., 14., 15.]])) + >>> torch.vsplit(t, [3, 6]) + (tensor([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.]]), + tensor([[12., 13., 14., 15.]]), + tensor([], size=(0, 4))) + +""", +) + +add_docstr( + torch.dsplit, + r""" +dsplit(input, indices_or_sections) -> List of Tensors + +Splits :attr:`input`, a tensor with three or more dimensions, into multiple tensors +depthwise according to :attr:`indices_or_sections`. Each split is a view of +:attr:`input`. + +This is equivalent to calling torch.tensor_split(input, indices_or_sections, dim=2) +(the split dimension is 2), except that if :attr:`indices_or_sections` is an integer +it must evenly divide the split dimension or a runtime error will be thrown. + +This function is based on NumPy's :func:`numpy.dsplit`. + +Args: + input (Tensor): tensor to split. + indices_or_sections (int or list or tuple of ints): See argument in :func:`torch.tensor_split`. + +Example:: + + >>> t = torch.arange(16.0).reshape(2, 2, 4) + >>> t + tensor([[[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.]], + [[ 8., 9., 10., 11.], + [12., 13., 14., 15.]]]) + >>> torch.dsplit(t, 2) + (tensor([[[ 0., 1.], + [ 4., 5.]], + [[ 8., 9.], + [12., 13.]]]), + tensor([[[ 2., 3.], + [ 6., 7.]], + [[10., 11.], + [14., 15.]]])) + + >>> torch.dsplit(t, [3, 6]) + (tensor([[[ 0., 1., 2.], + [ 4., 5., 6.]], + [[ 8., 9., 10.], + [12., 13., 14.]]]), + tensor([[[ 3.], + [ 7.]], + [[11.], + [15.]]]), + tensor([], size=(2, 2, 0))) + +""", +) + +add_docstr( + torch.can_cast, + r""" +can_cast(from_, to) -> bool + +Determines if a type conversion is allowed under PyTorch casting rules +described in the type promotion :ref:`documentation `. + +Args: + from\_ (dtype): The original :class:`torch.dtype`. + to (dtype): The target :class:`torch.dtype`. + +Example:: + + >>> torch.can_cast(torch.double, torch.float) + True + >>> torch.can_cast(torch.float, torch.int) + False +""", +) + +add_docstr( + torch.corrcoef, + r""" +corrcoef(input) -> Tensor + +Estimates the Pearson product-moment correlation coefficient matrix of the variables given by the :attr:`input` matrix, +where rows are the variables and columns are the observations. + +.. note:: + + The correlation coefficient matrix R is computed using the covariance matrix C as given by + :math:`R_{ij} = \frac{ C_{ij} } { \sqrt{ C_{ii} * C_{jj} } }` + +.. note:: + + Due to floating point rounding, the resulting array may not be Hermitian and its diagonal elements may not be 1. + The real and imaginary values are clipped to the interval [-1, 1] in an attempt to improve this situation. + +Args: + input (Tensor): A 2D matrix containing multiple variables and observations, or a + Scalar or 1D vector representing a single variable. + +Returns: + (Tensor) The correlation coefficient matrix of the variables. + +.. seealso:: + + :func:`torch.cov` covariance matrix. + +Example:: + + >>> x = torch.tensor([[0, 1, 2], [2, 1, 0]]) + >>> torch.corrcoef(x) + tensor([[ 1., -1.], + [-1., 1.]]) + >>> x = torch.randn(2, 4) + >>> x + tensor([[-0.2678, -0.0908, -0.3766, 0.2780], + [-0.5812, 0.1535, 0.2387, 0.2350]]) + >>> torch.corrcoef(x) + tensor([[1.0000, 0.3582], + [0.3582, 1.0000]]) + >>> torch.corrcoef(x[0]) + tensor(1.) +""", +) + +add_docstr( + torch.cov, + r""" +cov(input, *, correction=1, fweights=None, aweights=None) -> Tensor + +Estimates the covariance matrix of the variables given by the :attr:`input` matrix, where rows are +the variables and columns are the observations. + +A covariance matrix is a square matrix giving the covariance of each pair of variables. The diagonal contains +the variance of each variable (covariance of a variable with itself). By definition, if :attr:`input` represents +a single variable (Scalar or 1D) then its variance is returned. + +The sample covariance of the variables :math:`x` and :math:`y` is given by: + +.. math:: + \text{cov}(x,y) = \frac{\sum^{N}_{i = 1}(x_{i} - \bar{x})(y_{i} - \bar{y})}{\max(0,~N~-~\delta N)} + +where :math:`\bar{x}` and :math:`\bar{y}` are the simple means of the :math:`x` and :math:`y` respectively, and +:math:`\delta N` is the :attr:`correction`. + +If :attr:`fweights` and/or :attr:`aweights` are provided, the weighted covariance +is calculated, which is given by: + +.. math:: + \text{cov}_w(x,y) = \frac{\sum^{N}_{i = 1}w_i(x_{i} - \mu_x^*)(y_{i} - \mu_y^*)} + {\max(0,~\sum^{N}_{i = 1}w_i~-~\frac{\sum^{N}_{i = 1}w_ia_i}{\sum^{N}_{i = 1}w_i}~\delta N)} + +where :math:`w` denotes :attr:`fweights` or :attr:`aweights` (``f`` and ``a`` for brevity) based on whichever is +provided, or :math:`w = f \times a` if both are provided, and +:math:`\mu_x^* = \frac{\sum^{N}_{i = 1}w_ix_{i} }{\sum^{N}_{i = 1}w_i}` is the weighted mean of the variable. If not +provided, ``f`` and/or ``a`` can be seen as a :math:`\mathbb{1}` vector of appropriate size. + +Args: + input (Tensor): A 2D matrix containing multiple variables and observations, or a + Scalar or 1D vector representing a single variable. + +Keyword Args: + correction (int, optional): difference between the sample size and sample degrees of freedom. + Defaults to Bessel's correction, ``correction = 1`` which returns the unbiased estimate, + even if both :attr:`fweights` and :attr:`aweights` are specified. ``correction = 0`` + will return the simple average. Defaults to ``1``. + fweights (tensor, optional): A Scalar or 1D tensor of observation vector frequencies representing the number of + times each observation should be repeated. Its numel must equal the number of columns of :attr:`input`. + Must have integral dtype. Ignored if ``None``. Defaults to ``None``. + aweights (tensor, optional): A Scalar or 1D array of observation vector weights. + These relative weights are typically large for observations considered "important" and smaller for + observations considered less "important". Its numel must equal the number of columns of :attr:`input`. + Must have floating point dtype. Ignored if ``None``. Defaults to ``None``. + +Returns: + (Tensor) The covariance matrix of the variables. + +.. seealso:: + + :func:`torch.corrcoef` normalized covariance matrix. + +Example:: + + >>> x = torch.tensor([[0, 2], [1, 1], [2, 0]]).T + >>> x + tensor([[0, 1, 2], + [2, 1, 0]]) + >>> torch.cov(x) + tensor([[ 1., -1.], + [-1., 1.]]) + >>> torch.cov(x, correction=0) + tensor([[ 0.6667, -0.6667], + [-0.6667, 0.6667]]) + >>> fw = torch.randint(1, 10, (3,)) + >>> fw + tensor([1, 6, 9]) + >>> aw = torch.rand(3) + >>> aw + tensor([0.4282, 0.0255, 0.4144]) + >>> torch.cov(x, fweights=fw, aweights=aw) + tensor([[ 0.4169, -0.4169], + [-0.4169, 0.4169]]) +""", +) + +add_docstr( + torch.cat, + r""" +cat(tensors, dim=0, *, out=None) -> Tensor + +Concatenates the given sequence of tensors in :attr:`tensors` in the given dimension. +All tensors must either have the same shape (except in the concatenating +dimension) or be a 1-D empty tensor with size ``(0,)``. + +:func:`torch.cat` can be seen as an inverse operation for :func:`torch.split` +and :func:`torch.chunk`. + +:func:`torch.cat` can be best understood via examples. + +.. seealso:: + + :func:`torch.stack` concatenates the given sequence along a new dimension. + +Args: + tensors (sequence of Tensors): Non-empty tensors provided must have the same shape, + except in the cat dimension. + + dim (int, optional): the dimension over which the tensors are concatenated + +Keyword args: + {out} + +Example:: + + >>> x = torch.randn(2, 3) + >>> x + tensor([[ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497]]) + >>> torch.cat((x, x, x), 0) + tensor([[ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497], + [ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497], + [ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497]]) + >>> torch.cat((x, x, x), 1) + tensor([[ 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614, 0.6580, + -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497, -0.1034, + -0.5790, 0.1497]]) +""".format(**common_args), +) + +add_docstr( + torch.concat, + r""" +concat(tensors, dim=0, *, out=None) -> Tensor + +Alias of :func:`torch.cat`. +""", +) + +add_docstr( + torch.concatenate, + r""" +concatenate(tensors, axis=0, out=None) -> Tensor + +Alias of :func:`torch.cat`. +""", +) + +add_docstr( + torch.ceil, + r""" +ceil(input, *, out=None) -> Tensor + +Returns a new tensor with the ceil of the elements of :attr:`input`, +the smallest integer greater than or equal to each element. + +For integer inputs, follows the array-api convention of returning a +copy of the input tensor. + +.. math:: + \text{out}_{i} = \left\lceil \text{input}_{i} \right\rceil +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-0.6341, -1.4208, -1.0900, 0.5826]) + >>> torch.ceil(a) + tensor([-0., -1., -1., 1.]) +""".format(**common_args), +) + +add_docstr( + torch.real, + r""" +real(input) -> Tensor + +Returns a new tensor containing real values of the :attr:`self` tensor. +The returned tensor and :attr:`self` share the same underlying storage. + +Args: + {input} + +Example:: + + >>> x=torch.randn(4, dtype=torch.cfloat) + >>> x + tensor([(0.3100+0.3553j), (-0.5445-0.7896j), (-1.6492-0.0633j), (-0.0638-0.8119j)]) + >>> x.real + tensor([ 0.3100, -0.5445, -1.6492, -0.0638]) + +""".format(**common_args), +) + +add_docstr( + torch.imag, + r""" +imag(input) -> Tensor + +Returns a new tensor containing imaginary values of the :attr:`self` tensor. +The returned tensor and :attr:`self` share the same underlying storage. + +.. warning:: + :func:`imag` is only supported for tensors with complex dtypes. + +Args: + {input} + +Example:: + + >>> x=torch.randn(4, dtype=torch.cfloat) + >>> x + tensor([(0.3100+0.3553j), (-0.5445-0.7896j), (-1.6492-0.0633j), (-0.0638-0.8119j)]) + >>> x.imag + tensor([ 0.3553, -0.7896, -0.0633, -0.8119]) + +""".format(**common_args), +) + +add_docstr( + torch.view_as_real, + r""" +view_as_real(input) -> Tensor + +Returns a view of :attr:`input` as a real tensor. For an input complex tensor of +:attr:`size` :math:`m1, m2, \dots, mi`, this function returns a new +real tensor of size :math:`m1, m2, \dots, mi, 2`, where the last dimension of size 2 +represents the real and imaginary components of complex numbers. + +.. warning:: + :func:`view_as_real` is only supported for tensors with ``complex dtypes``. + +Args: + {input} + +Example:: + + >>> x=torch.randn(4, dtype=torch.cfloat) + >>> x + tensor([(0.4737-0.3839j), (-0.2098-0.6699j), (0.3470-0.9451j), (-0.5174-1.3136j)]) + >>> torch.view_as_real(x) + tensor([[ 0.4737, -0.3839], + [-0.2098, -0.6699], + [ 0.3470, -0.9451], + [-0.5174, -1.3136]]) +""".format(**common_args), +) + +add_docstr( + torch.view_as_complex, + r""" +view_as_complex(input) -> Tensor + +Returns a view of :attr:`input` as a complex tensor. For an input complex +tensor of :attr:`size` :math:`m1, m2, \dots, mi, 2`, this function returns a +new complex tensor of :attr:`size` :math:`m1, m2, \dots, mi` where the last +dimension of the input tensor is expected to represent the real and imaginary +components of complex numbers. + +.. warning:: + :func:`view_as_complex` is only supported for tensors with + :class:`torch.dtype` ``torch.float64`` and ``torch.float32``. The input is + expected to have the last dimension of :attr:`size` 2. In addition, the + tensor must have a `stride` of 1 for its last dimension. The strides of all + other dimensions must be even numbers. + +Args: + {input} + +Example:: + + >>> x=torch.randn(4, 2) + >>> x + tensor([[ 1.6116, -0.5772], + [-1.4606, -0.9120], + [ 0.0786, -1.7497], + [-0.6561, -1.6623]]) + >>> torch.view_as_complex(x) + tensor([(1.6116-0.5772j), (-1.4606-0.9120j), (0.0786-1.7497j), (-0.6561-1.6623j)]) +""".format(**common_args), +) + +add_docstr( + torch.reciprocal, + r""" +reciprocal(input, *, out=None) -> Tensor + +Returns a new tensor with the reciprocal of the elements of :attr:`input` + +.. math:: + \text{out}_{i} = \frac{1}{\text{input}_{i}} + +.. note:: + Unlike NumPy's reciprocal, torch.reciprocal supports integral inputs. Integral + inputs to reciprocal are automatically :ref:`promoted ` to + the default scalar type. +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-0.4595, -2.1219, -1.4314, 0.7298]) + >>> torch.reciprocal(a) + tensor([-2.1763, -0.4713, -0.6986, 1.3702]) +""".format(**common_args), +) + +add_docstr( + torch.cholesky, + r""" +cholesky(input, upper=False, *, out=None) -> Tensor + +Computes the Cholesky decomposition of a symmetric positive-definite +matrix :math:`A` or for batches of symmetric positive-definite matrices. + +If :attr:`upper` is ``True``, the returned matrix ``U`` is upper-triangular, and +the decomposition has the form: + +.. math:: + + A = U^TU + +If :attr:`upper` is ``False``, the returned matrix ``L`` is lower-triangular, and +the decomposition has the form: + +.. math:: + + A = LL^T + +If :attr:`upper` is ``True``, and :math:`A` is a batch of symmetric positive-definite +matrices, then the returned tensor will be composed of upper-triangular Cholesky factors +of each of the individual matrices. Similarly, when :attr:`upper` is ``False``, the returned +tensor will be composed of lower-triangular Cholesky factors of each of the individual +matrices. + +.. warning:: + + :func:`torch.cholesky` is deprecated in favor of :func:`torch.linalg.cholesky` + and will be removed in a future PyTorch release. + + ``L = torch.cholesky(A)`` should be replaced with + + .. code:: python + + L = torch.linalg.cholesky(A) + + ``U = torch.cholesky(A, upper=True)`` should be replaced with + + .. code:: python + + U = torch.linalg.cholesky(A).mH + + This transform will produce equivalent results for all valid (symmetric positive definite) inputs. + +Args: + input (Tensor): the input tensor :math:`A` of size :math:`(*, n, n)` where `*` is zero or more + batch dimensions consisting of symmetric positive-definite matrices. + upper (bool, optional): flag that indicates whether to return a + upper or lower triangular matrix. Default: ``False`` + +Keyword args: + out (Tensor, optional): the output matrix + +Example:: + + >>> a = torch.randn(3, 3) + >>> a = a @ a.mT + 1e-3 # make symmetric positive-definite + >>> l = torch.cholesky(a) + >>> a + tensor([[ 2.4112, -0.7486, 1.4551], + [-0.7486, 1.3544, 0.1294], + [ 1.4551, 0.1294, 1.6724]]) + >>> l + tensor([[ 1.5528, 0.0000, 0.0000], + [-0.4821, 1.0592, 0.0000], + [ 0.9371, 0.5487, 0.7023]]) + >>> l @ l.mT + tensor([[ 2.4112, -0.7486, 1.4551], + [-0.7486, 1.3544, 0.1294], + [ 1.4551, 0.1294, 1.6724]]) + >>> a = torch.randn(3, 2, 2) # Example for batched input + >>> a = a @ a.mT + 1e-03 # make symmetric positive-definite + >>> l = torch.cholesky(a) + >>> z = l @ l.mT + >>> torch.dist(z, a) + tensor(2.3842e-07) +""", +) + +add_docstr( + torch.cholesky_solve, + r""" +cholesky_solve(B, L, upper=False, *, out=None) -> Tensor + +Computes the solution of a system of linear equations with complex Hermitian +or real symmetric positive-definite lhs given its Cholesky decomposition. + +Let :math:`A` be a complex Hermitian or real symmetric positive-definite matrix, +and :math:`L` its Cholesky decomposition such that: + +.. math:: + + A = LL^{\text{H}} + +where :math:`L^{\text{H}}` is the conjugate transpose when :math:`L` is complex, +and the transpose when :math:`L` is real-valued. + +Returns the solution :math:`X` of the following linear system: + +.. math:: + + AX = B + +Supports inputs of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :math:`A` or :math:`B` is a batch of matrices +then the output has the same batch dimensions. + +Args: + B (Tensor): right-hand side tensor of shape `(*, n, k)` + where :math:`*` is zero or more batch dimensions + L (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions + consisting of lower or upper triangular Cholesky decompositions of + symmetric or Hermitian positive-definite matrices. + upper (bool, optional): flag that indicates whether :math:`L` is lower triangular + or upper triangular. Default: ``False``. + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Example:: + + >>> A = torch.randn(3, 3) + >>> A = A @ A.T + torch.eye(3) * 1e-3 # Creates a symmetric positive-definite matrix + >>> L = torch.linalg.cholesky(A) # Extract Cholesky decomposition + >>> B = torch.randn(3, 2) + >>> torch.cholesky_solve(B, L) + tensor([[ -8.1625, 19.6097], + [ -5.8398, 14.2387], + [ -4.3771, 10.4173]]) + >>> A.inverse() @ B + tensor([[ -8.1626, 19.6097], + [ -5.8398, 14.2387], + [ -4.3771, 10.4173]]) + + >>> A = torch.randn(3, 2, 2, dtype=torch.complex64) + >>> A = A @ A.mH + torch.eye(2) * 1e-3 # Batch of Hermitian positive-definite matrices + >>> L = torch.linalg.cholesky(A) + >>> B = torch.randn(2, 1, dtype=torch.complex64) + >>> X = torch.cholesky_solve(B, L) + >>> torch.dist(X, A.inverse() @ B) + tensor(1.6881e-5) +""", +) + +add_docstr( + torch.cholesky_inverse, + r""" +cholesky_inverse(L, upper=False, *, out=None) -> Tensor + +Computes the inverse of a complex Hermitian or real symmetric +positive-definite matrix given its Cholesky decomposition. + +Let :math:`A` be a complex Hermitian or real symmetric positive-definite matrix, +and :math:`L` its Cholesky decomposition such that: + +.. math:: + + A = LL^{\text{H}} + +where :math:`L^{\text{H}}` is the conjugate transpose when :math:`L` is complex, +and the transpose when :math:`L` is real-valued. + +Computes the inverse matrix :math:`A^{-1}`. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :math:`A` is a batch of matrices +then the output has the same batch dimensions. + +Args: + L (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions + consisting of lower or upper triangular Cholesky decompositions of + symmetric or Hermitian positive-definite matrices. + upper (bool, optional): flag that indicates whether :math:`L` is lower triangular + or upper triangular. Default: ``False`` + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Example:: + + >>> A = torch.randn(3, 3) + >>> A = A @ A.T + torch.eye(3) * 1e-3 # Creates a symmetric positive-definite matrix + >>> L = torch.linalg.cholesky(A) # Extract Cholesky decomposition + >>> torch.cholesky_inverse(L) + tensor([[ 1.9314, 1.2251, -0.0889], + [ 1.2251, 2.4439, 0.2122], + [-0.0889, 0.2122, 0.1412]]) + >>> A.inverse() + tensor([[ 1.9314, 1.2251, -0.0889], + [ 1.2251, 2.4439, 0.2122], + [-0.0889, 0.2122, 0.1412]]) + + >>> A = torch.randn(3, 2, 2, dtype=torch.complex64) + >>> A = A @ A.mH + torch.eye(2) * 1e-3 # Batch of Hermitian positive-definite matrices + >>> L = torch.linalg.cholesky(A) + >>> torch.dist(torch.inverse(A), torch.cholesky_inverse(L)) + tensor(5.6358e-7) +""", +) + +add_docstr( + torch.clone, + r""" +clone(input, *, memory_format=torch.preserve_format) -> Tensor + +Returns a copy of :attr:`input`. + +.. note:: + + This function is differentiable, so gradients will flow back from the + result of this operation to :attr:`input`. To create a tensor without an + autograd relationship to :attr:`input` see :meth:`~Tensor.detach`. + +Args: + {input} + +Keyword args: + {memory_format} +""".format(**common_args), +) + +add_docstr( + torch.clamp, + r""" +clamp(input, min=None, max=None, *, out=None) -> Tensor + +Clamps all elements in :attr:`input` into the range `[` :attr:`min`, :attr:`max` `]`. +Letting min_value and max_value be :attr:`min` and :attr:`max`, respectively, this returns: + +.. math:: + y_i = \min(\max(x_i, \text{min\_value}_i), \text{max\_value}_i) + +If :attr:`min` is ``None``, there is no lower bound. +Or, if :attr:`max` is ``None`` there is no upper bound. +""" + + r""" + +.. note:: + If :attr:`min` is greater than :attr:`max` :func:`torch.clamp(..., min, max) ` + sets all elements in :attr:`input` to the value of :attr:`max`. + +Args: + {input} + min (Number or Tensor, optional): lower-bound of the range to be clamped to + max (Number or Tensor, optional): upper-bound of the range to be clamped to + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-1.7120, 0.1734, -0.0478, -0.0922]) + >>> torch.clamp(a, min=-0.5, max=0.5) + tensor([-0.5000, 0.1734, -0.0478, -0.0922]) + + >>> min = torch.linspace(-1, 1, steps=4) + >>> torch.clamp(a, min=min) + tensor([-1.0000, 0.1734, 0.3333, 1.0000]) + +""".format(**common_args), +) + +add_docstr( + torch.clip, + r""" +clip(input, min=None, max=None, *, out=None) -> Tensor + +Alias for :func:`torch.clamp`. +""", +) + +add_docstr( + torch.column_stack, + r""" +column_stack(tensors, *, out=None) -> Tensor + +Creates a new tensor by horizontally stacking the tensors in :attr:`tensors`. + +Equivalent to ``torch.hstack(tensors)``, except each zero or one dimensional tensor ``t`` +in :attr:`tensors` is first reshaped into a ``(t.numel(), 1)`` column before being stacked horizontally. + +Args: + tensors (sequence of Tensors): sequence of tensors to concatenate + +Keyword args: + {out} + +Example:: + + >>> a = torch.tensor([1, 2, 3]) + >>> b = torch.tensor([4, 5, 6]) + >>> torch.column_stack((a, b)) + tensor([[1, 4], + [2, 5], + [3, 6]]) + >>> a = torch.arange(5) + >>> b = torch.arange(10).reshape(5, 2) + >>> torch.column_stack((a, b, b)) + tensor([[0, 0, 1, 0, 1], + [1, 2, 3, 2, 3], + [2, 4, 5, 4, 5], + [3, 6, 7, 6, 7], + [4, 8, 9, 8, 9]]) + +""".format(**common_args), +) + +add_docstr( + torch.complex, + r""" +complex(real, imag, *, out=None) -> Tensor + +Constructs a complex tensor with its real part equal to :attr:`real` and its +imaginary part equal to :attr:`imag`. + +Args: + real (Tensor): The real part of the complex tensor. Must be half, float or double. + imag (Tensor): The imaginary part of the complex tensor. Must be same dtype + as :attr:`real`. + +Keyword args: + out (Tensor): If the inputs are ``torch.float32``, must be + ``torch.complex64``. If the inputs are ``torch.float64``, must be + ``torch.complex128``. + +Example:: + + >>> real = torch.tensor([1, 2], dtype=torch.float32) + >>> imag = torch.tensor([3, 4], dtype=torch.float32) + >>> z = torch.complex(real, imag) + >>> z + tensor([(1.+3.j), (2.+4.j)]) + >>> z.dtype + torch.complex64 + +""", +) + +add_docstr( + torch.polar, + r""" +polar(abs, angle, *, out=None) -> Tensor + +Constructs a complex tensor whose elements are Cartesian coordinates +corresponding to the polar coordinates with absolute value :attr:`abs` and angle +:attr:`angle`. + +.. math:: + \text{out} = \text{abs} \cdot \cos(\text{angle}) + \text{abs} \cdot \sin(\text{angle}) \cdot j + +.. note:: + `torch.polar` is similar to + `std::polar `_ + and does not compute the polar decomposition + of a complex tensor like Python's `cmath.polar` and SciPy's `linalg.polar` do. + The behavior of this function is undefined if `abs` is negative or NaN, or if `angle` is + infinite. + +""" + + r""" +Args: + abs (Tensor): The absolute value the complex tensor. Must be float or double. + angle (Tensor): The angle of the complex tensor. Must be same dtype as + :attr:`abs`. + +Keyword args: + out (Tensor): If the inputs are ``torch.float32``, must be + ``torch.complex64``. If the inputs are ``torch.float64``, must be + ``torch.complex128``. + +Example:: + + >>> import numpy as np + >>> abs = torch.tensor([1, 2], dtype=torch.float64) + >>> angle = torch.tensor([np.pi / 2, 5 * np.pi / 4], dtype=torch.float64) + >>> z = torch.polar(abs, angle) + >>> z + tensor([(0.0000+1.0000j), (-1.4142-1.4142j)], dtype=torch.complex128) +""", +) + +add_docstr( + torch.conj_physical, + r""" +conj_physical(input, *, out=None) -> Tensor + +Computes the element-wise conjugate of the given :attr:`input` tensor. +If :attr:`input` has a non-complex dtype, this function just returns :attr:`input`. + +.. note:: + This performs the conjugate operation regardless of the fact conjugate bit is set or not. + +.. warning:: In the future, :func:`torch.conj_physical` may return a non-writeable view for an :attr:`input` of + non-complex dtype. It's recommended that programs not modify the tensor returned by :func:`torch.conj_physical` + when :attr:`input` is of non-complex dtype to be compatible with this change. + +.. math:: + \text{out}_{i} = conj(\text{input}_{i}) +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> torch.conj_physical(torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j])) + tensor([-1 - 1j, -2 - 2j, 3 + 3j]) +""".format(**common_args), +) + +add_docstr( + torch.conj, + r""" +conj(input) -> Tensor + +Returns a view of :attr:`input` with a flipped conjugate bit. If :attr:`input` has a non-complex dtype, +this function just returns :attr:`input`. + +.. note:: + :func:`torch.conj` performs a lazy conjugation, but the actual conjugated tensor can be materialized + at any time using :func:`torch.resolve_conj`. + +.. warning:: In the future, :func:`torch.conj` may return a non-writeable view for an :attr:`input` of + non-complex dtype. It's recommended that programs not modify the tensor returned by :func:`torch.conj_physical` + when :attr:`input` is of non-complex dtype to be compatible with this change. + +Args: + {input} + +Example:: + + >>> x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]) + >>> x.is_conj() + False + >>> y = torch.conj(x) + >>> y.is_conj() + True +""".format(**common_args), +) + +add_docstr( + torch.resolve_conj, + r""" +resolve_conj(input) -> Tensor + +Returns a new tensor with materialized conjugation if :attr:`input`'s conjugate bit is set to `True`, +else returns :attr:`input`. The output tensor will always have its conjugate bit set to `False`. + +Args: + {input} + +Example:: + + >>> x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]) + >>> y = x.conj() + >>> y.is_conj() + True + >>> z = y.resolve_conj() + >>> z + tensor([-1 - 1j, -2 - 2j, 3 + 3j]) + >>> z.is_conj() + False +""".format(**common_args), +) + +add_docstr( + torch.resolve_neg, + r""" +resolve_neg(input) -> Tensor + +Returns a new tensor with materialized negation if :attr:`input`'s negative bit is set to `True`, +else returns :attr:`input`. The output tensor will always have its negative bit set to `False`. + +Args: + {input} + +Example:: + + >>> x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]) + >>> y = x.conj() + >>> z = y.imag + >>> z.is_neg() + True + >>> out = z.resolve_neg() + >>> out + tensor([-1., -2., 3.]) + >>> out.is_neg() + False +""".format(**common_args), +) + +add_docstr( + torch.copysign, + r""" +copysign(input, other, *, out=None) -> Tensor + +Create a new floating-point tensor with the magnitude of :attr:`input` and the sign of :attr:`other`, elementwise. + +.. math:: + \text{out}_{i} = \begin{cases} + -|\text{input}_{i}| & \text{if } \text{other}_{i} \leq -0.0 \\ + |\text{input}_{i}| & \text{if } \text{other}_{i} \geq 0.0 \\ + \end{cases} +""" + + r""" + +Supports :ref:`broadcasting to a common shape `, +and integer and float inputs. + +Args: + input (Tensor): magnitudes. + other (Tensor or Number): contains value(s) whose signbit(s) are + applied to the magnitudes in :attr:`input`. + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(5) + >>> a + tensor([-1.2557, -0.0026, -0.5387, 0.4740, -0.9244]) + >>> torch.copysign(a, 1) + tensor([1.2557, 0.0026, 0.5387, 0.4740, 0.9244]) + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.7079, 0.2778, -1.0249, 0.5719], + [-0.0059, -0.2600, -0.4475, -1.3948], + [ 0.3667, -0.9567, -2.5757, -0.1751], + [ 0.2046, -0.0742, 0.2998, -0.1054]]) + >>> b = torch.randn(4) + tensor([ 0.2373, 0.3120, 0.3190, -1.1128]) + >>> torch.copysign(a, b) + tensor([[ 0.7079, 0.2778, 1.0249, -0.5719], + [ 0.0059, 0.2600, 0.4475, -1.3948], + [ 0.3667, 0.9567, 2.5757, -0.1751], + [ 0.2046, 0.0742, 0.2998, -0.1054]]) + >>> a = torch.tensor([1.]) + >>> b = torch.tensor([-0.]) + >>> torch.copysign(a, b) + tensor([-1.]) + +.. note:: + copysign handles signed zeros. If the other argument has a negative zero (-0), + the corresponding output value will be negative. + +""".format(**common_args), +) + +add_docstr( + torch.cos, + r""" +cos(input, *, out=None) -> Tensor + +Returns a new tensor with the cosine of the elements of :attr:`input`. + +.. math:: + \text{out}_{i} = \cos(\text{input}_{i}) +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 1.4309, 1.2706, -0.8562, 0.9796]) + >>> torch.cos(a) + tensor([ 0.1395, 0.2957, 0.6553, 0.5574]) +""".format(**common_args), +) + +add_docstr( + torch.cosh, + r""" +cosh(input, *, out=None) -> Tensor + +Returns a new tensor with the hyperbolic cosine of the elements of +:attr:`input`. + +.. math:: + \text{out}_{i} = \cosh(\text{input}_{i}) +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.1632, 1.1835, -0.6979, -0.7325]) + >>> torch.cosh(a) + tensor([ 1.0133, 1.7860, 1.2536, 1.2805]) + +.. note:: + When :attr:`input` is on the CPU, the implementation of torch.cosh may use + the Sleef library, which rounds very large results to infinity or negative + infinity. See `here `_ for details. +""".format(**common_args), +) + +add_docstr( + torch.cross, + r""" +cross(input, other, dim=None, *, out=None) -> Tensor + + +Returns the cross product of vectors in dimension :attr:`dim` of :attr:`input` +and :attr:`other`. + +Supports input of float, double, cfloat and cdouble dtypes. Also supports batches +of vectors, for which it computes the product along the dimension :attr:`dim`. +In this case, the output has the same batch dimensions as the inputs. + +.. warning:: + If :attr:`dim` is not given, it defaults to the first dimension found + with the size 3. Note that this might be unexpected. + + This behavior is deprecated and will be changed to match that of :func:`torch.linalg.cross` + in a future release. + +.. seealso:: + :func:`torch.linalg.cross` which has dim=-1 as default. + + +Args: + {input} + other (Tensor): the second input tensor + dim (int, optional): the dimension to take the cross-product in. + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4, 3) + >>> a + tensor([[-0.3956, 1.1455, 1.6895], + [-0.5849, 1.3672, 0.3599], + [-1.1626, 0.7180, -0.0521], + [-0.1339, 0.9902, -2.0225]]) + >>> b = torch.randn(4, 3) + >>> b + tensor([[-0.0257, -1.4725, -1.2251], + [-1.1479, -0.7005, -1.9757], + [-1.3904, 0.3726, -1.1836], + [-0.9688, -0.7153, 0.2159]]) + >>> torch.cross(a, b, dim=1) + tensor([[ 1.0844, -0.5281, 0.6120], + [-2.4490, -1.5687, 1.9792], + [-0.8304, -1.3037, 0.5650], + [-1.2329, 1.9883, 1.0551]]) + >>> torch.cross(a, b) + tensor([[ 1.0844, -0.5281, 0.6120], + [-2.4490, -1.5687, 1.9792], + [-0.8304, -1.3037, 0.5650], + [-1.2329, 1.9883, 1.0551]]) +""".format(**common_args), +) + +add_docstr( + torch.logcumsumexp, + r""" +logcumsumexp(input, dim, *, out=None) -> Tensor +Returns the logarithm of the cumulative summation of the exponentiation of +elements of :attr:`input` in the dimension :attr:`dim`. + +For summation index :math:`j` given by `dim` and other indices :math:`i`, the result is + + .. math:: + \text{{logcumsumexp}}(x)_{{ij}} = \log \sum\limits_{{k=0}}^{{j}} \exp(x_{{ik}}) + +Args: + {input} + dim (int): the dimension to do the operation over + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(10) + >>> torch.logcumsumexp(a, dim=0) + tensor([-0.42296738, -0.04462666, 0.86278635, 0.94622083, 1.05277811, + 1.39202815, 1.83525007, 1.84492621, 2.06084887, 2.06844475])) +""".format(**reduceops_common_args), +) + +add_docstr( + torch.cummax, + r""" +cummax(input, dim, *, out=None) -> (Tensor, LongTensor) +Returns a namedtuple ``(values, indices)`` where ``values`` is the cumulative maximum of +elements of :attr:`input` in the dimension :attr:`dim`. And ``indices`` is the index +location of each maximum value found in the dimension :attr:`dim`. + +.. math:: + y_i = max(x_1, x_2, x_3, \dots, x_i) + +Args: + {input} + dim (int): the dimension to do the operation over + +Keyword args: + out (tuple, optional): the result tuple of two output tensors (values, indices) + +Example:: + + >>> a = torch.randn(10) + >>> a + tensor([-0.3449, -1.5447, 0.0685, -1.5104, -1.1706, 0.2259, 1.4696, -1.3284, + 1.9946, -0.8209]) + >>> torch.cummax(a, dim=0) + torch.return_types.cummax( + values=tensor([-0.3449, -0.3449, 0.0685, 0.0685, 0.0685, 0.2259, 1.4696, 1.4696, + 1.9946, 1.9946]), + indices=tensor([0, 0, 2, 2, 2, 5, 6, 6, 8, 8])) +""".format(**reduceops_common_args), +) + +add_docstr( + torch.cummin, + r""" +cummin(input, dim, *, out=None) -> (Tensor, LongTensor) +Returns a namedtuple ``(values, indices)`` where ``values`` is the cumulative minimum of +elements of :attr:`input` in the dimension :attr:`dim`. And ``indices`` is the index +location of each maximum value found in the dimension :attr:`dim`. + +.. math:: + y_i = min(x_1, x_2, x_3, \dots, x_i) + +Args: + {input} + dim (int): the dimension to do the operation over + +Keyword args: + out (tuple, optional): the result tuple of two output tensors (values, indices) + +Example:: + + >>> a = torch.randn(10) + >>> a + tensor([-0.2284, -0.6628, 0.0975, 0.2680, -1.3298, -0.4220, -0.3885, 1.1762, + 0.9165, 1.6684]) + >>> torch.cummin(a, dim=0) + torch.return_types.cummin( + values=tensor([-0.2284, -0.6628, -0.6628, -0.6628, -1.3298, -1.3298, -1.3298, -1.3298, + -1.3298, -1.3298]), + indices=tensor([0, 1, 1, 1, 4, 4, 4, 4, 4, 4])) +""".format(**reduceops_common_args), +) + +add_docstr( + torch.cumprod, + r""" +cumprod(input, dim, *, dtype=None, out=None) -> Tensor + +Returns the cumulative product of elements of :attr:`input` in the dimension +:attr:`dim`. + +For example, if :attr:`input` is a vector of size N, the result will also be +a vector of size N, with elements. + +.. math:: + y_i = x_1 \times x_2\times x_3\times \dots \times x_i + +Args: + {input} + dim (int): the dimension to do the operation over + +Keyword args: + {dtype} + {out} + +Example:: + + >>> a = torch.randn(10) + >>> a + tensor([ 0.6001, 0.2069, -0.1919, 0.9792, 0.6727, 1.0062, 0.4126, + -0.2129, -0.4206, 0.1968]) + >>> torch.cumprod(a, dim=0) + tensor([ 0.6001, 0.1241, -0.0238, -0.0233, -0.0157, -0.0158, -0.0065, + 0.0014, -0.0006, -0.0001]) + + >>> a[5] = 0.0 + >>> torch.cumprod(a, dim=0) + tensor([ 0.6001, 0.1241, -0.0238, -0.0233, -0.0157, -0.0000, -0.0000, + 0.0000, -0.0000, -0.0000]) +""".format(**reduceops_common_args), +) + +add_docstr( + torch.cumsum, + r""" +cumsum(input, dim, *, dtype=None, out=None) -> Tensor + +Returns the cumulative sum of elements of :attr:`input` in the dimension +:attr:`dim`. + +For example, if :attr:`input` is a vector of size N, the result will also be +a vector of size N, with elements. + +.. math:: + y_i = x_1 + x_2 + x_3 + \dots + x_i + +Args: + {input} + dim (int): the dimension to do the operation over + +Keyword args: + {dtype} + {out} + +Example:: + + >>> a = torch.randint(1, 20, (10,)) + >>> a + tensor([13, 7, 3, 10, 13, 3, 15, 10, 9, 10]) + >>> torch.cumsum(a, dim=0) + tensor([13, 20, 23, 33, 46, 49, 64, 74, 83, 93]) +""".format(**reduceops_common_args), +) + +add_docstr( + torch.count_nonzero, + r""" +count_nonzero(input, dim=None) -> Tensor + +Counts the number of non-zero values in the tensor :attr:`input` along the given :attr:`dim`. +If no dim is specified then all non-zeros in the tensor are counted. + +Args: + {input} + dim (int or tuple of ints, optional): Dim or tuple of dims along which to count non-zeros. + +Example:: + + >>> x = torch.zeros(3,3) + >>> x[torch.randn(3,3) > 0.5] = 1 + >>> x + tensor([[0., 1., 1.], + [0., 0., 0.], + [0., 0., 1.]]) + >>> torch.count_nonzero(x) + tensor(3) + >>> torch.count_nonzero(x, dim=0) + tensor([0, 1, 2]) +""".format(**reduceops_common_args), +) + +add_docstr( + torch.dequantize, + r""" +dequantize(tensor) -> Tensor + +Returns an fp32 Tensor by dequantizing a quantized Tensor + +Args: + tensor (Tensor): A quantized Tensor + +.. function:: dequantize(tensors) -> sequence of Tensors + :noindex: + +Given a list of quantized Tensors, dequantize them and return a list of fp32 Tensors + +Args: + tensors (sequence of Tensors): A list of quantized Tensors +""", +) + +add_docstr( + torch.diag, + r""" +diag(input, diagonal=0, *, out=None) -> Tensor + +- If :attr:`input` is a vector (1-D tensor), then returns a 2-D square tensor + with the elements of :attr:`input` as the diagonal. +- If :attr:`input` is a matrix (2-D tensor), then returns a 1-D tensor with + the diagonal elements of :attr:`input`. + +The argument :attr:`diagonal` controls which diagonal to consider: + +- If :attr:`diagonal` = 0, it is the main diagonal. +- If :attr:`diagonal` > 0, it is above the main diagonal. +- If :attr:`diagonal` < 0, it is below the main diagonal. + +Args: + {input} + diagonal (int, optional): the diagonal to consider + +Keyword args: + {out} + +.. seealso:: + + :func:`torch.diagonal` always returns the diagonal of its input. + + :func:`torch.diagflat` always constructs a tensor with diagonal elements + specified by the input. + +Examples: + +Get the square matrix where the input vector is the diagonal:: + + >>> a = torch.randn(3) + >>> a + tensor([ 0.5950,-0.0872, 2.3298]) + >>> torch.diag(a) + tensor([[ 0.5950, 0.0000, 0.0000], + [ 0.0000,-0.0872, 0.0000], + [ 0.0000, 0.0000, 2.3298]]) + >>> torch.diag(a, 1) + tensor([[ 0.0000, 0.5950, 0.0000, 0.0000], + [ 0.0000, 0.0000,-0.0872, 0.0000], + [ 0.0000, 0.0000, 0.0000, 2.3298], + [ 0.0000, 0.0000, 0.0000, 0.0000]]) + +Get the k-th diagonal of a given matrix:: + + >>> a = torch.randn(3, 3) + >>> a + tensor([[-0.4264, 0.0255,-0.1064], + [ 0.8795,-0.2429, 0.1374], + [ 0.1029,-0.6482,-1.6300]]) + >>> torch.diag(a, 0) + tensor([-0.4264,-0.2429,-1.6300]) + >>> torch.diag(a, 1) + tensor([ 0.0255, 0.1374]) +""".format(**common_args), +) + +add_docstr( + torch.diag_embed, + r""" +diag_embed(input, offset=0, dim1=-2, dim2=-1) -> Tensor + +Creates a tensor whose diagonals of certain 2D planes (specified by +:attr:`dim1` and :attr:`dim2`) are filled by :attr:`input`. +To facilitate creating batched diagonal matrices, the 2D planes formed by +the last two dimensions of the returned tensor are chosen by default. + +The argument :attr:`offset` controls which diagonal to consider: + +- If :attr:`offset` = 0, it is the main diagonal. +- If :attr:`offset` > 0, it is above the main diagonal. +- If :attr:`offset` < 0, it is below the main diagonal. + +The size of the new matrix will be calculated to make the specified diagonal +of the size of the last input dimension. +Note that for :attr:`offset` other than :math:`0`, the order of :attr:`dim1` +and :attr:`dim2` matters. Exchanging them is equivalent to changing the +sign of :attr:`offset`. + +Applying :meth:`torch.diagonal` to the output of this function with +the same arguments yields a matrix identical to input. However, +:meth:`torch.diagonal` has different default dimensions, so those +need to be explicitly specified. + +Args: + {input} Must be at least 1-dimensional. + offset (int, optional): which diagonal to consider. Default: 0 + (main diagonal). + dim1 (int, optional): first dimension with respect to which to + take diagonal. Default: -2. + dim2 (int, optional): second dimension with respect to which to + take diagonal. Default: -1. + +Example:: + + >>> a = torch.randn(2, 3) + >>> torch.diag_embed(a) + tensor([[[ 1.5410, 0.0000, 0.0000], + [ 0.0000, -0.2934, 0.0000], + [ 0.0000, 0.0000, -2.1788]], + + [[ 0.5684, 0.0000, 0.0000], + [ 0.0000, -1.0845, 0.0000], + [ 0.0000, 0.0000, -1.3986]]]) + + >>> torch.diag_embed(a, offset=1, dim1=0, dim2=2) + tensor([[[ 0.0000, 1.5410, 0.0000, 0.0000], + [ 0.0000, 0.5684, 0.0000, 0.0000]], + + [[ 0.0000, 0.0000, -0.2934, 0.0000], + [ 0.0000, 0.0000, -1.0845, 0.0000]], + + [[ 0.0000, 0.0000, 0.0000, -2.1788], + [ 0.0000, 0.0000, 0.0000, -1.3986]], + + [[ 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000]]]) +""".format(**common_args), +) + + +add_docstr( + torch.diagflat, + r""" +diagflat(input, offset=0) -> Tensor + +- If :attr:`input` is a vector (1-D tensor), then returns a 2-D square tensor + with the elements of :attr:`input` as the diagonal. +- If :attr:`input` is a tensor with more than one dimension, then returns a + 2-D tensor with diagonal elements equal to a flattened :attr:`input`. + +The argument :attr:`offset` controls which diagonal to consider: + +- If :attr:`offset` = 0, it is the main diagonal. +- If :attr:`offset` > 0, it is above the main diagonal. +- If :attr:`offset` < 0, it is below the main diagonal. + +Args: + {input} + offset (int, optional): the diagonal to consider. Default: 0 (main + diagonal). + +Examples:: + + >>> a = torch.randn(3) + >>> a + tensor([-0.2956, -0.9068, 0.1695]) + >>> torch.diagflat(a) + tensor([[-0.2956, 0.0000, 0.0000], + [ 0.0000, -0.9068, 0.0000], + [ 0.0000, 0.0000, 0.1695]]) + >>> torch.diagflat(a, 1) + tensor([[ 0.0000, -0.2956, 0.0000, 0.0000], + [ 0.0000, 0.0000, -0.9068, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.1695], + [ 0.0000, 0.0000, 0.0000, 0.0000]]) + + >>> a = torch.randn(2, 2) + >>> a + tensor([[ 0.2094, -0.3018], + [-0.1516, 1.9342]]) + >>> torch.diagflat(a) + tensor([[ 0.2094, 0.0000, 0.0000, 0.0000], + [ 0.0000, -0.3018, 0.0000, 0.0000], + [ 0.0000, 0.0000, -0.1516, 0.0000], + [ 0.0000, 0.0000, 0.0000, 1.9342]]) +""".format(**common_args), +) + +add_docstr( + torch.diagonal, + r""" +diagonal(input, offset=0, dim1=0, dim2=1) -> Tensor + +Returns a partial view of :attr:`input` with the its diagonal elements +with respect to :attr:`dim1` and :attr:`dim2` appended as a dimension +at the end of the shape. + +The argument :attr:`offset` controls which diagonal to consider: + +- If :attr:`offset` = 0, it is the main diagonal. +- If :attr:`offset` > 0, it is above the main diagonal. +- If :attr:`offset` < 0, it is below the main diagonal. + +Applying :meth:`torch.diag_embed` to the output of this function with +the same arguments yields a diagonal matrix with the diagonal entries +of the input. However, :meth:`torch.diag_embed` has different default +dimensions, so those need to be explicitly specified. + +Args: + {input} Must be at least 2-dimensional. + offset (int, optional): which diagonal to consider. Default: 0 + (main diagonal). + dim1 (int, optional): first dimension with respect to which to + take diagonal. Default: 0. + dim2 (int, optional): second dimension with respect to which to + take diagonal. Default: 1. + +.. note:: To take a batch diagonal, pass in dim1=-2, dim2=-1. + +Examples:: + + >>> a = torch.randn(3, 3) + >>> a + tensor([[-1.0854, 1.1431, -0.1752], + [ 0.8536, -0.0905, 0.0360], + [ 0.6927, -0.3735, -0.4945]]) + + + >>> torch.diagonal(a) + tensor([-1.0854, -0.0905, -0.4945]) + + + >>> torch.diagonal(a, 1) + tensor([ 1.1431, 0.0360]) + + >>> b = torch.randn(2, 5) + >>> b + tensor([[-1.7948, -1.2731, -0.3181, 2.0200, -1.6745], + [ 1.8262, -1.5049, 0.4114, 1.0704, -1.2607]]) + + >>> torch.diagonal(b, 1, 1, 0) + tensor([1.8262]) + + >>> x = torch.randn(2, 5, 4, 2) + >>> torch.diagonal(x, offset=-1, dim1=1, dim2=2) + tensor([[[-1.2631, 0.3755, -1.5977, -1.8172], + [-1.1065, 1.0401, -0.2235, -0.7938]], + + [[-1.7325, -0.3081, 0.6166, 0.2335], + [ 1.0500, 0.7336, -0.3836, -1.1015]]]) +""".format(**common_args), +) + +add_docstr( + torch.diagonal_scatter, + r""" +diagonal_scatter(input, src, offset=0, dim1=0, dim2=1) -> Tensor + +Embeds the values of the :attr:`src` tensor into :attr:`input` along +the diagonal elements of :attr:`input`, with respect to :attr:`dim1` +and :attr:`dim2`. + +This function returns a tensor with fresh storage; it does not +return a view. + +The argument :attr:`offset` controls which diagonal to consider: + +- If :attr:`offset` = 0, it is the main diagonal. +- If :attr:`offset` > 0, it is above the main diagonal. +- If :attr:`offset` < 0, it is below the main diagonal. + +Args: + {input} Must be at least 2-dimensional. + src (Tensor): the tensor to embed into :attr:`input`. + offset (int, optional): which diagonal to consider. Default: 0 + (main diagonal). + dim1 (int, optional): first dimension with respect to which to + take diagonal. Default: 0. + dim2 (int, optional): second dimension with respect to which to + take diagonal. Default: 1. + +.. note:: + + :attr:`src` must be of the proper size in order to be embedded + into :attr:`input`. Specifically, it should have the same shape as + ``torch.diagonal(input, offset, dim1, dim2)`` + +Examples:: + + >>> a = torch.zeros(3, 3) + >>> a + tensor([[0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.]]) + + >>> torch.diagonal_scatter(a, torch.ones(3), 0) + tensor([[1., 0., 0.], + [0., 1., 0.], + [0., 0., 1.]]) + + >>> torch.diagonal_scatter(a, torch.ones(2), 1) + tensor([[0., 1., 0.], + [0., 0., 1.], + [0., 0., 0.]]) +""".format(**common_args), +) + +add_docstr( + torch.as_strided_scatter, + r""" +as_strided_scatter(input, src, size, stride, storage_offset=None) -> Tensor + +Embeds the values of the :attr:`src` tensor into :attr:`input` along +the elements corresponding to the result of calling +input.as_strided(size, stride, storage_offset). + +This function returns a tensor with fresh storage; it does not +return a view. + +Args: + {input} + size (tuple or ints): the shape of the output tensor + stride (tuple or ints): the stride of the output tensor + storage_offset (int, optional): the offset in the underlying storage of the output tensor + +.. note:: + + :attr:`src` must be of the proper size in order to be embedded + into :attr:`input`. Specifically, it should have the same shape as + `torch.as_strided(input, size, stride, storage_offset)` + +Example:: + + >>> a = torch.arange(4).reshape(2, 2) + 1 + >>> a + tensor([[1, 2], + [3, 4]]) + >>> b = torch.zeros(3, 3) + >>> b + tensor([[0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.]]) + >>> torch.as_strided_scatter(b, a, (2, 2), (1, 2)) + tensor([[1., 3., 2.], + [4., 0., 0.], + [0., 0., 0.]]) + +""".format(**common_args), +) + +add_docstr( + torch.diff, + r""" +diff(input, n=1, dim=-1, prepend=None, append=None) -> Tensor + +Computes the n-th forward difference along the given dimension. + +The first-order differences are given by `out[i] = input[i + 1] - input[i]`. Higher-order +differences are calculated by using :func:`torch.diff` recursively. + +Args: + input (Tensor): the tensor to compute the differences on + n (int, optional): the number of times to recursively compute the difference + dim (int, optional): the dimension to compute the difference along. + Default is the last dimension. + prepend, append (Tensor, optional): values to prepend or append to + :attr:`input` along :attr:`dim` before computing the difference. + Their dimensions must be equivalent to that of input, and their shapes + must match input's shape except on :attr:`dim`. + +Keyword args: + {out} + +Example:: + + >>> a = torch.tensor([1, 3, 2]) + >>> torch.diff(a) + tensor([ 2, -1]) + >>> b = torch.tensor([4, 5]) + >>> torch.diff(a, append=b) + tensor([ 2, -1, 2, 1]) + >>> c = torch.tensor([[1, 2, 3], [3, 4, 5]]) + >>> torch.diff(c, dim=0) + tensor([[2, 2, 2]]) + >>> torch.diff(c, dim=1) + tensor([[1, 1], + [1, 1]]) +""".format(**common_args), +) + +add_docstr( + torch.digamma, + r""" +digamma(input, *, out=None) -> Tensor + +Alias for :func:`torch.special.digamma`. +""", +) + +add_docstr( + torch.dist, + r""" +dist(input, other, p=2) -> Tensor + +Returns the p-norm of (:attr:`input` - :attr:`other`) + +The shapes of :attr:`input` and :attr:`other` must be +:ref:`broadcastable `. + +Args: + {input} + other (Tensor): the Right-hand-side input tensor + p (float, optional): the norm to be computed + +Example:: + + >>> x = torch.randn(4) + >>> x + tensor([-1.5393, -0.8675, 0.5916, 1.6321]) + >>> y = torch.randn(4) + >>> y + tensor([ 0.0967, -1.0511, 0.6295, 0.8360]) + >>> torch.dist(x, y, 3.5) + tensor(1.6727) + >>> torch.dist(x, y, 3) + tensor(1.6973) + >>> torch.dist(x, y, 0) + tensor(4.) + >>> torch.dist(x, y, 1) + tensor(2.6537) +""".format(**common_args), +) + +add_docstr( + torch.div, + r""" +div(input, other, *, rounding_mode=None, out=None) -> Tensor + +Divides each element of the input ``input`` by the corresponding element of +:attr:`other`. + +.. math:: + \text{{out}}_i = \frac{{\text{{input}}_i}}{{\text{{other}}_i}} + +.. note:: + By default, this performs a "true" division like Python 3. + See the :attr:`rounding_mode` argument for floor division. + +Supports :ref:`broadcasting to a common shape `, +:ref:`type promotion `, and integer, float, and complex inputs. +Always promotes integer types to the default scalar type. + +Args: + input (Tensor): the dividend + other (Tensor or Number): the divisor + +Keyword args: + rounding_mode (str, optional): Type of rounding applied to the result: + + * None - default behavior. Performs no rounding and, if both :attr:`input` and + :attr:`other` are integer types, promotes the inputs to the default scalar type. + Equivalent to true division in Python (the ``/`` operator) and NumPy's ``np.true_divide``. + * ``"trunc"`` - rounds the results of the division towards zero. + Equivalent to C-style integer division. + * ``"floor"`` - rounds the results of the division down. + Equivalent to floor division in Python (the ``//`` operator) and NumPy's ``np.floor_divide``. + + {out} + +Examples:: + + >>> x = torch.tensor([ 0.3810, 1.2774, -0.2972, -0.3719, 0.4637]) + >>> torch.div(x, 0.5) + tensor([ 0.7620, 2.5548, -0.5944, -0.7438, 0.9274]) + + >>> a = torch.tensor([[-0.3711, -1.9353, -0.4605, -0.2917], + ... [ 0.1815, -1.0111, 0.9805, -1.5923], + ... [ 0.1062, 1.4581, 0.7759, -1.2344], + ... [-0.1830, -0.0313, 1.1908, -1.4757]]) + >>> b = torch.tensor([ 0.8032, 0.2930, -0.8113, -0.2308]) + >>> torch.div(a, b) + tensor([[-0.4620, -6.6051, 0.5676, 1.2639], + [ 0.2260, -3.4509, -1.2086, 6.8990], + [ 0.1322, 4.9764, -0.9564, 5.3484], + [-0.2278, -0.1068, -1.4678, 6.3938]]) + + >>> torch.div(a, b, rounding_mode='trunc') + tensor([[-0., -6., 0., 1.], + [ 0., -3., -1., 6.], + [ 0., 4., -0., 5.], + [-0., -0., -1., 6.]]) + + >>> torch.div(a, b, rounding_mode='floor') + tensor([[-1., -7., 0., 1.], + [ 0., -4., -2., 6.], + [ 0., 4., -1., 5.], + [-1., -1., -2., 6.]]) + +""".format(**common_args), +) + +add_docstr( + torch.divide, + r""" +divide(input, other, *, rounding_mode=None, out=None) -> Tensor + +Alias for :func:`torch.div`. +""", +) + +add_docstr( + torch.dot, + r""" +dot(input, tensor, *, out=None) -> Tensor + +Computes the dot product of two 1D tensors. + +.. note:: + + Unlike NumPy's dot, torch.dot intentionally only supports computing the dot product + of two 1D tensors with the same number of elements. + +Args: + input (Tensor): first tensor in the dot product, must be 1D. + tensor (Tensor): second tensor in the dot product, must be 1D. + +Keyword args: + {out} + +Example:: + + >>> torch.dot(torch.tensor([2, 3]), torch.tensor([2, 1])) + tensor(7) + + >>> t1, t2 = torch.tensor([0, 1]), torch.tensor([2, 3]) + >>> torch.dot(t1, t2) + tensor(3) +""".format(**common_args), +) + +add_docstr( + torch.vdot, + r""" +vdot(input, other, *, out=None) -> Tensor + +Computes the dot product of two 1D vectors along a dimension. + +In symbols, this function computes + +.. math:: + + \sum_{i=1}^n \overline{x_i}y_i. + +where :math:`\overline{x_i}` denotes the conjugate for complex +vectors, and it is the identity for real vectors. + +.. note:: + + Unlike NumPy's vdot, torch.vdot intentionally only supports computing the dot product + of two 1D tensors with the same number of elements. + +.. seealso:: + + :func:`torch.linalg.vecdot` computes the dot product of two batches of vectors along a dimension. + +Args: + input (Tensor): first tensor in the dot product, must be 1D. Its conjugate is used if it's complex. + other (Tensor): second tensor in the dot product, must be 1D. + +Keyword args: +""" + + rf""" +.. note:: {common_args["out"]} +""" + + r""" + +Example:: + + >>> torch.vdot(torch.tensor([2, 3]), torch.tensor([2, 1])) + tensor(7) + >>> a = torch.tensor((1 +2j, 3 - 1j)) + >>> b = torch.tensor((2 +1j, 4 - 0j)) + >>> torch.vdot(a, b) + tensor([16.+1.j]) + >>> torch.vdot(b, a) + tensor([16.-1.j]) +""", +) + +add_docstr( + torch.eq, + r""" +eq(input, other, *, out=None) -> Tensor + +Computes element-wise equality + +The second argument can be a number or a tensor whose shape is +:ref:`broadcastable ` with the first argument. + +Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + +Keyword args: + {out} + +Returns: + A boolean tensor that is True where :attr:`input` is equal to :attr:`other` and False elsewhere + +Example:: + + >>> torch.eq(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[ True, False], + [False, True]]) +""".format(**common_args), +) + +add_docstr( + torch.equal, + r""" +equal(input, other) -> bool + +``True`` if two tensors have the same size and elements, ``False`` otherwise. + +.. note:: + + Tensors containing NaNs are never equal to each other. Additionally, this function does not + differentiate between the data types of the tensors during comparison. For more thorough tensor checks, + use :meth:`torch.testing.assert_close`. + +Example:: + + >>> torch.equal(torch.tensor([1, 2]), torch.tensor([1, 2])) + True + >>> torch.equal(torch.tensor([3, torch.nan]), torch.tensor([3, torch.nan])) + False + >>> torch.equal(torch.tensor([1, 2, 3], dtype=torch.int32), torch.tensor([1, 2, 3], dtype=torch.float32)) + True +""", +) + +add_docstr( + torch.erf, + r""" +erf(input, *, out=None) -> Tensor + +Alias for :func:`torch.special.erf`. +""", +) + +add_docstr( + torch.erfc, + r""" +erfc(input, *, out=None) -> Tensor + +Alias for :func:`torch.special.erfc`. +""", +) + +add_docstr( + torch.erfinv, + r""" +erfinv(input, *, out=None) -> Tensor + +Alias for :func:`torch.special.erfinv`. +""", +) + +add_docstr( + torch.exp, + r""" +exp(input, *, out=None) -> Tensor + +Returns a new tensor with the exponential of the elements +of the input tensor :attr:`input`. + +.. math:: + y_{i} = e^{x_{i}} +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> torch.exp(torch.tensor([0, math.log(2.)])) + tensor([ 1., 2.]) +""".format(**common_args), +) + +add_docstr( + torch.exp2, + r""" +exp2(input, *, out=None) -> Tensor + +Alias for :func:`torch.special.exp2`. +""", +) + +add_docstr( + torch.expm1, + r""" +expm1(input, *, out=None) -> Tensor + +Alias for :func:`torch.special.expm1`. +""", +) + +add_docstr( + torch.eye, + r""" +eye(n, m=None, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + +Returns a 2-D tensor with ones on the diagonal and zeros elsewhere. + +Args: + n (int): the number of rows + m (int, optional): the number of columns with default being :attr:`n` + +Keyword arguments: + {out} + {dtype} + {layout} + {device} + {requires_grad} + +Returns: + Tensor: A 2-D tensor with ones on the diagonal and zeros elsewhere + +Example:: + + >>> torch.eye(3) + tensor([[ 1., 0., 0.], + [ 0., 1., 0.], + [ 0., 0., 1.]]) +""".format(**factory_common_args), +) + +add_docstr( + torch.floor, + r""" +floor(input, *, out=None) -> Tensor + +Returns a new tensor with the floor of the elements of :attr:`input`, +the largest integer less than or equal to each element. + +For integer inputs, follows the array-api convention of returning a +copy of the input tensor. + +.. math:: + \text{out}_{i} = \left\lfloor \text{input}_{i} \right\rfloor +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-0.8166, 1.5308, -0.2530, -0.2091]) + >>> torch.floor(a) + tensor([-1., 1., -1., -1.]) +""".format(**common_args), +) + +add_docstr( + torch.floor_divide, + r""" +floor_divide(input, other, *, out=None) -> Tensor + +.. note:: + + Before PyTorch 1.13 :func:`torch.floor_divide` incorrectly performed + truncation division. To restore the previous behavior use + :func:`torch.div` with ``rounding_mode='trunc'``. + +Computes :attr:`input` divided by :attr:`other`, elementwise, and floors +the result. + +.. math:: + \text{{out}}_i = \text{floor} \left( \frac{{\text{{input}}_i}}{{\text{{other}}_i}} \right) + +""" + + r""" + +Supports broadcasting to a common shape, type promotion, and integer and float inputs. + +Args: + input (Tensor or Number): the dividend + other (Tensor or Number): the divisor + +Keyword args: + {out} + +Example:: + + >>> a = torch.tensor([4.0, 3.0]) + >>> b = torch.tensor([2.0, 2.0]) + >>> torch.floor_divide(a, b) + tensor([2.0, 1.0]) + >>> torch.floor_divide(a, 1.4) + tensor([2.0, 2.0]) +""".format(**common_args), +) + +add_docstr( + torch.fmod, + r""" +fmod(input, other, *, out=None) -> Tensor + +Applies C++'s `std::fmod `_ entrywise. +The result has the same sign as the dividend :attr:`input` and its absolute value +is less than that of :attr:`other`. + +This function may be defined in terms of :func:`torch.div` as + +.. code:: python + + torch.fmod(a, b) == a - a.div(b, rounding_mode="trunc") * b + +Supports :ref:`broadcasting to a common shape `, +:ref:`type promotion `, and integer and float inputs. + +.. note:: + + When the divisor is zero, returns ``NaN`` for floating point dtypes + on both CPU and GPU; raises ``RuntimeError`` for integer division by + zero on CPU; Integer division by zero on GPU may return any value. + +.. note:: + + Complex inputs are not supported. In some cases, it is not mathematically + possible to satisfy the definition of a modulo operation with complex numbers. + +.. seealso:: + + :func:`torch.remainder` which implements Python's modulus operator. + This one is defined using division rounding down the result. + +Args: + input (Tensor): the dividend + other (Tensor or Scalar): the divisor + +Keyword args: + {out} + +Example:: + + >>> torch.fmod(torch.tensor([-3., -2, -1, 1, 2, 3]), 2) + tensor([-1., -0., -1., 1., 0., 1.]) + >>> torch.fmod(torch.tensor([1, 2, 3, 4, 5]), -1.5) + tensor([1.0000, 0.5000, 0.0000, 1.0000, 0.5000]) + +""".format(**common_args), +) + +add_docstr( + torch.frac, + r""" +frac(input, *, out=None) -> Tensor + +Computes the fractional portion of each element in :attr:`input`. + +.. math:: + \text{out}_{i} = \text{input}_{i} - \left\lfloor |\text{input}_{i}| \right\rfloor * \operatorname{sgn}(\text{input}_{i}) + +Example:: + + >>> torch.frac(torch.tensor([1, 2.5, -3.2])) + tensor([ 0.0000, 0.5000, -0.2000]) +""", +) + +add_docstr( + torch.frexp, + r""" +frexp(input, *, out=None) -> (Tensor mantissa, Tensor exponent) + +Decomposes :attr:`input` into mantissa and exponent tensors +such that :math:`\text{input} = \text{mantissa} \times 2^{\text{exponent}}`. + +The range of mantissa is the open interval (-1, 1). + +Supports float inputs. + +Args: + input (Tensor): the input tensor + + +Keyword args: + out (tuple, optional): the output tensors + +Example:: + + >>> x = torch.arange(9.) + >>> mantissa, exponent = torch.frexp(x) + >>> mantissa + tensor([0.0000, 0.5000, 0.5000, 0.7500, 0.5000, 0.6250, 0.7500, 0.8750, 0.5000]) + >>> exponent + tensor([0, 1, 2, 2, 3, 3, 3, 3, 4], dtype=torch.int32) + >>> torch.ldexp(mantissa, exponent) + tensor([0., 1., 2., 3., 4., 5., 6., 7., 8.]) +""", +) + +add_docstr( + torch.from_numpy, + r""" +from_numpy(ndarray) -> Tensor + +Creates a :class:`Tensor` from a :class:`numpy.ndarray`. + +The returned tensor and :attr:`ndarray` share the same memory. Modifications to +the tensor will be reflected in the :attr:`ndarray` and vice versa. The returned +tensor is not resizable. + +It currently accepts :attr:`ndarray` with dtypes of ``numpy.float64``, +``numpy.float32``, ``numpy.float16``, ``numpy.complex64``, ``numpy.complex128``, +``numpy.int64``, ``numpy.int32``, ``numpy.int16``, ``numpy.int8``, ``numpy.uint8``, +and ``bool``. + +.. warning:: + Writing to a tensor created from a read-only NumPy array is not supported and will result in undefined behavior. + +Example:: + + >>> a = numpy.array([1, 2, 3]) + >>> t = torch.from_numpy(a) + >>> t + tensor([ 1, 2, 3]) + >>> t[0] = -1 + >>> a + array([-1, 2, 3]) +""", +) + +add_docstr( + torch.frombuffer, + r""" +frombuffer(buffer, *, dtype, count=-1, offset=0, requires_grad=False) -> Tensor + +Creates a 1-dimensional :class:`Tensor` from an object that implements +the Python buffer protocol. + +Skips the first :attr:`offset` bytes in the buffer, and interprets the rest of +the raw bytes as a 1-dimensional tensor of type :attr:`dtype` with :attr:`count` +elements. + +Note that either of the following must be true: + +1. :attr:`count` is a positive non-zero number, and the total number of bytes +in the buffer is more than :attr:`offset` plus :attr:`count` times the size +(in bytes) of :attr:`dtype`. + +2. :attr:`count` is negative, and the length (number of bytes) of the buffer +subtracted by the :attr:`offset` is a multiple of the size (in bytes) of +:attr:`dtype`. + +The returned tensor and buffer share the same memory. Modifications to +the tensor will be reflected in the buffer and vice versa. The returned +tensor is not resizable. + +.. note:: + This function increments the reference count for the object that + owns the shared memory. Therefore, such memory will not be deallocated + before the returned tensor goes out of scope. + +.. warning:: + This function's behavior is undefined when passed an object implementing + the buffer protocol whose data is not on the CPU. Doing so is likely to + cause a segmentation fault. + +.. warning:: + This function does not try to infer the :attr:`dtype` (hence, it is not + optional). Passing a different :attr:`dtype` than its source may result + in unexpected behavior. + +Args: + buffer (object): a Python object that exposes the buffer interface. + +Keyword args: + dtype (:class:`torch.dtype`): the desired data type of returned tensor. + count (int, optional): the number of desired elements to be read. + If negative, all the elements (until the end of the buffer) will be + read. Default: -1. + offset (int, optional): the number of bytes to skip at the start of + the buffer. Default: 0. + {requires_grad} + +Example:: + + >>> import array + >>> a = array.array('i', [1, 2, 3]) + >>> t = torch.frombuffer(a, dtype=torch.int32) + >>> t + tensor([ 1, 2, 3]) + >>> t[0] = -1 + >>> a + array([-1, 2, 3]) + + >>> # Interprets the signed char bytes as 32-bit integers. + >>> # Each 4 signed char elements will be interpreted as + >>> # 1 signed 32-bit integer. + >>> import array + >>> a = array.array('b', [-1, 0, 0, 0]) + >>> torch.frombuffer(a, dtype=torch.int32) + tensor([255], dtype=torch.int32) +""".format(**factory_common_args), +) + +add_docstr( + torch.from_file, + r""" +from_file(filename, shared=None, size=0, *, dtype=None, layout=None, device=None, pin_memory=False) + +Creates a CPU tensor with a storage backed by a memory-mapped file. + +If ``shared`` is True, then memory is shared between processes. All changes are written to the file. +If ``shared`` is False, then changes to the tensor do not affect the file. + +``size`` is the number of elements in the Tensor. If ``shared`` is ``False``, then the file must contain +at least ``size * sizeof(dtype)`` bytes. If ``shared`` is ``True`` the file will be created if needed. + +.. note:: + Only CPU tensors can be mapped to files. + +.. note:: + For now, tensors with storages backed by a memory-mapped file cannot be created in pinned memory. + + +Args: + filename (str): file name to map + shared (bool): whether to share memory (whether ``MAP_SHARED`` or ``MAP_PRIVATE`` is passed to the + underlying `mmap(2) call `_) + size (int): number of elements in the tensor + +Keyword args: + {dtype} + {layout} + {device} + {pin_memory} + +Example:: + + >>> t = torch.randn(2, 5, dtype=torch.float64) + >>> t.numpy().tofile('storage.pt') + >>> t_mapped = torch.from_file('storage.pt', shared=False, size=10, dtype=torch.float64) + """.format(**factory_common_args), +) + +add_docstr( + torch.flatten, + r""" +flatten(input, start_dim=0, end_dim=-1) -> Tensor + +Flattens :attr:`input` by reshaping it into a one-dimensional tensor. If :attr:`start_dim` or :attr:`end_dim` +are passed, only dimensions starting with :attr:`start_dim` and ending with :attr:`end_dim` are flattened. +The order of elements in :attr:`input` is unchanged. + +Unlike NumPy's flatten, which always copies input's data, this function may return the original object, a view, +or copy. If no dimensions are flattened, then the original object :attr:`input` is returned. Otherwise, if input can +be viewed as the flattened shape, then that view is returned. Finally, only if the input cannot be viewed as the +flattened shape is input's data copied. See :meth:`torch.Tensor.view` for details on when a view will be returned. + +.. note:: + Flattening a zero-dimensional tensor will return a one-dimensional view. + +Args: + {input} + start_dim (int): the first dim to flatten + end_dim (int): the last dim to flatten + +Example:: + + >>> t = torch.tensor([[[1, 2], + ... [3, 4]], + ... [[5, 6], + ... [7, 8]]]) + >>> torch.flatten(t) + tensor([1, 2, 3, 4, 5, 6, 7, 8]) + >>> torch.flatten(t, start_dim=1) + tensor([[1, 2, 3, 4], + [5, 6, 7, 8]]) +""".format(**common_args), +) + +add_docstr( + torch.unflatten, + r""" +unflatten(input, dim, sizes) -> Tensor + +Expands a dimension of the input tensor over multiple dimensions. + +.. seealso:: + + :func:`torch.flatten` the inverse of this function. It coalesces several dimensions into one. + +Args: + {input} + dim (int): Dimension to be unflattened, specified as an index into + ``input.shape``. + sizes (Tuple[int]): New shape of the unflattened dimension. + One of its elements can be `-1` in which case the corresponding output + dimension is inferred. Otherwise, the product of ``sizes`` *must* + equal ``input.shape[dim]``. + +Returns: + A View of input with the specified dimension unflattened. + +Examples:: + >>> torch.unflatten(torch.randn(3, 4, 1), 1, (2, 2)).shape + torch.Size([3, 2, 2, 1]) + >>> torch.unflatten(torch.randn(3, 4, 1), 1, (-1, 2)).shape + torch.Size([3, 2, 2, 1]) + >>> torch.unflatten(torch.randn(5, 12, 3), -2, (2, 2, 3, 1, 1)).shape + torch.Size([5, 2, 2, 3, 1, 1, 3]) +""".format(**common_args), +) + +add_docstr( + torch.gather, + r""" +gather(input, dim, index, *, sparse_grad=False, out=None) -> Tensor + +Gathers values along an axis specified by `dim`. + +For a 3-D tensor the output is specified by:: + + out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 + out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 + out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2 + +:attr:`input` and :attr:`index` must have the same number of dimensions. +It is also required that ``index.size(d) <= input.size(d)`` for all +dimensions ``d != dim``. :attr:`out` will have the same shape as :attr:`index`. +Note that ``input`` and ``index`` do not broadcast against each other. + +Args: + input (Tensor): the source tensor + dim (int): the axis along which to index + index (LongTensor): the indices of elements to gather + +Keyword arguments: + sparse_grad (bool, optional): If ``True``, gradient w.r.t. :attr:`input` will be a sparse tensor. + out (Tensor, optional): the destination tensor + +Example:: + + >>> t = torch.tensor([[1, 2], [3, 4]]) + >>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]])) + tensor([[ 1, 1], + [ 4, 3]]) +""", +) + + +add_docstr( + torch.gcd, + r""" +gcd(input, other, *, out=None) -> Tensor + +Computes the element-wise greatest common divisor (GCD) of :attr:`input` and :attr:`other`. + +Both :attr:`input` and :attr:`other` must have integer types. + +.. note:: + This defines :math:`gcd(0, 0) = 0`. + +Args: + {input} + other (Tensor): the second input tensor + +Keyword arguments: + {out} + +Example:: + + >>> a = torch.tensor([5, 10, 15]) + >>> b = torch.tensor([3, 4, 5]) + >>> torch.gcd(a, b) + tensor([1, 2, 5]) + >>> c = torch.tensor([3]) + >>> torch.gcd(a, c) + tensor([1, 1, 3]) +""".format(**common_args), +) + +add_docstr( + torch.ge, + r""" +ge(input, other, *, out=None) -> Tensor + +Computes :math:`\text{input} \geq \text{other}` element-wise. +""" + + r""" + +The second argument can be a number or a tensor whose shape is +:ref:`broadcastable ` with the first argument. + +Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + +Keyword args: + {out} + +Returns: + A boolean tensor that is True where :attr:`input` is greater than or equal to :attr:`other` and False elsewhere + +Example:: + + >>> torch.ge(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[True, True], [False, True]]) +""".format(**common_args), +) + +add_docstr( + torch.greater_equal, + r""" +greater_equal(input, other, *, out=None) -> Tensor + +Alias for :func:`torch.ge`. +""", +) + +add_docstr( + torch.gradient, + r""" +gradient(input, *, spacing=1, dim=None, edge_order=1) -> List of Tensors + +Estimates the gradient of a function :math:`g : \mathbb{R}^n \rightarrow \mathbb{R}` in +one or more dimensions using the `second-order accurate central differences method +`_ and +either first or second order estimates at the boundaries. + +The gradient of :math:`g` is estimated using samples. By default, when :attr:`spacing` is not +specified, the samples are entirely described by :attr:`input`, and the mapping of input coordinates +to an output is the same as the tensor's mapping of indices to values. For example, for a three-dimensional +:attr:`input` the function described is :math:`g : \mathbb{R}^3 \rightarrow \mathbb{R}`, and +:math:`g(1, 2, 3)\ == input[1, 2, 3]`. + +When :attr:`spacing` is specified, it modifies the relationship between :attr:`input` and input coordinates. +This is detailed in the "Keyword Arguments" section below. + +The gradient is estimated by estimating each partial derivative of :math:`g` independently. This estimation is +accurate if :math:`g` is in :math:`C^3` (it has at least 3 continuous derivatives), and the estimation can be +improved by providing closer samples. Mathematically, the value at each interior point of a partial derivative +is estimated using `Taylor's theorem with remainder `_. +Letting :math:`x` be an interior point with :math:`x-h_l` and :math:`x+h_r` be points neighboring +it to the left and right respectively, :math:`f(x+h_r)` and :math:`f(x-h_l)` can be estimated using: + +.. math:: + \begin{aligned} + f(x+h_r) = f(x) + h_r f'(x) + {h_r}^2 \frac{f''(x)}{2} + {h_r}^3 \frac{f'''(\xi_1)}{6}, \xi_1 \in (x, x+h_r) \\ + f(x-h_l) = f(x) - h_l f'(x) + {h_l}^2 \frac{f''(x)}{2} - {h_l}^3 \frac{f'''(\xi_2)}{6}, \xi_2 \in (x, x-h_l) \\ + \end{aligned} + +Using the fact that :math:`f \in C^3` and solving the linear system, we derive: + +.. math:: + f'(x) \approx \frac{ {h_l}^2 f(x+h_r) - {h_r}^2 f(x-h_l) + + ({h_r}^2-{h_l}^2 ) f(x) }{ {h_r} {h_l}^2 + {h_r}^2 {h_l} } + +.. note:: + We estimate the gradient of functions in complex domain + :math:`g : \mathbb{C}^n \rightarrow \mathbb{C}` in the same way. + +The value of each partial derivative at the boundary points is computed differently. See edge_order below. + +Args: + input (``Tensor``): the tensor that represents the values of the function + +Keyword args: + spacing (``scalar``, ``list of scalar``, ``list of Tensor``, optional): :attr:`spacing` can be used to modify + how the :attr:`input` tensor's indices relate to sample coordinates. If :attr:`spacing` is a scalar then + the indices are multiplied by the scalar to produce the coordinates. For example, if :attr:`spacing=2` the + indices (1, 2, 3) become coordinates (2, 4, 6). If :attr:`spacing` is a list of scalars then the corresponding + indices are multiplied. For example, if :attr:`spacing=(2, -1, 3)` the indices (1, 2, 3) become coordinates (2, -2, 9). + Finally, if :attr:`spacing` is a list of one-dimensional tensors then each tensor specifies the coordinates for + the corresponding dimension. For example, if the indices are (1, 2, 3) and the tensors are (t0, t1, t2), then + the coordinates are (t0[1], t1[2], t2[3]) + + dim (``int``, ``list of int``, optional): the dimension or dimensions to approximate the gradient over. By default + the partial gradient in every dimension is computed. Note that when :attr:`dim` is specified the elements of + the :attr:`spacing` argument must correspond with the specified dims." + + edge_order (``int``, optional): 1 or 2, for `first-order + `_ or + `second-order `_ + estimation of the boundary ("edge") values, respectively. + +Examples:: + + >>> # Estimates the gradient of f(x)=x^2 at points [-2, -1, 2, 4] + >>> coordinates = (torch.tensor([-2., -1., 1., 4.]),) + >>> values = torch.tensor([4., 1., 1., 16.], ) + >>> torch.gradient(values, spacing = coordinates) + (tensor([-3., -2., 2., 5.]),) + + >>> # Estimates the gradient of the R^2 -> R function whose samples are + >>> # described by the tensor t. Implicit coordinates are [0, 1] for the outermost + >>> # dimension and [0, 1, 2, 3] for the innermost dimension, and function estimates + >>> # partial derivative for both dimensions. + >>> t = torch.tensor([[1, 2, 4, 8], [10, 20, 40, 80]]) + >>> torch.gradient(t) + (tensor([[ 9., 18., 36., 72.], + [ 9., 18., 36., 72.]]), + tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]])) + + >>> # A scalar value for spacing modifies the relationship between tensor indices + >>> # and input coordinates by multiplying the indices to find the + >>> # coordinates. For example, below the indices of the innermost + >>> # 0, 1, 2, 3 translate to coordinates of [0, 2, 4, 6], and the indices of + >>> # the outermost dimension 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = 2.0) # dim = None (implicitly [0, 1]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.5000, 0.7500, 1.5000, 2.0000], + [ 5.0000, 7.5000, 15.0000, 20.0000]])) + >>> # doubling the spacing between samples halves the estimated partial gradients. + + >>> + >>> # Estimates only the partial derivative for dimension 1 + >>> torch.gradient(t, dim = 1) # spacing = None (implicitly 1.) + (tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]]),) + + >>> # When spacing is a list of scalars, the relationship between the tensor + >>> # indices and input coordinates changes based on dimension. + >>> # For example, below, the indices of the innermost dimension 0, 1, 2, 3 translate + >>> # to coordinates of [0, 3, 6, 9], and the indices of the outermost dimension + >>> # 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = [3., 2.]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + + >>> # The following example is a replication of the previous one with explicit + >>> # coordinates. + >>> coords = (torch.tensor([0, 2]), torch.tensor([0, 3, 6, 9])) + >>> torch.gradient(t, spacing = coords) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + +""", +) + +add_docstr( + torch.geqrf, + r""" +geqrf(input, *, out=None) -> (Tensor, Tensor) + +This is a low-level function for calling LAPACK's geqrf directly. This function +returns a namedtuple (a, tau) as defined in `LAPACK documentation for geqrf`_ . + +Computes a QR decomposition of :attr:`input`. +Both `Q` and `R` matrices are stored in the same output tensor `a`. +The elements of `R` are stored on and above the diagonal. +Elementary reflectors (or Householder vectors) implicitly defining matrix `Q` +are stored below the diagonal. +The results of this function can be used together with :func:`torch.linalg.householder_product` +to obtain the `Q` matrix or +with :func:`torch.ormqr`, which uses an implicit representation of the `Q` matrix, +for an efficient matrix-matrix multiplication. + +See `LAPACK documentation for geqrf`_ for further details. + +.. note:: + See also :func:`torch.linalg.qr`, which computes Q and R matrices, and :func:`torch.linalg.lstsq` + with the ``driver="gels"`` option for a function that can solve matrix equations using a QR decomposition. + +Args: + input (Tensor): the input matrix + +Keyword args: + out (tuple, optional): the output tuple of (Tensor, Tensor). Ignored if `None`. Default: `None`. + +.. _LAPACK documentation for geqrf: + http://www.netlib.org/lapack/explore-html/df/dc5/group__variants_g_ecomputational_ga3766ea903391b5cf9008132f7440ec7b.html + +""", +) + +add_docstr( + torch.inner, + r""" +inner(input, other, *, out=None) -> Tensor + +Computes the dot product for 1D tensors. For higher dimensions, sums the product +of elements from :attr:`input` and :attr:`other` along their last dimension. + +.. note:: + + If either :attr:`input` or :attr:`other` is a scalar, the result is equivalent + to `torch.mul(input, other)`. + + If both :attr:`input` and :attr:`other` are non-scalars, the size of their last + dimension must match and the result is equivalent to `torch.tensordot(input, + other, dims=([-1], [-1]))` + +Args: + input (Tensor): First input tensor + other (Tensor): Second input tensor + +Keyword args: + out (Tensor, optional): Optional output tensor to write result into. The output + shape is `input.shape[:-1] + other.shape[:-1]`. + +Example:: + + # Dot product + >>> torch.inner(torch.tensor([1, 2, 3]), torch.tensor([0, 2, 1])) + tensor(7) + + # Multidimensional input tensors + >>> a = torch.randn(2, 3) + >>> a + tensor([[0.8173, 1.0874, 1.1784], + [0.3279, 0.1234, 2.7894]]) + >>> b = torch.randn(2, 4, 3) + >>> b + tensor([[[-0.4682, -0.7159, 0.1506], + [ 0.4034, -0.3657, 1.0387], + [ 0.9892, -0.6684, 0.1774], + [ 0.9482, 1.3261, 0.3917]], + + [[ 0.4537, 0.7493, 1.1724], + [ 0.2291, 0.5749, -0.2267], + [-0.7920, 0.3607, -0.3701], + [ 1.3666, -0.5850, -1.7242]]]) + >>> torch.inner(a, b) + tensor([[[-0.9837, 1.1560, 0.2907, 2.6785], + [ 2.5671, 0.5452, -0.6912, -1.5509]], + + [[ 0.1782, 2.9843, 0.7366, 1.5672], + [ 3.5115, -0.4864, -1.2476, -4.4337]]]) + + # Scalar input + >>> torch.inner(a, torch.tensor(2)) + tensor([[1.6347, 2.1748, 2.3567], + [0.6558, 0.2469, 5.5787]]) +""", +) + +add_docstr( + torch.outer, + r""" +outer(input, vec2, *, out=None) -> Tensor + +Outer product of :attr:`input` and :attr:`vec2`. +If :attr:`input` is a vector of size :math:`n` and :attr:`vec2` is a vector of +size :math:`m`, then :attr:`out` must be a matrix of size :math:`(n \times m)`. + +.. note:: This function does not :ref:`broadcast `. + +Args: + input (Tensor): 1-D input vector + vec2 (Tensor): 1-D input vector + +Keyword args: + out (Tensor, optional): optional output matrix + +Example:: + + >>> v1 = torch.arange(1., 5.) + >>> v2 = torch.arange(1., 4.) + >>> torch.outer(v1, v2) + tensor([[ 1., 2., 3.], + [ 2., 4., 6.], + [ 3., 6., 9.], + [ 4., 8., 12.]]) +""", +) + +add_docstr( + torch.ger, + r""" +ger(input, vec2, *, out=None) -> Tensor + +Alias of :func:`torch.outer`. + +.. warning:: + This function is deprecated and will be removed in a future PyTorch release. + Use :func:`torch.outer` instead. +""", +) + +add_docstr( + torch.get_default_dtype, + r""" +get_default_dtype() -> torch.dtype + +Get the current default floating point :class:`torch.dtype`. + +Example:: + + >>> torch.get_default_dtype() # initial default for floating point is torch.float32 + torch.float32 + >>> torch.set_default_dtype(torch.float64) + >>> torch.get_default_dtype() # default is now changed to torch.float64 + torch.float64 + +""", +) + +add_docstr( + torch.get_num_threads, + r""" +get_num_threads() -> int + +Returns the number of threads used for parallelizing CPU operations +""", +) + +add_docstr( + torch.get_num_interop_threads, + r""" +get_num_interop_threads() -> int + +Returns the number of threads used for inter-op parallelism on CPU +(e.g. in JIT interpreter) +""", +) + +add_docstr( + torch.gt, + r""" +gt(input, other, *, out=None) -> Tensor + +Computes :math:`\text{input} > \text{other}` element-wise. +""" + + r""" + +The second argument can be a number or a tensor whose shape is +:ref:`broadcastable ` with the first argument. + +Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + +Keyword args: + {out} + +Returns: + A boolean tensor that is True where :attr:`input` is greater than :attr:`other` and False elsewhere + +Example:: + + >>> torch.gt(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[False, True], [False, False]]) +""".format(**common_args), +) + +add_docstr( + torch.greater, + r""" +greater(input, other, *, out=None) -> Tensor + +Alias for :func:`torch.gt`. +""", +) + +add_docstr( + torch.histc, + r""" +histc(input, bins=100, min=0, max=0, *, out=None) -> Tensor + +Computes the histogram of a tensor. + +The elements are sorted into equal width bins between :attr:`min` and +:attr:`max`. If :attr:`min` and :attr:`max` are both zero, the minimum and +maximum values of the data are used. + +Elements lower than min and higher than max and ``NaN`` elements are ignored. + +Args: + {input} + bins (int): number of histogram bins + min (Scalar): lower end of the range (inclusive) + max (Scalar): upper end of the range (inclusive) + +Keyword args: + {out} + +Returns: + Tensor: Histogram represented as a tensor + +Example:: + + >>> torch.histc(torch.tensor([1., 2, 1]), bins=4, min=0, max=3) + tensor([ 0., 2., 1., 0.]) +""".format(**common_args), +) + +add_docstr( + torch.histogram, + r""" +histogram(input, bins, *, range=None, weight=None, density=False, out=None) -> (Tensor, Tensor) + +Computes a histogram of the values in a tensor. + +:attr:`bins` can be an integer or a 1D tensor. + +If :attr:`bins` is an int, it specifies the number of equal-width bins. +By default, the lower and upper range of the bins is determined by the +minimum and maximum elements of the input tensor. The :attr:`range` +argument can be provided to specify a range for the bins. + +If :attr:`bins` is a 1D tensor, it specifies the sequence of bin edges +including the rightmost edge. It should contain at least 2 elements +and its elements should be increasing. + +Args: + {input} + bins: int or 1D Tensor. If int, defines the number of equal-width bins. If tensor, + defines the sequence of bin edges including the rightmost edge. + +Keyword args: + range (tuple of float): Defines the range of the bins. + weight (Tensor): If provided, weight should have the same shape as input. Each value in + input contributes its associated weight towards its bin's result. + density (bool): If False, the result will contain the count (or total weight) in each bin. + If True, the result is the value of the probability density function over the bins, + normalized such that the integral over the range of the bins is 1. + {out} (tuple, optional): The result tuple of two output tensors (hist, bin_edges). + +Returns: + hist (Tensor): 1D Tensor containing the values of the histogram. + bin_edges(Tensor): 1D Tensor containing the edges of the histogram bins. + +Example:: + + >>> torch.histogram(torch.tensor([1., 2, 1]), bins=4, range=(0., 3.), weight=torch.tensor([1., 2., 4.])) + (tensor([ 0., 5., 2., 0.]), tensor([0., 0.75, 1.5, 2.25, 3.])) + >>> torch.histogram(torch.tensor([1., 2, 1]), bins=4, range=(0., 3.), weight=torch.tensor([1., 2., 4.]), density=True) + (tensor([ 0., 0.9524, 0.3810, 0.]), tensor([0., 0.75, 1.5, 2.25, 3.])) +""".format(**common_args), +) + +add_docstr( + torch.histogramdd, + r""" +histogramdd(input, bins, *, range=None, weight=None, density=False, out=None) -> (Tensor, Tensor[]) + +Computes a multi-dimensional histogram of the values in a tensor. + +Interprets the elements of an input tensor whose innermost dimension has size N +as a collection of N-dimensional points. Maps each of the points into a set of +N-dimensional bins and returns the number of points (or total weight) in each bin. + +:attr:`input` must be a tensor with at least 2 dimensions. +If input has shape (M, N), each of its M rows defines a point in N-dimensional space. +If input has three or more dimensions, all but the last dimension are flattened. + +Each dimension is independently associated with its own strictly increasing sequence +of bin edges. Bin edges may be specified explicitly by passing a sequence of 1D +tensors. Alternatively, bin edges may be constructed automatically by passing a +sequence of integers specifying the number of equal-width bins in each dimension. + +For each N-dimensional point in input: + - Each of its coordinates is binned independently among the bin edges + corresponding to its dimension + - Binning results are combined to identify the N-dimensional bin (if any) + into which the point falls + - If the point falls into a bin, the bin's count (or total weight) is incremented + - Points which do not fall into any bin do not contribute to the output + +:attr:`bins` can be a sequence of N 1D tensors, a sequence of N ints, or a single int. + +If :attr:`bins` is a sequence of N 1D tensors, it explicitly specifies the N sequences +of bin edges. Each 1D tensor should contain a strictly increasing sequence with at +least one element. A sequence of K bin edges defines K-1 bins, explicitly specifying +the left and right edges of all bins. Every bin is exclusive of its left edge. Only +the rightmost bin is inclusive of its right edge. + +If :attr:`bins` is a sequence of N ints, it specifies the number of equal-width bins +in each dimension. By default, the leftmost and rightmost bin edges in each dimension +are determined by the minimum and maximum elements of the input tensor in the +corresponding dimension. The :attr:`range` argument can be provided to manually +specify the leftmost and rightmost bin edges in each dimension. + +If :attr:`bins` is an int, it specifies the number of equal-width bins for all dimensions. + +.. note:: + See also :func:`torch.histogram`, which specifically computes 1D histograms. + While :func:`torch.histogramdd` infers the dimensionality of its bins and + binned values from the shape of :attr:`input`, :func:`torch.histogram` + accepts and flattens :attr:`input` of any shape. + +Args: + {input} + bins: Tensor[], int[], or int. + If Tensor[], defines the sequences of bin edges. + If int[], defines the number of equal-width bins in each dimension. + If int, defines the number of equal-width bins for all dimensions. +Keyword args: + range (sequence of float): Defines the leftmost and rightmost bin edges + in each dimension. + weight (Tensor): By default, each value in the input has weight 1. If a weight + tensor is passed, each N-dimensional coordinate in input + contributes its associated weight towards its bin's result. + The weight tensor should have the same shape as the :attr:`input` + tensor excluding its innermost dimension N. + density (bool): If False (default), the result will contain the count (or total weight) + in each bin. If True, each count (weight) is divided by the total count + (total weight), then divided by the volume of its associated bin. +Returns: + hist (Tensor): N-dimensional Tensor containing the values of the histogram. + bin_edges(Tensor[]): sequence of N 1D Tensors containing the bin edges. + +Example:: + + >>> torch.histogramdd(torch.tensor([[0., 1.], [1., 0.], [2., 0.], [2., 2.]]), bins=[3, 3], + ... weight=torch.tensor([1., 2., 4., 8.])) + torch.return_types.histogramdd( + hist=tensor([[0., 1., 0.], + [2., 0., 0.], + [4., 0., 8.]]), + bin_edges=(tensor([0.0000, 0.6667, 1.3333, 2.0000]), + tensor([0.0000, 0.6667, 1.3333, 2.0000]))) + + >>> torch.histogramdd(torch.tensor([[0., 0.], [1., 1.], [2., 2.]]), bins=[2, 2], + ... range=[0., 1., 0., 1.], density=True) + torch.return_types.histogramdd( + hist=tensor([[2., 0.], + [0., 2.]]), + bin_edges=(tensor([0.0000, 0.5000, 1.0000]), + tensor([0.0000, 0.5000, 1.0000]))) + +""".format(**common_args), +) +# TODO: Fix via https://github.com/pytorch/pytorch/issues/75798 +torch.histogramdd.__module__ = "torch" + +add_docstr( + torch.hypot, + r""" +hypot(input, other, *, out=None) -> Tensor + +Given the legs of a right triangle, return its hypotenuse. + +.. math:: + \text{out}_{i} = \sqrt{\text{input}_{i}^{2} + \text{other}_{i}^{2}} + +The shapes of ``input`` and ``other`` must be +:ref:`broadcastable `. +""" + + r""" +Args: + input (Tensor): the first input tensor + other (Tensor): the second input tensor + +Keyword args: + {out} + +Example:: + + >>> a = torch.hypot(torch.tensor([4.0]), torch.tensor([3.0, 4.0, 5.0])) + tensor([5.0000, 5.6569, 6.4031]) + +""".format(**common_args), +) + +add_docstr( + torch.i0, + r""" +i0(input, *, out=None) -> Tensor + +Alias for :func:`torch.special.i0`. +""", +) + +add_docstr( + torch.igamma, + r""" +igamma(input, other, *, out=None) -> Tensor + +Alias for :func:`torch.special.gammainc`. +""", +) + +add_docstr( + torch.igammac, + r""" +igammac(input, other, *, out=None) -> Tensor + +Alias for :func:`torch.special.gammaincc`. +""", +) + +add_docstr( + torch.index_select, + r""" +index_select(input, dim, index, *, out=None) -> Tensor + +Returns a new tensor which indexes the :attr:`input` tensor along dimension +:attr:`dim` using the entries in :attr:`index` which is a `LongTensor`. + +The returned tensor has the same number of dimensions as the original tensor +(:attr:`input`). The :attr:`dim`\ th dimension has the same size as the length +of :attr:`index`; other dimensions have the same size as in the original tensor. + +.. note:: The returned tensor does **not** use the same storage as the original + tensor. If :attr:`out` has a different shape than expected, we + silently change it to the correct shape, reallocating the underlying + storage if necessary. + +Args: + {input} + dim (int): the dimension in which we index + index (IntTensor or LongTensor): the 1-D tensor containing the indices to index + +Keyword args: + {out} + +Example:: + + >>> x = torch.randn(3, 4) + >>> x + tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], + [-0.4664, 0.2647, -0.1228, -1.1068], + [-1.1734, -0.6571, 0.7230, -0.6004]]) + >>> indices = torch.tensor([0, 2]) + >>> torch.index_select(x, 0, indices) + tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], + [-1.1734, -0.6571, 0.7230, -0.6004]]) + >>> torch.index_select(x, 1, indices) + tensor([[ 0.1427, -0.5414], + [-0.4664, -0.1228], + [-1.1734, 0.7230]]) +""".format(**common_args), +) + +add_docstr( + torch.inverse, + r""" +inverse(input, *, out=None) -> Tensor + +Alias for :func:`torch.linalg.inv` +""", +) + +add_docstr( + torch.isin, + r""" +isin(elements, test_elements, *, assume_unique=False, invert=False) -> Tensor + +Tests if each element of :attr:`elements` is in :attr:`test_elements`. Returns +a boolean tensor of the same shape as :attr:`elements` that is True for elements +in :attr:`test_elements` and False otherwise. + +.. note:: + One of :attr:`elements` or :attr:`test_elements` can be a scalar, but not both. + +Args: + elements (Tensor or Scalar): Input elements + test_elements (Tensor or Scalar): Values against which to test for each input element + assume_unique (bool, optional): If True, assumes both :attr:`elements` and + :attr:`test_elements` contain unique elements, which can speed up the + calculation. Default: False + invert (bool, optional): If True, inverts the boolean return tensor, resulting in True + values for elements *not* in :attr:`test_elements`. Default: False + +Returns: + A boolean tensor of the same shape as :attr:`elements` that is True for elements in + :attr:`test_elements` and False otherwise + +Example: + >>> torch.isin(torch.tensor([[1, 2], [3, 4]]), torch.tensor([2, 3])) + tensor([[False, True], + [ True, False]]) +""", +) + +add_docstr( + torch.isinf, + r""" +isinf(input) -> Tensor + +Tests if each element of :attr:`input` is infinite +(positive or negative infinity) or not. + +.. note:: + Complex values are infinite when their real or imaginary part is + infinite. + +Args: + {input} + +Returns: + A boolean tensor that is True where :attr:`input` is infinite and False elsewhere + +Example:: + + >>> torch.isinf(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')])) + tensor([False, True, False, True, False]) +""".format(**common_args), +) + +add_docstr( + torch.isposinf, + r""" +isposinf(input, *, out=None) -> Tensor +Tests if each element of :attr:`input` is positive infinity or not. + +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.tensor([-float('inf'), float('inf'), 1.2]) + >>> torch.isposinf(a) + tensor([False, True, False]) +""".format(**common_args), +) + +add_docstr( + torch.isneginf, + r""" +isneginf(input, *, out=None) -> Tensor +Tests if each element of :attr:`input` is negative infinity or not. + +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.tensor([-float('inf'), float('inf'), 1.2]) + >>> torch.isneginf(a) + tensor([ True, False, False]) +""".format(**common_args), +) + +add_docstr( + torch.isclose, + r""" +isclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False) -> Tensor + +Returns a new tensor with boolean elements representing if each element of +:attr:`input` is "close" to the corresponding element of :attr:`other`. +Closeness is defined as: + +.. math:: + \lvert \text{input}_i - \text{other}_i \rvert \leq \texttt{rtol} \times \lvert \text{other}_i \rvert + \texttt{atol} +""" + + r""" + +where :attr:`input` and :attr:`other` are finite. Where :attr:`input` +and/or :attr:`other` are nonfinite they are close if and only if +they are equal, with NaNs being considered equal to each other when +:attr:`equal_nan` is True. + +Args: + input (Tensor): first tensor to compare + other (Tensor): second tensor to compare + rtol (float, optional): relative tolerance. Default: 1e-05 + atol (float, optional): absolute tolerance. Default: 1e-08 + equal_nan (bool, optional): if ``True``, then two ``NaN`` s will be considered equal. Default: ``False`` + +Examples:: + + >>> torch.isclose(torch.tensor((1., 2, 3)), torch.tensor((1 + 1e-10, 3, 4))) + tensor([ True, False, False]) + >>> torch.isclose(torch.tensor((float('inf'), 4)), torch.tensor((float('inf'), 6)), rtol=.5) + tensor([True, True]) +""", +) + +add_docstr( + torch.isfinite, + r""" +isfinite(input) -> Tensor + +Returns a new tensor with boolean elements representing if each element is `finite` or not. + +Real values are finite when they are not NaN, negative infinity, or infinity. +Complex values are finite when both their real and imaginary parts are finite. + +Args: + {input} + +Returns: + A boolean tensor that is True where :attr:`input` is finite and False elsewhere + +Example:: + + >>> torch.isfinite(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')])) + tensor([True, False, True, False, False]) +""".format(**common_args), +) + +add_docstr( + torch.isnan, + r""" +isnan(input) -> Tensor + +Returns a new tensor with boolean elements representing if each element of :attr:`input` +is NaN or not. Complex values are considered NaN when either their real +and/or imaginary part is NaN. + +Arguments: + {input} + +Returns: + A boolean tensor that is True where :attr:`input` is NaN and False elsewhere + +Example:: + + >>> torch.isnan(torch.tensor([1, float('nan'), 2])) + tensor([False, True, False]) +""".format(**common_args), +) + +add_docstr( + torch.isreal, + r""" +isreal(input) -> Tensor + +Returns a new tensor with boolean elements representing if each element of :attr:`input` is real-valued or not. +All real-valued types are considered real. Complex values are considered real when their imaginary part is 0. + +Arguments: + {input} + +Returns: + A boolean tensor that is True where :attr:`input` is real and False elsewhere + +Example:: + + >>> torch.isreal(torch.tensor([1, 1+1j, 2+0j])) + tensor([True, False, True]) +""".format(**common_args), +) + +add_docstr( + torch.is_floating_point, + r""" +is_floating_point(input) -> (bool) + +Returns True if the data type of :attr:`input` is a floating point data type i.e., +one of ``torch.float64``, ``torch.float32``, ``torch.float16``, and ``torch.bfloat16``. + +Args: + {input} +""".format(**common_args), +) + +add_docstr( + torch.is_complex, + r""" +is_complex(input) -> (bool) + +Returns True if the data type of :attr:`input` is a complex data type i.e., +one of ``torch.complex64``, and ``torch.complex128``. + +Args: + {input} +""".format(**common_args), +) + +add_docstr( + torch.is_grad_enabled, + r""" +is_grad_enabled() -> (bool) + +Returns True if grad mode is currently enabled. +""".format(**common_args), +) + +add_docstr( + torch.is_inference_mode_enabled, + r""" +is_inference_mode_enabled() -> (bool) + +Returns True if inference mode is currently enabled. +""".format(**common_args), +) + +add_docstr( + torch.is_inference, + r""" +is_inference(input) -> (bool) + +Returns True if :attr:`input` is an inference tensor. + +A non-view tensor is an inference tensor if and only if it was +allocated during inference mode. A view tensor is an inference +tensor if and only if the tensor it is a view of is an inference tensor. + +For details on inference mode please see +`Inference Mode `_. + +Args: + {input} +""".format(**common_args), +) + +add_docstr( + torch.is_conj, + r""" +is_conj(input) -> (bool) + +Returns True if the :attr:`input` is a conjugated tensor, i.e. its conjugate bit is set to `True`. + +Args: + {input} +""".format(**common_args), +) + +add_docstr( + torch.is_nonzero, + r""" +is_nonzero(input) -> (bool) + +Returns True if the :attr:`input` is a single element tensor which is not equal to zero +after type conversions. +i.e. not equal to ``torch.tensor([0.])`` or ``torch.tensor([0])`` or +``torch.tensor([False])``. +Throws a ``RuntimeError`` if ``torch.numel() != 1`` (even in case +of sparse tensors). + +Args: + {input} + +Examples:: + + >>> torch.is_nonzero(torch.tensor([0.])) + False + >>> torch.is_nonzero(torch.tensor([1.5])) + True + >>> torch.is_nonzero(torch.tensor([False])) + False + >>> torch.is_nonzero(torch.tensor([3])) + True + >>> torch.is_nonzero(torch.tensor([1, 3, 5])) + Traceback (most recent call last): + ... + RuntimeError: bool value of Tensor with more than one value is ambiguous + >>> torch.is_nonzero(torch.tensor([])) + Traceback (most recent call last): + ... + RuntimeError: bool value of Tensor with no values is ambiguous +""".format(**common_args), +) + +add_docstr( + torch.kron, + r""" +kron(input, other, *, out=None) -> Tensor + +Computes the Kronecker product, denoted by :math:`\otimes`, of :attr:`input` and :attr:`other`. + +If :attr:`input` is a :math:`(a_0 \times a_1 \times \dots \times a_n)` tensor and :attr:`other` is a +:math:`(b_0 \times b_1 \times \dots \times b_n)` tensor, the result will be a +:math:`(a_0*b_0 \times a_1*b_1 \times \dots \times a_n*b_n)` tensor with the following entries: + +.. math:: + (\text{input} \otimes \text{other})_{k_0, k_1, \dots, k_n} = + \text{input}_{i_0, i_1, \dots, i_n} * \text{other}_{j_0, j_1, \dots, j_n}, + +where :math:`k_t = i_t * b_t + j_t` for :math:`0 \leq t \leq n`. +If one tensor has fewer dimensions than the other it is unsqueezed until it has the same number of dimensions. + +Supports real-valued and complex-valued inputs. + +.. note:: + This function generalizes the typical definition of the Kronecker product for two matrices to two tensors, + as described above. When :attr:`input` is a :math:`(m \times n)` matrix and :attr:`other` is a + :math:`(p \times q)` matrix, the result will be a :math:`(p*m \times q*n)` block matrix: + + .. math:: + \mathbf{A} \otimes \mathbf{B}=\begin{bmatrix} + a_{11} \mathbf{B} & \cdots & a_{1 n} \mathbf{B} \\ + \vdots & \ddots & \vdots \\ + a_{m 1} \mathbf{B} & \cdots & a_{m n} \mathbf{B} \end{bmatrix} + + where :attr:`input` is :math:`\mathbf{A}` and :attr:`other` is :math:`\mathbf{B}`. + +Arguments: + input (Tensor) + other (Tensor) + +Keyword args: + out (Tensor, optional): The output tensor. Ignored if ``None``. Default: ``None`` + +Examples:: + + >>> mat1 = torch.eye(2) + >>> mat2 = torch.ones(2, 2) + >>> torch.kron(mat1, mat2) + tensor([[1., 1., 0., 0.], + [1., 1., 0., 0.], + [0., 0., 1., 1.], + [0., 0., 1., 1.]]) + + >>> mat1 = torch.eye(2) + >>> mat2 = torch.arange(1, 5).reshape(2, 2) + >>> torch.kron(mat1, mat2) + tensor([[1., 2., 0., 0.], + [3., 4., 0., 0.], + [0., 0., 1., 2.], + [0., 0., 3., 4.]]) +""", +) + +add_docstr( + torch.kthvalue, + r""" +kthvalue(input, k, dim=None, keepdim=False, *, out=None) -> (Tensor, LongTensor) + +Returns a namedtuple ``(values, indices)`` where ``values`` is the :attr:`k` th +smallest element of each row of the :attr:`input` tensor in the given dimension +:attr:`dim`. And ``indices`` is the index location of each element found. + +If :attr:`dim` is not given, the last dimension of the `input` is chosen. + +If :attr:`keepdim` is ``True``, both the :attr:`values` and :attr:`indices` tensors +are the same size as :attr:`input`, except in the dimension :attr:`dim` where +they are of size 1. Otherwise, :attr:`dim` is squeezed +(see :func:`torch.squeeze`), resulting in both the :attr:`values` and +:attr:`indices` tensors having 1 fewer dimension than the :attr:`input` tensor. + +.. note:: + When :attr:`input` is a CUDA tensor and there are multiple valid + :attr:`k` th values, this function may nondeterministically return + :attr:`indices` for any of them. + +Args: + {input} + k (int): k for the k-th smallest element + dim (int, optional): the dimension to find the kth value along + {opt_keepdim} + +Keyword args: + out (tuple, optional): the output tuple of (Tensor, LongTensor) + can be optionally given to be used as output buffers + +Example:: + + >>> x = torch.arange(1., 6.) + >>> x + tensor([ 1., 2., 3., 4., 5.]) + >>> torch.kthvalue(x, 4) + torch.return_types.kthvalue(values=tensor(4.), indices=tensor(3)) + + >>> x=torch.arange(1.,7.).resize_(2,3) + >>> x + tensor([[ 1., 2., 3.], + [ 4., 5., 6.]]) + >>> torch.kthvalue(x, 2, 0, True) + torch.return_types.kthvalue(values=tensor([[4., 5., 6.]]), indices=tensor([[1, 1, 1]])) +""".format(**single_dim_common), +) + +add_docstr( + torch.lcm, + r""" +lcm(input, other, *, out=None) -> Tensor + +Computes the element-wise least common multiple (LCM) of :attr:`input` and :attr:`other`. + +Both :attr:`input` and :attr:`other` must have integer types. + +.. note:: + This defines :math:`lcm(0, 0) = 0` and :math:`lcm(0, a) = 0`. + +Args: + {input} + other (Tensor): the second input tensor + +Keyword arguments: + {out} + +Example:: + + >>> a = torch.tensor([5, 10, 15]) + >>> b = torch.tensor([3, 4, 5]) + >>> torch.lcm(a, b) + tensor([15, 20, 15]) + >>> c = torch.tensor([3]) + >>> torch.lcm(a, c) + tensor([15, 30, 15]) +""".format(**common_args), +) + +add_docstr( + torch.ldexp, + r""" +ldexp(input, other, *, out=None) -> Tensor + +Multiplies :attr:`input` by 2 ** :attr:`other`. + +.. math:: + \text{{out}}_i = \text{{input}}_i * 2^\text{{other}}_i +""" + + r""" + +Typically this function is used to construct floating point numbers by multiplying +mantissas in :attr:`input` with integral powers of two created from the exponents +in :attr:`other`. + +Args: + {input} + other (Tensor): a tensor of exponents, typically integers. + +Keyword args: + {out} + +Example:: + + >>> torch.ldexp(torch.tensor([1.]), torch.tensor([1])) + tensor([2.]) + >>> torch.ldexp(torch.tensor([1.0]), torch.tensor([1, 2, 3, 4])) + tensor([ 2., 4., 8., 16.]) + + +""".format(**common_args), +) + +add_docstr( + torch.le, + r""" +le(input, other, *, out=None) -> Tensor + +Computes :math:`\text{input} \leq \text{other}` element-wise. +""" + + r""" + +The second argument can be a number or a tensor whose shape is +:ref:`broadcastable ` with the first argument. + +Args: + input (Tensor): the tensor to compare + other (Tensor or Scalar): the tensor or value to compare + +Keyword args: + {out} + +Returns: + A boolean tensor that is True where :attr:`input` is less than or equal to + :attr:`other` and False elsewhere + +Example:: + + >>> torch.le(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[True, False], [True, True]]) +""".format(**common_args), +) + +add_docstr( + torch.less_equal, + r""" +less_equal(input, other, *, out=None) -> Tensor + +Alias for :func:`torch.le`. +""", +) + +add_docstr( + torch.lerp, + r""" +lerp(input, end, weight, *, out=None) + +Does a linear interpolation of two tensors :attr:`start` (given by :attr:`input`) and :attr:`end` based +on a scalar or tensor :attr:`weight` and returns the resulting :attr:`out` tensor. + +.. math:: + \text{out}_i = \text{start}_i + \text{weight}_i \times (\text{end}_i - \text{start}_i) +""" + + r""" +The shapes of :attr:`start` and :attr:`end` must be +:ref:`broadcastable `. If :attr:`weight` is a tensor, then +the shapes of :attr:`weight`, :attr:`start`, and :attr:`end` must be :ref:`broadcastable `. + +Args: + input (Tensor): the tensor with the starting points + end (Tensor): the tensor with the ending points + weight (float or tensor): the weight for the interpolation formula + +Keyword args: + {out} + +Example:: + + >>> start = torch.arange(1., 5.) + >>> end = torch.empty(4).fill_(10) + >>> start + tensor([ 1., 2., 3., 4.]) + >>> end + tensor([ 10., 10., 10., 10.]) + >>> torch.lerp(start, end, 0.5) + tensor([ 5.5000, 6.0000, 6.5000, 7.0000]) + >>> torch.lerp(start, end, torch.full_like(start, 0.5)) + tensor([ 5.5000, 6.0000, 6.5000, 7.0000]) +""".format(**common_args), +) + +add_docstr( + torch.lgamma, + r""" +lgamma(input, *, out=None) -> Tensor + +Computes the natural logarithm of the absolute value of the gamma function on :attr:`input`. + +.. math:: + \text{out}_{i} = \ln |\Gamma(\text{input}_{i})| +""" + + """ +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.arange(0.5, 2, 0.5) + >>> torch.lgamma(a) + tensor([ 0.5724, 0.0000, -0.1208]) +""".format(**common_args), +) + +add_docstr( + torch.linspace, + r""" +linspace(start, end, steps, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + +Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly +spaced from :attr:`start` to :attr:`end`, inclusive. That is, the value are: + +.. math:: + (\text{start}, + \text{start} + \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \ldots, + \text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \text{end}) +""" + + """ + +From PyTorch 1.11 linspace requires the steps argument. Use steps=100 to restore the previous behavior. + +Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + +Keyword arguments: + {out} + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + {layout} + {device} + {requires_grad} + + +Example:: + + >>> torch.linspace(3, 10, steps=5) + tensor([ 3.0000, 4.7500, 6.5000, 8.2500, 10.0000]) + >>> torch.linspace(-10, 10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=1) + tensor([-10.]) +""".format(**factory_common_args), +) + +add_docstr( + torch.log, + r""" +log(input, *, out=None) -> Tensor + +Returns a new tensor with the natural logarithm of the elements +of :attr:`input`. + +.. math:: + y_{i} = \log_{e} (x_{i}) +""" + + r""" + +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.rand(5) * 5 + >>> a + tensor([4.7767, 4.3234, 1.2156, 0.2411, 4.5739]) + >>> torch.log(a) + tensor([ 1.5637, 1.4640, 0.1952, -1.4226, 1.5204]) +""".format(**common_args), +) + +add_docstr( + torch.log10, + r""" +log10(input: Tensor, *, out: Optional[Tensor]) -> Tensor + +Returns a new tensor with the logarithm to the base 10 of the elements +of :attr:`input`. + +.. math:: + y_{i} = \log_{10} (x_{i}) +""" + + r""" + +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.rand(5) + >>> a + tensor([ 0.5224, 0.9354, 0.7257, 0.1301, 0.2251]) + + + >>> torch.log10(a) + tensor([-0.2820, -0.0290, -0.1392, -0.8857, -0.6476]) + +""".format(**common_args), +) + +add_docstr( + torch.log1p, + r""" +log1p(input, *, out=None) -> Tensor + +Returns a new tensor with the natural logarithm of (1 + :attr:`input`). + +.. math:: + y_i = \log_{e} (x_i + 1) +""" + + r""" +.. note:: This function is more accurate than :func:`torch.log` for small + values of :attr:`input` + +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(5) + >>> a + tensor([-1.0090, -0.9923, 1.0249, -0.5372, 0.2492]) + >>> torch.log1p(a) + tensor([ nan, -4.8653, 0.7055, -0.7705, 0.2225]) +""".format(**common_args), +) + +add_docstr( + torch.log2, + r""" +log2(input: Tensor, *, out: Optional[Tensor]) -> Tensor + +Returns a new tensor with the logarithm to the base 2 of the elements +of :attr:`input`. + +.. math:: + y_{i} = \log_{2} (x_{i}) +""" + + r""" + +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.rand(5) + >>> a + tensor([ 0.8419, 0.8003, 0.9971, 0.5287, 0.0490]) + + + >>> torch.log2(a) + tensor([-0.2483, -0.3213, -0.0042, -0.9196, -4.3504]) + +""".format(**common_args), +) + +add_docstr( + torch.logaddexp, + r""" +logaddexp(input, other, *, out=None) -> Tensor + +Logarithm of the sum of exponentiations of the inputs. + +Calculates pointwise :math:`\log\left(e^x + e^y\right)`. This function is useful +in statistics where the calculated probabilities of events may be so small as to +exceed the range of normal floating point numbers. In such cases the logarithm +of the calculated probability is stored. This function allows adding +probabilities stored in such a fashion. + +This op should be disambiguated with :func:`torch.logsumexp` which performs a +reduction on a single tensor. + +Args: + {input} + other (Tensor): the second input tensor + +Keyword arguments: + {out} + +Example:: + + >>> torch.logaddexp(torch.tensor([-1.0]), torch.tensor([-1.0, -2, -3])) + tensor([-0.3069, -0.6867, -0.8731]) + >>> torch.logaddexp(torch.tensor([-100.0, -200, -300]), torch.tensor([-1.0, -2, -3])) + tensor([-1., -2., -3.]) + >>> torch.logaddexp(torch.tensor([1.0, 2000, 30000]), torch.tensor([-1.0, -2, -3])) + tensor([1.1269e+00, 2.0000e+03, 3.0000e+04]) +""".format(**common_args), +) + +add_docstr( + torch.logaddexp2, + r""" +logaddexp2(input, other, *, out=None) -> Tensor + +Logarithm of the sum of exponentiations of the inputs in base-2. + +Calculates pointwise :math:`\log_2\left(2^x + 2^y\right)`. See +:func:`torch.logaddexp` for more details. + +Args: + {input} + other (Tensor): the second input tensor + +Keyword arguments: + {out} +""".format(**common_args), +) + +add_docstr( + torch.xlogy, + r""" +xlogy(input, other, *, out=None) -> Tensor + +Alias for :func:`torch.special.xlogy`. +""", +) + +add_docstr( + torch.logical_and, + r""" +logical_and(input, other, *, out=None) -> Tensor + +Computes the element-wise logical AND of the given input tensors. Zeros are treated as ``False`` and nonzeros are +treated as ``True``. + +Args: + {input} + other (Tensor): the tensor to compute AND with + +Keyword args: + {out} + +Example:: + + >>> torch.logical_and(torch.tensor([True, False, True]), torch.tensor([True, False, False])) + tensor([ True, False, False]) + >>> a = torch.tensor([0, 1, 10, 0], dtype=torch.int8) + >>> b = torch.tensor([4, 0, 1, 0], dtype=torch.int8) + >>> torch.logical_and(a, b) + tensor([False, False, True, False]) + >>> torch.logical_and(a.double(), b.double()) + tensor([False, False, True, False]) + >>> torch.logical_and(a.double(), b) + tensor([False, False, True, False]) + >>> torch.logical_and(a, b, out=torch.empty(4, dtype=torch.bool)) + tensor([False, False, True, False]) +""".format(**common_args), +) + +add_docstr( + torch.logical_not, + r""" +logical_not(input, *, out=None) -> Tensor + +Computes the element-wise logical NOT of the given input tensor. If not specified, the output tensor will have the bool +dtype. If the input tensor is not a bool tensor, zeros are treated as ``False`` and non-zeros are treated as ``True``. + +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> torch.logical_not(torch.tensor([True, False])) + tensor([False, True]) + >>> torch.logical_not(torch.tensor([0, 1, -10], dtype=torch.int8)) + tensor([ True, False, False]) + >>> torch.logical_not(torch.tensor([0., 1.5, -10.], dtype=torch.double)) + tensor([ True, False, False]) + >>> torch.logical_not(torch.tensor([0., 1., -10.], dtype=torch.double), out=torch.empty(3, dtype=torch.int16)) + tensor([1, 0, 0], dtype=torch.int16) +""".format(**common_args), +) + +add_docstr( + torch.logical_or, + r""" +logical_or(input, other, *, out=None) -> Tensor + +Computes the element-wise logical OR of the given input tensors. Zeros are treated as ``False`` and nonzeros are +treated as ``True``. + +Args: + {input} + other (Tensor): the tensor to compute OR with + +Keyword args: + {out} + +Example:: + + >>> torch.logical_or(torch.tensor([True, False, True]), torch.tensor([True, False, False])) + tensor([ True, False, True]) + >>> a = torch.tensor([0, 1, 10, 0], dtype=torch.int8) + >>> b = torch.tensor([4, 0, 1, 0], dtype=torch.int8) + >>> torch.logical_or(a, b) + tensor([ True, True, True, False]) + >>> torch.logical_or(a.double(), b.double()) + tensor([ True, True, True, False]) + >>> torch.logical_or(a.double(), b) + tensor([ True, True, True, False]) + >>> torch.logical_or(a, b, out=torch.empty(4, dtype=torch.bool)) + tensor([ True, True, True, False]) +""".format(**common_args), +) + +add_docstr( + torch.logical_xor, + r""" +logical_xor(input: Tensor, other: Tensor, *, out: Optional[Tensor]) -> Tensor + +Computes the element-wise logical XOR of the given input tensors. Zeros are treated as ``False`` and nonzeros are +treated as ``True``. + +Args: + {input} + other (Tensor): the tensor to compute XOR with + +Keyword args: + {out} + +Example:: + + >>> torch.logical_xor(torch.tensor([True, False, True]), torch.tensor([True, False, False])) + tensor([False, False, True]) + >>> a = torch.tensor([0, 1, 10, 0], dtype=torch.int8) + >>> b = torch.tensor([4, 0, 1, 0], dtype=torch.int8) + >>> torch.logical_xor(a, b) + tensor([ True, True, False, False]) + >>> torch.logical_xor(a.double(), b.double()) + tensor([ True, True, False, False]) + >>> torch.logical_xor(a.double(), b) + tensor([ True, True, False, False]) + >>> torch.logical_xor(a, b, out=torch.empty(4, dtype=torch.bool)) + tensor([ True, True, False, False]) +""".format(**common_args), +) + +add_docstr( + torch.logspace, + """ +logspace(start, end, steps, base=10.0, *, \ + out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor +""" + + r""" + +Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly +spaced from :math:`{{\text{{base}}}}^{{\text{{start}}}}` to +:math:`{{\text{{base}}}}^{{\text{{end}}}}`, inclusive, on a logarithmic scale +with base :attr:`base`. That is, the values are: + +.. math:: + (\text{base}^{\text{start}}, + \text{base}^{(\text{start} + \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \ldots, + \text{base}^{(\text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \text{base}^{\text{end}}) +""" + + """ + + +From PyTorch 1.11 logspace requires the steps argument. Use steps=100 to restore the previous behavior. + +Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + base (float, optional): base of the logarithm function. Default: ``10.0``. + +Keyword arguments: + {out} + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + {layout} + {device} + {requires_grad} + +Example:: + + >>> torch.logspace(start=-10, end=10, steps=5) + tensor([ 1.0000e-10, 1.0000e-05, 1.0000e+00, 1.0000e+05, 1.0000e+10]) + >>> torch.logspace(start=0.1, end=1.0, steps=5) + tensor([ 1.2589, 2.1135, 3.5481, 5.9566, 10.0000]) + >>> torch.logspace(start=0.1, end=1.0, steps=1) + tensor([1.2589]) + >>> torch.logspace(start=2, end=2, steps=1, base=2) + tensor([4.0]) +""".format(**factory_common_args), +) + +add_docstr( + torch.logsumexp, + r""" +logsumexp(input, dim, keepdim=False, *, out=None) + +Returns the log of summed exponentials of each row of the :attr:`input` +tensor in the given dimension :attr:`dim`. The computation is numerically +stabilized. + +For summation index :math:`j` given by `dim` and other indices :math:`i`, the result is + + .. math:: + \text{{logsumexp}}(x)_{{i}} = \log \sum_j \exp(x_{{ij}}) + +{keepdim_details} + +Args: + {input} + {dim} + {opt_keepdim} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(3, 3) + >>> torch.logsumexp(a, 1) + tensor([1.4907, 1.0593, 1.5696]) + >>> torch.dist(torch.logsumexp(a, 1), torch.log(torch.sum(torch.exp(a), 1))) + tensor(1.6859e-07) +""".format(**multi_dim_common), +) + +add_docstr( + torch.lt, + r""" +lt(input, other, *, out=None) -> Tensor + +Computes :math:`\text{input} < \text{other}` element-wise. +""" + + r""" + +The second argument can be a number or a tensor whose shape is +:ref:`broadcastable ` with the first argument. + +Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + +Keyword args: + {out} + +Returns: + A boolean tensor that is True where :attr:`input` is less than :attr:`other` and False elsewhere + +Example:: + + >>> torch.lt(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[False, False], [True, False]]) +""".format(**common_args), +) + +add_docstr( + torch.lu_unpack, + r""" +lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True, *, out=None) -> (Tensor, Tensor, Tensor) + +Unpacks the LU decomposition returned by :func:`~linalg.lu_factor` into the `P, L, U` matrices. + +.. seealso:: + + :func:`~linalg.lu` returns the matrices from the LU decomposition. Its gradient formula is more efficient + than that of doing :func:`~linalg.lu_factor` followed by :func:`~linalg.lu_unpack`. + +Args: + LU_data (Tensor): the packed LU factorization data + LU_pivots (Tensor): the packed LU factorization pivots + unpack_data (bool): flag indicating if the data should be unpacked. + If ``False``, then the returned ``L`` and ``U`` are empty tensors. + Default: ``True`` + unpack_pivots (bool): flag indicating if the pivots should be unpacked into a permutation matrix ``P``. + If ``False``, then the returned ``P`` is an empty tensor. + Default: ``True`` + +Keyword args: + out (tuple, optional): output tuple of three tensors. Ignored if `None`. + +Returns: + A namedtuple ``(P, L, U)`` + +Examples:: + + >>> A = torch.randn(2, 3, 3) + >>> LU, pivots = torch.linalg.lu_factor(A) + >>> P, L, U = torch.lu_unpack(LU, pivots) + >>> # We can recover A from the factorization + >>> A_ = P @ L @ U + >>> torch.allclose(A, A_) + True + + >>> # LU factorization of a rectangular matrix: + >>> A = torch.randn(2, 3, 2) + >>> LU, pivots = torch.linalg.lu_factor(A) + >>> P, L, U = torch.lu_unpack(LU, pivots) + >>> # P, L, U are the same as returned by linalg.lu + >>> P_, L_, U_ = torch.linalg.lu(A) + >>> torch.allclose(P, P_) and torch.allclose(L, L_) and torch.allclose(U, U_) + True + +""".format(**common_args), +) + +add_docstr( + torch.less, + r""" +less(input, other, *, out=None) -> Tensor + +Alias for :func:`torch.lt`. +""", +) + +add_docstr( + torch.lu_solve, + r""" +lu_solve(b, LU_data, LU_pivots, *, out=None) -> Tensor + +Returns the LU solve of the linear system :math:`Ax = b` using the partially pivoted +LU factorization of A from :func:`~linalg.lu_factor`. + +This function supports ``float``, ``double``, ``cfloat`` and ``cdouble`` dtypes for :attr:`input`. + +.. warning:: + + :func:`torch.lu_solve` is deprecated in favor of :func:`torch.linalg.lu_solve`. + :func:`torch.lu_solve` will be removed in a future PyTorch release. + ``X = torch.lu_solve(B, LU, pivots)`` should be replaced with + + .. code:: python + + X = linalg.lu_solve(LU, pivots, B) + +Arguments: + b (Tensor): the RHS tensor of size :math:`(*, m, k)`, where :math:`*` + is zero or more batch dimensions. + LU_data (Tensor): the pivoted LU factorization of A from :meth:`~linalg.lu_factor` of size :math:`(*, m, m)`, + where :math:`*` is zero or more batch dimensions. + LU_pivots (IntTensor): the pivots of the LU factorization from :meth:`~linalg.lu_factor` of size :math:`(*, m)`, + where :math:`*` is zero or more batch dimensions. + The batch dimensions of :attr:`LU_pivots` must be equal to the batch dimensions of + :attr:`LU_data`. + +Keyword args: + {out} + +Example:: + + >>> A = torch.randn(2, 3, 3) + >>> b = torch.randn(2, 3, 1) + >>> LU, pivots = torch.linalg.lu_factor(A) + >>> x = torch.lu_solve(b, LU, pivots) + >>> torch.dist(A @ x, b) + tensor(1.00000e-07 * + 2.8312) +""".format(**common_args), +) + +add_docstr( + torch.masked_select, + r""" +masked_select(input, mask, *, out=None) -> Tensor + +Returns a new 1-D tensor which indexes the :attr:`input` tensor according to +the boolean mask :attr:`mask` which is a `BoolTensor`. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor don't need +to match, but they must be :ref:`broadcastable `. + +.. note:: The returned tensor does **not** use the same storage + as the original tensor + +Args: + {input} + mask (BoolTensor): the tensor containing the binary mask to index with + +Keyword args: + {out} + +Example:: + + >>> x = torch.randn(3, 4) + >>> x + tensor([[ 0.3552, -2.3825, -0.8297, 0.3477], + [-1.2035, 1.2252, 0.5002, 0.6248], + [ 0.1307, -2.0608, 0.1244, 2.0139]]) + >>> mask = x.ge(0.5) + >>> mask + tensor([[False, False, False, False], + [False, True, True, True], + [False, False, False, True]]) + >>> torch.masked_select(x, mask) + tensor([ 1.2252, 0.5002, 0.6248, 2.0139]) +""".format(**common_args), +) + +add_docstr( + torch.matrix_power, + r""" +matrix_power(input, n, *, out=None) -> Tensor + +Alias for :func:`torch.linalg.matrix_power` +""", +) + +add_docstr( + torch.matrix_exp, + r""" +matrix_exp(A) -> Tensor + +Alias for :func:`torch.linalg.matrix_exp`. +""", +) + +add_docstr( + torch.max, + r""" +max(input, *, out=None) -> Tensor + +Returns the maximum value of all elements in the ``input`` tensor. + +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.6763, 0.7445, -2.2369]]) + >>> torch.max(a) + tensor(0.7445) + +.. function:: max(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + +Returns a namedtuple ``(values, indices)`` where ``values`` is the maximum +value of each row of the :attr:`input` tensor in the given dimension +:attr:`dim`. And ``indices`` is the index location of each maximum value found +(argmax). + +If ``keepdim`` is ``True``, the output tensors are of the same size +as ``input`` except in the dimension ``dim`` where they are of size 1. +Otherwise, ``dim`` is squeezed (see :func:`torch.squeeze`), resulting +in the output tensors having 1 fewer dimension than ``input``. + +.. note:: If there are multiple maximal values in a reduced row then + the indices of the first maximal value are returned. + +Args: + {input} + {opt_dim_all_reduce} + {opt_keepdim} + +Keyword args: + out (tuple, optional): the result tuple of two output tensors (max, max_indices) + +Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-1.2360, -0.2942, -0.1222, 0.8475], + [ 1.1949, -1.1127, -2.2379, -0.6702], + [ 1.5717, -0.9207, 0.1297, -1.8768], + [-0.6172, 1.0036, -0.6060, -0.2432]]) + >>> torch.max(a, 1) + torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1])) + >>> a = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + >>> a.max(dim=1, keepdim=True) + torch.return_types.max( + values=tensor([[2.], [4.]]), + indices=tensor([[1], [1]])) + >>> a.max(dim=1, keepdim=False) + torch.return_types.max( + values=tensor([2., 4.]), + indices=tensor([1, 1])) + +.. function:: max(input, other, *, out=None) -> Tensor + :noindex: + +See :func:`torch.maximum`. + +""".format(**multi_dim_common), +) + +add_docstr( + torch.maximum, + r""" +maximum(input, other, *, out=None) -> Tensor + +Computes the element-wise maximum of :attr:`input` and :attr:`other`. + +.. note:: + If one of the elements being compared is a NaN, then that element is returned. + :func:`maximum` is not supported for tensors with complex dtypes. + +Args: + {input} + other (Tensor): the second input tensor + +Keyword args: + {out} + +Example:: + + >>> a = torch.tensor((1, 2, -1)) + >>> b = torch.tensor((3, 0, 4)) + >>> torch.maximum(a, b) + tensor([3, 2, 4]) +""".format(**common_args), +) + +add_docstr( + torch.fmax, + r""" +fmax(input, other, *, out=None) -> Tensor + +Computes the element-wise maximum of :attr:`input` and :attr:`other`. + +This is like :func:`torch.maximum` except it handles NaNs differently: +if exactly one of the two elements being compared is a NaN then the non-NaN element is taken as the maximum. +Only if both elements are NaN is NaN propagated. + +This function is a wrapper around C++'s ``std::fmax`` and is similar to NumPy's ``fmax`` function. + +Supports :ref:`broadcasting to a common shape `, +:ref:`type promotion `, and integer and floating-point inputs. + +Args: + {input} + other (Tensor): the second input tensor + +Keyword args: + {out} + +Example:: + + >>> a = torch.tensor([9.7, float('nan'), 3.1, float('nan')]) + >>> b = torch.tensor([-2.2, 0.5, float('nan'), float('nan')]) + >>> torch.fmax(a, b) + tensor([9.7000, 0.5000, 3.1000, nan]) +""".format(**common_args), +) + +add_docstr( + torch.amax, + r""" +amax(input, dim, keepdim=False, *, out=None) -> Tensor + +Returns the maximum value of each slice of the :attr:`input` tensor in the given +dimension(s) :attr:`dim`. + +.. note:: + The difference between ``max``/``min`` and ``amax``/``amin`` is: + - ``amax``/``amin`` supports reducing on multiple dimensions, + - ``amax``/``amin`` does not return indices. + + Both ``max``/``min`` and ``amax``/``amin`` evenly distribute gradients between equal values + when there are multiple input elements with the same minimum or maximum value. + +{keepdim_details} + +Args: + {input} + {opt_dim_all_reduce} + {opt_keepdim} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.8177, 1.4878, -0.2491, 0.9130], + [-0.7158, 1.1775, 2.0992, 0.4817], + [-0.0053, 0.0164, -1.3738, -0.0507], + [ 1.9700, 1.1106, -1.0318, -1.0816]]) + >>> torch.amax(a, 1) + tensor([1.4878, 2.0992, 0.0164, 1.9700]) +""".format(**multi_dim_common), +) + +add_docstr( + torch.argmax, + r""" +argmax(input) -> LongTensor + +Returns the indices of the maximum value of all elements in the :attr:`input` tensor. + +This is the second value returned by :meth:`torch.max`. See its +documentation for the exact semantics of this method. + +.. note:: If there are multiple maximal values then the indices of the first maximal value are returned. + +Args: + {input} + +Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 1.3398, 0.2663, -0.2686, 0.2450], + [-0.7401, -0.8805, -0.3402, -1.1936], + [ 0.4907, -1.3948, -1.0691, -0.3132], + [-1.6092, 0.5419, -0.2993, 0.3195]]) + >>> torch.argmax(a) + tensor(0) + +.. function:: argmax(input, dim, keepdim=False) -> LongTensor + :noindex: + +Returns the indices of the maximum values of a tensor across a dimension. + +This is the second value returned by :meth:`torch.max`. See its +documentation for the exact semantics of this method. + +Args: + {input} + {opt_dim} If ``None``, the argmax of the flattened input is returned. + {opt_keepdim} + +Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 1.3398, 0.2663, -0.2686, 0.2450], + [-0.7401, -0.8805, -0.3402, -1.1936], + [ 0.4907, -1.3948, -1.0691, -0.3132], + [-1.6092, 0.5419, -0.2993, 0.3195]]) + >>> torch.argmax(a, dim=1) + tensor([ 0, 2, 0, 1]) +""".format(**single_dim_common), +) + +add_docstr( + torch.argwhere, + r""" +argwhere(input) -> Tensor + +Returns a tensor containing the indices of all non-zero elements of +:attr:`input`. Each row in the result contains the indices of a non-zero +element in :attr:`input`. The result is sorted lexicographically, with +the last index changing the fastest (C-style). + +If :attr:`input` has :math:`n` dimensions, then the resulting indices tensor +:attr:`out` is of size :math:`(z \times n)`, where :math:`z` is the total number of +non-zero elements in the :attr:`input` tensor. + +.. note:: + This function is similar to NumPy's `argwhere`. + + When :attr:`input` is on CUDA, this function causes host-device synchronization. + +Args: + {input} + +Example:: + + >>> t = torch.tensor([1, 0, 1]) + >>> torch.argwhere(t) + tensor([[0], + [2]]) + >>> t = torch.tensor([[1, 0, 1], [0, 1, 1]]) + >>> torch.argwhere(t) + tensor([[0, 0], + [0, 2], + [1, 1], + [1, 2]]) +""", +) + +add_docstr( + torch.mean, + r""" +mean(input, *, dtype=None) -> Tensor + +.. note:: + If the `input` tensor is empty, ``torch.mean()`` returns ``nan``. + This behavior is consistent with NumPy and follows the definition + that the mean over an empty set is undefined. + + +Returns the mean value of all elements in the :attr:`input` tensor. Input must be floating point or complex. + +Args: + input (Tensor): + the input tensor, either of floating point or complex dtype + +Keyword args: + {dtype} + +Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.2294, -0.5481, 1.3288]]) + >>> torch.mean(a) + tensor(0.3367) + +.. function:: mean(input, dim, keepdim=False, *, dtype=None, out=None) -> Tensor + :noindex: + +Returns the mean value of each row of the :attr:`input` tensor in the given +dimension :attr:`dim`. If :attr:`dim` is a list of dimensions, +reduce over all of them. + +{keepdim_details} + +Args: + {input} + {opt_dim_all_reduce} + {opt_keepdim} + +Keyword args: + {dtype} + {out} + +.. seealso:: + + :func:`torch.nanmean` computes the mean value of `non-NaN` elements. + +Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-0.3841, 0.6320, 0.4254, -0.7384], + [-0.9644, 1.0131, -0.6549, -1.4279], + [-0.2951, -1.3350, -0.7694, 0.5600], + [ 1.0842, -0.9580, 0.3623, 0.2343]]) + >>> torch.mean(a, 1) + tensor([-0.0163, -0.5085, -0.4599, 0.1807]) + >>> torch.mean(a, 1, True) + tensor([[-0.0163], + [-0.5085], + [-0.4599], + [ 0.1807]]) +""".format(**multi_dim_common), +) + +add_docstr( + torch.nanmean, + r""" +nanmean(input, dim=None, keepdim=False, *, dtype=None, out=None) -> Tensor + +Computes the mean of all `non-NaN` elements along the specified dimensions. +Input must be floating point or complex. + +This function is identical to :func:`torch.mean` when there are no `NaN` values +in the :attr:`input` tensor. In the presence of `NaN`, :func:`torch.mean` will +propagate the `NaN` to the output whereas :func:`torch.nanmean` will ignore the +`NaN` values (`torch.nanmean(a)` is equivalent to `torch.mean(a[~a.isnan()])`). + +{keepdim_details} + +Args: + input (Tensor): the input tensor, either of floating point or complex dtype + {opt_dim_all_reduce} + {opt_keepdim} + +Keyword args: + {dtype} + {out} + +.. seealso:: + + :func:`torch.mean` computes the mean value, propagating `NaN`. + +Example:: + + >>> x = torch.tensor([[torch.nan, 1, 2], [1, 2, 3]]) + >>> x.mean() + tensor(nan) + >>> x.nanmean() + tensor(1.8000) + >>> x.mean(dim=0) + tensor([ nan, 1.5000, 2.5000]) + >>> x.nanmean(dim=0) + tensor([1.0000, 1.5000, 2.5000]) + + # If all elements in the reduced dimensions are NaN then the result is NaN + >>> torch.tensor([torch.nan]).nanmean() + tensor(nan) +""".format(**multi_dim_common), +) + +add_docstr( + torch.median, + r""" +median(input) -> Tensor + +Returns the median of the values in :attr:`input`. + +.. note:: + The median is not unique for :attr:`input` tensors with an even number + of elements. In this case the lower of the two medians is returned. To + compute the mean of both medians, use :func:`torch.quantile` with ``q=0.5`` instead. + +.. warning:: + This function produces deterministic (sub)gradients unlike ``median(dim=0)`` + +Args: + {input} + +Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 1.5219, -1.5212, 0.2202]]) + >>> torch.median(a) + tensor(0.2202) + +.. function:: median(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + +Returns a namedtuple ``(values, indices)`` where ``values`` contains the median of each row of :attr:`input` +in the dimension :attr:`dim`, and ``indices`` contains the index of the median values found in the dimension :attr:`dim`. + +By default, :attr:`dim` is the last dimension of the :attr:`input` tensor. + +If :attr:`keepdim` is ``True``, the output tensors are of the same size +as :attr:`input` except in the dimension :attr:`dim` where they are of size 1. +Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in +the outputs tensor having 1 fewer dimension than :attr:`input`. + +.. note:: + The median is not unique for :attr:`input` tensors with an even number + of elements in the dimension :attr:`dim`. In this case the lower of the + two medians is returned. To compute the mean of both medians in + :attr:`input`, use :func:`torch.quantile` with ``q=0.5`` instead. + +.. warning:: + ``indices`` does not necessarily contain the first occurrence of each + median value found, unless it is unique. + The exact implementation details are device-specific. + Do not expect the same result when run on CPU and GPU in general. + For the same reason do not expect the gradients to be deterministic. + +Args: + {input} + {opt_dim_all_reduce} + {opt_keepdim} + +Keyword args: + out ((Tensor, Tensor), optional): The first tensor will be populated with the median values and the second + tensor, which must have dtype long, with their indices in the dimension + :attr:`dim` of :attr:`input`. + +Example:: + + >>> a = torch.randn(4, 5) + >>> a + tensor([[ 0.2505, -0.3982, -0.9948, 0.3518, -1.3131], + [ 0.3180, -0.6993, 1.0436, 0.0438, 0.2270], + [-0.2751, 0.7303, 0.2192, 0.3321, 0.2488], + [ 1.0778, -1.9510, 0.7048, 0.4742, -0.7125]]) + >>> torch.median(a, 1) + torch.return_types.median(values=tensor([-0.3982, 0.2270, 0.2488, 0.4742]), indices=tensor([1, 4, 4, 3])) +""".format(**single_dim_common), +) + +add_docstr( + torch.nanmedian, + r""" +nanmedian(input) -> Tensor + +Returns the median of the values in :attr:`input`, ignoring ``NaN`` values. + +This function is identical to :func:`torch.median` when there are no ``NaN`` values in :attr:`input`. +When :attr:`input` has one or more ``NaN`` values, :func:`torch.median` will always return ``NaN``, +while this function will return the median of the non-``NaN`` elements in :attr:`input`. +If all the elements in :attr:`input` are ``NaN`` it will also return ``NaN``. + +Args: + {input} + +Example:: + + >>> a = torch.tensor([1, float('nan'), 3, 2]) + >>> a.median() + tensor(nan) + >>> a.nanmedian() + tensor(2.) + +.. function:: nanmedian(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + +Returns a namedtuple ``(values, indices)`` where ``values`` contains the median of each row of :attr:`input` +in the dimension :attr:`dim`, ignoring ``NaN`` values, and ``indices`` contains the index of the median values +found in the dimension :attr:`dim`. + +This function is identical to :func:`torch.median` when there are no ``NaN`` values in a reduced row. When a reduced row has +one or more ``NaN`` values, :func:`torch.median` will always reduce it to ``NaN``, while this function will reduce it to the +median of the non-``NaN`` elements. If all the elements in a reduced row are ``NaN`` then it will be reduced to ``NaN``, too. + +Args: + {input} + {opt_dim_all_reduce} + {opt_keepdim} + +Keyword args: + out ((Tensor, Tensor), optional): The first tensor will be populated with the median values and the second + tensor, which must have dtype long, with their indices in the dimension + :attr:`dim` of :attr:`input`. + +Example:: + + >>> a = torch.tensor([[2, 3, 1], [float('nan'), 1, float('nan')]]) + >>> a + tensor([[2., 3., 1.], + [nan, 1., nan]]) + >>> a.median(0) + torch.return_types.median(values=tensor([nan, 1., nan]), indices=tensor([1, 1, 1])) + >>> a.nanmedian(0) + torch.return_types.nanmedian(values=tensor([2., 1., 1.]), indices=tensor([0, 1, 0])) +""".format(**single_dim_common), +) + +add_docstr( + torch.quantile, + r""" +quantile(input, q, dim=None, keepdim=False, *, interpolation='linear', out=None) -> Tensor + +Computes the q-th quantiles of each row of the :attr:`input` tensor along the dimension :attr:`dim`. + +To compute the quantile, we map q in [0, 1] to the range of indices [0, n] to find the location +of the quantile in the sorted input. If the quantile lies between two data points ``a < b`` with +indices ``i`` and ``j`` in the sorted order, result is computed according to the given +:attr:`interpolation` method as follows: + +- ``linear``: ``a + (b - a) * fraction``, where ``fraction`` is the fractional part of the computed quantile index. +- ``lower``: ``a``. +- ``higher``: ``b``. +- ``nearest``: ``a`` or ``b``, whichever's index is closer to the computed quantile index (rounding down for .5 fractions). +- ``midpoint``: ``(a + b) / 2``. + +If :attr:`q` is a 1D tensor, the first dimension of the output represents the quantiles and has size +equal to the size of :attr:`q`, the remaining dimensions are what remains from the reduction. + +.. note:: + By default :attr:`dim` is ``None`` resulting in the :attr:`input` tensor being flattened before computation. + +Args: + {input} + q (float or Tensor): a scalar or 1D tensor of values in the range [0, 1]. + {opt_dim} + {opt_keepdim} + +Keyword arguments: + interpolation (str): interpolation method to use when the desired quantile lies between two data points. + Can be ``linear``, ``lower``, ``higher``, ``midpoint`` and ``nearest``. + Default is ``linear``. + {out} + +Example:: + + >>> a = torch.randn(2, 3) + >>> a + tensor([[ 0.0795, -1.2117, 0.9765], + [ 1.1707, 0.6706, 0.4884]]) + >>> q = torch.tensor([0.25, 0.5, 0.75]) + >>> torch.quantile(a, q, dim=1, keepdim=True) + tensor([[[-0.5661], + [ 0.5795]], + + [[ 0.0795], + [ 0.6706]], + + [[ 0.5280], + [ 0.9206]]]) + >>> torch.quantile(a, q, dim=1, keepdim=True).shape + torch.Size([3, 2, 1]) + >>> a = torch.arange(4.) + >>> a + tensor([0., 1., 2., 3.]) + >>> torch.quantile(a, 0.6, interpolation='linear') + tensor(1.8000) + >>> torch.quantile(a, 0.6, interpolation='lower') + tensor(1.) + >>> torch.quantile(a, 0.6, interpolation='higher') + tensor(2.) + >>> torch.quantile(a, 0.6, interpolation='midpoint') + tensor(1.5000) + >>> torch.quantile(a, 0.6, interpolation='nearest') + tensor(2.) + >>> torch.quantile(a, 0.4, interpolation='nearest') + tensor(1.) +""".format(**single_dim_common), +) + +add_docstr( + torch.nanquantile, + r""" +nanquantile(input, q, dim=None, keepdim=False, *, interpolation='linear', out=None) -> Tensor + +This is a variant of :func:`torch.quantile` that "ignores" ``NaN`` values, +computing the quantiles :attr:`q` as if ``NaN`` values in :attr:`input` did +not exist. If all values in a reduced row are ``NaN`` then the quantiles for +that reduction will be ``NaN``. See the documentation for :func:`torch.quantile`. + +Args: + {input} + q (float or Tensor): a scalar or 1D tensor of quantile values in the range [0, 1] + {opt_dim_all_reduce} + {opt_keepdim} + +Keyword arguments: + interpolation (str): interpolation method to use when the desired quantile lies between two data points. + Can be ``linear``, ``lower``, ``higher``, ``midpoint`` and ``nearest``. + Default is ``linear``. + {out} + +Example:: + + >>> t = torch.tensor([float('nan'), 1, 2]) + >>> t.quantile(0.5) + tensor(nan) + >>> t.nanquantile(0.5) + tensor(1.5000) + >>> t = torch.tensor([[float('nan'), float('nan')], [1, 2]]) + >>> t + tensor([[nan, nan], + [1., 2.]]) + >>> t.nanquantile(0.5, dim=0) + tensor([1., 2.]) + >>> t.nanquantile(0.5, dim=1) + tensor([ nan, 1.5000]) +""".format(**single_dim_common), +) + +add_docstr( + torch.min, + r""" +min(input, *, out=None) -> Tensor + +Returns the minimum value of all elements in the :attr:`input` tensor. + +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.6750, 1.0857, 1.7197]]) + >>> torch.min(a) + tensor(0.6750) + +.. function:: min(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + +Returns a namedtuple ``(values, indices)`` where ``values`` is the minimum +value of each row of the :attr:`input` tensor in the given dimension +:attr:`dim`. And ``indices`` is the index location of each minimum value found +(argmin). + +If :attr:`keepdim` is ``True``, the output tensors are of the same size as +:attr:`input` except in the dimension :attr:`dim` where they are of size 1. +Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in +the output tensors having 1 fewer dimension than :attr:`input`. + +.. note:: If there are multiple minimal values in a reduced row then + the indices of the first minimal value are returned. + +Args: + {input} + {opt_dim_all_reduce} + {opt_keepdim} + +Keyword args: + out (tuple, optional): the tuple of two output tensors (min, min_indices) + +Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-0.6248, 1.1334, -1.1899, -0.2803], + [-1.4644, -0.2635, -0.3651, 0.6134], + [ 0.2457, 0.0384, 1.0128, 0.7015], + [-0.1153, 2.9849, 2.1458, 0.5788]]) + >>> torch.min(a, 1) + torch.return_types.min(values=tensor([-1.1899, -1.4644, 0.0384, -0.1153]), indices=tensor([2, 0, 1, 0])) + +.. function:: min(input, other, *, out=None) -> Tensor + :noindex: + +See :func:`torch.minimum`. +""".format(**single_dim_common), +) + +add_docstr( + torch.minimum, + r""" +minimum(input, other, *, out=None) -> Tensor + +Computes the element-wise minimum of :attr:`input` and :attr:`other`. + +.. note:: + If one of the elements being compared is a NaN, then that element is returned. + :func:`minimum` is not supported for tensors with complex dtypes. + +Args: + {input} + other (Tensor): the second input tensor + +Keyword args: + {out} + +Example:: + + >>> a = torch.tensor((1, 2, -1)) + >>> b = torch.tensor((3, 0, 4)) + >>> torch.minimum(a, b) + tensor([1, 0, -1]) +""".format(**common_args), +) + +add_docstr( + torch.fmin, + r""" +fmin(input, other, *, out=None) -> Tensor + +Computes the element-wise minimum of :attr:`input` and :attr:`other`. + +This is like :func:`torch.minimum` except it handles NaNs differently: +if exactly one of the two elements being compared is a NaN then the non-NaN element is taken as the minimum. +Only if both elements are NaN is NaN propagated. + +This function is a wrapper around C++'s ``std::fmin`` and is similar to NumPy's ``fmin`` function. + +Supports :ref:`broadcasting to a common shape `, +:ref:`type promotion `, and integer and floating-point inputs. + +Args: + {input} + other (Tensor): the second input tensor + +Keyword args: + {out} + +Example:: + + >>> a = torch.tensor([2.2, float('nan'), 2.1, float('nan')]) + >>> b = torch.tensor([-9.3, 0.1, float('nan'), float('nan')]) + >>> torch.fmin(a, b) + tensor([-9.3000, 0.1000, 2.1000, nan]) +""".format(**common_args), +) + +add_docstr( + torch.amin, + r""" +amin(input, dim, keepdim=False, *, out=None) -> Tensor + +Returns the minimum value of each slice of the :attr:`input` tensor in the given +dimension(s) :attr:`dim`. + +.. note:: + The difference between ``max``/``min`` and ``amax``/``amin`` is: + - ``amax``/``amin`` supports reducing on multiple dimensions, + - ``amax``/``amin`` does not return indices. + + Both ``max``/``min`` and ``amax``/``amin`` evenly distribute gradients between equal values + when there are multiple input elements with the same minimum or maximum value. + +{keepdim_details} + +Args: + {input} + {opt_dim_all_reduce} + {opt_keepdim} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.6451, -0.4866, 0.2987, -1.3312], + [-0.5744, 1.2980, 1.8397, -0.2713], + [ 0.9128, 0.9214, -1.7268, -0.2995], + [ 0.9023, 0.4853, 0.9075, -1.6165]]) + >>> torch.amin(a, 1) + tensor([-1.3312, -0.5744, -1.7268, -1.6165]) +""".format(**multi_dim_common), +) + +add_docstr( + torch.aminmax, + r""" +aminmax(input, *, dim=None, keepdim=False, out=None) -> (Tensor min, Tensor max) + +Computes the minimum and maximum values of the :attr:`input` tensor. + +Args: + input (Tensor): + The input tensor + +Keyword Args: + dim (Optional[int]): + The dimension along which to compute the values. If `None`, + computes the values over the entire :attr:`input` tensor. + Default is `None`. + keepdim (bool): + If `True`, the reduced dimensions will be kept in the output + tensor as dimensions with size 1 for broadcasting, otherwise + they will be removed, as if calling (:func:`torch.squeeze`). + Default is `False`. + out (Optional[Tuple[Tensor, Tensor]]): + Optional tensors on which to write the result. Must have the same + shape and dtype as the expected output. + Default is `None`. + +Returns: + A named tuple `(min, max)` containing the minimum and maximum values. + +Raises: + RuntimeError + If any of the dimensions to compute the values over has size 0. + +.. note:: + NaN values are propagated to the output if at least one value is NaN. + +.. seealso:: + :func:`torch.amin` computes just the minimum value + :func:`torch.amax` computes just the maximum value + +Example:: + + >>> torch.aminmax(torch.tensor([1, -3, 5])) + torch.return_types.aminmax( + min=tensor(-3), + max=tensor(5)) + + >>> # aminmax propagates NaNs + >>> torch.aminmax(torch.tensor([1, -3, 5, torch.nan])) + torch.return_types.aminmax( + min=tensor(nan), + max=tensor(nan)) + + >>> t = torch.arange(10).view(2, 5) + >>> t + tensor([[0, 1, 2, 3, 4], + [5, 6, 7, 8, 9]]) + >>> t.aminmax(dim=0, keepdim=True) + torch.return_types.aminmax( + min=tensor([[0, 1, 2, 3, 4]]), + max=tensor([[5, 6, 7, 8, 9]])) +""", +) + +add_docstr( + torch.argmin, + r""" +argmin(input, dim=None, keepdim=False) -> LongTensor + +Returns the indices of the minimum value(s) of the flattened tensor or along a dimension + +This is the second value returned by :meth:`torch.min`. See its +documentation for the exact semantics of this method. + +.. note:: If there are multiple minimal values then the indices of the first minimal value are returned. + +Args: + {input} + {opt_dim} If ``None``, the argmin of the flattened input is returned. + {opt_keepdim} + +Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.1139, 0.2254, -0.1381, 0.3687], + [ 1.0100, -1.1975, -0.0102, -0.4732], + [-0.9240, 0.1207, -0.7506, -1.0213], + [ 1.7809, -1.2960, 0.9384, 0.1438]]) + >>> torch.argmin(a) + tensor(13) + >>> torch.argmin(a, dim=1) + tensor([ 2, 1, 3, 1]) + >>> torch.argmin(a, dim=1, keepdim=True) + tensor([[2], + [1], + [3], + [1]]) +""".format(**single_dim_common), +) + +add_docstr( + torch.mm, + r""" +mm(input, mat2, out_dtype=None, *, out=None) -> Tensor + +Performs a matrix multiplication of the matrices :attr:`input` and :attr:`mat2`. + +If :attr:`input` is a :math:`(n \times m)` tensor, :attr:`mat2` is a +:math:`(m \times p)` tensor, :attr:`out` will be a :math:`(n \times p)` tensor. + +.. note:: This function does not :ref:`broadcast `. + For broadcasting matrix products, see :func:`torch.matmul`. + +Supports strided and sparse 2-D tensors as inputs, autograd with +respect to strided inputs. + +This operation has support for arguments with :ref:`sparse layouts`. +If :attr:`out` is provided its layout will be used. Otherwise, the result +layout will be deduced from that of :attr:`input`. + +{sparse_beta_warning} + +{tf32_note} + +{rocm_fp16_note} + +Args: + input (Tensor): the first matrix to be matrix multiplied + mat2 (Tensor): the second matrix to be matrix multiplied + out_dtype (dtype, optional): the dtype of the output tensor, + Supported only on CUDA and for torch.float32 given + torch.float16/torch.bfloat16 input dtypes + +Keyword args: + {out} + +Example:: + + >>> mat1 = torch.randn(2, 3) + >>> mat2 = torch.randn(3, 3) + >>> torch.mm(mat1, mat2) + tensor([[ 0.4851, 0.5037, -0.3633], + [-0.0760, -3.6705, 2.4784]]) +""".format(**common_args, **tf32_notes, **rocm_fp16_notes, **sparse_support_notes), +) + +add_docstr( + torch.hspmm, + r""" +hspmm(mat1, mat2, *, out=None) -> Tensor + +Performs a matrix multiplication of a :ref:`sparse COO matrix +` :attr:`mat1` and a strided matrix :attr:`mat2`. The +result is a (1 + 1)-dimensional :ref:`hybrid COO matrix +`. + +Args: + mat1 (Tensor): the first sparse matrix to be matrix multiplied + mat2 (Tensor): the second strided matrix to be matrix multiplied + +Keyword args: + {out} +""".format(**common_args), +) + +add_docstr( + torch.matmul, + r""" +matmul(input, other, *, out=None) -> Tensor + +Matrix product of two tensors. + +The behavior depends on the dimensionality of the tensors as follows: + +- If both tensors are 1-dimensional, the dot product (scalar) is returned. +- If both arguments are 2-dimensional, the matrix-matrix product is returned. +- If the first argument is 1-dimensional and the second argument is 2-dimensional, + a 1 is prepended to its dimension for the purpose of the matrix multiply. + After the matrix multiply, the prepended dimension is removed. +- If the first argument is 2-dimensional and the second argument is 1-dimensional, + the matrix-vector product is returned. +- If both arguments are at least 1-dimensional and at least one argument is + N-dimensional (where N > 2), then a batched matrix multiply is returned. If the first + argument is 1-dimensional, a 1 is prepended to its dimension for the purpose of the + batched matrix multiply and removed after. If the second argument is 1-dimensional, a + 1 is appended to its dimension for the purpose of the batched matrix multiple and removed after. + The non-matrix (i.e. batch) dimensions are :ref:`broadcasted ` (and thus + must be broadcastable). For example, if :attr:`input` is a + :math:`(j \times 1 \times n \times n)` tensor and :attr:`other` is a :math:`(k \times n \times n)` + tensor, :attr:`out` will be a :math:`(j \times k \times n \times n)` tensor. + + Note that the broadcasting logic only looks at the batch dimensions when determining if the inputs + are broadcastable, and not the matrix dimensions. For example, if :attr:`input` is a + :math:`(j \times 1 \times n \times m)` tensor and :attr:`other` is a :math:`(k \times m \times p)` + tensor, these inputs are valid for broadcasting even though the final two dimensions (i.e. the + matrix dimensions) are different. :attr:`out` will be a :math:`(j \times k \times n \times p)` tensor. + +This operation has support for arguments with :ref:`sparse layouts`. In particular the +matrix-matrix (both arguments 2-dimensional) supports sparse arguments with the same restrictions +as :func:`torch.mm` + +{sparse_beta_warning} + +{tf32_note} + +{rocm_fp16_note} + +.. note:: + + The 1-dimensional dot product version of this function does not support an :attr:`out` parameter. + +Arguments: + input (Tensor): the first tensor to be multiplied + other (Tensor): the second tensor to be multiplied + +Keyword args: + {out} + +Example:: + + >>> # vector x vector + >>> tensor1 = torch.randn(3) + >>> tensor2 = torch.randn(3) + >>> torch.matmul(tensor1, tensor2).size() + torch.Size([]) + >>> # matrix x vector + >>> tensor1 = torch.randn(3, 4) + >>> tensor2 = torch.randn(4) + >>> torch.matmul(tensor1, tensor2).size() + torch.Size([3]) + >>> # batched matrix x broadcasted vector + >>> tensor1 = torch.randn(10, 3, 4) + >>> tensor2 = torch.randn(4) + >>> torch.matmul(tensor1, tensor2).size() + torch.Size([10, 3]) + >>> # batched matrix x batched matrix + >>> tensor1 = torch.randn(10, 3, 4) + >>> tensor2 = torch.randn(10, 4, 5) + >>> torch.matmul(tensor1, tensor2).size() + torch.Size([10, 3, 5]) + >>> # batched matrix x broadcasted matrix + >>> tensor1 = torch.randn(10, 3, 4) + >>> tensor2 = torch.randn(4, 5) + >>> torch.matmul(tensor1, tensor2).size() + torch.Size([10, 3, 5]) + +""".format(**common_args, **tf32_notes, **rocm_fp16_notes, **sparse_support_notes), +) + +add_docstr( + torch.mode, + r""" +mode(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + +Returns a namedtuple ``(values, indices)`` where ``values`` is the mode +value of each row of the :attr:`input` tensor in the given dimension +:attr:`dim`, i.e. a value which appears most often +in that row, and ``indices`` is the index location of each mode value found. + +By default, :attr:`dim` is the last dimension of the :attr:`input` tensor. + +If :attr:`keepdim` is ``True``, the output tensors are of the same size as +:attr:`input` except in the dimension :attr:`dim` where they are of size 1. +Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting +in the output tensors having 1 fewer dimension than :attr:`input`. + +.. note:: This function is not defined for ``torch.cuda.Tensor`` yet. + +Args: + {input} + {opt_dim} + {opt_keepdim} + +Keyword args: + out (tuple, optional): the result tuple of two output tensors (values, indices) + +Example:: + + >>> b = torch.tensor([[0, 0, 0, 2, 0, 0, 2], + ... [0, 3, 0, 0, 2, 0, 1], + ... [2, 2, 2, 0, 0, 0, 3], + ... [2, 2, 3, 0, 1, 1, 0], + ... [1, 1, 0, 0, 2, 0, 2]]) + >>> torch.mode(b, 0) + torch.return_types.mode( + values=tensor([0, 2, 0, 0, 0, 0, 2]), + indices=tensor([1, 3, 4, 4, 2, 4, 4])) +""".format(**single_dim_common), +) + +add_docstr( + torch.mul, + r""" +mul(input, other, *, out=None) -> Tensor + +Multiplies :attr:`input` by :attr:`other`. + + +.. math:: + \text{out}_i = \text{input}_i \times \text{other}_i +""" + + r""" + +Supports :ref:`broadcasting to a common shape `, +:ref:`type promotion `, and integer, float, and complex inputs. + +Args: + {input} + other (Tensor or Number) - the tensor or number to multiply input by. + +Keyword args: + {out} + +Examples:: + + >>> a = torch.randn(3) + >>> a + tensor([ 0.2015, -0.4255, 2.6087]) + >>> torch.mul(a, 100) + tensor([ 20.1494, -42.5491, 260.8663]) + + >>> b = torch.randn(4, 1) + >>> b + tensor([[ 1.1207], + [-0.3137], + [ 0.0700], + [ 0.8378]]) + >>> c = torch.randn(1, 4) + >>> c + tensor([[ 0.5146, 0.1216, -0.5244, 2.2382]]) + >>> torch.mul(b, c) + tensor([[ 0.5767, 0.1363, -0.5877, 2.5083], + [-0.1614, -0.0382, 0.1645, -0.7021], + [ 0.0360, 0.0085, -0.0367, 0.1567], + [ 0.4312, 0.1019, -0.4394, 1.8753]]) +""".format(**common_args), +) + +add_docstr( + torch.multiply, + r""" +multiply(input, other, *, out=None) + +Alias for :func:`torch.mul`. +""", +) + +add_docstr( + torch.multinomial, + r""" +multinomial(input, num_samples, replacement=False, *, generator=None, out=None) -> LongTensor + +Returns a tensor where each row contains :attr:`num_samples` indices sampled +from the multinomial (a stricter definition would be multivariate, +refer to :class:`torch.distributions.multinomial.Multinomial` for more details) +probability distribution located in the corresponding row +of tensor :attr:`input`. + +.. note:: + The rows of :attr:`input` do not need to sum to one (in which case we use + the values as weights), but must be non-negative, finite and have + a non-zero sum. + +Indices are ordered from left to right according to when each was sampled +(first samples are placed in first column). + +If :attr:`input` is a vector, :attr:`out` is a vector of size :attr:`num_samples`. + +If :attr:`input` is a matrix with `m` rows, :attr:`out` is an matrix of shape +:math:`(m \times \text{{num\_samples}})`. + +If replacement is ``True``, samples are drawn with replacement. + +If not, they are drawn without replacement, which means that when a +sample index is drawn for a row, it cannot be drawn again for that row. + +.. note:: + When drawn without replacement, :attr:`num_samples` must be lower than + number of non-zero elements in :attr:`input` (or the min number of non-zero + elements in each row of :attr:`input` if it is a matrix). + +Args: + input (Tensor): the input tensor containing probabilities + num_samples (int): number of samples to draw + replacement (bool, optional): whether to draw with replacement or not + +Keyword args: + {generator} + {out} + +Example:: + + >>> weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # create a tensor of weights + >>> torch.multinomial(weights, 2) + tensor([1, 2]) + >>> torch.multinomial(weights, 5) # ERROR! + RuntimeError: cannot sample n_sample > prob_dist.size(-1) samples without replacement + >>> torch.multinomial(weights, 4, replacement=True) + tensor([ 2, 1, 1, 1]) +""".format(**common_args), +) + +add_docstr( + torch.mv, + r""" +mv(input, vec, *, out=None) -> Tensor + +Performs a matrix-vector product of the matrix :attr:`input` and the vector +:attr:`vec`. + +If :attr:`input` is a :math:`(n \times m)` tensor, :attr:`vec` is a 1-D tensor of +size :math:`m`, :attr:`out` will be 1-D of size :math:`n`. + +.. note:: This function does not :ref:`broadcast `. + +Args: + input (Tensor): matrix to be multiplied + vec (Tensor): vector to be multiplied + +Keyword args: + {out} + +Example:: + + >>> mat = torch.randn(2, 3) + >>> vec = torch.randn(3) + >>> torch.mv(mat, vec) + tensor([ 1.0404, -0.6361]) +""".format(**common_args), +) + +add_docstr( + torch.mvlgamma, + r""" +mvlgamma(input, p, *, out=None) -> Tensor + +Alias for :func:`torch.special.multigammaln`. +""", +) + +add_docstr( + torch.movedim, + r""" +movedim(input, source, destination) -> Tensor + +Moves the dimension(s) of :attr:`input` at the position(s) in :attr:`source` +to the position(s) in :attr:`destination`. + +Other dimensions of :attr:`input` that are not explicitly moved remain in +their original order and appear at the positions not specified in :attr:`destination`. + +Args: + {input} + source (int or tuple of ints): Original positions of the dims to move. These must be unique. + destination (int or tuple of ints): Destination positions for each of the original dims. These must also be unique. + +Examples:: + + >>> t = torch.randn(3,2,1) + >>> t + tensor([[[-0.3362], + [-0.8437]], + + [[-0.9627], + [ 0.1727]], + + [[ 0.5173], + [-0.1398]]]) + >>> torch.movedim(t, 1, 0).shape + torch.Size([2, 3, 1]) + >>> torch.movedim(t, 1, 0) + tensor([[[-0.3362], + [-0.9627], + [ 0.5173]], + + [[-0.8437], + [ 0.1727], + [-0.1398]]]) + >>> torch.movedim(t, (1, 2), (0, 1)).shape + torch.Size([2, 1, 3]) + >>> torch.movedim(t, (1, 2), (0, 1)) + tensor([[[-0.3362, -0.9627, 0.5173]], + + [[-0.8437, 0.1727, -0.1398]]]) +""".format(**common_args), +) + +add_docstr( + torch.moveaxis, + r""" +moveaxis(input, source, destination) -> Tensor + +Alias for :func:`torch.movedim`. + +This function is equivalent to NumPy's moveaxis function. + +Examples:: + + >>> t = torch.randn(3,2,1) + >>> t + tensor([[[-0.3362], + [-0.8437]], + + [[-0.9627], + [ 0.1727]], + + [[ 0.5173], + [-0.1398]]]) + >>> torch.moveaxis(t, 1, 0).shape + torch.Size([2, 3, 1]) + >>> torch.moveaxis(t, 1, 0) + tensor([[[-0.3362], + [-0.9627], + [ 0.5173]], + + [[-0.8437], + [ 0.1727], + [-0.1398]]]) + >>> torch.moveaxis(t, (1, 2), (0, 1)).shape + torch.Size([2, 1, 3]) + >>> torch.moveaxis(t, (1, 2), (0, 1)) + tensor([[[-0.3362, -0.9627, 0.5173]], + + [[-0.8437, 0.1727, -0.1398]]]) +""".format(**common_args), +) + +add_docstr( + torch.swapdims, + r""" +swapdims(input, dim0, dim1) -> Tensor + +Alias for :func:`torch.transpose`. + +This function is equivalent to NumPy's swapaxes function. + +Examples:: + + >>> x = torch.tensor([[[0,1],[2,3]],[[4,5],[6,7]]]) + >>> x + tensor([[[0, 1], + [2, 3]], + + [[4, 5], + [6, 7]]]) + >>> torch.swapdims(x, 0, 1) + tensor([[[0, 1], + [4, 5]], + + [[2, 3], + [6, 7]]]) + >>> torch.swapdims(x, 0, 2) + tensor([[[0, 4], + [2, 6]], + + [[1, 5], + [3, 7]]]) +""".format(**common_args), +) + +add_docstr( + torch.swapaxes, + r""" +swapaxes(input, axis0, axis1) -> Tensor + +Alias for :func:`torch.transpose`. + +This function is equivalent to NumPy's swapaxes function. + +Examples:: + + >>> x = torch.tensor([[[0,1],[2,3]],[[4,5],[6,7]]]) + >>> x + tensor([[[0, 1], + [2, 3]], + + [[4, 5], + [6, 7]]]) + >>> torch.swapaxes(x, 0, 1) + tensor([[[0, 1], + [4, 5]], + + [[2, 3], + [6, 7]]]) + >>> torch.swapaxes(x, 0, 2) + tensor([[[0, 4], + [2, 6]], + + [[1, 5], + [3, 7]]]) +""".format(**common_args), +) + +add_docstr( + torch.narrow, + r""" +narrow(input, dim, start, length) -> Tensor + +Returns a new tensor that is a narrowed version of :attr:`input` tensor. The +dimension :attr:`dim` is input from :attr:`start` to ``start + length``. The +returned tensor and :attr:`input` tensor share the same underlying storage. + +Args: + input (Tensor): the tensor to narrow + dim (int): the dimension along which to narrow + start (int or Tensor): index of the element to start the narrowed dimension + from. Can be negative, which means indexing from the end of `dim`. If + `Tensor`, it must be an 0-dim integral `Tensor` (bools not allowed) + length (int): length of the narrowed dimension, must be weakly positive + +Example:: + + >>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> torch.narrow(x, 0, 0, 2) + tensor([[ 1, 2, 3], + [ 4, 5, 6]]) + >>> torch.narrow(x, 1, 1, 2) + tensor([[ 2, 3], + [ 5, 6], + [ 8, 9]]) + >>> torch.narrow(x, -1, torch.tensor(-1), 1) + tensor([[3], + [6], + [9]]) +""", +) + +add_docstr( + torch.narrow_copy, + r""" +narrow_copy(input, dim, start, length, *, out=None) -> Tensor + +Same as :meth:`Tensor.narrow` except this returns a copy rather +than shared storage. This is primarily for sparse tensors, which +do not have a shared-storage narrow method. + +Args: + input (Tensor): the tensor to narrow + dim (int): the dimension along which to narrow + start (int): index of the element to start the narrowed dimension from. Can + be negative, which means indexing from the end of `dim` + length (int): length of the narrowed dimension, must be weakly positive + +Keyword args: + {out} + +Example:: + + >>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> torch.narrow_copy(x, 0, 0, 2) + tensor([[ 1, 2, 3], + [ 4, 5, 6]]) + >>> torch.narrow_copy(x, 1, 1, 2) + tensor([[ 2, 3], + [ 5, 6], + [ 8, 9]]) + >>> s = torch.arange(16).reshape(2, 2, 2, 2).to_sparse(2) + >>> torch.narrow_copy(s, 0, 0, 1) + tensor(indices=tensor([[0, 0], + [0, 1]]), + values=tensor([[[0, 1], + [2, 3]], + + [[4, 5], + [6, 7]]]), + size=(1, 2, 2, 2), nnz=2, layout=torch.sparse_coo) + +.. seealso:: + + :func:`torch.narrow` for a non copy variant + +""".format(**common_args), +) + +add_docstr( + torch.nan_to_num, + r""" +nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None) -> Tensor + +Replaces :literal:`NaN`, positive infinity, and negative infinity values in :attr:`input` +with the values specified by :attr:`nan`, :attr:`posinf`, and :attr:`neginf`, respectively. +By default, :literal:`NaN`\ s are replaced with zero, positive infinity is replaced with the +greatest finite value representable by :attr:`input`'s dtype, and negative infinity +is replaced with the least finite value representable by :attr:`input`'s dtype. + +Args: + {input} + nan (Number, optional): the value to replace :literal:`NaN`\s with. Default is zero. + posinf (Number, optional): if a Number, the value to replace positive infinity values with. + If None, positive infinity values are replaced with the greatest finite value representable by :attr:`input`'s dtype. + Default is None. + neginf (Number, optional): if a Number, the value to replace negative infinity values with. + If None, negative infinity values are replaced with the lowest finite value representable by :attr:`input`'s dtype. + Default is None. + +Keyword args: + {out} + +Example:: + + >>> x = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.14]) + >>> torch.nan_to_num(x) + tensor([ 0.0000e+00, 3.4028e+38, -3.4028e+38, 3.1400e+00]) + >>> torch.nan_to_num(x, nan=2.0) + tensor([ 2.0000e+00, 3.4028e+38, -3.4028e+38, 3.1400e+00]) + >>> torch.nan_to_num(x, nan=2.0, posinf=1.0) + tensor([ 2.0000e+00, 1.0000e+00, -3.4028e+38, 3.1400e+00]) + +""".format(**common_args), +) + +add_docstr( + torch.ne, + r""" +ne(input, other, *, out=None) -> Tensor + +Computes :math:`\text{input} \neq \text{other}` element-wise. +""" + + r""" + +The second argument can be a number or a tensor whose shape is +:ref:`broadcastable ` with the first argument. + +Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + +Keyword args: + {out} + +Returns: + A boolean tensor that is True where :attr:`input` is not equal to :attr:`other` and False elsewhere + +Example:: + + >>> torch.ne(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[False, True], [True, False]]) +""".format(**common_args), +) + +add_docstr( + torch.not_equal, + r""" +not_equal(input, other, *, out=None) -> Tensor + +Alias for :func:`torch.ne`. +""", +) + +add_docstr( + torch.neg, + r""" +neg(input, *, out=None) -> Tensor + +Returns a new tensor with the negative of the elements of :attr:`input`. + +.. math:: + \text{out} = -1 \times \text{input} +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(5) + >>> a + tensor([ 0.0090, -0.2262, -0.0682, -0.2866, 0.3940]) + >>> torch.neg(a) + tensor([-0.0090, 0.2262, 0.0682, 0.2866, -0.3940]) +""".format(**common_args), +) + +add_docstr( + torch.negative, + r""" +negative(input, *, out=None) -> Tensor + +Alias for :func:`torch.neg` +""", +) + +add_docstr( + torch.nextafter, + r""" +nextafter(input, other, *, out=None) -> Tensor + +Return the next floating-point value after :attr:`input` towards :attr:`other`, elementwise. + +The shapes of ``input`` and ``other`` must be +:ref:`broadcastable `. + +Args: + input (Tensor): the first input tensor + other (Tensor): the second input tensor + +Keyword args: + {out} + +Example:: + + >>> eps = torch.finfo(torch.float32).eps + >>> torch.nextafter(torch.tensor([1.0, 2.0]), torch.tensor([2.0, 1.0])) == torch.tensor([eps + 1, 2 - eps]) + tensor([True, True]) + +""".format(**common_args), +) + +add_docstr( + torch.nonzero, + r""" +nonzero(input, *, out=None, as_tuple=False) -> LongTensor or tuple of LongTensors + +.. note:: + :func:`torch.nonzero(..., as_tuple=False) ` (default) returns a + 2-D tensor where each row is the index for a nonzero value. + + :func:`torch.nonzero(..., as_tuple=True) ` returns a tuple of 1-D + index tensors, allowing for advanced indexing, so ``x[x.nonzero(as_tuple=True)]`` + gives all nonzero values of tensor ``x``. Of the returned tuple, each index tensor + contains nonzero indices for a certain dimension. + + See below for more details on the two behaviors. + + When :attr:`input` is on CUDA, :func:`torch.nonzero() ` causes + host-device synchronization. + +**When** :attr:`as_tuple` **is** ``False`` **(default)**: + +Returns a tensor containing the indices of all non-zero elements of +:attr:`input`. Each row in the result contains the indices of a non-zero +element in :attr:`input`. The result is sorted lexicographically, with +the last index changing the fastest (C-style). + +If :attr:`input` has :math:`n` dimensions, then the resulting indices tensor +:attr:`out` is of size :math:`(z \times n)`, where :math:`z` is the total number of +non-zero elements in the :attr:`input` tensor. + +**When** :attr:`as_tuple` **is** ``True``: + +Returns a tuple of 1-D tensors, one for each dimension in :attr:`input`, +each containing the indices (in that dimension) of all non-zero elements of +:attr:`input` . + +If :attr:`input` has :math:`n` dimensions, then the resulting tuple contains :math:`n` +tensors of size :math:`z`, where :math:`z` is the total number of +non-zero elements in the :attr:`input` tensor. + +As a special case, when :attr:`input` has zero dimensions and a nonzero scalar +value, it is treated as a one-dimensional tensor with one element. + +Args: + {input} + +Keyword args: + out (LongTensor, optional): the output tensor containing indices + +Returns: + LongTensor or tuple of LongTensor: If :attr:`as_tuple` is ``False``, the output + tensor containing indices. If :attr:`as_tuple` is ``True``, one 1-D tensor for + each dimension, containing the indices of each nonzero element along that + dimension. + +Example:: + + >>> torch.nonzero(torch.tensor([1, 1, 1, 0, 1])) + tensor([[ 0], + [ 1], + [ 2], + [ 4]]) + >>> torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0], + ... [0.0, 0.4, 0.0, 0.0], + ... [0.0, 0.0, 1.2, 0.0], + ... [0.0, 0.0, 0.0,-0.4]])) + tensor([[ 0, 0], + [ 1, 1], + [ 2, 2], + [ 3, 3]]) + >>> torch.nonzero(torch.tensor([1, 1, 1, 0, 1]), as_tuple=True) + (tensor([0, 1, 2, 4]),) + >>> torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0], + ... [0.0, 0.4, 0.0, 0.0], + ... [0.0, 0.0, 1.2, 0.0], + ... [0.0, 0.0, 0.0,-0.4]]), as_tuple=True) + (tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])) + >>> torch.nonzero(torch.tensor(5), as_tuple=True) + (tensor([0]),) +""".format(**common_args), +) + +add_docstr( + torch.normal, + r""" +normal(mean, std, *, generator=None, out=None) -> Tensor + +Returns a tensor of random numbers drawn from separate normal distributions +whose mean and standard deviation are given. + +The :attr:`mean` is a tensor with the mean of +each output element's normal distribution + +The :attr:`std` is a tensor with the standard deviation of +each output element's normal distribution + +The shapes of :attr:`mean` and :attr:`std` don't need to match, but the +total number of elements in each tensor need to be the same. + +.. note:: When the shapes do not match, the shape of :attr:`mean` + is used as the shape for the returned output tensor + +.. note:: When :attr:`std` is a CUDA tensor, this function synchronizes + its device with the CPU. + +Args: + mean (Tensor): the tensor of per-element means + std (Tensor): the tensor of per-element standard deviations + +Keyword args: + {generator} + {out} + +Example:: + + >>> torch.normal(mean=torch.arange(1., 11.), std=torch.arange(1, 0, -0.1)) + tensor([ 1.0425, 3.5672, 2.7969, 4.2925, 4.7229, 6.2134, + 8.0505, 8.1408, 9.0563, 10.0566]) + +.. function:: normal(mean=0.0, std, *, out=None) -> Tensor + :noindex: + +Similar to the function above, but the means are shared among all drawn +elements. + +Args: + mean (float, optional): the mean for all distributions + std (Tensor): the tensor of per-element standard deviations + +Keyword args: + {out} + +Example:: + + >>> torch.normal(mean=0.5, std=torch.arange(1., 6.)) + tensor([-1.2793, -1.0732, -2.0687, 5.1177, -1.2303]) + +.. function:: normal(mean, std=1.0, *, out=None) -> Tensor + :noindex: + +Similar to the function above, but the standard deviations are shared among +all drawn elements. + +Args: + mean (Tensor): the tensor of per-element means + std (float, optional): the standard deviation for all distributions + +Keyword args: + out (Tensor, optional): the output tensor + +Example:: + + >>> torch.normal(mean=torch.arange(1., 6.)) + tensor([ 1.1552, 2.6148, 2.6535, 5.8318, 4.2361]) + +.. function:: normal(mean, std, size, *, out=None) -> Tensor + :noindex: + +Similar to the function above, but the means and standard deviations are shared +among all drawn elements. The resulting tensor has size given by :attr:`size`. + +Args: + mean (float): the mean for all distributions + std (float): the standard deviation for all distributions + size (int...): a sequence of integers defining the shape of the output tensor. + +Keyword args: + {out} + +Example:: + + >>> torch.normal(2, 3, size=(1, 4)) + tensor([[-1.3987, -1.9544, 3.6048, 0.7909]]) +""".format(**common_args), +) + +add_docstr( + torch.numel, + r""" +numel(input: Tensor) -> int + +Returns the total number of elements in the :attr:`input` tensor. + +Args: + {input} + +Example:: + + >>> a = torch.randn(1, 2, 3, 4, 5) + >>> torch.numel(a) + 120 + >>> a = torch.zeros(4,4) + >>> torch.numel(a) + 16 + +""".format(**common_args), +) + +add_docstr( + torch.ones, + r""" +ones(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + +Returns a tensor filled with the scalar value `1`, with the shape defined +by the variable argument :attr:`size`. + +Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + +Keyword arguments: + {out} + {dtype} + {layout} + {device} + {requires_grad} + +Example:: + + >>> torch.ones(2, 3) + tensor([[ 1., 1., 1.], + [ 1., 1., 1.]]) + + >>> torch.ones(5) + tensor([ 1., 1., 1., 1., 1.]) + +""".format(**factory_common_args), +) + +add_docstr( + torch.ones_like, + r""" +ones_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + +Returns a tensor filled with the scalar value `1`, with the same size as +:attr:`input`. ``torch.ones_like(input)`` is equivalent to +``torch.ones(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``. + +.. warning:: + As of 0.4, this function does not support an :attr:`out` keyword. As an alternative, + the old ``torch.ones_like(input, out=output)`` is equivalent to + ``torch.ones(input.size(), out=output)``. + +Args: + {input} + +Keyword arguments: + {dtype} + {layout} + {device} + {requires_grad} + {memory_format} + +Example:: + + >>> input = torch.empty(2, 3) + >>> torch.ones_like(input) + tensor([[ 1., 1., 1.], + [ 1., 1., 1.]]) +""".format(**factory_like_common_args), +) + +add_docstr( + torch.orgqr, + r""" +orgqr(input, tau) -> Tensor + +Alias for :func:`torch.linalg.householder_product`. +""", +) + +add_docstr( + torch.ormqr, + r""" +ormqr(input, tau, other, left=True, transpose=False, *, out=None) -> Tensor + +Computes the matrix-matrix multiplication of a product of Householder matrices with a general matrix. + +Multiplies a :math:`m \times n` matrix `C` (given by :attr:`other`) with a matrix `Q`, +where `Q` is represented using Householder reflectors `(input, tau)`. +See `Representation of Orthogonal or Unitary Matrices`_ for further details. + +If :attr:`left` is `True` then `op(Q)` times `C` is computed, otherwise the result is `C` times `op(Q)`. +When :attr:`left` is `True`, the implicit matrix `Q` has size :math:`m \times m`. +It has size :math:`n \times n` otherwise. +If :attr:`transpose` is `True` then `op` is the conjugate transpose operation, otherwise it's a no-op. + +Supports inputs of float, double, cfloat and cdouble dtypes. +Also supports batched inputs, and, if the input is batched, the output is batched with the same dimensions. + +.. seealso:: + :func:`torch.geqrf` can be used to form the Householder representation `(input, tau)` of matrix `Q` + from the QR decomposition. + +.. note:: + This function supports backward but it is only fast when ``(input, tau)`` do not require gradients + and/or ``tau.size(-1)`` is very small. + `` + +Args: + input (Tensor): tensor of shape `(*, mn, k)` where `*` is zero or more batch dimensions + and `mn` equals to `m` or `n` depending on the :attr:`left`. + tau (Tensor): tensor of shape `(*, min(mn, k))` where `*` is zero or more batch dimensions. + other (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + left (bool): controls the order of multiplication. + transpose (bool): controls whether the matrix `Q` is conjugate transposed or not. + +Keyword args: + out (Tensor, optional): the output Tensor. Ignored if `None`. Default: `None`. + +.. _Representation of Orthogonal or Unitary Matrices: + https://www.netlib.org/lapack/lug/node128.html +""", +) + +add_docstr( + torch.permute, + r""" +permute(input, dims) -> Tensor + +Returns a view of the original tensor :attr:`input` with its dimensions permuted. + +Args: + {input} + dims (tuple of int): The desired ordering of dimensions + +Example: + >>> x = torch.randn(2, 3, 5) + >>> x.size() + torch.Size([2, 3, 5]) + >>> torch.permute(x, (2, 0, 1)).size() + torch.Size([5, 2, 3]) +""".format(**common_args), +) + +add_docstr( + torch.poisson, + r""" +poisson(input, generator=None) -> Tensor + +Returns a tensor of the same size as :attr:`input` with each element +sampled from a Poisson distribution with rate parameter given by the corresponding +element in :attr:`input` i.e., + +.. math:: + \text{{out}}_i \sim \text{{Poisson}}(\text{{input}}_i) + +:attr:`input` must be non-negative. + +Args: + input (Tensor): the input tensor containing the rates of the Poisson distribution + +Keyword args: + {generator} + +Example:: + + >>> rates = torch.rand(4, 4) * 5 # rate parameter between 0 and 5 + >>> torch.poisson(rates) + tensor([[9., 1., 3., 5.], + [8., 6., 6., 0.], + [0., 4., 5., 3.], + [2., 1., 4., 2.]]) +""".format(**common_args), +) + +add_docstr( + torch.polygamma, + r""" +polygamma(n, input, *, out=None) -> Tensor + +Alias for :func:`torch.special.polygamma`. +""", +) + +add_docstr( + torch.positive, + r""" +positive(input) -> Tensor + +Returns :attr:`input`. +Throws a runtime error if :attr:`input` is a bool tensor. +""" + + r""" +Args: + {input} + +Example:: + + >>> t = torch.randn(5) + >>> t + tensor([ 0.0090, -0.2262, -0.0682, -0.2866, 0.3940]) + >>> torch.positive(t) + tensor([ 0.0090, -0.2262, -0.0682, -0.2866, 0.3940]) +""".format(**common_args), +) + +add_docstr( + torch.pow, + r""" +pow(input, exponent, *, out=None) -> Tensor + +Takes the power of each element in :attr:`input` with :attr:`exponent` and +returns a tensor with the result. + +:attr:`exponent` can be either a single ``float`` number or a `Tensor` +with the same number of elements as :attr:`input`. + +When :attr:`exponent` is a scalar value, the operation applied is: + +.. math:: + \text{out}_i = x_i ^ \text{exponent} + +When :attr:`exponent` is a tensor, the operation applied is: + +.. math:: + \text{out}_i = x_i ^ {\text{exponent}_i} +""" + + r""" +When :attr:`exponent` is a tensor, the shapes of :attr:`input` +and :attr:`exponent` must be :ref:`broadcastable `. + +Args: + {input} + exponent (float or tensor): the exponent value + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.4331, 1.2475, 0.6834, -0.2791]) + >>> torch.pow(a, 2) + tensor([ 0.1875, 1.5561, 0.4670, 0.0779]) + >>> exp = torch.arange(1., 5.) + + >>> a = torch.arange(1., 5.) + >>> a + tensor([ 1., 2., 3., 4.]) + >>> exp + tensor([ 1., 2., 3., 4.]) + >>> torch.pow(a, exp) + tensor([ 1., 4., 27., 256.]) + +.. function:: pow(self, exponent, *, out=None) -> Tensor + :noindex: + +:attr:`self` is a scalar ``float`` value, and :attr:`exponent` is a tensor. +The returned tensor :attr:`out` is of the same shape as :attr:`exponent` + +The operation applied is: + +.. math:: + \text{{out}}_i = \text{{self}} ^ {{\text{{exponent}}_i}} + +Args: + self (float): the scalar base value for the power operation + exponent (Tensor): the exponent tensor + +Keyword args: + {out} + +Example:: + + >>> exp = torch.arange(1., 5.) + >>> base = 2 + >>> torch.pow(base, exp) + tensor([ 2., 4., 8., 16.]) +""".format(**common_args), +) + +add_docstr( + torch.float_power, + r""" +float_power(input, exponent, *, out=None) -> Tensor + +Raises :attr:`input` to the power of :attr:`exponent`, elementwise, in double precision. +If neither input is complex returns a ``torch.float64`` tensor, +and if one or more inputs is complex returns a ``torch.complex128`` tensor. + +.. note:: + This function always computes in double precision, unlike :func:`torch.pow`, + which implements more typical :ref:`type promotion `. + This is useful when the computation needs to be performed in a wider or more precise dtype, + or the results of the computation may contain fractional values not representable in the input dtypes, + like when an integer base is raised to a negative integer exponent. + +Args: + input (Tensor or Number): the base value(s) + exponent (Tensor or Number): the exponent value(s) + +Keyword args: + {out} + +Example:: + + >>> a = torch.randint(10, (4,)) + >>> a + tensor([6, 4, 7, 1]) + >>> torch.float_power(a, 2) + tensor([36., 16., 49., 1.], dtype=torch.float64) + + >>> a = torch.arange(1, 5) + >>> a + tensor([ 1, 2, 3, 4]) + >>> exp = torch.tensor([2, -3, 4, -5]) + >>> exp + tensor([ 2, -3, 4, -5]) + >>> torch.float_power(a, exp) + tensor([1.0000e+00, 1.2500e-01, 8.1000e+01, 9.7656e-04], dtype=torch.float64) +""".format(**common_args), +) + +add_docstr( + torch.prod, + r""" +prod(input: Tensor, *, dtype: Optional[_dtype]) -> Tensor + +Returns the product of all elements in the :attr:`input` tensor. + +Args: + {input} + +Keyword args: + {dtype} + +Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[-0.8020, 0.5428, -1.5854]]) + >>> torch.prod(a) + tensor(0.6902) + +.. function:: prod(input, dim, keepdim=False, *, dtype=None) -> Tensor + :noindex: + +Returns the product of each row of the :attr:`input` tensor in the given +dimension :attr:`dim`. + +{keepdim_details} + +Args: + {input} + {opt_dim_all_reduce} + {opt_keepdim} + +Keyword args: + {dtype} + +Example:: + + >>> a = torch.randn(4, 2) + >>> a + tensor([[ 0.5261, -0.3837], + [ 1.1857, -0.2498], + [-1.1646, 0.0705], + [ 1.1131, -1.0629]]) + >>> torch.prod(a, 1) + tensor([-0.2018, -0.2962, -0.0821, -1.1831]) +""".format(**single_dim_common), +) + +add_docstr( + torch.promote_types, + r""" +promote_types(type1, type2) -> dtype + +Returns the :class:`torch.dtype` with the smallest size and scalar kind that is +not smaller nor of lower kind than either `type1` or `type2`. See type promotion +:ref:`documentation ` for more information on the type +promotion logic. + +Args: + type1 (:class:`torch.dtype`) + type2 (:class:`torch.dtype`) + +Example:: + + >>> torch.promote_types(torch.int32, torch.float32) + torch.float32 + >>> torch.promote_types(torch.uint8, torch.long) + torch.long +""", +) + +add_docstr( + torch.qr, + r""" +qr(input: Tensor, some: bool = True, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None]) -> (Tensor, Tensor) + +Computes the QR decomposition of a matrix or a batch of matrices :attr:`input`, +and returns a namedtuple (Q, R) of tensors such that :math:`\text{input} = Q R` +with :math:`Q` being an orthogonal matrix or batch of orthogonal matrices and +:math:`R` being an upper triangular matrix or batch of upper triangular matrices. + +If :attr:`some` is ``True``, then this function returns the thin (reduced) QR factorization. +Otherwise, if :attr:`some` is ``False``, this function returns the complete QR factorization. + +.. warning:: + + :func:`torch.qr` is deprecated in favor of :func:`torch.linalg.qr` + and will be removed in a future PyTorch release. The boolean parameter :attr:`some` has been + replaced with a string parameter :attr:`mode`. + + ``Q, R = torch.qr(A)`` should be replaced with + + .. code:: python + + Q, R = torch.linalg.qr(A) + + ``Q, R = torch.qr(A, some=False)`` should be replaced with + + .. code:: python + + Q, R = torch.linalg.qr(A, mode="complete") + +.. warning:: + If you plan to backpropagate through QR, note that the current backward implementation + is only well-defined when the first :math:`\min(input.size(-1), input.size(-2))` + columns of :attr:`input` are linearly independent. + This behavior will probably change once QR supports pivoting. + +.. note:: This function uses LAPACK for CPU inputs and MAGMA for CUDA inputs, + and may produce different (valid) decompositions on different device types + or different platforms. + +Args: + input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more + batch dimensions consisting of matrices of dimension :math:`m \times n`. + some (bool, optional): Set to ``True`` for reduced QR decomposition and ``False`` for + complete QR decomposition. If `k = min(m, n)` then: + + * ``some=True`` : returns `(Q, R)` with dimensions (m, k), (k, n) (default) + + * ``'some=False'``: returns `(Q, R)` with dimensions (m, m), (m, n) + +Keyword args: + out (tuple, optional): tuple of `Q` and `R` tensors. + The dimensions of `Q` and `R` are detailed in the description of :attr:`some` above. + +Example:: + + >>> a = torch.tensor([[12., -51, 4], [6, 167, -68], [-4, 24, -41]]) + >>> q, r = torch.qr(a) + >>> q + tensor([[-0.8571, 0.3943, 0.3314], + [-0.4286, -0.9029, -0.0343], + [ 0.2857, -0.1714, 0.9429]]) + >>> r + tensor([[ -14.0000, -21.0000, 14.0000], + [ 0.0000, -175.0000, 70.0000], + [ 0.0000, 0.0000, -35.0000]]) + >>> torch.mm(q, r).round() + tensor([[ 12., -51., 4.], + [ 6., 167., -68.], + [ -4., 24., -41.]]) + >>> torch.mm(q.t(), q).round() + tensor([[ 1., 0., 0.], + [ 0., 1., -0.], + [ 0., -0., 1.]]) + >>> a = torch.randn(3, 4, 5) + >>> q, r = torch.qr(a, some=False) + >>> torch.allclose(torch.matmul(q, r), a) + True + >>> torch.allclose(torch.matmul(q.mT, q), torch.eye(5)) + True +""", +) + +add_docstr( + torch.rad2deg, + r""" +rad2deg(input: Tensor, *, out: Optional[Tensor]) -> Tensor + +Returns a new tensor with each of the elements of :attr:`input` +converted from angles in radians to degrees. + +Args: + {input} + +Keyword arguments: + {out} + +Example:: + + >>> a = torch.tensor([[3.142, -3.142], [6.283, -6.283], [1.570, -1.570]]) + >>> torch.rad2deg(a) + tensor([[ 180.0233, -180.0233], + [ 359.9894, -359.9894], + [ 89.9544, -89.9544]]) + +""".format(**common_args), +) + +add_docstr( + torch.deg2rad, + r""" +deg2rad(input, *, out=None) -> Tensor + +Returns a new tensor with each of the elements of :attr:`input` +converted from angles in degrees to radians. + +Args: + {input} + +Keyword arguments: + {out} + +Example:: + + >>> a = torch.tensor([[180.0, -180.0], [360.0, -360.0], [90.0, -90.0]]) + >>> torch.deg2rad(a) + tensor([[ 3.1416, -3.1416], + [ 6.2832, -6.2832], + [ 1.5708, -1.5708]]) + +""".format(**common_args), +) + +add_docstr( + torch.heaviside, + r""" +heaviside(input, values, *, out=None) -> Tensor + +Computes the Heaviside step function for each element in :attr:`input`. +The Heaviside step function is defined as: + +.. math:: + \text{{heaviside}}(input, values) = \begin{cases} + 0, & \text{if input < 0}\\ + values, & \text{if input == 0}\\ + 1, & \text{if input > 0} + \end{cases} +""" + + r""" + +Args: + {input} + values (Tensor): The values to use where :attr:`input` is zero. + +Keyword arguments: + {out} + +Example:: + + >>> input = torch.tensor([-1.5, 0, 2.0]) + >>> values = torch.tensor([0.5]) + >>> torch.heaviside(input, values) + tensor([0.0000, 0.5000, 1.0000]) + >>> values = torch.tensor([1.2, -2.0, 3.5]) + >>> torch.heaviside(input, values) + tensor([0., -2., 1.]) + +""".format(**common_args), +) + +add_docstr( + torch.rand, + """ +rand(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, \ +requires_grad=False, pin_memory=False) -> Tensor +""" + + r""" +Returns a tensor filled with random numbers from a uniform distribution +on the interval :math:`[0, 1)` + +The shape of the tensor is defined by the variable argument :attr:`size`. + +Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + +Keyword args: + {generator} + {out} + {dtype} + {layout} + {device} + {requires_grad} + {pin_memory} + +Example:: + + >>> torch.rand(4) + tensor([ 0.5204, 0.2503, 0.3525, 0.5673]) + >>> torch.rand(2, 3) + tensor([[ 0.8237, 0.5781, 0.6879], + [ 0.3816, 0.7249, 0.0998]]) +""".format(**factory_common_args), +) + +add_docstr( + torch.rand_like, + r""" +rand_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + +Returns a tensor with the same size as :attr:`input` that is filled with +random numbers from a uniform distribution on the interval :math:`[0, 1)`. +``torch.rand_like(input)`` is equivalent to +``torch.rand(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``. + +Args: + {input} + +Keyword args: + {dtype} + {layout} + {device} + {requires_grad} + {memory_format} + +""".format(**factory_like_common_args), +) + +add_docstr( + torch.randint, + """ +randint(low=0, high, size, \\*, generator=None, out=None, \ +dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + +Returns a tensor filled with random integers generated uniformly +between :attr:`low` (inclusive) and :attr:`high` (exclusive). + +The shape of the tensor is defined by the variable argument :attr:`size`. + +.. note:: + With the global dtype default (``torch.float32``), this function returns + a tensor with dtype ``torch.int64``. + +Args: + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + size (tuple): a tuple defining the shape of the output tensor. + +Keyword args: + {generator} + {out} + dtype (`torch.dtype`, optional) - the desired data type of returned tensor. Default: if ``None``, + this function returns a tensor with dtype ``torch.int64``. + {layout} + {device} + {requires_grad} + +Example:: + + >>> torch.randint(3, 5, (3,)) + tensor([4, 3, 4]) + + + >>> torch.randint(10, (2, 2)) + tensor([[0, 2], + [5, 5]]) + + + >>> torch.randint(3, 10, (2, 2)) + tensor([[4, 5], + [6, 7]]) + + +""".format(**factory_common_args), +) + +add_docstr( + torch.randint_like, + """ +randint_like(input, low=0, high, \\*, dtype=None, layout=torch.strided, device=None, requires_grad=False, \ +memory_format=torch.preserve_format) -> Tensor + +Returns a tensor with the same shape as Tensor :attr:`input` filled with +random integers generated uniformly between :attr:`low` (inclusive) and +:attr:`high` (exclusive). + +.. note: + With the global dtype default (``torch.float32``), this function returns + a tensor with dtype ``torch.int64``. + +Args: + {input} + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + +Keyword args: + {dtype} + {layout} + {device} + {requires_grad} + {memory_format} + +""".format(**factory_like_common_args), +) + +add_docstr( + torch.randn, + """ +randn(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, \ +pin_memory=False) -> Tensor +""" + + r""" + +Returns a tensor filled with random numbers from a normal distribution +with mean `0` and variance `1` (also called the standard normal +distribution). + +.. math:: + \text{{out}}_{{i}} \sim \mathcal{{N}}(0, 1) + +For complex dtypes, the tensor is i.i.d. sampled from a `complex normal distribution`_ with zero mean and +unit variance as + +.. math:: + \text{{out}}_{{i}} \sim \mathcal{{CN}}(0, 1) + +This is equivalent to separately sampling the real :math:`(\operatorname{{Re}})` and imaginary +:math:`(\operatorname{{Im}})` part of :math:`\text{{out}}_i` as + +.. math:: + \operatorname{{Re}}(\text{{out}}_{{i}}) \sim \mathcal{{N}}(0, \frac{{1}}{{2}}),\quad + \operatorname{{Im}}(\text{{out}}_{{i}}) \sim \mathcal{{N}}(0, \frac{{1}}{{2}}) + +The shape of the tensor is defined by the variable argument :attr:`size`. + + +Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + +Keyword args: + {generator} + {out} + {dtype} + {layout} + {device} + {requires_grad} + {pin_memory} + +Example:: + + >>> torch.randn(4) + tensor([-2.1436, 0.9966, 2.3426, -0.6366]) + >>> torch.randn(2, 3) + tensor([[ 1.5954, 2.8929, -1.0923], + [ 1.1719, -0.4709, -0.1996]]) + +.. _complex normal distribution: https://en.wikipedia.org/wiki/Complex_normal_distribution +""".format(**factory_common_args), +) + +add_docstr( + torch.randn_like, + r""" +randn_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + +Returns a tensor with the same size as :attr:`input` that is filled with +random numbers from a normal distribution with mean 0 and variance 1. Please refer to :func:`torch.randn` for the +sampling process of complex dtypes. ``torch.randn_like(input)`` is equivalent to +``torch.randn(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``. + +Args: + {input} + +Keyword args: + {dtype} + {layout} + {device} + {requires_grad} + {memory_format} + +""".format(**factory_like_common_args), +) + +add_docstr( + torch.randperm, + """ +randperm(n, *, generator=None, out=None, dtype=torch.int64,layout=torch.strided, \ +device=None, requires_grad=False, pin_memory=False) -> Tensor +""" + + r""" +Returns a random permutation of integers from ``0`` to ``n - 1``. + +Args: + n (int): the upper bound (exclusive) + +Keyword args: + {generator} + {out} + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: ``torch.int64``. + {layout} + {device} + {requires_grad} + {pin_memory} + +Example:: + + >>> torch.randperm(4) + tensor([2, 1, 0, 3]) +""".format(**factory_common_args), +) + +add_docstr( + torch.tensor, + r""" +tensor(data, *, dtype=None, device=None, requires_grad=False, pin_memory=False) -> Tensor + +Constructs a tensor with no autograd history (also known as a "leaf tensor", see :doc:`/notes/autograd`) by copying :attr:`data`. + +.. warning:: + + When working with tensors prefer using :func:`torch.Tensor.clone`, + :func:`torch.Tensor.detach`, and :func:`torch.Tensor.requires_grad_` for + readability. Letting `t` be a tensor, ``torch.tensor(t)`` is equivalent to + ``t.detach().clone()``, and ``torch.tensor(t, requires_grad=True)`` + is equivalent to ``t.detach().clone().requires_grad_(True)``. + +.. seealso:: + + :func:`torch.as_tensor` preserves autograd history and avoids copies where possible. + :func:`torch.from_numpy` creates a tensor that shares storage with a NumPy array. + +Args: + {data} + +Keyword args: + {dtype} + device (:class:`torch.device`, optional): the device of the constructed tensor. If None and data is a tensor + then the device of data is used. If None and data is not a tensor then + the result tensor is constructed on the current device. + {requires_grad} + {pin_memory} + + +Example:: + + >>> torch.tensor([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]]) + tensor([[ 0.1000, 1.2000], + [ 2.2000, 3.1000], + [ 4.9000, 5.2000]]) + + >>> torch.tensor([0, 1]) # Type inference on data + tensor([ 0, 1]) + + >>> torch.tensor([[0.11111, 0.222222, 0.3333333]], + ... dtype=torch.float64, + ... device=torch.device('cuda:0')) # creates a double tensor on a CUDA device + tensor([[ 0.1111, 0.2222, 0.3333]], dtype=torch.float64, device='cuda:0') + + >>> torch.tensor(3.14159) # Create a zero-dimensional (scalar) tensor + tensor(3.1416) + + >>> torch.tensor([]) # Create an empty tensor (of size (0,)) + tensor([]) +""".format(**factory_data_common_args), +) + +add_docstr( + torch.range, + r""" +range(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + +Returns a 1-D tensor of size :math:`\left\lfloor \frac{\text{end} - \text{start}}{\text{step}} \right\rfloor + 1` +with values from :attr:`start` to :attr:`end` with step :attr:`step`. Step is +the gap between two values in the tensor. + +.. math:: + \text{out}_{i+1} = \text{out}_i + \text{step}. +""" + + r""" +.. warning:: + This function is deprecated and will be removed in a future release because its behavior is inconsistent with + Python's range builtin. Instead, use :func:`torch.arange`, which produces values in [start, end). + +Args: + start (float, optional): the starting value for the set of points. Default: ``0``. + end (float): the ending value for the set of points + step (float, optional): the gap between each pair of adjacent points. Default: ``1``. + +Keyword args: + {out} + {dtype} If `dtype` is not given, infer the data type from the other input + arguments. If any of `start`, `end`, or `step` are floating-point, the + `dtype` is inferred to be the default dtype, see + :meth:`~torch.get_default_dtype`. Otherwise, the `dtype` is inferred to + be `torch.int64`. + {layout} + {device} + {requires_grad} + +Example:: + + >>> torch.range(1, 4) + tensor([ 1., 2., 3., 4.]) + >>> torch.range(1, 4, 0.5) + tensor([ 1.0000, 1.5000, 2.0000, 2.5000, 3.0000, 3.5000, 4.0000]) +""".format(**factory_common_args), +) + +add_docstr( + torch.arange, + r""" +arange(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + +Returns a 1-D tensor of size :math:`\left\lceil \frac{\text{end} - \text{start}}{\text{step}} \right\rceil` +with values from the interval ``[start, end)`` taken with common difference +:attr:`step` beginning from `start`. + +Note: When using floating-point dtypes (especially reduced precision types like ``bfloat16``), +the results may be affected by floating-point rounding behavior. Some values in the sequence +might not be exactly representable in certain floating-point formats, which can lead to +repeated values or unexpected rounding. For precise sequences, it is recommended to use +integer dtypes instead of floating-point dtypes. + +Note that non-integer :attr:`step` is subject to floating point rounding errors when +comparing against :attr:`end`; to avoid inconsistency, we advise subtracting a small epsilon from :attr:`end` +in such cases. + +.. math:: + \text{out}_{{i+1}} = \text{out}_{i} + \text{step} +""" + + r""" +Args: + start (Number, optional): the starting value for the set of points. Default: ``0``. + end (Number): the ending value for the set of points + step (Number, optional): the gap between each pair of adjacent points. Default: ``1``. + +Keyword args: + {out} + {dtype} If `dtype` is not given, infer the data type from the other input + arguments. If any of `start`, `end`, or `stop` are floating-point, the + `dtype` is inferred to be the default dtype, see + :meth:`~torch.get_default_dtype`. Otherwise, the `dtype` is inferred to + be `torch.int64`. + {layout} + {device} + {requires_grad} + +Example:: + + >>> torch.arange(5) + tensor([ 0, 1, 2, 3, 4]) + >>> torch.arange(1, 4) + tensor([ 1, 2, 3]) + >>> torch.arange(1, 2.5, 0.5) + tensor([ 1.0000, 1.5000, 2.0000]) +""".format(**factory_common_args), +) + +add_docstr( + torch.ravel, + r""" +ravel(input) -> Tensor + +Return a contiguous flattened tensor. A copy is made only if needed. + +Args: + {input} + +Example:: + + >>> t = torch.tensor([[[1, 2], + ... [3, 4]], + ... [[5, 6], + ... [7, 8]]]) + >>> torch.ravel(t) + tensor([1, 2, 3, 4, 5, 6, 7, 8]) +""".format(**common_args), +) + +add_docstr( + torch.remainder, + r""" +remainder(input, other, *, out=None) -> Tensor + +Computes +`Python's modulus operation `_ +entrywise. The result has the same sign as the divisor :attr:`other` and its absolute value +is less than that of :attr:`other`. + +It may also be defined in terms of :func:`torch.div` as + +.. code:: python + + torch.remainder(a, b) == a - a.div(b, rounding_mode="floor") * b + +Supports :ref:`broadcasting to a common shape `, +:ref:`type promotion `, and integer and float inputs. + +.. note:: + Complex inputs are not supported. In some cases, it is not mathematically + possible to satisfy the definition of a modulo operation with complex numbers. + See :func:`torch.fmod` for how division by zero is handled. + +.. seealso:: + + :func:`torch.fmod` which implements C++'s `std::fmod `_. + This one is defined in terms of division rounding towards zero. + +Args: + input (Tensor or Scalar): the dividend + other (Tensor or Scalar): the divisor + +Keyword args: + {out} + +Example:: + + >>> torch.remainder(torch.tensor([-3., -2, -1, 1, 2, 3]), 2) + tensor([ 1., 0., 1., 1., 0., 1.]) + >>> torch.remainder(torch.tensor([1, 2, 3, 4, 5]), -1.5) + tensor([ -0.5000, -1.0000, 0.0000, -0.5000, -1.0000 ]) +""".format(**common_args), +) + +add_docstr( + torch.renorm, + r""" +renorm(input, p, dim, maxnorm, *, out=None) -> Tensor + +Returns a tensor where each sub-tensor of :attr:`input` along dimension +:attr:`dim` is normalized such that the `p`-norm of the sub-tensor is lower +than the value :attr:`maxnorm` + +.. note:: If the norm of a row is lower than `maxnorm`, the row is unchanged + +Args: + {input} + p (float): the power for the norm computation + dim (int): the dimension to slice over to get the sub-tensors + maxnorm (float): the maximum norm to keep each sub-tensor under + +Keyword args: + {out} + +Example:: + + >>> x = torch.ones(3, 3) + >>> x[1].fill_(2) + tensor([ 2., 2., 2.]) + >>> x[2].fill_(3) + tensor([ 3., 3., 3.]) + >>> x + tensor([[ 1., 1., 1.], + [ 2., 2., 2.], + [ 3., 3., 3.]]) + >>> torch.renorm(x, 1, 0, 5) + tensor([[ 1.0000, 1.0000, 1.0000], + [ 1.6667, 1.6667, 1.6667], + [ 1.6667, 1.6667, 1.6667]]) +""".format(**common_args), +) + +add_docstr( + torch.reshape, + r""" +reshape(input, shape) -> Tensor + +Returns a tensor with the same data and number of elements as :attr:`input`, +but with the specified shape. When possible, the returned tensor will be a view +of :attr:`input`. Otherwise, it will be a copy. Contiguous inputs and inputs +with compatible strides can be reshaped without copying, but you should not +depend on the copying vs. viewing behavior. + +See :meth:`torch.Tensor.view` on when it is possible to return a view. + +A single dimension may be -1, in which case it's inferred from the remaining +dimensions and the number of elements in :attr:`input`. + +Args: + input (Tensor): the tensor to be reshaped + shape (tuple of int): the new shape + +Example:: + + >>> a = torch.arange(4.) + >>> torch.reshape(a, (2, 2)) + tensor([[ 0., 1.], + [ 2., 3.]]) + >>> b = torch.tensor([[0, 1], [2, 3]]) + >>> torch.reshape(b, (-1,)) + tensor([ 0, 1, 2, 3]) +""", +) + + +add_docstr( + torch.result_type, + r""" +result_type(tensor1, tensor2) -> dtype + +Returns the :class:`torch.dtype` that would result from performing an arithmetic +operation on the provided input tensors. See type promotion :ref:`documentation ` +for more information on the type promotion logic. + +Args: + tensor1 (Tensor or Number): an input tensor or number + tensor2 (Tensor or Number): an input tensor or number + +Example:: + + >>> torch.result_type(torch.tensor([1, 2], dtype=torch.int), 1.0) + torch.float32 + >>> torch.result_type(torch.tensor([1, 2], dtype=torch.uint8), torch.tensor(1)) + torch.uint8 +""", +) + +add_docstr( + torch.row_stack, + r""" +row_stack(tensors, *, out=None) -> Tensor + +Alias of :func:`torch.vstack`. +""", +) + +add_docstr( + torch.round, + r""" +round(input, *, decimals=0, out=None) -> Tensor + +Rounds elements of :attr:`input` to the nearest integer. + +For integer inputs, follows the array-api convention of returning a +copy of the input tensor. +The return type of output is same as that of input's dtype. + +.. note:: + This function implements the "round half to even" to + break ties when a number is equidistant from two + integers (e.g. `round(2.5)` is 2). + + When the :attr:\`decimals\` argument is specified the + algorithm used is similar to NumPy's `around`. This + algorithm is fast but inexact and it can easily + overflow for low precision dtypes. + Eg. `round(tensor([10000], dtype=torch.float16), decimals=3)` is `inf`. + +.. seealso:: + :func:`torch.ceil`, which rounds up. + :func:`torch.floor`, which rounds down. + :func:`torch.trunc`, which rounds towards zero. + +Args: + {input} + decimals (int): Number of decimal places to round to (default: 0). + If decimals is negative, it specifies the number of positions + to the left of the decimal point. + +Keyword args: + {out} + +Example:: + + >>> torch.round(torch.tensor((4.7, -2.3, 9.1, -7.7))) + tensor([ 5., -2., 9., -8.]) + + >>> # Values equidistant from two integers are rounded towards the + >>> # the nearest even value (zero is treated as even) + >>> torch.round(torch.tensor([-0.5, 0.5, 1.5, 2.5])) + tensor([-0., 0., 2., 2.]) + + >>> # A positive decimals argument rounds to the to that decimal place + >>> torch.round(torch.tensor([0.1234567]), decimals=3) + tensor([0.1230]) + + >>> # A negative decimals argument rounds to the left of the decimal + >>> torch.round(torch.tensor([1200.1234567]), decimals=-3) + tensor([1000.]) +""".format(**common_args), +) + +add_docstr( + torch.rsqrt, + r""" +rsqrt(input, *, out=None) -> Tensor + +Returns a new tensor with the reciprocal of the square-root of each of +the elements of :attr:`input`. + +.. math:: + \text{out}_{i} = \frac{1}{\sqrt{\text{input}_{i}}} +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-0.0370, 0.2970, 1.5420, -0.9105]) + >>> torch.rsqrt(a) + tensor([ nan, 1.8351, 0.8053, nan]) +""".format(**common_args), +) + +add_docstr( + torch.scatter, + r""" +scatter(input, dim, index, src) -> Tensor + +Out-of-place version of :meth:`torch.Tensor.scatter_` +""", +) + +add_docstr( + torch.scatter_add, + r""" +scatter_add(input, dim, index, src) -> Tensor + +Out-of-place version of :meth:`torch.Tensor.scatter_add_` +""", +) + +add_docstr( + torch.scatter_reduce, + r""" +scatter_reduce(input, dim, index, src, reduce, *, include_self=True) -> Tensor + +Out-of-place version of :meth:`torch.Tensor.scatter_reduce_` +""", +) + +add_docstr( + torch.select, + r""" +select(input, dim, index) -> Tensor + +Slices the :attr:`input` tensor along the selected dimension at the given index. +This function returns a view of the original tensor with the given dimension removed. + +.. note:: If :attr:`input` is a sparse tensor and returning a view of + the tensor is not possible, a RuntimeError exception is + raised. In this is the case, consider using + :func:`torch.select_copy` function. + +Args: + {input} + dim (int): the dimension to slice + index (int): the index to select with + +.. note:: + + :meth:`select` is equivalent to slicing. For example, + ``tensor.select(0, index)`` is equivalent to ``tensor[index]`` and + ``tensor.select(2, index)`` is equivalent to ``tensor[:,:,index]``. +""".format(**common_args), +) + +add_docstr( + torch.select_scatter, + r""" +select_scatter(input, src, dim, index) -> Tensor + +Embeds the values of the :attr:`src` tensor into :attr:`input` at the given index. +This function returns a tensor with fresh storage; it does not create a view. + + +Args: + {input} + src (Tensor): The tensor to embed into :attr:`input` + dim (int): the dimension to insert the slice into. + index (int): the index to select with + +.. note:: + + :attr:`src` must be of the proper size in order to be embedded + into :attr:`input`. Specifically, it should have the same shape as + ``torch.select(input, dim, index)`` + +Example:: + + >>> a = torch.zeros(2, 2) + >>> b = torch.ones(2) + >>> a.select_scatter(b, 0, 0) + tensor([[1., 1.], + [0., 0.]]) +""".format(**common_args), +) + +add_docstr( + torch.slice_scatter, + r""" +slice_scatter(input, src, dim=0, start=None, end=None, step=1) -> Tensor + +Embeds the values of the :attr:`src` tensor into :attr:`input` at the given +dimension. +This function returns a tensor with fresh storage; it does not create a view. + + +Args: + {input} + src (Tensor): The tensor to embed into :attr:`input` + dim (int): the dimension to insert the slice into + start (Optional[int]): the start index of where to insert the slice + end (Optional[int]): the end index of where to insert the slice + step (int): the how many elements to skip in + +Example:: + + >>> a = torch.zeros(8, 8) + >>> b = torch.ones(2, 8) + >>> a.slice_scatter(b, start=6) + tensor([[0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0.], + [1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1.]]) + + >>> b = torch.ones(8, 2) + >>> a.slice_scatter(b, dim=1, start=2, end=6, step=2) + tensor([[0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.]]) +""".format(**common_args), +) + +add_docstr( + torch.set_flush_denormal, + r""" +set_flush_denormal(mode) -> bool + +Disables denormal floating numbers on CPU. + +Returns ``True`` if your system supports flushing denormal numbers and it +successfully configures flush denormal mode. :meth:`~torch.set_flush_denormal` +is supported on x86 architectures supporting SSE3 and AArch64 architecture. + +Args: + mode (bool): Controls whether to enable flush denormal mode or not + +Example:: + + >>> torch.set_flush_denormal(True) + True + >>> torch.tensor([1e-323], dtype=torch.float64) + tensor([ 0.], dtype=torch.float64) + >>> torch.set_flush_denormal(False) + True + >>> torch.tensor([1e-323], dtype=torch.float64) + tensor(9.88131e-324 * + [ 1.0000], dtype=torch.float64) +""", +) + +add_docstr( + torch.set_num_threads, + r""" +set_num_threads(int) + +Sets the number of threads used for intraop parallelism on CPU. + +.. warning:: + To ensure that the correct number of threads is used, set_num_threads + must be called before running eager, JIT or autograd code. +""", +) + +add_docstr( + torch.set_num_interop_threads, + r""" +set_num_interop_threads(int) + +Sets the number of threads used for interop parallelism +(e.g. in JIT interpreter) on CPU. + +.. warning:: + Can only be called once and before any inter-op parallel work + is started (e.g. JIT execution). +""", +) + +add_docstr( + torch.sigmoid, + r""" +sigmoid(input, *, out=None) -> Tensor + +Alias for :func:`torch.special.expit`. +""", +) + +add_docstr( + torch.logit, + r""" +logit(input, eps=None, *, out=None) -> Tensor + +Alias for :func:`torch.special.logit`. +""", +) + +add_docstr( + torch.sign, + r""" +sign(input, *, out=None) -> Tensor + +Returns a new tensor with the signs of the elements of :attr:`input`. + +.. math:: + \text{out}_{i} = \operatorname{sgn}(\text{input}_{i}) +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.tensor([0.7, -1.2, 0., 2.3]) + >>> a + tensor([ 0.7000, -1.2000, 0.0000, 2.3000]) + >>> torch.sign(a) + tensor([ 1., -1., 0., 1.]) +""".format(**common_args), +) + +add_docstr( + torch.signbit, + r""" +signbit(input, *, out=None) -> Tensor + +Tests if each element of :attr:`input` has its sign bit set or not. + +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.tensor([0.7, -1.2, 0., 2.3]) + >>> torch.signbit(a) + tensor([ False, True, False, False]) + >>> a = torch.tensor([-0.0, 0.0]) + >>> torch.signbit(a) + tensor([ True, False]) + +.. note:: + signbit handles signed zeros, so negative zero (-0) returns True. + +""".format(**common_args), +) + +add_docstr( + torch.sgn, + r""" +sgn(input, *, out=None) -> Tensor + +This function is an extension of torch.sign() to complex tensors. +It computes a new tensor whose elements have +the same angles as the corresponding elements of :attr:`input` and +absolute values (i.e. magnitudes) of one for complex tensors and +is equivalent to torch.sign() for non-complex tensors. + +.. math:: + \text{out}_{i} = \begin{cases} + 0 & |\text{{input}}_i| == 0 \\ + \frac{{\text{{input}}_i}}{|{\text{{input}}_i}|} & \text{otherwise} + \end{cases} + +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> t = torch.tensor([3+4j, 7-24j, 0, 1+2j]) + >>> t.sgn() + tensor([0.6000+0.8000j, 0.2800-0.9600j, 0.0000+0.0000j, 0.4472+0.8944j]) +""".format(**common_args), +) + +add_docstr( + torch.sin, + r""" +sin(input, *, out=None) -> Tensor + +Returns a new tensor with the sine of the elements of :attr:`input`. + +.. math:: + \text{out}_{i} = \sin(\text{input}_{i}) +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-0.5461, 0.1347, -2.7266, -0.2746]) + >>> torch.sin(a) + tensor([-0.5194, 0.1343, -0.4032, -0.2711]) +""".format(**common_args), +) + +add_docstr( + torch.sinc, + r""" +sinc(input, *, out=None) -> Tensor + +Alias for :func:`torch.special.sinc`. +""", +) + +add_docstr( + torch.sinh, + r""" +sinh(input, *, out=None) -> Tensor + +Returns a new tensor with the hyperbolic sine of the elements of +:attr:`input`. + +.. math:: + \text{out}_{i} = \sinh(\text{input}_{i}) +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.5380, -0.8632, -0.1265, 0.9399]) + >>> torch.sinh(a) + tensor([ 0.5644, -0.9744, -0.1268, 1.0845]) + +.. note:: + When :attr:`input` is on the CPU, the implementation of torch.sinh may use + the Sleef library, which rounds very large results to infinity or negative + infinity. See `here `_ for details. +""".format(**common_args), +) + +add_docstr( + torch.sort, + r""" +sort(input, dim=-1, descending=False, stable=False, *, out=None) -> (Tensor, LongTensor) + +Sorts the elements of the :attr:`input` tensor along a given dimension +in ascending order by value. + +If :attr:`dim` is not given, the last dimension of the `input` is chosen. + +If :attr:`descending` is ``True`` then the elements are sorted in descending +order by value. + +If :attr:`stable` is ``True`` then the sorting routine becomes stable, preserving +the order of equivalent elements. + +A namedtuple of (values, indices) is returned, where the `values` are the +sorted values and `indices` are the indices of the elements in the original +`input` tensor. + +Args: + {input} + dim (int, optional): the dimension to sort along + descending (bool, optional): controls the sorting order (ascending or descending) + stable (bool, optional): makes the sorting routine stable, which guarantees that the order + of equivalent elements is preserved. + +Keyword args: + out (tuple, optional): the output tuple of (`Tensor`, `LongTensor`) that can + be optionally given to be used as output buffers + +Example:: + + >>> x = torch.randn(3, 4) + >>> sorted, indices = torch.sort(x) + >>> sorted + tensor([[-0.2162, 0.0608, 0.6719, 2.3332], + [-0.5793, 0.0061, 0.6058, 0.9497], + [-0.5071, 0.3343, 0.9553, 1.0960]]) + >>> indices + tensor([[ 1, 0, 2, 3], + [ 3, 1, 0, 2], + [ 0, 3, 1, 2]]) + + >>> sorted, indices = torch.sort(x, 0) + >>> sorted + tensor([[-0.5071, -0.2162, 0.6719, -0.5793], + [ 0.0608, 0.0061, 0.9497, 0.3343], + [ 0.6058, 0.9553, 1.0960, 2.3332]]) + >>> indices + tensor([[ 2, 0, 0, 1], + [ 0, 1, 1, 2], + [ 1, 2, 2, 0]]) + >>> x = torch.tensor([0, 1] * 9) + >>> x.sort() + torch.return_types.sort( + values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + indices=tensor([ 2, 16, 4, 6, 14, 8, 0, 10, 12, 9, 17, 15, 13, 11, 7, 5, 3, 1])) + >>> x.sort(stable=True) + torch.return_types.sort( + values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + indices=tensor([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 1, 3, 5, 7, 9, 11, 13, 15, 17])) +""".format(**common_args), +) + +add_docstr( + torch.argsort, + r""" +argsort(input, dim=-1, descending=False, stable=False) -> Tensor + +Returns the indices that sort a tensor along a given dimension in ascending +order by value. + +This is the second value returned by :meth:`torch.sort`. See its documentation +for the exact semantics of this method. + +If :attr:`stable` is ``True`` then the sorting routine becomes stable, preserving +the order of equivalent elements. If ``False``, the relative order of values +which compare equal is not guaranteed. ``True`` is slower. + +Args: + {input} + dim (int, optional): the dimension to sort along + descending (bool, optional): controls the sorting order (ascending or descending) + stable (bool, optional): controls the relative order of equivalent elements + +Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.0785, 1.5267, -0.8521, 0.4065], + [ 0.1598, 0.0788, -0.0745, -1.2700], + [ 1.2208, 1.0722, -0.7064, 1.2564], + [ 0.0669, -0.2318, -0.8229, -0.9280]]) + + + >>> torch.argsort(a, dim=1) + tensor([[2, 0, 3, 1], + [3, 2, 1, 0], + [2, 1, 0, 3], + [3, 2, 1, 0]]) +""".format(**common_args), +) + +add_docstr( + torch.msort, + r""" +msort(input: Tensor, *, out: Optional[Tensor]) -> Tensor + +Sorts the elements of the :attr:`input` tensor along its first dimension +in ascending order by value. + +.. note:: `torch.msort(t)` is equivalent to `torch.sort(t, dim=0)[0]`. + See also :func:`torch.sort`. + +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> t = torch.randn(3, 4) + >>> t + tensor([[-0.1321, 0.4370, -1.2631, -1.1289], + [-2.0527, -1.1250, 0.2275, 0.3077], + [-0.0881, -0.1259, -0.5495, 1.0284]]) + >>> torch.msort(t) + tensor([[-2.0527, -1.1250, -1.2631, -1.1289], + [-0.1321, -0.1259, -0.5495, 0.3077], + [-0.0881, 0.4370, 0.2275, 1.0284]]) +""".format(**common_args), +) + +add_docstr( + torch.sparse_compressed_tensor, + r"""sparse_compressed_tensor(compressed_indices, plain_indices, values, size=None, """ + r"""*, dtype=None, layout=None, device=None, pin_memory=False, requires_grad=False, check_invariants=None) -> Tensor + +Constructs a :ref:`sparse tensor in Compressed Sparse format - CSR, +CSC, BSR, or BSC - ` with specified values at +the given :attr:`compressed_indices` and :attr:`plain_indices`. Sparse +matrix multiplication operations in Compressed Sparse format are +typically faster than that for sparse tensors in COO format. Make you +have a look at :ref:`the note on the data type of the indices +`. + +{sparse_factory_device_note} + +Args: + compressed_indices (array_like): (B+1)-dimensional array of size + ``(*batchsize, compressed_dim_size + 1)``. The last element of + each batch is the number of non-zero elements or blocks. This + tensor encodes the index in ``values`` and ``plain_indices`` + depending on where the given compressed dimension (row or + column) starts. Each successive number in the tensor + subtracted by the number before it denotes the number of + elements or blocks in a given compressed dimension. + plain_indices (array_like): Plain dimension (column or row) + co-ordinates of each element or block in values. (B+1)-dimensional + tensor with the same length as values. + + values (array_list): Initial values for the tensor. Can be a list, + tuple, NumPy ``ndarray``, scalar, and other types. that + represents a (1+K)-dimensional (for CSR and CSC layouts) or + (1+2+K)-dimensional tensor (for BSR and BSC layouts) where + ``K`` is the number of dense dimensions. + size (list, tuple, :class:`torch.Size`, optional): Size of the + sparse tensor: ``(*batchsize, nrows * blocksize[0], ncols * + blocksize[1], *densesize)`` where ``blocksize[0] == + blocksize[1] == 1`` for CSR and CSC formats. If not provided, + the size will be inferred as the minimum size big enough to + hold all non-zero elements or blocks. + +Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of + returned tensor. Default: if None, infers data type from + :attr:`values`. + layout (:class:`torch.layout`, required): the desired layout of + returned tensor: :attr:`torch.sparse_csr`, + :attr:`torch.sparse_csc`, :attr:`torch.sparse_bsr`, or + :attr:`torch.sparse_bsc`. + device (:class:`torch.device`, optional): the desired device of + returned tensor. Default: if None, uses the current device + for the default tensor type (see + :func:`torch.set_default_device`). :attr:`device` will be + the CPU for CPU tensor types and the current CUDA device for + CUDA tensor types. + {pin_memory} + {requires_grad} + {check_invariants} + +Example:: + + >>> compressed_indices = [0, 2, 4] + >>> plain_indices = [0, 1, 0, 1] + >>> values = [1, 2, 3, 4] + >>> torch.sparse_compressed_tensor(torch.tensor(compressed_indices, dtype=torch.int64), + ... torch.tensor(plain_indices, dtype=torch.int64), + ... torch.tensor(values), dtype=torch.double, layout=torch.sparse_csr) + tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 1]), + values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4, + dtype=torch.float64, layout=torch.sparse_csr) +""".format(**factory_common_args), +) + +add_docstr( + torch.sparse_csr_tensor, + r"""sparse_csr_tensor(crow_indices, col_indices, values, size=None, """ + r"""*, dtype=None, device=None, pin_memory=False, requires_grad=False, check_invariants=None) -> Tensor + +Constructs a :ref:`sparse tensor in CSR (Compressed Sparse Row) ` with specified +values at the given :attr:`crow_indices` and :attr:`col_indices`. Sparse matrix multiplication operations +in CSR format are typically faster than that for sparse tensors in COO format. Make you have a look +at :ref:`the note on the data type of the indices `. + +{sparse_factory_device_note} + +Args: + crow_indices (array_like): (B+1)-dimensional array of size + ``(*batchsize, nrows + 1)``. The last element of each batch + is the number of non-zeros. This tensor encodes the index in + values and col_indices depending on where the given row + starts. Each successive number in the tensor subtracted by the + number before it denotes the number of elements in a given + row. + col_indices (array_like): Column co-ordinates of each element in + values. (B+1)-dimensional tensor with the same length + as values. + values (array_list): Initial values for the tensor. Can be a list, + tuple, NumPy ``ndarray``, scalar, and other types that + represents a (1+K)-dimensional tensor where ``K`` is the number + of dense dimensions. + size (list, tuple, :class:`torch.Size`, optional): Size of the + sparse tensor: ``(*batchsize, nrows, ncols, *densesize)``. If + not provided, the size will be inferred as the minimum size + big enough to hold all non-zero elements. + +Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of + returned tensor. Default: if None, infers data type from + :attr:`values`. + device (:class:`torch.device`, optional): the desired device of + returned tensor. Default: if None, uses the current device + for the default tensor type (see + :func:`torch.set_default_device`). :attr:`device` will be + the CPU for CPU tensor types and the current CUDA device for + CUDA tensor types. + {pin_memory} + {requires_grad} + {check_invariants} + +Example:: + + >>> crow_indices = [0, 2, 4] + >>> col_indices = [0, 1, 0, 1] + >>> values = [1, 2, 3, 4] + >>> torch.sparse_csr_tensor(torch.tensor(crow_indices, dtype=torch.int64), + ... torch.tensor(col_indices, dtype=torch.int64), + ... torch.tensor(values), dtype=torch.double) + tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 1]), + values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4, + dtype=torch.float64, layout=torch.sparse_csr) +""".format(**factory_common_args), +) + +add_docstr( + torch.sparse_csc_tensor, + r"""sparse_csc_tensor(ccol_indices, row_indices, values, size=None, """ + r"""*, dtype=None, device=None, pin_memory=False, requires_grad=False, check_invariants=None) -> Tensor + +Constructs a :ref:`sparse tensor in CSC (Compressed Sparse Column) +` with specified values at the given +:attr:`ccol_indices` and :attr:`row_indices`. Sparse matrix +multiplication operations in CSC format are typically faster than that +for sparse tensors in COO format. Make you have a look at :ref:`the +note on the data type of the indices `. + +{sparse_factory_device_note} + +Args: + ccol_indices (array_like): (B+1)-dimensional array of size + ``(*batchsize, ncols + 1)``. The last element of each batch + is the number of non-zeros. This tensor encodes the index in + values and row_indices depending on where the given column + starts. Each successive number in the tensor subtracted by the + number before it denotes the number of elements in a given + column. + row_indices (array_like): Row co-ordinates of each element in + values. (B+1)-dimensional tensor with the same length as + values. + values (array_list): Initial values for the tensor. Can be a list, + tuple, NumPy ``ndarray``, scalar, and other types that + represents a (1+K)-dimensional tensor where ``K`` is the number + of dense dimensions. + size (list, tuple, :class:`torch.Size`, optional): Size of the + sparse tensor: ``(*batchsize, nrows, ncols, *densesize)``. If + not provided, the size will be inferred as the minimum size + big enough to hold all non-zero elements. + +Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of + returned tensor. Default: if None, infers data type from + :attr:`values`. + device (:class:`torch.device`, optional): the desired device of + returned tensor. Default: if None, uses the current device + for the default tensor type (see + :func:`torch.set_default_device`). :attr:`device` will be + the CPU for CPU tensor types and the current CUDA device for + CUDA tensor types. + {pin_memory} + {requires_grad} + {check_invariants} + +Example:: + + >>> ccol_indices = [0, 2, 4] + >>> row_indices = [0, 1, 0, 1] + >>> values = [1, 2, 3, 4] + >>> torch.sparse_csc_tensor(torch.tensor(ccol_indices, dtype=torch.int64), + ... torch.tensor(row_indices, dtype=torch.int64), + ... torch.tensor(values), dtype=torch.double) + tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 1]), + values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4, + dtype=torch.float64, layout=torch.sparse_csc) +""".format(**factory_common_args), +) + +add_docstr( + torch.sparse_bsr_tensor, + r"""sparse_bsr_tensor(crow_indices, col_indices, values, size=None, """ + r"""*, dtype=None, device=None, pin_memory=False, requires_grad=False, check_invariants=None) -> Tensor + +Constructs a :ref:`sparse tensor in BSR (Block Compressed Sparse Row)) +` with specified 2-dimensional blocks at the given +:attr:`crow_indices` and :attr:`col_indices`. Sparse matrix +multiplication operations in BSR format are typically faster than that +for sparse tensors in COO format. Make you have a look at :ref:`the +note on the data type of the indices `. + +{sparse_factory_device_note} + +Args: + crow_indices (array_like): (B+1)-dimensional array of size + ``(*batchsize, nrowblocks + 1)``. The last element of each + batch is the number of non-zeros. This tensor encodes the + block index in values and col_indices depending on where the + given row block starts. Each successive number in the tensor + subtracted by the number before it denotes the number of + blocks in a given row. + col_indices (array_like): Column block co-ordinates of each block + in values. (B+1)-dimensional tensor with the same length as + values. + values (array_list): Initial values for the tensor. Can be a list, + tuple, NumPy ``ndarray``, scalar, and other types that + represents a (1 + 2 + K)-dimensional tensor where ``K`` is the + number of dense dimensions. + size (list, tuple, :class:`torch.Size`, optional): Size of the + sparse tensor: ``(*batchsize, nrows * blocksize[0], ncols * + blocksize[1], *densesize)`` where ``blocksize == + values.shape[1:3]``. If not provided, the size will be + inferred as the minimum size big enough to hold all non-zero + blocks. + +Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of + returned tensor. Default: if None, infers data type from + :attr:`values`. + device (:class:`torch.device`, optional): the desired device of + returned tensor. Default: if None, uses the current device + for the default tensor type (see + :func:`torch.set_default_device`). :attr:`device` will be + the CPU for CPU tensor types and the current CUDA device for + CUDA tensor types. + {pin_memory} + {requires_grad} + {check_invariants} + +Example:: + + >>> crow_indices = [0, 1, 2] + >>> col_indices = [0, 1] + >>> values = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] + >>> torch.sparse_bsr_tensor(torch.tensor(crow_indices, dtype=torch.int64), + ... torch.tensor(col_indices, dtype=torch.int64), + ... torch.tensor(values), dtype=torch.double) + tensor(crow_indices=tensor([0, 1, 2]), + col_indices=tensor([0, 1]), + values=tensor([[[1., 2.], + [3., 4.]], + [[5., 6.], + [7., 8.]]]), size=(2, 2), nnz=2, dtype=torch.float64, + layout=torch.sparse_bsr) +""".format(**factory_common_args), +) + +add_docstr( + torch.sparse_bsc_tensor, + r"""sparse_bsc_tensor(ccol_indices, row_indices, values, size=None, """ + r"""*, dtype=None, device=None, pin_memory=False, requires_grad=False, check_invariants=None) -> Tensor + +Constructs a :ref:`sparse tensor in BSC (Block Compressed Sparse +Column)) ` with specified 2-dimensional blocks at the +given :attr:`ccol_indices` and :attr:`row_indices`. Sparse matrix +multiplication operations in BSC format are typically faster than that +for sparse tensors in COO format. Make you have a look at :ref:`the +note on the data type of the indices `. + +{sparse_factory_device_note} + +Args: + ccol_indices (array_like): (B+1)-dimensional array of size + ``(*batchsize, ncolblocks + 1)``. The last element of each + batch is the number of non-zeros. This tensor encodes the + index in values and row_indices depending on where the given + column starts. Each successive number in the tensor subtracted + by the number before it denotes the number of elements in a + given column. + row_indices (array_like): Row block co-ordinates of each block in + values. (B+1)-dimensional tensor with the same length + as values. + values (array_list): Initial blocks for the tensor. Can be a list, + tuple, NumPy ``ndarray``, and other types that + represents a (1 + 2 + K)-dimensional tensor where ``K`` is the + number of dense dimensions. + size (list, tuple, :class:`torch.Size`, optional): Size of the + sparse tensor: ``(*batchsize, nrows * blocksize[0], ncols * + blocksize[1], *densesize)`` If not provided, the size will be + inferred as the minimum size big enough to hold all non-zero + blocks. + +Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of + returned tensor. Default: if None, infers data type from + :attr:`values`. + device (:class:`torch.device`, optional): the desired device of + returned tensor. Default: if None, uses the current device + for the default tensor type (see + :func:`torch.set_default_device`). :attr:`device` will be + the CPU for CPU tensor types and the current CUDA device for + CUDA tensor types. + {pin_memory} + {requires_grad} + {check_invariants} + +Example:: + + >>> ccol_indices = [0, 1, 2] + >>> row_indices = [0, 1] + >>> values = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] + >>> torch.sparse_bsc_tensor(torch.tensor(ccol_indices, dtype=torch.int64), + ... torch.tensor(row_indices, dtype=torch.int64), + ... torch.tensor(values), dtype=torch.double) + tensor(ccol_indices=tensor([0, 1, 2]), + row_indices=tensor([0, 1]), + values=tensor([[[1., 2.], + [3., 4.]], + [[5., 6.], + [7., 8.]]]), size=(2, 2), nnz=2, dtype=torch.float64, + layout=torch.sparse_bsc) +""".format(**factory_common_args), +) + +add_docstr( + torch.sparse_coo_tensor, + r"""sparse_coo_tensor(indices, values, size=None, """ + r"""*, dtype=None, device=None, pin_memory=False, requires_grad=False, check_invariants=None, is_coalesced=None) -> Tensor + +Constructs a :ref:`sparse tensor in COO(rdinate) format +` with specified values at the given +:attr:`indices`. + +.. note:: + + This function returns an :ref:`uncoalesced tensor + ` when :attr:`is_coalesced` is + unspecified or ``None``. + +{sparse_factory_device_note} + +Args: + indices (array_like): Initial data for the tensor. Can be a list, tuple, + NumPy ``ndarray``, scalar, and other types. Will be cast to a :class:`torch.LongTensor` + internally. The indices are the coordinates of the non-zero values in the matrix, and thus + should be two-dimensional where the first dimension is the number of tensor dimensions and + the second dimension is the number of non-zero values. + values (array_like): Initial values for the tensor. Can be a list, tuple, + NumPy ``ndarray``, scalar, and other types. + size (list, tuple, or :class:`torch.Size`, optional): Size of the sparse tensor. If not + provided the size will be inferred as the minimum size big enough to hold all non-zero + elements. + +Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if None, infers data type from :attr:`values`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if None, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + {pin_memory} + {requires_grad} + {check_invariants} + is_coalesced (bool, optional): When``True``, the caller is + responsible for providing tensor indices that correspond to a + coalesced tensor. If the :attr:`check_invariants` flag is + False, no error will be raised if the prerequisites are not + met and this will lead to silently incorrect results. To force + coalescion please use :meth:`coalesce` on the resulting + Tensor. + Default: None: except for trivial cases (e.g. nnz < 2) the + resulting Tensor has is_coalesced set to ``False```. + +Example:: + + >>> i = torch.tensor([[0, 1, 1], + ... [2, 0, 2]]) + >>> v = torch.tensor([3, 4, 5], dtype=torch.float32) + >>> torch.sparse_coo_tensor(i, v, [2, 4]) + tensor(indices=tensor([[0, 1, 1], + [2, 0, 2]]), + values=tensor([3., 4., 5.]), + size=(2, 4), nnz=3, layout=torch.sparse_coo) + + >>> torch.sparse_coo_tensor(i, v) # Shape inference + tensor(indices=tensor([[0, 1, 1], + [2, 0, 2]]), + values=tensor([3., 4., 5.]), + size=(2, 3), nnz=3, layout=torch.sparse_coo) + + >>> torch.sparse_coo_tensor(i, v, [2, 4], + ... dtype=torch.float64, + ... device=torch.device('cuda:0')) + tensor(indices=tensor([[0, 1, 1], + [2, 0, 2]]), + values=tensor([3., 4., 5.]), + device='cuda:0', size=(2, 4), nnz=3, dtype=torch.float64, + layout=torch.sparse_coo) + + # Create an empty sparse tensor with the following invariants: + # 1. sparse_dim + dense_dim = len(SparseTensor.shape) + # 2. SparseTensor._indices().shape = (sparse_dim, nnz) + # 3. SparseTensor._values().shape = (nnz, SparseTensor.shape[sparse_dim:]) + # + # For instance, to create an empty sparse tensor with nnz = 0, dense_dim = 0 and + # sparse_dim = 1 (hence indices is a 2D tensor of shape = (1, 0)) + >>> S = torch.sparse_coo_tensor(torch.empty([1, 0]), [], [1]) + tensor(indices=tensor([], size=(1, 0)), + values=tensor([], size=(0,)), + size=(1,), nnz=0, layout=torch.sparse_coo) + + # and to create an empty sparse tensor with nnz = 0, dense_dim = 1 and + # sparse_dim = 1 + >>> S = torch.sparse_coo_tensor(torch.empty([1, 0]), torch.empty([0, 2]), [1, 2]) + tensor(indices=tensor([], size=(1, 0)), + values=tensor([], size=(0, 2)), + size=(1, 2), nnz=0, layout=torch.sparse_coo) + +.. _torch.sparse: https://pytorch.org/docs/stable/sparse.html +""".format(**factory_common_args), +) + +add_docstr( + torch.sqrt, + r""" +sqrt(input, *, out=None) -> Tensor + +Returns a new tensor with the square-root of the elements of :attr:`input`. + +.. math:: + \text{out}_{i} = \sqrt{\text{input}_{i}} +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-2.0755, 1.0226, 0.0831, 0.4806]) + >>> torch.sqrt(a) + tensor([ nan, 1.0112, 0.2883, 0.6933]) +""".format(**common_args), +) + +add_docstr( + torch.square, + r""" +square(input: Tensor, *, out: Optional[Tensor]) -> Tensor + +Returns a new tensor with the square of the elements of :attr:`input`. + +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-2.0755, 1.0226, 0.0831, 0.4806]) + >>> torch.square(a) + tensor([ 4.3077, 1.0457, 0.0069, 0.2310]) +""".format(**common_args), +) + +add_docstr( + torch.squeeze, + r""" +squeeze(input: Tensor, dim: Optional[Union[int, List[int]]]) -> Tensor + +Returns a tensor with all specified dimensions of :attr:`input` of size `1` removed. + +For example, if `input` is of shape: +:math:`(A \times 1 \times B \times C \times 1 \times D)` then the `input.squeeze()` +will be of shape: :math:`(A \times B \times C \times D)`. + +When :attr:`dim` is given, a squeeze operation is done only in the given +dimension(s). If `input` is of shape: :math:`(A \times 1 \times B)`, +``squeeze(input, 0)`` leaves the tensor unchanged, but ``squeeze(input, 1)`` +will squeeze the tensor to the shape :math:`(A \times B)`. + +.. note:: The returned tensor shares the storage with the input tensor, + so changing the contents of one will change the contents of the other. + +.. warning:: If the tensor has a batch dimension of size 1, then `squeeze(input)` + will also remove the batch dimension, which can lead to unexpected + errors. Consider specifying only the dims you wish to be squeezed. + +Args: + {input} + dim (int or tuple of ints, optional): if given, the input will be squeezed + only in the specified dimensions. + + .. versionchanged:: 2.0 + :attr:`dim` now accepts tuples of dimensions. + +Example:: + + >>> x = torch.zeros(2, 1, 2, 1, 2) + >>> x.size() + torch.Size([2, 1, 2, 1, 2]) + >>> y = torch.squeeze(x) + >>> y.size() + torch.Size([2, 2, 2]) + >>> y = torch.squeeze(x, 0) + >>> y.size() + torch.Size([2, 1, 2, 1, 2]) + >>> y = torch.squeeze(x, 1) + >>> y.size() + torch.Size([2, 2, 1, 2]) + >>> y = torch.squeeze(x, (1, 2, 3)) + torch.Size([2, 2, 2]) +""".format(**common_args), +) + +add_docstr( + torch.std, + r""" +std(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + +Calculates the standard deviation over the dimensions specified by :attr:`dim`. +:attr:`dim` can be a single dimension, list of dimensions, or ``None`` to +reduce over all dimensions. + +The standard deviation (:math:`\sigma`) is calculated as + +.. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + +where :math:`x` is the sample set of elements, :math:`\bar{x}` is the +sample mean, :math:`N` is the number of samples and :math:`\delta N` is +the :attr:`correction`. +""" + + r""" + +{keepdim_details} + +Args: + {input} + {opt_dim_all_reduce} + +Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + {opt_keepdim} + {out} + +Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.std(a, dim=1, keepdim=True) + tensor([[1.0311], + [0.7477], + [1.2204], + [0.9087]]) + +.. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + +""".format(**multi_dim_common), +) + +add_docstr( + torch.std_mean, + r""" +std_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + +Calculates the standard deviation and mean over the dimensions specified by +:attr:`dim`. :attr:`dim` can be a single dimension, list of dimensions, or +``None`` to reduce over all dimensions. + +The standard deviation (:math:`\sigma`) is calculated as + +.. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + +where :math:`x` is the sample set of elements, :math:`\bar{x}` is the +sample mean, :math:`N` is the number of samples and :math:`\delta N` is +the :attr:`correction`. + +""" + + r""" + +{keepdim_details} + +Args: + {input} + {opt_dim_all_reduce} + +Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + {opt_keepdim} + {out} + +Returns: + A tuple (std, mean) containing the standard deviation and mean. + +Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.std_mean(a, dim=0, keepdim=True) + (tensor([[1.2620, 1.0028, 1.0957, 0.6038]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + +.. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + +""".format(**multi_dim_common), +) + +add_docstr( + torch.sub, + r""" +sub(input, other, *, alpha=1, out=None) -> Tensor + +Subtracts :attr:`other`, scaled by :attr:`alpha`, from :attr:`input`. + +.. math:: + \text{{out}}_i = \text{{input}}_i - \text{{alpha}} \times \text{{other}}_i +""" + + r""" + +Supports :ref:`broadcasting to a common shape `, +:ref:`type promotion `, and integer, float, and complex inputs. + +Args: + {input} + other (Tensor or Number): the tensor or number to subtract from :attr:`input`. + +Keyword args: + alpha (Number): the multiplier for :attr:`other`. + {out} + +Example:: + + >>> a = torch.tensor((1, 2)) + >>> b = torch.tensor((0, 1)) + >>> torch.sub(a, b, alpha=2) + tensor([1, 0]) +""".format(**common_args), +) + +add_docstr( + torch.subtract, + r""" +subtract(input, other, *, alpha=1, out=None) -> Tensor + +Alias for :func:`torch.sub`. +""", +) + +add_docstr( + torch.sum, + r""" +sum(input, *, dtype=None) -> Tensor + +Returns the sum of all elements in the :attr:`input` tensor. + +Args: + {input} + +Keyword args: + {dtype} + +.. note:: Use the `dtype` argument if you need the result in a specific tensor type. + Otherwise, the result type may be automatically promoted (e.g., from `torch.int32` to `torch.int64`). + +Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.1133, -0.9567, 0.2958]]) + >>> torch.sum(a) + tensor(-0.5475) + +.. function:: sum(input, dim, keepdim=False, *, dtype=None) -> Tensor + :noindex: + +Returns the sum of each row of the :attr:`input` tensor in the given +dimension :attr:`dim`. If :attr:`dim` is a list of dimensions, +reduce over all of them. + +{keepdim_details} + +Args: + {input} + {opt_dim_all_reduce} + {opt_keepdim} + +Keyword args: + {dtype} + +Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.0569, -0.2475, 0.0737, -0.3429], + [-0.2993, 0.9138, 0.9337, -1.6864], + [ 0.1132, 0.7892, -0.1003, 0.5688], + [ 0.3637, -0.9906, -0.4752, -1.5197]]) + >>> torch.sum(a, 1) + tensor([-0.4598, -0.1381, 1.3708, -2.6217]) + >>> b = torch.arange(4 * 5 * 6).view(4, 5, 6) + >>> torch.sum(b, (2, 1)) + tensor([ 435., 1335., 2235., 3135.]) +""".format(**multi_dim_common), +) + +add_docstr( + torch.nansum, + r""" +nansum(input, *, dtype=None) -> Tensor + +Returns the sum of all elements, treating Not a Numbers (NaNs) as zero. + +Args: + {input} + +Keyword args: + {dtype} + +Example:: + + >>> a = torch.tensor([1., 2., float('nan'), 4.]) + >>> torch.nansum(a) + tensor(7.) + +.. function:: nansum(input, dim, keepdim=False, *, dtype=None) -> Tensor + :noindex: + +Returns the sum of each row of the :attr:`input` tensor in the given +dimension :attr:`dim`, treating Not a Numbers (NaNs) as zero. +If :attr:`dim` is a list of dimensions, reduce over all of them. + +{keepdim_details} + +Args: + {input} + {opt_dim_all_reduce} + {opt_keepdim} + +Keyword args: + {dtype} + +Example:: + + >>> torch.nansum(torch.tensor([1., float("nan")])) + tensor(1.) + >>> a = torch.tensor([[1, 2], [3., float("nan")]]) + >>> torch.nansum(a) + tensor(6.) + >>> torch.nansum(a, dim=0) + tensor([4., 2.]) + >>> torch.nansum(a, dim=1) + tensor([3., 3.]) +""".format(**multi_dim_common), +) + +add_docstr( + torch.svd, + r""" +svd(input, some=True, compute_uv=True, *, out=None) -> (Tensor, Tensor, Tensor) + +Computes the singular value decomposition of either a matrix or batch of +matrices :attr:`input`. The singular value decomposition is represented as a +namedtuple `(U, S, V)`, such that :attr:`input` :math:`= U \text{diag}(S) V^{\text{H}}`. +where :math:`V^{\text{H}}` is the transpose of `V` for real inputs, +and the conjugate transpose of `V` for complex inputs. +If :attr:`input` is a batch of matrices, then `U`, `S`, and `V` are also +batched with the same batch dimensions as :attr:`input`. + +If :attr:`some` is `True` (default), the method returns the reduced singular +value decomposition. In this case, if the last two dimensions of :attr:`input` are +`m` and `n`, then the returned `U` and `V` matrices will contain only +`min(n, m)` orthonormal columns. + +If :attr:`compute_uv` is `False`, the returned `U` and `V` will be +zero-filled matrices of shape `(m, m)` and `(n, n)` +respectively, and the same device as :attr:`input`. The argument :attr:`some` +has no effect when :attr:`compute_uv` is `False`. + +Supports :attr:`input` of float, double, cfloat and cdouble data types. +The dtypes of `U` and `V` are the same as :attr:`input`'s. `S` will +always be real-valued, even if :attr:`input` is complex. + +.. warning:: + + :func:`torch.svd` is deprecated in favor of :func:`torch.linalg.svd` + and will be removed in a future PyTorch release. + + ``U, S, V = torch.svd(A, some=some, compute_uv=True)`` (default) should be replaced with + + .. code:: python + + U, S, Vh = torch.linalg.svd(A, full_matrices=not some) + V = Vh.mH + + ``_, S, _ = torch.svd(A, some=some, compute_uv=False)`` should be replaced with + + .. code:: python + + S = torch.linalg.svdvals(A) + +.. note:: Differences with :func:`torch.linalg.svd`: + + * :attr:`some` is the opposite of + :func:`torch.linalg.svd`'s :attr:`full_matrices`. Note that + default value for both is `True`, so the default behavior is + effectively the opposite. + * :func:`torch.svd` returns `V`, whereas :func:`torch.linalg.svd` returns + `Vh`, that is, :math:`V^{\text{H}}`. + * If :attr:`compute_uv` is `False`, :func:`torch.svd` returns zero-filled + tensors for `U` and `Vh`, whereas :func:`torch.linalg.svd` returns + empty tensors. + +.. note:: The singular values are returned in descending order. If :attr:`input` is a batch of matrices, + then the singular values of each matrix in the batch are returned in descending order. + +.. note:: The `S` tensor can only be used to compute gradients if :attr:`compute_uv` is `True`. + +.. note:: When :attr:`some` is `False`, the gradients on `U[..., :, min(m, n):]` + and `V[..., :, min(m, n):]` will be ignored in the backward pass, as those vectors + can be arbitrary bases of the corresponding subspaces. + +.. note:: The implementation of :func:`torch.linalg.svd` on CPU uses LAPACK's routine `?gesdd` + (a divide-and-conquer algorithm) instead of `?gesvd` for speed. Analogously, + on GPU, it uses cuSOLVER's routines `gesvdj` and `gesvdjBatched` on CUDA 10.1.243 + and later, and MAGMA's routine `gesdd` on earlier versions of CUDA. + +.. note:: The returned `U` will not be contiguous. The matrix (or batch of matrices) will + be represented as a column-major matrix (i.e. Fortran-contiguous). + +.. warning:: The gradients with respect to `U` and `V` will only be finite when the input does not + have zero nor repeated singular values. + +.. warning:: If the distance between any two singular values is close to zero, the gradients with respect to + `U` and `V` will be numerically unstable, as they depends on + :math:`\frac{1}{\min_{i \neq j} \sigma_i^2 - \sigma_j^2}`. The same happens when the matrix + has small singular values, as these gradients also depend on `S^{-1}`. + +.. warning:: For complex-valued :attr:`input` the singular value decomposition is not unique, + as `U` and `V` may be multiplied by an arbitrary phase factor :math:`e^{i \phi}` on every column. + The same happens when :attr:`input` has repeated singular values, where one may multiply + the columns of the spanning subspace in `U` and `V` by a rotation matrix + and `the resulting vectors will span the same subspace`_. + Different platforms, like NumPy, or inputs on different device types, + may produce different `U` and `V` tensors. + +Args: + input (Tensor): the input tensor of size `(*, m, n)` where `*` is zero or more + batch dimensions consisting of `(m, n)` matrices. + some (bool, optional): controls whether to compute the reduced or full decomposition, and + consequently, the shape of returned `U` and `V`. Default: `True`. + compute_uv (bool, optional): controls whether to compute `U` and `V`. Default: `True`. + +Keyword args: + out (tuple, optional): the output tuple of tensors + +Example:: + + >>> a = torch.randn(5, 3) + >>> a + tensor([[ 0.2364, -0.7752, 0.6372], + [ 1.7201, 0.7394, -0.0504], + [-0.3371, -1.0584, 0.5296], + [ 0.3550, -0.4022, 1.5569], + [ 0.2445, -0.0158, 1.1414]]) + >>> u, s, v = torch.svd(a) + >>> u + tensor([[ 0.4027, 0.0287, 0.5434], + [-0.1946, 0.8833, 0.3679], + [ 0.4296, -0.2890, 0.5261], + [ 0.6604, 0.2717, -0.2618], + [ 0.4234, 0.2481, -0.4733]]) + >>> s + tensor([2.3289, 2.0315, 0.7806]) + >>> v + tensor([[-0.0199, 0.8766, 0.4809], + [-0.5080, 0.4054, -0.7600], + [ 0.8611, 0.2594, -0.4373]]) + >>> torch.dist(a, torch.mm(torch.mm(u, torch.diag(s)), v.t())) + tensor(8.6531e-07) + >>> a_big = torch.randn(7, 5, 3) + >>> u, s, v = torch.svd(a_big) + >>> torch.dist(a_big, torch.matmul(torch.matmul(u, torch.diag_embed(s)), v.mT)) + tensor(2.6503e-06) + +.. _the resulting vectors will span the same subspace: + (https://en.wikipedia.org/wiki/Singular_value_decomposition#Singular_values,_singular_vectors,_and_their_relation_to_the_SVD) +""", +) + + +add_docstr( + torch.t, + r""" +t(input) -> Tensor + +Expects :attr:`input` to be <= 2-D tensor and transposes dimensions 0 +and 1. + +0-D and 1-D tensors are returned as is. When input is a 2-D tensor this +is equivalent to ``transpose(input, 0, 1)``. + +Args: + {input} + +Example:: + + >>> x = torch.randn(()) + >>> x + tensor(0.1995) + >>> torch.t(x) + tensor(0.1995) + >>> x = torch.randn(3) + >>> x + tensor([ 2.4320, -0.4608, 0.7702]) + >>> torch.t(x) + tensor([ 2.4320, -0.4608, 0.7702]) + >>> x = torch.randn(2, 3) + >>> x + tensor([[ 0.4875, 0.9158, -0.5872], + [ 0.3938, -0.6929, 0.6932]]) + >>> torch.t(x) + tensor([[ 0.4875, 0.3938], + [ 0.9158, -0.6929], + [-0.5872, 0.6932]]) + +See also :func:`torch.transpose`. +""".format(**common_args), +) + +add_docstr( + torch.flip, + r""" +flip(input, dims) -> Tensor + +Reverse the order of an n-D tensor along given axis in dims. + +.. note:: + `torch.flip` makes a copy of :attr:`input`'s data. This is different from NumPy's `np.flip`, + which returns a view in constant time. Since copying a tensor's data is more work than viewing that data, + `torch.flip` is expected to be slower than `np.flip`. + +Args: + {input} + dims (a list or tuple): axis to flip on + +Example:: + + >>> x = torch.arange(8).view(2, 2, 2) + >>> x + tensor([[[ 0, 1], + [ 2, 3]], + + [[ 4, 5], + [ 6, 7]]]) + >>> torch.flip(x, [0, 1]) + tensor([[[ 6, 7], + [ 4, 5]], + + [[ 2, 3], + [ 0, 1]]]) +""".format(**common_args), +) + +add_docstr( + torch.fliplr, + r""" +fliplr(input) -> Tensor + +Flip tensor in the left/right direction, returning a new tensor. + +Flip the entries in each row in the left/right direction. +Columns are preserved, but appear in a different order than before. + +Note: + Requires the tensor to be at least 2-D. + +.. note:: + `torch.fliplr` makes a copy of :attr:`input`'s data. This is different from NumPy's `np.fliplr`, + which returns a view in constant time. Since copying a tensor's data is more work than viewing that data, + `torch.fliplr` is expected to be slower than `np.fliplr`. + +Args: + input (Tensor): Must be at least 2-dimensional. + +Example:: + + >>> x = torch.arange(4).view(2, 2) + >>> x + tensor([[0, 1], + [2, 3]]) + >>> torch.fliplr(x) + tensor([[1, 0], + [3, 2]]) +""".format(**common_args), +) + +add_docstr( + torch.flipud, + r""" +flipud(input) -> Tensor + +Flip tensor in the up/down direction, returning a new tensor. + +Flip the entries in each column in the up/down direction. +Rows are preserved, but appear in a different order than before. + +Note: + Requires the tensor to be at least 1-D. + +.. note:: + `torch.flipud` makes a copy of :attr:`input`'s data. This is different from NumPy's `np.flipud`, + which returns a view in constant time. Since copying a tensor's data is more work than viewing that data, + `torch.flipud` is expected to be slower than `np.flipud`. + +Args: + input (Tensor): Must be at least 1-dimensional. + +Example:: + + >>> x = torch.arange(4).view(2, 2) + >>> x + tensor([[0, 1], + [2, 3]]) + >>> torch.flipud(x) + tensor([[2, 3], + [0, 1]]) +""".format(**common_args), +) + +add_docstr( + torch.roll, + r""" +roll(input, shifts, dims=None) -> Tensor + +Roll the tensor :attr:`input` along the given dimension(s). Elements that are +shifted beyond the last position are re-introduced at the first position. If +:attr:`dims` is `None`, the tensor will be flattened before rolling and then +restored to the original shape. + +Args: + {input} + shifts (int or tuple of ints): The number of places by which the elements + of the tensor are shifted. If shifts is a tuple, dims must be a tuple of + the same size, and each dimension will be rolled by the corresponding + value + dims (int or tuple of ints): Axis along which to roll + +Example:: + + >>> x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]).view(4, 2) + >>> x + tensor([[1, 2], + [3, 4], + [5, 6], + [7, 8]]) + >>> torch.roll(x, 1) + tensor([[8, 1], + [2, 3], + [4, 5], + [6, 7]]) + >>> torch.roll(x, 1, 0) + tensor([[7, 8], + [1, 2], + [3, 4], + [5, 6]]) + >>> torch.roll(x, -1, 0) + tensor([[3, 4], + [5, 6], + [7, 8], + [1, 2]]) + >>> torch.roll(x, shifts=(2, 1), dims=(0, 1)) + tensor([[6, 5], + [8, 7], + [2, 1], + [4, 3]]) +""".format(**common_args), +) + +add_docstr( + torch.rot90, + r""" +rot90(input, k=1, dims=(0, 1)) -> Tensor + +Rotate an n-D tensor by 90 degrees in the plane specified by dims axis. +Rotation direction is from the first towards the second axis if k > 0, and from the second towards the first for k < 0. + +Args: + {input} + k (int): number of times to rotate. Default value is 1 + dims (a list or tuple): axis to rotate. Default value is [0, 1] + +Example:: + + >>> x = torch.arange(4).view(2, 2) + >>> x + tensor([[0, 1], + [2, 3]]) + >>> torch.rot90(x, 1, [0, 1]) + tensor([[1, 3], + [0, 2]]) + + >>> x = torch.arange(8).view(2, 2, 2) + >>> x + tensor([[[0, 1], + [2, 3]], + + [[4, 5], + [6, 7]]]) + >>> torch.rot90(x, 1, [1, 2]) + tensor([[[1, 3], + [0, 2]], + + [[5, 7], + [4, 6]]]) +""".format(**common_args), +) + +add_docstr( + torch.take, + r""" +take(input, index) -> Tensor + +Returns a new tensor with the elements of :attr:`input` at the given indices. +The input tensor is treated as if it were viewed as a 1-D tensor. The result +takes the same shape as the indices. + +Args: + {input} + index (LongTensor): the indices into tensor + +Example:: + + >>> src = torch.tensor([[4, 3, 5], + ... [6, 7, 8]]) + >>> torch.take(src, torch.tensor([0, 2, 5])) + tensor([ 4, 5, 8]) +""".format(**common_args), +) + +add_docstr( + torch.take_along_dim, + r""" +take_along_dim(input, indices, dim=None, *, out=None) -> Tensor + +Selects values from :attr:`input` at the 1-dimensional indices from :attr:`indices` along the given :attr:`dim`. + +If :attr:`dim` is None, the input array is treated as if it has been flattened to 1d. + +Functions that return indices along a dimension, like :func:`torch.argmax` and :func:`torch.argsort`, +are designed to work with this function. See the examples below. + +.. note:: + This function is similar to NumPy's `take_along_axis`. + See also :func:`torch.gather`. + +Args: + {input} + indices (LongTensor): the indices into :attr:`input`. Must have long dtype. + dim (int, optional): dimension to select along. Default: 0 + +Keyword args: + {out} + +Example:: + + >>> t = torch.tensor([[10, 30, 20], [60, 40, 50]]) + >>> max_idx = torch.argmax(t) + >>> torch.take_along_dim(t, max_idx) + tensor([60]) + >>> sorted_idx = torch.argsort(t, dim=1) + >>> torch.take_along_dim(t, sorted_idx, dim=1) + tensor([[10, 20, 30], + [40, 50, 60]]) +""".format(**common_args), +) + +add_docstr( + torch.tan, + r""" +tan(input, *, out=None) -> Tensor + +Returns a new tensor with the tangent of the elements of :attr:`input`. + +.. math:: + \text{out}_{i} = \tan(\text{input}_{i}) +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-1.2027, -1.7687, 0.4412, -1.3856]) + >>> torch.tan(a) + tensor([-2.5930, 4.9859, 0.4722, -5.3366]) +""".format(**common_args), +) + +add_docstr( + torch.tanh, + r""" +tanh(input, *, out=None) -> Tensor + +Returns a new tensor with the hyperbolic tangent of the elements +of :attr:`input`. + +.. math:: + \text{out}_{i} = \tanh(\text{input}_{i}) +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.8986, -0.7279, 1.1745, 0.2611]) + >>> torch.tanh(a) + tensor([ 0.7156, -0.6218, 0.8257, 0.2553]) +""".format(**common_args), +) + +add_docstr( + # torch.softmax doc str. Point this to torch.nn.functional.softmax + torch.softmax, + r""" +softmax(input, dim, *, dtype=None) -> Tensor + +Alias for :func:`torch.nn.functional.softmax`. +""", +) + +add_docstr( + torch.topk, + r""" +topk(input, k, dim=None, largest=True, sorted=True, *, out=None) -> (Tensor, LongTensor) + +Returns the :attr:`k` largest elements of the given :attr:`input` tensor along +a given dimension. + +If :attr:`dim` is not given, the last dimension of the `input` is chosen. + +If :attr:`largest` is ``False`` then the `k` smallest elements are returned. + +A namedtuple of `(values, indices)` is returned with the `values` and +`indices` of the largest `k` elements of each row of the `input` tensor in the +given dimension `dim`. + +The boolean option :attr:`sorted` if ``True``, will make sure that the returned +`k` elements are themselves sorted + +.. note:: + When using `torch.topk`, the indices of tied elements are not guaranteed to be stable + and may vary across different invocations. + +Args: + {input} + k (int): the k in "top-k" + dim (int, optional): the dimension to sort along + largest (bool, optional): controls whether to return largest or + smallest elements + sorted (bool, optional): controls whether to return the elements + in sorted order + +Keyword args: + out (tuple, optional): the output tuple of (Tensor, LongTensor) that can be + optionally given to be used as output buffers + +Example:: + + >>> x = torch.arange(1., 6.) + >>> x + tensor([ 1., 2., 3., 4., 5.]) + >>> torch.topk(x, 3) + torch.return_types.topk(values=tensor([5., 4., 3.]), indices=tensor([4, 3, 2])) +""".format(**common_args), +) + +add_docstr( + torch.trace, + r""" +trace(input) -> Tensor + +Returns the sum of the elements of the diagonal of the input 2-D matrix. + +Example:: + + >>> x = torch.arange(1., 10.).view(3, 3) + >>> x + tensor([[ 1., 2., 3.], + [ 4., 5., 6.], + [ 7., 8., 9.]]) + >>> torch.trace(x) + tensor(15.) +""", +) + +add_docstr( + torch.transpose, + r""" +transpose(input, dim0, dim1) -> Tensor + +Returns a tensor that is a transposed version of :attr:`input`. +The given dimensions :attr:`dim0` and :attr:`dim1` are swapped. + +If :attr:`input` is a strided tensor then the resulting :attr:`out` +tensor shares its underlying storage with the :attr:`input` tensor, so +changing the content of one would change the content of the other. + +If :attr:`input` is a :ref:`sparse tensor ` then the +resulting :attr:`out` tensor *does not* share the underlying storage +with the :attr:`input` tensor. + +If :attr:`input` is a :ref:`sparse tensor ` with compressed +layout (SparseCSR, SparseBSR, SparseCSC or SparseBSC) the arguments +:attr:`dim0` and :attr:`dim1` must be both batch dimensions, or must +both be sparse dimensions. The batch dimensions of a sparse tensor are the +dimensions preceding the sparse dimensions. + +.. note:: + Transpositions which interchange the sparse dimensions of a `SparseCSR` + or `SparseCSC` layout tensor will result in the layout changing between + the two options. Transposition of the sparse dimensions of a ` SparseBSR` + or `SparseBSC` layout tensor will likewise generate a result with the + opposite layout. + + +Args: + {input} + dim0 (int): the first dimension to be transposed + dim1 (int): the second dimension to be transposed + +Example:: + + >>> x = torch.randn(2, 3) + >>> x + tensor([[ 1.0028, -0.9893, 0.5809], + [-0.1669, 0.7299, 0.4942]]) + >>> torch.transpose(x, 0, 1) + tensor([[ 1.0028, -0.1669], + [-0.9893, 0.7299], + [ 0.5809, 0.4942]]) + +See also :func:`torch.t`. +""".format(**common_args), +) + +add_docstr( + torch.triangular_solve, + r""" +triangular_solve(b, A, upper=True, transpose=False, unitriangular=False, *, out=None) -> (Tensor, Tensor) + +Solves a system of equations with a square upper or lower triangular invertible matrix :math:`A` +and multiple right-hand sides :math:`b`. + +In symbols, it solves :math:`AX = b` and assumes :math:`A` is square upper-triangular +(or lower-triangular if :attr:`upper`\ `= False`) and does not have zeros on the diagonal. + +`torch.triangular_solve(b, A)` can take in 2D inputs `b, A` or inputs that are +batches of 2D matrices. If the inputs are batches, then returns +batched outputs `X` + +If the diagonal of :attr:`A` contains zeros or elements that are very close to zero and +:attr:`unitriangular`\ `= False` (default) or if the input matrix is badly conditioned, +the result may contain `NaN` s. + +Supports input of float, double, cfloat and cdouble data types. + +.. warning:: + + :func:`torch.triangular_solve` is deprecated in favor of :func:`torch.linalg.solve_triangular` + and will be removed in a future PyTorch release. + :func:`torch.linalg.solve_triangular` has its arguments reversed and does not return a + copy of one of the inputs. + + ``X = torch.triangular_solve(B, A).solution`` should be replaced with + + .. code:: python + + X = torch.linalg.solve_triangular(A, B) + +Args: + b (Tensor): multiple right-hand sides of size :math:`(*, m, k)` where + :math:`*` is zero of more batch dimensions + A (Tensor): the input triangular coefficient matrix of size :math:`(*, m, m)` + where :math:`*` is zero or more batch dimensions + upper (bool, optional): whether :math:`A` is upper or lower triangular. Default: ``True``. + transpose (bool, optional): solves `op(A)X = b` where `op(A) = A^T` if this flag is ``True``, + and `op(A) = A` if it is ``False``. Default: ``False``. + unitriangular (bool, optional): whether :math:`A` is unit triangular. + If True, the diagonal elements of :math:`A` are assumed to be + 1 and not referenced from :math:`A`. Default: ``False``. + +Keyword args: + out ((Tensor, Tensor), optional): tuple of two tensors to write + the output to. Ignored if `None`. Default: `None`. + +Returns: + A namedtuple `(solution, cloned_coefficient)` where `cloned_coefficient` + is a clone of :math:`A` and `solution` is the solution :math:`X` to :math:`AX = b` + (or whatever variant of the system of equations, depending on the keyword arguments.) + +Examples:: + + >>> A = torch.randn(2, 2).triu() + >>> A + tensor([[ 1.1527, -1.0753], + [ 0.0000, 0.7986]]) + >>> b = torch.randn(2, 3) + >>> b + tensor([[-0.0210, 2.3513, -1.5492], + [ 1.5429, 0.7403, -1.0243]]) + >>> torch.triangular_solve(b, A) + torch.return_types.triangular_solve( + solution=tensor([[ 1.7841, 2.9046, -2.5405], + [ 1.9320, 0.9270, -1.2826]]), + cloned_coefficient=tensor([[ 1.1527, -1.0753], + [ 0.0000, 0.7986]])) +""", +) + +add_docstr( + torch.tril, + r""" +tril(input, diagonal=0, *, out=None) -> Tensor + +Returns the lower triangular part of the matrix (2-D tensor) or batch of matrices +:attr:`input`, the other elements of the result tensor :attr:`out` are set to 0. + +The lower triangular part of the matrix is defined as the elements on and +below the diagonal. + +The argument :attr:`diagonal` controls which diagonal to consider. If +:attr:`diagonal` = 0, all elements on and below the main diagonal are +retained. A positive value includes just as many diagonals above the main +diagonal, and similarly a negative value excludes just as many diagonals below +the main diagonal. The main diagonal are the set of indices +:math:`\lbrace (i, i) \rbrace` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]` where +:math:`d_{1}, d_{2}` are the dimensions of the matrix. +""" + + r""" +Args: + {input} + diagonal (int, optional): the diagonal to consider + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(3, 3) + >>> a + tensor([[-1.0813, -0.8619, 0.7105], + [ 0.0935, 0.1380, 2.2112], + [-0.3409, -0.9828, 0.0289]]) + >>> torch.tril(a) + tensor([[-1.0813, 0.0000, 0.0000], + [ 0.0935, 0.1380, 0.0000], + [-0.3409, -0.9828, 0.0289]]) + + >>> b = torch.randn(4, 6) + >>> b + tensor([[ 1.2219, 0.5653, -0.2521, -0.2345, 1.2544, 0.3461], + [ 0.4785, -0.4477, 0.6049, 0.6368, 0.8775, 0.7145], + [ 1.1502, 3.2716, -1.1243, -0.5413, 0.3615, 0.6864], + [-0.0614, -0.7344, -1.3164, -0.7648, -1.4024, 0.0978]]) + >>> torch.tril(b, diagonal=1) + tensor([[ 1.2219, 0.5653, 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.4785, -0.4477, 0.6049, 0.0000, 0.0000, 0.0000], + [ 1.1502, 3.2716, -1.1243, -0.5413, 0.0000, 0.0000], + [-0.0614, -0.7344, -1.3164, -0.7648, -1.4024, 0.0000]]) + >>> torch.tril(b, diagonal=-1) + tensor([[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.4785, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [ 1.1502, 3.2716, 0.0000, 0.0000, 0.0000, 0.0000], + [-0.0614, -0.7344, -1.3164, 0.0000, 0.0000, 0.0000]]) +""".format(**common_args), +) + +# docstr is split in two parts to avoid format mis-captureing :math: braces '{}' +# as common args. +add_docstr( + torch.tril_indices, + r""" +tril_indices(row, col, offset=0, *, dtype=torch.long, device='cpu', layout=torch.strided) -> Tensor + +Returns the indices of the lower triangular part of a :attr:`row`-by- +:attr:`col` matrix in a 2-by-N Tensor, where the first row contains row +coordinates of all indices and the second row contains column coordinates. +Indices are ordered based on rows and then columns. + +The lower triangular part of the matrix is defined as the elements on and +below the diagonal. + +The argument :attr:`offset` controls which diagonal to consider. If +:attr:`offset` = 0, all elements on and below the main diagonal are +retained. A positive value includes just as many diagonals above the main +diagonal, and similarly a negative value excludes just as many diagonals below +the main diagonal. The main diagonal are the set of indices +:math:`\lbrace (i, i) \rbrace` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]` +where :math:`d_{1}, d_{2}` are the dimensions of the matrix. + +.. note:: + When running on CUDA, ``row * col`` must be less than :math:`2^{59}` to + prevent overflow during calculation. +""" + + r""" +Args: + row (``int``): number of rows in the 2-D matrix. + col (``int``): number of columns in the 2-D matrix. + offset (``int``): diagonal offset from the main diagonal. + Default: if not provided, 0. + +Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor, + only support ``torch.int``, ``torch.long``. Default: if ``None``, ``torch.long``. + {device} + layout (:class:`torch.layout`, optional): currently only support ``torch.strided``. + +Example:: + + >>> a = torch.tril_indices(3, 3) + >>> a + tensor([[0, 1, 1, 2, 2, 2], + [0, 0, 1, 0, 1, 2]]) + + >>> a = torch.tril_indices(4, 3, -1) + >>> a + tensor([[1, 2, 2, 3, 3, 3], + [0, 0, 1, 0, 1, 2]]) + + >>> a = torch.tril_indices(4, 3, 1) + >>> a + tensor([[0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], + [0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2]]) +""".format(**factory_common_args), +) + +add_docstr( + torch.triu, + r""" +triu(input, diagonal=0, *, out=None) -> Tensor + +Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices +:attr:`input`, the other elements of the result tensor :attr:`out` are set to 0. + +The upper triangular part of the matrix is defined as the elements on and +above the diagonal. + +The argument :attr:`diagonal` controls which diagonal to consider. If +:attr:`diagonal` = 0, all elements on and above the main diagonal are +retained. A positive value excludes just as many diagonals above the main +diagonal, and similarly a negative value includes just as many diagonals below +the main diagonal. The main diagonal are the set of indices +:math:`\lbrace (i, i) \rbrace` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]` where +:math:`d_{1}, d_{2}` are the dimensions of the matrix. +""" + + r""" +Args: + {input} + diagonal (int, optional): the diagonal to consider + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(3, 3) + >>> a + tensor([[ 0.2309, 0.5207, 2.0049], + [ 0.2072, -1.0680, 0.6602], + [ 0.3480, -0.5211, -0.4573]]) + >>> torch.triu(a) + tensor([[ 0.2309, 0.5207, 2.0049], + [ 0.0000, -1.0680, 0.6602], + [ 0.0000, 0.0000, -0.4573]]) + >>> torch.triu(a, diagonal=1) + tensor([[ 0.0000, 0.5207, 2.0049], + [ 0.0000, 0.0000, 0.6602], + [ 0.0000, 0.0000, 0.0000]]) + >>> torch.triu(a, diagonal=-1) + tensor([[ 0.2309, 0.5207, 2.0049], + [ 0.2072, -1.0680, 0.6602], + [ 0.0000, -0.5211, -0.4573]]) + + >>> b = torch.randn(4, 6) + >>> b + tensor([[ 0.5876, -0.0794, -1.8373, 0.6654, 0.2604, 1.5235], + [-0.2447, 0.9556, -1.2919, 1.3378, -0.1768, -1.0857], + [ 0.4333, 0.3146, 0.6576, -1.0432, 0.9348, -0.4410], + [-0.9888, 1.0679, -1.3337, -1.6556, 0.4798, 0.2830]]) + >>> torch.triu(b, diagonal=1) + tensor([[ 0.0000, -0.0794, -1.8373, 0.6654, 0.2604, 1.5235], + [ 0.0000, 0.0000, -1.2919, 1.3378, -0.1768, -1.0857], + [ 0.0000, 0.0000, 0.0000, -1.0432, 0.9348, -0.4410], + [ 0.0000, 0.0000, 0.0000, 0.0000, 0.4798, 0.2830]]) + >>> torch.triu(b, diagonal=-1) + tensor([[ 0.5876, -0.0794, -1.8373, 0.6654, 0.2604, 1.5235], + [-0.2447, 0.9556, -1.2919, 1.3378, -0.1768, -1.0857], + [ 0.0000, 0.3146, 0.6576, -1.0432, 0.9348, -0.4410], + [ 0.0000, 0.0000, -1.3337, -1.6556, 0.4798, 0.2830]]) +""".format(**common_args), +) + +# docstr is split in two parts to avoid format mis-capturing :math: braces '{}' +# as common args. +add_docstr( + torch.triu_indices, + r""" +triu_indices(row, col, offset=0, *, dtype=torch.long, device='cpu', layout=torch.strided) -> Tensor + +Returns the indices of the upper triangular part of a :attr:`row` by +:attr:`col` matrix in a 2-by-N Tensor, where the first row contains row +coordinates of all indices and the second row contains column coordinates. +Indices are ordered based on rows and then columns. + +The upper triangular part of the matrix is defined as the elements on and +above the diagonal. + +The argument :attr:`offset` controls which diagonal to consider. If +:attr:`offset` = 0, all elements on and above the main diagonal are +retained. A positive value excludes just as many diagonals above the main +diagonal, and similarly a negative value includes just as many diagonals below +the main diagonal. The main diagonal are the set of indices +:math:`\lbrace (i, i) \rbrace` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]` +where :math:`d_{1}, d_{2}` are the dimensions of the matrix. + +.. note:: + When running on CUDA, ``row * col`` must be less than :math:`2^{59}` to + prevent overflow during calculation. +""" + + r""" +Args: + row (``int``): number of rows in the 2-D matrix. + col (``int``): number of columns in the 2-D matrix. + offset (``int``): diagonal offset from the main diagonal. + Default: if not provided, 0. + +Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor, + only support ``torch.int``, ``torch.long``. Default: if ``None``, ``torch.long``. + {device} + layout (:class:`torch.layout`, optional): currently only support ``torch.strided``. + +Example:: + + >>> a = torch.triu_indices(3, 3) + >>> a + tensor([[0, 0, 0, 1, 1, 2], + [0, 1, 2, 1, 2, 2]]) + + >>> a = torch.triu_indices(4, 3, -1) + >>> a + tensor([[0, 0, 0, 1, 1, 1, 2, 2, 3], + [0, 1, 2, 0, 1, 2, 1, 2, 2]]) + + >>> a = torch.triu_indices(4, 3, 1) + >>> a + tensor([[0, 0, 1], + [1, 2, 2]]) +""".format(**factory_common_args), +) + +add_docstr( + torch.true_divide, + r""" +true_divide(dividend, divisor, *, out) -> Tensor + +Alias for :func:`torch.div` with ``rounding_mode=None``. +""", +) + +add_docstr( + torch.trunc, + r""" +trunc(input, *, out=None) -> Tensor + +Returns a new tensor with the truncated integer values of +the elements of :attr:`input`. + +For integer inputs, follows the array-api convention of returning a +copy of the input tensor. + +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 3.4742, 0.5466, -0.8008, -0.9079]) + >>> torch.trunc(a) + tensor([ 3., 0., -0., -0.]) +""".format(**common_args), +) + +add_docstr( + torch.fake_quantize_per_tensor_affine, + r""" +fake_quantize_per_tensor_affine(input, scale, zero_point, quant_min, quant_max) -> Tensor + +Returns a new tensor with the data in :attr:`input` fake quantized using :attr:`scale`, +:attr:`zero_point`, :attr:`quant_min` and :attr:`quant_max`. + +.. math:: + \text{output} = ( + min( + \text{quant\_max}, + max( + \text{quant\_min}, + \text{std::nearby\_int}(\text{input} / \text{scale}) + \text{zero\_point} + ) + ) - \text{zero\_point} + ) \times \text{scale} + +Args: + input (Tensor): the input value(s), ``torch.float32`` tensor + scale (double scalar or ``float32`` Tensor): quantization scale + zero_point (int64 scalar or ``int32`` Tensor): quantization zero_point + quant_min (int64): lower bound of the quantized domain + quant_max (int64): upper bound of the quantized domain + +Returns: + Tensor: A newly fake_quantized ``torch.float32`` tensor + +Example:: + + >>> x = torch.randn(4) + >>> x + tensor([ 0.0552, 0.9730, 0.3973, -1.0780]) + >>> torch.fake_quantize_per_tensor_affine(x, 0.1, 0, 0, 255) + tensor([0.1000, 1.0000, 0.4000, 0.0000]) + >>> torch.fake_quantize_per_tensor_affine(x, torch.tensor(0.1), torch.tensor(0), 0, 255) + tensor([0.1000, 1.0000, 0.4000, 0.0000]) +""", +) + +add_docstr( + torch.fake_quantize_per_channel_affine, + r""" +fake_quantize_per_channel_affine(input, scale, zero_point, axis, quant_min, quant_max) -> Tensor + +Returns a new tensor with the data in :attr:`input` fake quantized per channel using :attr:`scale`, +:attr:`zero_point`, :attr:`quant_min` and :attr:`quant_max`, across the channel specified by :attr:`axis`. + +.. math:: + \text{output} = ( + min( + \text{quant\_max}, + max( + \text{quant\_min}, + \text{std::nearby\_int}(\text{input} / \text{scale}) + \text{zero\_point} + ) + ) - \text{zero\_point} + ) \times \text{scale} + +Args: + input (Tensor): the input value(s), in ``torch.float32`` + scale (Tensor): quantization scale, per channel in ``torch.float32`` + zero_point (Tensor): quantization zero_point, per channel in ``torch.int32`` or ``torch.half`` or ``torch.float32`` + axis (int32): channel axis + quant_min (int64): lower bound of the quantized domain + quant_max (int64): upper bound of the quantized domain + +Returns: + Tensor: A newly fake_quantized per channel ``torch.float32`` tensor + +Example:: + + >>> x = torch.randn(2, 2, 2) + >>> x + tensor([[[-0.2525, -0.0466], + [ 0.3491, -0.2168]], + + [[-0.5906, 1.6258], + [ 0.6444, -0.0542]]]) + >>> scales = (torch.randn(2) + 1) * 0.05 + >>> scales + tensor([0.0475, 0.0486]) + >>> zero_points = torch.zeros(2).to(torch.int32) + >>> zero_points + tensor([0, 0]) + >>> torch.fake_quantize_per_channel_affine(x, scales, zero_points, 1, 0, 255) + tensor([[[0.0000, 0.0000], + [0.3405, 0.0000]], + + [[0.0000, 1.6134], + [0.6323, 0.0000]]]) +""", +) + +add_docstr( + torch.fix, + r""" +fix(input, *, out=None) -> Tensor + +Alias for :func:`torch.trunc` +""", +) + +add_docstr( + torch.unsqueeze, + r""" +unsqueeze(input, dim) -> Tensor + +Returns a new tensor with a dimension of size one inserted at the +specified position. + +The returned tensor shares the same underlying data with this tensor. + +A :attr:`dim` value within the range ``[-input.dim() - 1, input.dim() + 1)`` +can be used. Negative :attr:`dim` will correspond to :meth:`unsqueeze` +applied at :attr:`dim` = ``dim + input.dim() + 1``. + +Args: + {input} + dim (int): the index at which to insert the singleton dimension + +Example:: + + >>> x = torch.tensor([1, 2, 3, 4]) + >>> torch.unsqueeze(x, 0) + tensor([[ 1, 2, 3, 4]]) + >>> torch.unsqueeze(x, 1) + tensor([[ 1], + [ 2], + [ 3], + [ 4]]) +""".format(**common_args), +) + +add_docstr( + torch.var, + r""" +var(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + +Calculates the variance over the dimensions specified by :attr:`dim`. :attr:`dim` +can be a single dimension, list of dimensions, or ``None`` to reduce over all +dimensions. + +The variance (:math:`\sigma^2`) is calculated as + +.. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + +where :math:`x` is the sample set of elements, :math:`\bar{x}` is the +sample mean, :math:`N` is the number of samples and :math:`\delta N` is +the :attr:`correction`. +""" + + r""" + +{keepdim_details} + +Args: + {input} + {opt_dim_all_reduce} + +Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + {opt_keepdim} + {out} + +Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.var(a, dim=1, keepdim=True) + tensor([[1.0631], + [0.5590], + [1.4893], + [0.8258]]) + +.. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + +""".format(**multi_dim_common), +) + +add_docstr( + torch.var_mean, + r""" +var_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + +Calculates the variance and mean over the dimensions specified by :attr:`dim`. +:attr:`dim` can be a single dimension, list of dimensions, or ``None`` to +reduce over all dimensions. + +The variance (:math:`\sigma^2`) is calculated as + +.. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + +where :math:`x` is the sample set of elements, :math:`\bar{x}` is the +sample mean, :math:`N` is the number of samples and :math:`\delta N` is +the :attr:`correction`. +""" + + r""" + +{keepdim_details} + +Args: + {input} + {opt_dim_all_reduce} + +Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + {opt_keepdim} + {out} + +Returns: + A tuple (var, mean) containing the variance and mean. + +Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.var_mean(a, dim=0, keepdim=True) + (tensor([[1.5926, 1.0056, 1.2005, 0.3646]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + +.. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + +""".format(**multi_dim_common), +) + +add_docstr( + torch.zeros, + r""" +zeros(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + +Returns a tensor filled with the scalar value `0`, with the shape defined +by the variable argument :attr:`size`. + +Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + +Keyword args: + {out} + {dtype} + {layout} + {device} + {requires_grad} + +Example:: + + >>> torch.zeros(2, 3) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.]]) + + >>> torch.zeros(5) + tensor([ 0., 0., 0., 0., 0.]) +""".format(**factory_common_args), +) + +add_docstr( + torch.zeros_like, + r""" +zeros_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + +Returns a tensor filled with the scalar value `0`, with the same size as +:attr:`input`. ``torch.zeros_like(input)`` is equivalent to +``torch.zeros(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``. + +.. warning:: + As of 0.4, this function does not support an :attr:`out` keyword. As an alternative, + the old ``torch.zeros_like(input, out=output)`` is equivalent to + ``torch.zeros(input.size(), out=output)``. + +Args: + {input} + +Keyword args: + {dtype} + {layout} + {device} + {requires_grad} + {memory_format} + +Example:: + + >>> input = torch.empty(2, 3) + >>> torch.zeros_like(input) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.]]) +""".format(**factory_like_common_args), +) + +add_docstr( + torch.empty, + """ +empty(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False, \ +memory_format=torch.contiguous_format) -> Tensor + +Returns a tensor filled with uninitialized data. The shape of the tensor is +defined by the variable argument :attr:`size`. + +.. note:: + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, the output tensor is initialized to prevent any possible + nondeterministic behavior from using the data as an input to an operation. + Floating point and complex tensors are filled with NaN, and integer tensors + are filled with the maximum value. + +Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + +Keyword args: + {out} + {dtype} + {layout} + {device} + {requires_grad} + {pin_memory} + {memory_format} + +Example:: + + >>> torch.empty((2,3), dtype=torch.int64) + tensor([[ 9.4064e+13, 2.8000e+01, 9.3493e+13], + [ 7.5751e+18, 7.1428e+18, 7.5955e+18]]) +""".format(**factory_common_args), +) + +add_docstr( + torch.empty_like, + r""" +empty_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + +Returns an uninitialized tensor with the same size as :attr:`input`. +``torch.empty_like(input)`` is equivalent to +``torch.empty(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``. + +.. note:: + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, the output tensor is initialized to prevent any possible + nondeterministic behavior from using the data as an input to an operation. + Floating point and complex tensors are filled with NaN, and integer tensors + are filled with the maximum value. + +Args: + {input} + +Keyword args: + {dtype} + {layout} + {device} + {requires_grad} + {memory_format} + +Example:: + + >>> a=torch.empty((2,3), dtype=torch.int32, device = 'cuda') + >>> torch.empty_like(a) + tensor([[0, 0, 0], + [0, 0, 0]], device='cuda:0', dtype=torch.int32) +""".format(**factory_like_common_args), +) + +add_docstr( + torch.empty_strided, + r""" +empty_strided(size, stride, *, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False) -> Tensor + +Creates a tensor with the specified :attr:`size` and :attr:`stride` and filled with undefined data. + +.. warning:: + If the constructed tensor is "overlapped" (with multiple indices referring to the same element + in memory) its behavior is undefined. + +.. note:: + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, the output tensor is initialized to prevent any possible + nondeterministic behavior from using the data as an input to an operation. + Floating point and complex tensors are filled with NaN, and integer tensors + are filled with the maximum value. + +Args: + size (tuple of int): the shape of the output tensor + stride (tuple of int): the strides of the output tensor + +Keyword args: + {dtype} + {layout} + {device} + {requires_grad} + {pin_memory} + +Example:: + + >>> a = torch.empty_strided((2, 3), (1, 2)) + >>> a + tensor([[8.9683e-44, 4.4842e-44, 5.1239e+07], + [0.0000e+00, 0.0000e+00, 3.0705e-41]]) + >>> a.stride() + (1, 2) + >>> a.size() + torch.Size([2, 3]) +""".format(**factory_common_args), +) + +add_docstr( + torch.empty_permuted, + r""" +empty_permuted(size, physical_layout, *, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False) -> Tensor + +Creates an uninitialized, non-overlapping and dense tensor with the +specified :attr:`size`, with :attr:`physical_layout` specifying how the +dimensions are physically laid out in memory (each logical dimension is listed +from outermost to innermost). :attr:`physical_layout` is a generalization +of NCHW/NHWC notation: if each dimension is assigned a number according to +what order they occur in size (N=0, C=1, H=2, W=3), then NCHW is ``(0, 1, 2, 3)`` +while NHWC is ``(0, 2, 3, 1)``. Equivalently, the strides of the output +tensor ``t`` are such that ``t.stride(physical_layout[i]) == contiguous_strides[i]`` +(notably, this function is *not* equivalent to ``torch.empty(size).permute(physical_layout)``). + +Unlike :func:`torch.empty_strided`, this is guaranteed to produce a dense +tensor with no overlaps. If possible, prefer using this function over +:func:`torch.empty_strided` or manual use of :func:`torch.as_strided`. + +.. note:: + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, the output tensor is initialized to prevent any possible + nondeterministic behavior from using the data as an input to an operation. + Floating point and complex tensors are filled with NaN, and integer tensors + are filled with the maximum value. + +Args: + size (tuple of int): the shape of the output tensor + physical_layout (tuple of int): the ordering of dimensions physically in memory + +Keyword args: + {dtype} + {layout} + {device} + {requires_grad} + {pin_memory} + +Examples: + + >>> torch.empty((2, 3, 5, 7)).stride() + (105, 35, 7, 1) + >>> torch.empty_permuted((2, 3, 5, 7), (0, 1, 2, 3)).stride() + (105, 35, 7, 1) + >>> torch.empty((2, 3, 5, 7), memory_format=torch.channels_last).stride() + (105, 1, 21, 3) + >>> torch.empty_permuted((2, 3, 5, 7), (0, 2, 3, 1)).stride() + (105, 1, 21, 3) + >>> torch.empty_permuted((2, 3, 5, 7), (0, 2, 3, 1)).dim_order() + (0, 2, 3, 1) +""".format(**factory_common_args), +) + +add_docstr( + torch.full, + r""" +full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + +Creates a tensor of size :attr:`size` filled with :attr:`fill_value`. The +tensor's dtype is inferred from :attr:`fill_value`. + +Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + fill_value (Scalar): the value to fill the output tensor with. + +Keyword args: + {out} + {dtype} + {layout} + {device} + {requires_grad} + +Example:: + + >>> torch.full((2, 3), 3.141592) + tensor([[ 3.1416, 3.1416, 3.1416], + [ 3.1416, 3.1416, 3.1416]]) +""".format(**factory_common_args), +) + +add_docstr( + torch.full_like, + """ +full_like(input, fill_value, \\*, dtype=None, layout=torch.strided, device=None, requires_grad=False, \ +memory_format=torch.preserve_format) -> Tensor + +Returns a tensor with the same size as :attr:`input` filled with :attr:`fill_value`. +``torch.full_like(input, fill_value)`` is equivalent to +``torch.full(input.size(), fill_value, dtype=input.dtype, layout=input.layout, device=input.device)``. + +Args: + {input} + fill_value: the number to fill the output tensor with. + +Keyword args: + {dtype} + {layout} + {device} + {requires_grad} + {memory_format} +""".format(**factory_like_common_args), +) + +add_docstr( + torch.det, + r""" +det(input) -> Tensor + +Alias for :func:`torch.linalg.det` +""", +) + +add_docstr( + torch.where, + r""" +where(condition, input, other, *, out=None) -> Tensor + +Return a tensor of elements selected from either :attr:`input` or :attr:`other`, depending on :attr:`condition`. + +The operation is defined as: + +.. math:: + \text{out}_i = \begin{cases} + \text{input}_i & \text{if } \text{condition}_i \\ + \text{other}_i & \text{otherwise} \\ + \end{cases} +""" + + r""" +.. note:: + The tensors :attr:`condition`, :attr:`input`, :attr:`other` must be :ref:`broadcastable `. + +Arguments: + condition (BoolTensor): When True (nonzero), yield input, otherwise yield other + input (Tensor or Scalar): value (if :attr:`input` is a scalar) or values selected at indices + where :attr:`condition` is ``True`` + other (Tensor or Scalar): value (if :attr:`other` is a scalar) or values selected at indices + where :attr:`condition` is ``False`` + +Keyword args: + {out} + +Returns: + Tensor: A tensor of shape equal to the broadcasted shape of :attr:`condition`, :attr:`input`, :attr:`other` + +Example:: + + >>> x = torch.randn(3, 2) + >>> y = torch.ones(3, 2) + >>> x + tensor([[-0.4620, 0.3139], + [ 0.3898, -0.7197], + [ 0.0478, -0.1657]]) + >>> torch.where(x > 0, 1.0, 0.0) + tensor([[0., 1.], + [1., 0.], + [1., 0.]]) + >>> torch.where(x > 0, x, y) + tensor([[ 1.0000, 0.3139], + [ 0.3898, 1.0000], + [ 0.0478, 1.0000]]) + >>> x = torch.randn(2, 2, dtype=torch.double) + >>> x + tensor([[ 1.0779, 0.0383], + [-0.8785, -1.1089]], dtype=torch.float64) + >>> torch.where(x > 0, x, 0.) + tensor([[1.0779, 0.0383], + [0.0000, 0.0000]], dtype=torch.float64) + +.. function:: where(condition) -> tuple of LongTensor + :noindex: + +``torch.where(condition)`` is identical to +``torch.nonzero(condition, as_tuple=True)``. + +.. note:: + See also :func:`torch.nonzero`. +""".format(**common_args), +) + +add_docstr( + torch.logdet, + r""" +logdet(input) -> Tensor + +Calculates log determinant of a square matrix or batches of square matrices. + +It returns ``-inf`` if the input has a determinant of zero, and ``NaN`` if it has +a negative determinant. + +.. note:: + Backward through :meth:`logdet` internally uses SVD results when :attr:`input` + is not invertible. In this case, double backward through :meth:`logdet` will + be unstable in when :attr:`input` doesn't have distinct singular values. See + :func:`torch.linalg.svd` for details. + +.. seealso:: + + :func:`torch.linalg.slogdet` computes the sign (resp. angle) and natural logarithm of the + absolute value of the determinant of real-valued (resp. complex) square matrices. + +Arguments: + input (Tensor): the input tensor of size ``(*, n, n)`` where ``*`` is zero or more + batch dimensions. + +Example:: + + >>> A = torch.randn(3, 3) + >>> torch.det(A) + tensor(0.2611) + >>> torch.logdet(A) + tensor(-1.3430) + >>> A + tensor([[[ 0.9254, -0.6213], + [-0.5787, 1.6843]], + + [[ 0.3242, -0.9665], + [ 0.4539, -0.0887]], + + [[ 1.1336, -0.4025], + [-0.7089, 0.9032]]]) + >>> A.det() + tensor([1.1990, 0.4099, 0.7386]) + >>> A.det().log() + tensor([ 0.1815, -0.8917, -0.3031]) +""", +) + +add_docstr( + torch.slogdet, + r""" +slogdet(input) -> (Tensor, Tensor) + +Alias for :func:`torch.linalg.slogdet` +""", +) + +add_docstr( + torch.pinverse, + r""" +pinverse(input, rcond=1e-15) -> Tensor + +Alias for :func:`torch.linalg.pinv` +""", +) + +add_docstr( + torch.hann_window, + """ +hann_window(window_length, periodic=True, *, dtype=None, \ +layout=torch.strided, device=None, requires_grad=False) -> Tensor +""" + + r""" +Hann window function. + +.. math:: + w[n] = \frac{1}{2}\ \left[1 - \cos \left( \frac{2 \pi n}{N - 1} \right)\right] = + \sin^2 \left( \frac{\pi n}{N - 1} \right), + +where :math:`N` is the full window size. + +The input :attr:`window_length` is a positive integer controlling the +returned window size. :attr:`periodic` flag determines whether the returned +window trims off the last duplicate value from the symmetric window and is +ready to be used as a periodic window with functions like +:meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in +above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have +``torch.hann_window(L, periodic=True)`` equal to +``torch.hann_window(L + 1, periodic=False)[:-1])``. + +.. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. +""" + + r""" +Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + +Keyword args: + {dtype} Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + {device} + {requires_grad} + +Returns: + Tensor: A 1-D tensor of size :math:`(\text{{window\_length}},)` containing the window + +""".format(**factory_common_args), +) + + +add_docstr( + torch.hamming_window, + """ +hamming_window(window_length, periodic=True, alpha=0.54, beta=0.46, *, dtype=None, \ +layout=torch.strided, device=None, requires_grad=False) -> Tensor +""" + + r""" +Hamming window function. + +.. math:: + w[n] = \alpha - \beta\ \cos \left( \frac{2 \pi n}{N - 1} \right), + +where :math:`N` is the full window size. + +The input :attr:`window_length` is a positive integer controlling the +returned window size. :attr:`periodic` flag determines whether the returned +window trims off the last duplicate value from the symmetric window and is +ready to be used as a periodic window with functions like +:meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in +above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have +``torch.hamming_window(L, periodic=True)`` equal to +``torch.hamming_window(L + 1, periodic=False)[:-1])``. + +.. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + +.. note:: + This is a generalized version of :meth:`torch.hann_window`. +""" + + r""" +Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + alpha (float, optional): The coefficient :math:`\alpha` in the equation above + beta (float, optional): The coefficient :math:`\beta` in the equation above + +Keyword args: + {dtype} Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + {device} + {requires_grad} + +Returns: + Tensor: A 1-D tensor of size :math:`(\text{{window\_length}},)` containing the window. + +""".format(**factory_common_args), +) + + +add_docstr( + torch.bartlett_window, + """ +bartlett_window(window_length, periodic=True, *, dtype=None, \ +layout=torch.strided, device=None, requires_grad=False) -> Tensor +""" + + r""" +Bartlett window function. + +.. math:: + w[n] = 1 - \left| \frac{2n}{N-1} - 1 \right| = \begin{cases} + \frac{2n}{N - 1} & \text{if } 0 \leq n \leq \frac{N - 1}{2} \\ + 2 - \frac{2n}{N - 1} & \text{if } \frac{N - 1}{2} < n < N \\ + \end{cases}, + +where :math:`N` is the full window size. + +The input :attr:`window_length` is a positive integer controlling the +returned window size. :attr:`periodic` flag determines whether the returned +window trims off the last duplicate value from the symmetric window and is +ready to be used as a periodic window with functions like +:meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in +above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have +``torch.bartlett_window(L, periodic=True)`` equal to +``torch.bartlett_window(L + 1, periodic=False)[:-1])``. + +.. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. +""" + + r""" +Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + +Keyword args: + {dtype} Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + {device} + {requires_grad} + +Returns: + Tensor: A 1-D tensor of size :math:`(\text{{window\_length}},)` containing the window + +""".format(**factory_common_args), +) + + +add_docstr( + torch.blackman_window, + """ +blackman_window(window_length, periodic=True, *, dtype=None, \ +layout=torch.strided, device=None, requires_grad=False) -> Tensor +""" + + r""" +Blackman window function. + +.. math:: + w[n] = 0.42 - 0.5 \cos \left( \frac{2 \pi n}{N - 1} \right) + 0.08 \cos \left( \frac{4 \pi n}{N - 1} \right) + +where :math:`N` is the full window size. + +The input :attr:`window_length` is a positive integer controlling the +returned window size. :attr:`periodic` flag determines whether the returned +window trims off the last duplicate value from the symmetric window and is +ready to be used as a periodic window with functions like +:meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in +above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have +``torch.blackman_window(L, periodic=True)`` equal to +``torch.blackman_window(L + 1, periodic=False)[:-1]``. + +.. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. +""" + + r""" +Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + +Keyword args: + {dtype} Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + {device} + {requires_grad} + +Returns: + Tensor: A 1-D tensor of size :math:`(\text{{window\_length}},)` containing the window + +""".format(**factory_common_args), +) + + +add_docstr( + torch.kaiser_window, + """ +kaiser_window(window_length, periodic=True, beta=12.0, *, dtype=None, \ +layout=torch.strided, device=None, requires_grad=False) -> Tensor +""" + + r""" +Computes the Kaiser window with window length :attr:`window_length` and shape parameter :attr:`beta`. + +Let I_0 be the zeroth order modified Bessel function of the first kind (see :func:`torch.i0`) and +``N = L - 1`` if :attr:`periodic` is False and ``L`` if :attr:`periodic` is True, +where ``L`` is the :attr:`window_length`. This function computes: + +.. math:: + out_i = I_0 \left( \beta \sqrt{1 - \left( {\frac{i - N/2}{N/2}} \right) ^2 } \right) / I_0( \beta ) + +Calling ``torch.kaiser_window(L, B, periodic=True)`` is equivalent to calling +``torch.kaiser_window(L + 1, B, periodic=False)[:-1])``. +The :attr:`periodic` argument is intended as a helpful shorthand +to produce a periodic window as input to functions like :func:`torch.stft`. + +.. note:: + If :attr:`window_length` is one, then the returned window is a single element tensor containing a one. + +""" + + r""" +Args: + window_length (int): length of the window. + periodic (bool, optional): If True, returns a periodic window suitable for use in spectral analysis. + If False, returns a symmetric window suitable for use in filter design. + beta (float, optional): shape parameter for the window. + +Keyword args: + {dtype} + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + {device} + {requires_grad} + +""".format(**factory_common_args), +) + + +add_docstr( + torch.vander, + """ +vander(x, N=None, increasing=False) -> Tensor +""" + + r""" +Generates a Vandermonde matrix. + +The columns of the output matrix are elementwise powers of the input vector :math:`x^{{(N-1)}}, x^{{(N-2)}}, ..., x^0`. +If increasing is True, the order of the columns is reversed :math:`x^0, x^1, ..., x^{{(N-1)}}`. Such a +matrix with a geometric progression in each row is named for Alexandre-Theophile Vandermonde. + +Arguments: + x (Tensor): 1-D input tensor. + N (int, optional): Number of columns in the output. If N is not specified, + a square array is returned :math:`(N = len(x))`. + increasing (bool, optional): Order of the powers of the columns. If True, + the powers increase from left to right, if False (the default) they are reversed. + +Returns: + Tensor: Vandermonde matrix. If increasing is False, the first column is :math:`x^{{(N-1)}}`, + the second :math:`x^{{(N-2)}}` and so forth. If increasing is True, the columns + are :math:`x^0, x^1, ..., x^{{(N-1)}}`. + +Example:: + + >>> x = torch.tensor([1, 2, 3, 5]) + >>> torch.vander(x) + tensor([[ 1, 1, 1, 1], + [ 8, 4, 2, 1], + [ 27, 9, 3, 1], + [125, 25, 5, 1]]) + >>> torch.vander(x, N=3) + tensor([[ 1, 1, 1], + [ 4, 2, 1], + [ 9, 3, 1], + [25, 5, 1]]) + >>> torch.vander(x, N=3, increasing=True) + tensor([[ 1, 1, 1], + [ 1, 2, 4], + [ 1, 3, 9], + [ 1, 5, 25]]) + +""".format(**factory_common_args), +) + + +add_docstr( + torch.unbind, + r""" +unbind(input, dim=0) -> seq + +Removes a tensor dimension. + +Returns a tuple of all slices along a given dimension, already without it. + +Arguments: + input (Tensor): the tensor to unbind + dim (int): dimension to remove + +Example:: + + >>> torch.unbind(torch.tensor([[1, 2, 3], + >>> [4, 5, 6], + >>> [7, 8, 9]])) + (tensor([1, 2, 3]), tensor([4, 5, 6]), tensor([7, 8, 9])) +""", +) + + +add_docstr( + torch.combinations, + r""" +combinations(input: Tensor, r: int = 2, with_replacement: bool = False) -> seq + +Compute combinations of length :math:`r` of the given tensor. The behavior is similar to +python's `itertools.combinations` when `with_replacement` is set to `False`, and +`itertools.combinations_with_replacement` when `with_replacement` is set to `True`. + +Arguments: + input (Tensor): 1D vector. + r (int, optional): number of elements to combine + with_replacement (bool, optional): whether to allow duplication in combination + +Returns: + Tensor: A tensor equivalent to converting all the input tensors into lists, do + `itertools.combinations` or `itertools.combinations_with_replacement` on these + lists, and finally convert the resulting list into tensor. + +Example:: + + >>> a = [1, 2, 3] + >>> list(itertools.combinations(a, r=2)) + [(1, 2), (1, 3), (2, 3)] + >>> list(itertools.combinations(a, r=3)) + [(1, 2, 3)] + >>> list(itertools.combinations_with_replacement(a, r=2)) + [(1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (3, 3)] + >>> tensor_a = torch.tensor(a) + >>> torch.combinations(tensor_a) + tensor([[1, 2], + [1, 3], + [2, 3]]) + >>> torch.combinations(tensor_a, r=3) + tensor([[1, 2, 3]]) + >>> torch.combinations(tensor_a, with_replacement=True) + tensor([[1, 1], + [1, 2], + [1, 3], + [2, 2], + [2, 3], + [3, 3]]) + +""", +) + +add_docstr( + torch.trapezoid, + r""" +trapezoid(y, x=None, *, dx=None, dim=-1) -> Tensor + +Computes the `trapezoidal rule `_ along +:attr:`dim`. By default the spacing between elements is assumed to be 1, but +:attr:`dx` can be used to specify a different constant spacing, and :attr:`x` can be +used to specify arbitrary spacing along :attr:`dim`. Only one of :attr:`x` or :attr:`dx` should be specified. + + +Assuming :attr:`y` is a one-dimensional tensor with elements :math:`{y_0, y_1, ..., y_n}`, +the default computation is + +.. math:: + \begin{aligned} + \sum_{i = 1}^{n} \frac{1}{2} (y_i + y_{i-1}) + \end{aligned} + +When :attr:`dx` is specified the computation becomes + +.. math:: + \begin{aligned} + \sum_{i = 1}^{n} \frac{\Delta x}{2} (y_i + y_{i-1}) + \end{aligned} + +effectively multiplying the result by :attr:`dx`. When :attr:`x` is specified, +assuming :attr:`x` is also a one-dimensional tensor with +elements :math:`{x_0, x_1, ..., x_n}`, the computation becomes + +.. math:: + \begin{aligned} + \sum_{i = 1}^{n} \frac{(x_i - x_{i-1})}{2} (y_i + y_{i-1}) + \end{aligned} + +When :attr:`x` and :attr:`y` have the same size, the computation is as described above and no broadcasting is needed. +The broadcasting behavior of this function is as follows when their sizes are different. For both :attr:`x` +and :attr:`y`, the function computes the difference between consecutive elements along +dimension :attr:`dim`. This effectively creates two tensors, `x_diff` and `y_diff`, that have +the same shape as the original tensors except their lengths along the dimension :attr:`dim` is reduced by 1. +After that, those two tensors are broadcast together to compute final output as part of the trapezoidal rule. +See the examples below for details. + +.. note:: + The trapezoidal rule is a technique for approximating the definite integral of a function + by averaging its left and right Riemann sums. The approximation becomes more accurate as + the resolution of the partition increases. + +Arguments: + y (Tensor): Values to use when computing the trapezoidal rule. + x (Tensor): If specified, defines spacing between values as specified above. + +Keyword arguments: + dx (float): constant spacing between values. If neither :attr:`x` or :attr:`dx` + are specified then this defaults to 1. Effectively multiplies the result by its value. + dim (int): The dimension along which to compute the trapezoidal rule. + The last (inner-most) dimension by default. + +Examples:: + + >>> # Computes the trapezoidal rule in 1D, spacing is implicitly 1 + >>> y = torch.tensor([1, 5, 10]) + >>> torch.trapezoid(y) + tensor(10.5) + + >>> # Computes the same trapezoidal rule directly to verify + >>> (1 + 10 + 10) / 2 + 10.5 + + >>> # Computes the trapezoidal rule in 1D with constant spacing of 2 + >>> # NOTE: the result is the same as before, but multiplied by 2 + >>> torch.trapezoid(y, dx=2) + 21.0 + + >>> # Computes the trapezoidal rule in 1D with arbitrary spacing + >>> x = torch.tensor([1, 3, 6]) + >>> torch.trapezoid(y, x) + 28.5 + + >>> # Computes the same trapezoidal rule directly to verify + >>> ((3 - 1) * (1 + 5) + (6 - 3) * (5 + 10)) / 2 + 28.5 + + >>> # Computes the trapezoidal rule for each row of a 3x3 matrix + >>> y = torch.arange(9).reshape(3, 3) + tensor([[0, 1, 2], + [3, 4, 5], + [6, 7, 8]]) + >>> torch.trapezoid(y) + tensor([ 2., 8., 14.]) + + >>> # Computes the trapezoidal rule for each column of the matrix + >>> torch.trapezoid(y, dim=0) + tensor([ 6., 8., 10.]) + + >>> # Computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with the same arbitrary spacing + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([1, 3, 6]) + >>> torch.trapezoid(y, x) + array([5., 5., 5.]) + + >>> # Computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with different arbitrary spacing per row + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([[1, 2, 3], [1, 3, 5], [1, 4, 7]]) + >>> torch.trapezoid(y, x) + array([2., 4., 6.]) +""", +) + +add_docstr( + torch.trapz, + r""" +trapz(y, x, *, dim=-1) -> Tensor + +Alias for :func:`torch.trapezoid`. +""", +) + +add_docstr( + torch.cumulative_trapezoid, + r""" +cumulative_trapezoid(y, x=None, *, dx=None, dim=-1) -> Tensor + +Cumulatively computes the `trapezoidal rule `_ +along :attr:`dim`. By default the spacing between elements is assumed to be 1, but +:attr:`dx` can be used to specify a different constant spacing, and :attr:`x` can be +used to specify arbitrary spacing along :attr:`dim`. + +For more details, please read :func:`torch.trapezoid`. The difference between :func:`torch.trapezoid` +and this function is that, :func:`torch.trapezoid` returns a value for each integration, +where as this function returns a cumulative value for every spacing within the integration. This +is analogous to how `.sum` returns a value and `.cumsum` returns a cumulative sum. + +Arguments: + y (Tensor): Values to use when computing the trapezoidal rule. + x (Tensor): If specified, defines spacing between values as specified above. + +Keyword arguments: + dx (float): constant spacing between values. If neither :attr:`x` or :attr:`dx` + are specified then this defaults to 1. Effectively multiplies the result by its value. + dim (int): The dimension along which to compute the trapezoidal rule. + The last (inner-most) dimension by default. + +Examples:: + + >>> # Cumulatively computes the trapezoidal rule in 1D, spacing is implicitly 1. + >>> y = torch.tensor([1, 5, 10]) + >>> torch.cumulative_trapezoid(y) + tensor([3., 10.5]) + + >>> # Computes the same trapezoidal rule directly up to each element to verify + >>> (1 + 5) / 2 + 3.0 + >>> (1 + 10 + 10) / 2 + 10.5 + + >>> # Cumulatively computes the trapezoidal rule in 1D with constant spacing of 2 + >>> # NOTE: the result is the same as before, but multiplied by 2 + >>> torch.cumulative_trapezoid(y, dx=2) + tensor([6., 21.]) + + >>> # Cumulatively computes the trapezoidal rule in 1D with arbitrary spacing + >>> x = torch.tensor([1, 3, 6]) + >>> torch.cumulative_trapezoid(y, x) + tensor([6., 28.5]) + + >>> # Computes the same trapezoidal rule directly up to each element to verify + >>> ((3 - 1) * (1 + 5)) / 2 + 6.0 + >>> ((3 - 1) * (1 + 5) + (6 - 3) * (5 + 10)) / 2 + 28.5 + + >>> # Cumulatively computes the trapezoidal rule for each row of a 3x3 matrix + >>> y = torch.arange(9).reshape(3, 3) + tensor([[0, 1, 2], + [3, 4, 5], + [6, 7, 8]]) + >>> torch.cumulative_trapezoid(y) + tensor([[ 0.5, 2.], + [ 3.5, 8.], + [ 6.5, 14.]]) + + >>> # Cumulatively computes the trapezoidal rule for each column of the matrix + >>> torch.cumulative_trapezoid(y, dim=0) + tensor([[ 1.5, 2.5, 3.5], + [ 6.0, 8.0, 10.0]]) + + >>> # Cumulatively computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with the same arbitrary spacing + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([1, 3, 6]) + >>> torch.cumulative_trapezoid(y, x) + tensor([[2., 5.], + [2., 5.], + [2., 5.]]) + + >>> # Cumulatively computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with different arbitrary spacing per row + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([[1, 2, 3], [1, 3, 5], [1, 4, 7]]) + >>> torch.cumulative_trapezoid(y, x) + tensor([[1., 2.], + [2., 4.], + [3., 6.]]) +""", +) + +add_docstr( + torch.repeat_interleave, + r""" +repeat_interleave(input, repeats, dim=None, *, output_size=None) -> Tensor + +Repeat elements of a tensor. + +.. warning:: + + This is different from :meth:`torch.Tensor.repeat` but similar to ``numpy.repeat``. + +Args: + {input} + repeats (Tensor or int): The number of repetitions for each element. + repeats is broadcasted to fit the shape of the given axis. + dim (int, optional): The dimension along which to repeat values. + By default, use the flattened input array, and return a flat output + array. + +Keyword args: + output_size (int, optional): Total output size for the given axis + ( e.g. sum of repeats). If given, it will avoid stream synchronization + needed to calculate output shape of the tensor. + +Returns: + Tensor: Repeated tensor which has the same shape as input, except along the given axis. + +Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> x.repeat_interleave(2) + tensor([1, 1, 2, 2, 3, 3]) + >>> y = torch.tensor([[1, 2], [3, 4]]) + >>> torch.repeat_interleave(y, 2) + tensor([1, 1, 2, 2, 3, 3, 4, 4]) + >>> torch.repeat_interleave(y, 3, dim=1) + tensor([[1, 1, 1, 2, 2, 2], + [3, 3, 3, 4, 4, 4]]) + >>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0) + tensor([[1, 2], + [3, 4], + [3, 4]]) + >>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0, output_size=3) + tensor([[1, 2], + [3, 4], + [3, 4]]) + +If the `repeats` is `tensor([n1, n2, n3, ...])`, then the output will be +`tensor([0, 0, ..., 1, 1, ..., 2, 2, ..., ...])` where `0` appears `n1` times, +`1` appears `n2` times, `2` appears `n3` times, etc. + +.. function:: repeat_interleave(repeats, *) -> Tensor + :noindex: + +Repeats 0 repeats[0] times, 1 repeats[1] times, 2 repeats[2] times, etc. + +Args: + repeats (Tensor): The number of repetitions for each element. + +Returns: + Tensor: Repeated tensor of size `sum(repeats)`. + +Example:: + + >>> torch.repeat_interleave(torch.tensor([1, 2, 3])) + tensor([0, 1, 1, 2, 2, 2]) + +""".format(**common_args), +) + +add_docstr( + torch.tile, + r""" +tile(input, dims) -> Tensor + +Constructs a tensor by repeating the elements of :attr:`input`. +The :attr:`dims` argument specifies the number of repetitions +in each dimension. + +If :attr:`dims` specifies fewer dimensions than :attr:`input` has, then +ones are prepended to :attr:`dims` until all dimensions are specified. +For example, if :attr:`input` has shape (8, 6, 4, 2) and :attr:`dims` +is (2, 2), then :attr:`dims` is treated as (1, 1, 2, 2). + +Analogously, if :attr:`input` has fewer dimensions than :attr:`dims` +specifies, then :attr:`input` is treated as if it were unsqueezed at +dimension zero until it has as many dimensions as :attr:`dims` specifies. +For example, if :attr:`input` has shape (4, 2) and :attr:`dims` +is (3, 3, 2, 2), then :attr:`input` is treated as if it had the +shape (1, 1, 4, 2). + +.. note:: + + This function is similar to NumPy's tile function. + +Args: + input (Tensor): the tensor whose elements to repeat. + dims (tuple): the number of repetitions per dimension. + +Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> x.tile((2,)) + tensor([1, 2, 3, 1, 2, 3]) + >>> y = torch.tensor([[1, 2], [3, 4]]) + >>> torch.tile(y, (2, 2)) + tensor([[1, 2, 1, 2], + [3, 4, 3, 4], + [1, 2, 1, 2], + [3, 4, 3, 4]]) +""", +) + +add_docstr( + torch.quantize_per_tensor, + r""" +quantize_per_tensor(input, scale, zero_point, dtype) -> Tensor + +Converts a float tensor to a quantized tensor with given scale and zero point. + +Arguments: + input (Tensor): float tensor or list of tensors to quantize + scale (float or Tensor): scale to apply in quantization formula + zero_point (int or Tensor): offset in integer value that maps to float zero + dtype (:class:`torch.dtype`): the desired data type of returned tensor. + Has to be one of the quantized dtypes: ``torch.quint8``, ``torch.qint8``, ``torch.qint32`` + +Returns: + Tensor: A newly quantized tensor or list of quantized tensors. + +Example:: + + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8) + tensor([-1., 0., 1., 2.], size=(4,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.1, zero_point=10) + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8).int_repr() + tensor([ 0, 10, 20, 30], dtype=torch.uint8) + >>> torch.quantize_per_tensor([torch.tensor([-1.0, 0.0]), torch.tensor([-2.0, 2.0])], + >>> torch.tensor([0.1, 0.2]), torch.tensor([10, 20]), torch.quint8) + (tensor([-1., 0.], size=(2,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.1, zero_point=10), + tensor([-2., 2.], size=(2,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.2, zero_point=20)) + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), torch.tensor(0.1), torch.tensor(10), torch.quint8) + tensor([-1., 0., 1., 2.], size=(4,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.10, zero_point=10) +""", +) + +add_docstr( + torch.quantize_per_tensor_dynamic, + r""" +quantize_per_tensor_dynamic(input, dtype, reduce_range) -> Tensor + +Converts a float tensor to a quantized tensor with scale and zero_point calculated +dynamically based on the input. + +Arguments: + input (Tensor): float tensor or list of tensors to quantize + dtype (:class:`torch.dtype`): the desired data type of returned tensor. + Has to be one of the quantized dtypes: ``torch.quint8``, ``torch.qint8`` + reduce_range (bool): a flag to indicate whether to reduce the range of quantized + data by 1 bit, it's required to avoid instruction overflow for some hardwares + +Returns: + Tensor: A newly (dynamically) quantized tensor + +Example:: + + >>> t = torch.quantize_per_tensor_dynamic(torch.tensor([-1.0, 0.0, 1.0, 2.0]), torch.quint8, False) + >>> print(t) + tensor([-1., 0., 1., 2.], size=(4,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.011764705882352941, + zero_point=85) + >>> t.int_repr() + tensor([ 0, 85, 170, 255], dtype=torch.uint8) +""", +) + +add_docstr( + torch.quantize_per_channel, + r""" +quantize_per_channel(input, scales, zero_points, axis, dtype) -> Tensor + +Converts a float tensor to a per-channel quantized tensor with given scales and zero points. + +Arguments: + input (Tensor): float tensor to quantize + scales (Tensor): float 1D tensor of scales to use, size should match ``input.size(axis)`` + zero_points (int): integer 1D tensor of offset to use, size should match ``input.size(axis)`` + axis (int): dimension on which apply per-channel quantization + dtype (:class:`torch.dtype`): the desired data type of returned tensor. + Has to be one of the quantized dtypes: ``torch.quint8``, ``torch.qint8``, ``torch.qint32`` + +Returns: + Tensor: A newly quantized tensor + +Example:: + + >>> x = torch.tensor([[-1.0, 0.0], [1.0, 2.0]]) + >>> torch.quantize_per_channel(x, torch.tensor([0.1, 0.01]), torch.tensor([10, 0]), 0, torch.quint8) + tensor([[-1., 0.], + [ 1., 2.]], size=(2, 2), dtype=torch.quint8, + quantization_scheme=torch.per_channel_affine, + scale=tensor([0.1000, 0.0100], dtype=torch.float64), + zero_point=tensor([10, 0]), axis=0) + >>> torch.quantize_per_channel(x, torch.tensor([0.1, 0.01]), torch.tensor([10, 0]), 0, torch.quint8).int_repr() + tensor([[ 0, 10], + [100, 200]], dtype=torch.uint8) +""", +) + + +add_docstr( + torch.quantized_batch_norm, + r""" +quantized_batch_norm(input, weight=None, bias=None, mean, var, eps, output_scale, output_zero_point) -> Tensor + +Applies batch normalization on a 4D (NCHW) quantized tensor. + +.. math:: + + y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + +Arguments: + input (Tensor): quantized tensor + weight (Tensor): float tensor that corresponds to the gamma, size C + bias (Tensor): float tensor that corresponds to the beta, size C + mean (Tensor): float mean value in batch normalization, size C + var (Tensor): float tensor for variance, size C + eps (float): a value added to the denominator for numerical stability. + output_scale (float): output quantized tensor scale + output_zero_point (int): output quantized tensor zero_point + +Returns: + Tensor: A quantized tensor with batch normalization applied. + +Example:: + + >>> qx = torch.quantize_per_tensor(torch.rand(2, 2, 2, 2), 1.5, 3, torch.quint8) + >>> torch.quantized_batch_norm(qx, torch.ones(2), torch.zeros(2), torch.rand(2), torch.rand(2), 0.00001, 0.2, 2) + tensor([[[[-0.2000, -0.2000], + [ 1.6000, -0.2000]], + + [[-0.4000, -0.4000], + [-0.4000, 0.6000]]], + + + [[[-0.2000, -0.2000], + [-0.2000, -0.2000]], + + [[ 0.6000, -0.4000], + [ 0.6000, -0.4000]]]], size=(2, 2, 2, 2), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.2, zero_point=2) +""", +) + + +add_docstr( + torch.quantized_max_pool1d, + r""" +quantized_max_pool1d(input, kernel_size, stride=[], padding=0, dilation=1, ceil_mode=False) -> Tensor + +Applies a 1D max pooling over an input quantized tensor composed of several input planes. + +Arguments: + input (Tensor): quantized tensor + kernel_size (list of int): the size of the sliding window + stride (``list of int``, optional): the stride of the sliding window + padding (``list of int``, optional): padding to be added on both sides, must be >= 0 and <= kernel_size / 2 + dilation (``list of int``, optional): The stride between elements within a sliding window, must be > 0. Default 1 + ceil_mode (bool, optional): If True, will use ceil instead of floor to compute the output shape. + Defaults to False. + + +Returns: + Tensor: A quantized tensor with max_pool1d applied. + +Example:: + + >>> qx = torch.quantize_per_tensor(torch.rand(2, 2), 1.5, 3, torch.quint8) + >>> torch.quantized_max_pool1d(qx, [2]) + tensor([[0.0000], + [1.5000]], size=(2, 1), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=1.5, zero_point=3) +""", +) + + +add_docstr( + torch.quantized_max_pool2d, + r""" +quantized_max_pool2d(input, kernel_size, stride=[], padding=0, dilation=1, ceil_mode=False) -> Tensor + +Applies a 2D max pooling over an input quantized tensor composed of several input planes. + +Arguments: + input (Tensor): quantized tensor + kernel_size (``list of int``): the size of the sliding window + stride (``list of int``, optional): the stride of the sliding window + padding (``list of int``, optional): padding to be added on both sides, must be >= 0 and <= kernel_size / 2 + dilation (``list of int``, optional): The stride between elements within a sliding window, must be > 0. Default 1 + ceil_mode (bool, optional): If True, will use ceil instead of floor to compute the output shape. + Defaults to False. + + +Returns: + Tensor: A quantized tensor with max_pool2d applied. + +Example:: + + >>> qx = torch.quantize_per_tensor(torch.rand(2, 2, 2, 2), 1.5, 3, torch.quint8) + >>> torch.quantized_max_pool2d(qx, [2,2]) + tensor([[[[1.5000]], + + [[1.5000]]], + + + [[[0.0000]], + + [[0.0000]]]], size=(2, 2, 1, 1), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=1.5, zero_point=3) +""", +) + + +add_docstr( + torch.Stream, + r""" +Stream(device, *, priority) -> Stream + +An in-order queue of executing the respective tasks asynchronously in first in first out (FIFO) order. +It can control or synchronize the execution of other Stream or block the current host thread to ensure +the correct task sequencing. It supports with statement as a context manager to ensure the operators +within the with block are running on the corresponding stream. + +See in-depth description of the CUDA behavior at :ref:`cuda-semantics` for details +on the exact semantic that applies to all devices. + +Arguments: + device (:class:`torch.device`, optional): the desired device for the Stream. + If not given, the current :ref:`accelerator` type will be used. + priority (int, optional): priority of the stream, should be 0 or negative, where negative + numbers indicate higher priority. By default, streams have priority 0. + +Returns: + Stream: An torch.Stream object. + +Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> with torch.Stream(device='cuda') as s_cuda: + >>> a = torch.randn(10, 5, device='cuda') + >>> b = torch.randn(5, 10, device='cuda') + >>> c = torch.mm(a, b) +""", +) + + +add_docstr( + torch.Stream.query, + r""" +Stream.query() -> bool + +Check if all the work submitted has been completed. + +Returns: + bool: A boolean indicating if all kernels in this stream are completed. + +Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> s_cuda = torch.Stream(device='cuda') + >>> s_cuda.query() + True +""", +) + + +add_docstr( + torch.Stream.record_event, + r""" +Stream.record_event(event) -> Event + +Record an event. En-queuing it into the Stream to allow further synchronization from the current point in the FIFO queue. + +Arguments: + event (:class:`torch.Event`, optional): event to record. If not given, a new one will be allocated. + +Returns: + Event: Recorded event. + +Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> s_cuda = torch.Stream(device='cuda') + >>> e_cuda = s_cuda.record_event() +""", +) + + +add_docstr( + torch.Stream.synchronize, + r""" +Stream.synchronize() -> None + +Wait for all the kernels in this stream to complete. + +Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> s_cuda = torch.Stream(device='cuda') + >>> s_cuda.synchronize() +""", +) + + +add_docstr( + torch.Stream.wait_event, + r""" +Stream.wait_event(event) -> None + +Make all future work submitted to the stream wait for an event. + +Arguments: + event (:class:`torch.Event`): an event to wait for. + +Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> s1_cuda = torch.Stream(device='cuda') + >>> s2_cuda = torch.Stream(device='cuda') + >>> e_cuda = s1_cuda.record_event() + >>> s2_cuda.wait_event(e_cuda) +""", +) + + +add_docstr( + torch.Stream.wait_stream, + r""" +Stream.wait_stream(stream) -> None + +Synchronize with another stream. All future work submitted to this stream will wait until all kernels +already submitted to the given stream are completed. + +Arguments: + stream (:class:`torch.Stream`): a stream to synchronize. + +Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> s1_cuda = torch.Stream(device='cuda') + >>> s2_cuda = torch.Stream(device='cuda') + >>> s2_cuda.wait_stream(s1_cuda) +""", +) + + +add_docstr( + torch.Event, + r""" +Event(device=None, *, enable_timing=False, blocking=False, interprocess=False) + +Query and record Stream status to identify or control dependencies across Stream and measure timing. + +Arguments: + device (:class:`torch.device`, optional): the desired device for the Event. + If not given, the current :ref:`accelerator` type will be used. + enable_timing (bool, optional): indicates if the event should measure time (default: ``False``) + blocking (bool, optional): if ``True``, :meth:`wait` will be blocking (default: ``False``) + interprocess (bool): if ``True``, the event can be shared between processes (default: ``False``) + +.. warning:: + + Both blocking and interprocess are not supported right now and are noops. + +Returns: + Event: An torch.Event object. + +Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> event = torch.Event() + >>> e_cuda = torch.Event(device='cuda') +""", +) + + +add_docstr( + torch.Event.elapsed_time, + r""" +Event.elapsed_time(end_event) -> float + +Returns the elapsed time in milliseconds between when this event and the :attr:`end_event` are +each recorded via :func:`torch.Stream.record_event`. + +Arguments: + end_event (:class:`torch.Event`): The ending event has been recorded. + +Returns: + float: Time between starting and ending event in milliseconds. + +Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> s_cuda = torch.Stream(device='cuda') + >>> e1_cuda = s_cuda.record_event() + >>> e2_cuda = s_cuda.record_event() + >>> ms = e1_cuda.elapsed_time(e2_cuda) +""", +) + + +add_docstr( + torch.Event.query, + r""" +Event.query() -> bool + +Check if the stream where this event was recorded already moved past the point where the event was recorded. +Always returns ``True`` if the Event was not recorded. + +Returns: + bool: A boolean indicating if all work currently captured by event has completed. + +Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> s_cuda = torch.Stream(device='cuda') + >>> e_cuda = s_cuda.record_event() + >>> e_cuda.query() + True +""", +) + + +add_docstr( + torch.Event.record, + r""" +Event.record(stream=None) -> None + +Record the event in a given stream. The stream's device must match the event's device. +This function is equivalent to ``stream.record_event(self)``. + +Arguments: + stream (:class:`torch.Stream`, optional): A stream to be recorded. + If not given, the current stream will be used. + +Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> e_cuda = torch.Event(device='cuda') + >>> e_cuda.record() +""", +) + + +add_docstr( + torch.Event.synchronize, + r""" +Event.synchronize() -> None + +Wait for the event to complete. This prevents the CPU thread from proceeding until the event completes. + +Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> s_cuda = torch.Stream(device='cuda') + >>> e_cuda = s_cuda.record_event() + >>> e_cuda.synchronize() +""", +) + + +add_docstr( + torch.Event.wait, + r""" +Event.wait(stream=None) -> None + +Make all future work submitted to the given stream wait for this event. + +Arguments: + stream (:class:`torch.Stream`, optional): A stream to synchronize. + If not given, the current stream will be used. + +Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> s1_cuda = torch.Stream(device='cuda') + >>> s2_cuda = torch.Stream(device='cuda') + >>> e_cuda = s1_cuda.record() + >>> e_cuda.wait(s2) +""", +) + + +add_docstr( + torch.Generator, + r""" +Generator(device='cpu') -> Generator + +Creates and returns a generator object that manages the state of the algorithm which +produces pseudo random numbers. Used as a keyword argument in many :ref:`inplace-random-sampling` +functions. + +Arguments: + device (:class:`torch.device`, optional): the desired device for the generator. + +Returns: + Generator: An torch.Generator object. + +Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> g_cpu = torch.Generator() + >>> g_cuda = torch.Generator(device='cuda') +""", +) + + +add_docstr( + torch.Generator.set_state, + r""" +Generator.set_state(new_state) -> void + +Sets the Generator state. + +Arguments: + new_state (torch.ByteTensor): The desired state. + +Example:: + + >>> g_cpu = torch.Generator() + >>> g_cpu_other = torch.Generator() + >>> g_cpu.set_state(g_cpu_other.get_state()) +""", +) + + +add_docstr( + torch.Generator.get_state, + r""" +Generator.get_state() -> Tensor + +Returns the Generator state as a ``torch.ByteTensor``. + +Returns: + Tensor: A ``torch.ByteTensor`` which contains all the necessary bits + to restore a Generator to a specific point in time. + +Example:: + + >>> g_cpu = torch.Generator() + >>> g_cpu.get_state() +""", +) + +add_docstr( + torch.Generator.graphsafe_set_state, + r""" +Generator.graphsafe_set_state(state) -> None + +Sets the state of the generator to the specified state in a manner that is safe for use in graph capture. +This method is crucial for ensuring that the generator's state can be captured in the CUDA graph. + +Arguments: + state (torch.Generator): A Generator point to the new state for the generator, typically obtained from `graphsafe_get_state`. + +Example: + >>> g_cuda = torch.Generator(device='cuda') + >>> g_cuda_other = torch.Generator(device='cuda') + >>> current_state = g_cuda_other.graphsafe_get_state() + >>> g_cuda.graphsafe_set_state(current_state) +""", +) + +add_docstr( + torch.Generator.graphsafe_get_state, + r""" +Generator.graphsafe_get_state() -> torch.Generator + +Retrieves the current state of the generator in a manner that is safe for graph capture. +This method is crucial for ensuring that the generator's state can be captured in the CUDA graph. + +Returns: + torch.Generator: A Generator point to the current state of the generator + +Example: + >>> g_cuda = torch.Generator(device='cuda') + >>> current_state = g_cuda.graphsafe_get_state() +""", +) + +add_docstr( + torch.Generator.clone_state, + r""" +Generator.clone_state() -> torch.Generator + +Clones the current state of the generator and returns a new generator pointing to this cloned state. +This method is beneficial for preserving a particular state of a generator to restore at a later point. + +Returns: + torch.Generator: A Generator pointing to the newly cloned state. + +Example: + >>> g_cuda = torch.Generator(device='cuda') + >>> cloned_state = g_cuda.clone_state() +""", +) + +add_docstr( + torch.Generator.manual_seed, + r""" +Generator.manual_seed(seed) -> Generator + +Sets the seed for generating random numbers. Returns a `torch.Generator` object. Any 32-bit integer is a valid seed. + +Arguments: + seed (int): The desired seed. Value must be within the inclusive range + `[-0x8000_0000_0000_0000, 0xffff_ffff_ffff_ffff]`. Otherwise, a RuntimeError + is raised. Negative inputs are remapped to positive values with the formula + `0xffff_ffff_ffff_ffff + seed`. + +Returns: + Generator: An torch.Generator object. + +Example:: + + >>> g_cpu = torch.Generator() + >>> g_cpu.manual_seed(2147483647) +""", +) + + +add_docstr( + torch.Generator.initial_seed, + r""" +Generator.initial_seed() -> int + +Returns the initial seed for generating random numbers. + +Example:: + + >>> g_cpu = torch.Generator() + >>> g_cpu.initial_seed() + 2147483647 +""", +) + + +add_docstr( + torch.Generator.seed, + r""" +Generator.seed() -> int + +Gets a non-deterministic random number from std::random_device or the current +time and uses it to seed a Generator. + +Example:: + + >>> g_cpu = torch.Generator() + >>> g_cpu.seed() + 1516516984916 +""", +) + + +add_docstr( + torch.Generator.device, + r""" +Generator.device -> device + +Gets the current device of the generator. + +Example:: + + >>> g_cpu = torch.Generator() + >>> g_cpu.device + device(type='cpu') +""", +) + +add_docstr( + torch._assert_async, + r""" +_assert_async(tensor) -> void + +Asynchronously assert that the contents of tensor are nonzero. For CPU tensors, +this is equivalent to ``assert tensor`` or ``assert tensor.is_nonzero()``; for +CUDA tensors, we DO NOT synchronize and you may only find out the assertion +failed at a later CUDA kernel launch. Asynchronous assertion can be helpful for +testing invariants in CUDA tensors without giving up performance. This function +is NOT intended to be used for regular error checking, as it will trash your CUDA +context if the assert fails (forcing you to restart your PyTorch process.) + +Args: + tensor (Tensor): a one element tensor to test to see if it is nonzero. Zero + elements (including False for boolean tensors) cause an assertion failure + to be raised. +""", +) + +add_docstr( + torch.searchsorted, + r""" +searchsorted(sorted_sequence, values, *, out_int32=False, right=False, side=None, out=None, sorter=None) -> Tensor + +Find the indices from the *innermost* dimension of :attr:`sorted_sequence` such that, if the +corresponding values in :attr:`values` were inserted before the indices, when sorted, the order +of the corresponding *innermost* dimension within :attr:`sorted_sequence` would be preserved. +Return a new tensor with the same size as :attr:`values`. More formally, +the returned index satisfies the following rules: + +.. list-table:: + :widths: 12 10 78 + :header-rows: 1 + + * - :attr:`sorted_sequence` + - :attr:`right` + - *returned index satisfies* + * - 1-D + - False + - ``sorted_sequence[i-1] < values[m][n]...[l][x] <= sorted_sequence[i]`` + * - 1-D + - True + - ``sorted_sequence[i-1] <= values[m][n]...[l][x] < sorted_sequence[i]`` + * - N-D + - False + - ``sorted_sequence[m][n]...[l][i-1] < values[m][n]...[l][x] <= sorted_sequence[m][n]...[l][i]`` + * - N-D + - True + - ``sorted_sequence[m][n]...[l][i-1] <= values[m][n]...[l][x] < sorted_sequence[m][n]...[l][i]`` + +Args: + sorted_sequence (Tensor): N-D or 1-D tensor, containing monotonically increasing sequence on the *innermost* + dimension unless :attr:`sorter` is provided, in which case the sequence does not + need to be sorted + values (Tensor or Scalar): N-D tensor or a Scalar containing the search value(s). + +Keyword args: + out_int32 (bool, optional): indicate the output data type. torch.int32 if True, torch.int64 otherwise. + Default value is False, i.e. default output data type is torch.int64. + right (bool, optional): if False, return the first suitable location that is found. If True, return the + last such index. If no suitable index found, return 0 for non-numerical value + (eg. nan, inf) or the size of *innermost* dimension within :attr:`sorted_sequence` + (one pass the last index of the *innermost* dimension). In other words, if False, + gets the lower bound index for each value in :attr:`values` on the corresponding + *innermost* dimension of the :attr:`sorted_sequence`. If True, gets the upper + bound index instead. Default value is False. :attr:`side` does the same and is + preferred. It will error if :attr:`side` is set to "left" while this is True. + side (str, optional): the same as :attr:`right` but preferred. "left" corresponds to False for :attr:`right` + and "right" corresponds to True for :attr:`right`. It will error if this is set to + "left" while :attr:`right` is True. Default value is None. + out (Tensor, optional): the output tensor, must be the same size as :attr:`values` if provided. + sorter (LongTensor, optional): if provided, a tensor matching the shape of the unsorted + :attr:`sorted_sequence` containing a sequence of indices that sort it in the + ascending order on the innermost dimension + + +Example:: + + >>> sorted_sequence = torch.tensor([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]]) + >>> sorted_sequence + tensor([[ 1, 3, 5, 7, 9], + [ 2, 4, 6, 8, 10]]) + >>> values = torch.tensor([[3, 6, 9], [3, 6, 9]]) + >>> values + tensor([[3, 6, 9], + [3, 6, 9]]) + >>> torch.searchsorted(sorted_sequence, values) + tensor([[1, 3, 4], + [1, 2, 4]]) + >>> torch.searchsorted(sorted_sequence, values, side='right') + tensor([[2, 3, 5], + [1, 3, 4]]) + + >>> sorted_sequence_1d = torch.tensor([1, 3, 5, 7, 9]) + >>> sorted_sequence_1d + tensor([1, 3, 5, 7, 9]) + >>> torch.searchsorted(sorted_sequence_1d, values) + tensor([[1, 3, 4], + [1, 3, 4]]) +""", +) + +add_docstr( + torch.bucketize, + r""" +bucketize(input, boundaries, *, out_int32=False, right=False, out=None) -> Tensor + +Returns the indices of the buckets to which each value in the :attr:`input` belongs, where the +boundaries of the buckets are set by :attr:`boundaries`. Return a new tensor with the same size +as :attr:`input`. If :attr:`right` is False (default), then the left boundary is open. Note that +this behavior is opposite the behavior of +`numpy.digitize `_. +More formally, the returned index satisfies the following rules: + +.. list-table:: + :widths: 15 85 + :header-rows: 1 + + * - :attr:`right` + - *returned index satisfies* + * - False + - ``boundaries[i-1] < input[m][n]...[l][x] <= boundaries[i]`` + * - True + - ``boundaries[i-1] <= input[m][n]...[l][x] < boundaries[i]`` + +Args: + input (Tensor or Scalar): N-D tensor or a Scalar containing the search value(s). + boundaries (Tensor): 1-D tensor, must contain a strictly increasing sequence, or the return value is undefined. + +Keyword args: + out_int32 (bool, optional): indicate the output data type. torch.int32 if True, torch.int64 otherwise. + Default value is False, i.e. default output data type is torch.int64. + right (bool, optional): determines the behavior for values in :attr:`boundaries`. See the table above. + out (Tensor, optional): the output tensor, must be the same size as :attr:`input` if provided. + + +Example:: + + >>> boundaries = torch.tensor([1, 3, 5, 7, 9]) + >>> boundaries + tensor([1, 3, 5, 7, 9]) + >>> v = torch.tensor([[3, 6, 9], [3, 6, 9]]) + >>> v + tensor([[3, 6, 9], + [3, 6, 9]]) + >>> torch.bucketize(v, boundaries) + tensor([[1, 3, 4], + [1, 3, 4]]) + >>> torch.bucketize(v, boundaries, right=True) + tensor([[2, 3, 5], + [2, 3, 5]]) +""", +) + +add_docstr( + torch.view_as_real_copy, + r""" +Performs the same operation as :func:`torch.view_as_real`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.view_as_complex_copy, + r""" +Performs the same operation as :func:`torch.view_as_complex`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.as_strided_copy, + r""" +Performs the same operation as :func:`torch.as_strided`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.diagonal_copy, + r""" +Performs the same operation as :func:`torch.diagonal`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.expand_copy, + r""" +Performs the same operation as :func:`torch.Tensor.expand`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.permute_copy, + r""" +Performs the same operation as :func:`torch.permute`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.select_copy, + r""" +Performs the same operation as :func:`torch.select`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.detach_copy, + r""" +Performs the same operation as :func:`torch.detach`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.slice_copy, + r""" +Performs the same operation as :func:`torch.slice`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.split_copy, + r""" +Performs the same operation as :func:`torch.split`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.split_with_sizes_copy, + r""" +Performs the same operation as :func:`torch.split_with_sizes`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.squeeze_copy, + r""" +Performs the same operation as :func:`torch.squeeze`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.t_copy, + r""" +Performs the same operation as :func:`torch.t`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.transpose_copy, + r""" +Performs the same operation as :func:`torch.transpose`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.unsqueeze_copy, + r""" +Performs the same operation as :func:`torch.unsqueeze`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.indices_copy, + r""" +Performs the same operation as :func:`torch.indices`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.values_copy, + r""" +Performs the same operation as :func:`torch.values`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.crow_indices_copy, + r""" +Performs the same operation as :func:`torch.crow_indices`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.col_indices_copy, + r""" +Performs the same operation as :func:`torch.col_indices`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.unbind_copy, + r""" +Performs the same operation as :func:`torch.unbind`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.view_copy, + r""" +Performs the same operation as :func:`torch.view`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.unfold_copy, + r""" +Performs the same operation as :func:`torch.unfold`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.alias_copy, + r""" +Performs the same operation as :func:`torch.alias`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +for unary_base_func_name in ( + "exp", + "sqrt", + "abs", + "acos", + "asin", + "atan", + "ceil", + "cos", + "cosh", + "erf", + "erfc", + "expm1", + "floor", + "log", + "log10", + "log1p", + "log2", + "neg", + "tan", + "tanh", + "sin", + "sinh", + "round", + "lgamma", + "frac", + "reciprocal", + "sigmoid", + "trunc", + "zero", +): + unary_foreach_func_name = f"_foreach_{unary_base_func_name}" + if hasattr(torch, unary_foreach_func_name): + add_docstr( + getattr(torch, unary_foreach_func_name), + rf""" +{unary_foreach_func_name}(self: List[Tensor]) -> List[Tensor] + +Apply :func:`torch.{unary_base_func_name}` to each Tensor of the input list. + """, + ) + unary_inplace_foreach_func_name = f"{unary_foreach_func_name}_" + if hasattr(torch, unary_inplace_foreach_func_name): + add_docstr( + getattr(torch, unary_inplace_foreach_func_name), + rf""" +{unary_inplace_foreach_func_name}(self: List[Tensor]) -> None + +Apply :func:`torch.{unary_base_func_name}` to each Tensor of the input list. + """, + ) diff --git a/phivenv/Lib/site-packages/torch/_utils.py b/phivenv/Lib/site-packages/torch/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..30b7cca811cded54d8182f6e884c25f7fc9e2fd6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_utils.py @@ -0,0 +1,1111 @@ +# mypy: allow-untyped-defs +import copyreg +import functools +import importlib +import logging +import sys +import traceback +import warnings +from collections import defaultdict +from types import ModuleType +from typing import Any, Callable, Generic, Optional, TYPE_CHECKING +from typing_extensions import deprecated, ParamSpec + +import torch + + +def _type(self, dtype=None, non_blocking=False, **kwargs): + """Returns the type if `dtype` is not provided, else casts this object to + the specified type. + + If this is already of the correct type, no copy is performed and the + original object is returned. + + Args: + dtype (type or string): The desired type + non_blocking (bool): If ``True``, and the source is in pinned memory + and destination is on the GPU or vice versa, the copy is performed + asynchronously with respect to the host. Otherwise, the argument + has no effect. + **kwargs: For compatibility, may contain the key ``async`` in place of + the ``non_blocking`` argument. The ``async`` arg is deprecated. + """ + non_blocking = _get_async_or_non_blocking("type", non_blocking, kwargs) + if dtype is None: + return self.__module__ + "." + self.__class__.__name__ + + if isinstance(dtype, str): + dtype = _import_dotted_name(dtype) + if dtype == type(self): + return self + if self.is_sparse: + if not dtype.is_sparse: + raise RuntimeError("Cannot cast sparse tensor to dense tensor") + new_module_name = dtype.__module__.replace(".sparse", "") + new_values_type_name = new_module_name + "." + dtype.__name__ + new_values = torch.Tensor._values(self).type(new_values_type_name, non_blocking) + new_indices_type_name = new_module_name + ".LongTensor" + new_indices = torch.Tensor._indices(self).type( + new_indices_type_name, non_blocking + ) + return dtype(new_indices, new_values, self.size()) + if dtype.is_sparse: + raise RuntimeError("Cannot cast dense tensor to sparse tensor") + return dtype(self.size()).copy_(self, non_blocking) + + +def _to(self, device, non_blocking=False): + """Returns a copy of this object in device memory. + + If this object is already on the correct device, then no copy is performed + and the original object is returned. + + Args: + device (int): The destination device. + non_blocking (bool): If ``True`` and the source is in pinned memory, + the copy will be asynchronous with respect to the host. Otherwise, + the argument has no effect. + """ + if self.device == device: + return self + + if device.type == "cpu": + pin_memory = non_blocking and self.device.type in ( + "cuda", + torch._C._get_privateuse1_backend_name(), + ) + untyped_storage = torch.empty( + self.nbytes(), dtype=torch.uint8, device=device, pin_memory=pin_memory + ).untyped_storage() + untyped_storage.copy_(self, non_blocking) + return untyped_storage + + device_module = getattr(torch, device.type, None) + assert device_module is not None, ( + f"{device.type.upper()} device module is not loaded" + ) + with device_module.device(device): + if self.is_sparse and hasattr(device_module, "sparse"): + new_type = getattr(device_module.sparse, self.__class__.__name__) + indices = getattr(torch.Tensor._indices(self), device.type)( + device, non_blocking + ) + values = getattr(torch.Tensor._values(self), device.type)( + device, non_blocking + ) + return new_type(indices, values, self.size()) + else: + assert not self.is_sparse, ( + f"sparse storage is not supported for {device.type.upper()} tensors" + ) + untyped_storage = torch.UntypedStorage(self.size(), device=device) + untyped_storage.copy_(self, non_blocking) + return untyped_storage + + +def _get_async_or_non_blocking(function_name, non_blocking, kwargs): + """Return the non-blocking flag given the function name and kwargs. + + Args: + function_name (str): the name of the function being used. + non_blocking (bool): the default value. + **kwargs (dict): the kwargs passed to the function. + """ + if not kwargs: + return non_blocking + if len(kwargs) != 1 or "async" not in kwargs: + message = "{}() got an unexpected keyword argument '{}'" + argument = list(kwargs.keys()).pop() + raise TypeError(message.format(function_name, argument)) + warnings.warn("'async' is deprecated; use 'non_blocking'") + return kwargs["async"] + + +def _get_restore_location(device): + """Return the map_location location. + + Used for rebuild functions where the tensor device is distinct from the storage + """ + + map_location = torch.serialization._serialization_tls.map_location + if map_location is None: + return device + else: + if isinstance(map_location, dict): + return map_location.get(device, device) + elif isinstance(map_location, (str, torch.device)): + return map_location + else: + assert callable(map_location) + raise RuntimeError( + "Callable map_location not supported with _rebuild_wrapper_subclass " + "or _rebuild_device_tensor_from_numpy" + ) + + +# Note [Don't serialize hooks] +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Since time immemorial, we have serialized the backward hooks associated with +# variables. This kind of half-worked--Python can pickle global functions +# (but not closures!)--but there were problems. +# +# - It's fragile. If you serialize a backward hook into a saved +# model, and then you rename the function associated with the hook, +# now your saved model is broken and you can't load it anymore. +# +# - It's not actually used. The standard recommendation is to +# serialize the *state_dict* of a model, not the model itself +# (since this is more stable to code changes affecting the model +# serialization), and the state dict saves "data" only, thus +# stripping the backward hooks. In some cases, hooks are +# essential to the well-functioning of a model (e.g., DDP), +# but DDP already manages readding the hooks! +# +# - We didn't serialize them in many cases. Prior to #10220, we +# were dropping backward hooks in ForkingPickler. We "fixed" this +# to be convenient with other serialization sites, but lack of +# serializing backward hooks wasn't actually the root cause of +# the bug. +# +# With these cases in mind, we have decided that a better strategy +# is to just NOT serialize hooks at all. +# +# Since this is a BC-breaking change, we should warn when we previously +# serialized a hook, but no longer do so. This will be done by adding a special +# sentinel property to hooks will be used to suppress this warning. If a hook +# has the property _torch_serialize_ignore, we will not emit a warning if we +# attempt to serialize a Tensor with this hook attached to it. +# +# By the way, when _backward_hooks is skipped, we must give an EMPTY +# OrderedDict(), if you pass a None you'll run afoul #12219. + + +# TODO: Once we decide to break serialization FC, `storage` no longer needs to +# be a TypedStorage +def _rebuild_tensor(storage, storage_offset, size, stride): + # first construct a tensor with the correct dtype/device + t = torch.empty((0,), dtype=storage.dtype, device=storage._untyped_storage.device) + return t.set_(storage._untyped_storage, storage_offset, size, stride) + + +def get_tensor_metadata(tensor): + # Tensor's Metadata for serializing. + # Currently, this only returns a dict[string, bool] specifing whether + # `conj` or `neg` bit is set. + assert isinstance(tensor, torch.Tensor) + return torch._C._get_tensor_metadata(tensor) # type: ignore[attr-defined] + + +def set_tensor_metadata(tensor, metadata): + # See `get_tensor_metadata` above + assert isinstance(metadata, dict) + assert isinstance(tensor, torch.Tensor) + torch._C._set_tensor_metadata(tensor, metadata) # type: ignore[attr-defined] + + +def _restore_device_fake_mode(tensor): + if torch._guards.detect_fake_mode(None) is not None: + if tensor.untyped_storage()._fake_device is not None: + device = _get_restore_location(tensor.untyped_storage()._fake_device) + if not isinstance(device, torch.device): + device = torch.device(device) + tensor.fake_device = torch.device(device) + return tensor + + +def _rebuild_tensor_v2( + storage, + storage_offset, + size, + stride, + requires_grad, + backward_hooks, + metadata=None, +): + tensor = _rebuild_tensor(storage, storage_offset, size, stride) + tensor.requires_grad = requires_grad + if metadata: + set_tensor_metadata(tensor, metadata) + + # NB: This line exists only for backwards compatibility; the + # general expectation is that backward_hooks is an empty + # OrderedDict. See Note [Don't serialize hooks] + tensor._backward_hooks = backward_hooks + + tensor = _restore_device_fake_mode(tensor) + return tensor + + +def _rebuild_tensor_v3( + storage, + storage_offset, + size, + stride, + requires_grad, + backward_hooks, + dtype, + metadata=None, +): + t = torch.empty( + (0,), + dtype=dtype, + device=storage._untyped_storage.device, + requires_grad=requires_grad, + ) + t.set_(storage._untyped_storage, storage_offset, size, stride) + if metadata: + set_tensor_metadata(t, metadata) + t._backward_hooks = backward_hooks + t = _restore_device_fake_mode(t) + return t + + +_sparse_tensors_to_validate: list["torch.Tensor"] = [] + + +# In _legacy_load() in serialization.py we unpickle storages after the sparse +# tensors have been already unpickled. Those storages contain data necessary for +# validating sparse tensors: indices and values. That's why sparse tensors are +# first unpickled without any validation, and then this function is called just +# before _legacy_load() returns, so that all the sparse tensors can be validated +# in bulk. +# +# The same procedure must be followed by _load() in serialization.py because due +# to Pickler semantics, we have to use the same (non-validating) function for +# unpickling sparse tensors, regardless of the caller. +def _validate_loaded_sparse_tensors(): + if not torch.sparse.check_sparse_tensor_invariants().is_enabled(): + # Skip sparse tensor invariants validation for better + # performance. See check_sparse_tensor_invariants + # documentation for how to control sparse tensor invariants + # checking. + _sparse_tensors_to_validate.clear() + return + try: + # We disable pinning check (see check_pinning=False below) to + # avoid gh-153143. In fact, pinning check is unnecessary + # anywhy when loading sparse data from external sources. + for t in _sparse_tensors_to_validate: + if t.layout is torch.sparse_coo: + torch._validate_sparse_coo_tensor_args( + t._indices(), + t._values(), + t.size(), + t.is_coalesced(), + check_pinning=False, + ) + elif t.layout in { + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + }: + # TODO: Validation currently involves an expensive traversal + # on CPU, which may include a device transfer. + if t.layout in {torch.sparse_csr, torch.sparse_bsr}: + compressed_indices, plain_indices = ( + t.crow_indices(), + t.col_indices(), + ) + else: + compressed_indices, plain_indices = ( + t.ccol_indices(), + t.row_indices(), + ) + torch._validate_sparse_compressed_tensor_args( + compressed_indices, + plain_indices, + t.values(), + t.size(), + t.layout, + check_pinning=False, + ) + else: + raise NotImplementedError( + f"_validate_loaded_sparse_tensors for layout `{t.layout}`" + ) + + finally: + _sparse_tensors_to_validate.clear() + + +def _rebuild_sparse_tensor(layout, data): + """ + Rebuilds a sparse tensor from its sparse storage representation. + + Args: + layout (str): The sparse storage layout of the tensor. + data (tuple): The tensor's sparse storage representation. + """ + if layout == torch.sparse_coo: + if len(data) == 3: + # For BC: + indices, values, size = data + is_coalesced = None + else: + indices, values, size, is_coalesced = data + result = torch.sparse_coo_tensor( + indices, values, size, check_invariants=False, is_coalesced=is_coalesced + ) + _sparse_tensors_to_validate.append(result) + return result + + elif layout in { + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + }: + compressed_indices, plain_indices, values, size = data + result = torch.sparse_compressed_tensor( + compressed_indices, + plain_indices, + values, + size, + layout=layout, + check_invariants=False, + ) + _sparse_tensors_to_validate.append(result) + return result + + raise NotImplementedError(f"rebuilding sparse tensor for layout {layout}") + + +def _rebuild_nested_tensor(buffer, sizes, strides, storage_offsets): + return torch._nested_view_from_buffer(buffer, sizes, strides, storage_offsets) + + +def _rebuild_device_tensor_from_cpu_tensor(data, dtype, device, requires_grad): + device = _get_restore_location(device) + tensor = data.to(dtype=dtype, device=device) + tensor.requires_grad = requires_grad + return tensor + + +def _rebuild_device_tensor_from_numpy(data, dtype, device, requires_grad): + device = _get_restore_location(device) + tensor = torch.from_numpy(data).to(dtype=dtype, device=device) + tensor.requires_grad = requires_grad + return tensor + + +# Should not be used, only here to be able to load Tensors serialized with older versions of pytorch +_rebuild_xla_tensor = _rebuild_device_tensor_from_numpy + + +def _rebuild_meta_tensor_no_storage(dtype, size, stride, requires_grad): + return torch.empty_strided( + size, stride, dtype=dtype, device="meta", requires_grad=requires_grad + ) + + +def _rebuild_wrapper_subclass( + cls, + dtype, + size, + stride, + storage_offset, + layout, + device, + requires_grad, +): + device = _get_restore_location(device) + return torch.Tensor._make_wrapper_subclass( + cls, + size, + strides=stride, + dtype=dtype, + storage_offset=storage_offset, + layout=layout, + device=device, + requires_grad=requires_grad, + ) + + +# TODO: Once we decide to break serialization FC, `storage` no longer needs to +# be a TypedStorage +def _rebuild_qtensor( + storage, + storage_offset, + size, + stride, + quantizer_params, + requires_grad, + backward_hooks, +): + qscheme = quantizer_params[0] + if qscheme == torch.per_tensor_affine: + _, scale, zero_point = quantizer_params + tensor = torch._empty_affine_quantized( + size, + scale=scale, + zero_point=zero_point, + dtype=storage.dtype, + device=storage.device, + ) + elif qscheme in (torch.per_channel_affine, torch.per_channel_affine_float_qparams): + _, scales, zero_points, axis = quantizer_params + if type(scales) is list and type(zero_points) is list: + if qscheme == torch.per_channel_affine: + scales = torch.tensor(scales, dtype=torch.double, device=storage.device) + zero_points = torch.tensor( + zero_points, dtype=torch.long, device=storage.device + ) + else: + scales = torch.tensor(scales, dtype=torch.float, device=storage.device) + zero_points = torch.tensor( + zero_points, dtype=torch.float, device=storage.device + ) + tensor = torch._empty_per_channel_affine_quantized( + size, + scales=scales, + zero_points=zero_points, + axis=axis, + dtype=storage.dtype, + device=storage.device, + ) + else: + raise RuntimeError(f"Can't deserialize quantized tensor with qscheme {qscheme}") + tensor.set_(storage, storage_offset, size, stride) + tensor.requires_grad = requires_grad + # NB: This line exists only for backwards compatibility; the + # general expectation is that backward_hooks is an empty + # OrderedDict. See Note [Don't serialize hooks] + tensor._backward_hooks = backward_hooks + return tensor + + +def _rebuild_parameter(data, requires_grad, backward_hooks): + param = torch.nn.Parameter(data, requires_grad) + # NB: This line exists only for backwards compatibility; the + # general expectation is that backward_hooks is an empty + # OrderedDict. See Note [Don't serialize hooks] + param._backward_hooks = backward_hooks + + return param + + +def _rebuild_parameter_with_state(data, requires_grad, backward_hooks, state): + param = torch.nn.Parameter(data, requires_grad) + # NB: This line exists only for backwards compatibility; the + # general expectation is that backward_hooks is an empty + # OrderedDict. See Note [Don't serialize hooks] + param._backward_hooks = backward_hooks + + # Restore state on Parameter like python attr. + param = _set_obj_state(param, state) + return param + + +def _get_obj_state(obj): + # Get the state of the python subclass + # This loosely mimicks the function on the object class but since Tensor do not inherit + # from it, we cannot call that function directly + # https://github.com/python/cpython/blob/c83919bd635f4433f1c6ae8504996a9fe3c215e5/Objects/typeobject.c#L4891 + # Note that starting with Python 3.11, this `__getstate__` is always defined and thus + # the else branch will never be taken. + getstate_fn = getattr(obj, "__getstate__", None) + if getstate_fn: + state = getstate_fn() + else: + slots_to_save = copyreg._slotnames(obj.__class__) # type: ignore[attr-defined] + if slots_to_save: + state = ( + obj.__dict__, + { + name: getattr(obj, name) + for name in slots_to_save + if hasattr(obj, name) + }, + ) + else: + state = obj.__dict__ + + return state + + +def _set_obj_state(obj, state): + if isinstance(state, tuple): + if not len(state) == 2: + raise RuntimeError(f"Invalid serialized state: {state}") + dict_state = state[0] + slots_state = state[1] + else: + dict_state = state + slots_state = None + + # Starting with Python 3.11, the __dict__ attribute is lazily created + # and is serialized as None when not needed. + if dict_state: + for k, v in dict_state.items(): + setattr(obj, k, v) + + if slots_state: + for k, v in slots_state.items(): + setattr(obj, k, v) + return obj + + +def _import_dotted_name(name): + components = name.split(".") + obj = __import__(components[0]) + for component in components[1:]: + obj = getattr(obj, component) + return obj + + +def _flatten_dense_tensors(tensors): + """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of + same dense type. + + Since inputs are dense, the resulting tensor will be a concatenated 1D + buffer. Element-wise operation on this buffer will be equivalent to + operating individually. + + Args: + tensors (Iterable[Tensor]): dense tensors to flatten. + + Returns: + A contiguous 1D buffer containing input tensors. + """ + return torch._C._nn.flatten_dense_tensors(tensors) + + +def _flatten_sparse_tensors(tensors): + """Flatten sparse tensors into two contiguous 1D buffers, one of indices and + one of values. Assume tensors are of same sparse type. + + Args: + tensors (Iterable[Tensor]): sparse tensors to flatten. + + Returns: + A tuple of two contiguous 1D buffers, one containing input tensors' + indices and the other containing the values. + """ + flat_indices = torch._C._nn.flatten_dense_tensors( + [torch.Tensor._indices(t) for t in tensors] + ) + flat_values = torch._C._nn.flatten_dense_tensors( + [torch.Tensor._values(t) for t in tensors] + ) + return flat_indices, flat_values + + +def _unflatten_dense_tensors(flat, tensors): + """View a flat buffer using the sizes of tensors. Assume that tensors are of + same dense type, and that flat is given by _flatten_dense_tensors. + + Args: + flat (Tensor): flattened dense tensors to unflatten. + tensors (Iterable[Tensor]): dense tensors whose sizes will be used to + unflatten flat. + + Returns: + Unflattened dense tensors with sizes same as tensors and values from + flat. + """ + return torch._C._nn.unflatten_dense_tensors(flat, tensors) + + +def _unflatten_sparse_tensors(flat, tensors): + """View flat buffer (containing indices and values) using the sizes of + tensors. Assume that tensors are of same sparse type, and that flat is given + by _flatten_sparse_tensors. + + Args: + flat (tuple(Tensor, Tensor)): flattened indices and values of sparse + tensors to unflatten. + tensors (Iterable[Tensor]): sparse tensors whose sizes will be used to + unflatten flat. + + Returns: + Unflattened sparse tensors with sizes same as tensors and values from + flat. + """ + flat_indices, flat_values = flat + indices = torch._C._nn.unflatten_dense_tensors( + flat_indices, [torch.Tensor._indices(t) for t in tensors] + ) + values = torch._C._nn.unflatten_dense_tensors( + flat_values, [torch.Tensor._values(t) for t in tensors] + ) + outputs = [] + for t, i, v in zip(tensors, indices, values): + outputs.append(t.new(i, v, t.size())) + return tuple(outputs) + + +def _reorder_tensors_as(tensors, ordered_tensors): + """Assume that tensors are of same order as ordered_tensors within their + types, e.g., from _take_tensors. Reorder them to be of same order as + ordered_tensors. + + Args: + tensors (Iterable[Tensor]): tensors to be reordered. They should be of + the same order as ordered_tensors within their own types. + ordered_tensors (Iterable[Tensor]): tensors whose order will be the + reference. + + Returns: + Ordered tuple of tensors with contents from tensors and order of + ordered_tensors. + """ + type_dict = defaultdict(list) + for tensor in tensors: + type_dict[tensor.type()].append(tensor) + type_dict_ = {t: iter(coll) for t, coll in type_dict.items()} + return tuple(next(type_dict_[tensor.type()]) for tensor in ordered_tensors) + + +def _take_tensors(tensors, size_limit): + """Group tensors into chunks. This generator yields a chunk at each time, + each containing tensors of same type up to certain byte limit in total size. + + Args: + tensors (Sequence): A sequence of tensors to be separated into chunks. + size_limit (int): The limit of each chunk in bytes. + + Yields: + Blocks of tensors of same type and within size_limit. The yielded + tensors are only ordered as the original sequence within its types. + """ + buf_dict: defaultdict[str, list] = defaultdict(lambda: [[], 0]) + for tensor in tensors: + t = tensor.type() + if tensor.is_sparse: + indices = torch.Tensor._indices(tensor) + values = torch.Tensor._values(tensor) + size = ( + indices.numel() * indices.element_size() + + values.numel() * values.element_size() + ) + else: + size = tensor.numel() * tensor.element_size() + buf_and_size = buf_dict[t] + if buf_and_size[1] + size > size_limit and buf_and_size[1] > 0: + yield buf_and_size[0] + buf_and_size = buf_dict[t] = [[], 0] + buf_and_size[0].append(tensor) + buf_and_size[1] += size + for buf, _ in buf_dict.values(): + if len(buf) > 0: + yield buf + + +# annotation decorator to get annotations in a way that is compatible +# with both Python 2 and 3 +def annotate(ret, **kwargs): + def dec(fun): + fun.__annotations__ = dict(kwargs) + fun.__annotations__["return"] = ret + return fun + + return dec + + +def render_call(fn, args, kwargs): + str_fn = torch.overrides.resolve_name(fn) + if str_fn is None: + str_fn = str(fn) + + str_args: list[str] = [] + with torch._tensor_str.printoptions(threshold=0, edgeitems=0): + str_args.extend(repr(a) for a in args) + str_args.extend(f"{k}={repr(v)}" for k, v in kwargs.items()) + r = f"{str_fn}({', '.join(str_args)})" + return r + + +# NOTE [ Python Traceback Reference Cycle Problem ] +# +# When using sys.exc_info(), it is important to **not** store the exc_info[2], +# which is the traceback, because otherwise you will run into the traceback +# reference cycle problem, i.e., the traceback holding reference to the frame, +# and the frame (which holds reference to all the object in its temporary scope) +# holding reference the traceback. + + +class KeyErrorMessage(str): + r"""str subclass that returns itself in repr""" + + __slots__ = () + + def __repr__(self): + return self + + +class ExceptionWrapper: + r"""Wraps an exception plus traceback to communicate across threads""" + + def __init__(self, exc_info=None, where="in background"): + # It is important that we don't store exc_info, see + # NOTE [ Python Traceback Reference Cycle Problem ] + if exc_info is None: + exc_info = sys.exc_info() + self.exc_type = exc_info[0] + self.exc_msg = "".join(traceback.format_exception(*exc_info)) + self.where = where + + def reraise(self): + r"""Reraises the wrapped exception in the current thread""" + # Format a message such as: "Caught ValueError in DataLoader worker + # process 2. Original Traceback:", followed by the traceback. + msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}" + if self.exc_type == KeyError: + # KeyError calls repr() on its argument (usually a dict key). This + # makes stack traces unreadable. It will not be changed in Python + # (https://bugs.python.org/issue2651), so we work around it. + msg = KeyErrorMessage(msg) + elif getattr(self.exc_type, "message", None): + # Some exceptions have first argument as non-str but explicitly + # have message field + raise self.exc_type(message=msg) + try: + exception = self.exc_type(msg) + except Exception: + # If the exception takes multiple arguments or otherwise can't + # be constructed, don't try to instantiate since we don't know how to + raise RuntimeError(msg) from None + raise exception + + +def _get_available_device_type(): + if torch.cuda.is_available(): + return "cuda" + if torch.backends.mps.is_available(): + return "mps" + if hasattr(torch, "xpu") and torch.xpu.is_available(): # type: ignore[attr-defined] + return "xpu" + if hasattr(torch, "mtia") and torch.mtia.is_available(): + return "mtia" + custom_backend_name = torch._C._get_privateuse1_backend_name() + custom_device_mod = getattr(torch, custom_backend_name, None) + if custom_device_mod and custom_device_mod.is_available(): + return custom_backend_name + # add more available device types here + return None + + +def _get_device_attr(get_member): + device_type = _get_available_device_type() + if device_type and device_type.lower() == "cuda": + return get_member(torch.cuda) + if device_type and device_type.lower() == "mps": + return get_member(torch.mps) + if device_type and device_type.lower() == "xpu": + return get_member(torch.xpu) # type: ignore[attr-defined] + if device_type and device_type.lower() == "mtia": + return get_member(torch.mtia) + if device_type == torch._C._get_privateuse1_backend_name(): + return get_member(getattr(torch, device_type)) + # add more available device types here + return None + + +def _get_current_device_index(): + # current device index + return _get_device_attr(lambda m: m.current_device()) + + +def _get_all_device_indices(): + # all device index + return _get_device_attr(lambda m: list(range(m.device_count()))) + + +def _get_devices_properties(device_ids): + # all device properties + return [_get_device_attr(lambda m: m.get_device_properties(i)) for i in device_ids] + + +def get_current_device_index() -> int: + r"""Checks if there are CUDA devices available and + returns the device index of the current default CUDA device. + Returns -1 in case there are no CUDA devices available. + Arguments: ``None`` + """ + if torch.cuda.device_count() > 0: + return torch.cuda.current_device() + return -1 + + +def _get_device_index( + device: Any, + optional: bool = False, + allow_cpu: bool = False, +) -> int: + r"""Gets the device index from :attr:`device`, which can be a torch.device + object, a Python integer, or ``None``. + + If :attr:`device` is a torch.device object, returns the device index if it + has index. Note that for a device without a specified index, + i.e., ``torch.device('xxx')``, this will return the current default + device of that type if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``, + CPU devices will be accepted and ``-1`` will be returned in this case. + + If :attr:`device` is a Python integer, it is returned as is. + + If :attr:`device` is ``None``, this will return the current default + device of the supported runtime platform if :attr:`optional` is ``True``. + i.e., the current default CUDA device will be returned if CUDA runtime is supported. + """ + if isinstance(device, str): + device = torch.device(device) + device_idx: Optional[int] = None + if isinstance(device, torch.device): + if not allow_cpu and device.type == "cpu": + raise ValueError(f"Expected a non cpu device, but got: {device}") + device_idx = -1 if device.type == "cpu" else device.index + if isinstance(device, int): + device_idx = device + if device_idx is None: + if optional: + # The eager API _get_current_device_index uses `lambda` functions which are + # not supported in JIT and hence not scriptable. The JIT equivalent API to get + # the current device index is `get_current_device_index()` which can + # be scripted. We use is_scripting to check the mode we are in and call the + # appropriate API. + if torch.jit.is_scripting(): + device_idx = get_current_device_index() + else: + device_idx = _get_current_device_index() + else: + raise ValueError( + f"Expected a torch.device with a specified index or an integer, but got:{device}" + ) + return device_idx + + +def _handle_complex(tensor): + """ + Returns a real view of a tensor if complex dtype else just the tensor + need to check if a UninitializedParameter because otherwise checking is_complex is an error for a LazyModule + """ + return ( + torch.view_as_real(tensor) + if not isinstance(tensor, torch.nn.UninitializedParameter) + and tensor.is_complex() + else tensor + ) + + +def _element_size(dtype): + """ + Returns the element size for a dtype, in bytes + """ + if not isinstance(dtype, torch.dtype): + raise RuntimeError(f"expected torch.dtype, but got {type(dtype)}") + + if dtype.is_complex: + return torch.finfo(dtype).bits >> 2 + elif dtype.is_floating_point: + return torch.finfo(dtype).bits >> 3 + elif dtype == torch.bool: + # NOTE: torch.bool is not supported in torch.iinfo() + return 1 + else: + return torch.iinfo(dtype).bits >> 3 + + +class _ClassPropertyDescriptor: + def __init__(self, fget, fset=None): + self.fget = fget + + def __get__(self, instance, owner=None): + if owner is None: + owner = type(instance) + return self.fget.__get__(instance, owner)() + + +def classproperty(func): + if not isinstance(func, (classmethod, staticmethod)): + func = classmethod(func) + return _ClassPropertyDescriptor(func) + + +if TYPE_CHECKING: + # TorchScript does not support `@deprecated` + # This is a workaround to avoid breaking TorchScript + @deprecated( + "`torch._utils.is_compiling` is deprecated. Use `torch.compiler.is_compiling` instead.", + category=FutureWarning, + ) + def is_compiling() -> bool: + return torch.compiler.is_compiling() + +else: + + def is_compiling() -> bool: + """ + Indicates whether we are tracing/compiling with torch.compile() or torch.export(). + """ + warnings.warn( # use `warnings.warn` instead of `@deprecated` + "`torch._utils.is_compiling` is deprecated. Use `torch.compiler.is_compiling` instead.", + # FutureWarning, # TorchScript does not support Warning type + stacklevel=2, + ) + return torch.compiler.is_compiling() + + +def _functionalize_sync(t): + # This code lives in python instead of C++ since conditioning on a certain python subclass + # is much more of a pain in C++. + from torch._subclasses.functional_tensor import FunctionalTensor + + if isinstance(t, FunctionalTensor): + # If a FunctionalTensorMode is active while syncing, we don't want it to intercept any ops that get called + # when we sync our inner tensor. + # Why? + # (1) If there are input mutations in the graph, then they will be re-applied during + # AOTAutograd when we call _sync() from inside of our functionalization kernels. + # (2) _sync() causes us to regenerate our updated the tensor from the updated base, + # which dispatches to a bunch of view ops + # (3) The input to these view ops is our inner FunctionalTensorWrapper + # (since the sync was called from C++), not the python FunctionalTensor + # (4) if a python FunctionalTensorMode is active, it will complain when it intercepts + # the view op, since it will see an input that is a C++ FunctionalTensorWrapper + # (aka a normal torch.Tensor) instead of a python `FunctionalTensor). + maybe_functional_mode = torch._C._unset_dispatch_mode( + torch._C._TorchDispatchModeKey.FUNCTIONAL + ) + try: + torch._functionalize_sync(t.elem) # type: ignore[attr-defined] + finally: + if maybe_functional_mode is not None: + torch._C._set_dispatch_mode(maybe_functional_mode) + else: + torch._functionalize_sync(t) # type: ignore[attr-defined] + + +@functools.lru_cache(2) +def _get_device_module(device_type: str): + device_module = getattr(torch, device_type, None) + if device_module is None: + raise RuntimeError( + f"Device '{device_type}' does not have a corresponding module registered as 'torch.{device_type}'." + ) + return device_module + + +def _dummy_type(name: str) -> type: + def get_err_fn(is_init: bool): + def err_fn(obj, *args, **kwargs): + if is_init: + class_name = obj.__class__.__name__ + else: + class_name = obj.__name__ + raise RuntimeError(f"Tried to instantiate dummy base class {class_name}") + + return err_fn + + return type( + name, (object,), {"__init__": get_err_fn(True), "__new__": get_err_fn(False)} + ) + + +class _LazySeedTracker: + # Since seeding is memory-less, only track the latest seed. + # Note: `manual_seed_all` followed by `manual_seed` overwrites + # the seed on current device. We track the order of **latest** + # calls between these two API. + def __init__(self): + self.manual_seed_all_cb = None + self.manual_seed_cb = None + self.call_order = [] + + def queue_seed_all(self, cb, traceback): + self.manual_seed_all_cb = (cb, traceback) + # update seed_all to be latest + self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb] + + def queue_seed(self, cb, traceback): + self.manual_seed_cb = (cb, traceback) + # update seed to be latest + self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb] + + def get_calls(self) -> list: + return self.call_order + + +logger = logging.getLogger(__name__) +P = ParamSpec("P") + + +class CallbackRegistry(Generic[P]): + def __init__(self, name: str): + self.name = name + self.callback_list: list[Callable[P, None]] = [] + + def add_callback(self, cb: Callable[P, None]) -> None: + self.callback_list.append(cb) + + def fire_callbacks(self, *args: P.args, **kwargs: P.kwargs) -> None: + for cb in self.callback_list: + try: + cb(*args, **kwargs) + except Exception: + logger.exception( + "Exception in callback for %s registered with gpu trace", self.name + ) + + +def try_import(module_name: str) -> Optional[ModuleType]: + # Implementation based on + # https://docs.python.org/3/library/importlib.html#checking-if-a-module-can-be-imported + if (module := sys.modules.get(module_name, None)) is not None: + return module + + if (spec := importlib.util.find_spec(module_name)) is not None: + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + + # https://docs.python.org/3/library/importlib.html#importlib.machinery.ModuleSpec.loader + # "The finder should always set this attribute" + assert spec.loader is not None, "The loader attribute should always be set" + spec.loader.exec_module(module) + return module + + return None + + +# IMPORT_MAPPING and NAME_MAPPING are adapted from https://github.com/python/cpython/blob/main/Lib/_compat_pickle.py +# for use in the weights_only Unpickler. + +IMPORT_MAPPING = { + "__builtin__": "builtins", + "copy_reg": "copyreg", + "Queue": "queue", + "repr": "reprlib", + "_abcoll": "collections.abc", + # Non-mutual mappings. + "UserDict": "collections", + "UserList": "collections", + "UserString": "collections", + "whichdb": "dbm", + "StringIO": "io", + "cStringIO": "io", +} + + +# This contains rename rules that are easy to handle. We ignore the more +# complex stuff (e.g. mapping the names in the urllib and types modules). +# These rules should be run before import names are fixed. +NAME_MAPPING = { + ("__builtin__", "xrange"): ("builtins", "range"), + ("__builtin__", "reduce"): ("functools", "reduce"), + ("__builtin__", "intern"): ("sys", "intern"), + ("__builtin__", "unichr"): ("builtins", "chr"), + ("__builtin__", "unicode"): ("builtins", "str"), + ("__builtin__", "long"): ("builtins", "int"), + ("itertools", "izip"): ("builtins", "zip"), + ("itertools", "imap"): ("builtins", "map"), + ("itertools", "ifilter"): ("builtins", "filter"), + ("itertools", "ifilterfalse"): ("itertools", "filterfalse"), + ("itertools", "izip_longest"): ("itertools", "zip_longest"), + ("UserDict", "IterableUserDict"): ("collections", "UserDict"), + ("UserList", "UserList"): ("collections", "UserList"), + ("UserString", "UserString"): ("collections", "UserString"), + # Non-mutual mappings. + ("__builtin__", "basestring"): ("builtins", "str"), + ("exceptions", "StandardError"): ("builtins", "Exception"), + ("UserDict", "UserDict"): ("collections", "UserDict"), +} diff --git a/phivenv/Lib/site-packages/torch/_utils_internal.py b/phivenv/Lib/site-packages/torch/_utils_internal.py new file mode 100644 index 0000000000000000000000000000000000000000..5bdeddf7c9c0b545038dff30c1d92e6df298348b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_utils_internal.py @@ -0,0 +1,284 @@ +# mypy: allow-untyped-defs +import functools +import logging +import os +import sys +import tempfile +from typing import Any, Callable, Optional, TypeVar +from typing_extensions import ParamSpec + +import torch +from torch._strobelight.compile_time_profiler import StrobelightCompileTimeProfiler + + +_T = TypeVar("_T") +_P = ParamSpec("_P") + +log = logging.getLogger(__name__) + +if os.environ.get("TORCH_COMPILE_STROBELIGHT", False): + import shutil + + if not shutil.which("strobeclient"): + log.info( + "TORCH_COMPILE_STROBELIGHT is true, but seems like you are not on a FB machine." + ) + else: + log.info("Strobelight profiler is enabled via environment variable") + StrobelightCompileTimeProfiler.enable() + +# this arbitrary-looking assortment of functionality is provided here +# to have a central place for overrideable behavior. The motivating +# use is the FB build environment, where this source file is replaced +# by an equivalent. + +if torch._running_with_deploy(): + # __file__ is meaningless in the context of frozen torch used in torch deploy. + # setting empty torch_parent should allow below functions to operate without crashing, + # but it's unclear if there is a valid use case for them in the context of deploy. + torch_parent = "" +else: + if os.path.basename(os.path.dirname(__file__)) == "shared": + torch_parent = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + else: + torch_parent = os.path.dirname(os.path.dirname(__file__)) + + +def get_file_path(*path_components: str) -> str: + return os.path.join(torch_parent, *path_components) + + +def get_file_path_2(*path_components: str) -> str: + return os.path.join(*path_components) + + +def get_writable_path(path: str) -> str: + if os.access(path, os.W_OK): + return path + return tempfile.mkdtemp(suffix=os.path.basename(path)) + + +def prepare_multiprocessing_environment(path: str) -> None: + pass + + +def resolve_library_path(path: str) -> str: + return os.path.realpath(path) + + +def throw_abstract_impl_not_imported_error(opname, module, context): + if module in sys.modules: + raise NotImplementedError( + f"{opname}: We could not find the fake impl for this operator. " + ) + else: + raise NotImplementedError( + f"{opname}: We could not find the fake impl for this operator. " + f"The operator specified that you may need to import the '{module}' " + f"Python module to load the fake impl. {context}" + ) + + +# NB! This treats "skip" kwarg specially!! +def compile_time_strobelight_meta( + phase_name: str, +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + def compile_time_strobelight_meta_inner( + function: Callable[_P, _T], + ) -> Callable[_P, _T]: + @functools.wraps(function) + def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> _T: + if "skip" in kwargs and isinstance(skip := kwargs["skip"], int): + kwargs["skip"] = skip + 1 + + # This is not needed but we have it here to avoid having profile_compile_time + # in stack traces when profiling is not enabled. + if not StrobelightCompileTimeProfiler.enabled: + return function(*args, **kwargs) + + return StrobelightCompileTimeProfiler.profile_compile_time( + function, phase_name, *args, **kwargs + ) + + return wrapper_function + + return compile_time_strobelight_meta_inner + + +# Meta only, see +# https://www.internalfb.com/intern/wiki/ML_Workflow_Observability/User_Guides/Adding_instrumentation_to_your_code/ +# +# This will cause an event to get logged to Scuba via the signposts API. You +# can view samples on the API at https://fburl.com/scuba/workflow_signpost/zh9wmpqs +# we log to subsystem "torch", and the category and name you provide here. +# Each of the arguments translate into a Scuba column. We're still figuring +# out local conventions in PyTorch, but category should be something like +# "dynamo" or "inductor", and name should be a specific string describing what +# kind of event happened. +# +# Killswitch is at +# https://www.internalfb.com/intern/justknobs/?name=pytorch%2Fsignpost#event +def signpost_event(category: str, name: str, parameters: dict[str, Any]): + log.info("%s %s: %r", category, name, parameters) + + +def log_compilation_event(metrics): + log.info("%s", metrics) + + +def upload_graph(graph): + pass + + +def set_pytorch_distributed_envs_from_justknobs(): + pass + + +def log_export_usage(**kwargs): + pass + + +def log_trace_structured_event(*args, **kwargs) -> None: + pass + + +def log_cache_bypass(*args, **kwargs) -> None: + pass + + +def log_torchscript_usage(api: str, **kwargs): + _ = api + return + + +def check_if_torch_exportable(): + return False + + +def export_training_ir_rollout_check() -> bool: + return True + + +def full_aoti_runtime_assert() -> bool: + return True + + +def log_torch_jit_trace_exportability( + api: str, + type_of_export: str, + export_outcome: str, + result: str, +): + _, _, _, _ = api, type_of_export, export_outcome, result + return + + +def justknobs_check(name: str, default: bool = True) -> bool: + """ + This function can be used to killswitch functionality in FB prod, + where you can toggle this value to False in JK without having to + do a code push. In OSS, we always have everything turned on all + the time, because downstream users can simply choose to not update + PyTorch. (If more fine-grained enable/disable is needed, we could + potentially have a map we lookup name in to toggle behavior. But + the point is that it's all tied to source code in OSS, since there's + no live server to query.) + + This is the bare minimum functionality I needed to do some killswitches. + We have a more detailed plan at + https://docs.google.com/document/d/1Ukerh9_42SeGh89J-tGtecpHBPwGlkQ043pddkKb3PU/edit + In particular, in some circumstances it may be necessary to read in + a knob once at process start, and then use it consistently for the + rest of the process. Future functionality will codify these patterns + into a better high level API. + + WARNING: Do NOT call this function at module import time, JK is not + fork safe and you will break anyone who forks the process and then + hits JK again. + """ + return default + + +def justknobs_getval_int(name: str) -> int: + """ + Read warning on justknobs_check + """ + return 0 + + +def is_fb_unit_test() -> bool: + return False + + +@functools.cache +def max_clock_rate(): + if not torch.version.hip: + from triton.testing import nvsmi + + return nvsmi(["clocks.max.sm"])[0] + else: + # Manually set max-clock speeds on ROCm until equivalent nvmsi + # functionality in triton.testing or via pyamdsmi enablement. Required + # for test_snode_runtime unit tests. + gcn_arch = str(torch.cuda.get_device_properties(0).gcnArchName.split(":", 1)[0]) + if "gfx94" in gcn_arch: + return 1700 + elif "gfx90a" in gcn_arch: + return 1700 + elif "gfx908" in gcn_arch: + return 1502 + elif "gfx12" in gcn_arch: + return 1700 + elif "gfx11" in gcn_arch: + return 1700 + elif "gfx103" in gcn_arch: + return 1967 + elif "gfx101" in gcn_arch: + return 1144 + elif "gfx95" in gcn_arch: + return 1700 # TODO: placeholder, get actual value + else: + return 1100 + + +def get_mast_job_name_version() -> Optional[tuple[str, int]]: + return None + + +TEST_MASTER_ADDR = "127.0.0.1" +TEST_MASTER_PORT = 29500 +# USE_GLOBAL_DEPS controls whether __init__.py tries to load +# libtorch_global_deps, see Note [Global dependencies] +USE_GLOBAL_DEPS = True +# USE_RTLD_GLOBAL_WITH_LIBTORCH controls whether __init__.py tries to load +# _C.so with RTLD_GLOBAL during the call to dlopen. +USE_RTLD_GLOBAL_WITH_LIBTORCH = False +# If an op was defined in C++ and extended from Python using the +# torch.library.register_fake, returns if we require that there be a +# m.set_python_module("mylib.ops") call from C++ that associates +# the C++ op with a python module. +REQUIRES_SET_PYTHON_MODULE = False + + +def maybe_upload_prof_stats_to_manifold(profile_path: str) -> Optional[str]: + print("Uploading profile stats (fb-only otherwise no-op)") + return None + + +def log_chromium_event_internal( + event: dict[str, Any], + stack: list[str], + logger_uuid: str, + start_time_ns: int, +): + return None + + +def record_chromium_event_internal( + event: dict[str, Any], +): + return None + + +def profiler_allow_cudagraph_cupti_lazy_reinit_cuda12(): + return True diff --git a/phivenv/Lib/site-packages/torch/_vmap_internals.py b/phivenv/Lib/site-packages/torch/_vmap_internals.py new file mode 100644 index 0000000000000000000000000000000000000000..98083c76fd531eeb88eda22cc1230ecd1459b7de --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_vmap_internals.py @@ -0,0 +1,245 @@ +# mypy: allow-untyped-defs +import functools +from typing import Any, Callable, Optional, Union +from typing_extensions import deprecated + +import torch +from torch import Tensor +from torch.utils._pytree import _broadcast_to_and_flatten, tree_flatten, tree_unflatten + + +in_dims_t = Union[int, tuple] +out_dims_t = Union[int, tuple[int, ...]] + + +# Checks that all args-to-be-batched have the same batch dim size +def _validate_and_get_batch_size( + flat_in_dims: list[Optional[int]], + flat_args: list, +) -> int: + batch_sizes = [ + arg.size(in_dim) + for in_dim, arg in zip(flat_in_dims, flat_args) + if in_dim is not None + ] + if batch_sizes and any(size != batch_sizes[0] for size in batch_sizes): + raise ValueError( + f"vmap: Expected all tensors to have the same size in the mapped " + f"dimension, got sizes {batch_sizes} for the mapped dimension" + ) + return batch_sizes[0] + + +def _num_outputs(batched_outputs: Union[Tensor, tuple[Tensor, ...]]) -> int: + if isinstance(batched_outputs, tuple): + return len(batched_outputs) + return 1 + + +# If value is a tuple, check it has length `num_elements`. +# If value is not a tuple, make a tuple with `value` repeated `num_elements` times +def _as_tuple( + value: Any, + num_elements: int, + error_message_lambda: Callable[[], str], +) -> tuple: + if not isinstance(value, tuple): + return (value,) * num_elements + if len(value) != num_elements: + raise ValueError(error_message_lambda()) + return value + + +# Creates BatchedTensors for every Tensor in arg that should be batched. +# Returns the (potentially) batched arguments and the batch_size. +def _create_batched_inputs( + in_dims: in_dims_t, + args: tuple, + vmap_level: int, + func: Callable, +) -> tuple[tuple, int]: + if not isinstance(in_dims, int) and not isinstance(in_dims, tuple): + raise ValueError( + f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(): " + f"expected `in_dims` to be int or a (potentially nested) tuple " + f"matching the structure of inputs, got: {type(in_dims)}." + ) + if len(args) == 0: + raise ValueError( + f"vmap({_get_name(func)})(): got no inputs. Maybe you forgot to add " + f"inputs, or you are trying to vmap over a function with no inputs. " + f"The latter is unsupported." + ) + + flat_args, args_spec = tree_flatten(args) + flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec) + if flat_in_dims is None: + raise ValueError( + f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(): " + f"in_dims is not compatible with the structure of `inputs`. " + f"in_dims has structure {tree_flatten(in_dims)[1]} but inputs " + f"has structure {args_spec}." + ) + + for arg, in_dim in zip(flat_args, flat_in_dims): + if not isinstance(in_dim, int) and in_dim is not None: + raise ValueError( + f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(): " + f"Got in_dim={in_dim} for an input but in_dim must be either " + f"an integer dimension or None." + ) + if isinstance(in_dim, int) and not isinstance(arg, Tensor): + raise ValueError( + f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(): " + f"Got in_dim={in_dim} for an input but the input is of type " + f"{type(arg)}. We cannot vmap over non-Tensor arguments, " + f"please use None as the respective in_dim" + ) + if in_dim is not None and (in_dim < 0 or in_dim >= arg.dim()): + raise ValueError( + f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(): " + f"Got in_dim={in_dim} for some input, but that input is a Tensor " + f"of dimensionality {arg.dim()} so expected in_dim to satisfy " + f"0 <= in_dim < {arg.dim()}." + ) + + batch_size = _validate_and_get_batch_size(flat_in_dims, flat_args) + # See NOTE [Ignored _remove_batch_dim, _add_batch_dim] + batched_inputs = [ + arg if in_dim is None else torch._add_batch_dim(arg, in_dim, vmap_level) + for in_dim, arg in zip(flat_in_dims, flat_args) + ] + return tree_unflatten(batched_inputs, args_spec), batch_size + + +# Undos the batching (and any batch dimensions) associated with the `vmap_level`. +def _unwrap_batched( + batched_outputs: Union[Tensor, tuple[Tensor, ...]], + out_dims: out_dims_t, + vmap_level: int, + batch_size: int, + func: Callable, + allow_none_pass_through: bool = False, +) -> tuple: + num_outputs = _num_outputs(batched_outputs) + out_dims_as_tuple = _as_tuple( + out_dims, + num_outputs, + lambda: f"vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must " + f"have one dim per output (got {num_outputs} outputs) of {_get_name(func)}.", + ) + + # NOTE [Ignored _remove_batch_dim, _add_batch_dim] + # There is something wrong with our type bindings for functions that begin + # with '_', see #40397. + if isinstance(batched_outputs, Tensor): + out_dim = out_dims_as_tuple[0] + return torch._remove_batch_dim(batched_outputs, vmap_level, batch_size, out_dim) # type: ignore[return-value] + if allow_none_pass_through: + return tuple( + ( + torch._remove_batch_dim(out, vmap_level, batch_size, out_dim) + if out is not None + else None + ) + for out, out_dim in zip(batched_outputs, out_dims_as_tuple) + ) + else: + return tuple( + torch._remove_batch_dim(out, vmap_level, batch_size, out_dim) + for out, out_dim in zip(batched_outputs, out_dims_as_tuple) + ) + + +# Checks that `fn` returned one or more Tensors and nothing else. +# NB: A python function that return multiple arguments returns a single tuple, +# so we are effectively checking that `outputs` is a single Tensor or a tuple of +# Tensors. +def _validate_outputs(outputs: Any, func: Callable) -> None: + if isinstance(outputs, Tensor): + return + if not isinstance(outputs, tuple): + raise ValueError( + f"vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return " + f"Tensors, got type {type(outputs)} as the return." + ) + for idx, output in enumerate(outputs): + if isinstance(output, Tensor): + continue + raise ValueError( + f"vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return " + f"Tensors, got type {type(output)} for return {idx}." + ) + + +def _check_out_dims_is_int_or_int_tuple(out_dims: out_dims_t, func: Callable) -> None: + if isinstance(out_dims, int): + return + if not isinstance(out_dims, tuple) or not all( + isinstance(out_dim, int) for out_dim in out_dims + ): + raise ValueError( + f"vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must be " + f"an int or a tuple of int representing where in the outputs the " + f"vmapped dimension should appear." + ) + + +def _get_name(func: Callable): + if hasattr(func, "__name__"): + return func.__name__ + + # Not all callables have __name__, in fact, only static functions/methods do. + # A callable created via functools.partial or an nn.Module, to name some + # examples, don't have a __name__. + return repr(func) + + +# vmap(func)(inputs) wraps all Tensor inputs to be batched in BatchedTensors, +# sends those into func, and then unwraps the output BatchedTensors. Operations +# on BatchedTensors perform the batched operations that the user is asking for. +@deprecated( + "Please use `torch.vmap` instead of `torch._vmap_internals.vmap`.", + category=FutureWarning, +) +def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Callable: + """ + Please use torch.vmap instead of this API. + """ + return _vmap(func, in_dims, out_dims) + + +# A version of vmap but without the initial "experimental prototype" warning +def _vmap( + func: Callable, + in_dims: in_dims_t = 0, + out_dims: out_dims_t = 0, + allow_none_pass_through: bool = False, +) -> Callable: + # The `allow_none_pass_through` argument is a temporary workaround may be removed. + # Currently it enables us to wrap the call in `autograd.grad` to the autograd engine, + # which may return None if any of the inputs are unused. See the issue discussing this: + # https://github.com/pytorch/functorch/issues/159. + @functools.wraps(func) + def wrapped(*args): + _check_out_dims_is_int_or_int_tuple(out_dims, func) + vmap_level = torch._C._vmapmode_increment_nesting() + try: + batched_inputs, batch_size = _create_batched_inputs( + in_dims, args, vmap_level, func + ) + batched_outputs = func(*batched_inputs) + if not allow_none_pass_through: + _validate_outputs(batched_outputs, func) + return _unwrap_batched( + batched_outputs, + out_dims, + vmap_level, + batch_size, + func, + allow_none_pass_through=allow_none_pass_through, + ) + finally: + torch._C._vmapmode_decrement_nesting() + + return wrapped diff --git a/phivenv/Lib/site-packages/torch/_weights_only_unpickler.py b/phivenv/Lib/site-packages/torch/_weights_only_unpickler.py new file mode 100644 index 0000000000000000000000000000000000000000..7949a54f6a3e953e0ba8fa750108b585589f45b1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_weights_only_unpickler.py @@ -0,0 +1,573 @@ +# mypy: allow-untyped-defs +# Unpickler restricted to loading only state dicts +# Restrict constructing types to a list defined in _get_allowed_globals() +# Restrict BUILD operation to `Tensor`, `Parameter` and `OrderedDict` types only +# Restrict APPEND/APPENDS to `list` +# In `GLOBALS` operation do not do class lookup by name, but rather rely on dictionary +# defined by `_get_allowed_globals()` method, that contains: +# - torch types (Storage, dtypes, Tensor, `torch.Size`), +# - `torch._utils._rebuild` functions. +# - `torch.nn.Parameter` +# - `collections.Counter` +# - `collections.OrderedDict` +# Additionally, users can use an allowlist for adding classes they have deemed as safe using +# `_add_safe_globals()` (`torch.serialization.add_safe_globals`) +# `_clear_safe_globals()` (`torch.serialization.clear_safe_globals`) +# `_get_safe_globals()` (`torch.serialization.get_safe_globals`) + +# Based of https://github.com/python/cpython/blob/main/Lib/pickle.py +# Expected to be useful for loading PyTorch model weights +# For example: +# data = urllib.request.urlopen('https://download.pytorch.org/models/resnet50-0676ba61.pth').read() +# buf = io.BytesIO(data) +# weights = torch.load(buf, weights_only = True) + +import functools as _functools +import warnings + +from _codecs import encode +from collections import Counter, OrderedDict +from pickle import ( + APPEND, + APPENDS, + BINFLOAT, + BINGET, + BININT, + BININT1, + BININT2, + BINPERSID, + BINPUT, + BINUNICODE, + BUILD, + bytes_types, + decode_long, + EMPTY_DICT, + EMPTY_LIST, + EMPTY_SET, + EMPTY_TUPLE, + GLOBAL, + LONG1, + LONG_BINGET, + LONG_BINPUT, + MARK, + NEWFALSE, + NEWOBJ, + NEWTRUE, + NONE, + PROTO, + REDUCE, + SETITEM, + SETITEMS, + SHORT_BINSTRING, + STOP, + TUPLE, + TUPLE1, + TUPLE2, + TUPLE3, + UnpicklingError, +) +from struct import unpack +from sys import maxsize +from typing import Any, Callable, Union + +import torch +from torch._utils import _sparse_tensors_to_validate, IMPORT_MAPPING, NAME_MAPPING + + +# modules in this list are never allowed, even if the user attempts to allowlist +# functions/classes from them +_blocklisted_modules = [ + "sys", + "os", + "posix", + "nt", +] + +_marked_safe_globals_set: set[Union[Callable, tuple[Callable, str]]] = set() + + +def _add_safe_globals(safe_globals: list[Union[Callable, tuple[Callable, str]]]): + global _marked_safe_globals_set + _marked_safe_globals_set = _marked_safe_globals_set.union(set(safe_globals)) + + +def _get_safe_globals() -> list[Union[Callable, tuple[Callable, str]]]: + global _marked_safe_globals_set + return list(_marked_safe_globals_set) + + +def _clear_safe_globals(): + global _marked_safe_globals_set + _marked_safe_globals_set = set() + + +def _remove_safe_globals( + globals_to_remove: list[Union[Callable, tuple[Callable, str]]], +): + global _marked_safe_globals_set + _marked_safe_globals_set = _marked_safe_globals_set - set(globals_to_remove) + + +class _safe_globals: + def __init__(self, safe_globals: list[Union[Callable, tuple[Callable, str]]]): + self.safe_globals = safe_globals + + def __enter__(self): + _add_safe_globals(self.safe_globals) + + def __exit__(self, type, value, tb): + _remove_safe_globals(self.safe_globals) + + +# Separate from _get_allowed_globals because of the lru_cache on _get_allowed_globals +# For example if user had a script like +# torch.load(file_a) +# torch.serialization._add_safe_globals([torch.foo]) +# torch.load(file_b) +# the dynamic additions to safe_globals would not be picked up by +# _get_allowed_globals due to the lru_cache +def _get_user_allowed_globals(): + rc: dict[str, Any] = {} + for f in _marked_safe_globals_set: + if isinstance(f, tuple): + if len(f) != 2: + raise ValueError( + f"Expected tuple of length 2 (global, str of callable full path), but got tuple of length: {len(f)}" + ) + if type(f[1]) is not str: + raise TypeError( + f"Expected second item in tuple to be str of callable full path, but got: {type(f[1])}" + ) + f, name = f + rc[name] = f + else: + module, name = f.__module__, f.__qualname__ + rc[f"{module}.{name}"] = f + return rc + + +def _tensor_rebuild_functions(): + return { + torch._utils._rebuild_parameter, + torch._utils._rebuild_parameter_with_state, + torch._utils._rebuild_qtensor, + torch._utils._rebuild_tensor, + torch._utils._rebuild_tensor_v2, + torch._utils._rebuild_tensor_v3, + torch._utils._rebuild_sparse_tensor, + torch._utils._rebuild_meta_tensor_no_storage, + torch._utils._rebuild_nested_tensor, + torch._utils._rebuild_wrapper_subclass, + # Allowlisting this, but not allowlisting the numpy functions by default + # Reasoning is that we don't have control over the numpy functions, but + # this utility is provided by pytorch + torch._utils._rebuild_device_tensor_from_numpy, + # In 2.6, we should no longer have a dependency on numpy and the above + # _rebuild_device_tensor_from_numpy function. + torch._utils._rebuild_device_tensor_from_cpu_tensor, + } + + +# Unpickling machinery +@_functools.lru_cache(maxsize=1) +def _get_allowed_globals(): + rc: dict[str, Any] = { + "collections.OrderedDict": OrderedDict, + "collections.Counter": Counter, + "torch.nn.parameter.Parameter": torch.nn.Parameter, + "torch.serialization._get_layout": torch.serialization._get_layout, + "torch.Size": torch.Size, + "torch.Tensor": torch.Tensor, + "torch.device": torch.device, + "_codecs.encode": encode, # for bytes + "builtins.bytearray": bytearray, # for bytearray + "builtins.set": set, # for set + "builtins.complex": complex, # for complex + } + + # dtype + for t in torch.storage._dtype_to_storage_type_map().keys(): + rc[str(t)] = t + for t in torch.storage._new_dtypes(): + rc[str(t)] = t + for t in [getattr(torch, f"uint{x}") for x in range(1, 8)]: + rc[str(t)] = t + for t in [getattr(torch, f"int{x}") for x in range(1, 8)]: + rc[str(t)] = t + + # Tensor classes + for tt in torch._tensor_classes: + rc[f"{tt.__module__}.{tt.__name__}"] = tt + # Storage classes + for ts in torch._storage_classes: + if ts not in (torch.storage.TypedStorage, torch.storage.UntypedStorage): + # Wrap legacy storage types in a dummy class + rc[f"{ts.__module__}.{ts.__name__}"] = torch.serialization.StorageType( + ts.__name__ + ) + else: + rc[f"{ts.__module__}.{ts.__name__}"] = ts + # Quantization specific + for qt in [ + torch.per_tensor_affine, + torch.per_tensor_symmetric, + torch.per_channel_affine, + torch.per_channel_symmetric, + torch.per_channel_affine_float_qparams, + ]: + rc[str(qt)] = qt + # Rebuild functions + for f in _tensor_rebuild_functions(): + rc[f"torch._utils.{f.__name__}"] = f + + # Handles Tensor Subclasses, Tensor's with attributes. + # NOTE: It calls into above rebuild functions for regular Tensor types. + rc["torch._tensor._rebuild_from_type_v2"] = torch._tensor._rebuild_from_type_v2 + return rc + + +def _read_global_instruction(readline: Callable) -> tuple[str, str]: + module = readline()[:-1].decode("utf-8") + name = readline()[:-1].decode("utf-8") + # Patch since torch.save default protocol is 2 + # users will be running this code in python > 3 + if (module, name) in NAME_MAPPING: + module, name = NAME_MAPPING[(module, name)] + elif module in IMPORT_MAPPING: + module = IMPORT_MAPPING[module] + return module, name + + +def get_globals_in_pkl(file) -> set[str]: + globals_in_checkpoint = set() + read = file.read + readline = file.readline + op_to_bytes_to_read = { + NEWOBJ[0]: 0, + REDUCE[0]: 0, + BUILD[0]: 0, + APPEND[0]: 0, + APPENDS[0]: 0, + SETITEM[0]: 0, + SETITEMS[0]: 0, + MARK[0]: 0, + TUPLE[0]: 0, + TUPLE1[0]: 0, + TUPLE2[0]: 0, + TUPLE3[0]: 0, + NONE[0]: 0, + NEWFALSE[0]: 0, + NEWTRUE[0]: 0, + EMPTY_TUPLE[0]: 0, + EMPTY_LIST[0]: 0, + EMPTY_DICT[0]: 0, + EMPTY_SET[0]: 0, + BINPERSID[0]: 0, + BININT[0]: 4, + BININT1[0]: 1, + BININT2[0]: 2, + BINFLOAT[0]: 8, + BINGET[0]: 1, + LONG_BINGET[0]: 4, + BINPUT[0]: 1, + LONG_BINPUT[0]: 4, + } + while True: + key = read(1) + if not key: + raise EOFError + assert isinstance(key, bytes_types) + if key[0] == GLOBAL[0]: + module, name = _read_global_instruction(readline) + globals_in_checkpoint.add(f"{module}.{name}") + elif key[0] in op_to_bytes_to_read: + bytes_to_read = op_to_bytes_to_read[key[0]] + if bytes_to_read: + read(bytes_to_read) + # ops where bytes to read depends on the data + elif key[0] == BINUNICODE[0]: + strlen = unpack(" maxsize: + raise UnpicklingError("String is too long") + read(strlen) + elif key[0] in {SHORT_BINSTRING[0], LONG1[0]}: + strlen = read(1)[0] + read(strlen) + # first and last op + elif key[0] == PROTO[0]: + read(1)[0] + elif key[0] == STOP[0]: + return globals_in_checkpoint + else: + raise UnpicklingError(f"Unsupported operand {key[0]}") + + +class Unpickler: + def __init__(self, file, *, encoding: str = "bytes"): + self.encoding = encoding + self.readline = file.readline + self.read = file.read + self.memo: dict[int, Any] = {} + self.proto: int = -1 + + def load(self): + """Read a pickled object representation from the open file. + + Return the reconstituted object hierarchy specified in the file. + """ + self.metastack = [] + self.stack: list[Any] = [] + self.append = self.stack.append + read = self.read + while True: + key = read(1) + if not key: + raise EOFError + assert isinstance(key, bytes_types) + # Risky operators + if key[0] == GLOBAL[0]: + module, name = _read_global_instruction(self.readline) + full_path = f"{module}.{name}" + if module in _blocklisted_modules: + raise UnpicklingError( + f"Trying to load unsupported GLOBAL {full_path} whose module {module} is blocked." + ) + if full_path in _get_allowed_globals(): + self.append(_get_allowed_globals()[full_path]) + elif full_path in _get_user_allowed_globals(): + self.append(_get_user_allowed_globals()[full_path]) + elif full_path in ( + [ + "torch.nested._internal.nested_tensor.NestedTensor", + "torch.nested._internal.nested_tensor._rebuild_njt", + "torch._dynamo.decorators._DimRange", + ] + ): + raise UnpicklingError( + "``torch.nested`` and ``torch._dynamo`` must be imported to load nested jagged tensors (NJTs)" + ) + elif full_path in ( + [ + "torch.distributed.device_mesh.DeviceMesh", + "torch.distributed.tensor._dtensor_spec.DTensorSpec", + "torch.distributed.tensor._dtensor_spec.TensorMeta", + "torch.distributed.tensor.DTensor", + "torch.distributed.tensor.placement_types.Partial", + "torch.distributed.tensor.placement_types.Replicate", + "torch.distributed.tensor.placement_types.Shard", + ] + ): + raise UnpicklingError( + "``torch.distributed.tensor`` must be imported to load DTensors" + ) + else: + builtins_name = "builtins" + if ( + builtins_name in full_path + and builtins_name == full_path[: len(builtins_name)] + ): + full_path = full_path[len(builtins_name) :] + full_path = ( + full_path[1:] + if len(full_path) > 0 and full_path[0] == "." + else builtins_name + full_path + ) + raise UnpicklingError( + f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. " + f"Please use `torch.serialization.add_safe_globals([{full_path}])` or the " + f"`torch.serialization.safe_globals([{full_path}])` context manager to allowlist this global " + "if you trust this class/function." + ) + elif key[0] == NEWOBJ[0]: + args = self.stack.pop() + cls = self.stack.pop() + if cls is torch.nn.Parameter: + self.append(torch.nn.Parameter(*args)) + elif ( + cls in _get_user_allowed_globals().values() + or cls in _get_allowed_globals().values() + ): + result = cls.__new__(cls, *args) + if cls in torch._tensor_classes and "sparse" in cls.__module__: + _sparse_tensors_to_validate.append(result) + self.append(result) + else: + raise UnpicklingError( + "Can only create new object for nn.Parameter or classes allowlisted " + f"via `add_safe_globals` but got {cls}" + ) + elif key[0] == REDUCE[0]: + args = self.stack.pop() + func = self.stack[-1] + if ( + func not in _get_allowed_globals().values() + and func not in _get_user_allowed_globals().values() + ): + raise UnpicklingError( + f"Trying to call reduce for unrecognized function {func}" + ) + result = func(*args) + if func in torch._tensor_classes and "sparse" in func.__module__: + _sparse_tensors_to_validate.append(result) + self.stack[-1] = result + elif key[0] == BUILD[0]: + state = self.stack.pop() + inst = self.stack[-1] + if type(inst) is torch.Tensor: + # Legacy unpickling + inst.set_(*state) + elif type(inst) is torch.nn.Parameter: + inst.__setstate__(state) + elif type(inst) is OrderedDict: + inst.__dict__.update(state) + elif ( + type(inst) in _get_user_allowed_globals().values() + or type(inst) in _get_allowed_globals().values() + ): + if hasattr(inst, "__setstate__"): + inst.__setstate__(state) + else: + # mimics load_build in pickle + # https://github.com/python/cpython/blob/f0c6fccd08904787a39269367f09f263d496114c/Lib/pickle.py#L1854-L1867 + slotstate = None + if isinstance(state, tuple) and len(state) == 2: + state, slotstate = state + if state: + inst.__dict__.update(state) + if slotstate: + for k, v in slotstate.items(): + setattr(inst, k, v) + else: + raise UnpicklingError( + "Can only build Tensor, Parameter, OrderedDict or types allowlisted " + f"via `add_safe_globals`, but got {type(inst)}" + ) + # Stack manipulation + elif key[0] == APPEND[0]: + item = self.stack.pop() + list_obj = self.stack[-1] + if type(list_obj) is not list: + raise UnpicklingError( + f"Can only append to lists, but got {type(list_obj)}" + ) + list_obj.append(item) + elif key[0] == APPENDS[0]: + items = self.pop_mark() + list_obj = self.stack[-1] + if type(list_obj) is not list: + raise UnpicklingError( + f"Can only extend lists, but got {type(list_obj)}" + ) + list_obj.extend(items) + elif key[0] == SETITEM[0]: + (v, k) = (self.stack.pop(), self.stack.pop()) + self.stack[-1][k] = v + elif key[0] == SETITEMS[0]: + items = self.pop_mark() + for i in range(0, len(items), 2): + self.stack[-1][items[i]] = items[i + 1] + elif key[0] == MARK[0]: + self.metastack.append(self.stack) + self.stack = [] + self.append = self.stack.append + elif key[0] == TUPLE[0]: + items = self.pop_mark() + self.append(tuple(items)) + elif key[0] == TUPLE1[0]: + self.stack[-1] = (self.stack[-1],) + elif key[0] == TUPLE2[0]: + self.stack[-2:] = [(self.stack[-2], self.stack[-1])] + elif key[0] == TUPLE3[0]: + self.stack[-3:] = [(self.stack[-3], self.stack[-2], self.stack[-1])] + # Basic types construction + elif key[0] == NONE[0]: + self.append(None) + elif key[0] == NEWFALSE[0]: + self.append(False) + elif key[0] == NEWTRUE[0]: + self.append(True) + elif key[0] == EMPTY_TUPLE[0]: + self.append(()) + elif key[0] == EMPTY_LIST[0]: + self.append([]) + elif key[0] == EMPTY_DICT[0]: + self.append({}) + elif key[0] == EMPTY_SET[0]: + self.append(set()) + elif key[0] == BININT[0]: + self.append(unpack("d", self.read(8))[0]) + elif key[0] == BINUNICODE[0]: + strlen = unpack(" maxsize: + raise UnpicklingError("String is too long") + strval = str(read(strlen), "utf-8", "surrogatepass") + self.append(strval) + elif key[0] == SHORT_BINSTRING[0]: + strlen = read(1)[0] + strdata = read(strlen) + if self.encoding != "bytes": + strdata = strdata.decode(self.encoding, "strict") + self.append(strdata) + elif key[0] == BINPERSID[0]: + pid = self.stack.pop() + # Only allow persistent load of storage + if type(pid) is not tuple and not type(pid) is not int: + raise UnpicklingError( + f"persistent_load id must be tuple or int, but got {type(pid)}" + ) + if ( + type(pid) is tuple + and len(pid) > 0 + and torch.serialization._maybe_decode_ascii(pid[0]) != "storage" + ): + raise UnpicklingError( + f"Only persistent_load of storage is allowed, but got {pid[0]}" + ) + self.append(self.persistent_load(pid)) + elif key[0] in [BINGET[0], LONG_BINGET[0]]: + idx = (read(1) if key[0] == BINGET[0] else unpack("` in python. +""" + +from typing import Optional +from typing_extensions import deprecated + +import torch + +from ._utils import _device_t, _get_device_index + + +__all__ = [ + "current_accelerator", + "current_device_idx", # deprecated + "current_device_index", + "current_stream", + "device_count", + "device_index", + "is_available", + "set_device_idx", # deprecated + "set_device_index", + "set_stream", + "synchronize", +] + + +def device_count() -> int: + r"""Return the number of current :ref:`accelerator` available. + + Returns: + int: the number of the current :ref:`accelerator` available. + If there is no available accelerators, return 0. + + .. note:: This API delegates to the device-specific version of `device_count`. + On CUDA, this API will NOT poison fork if NVML discovery succeeds. + Otherwise, it will. For more details, see :ref:`multiprocessing-poison-fork-note`. + """ + acc = current_accelerator() + if acc is None: + return 0 + + mod = torch.get_device_module(acc) + return mod.device_count() + + +def is_available() -> bool: + r"""Check if the current accelerator is available at runtime: it was build, all the + required drivers are available and at least one device is visible. + See :ref:`accelerator` for details. + + Returns: + bool: A boolean indicating if there is an available :ref:`accelerator`. + + .. note:: This API delegates to the device-specific version of `is_available`. + On CUDA, when the environment variable ``PYTORCH_NVML_BASED_CUDA_CHECK=1`` is set, + this function will NOT poison fork. Otherwise, it will. For more details, see + :ref:`multiprocessing-poison-fork-note`. + + Example:: + + >>> assert torch.accelerator.is_available() "No available accelerators detected." + """ + # Why not just check "device_count() > 0" like other is_available call? + # Because device like CUDA have a python implementation of is_available that is + # non-poisoning and some features like Dataloader rely on it. + # So we are careful to delegate to the Python version of the accelerator here + acc = current_accelerator() + if acc is None: + return False + + mod = torch.get_device_module(acc) + return mod.is_available() + + +def current_accelerator(check_available: bool = False) -> Optional[torch.device]: + r"""Return the device of the accelerator available at compilation time. + If no accelerator were available at compilation time, returns None. + See :ref:`accelerator` for details. + + Args: + check_available (bool, optional): if True, will also do a runtime check to see + if the device :func:`torch.accelerator.is_available` on top of the compile-time + check. + Default: ``False`` + + Returns: + torch.device: return the current accelerator as :class:`torch.device`. + + .. note:: The index of the returned :class:`torch.device` will be ``None``, please use + :func:`torch.accelerator.current_device_index` to know the current index being used. + This API does NOT poison fork. For more details, see :ref:`multiprocessing-poison-fork-note`. + + Example:: + + >>> # xdoctest: + >>> # If an accelerator is available, sent the model to it + >>> model = torch.nn.Linear(2, 2) + >>> if (current_device := current_accelerator(check_available=True)) is not None: + >>> model.to(current_device) + """ + if (acc := torch._C._accelerator_getAccelerator()) is not None: + if (not check_available) or (check_available and is_available()): + return acc + return None + + +def current_device_index() -> int: + r"""Return the index of a currently selected device for the current :ref:`accelerator`. + + Returns: + int: the index of a currently selected device. + """ + return torch._C._accelerator_getDeviceIndex() + + +current_device_idx = deprecated( + "Use `current_device_index` instead.", + category=FutureWarning, +)(current_device_index) + + +def set_device_index(device: _device_t, /) -> None: + r"""Set the current device index to a given device. + + Args: + device (:class:`torch.device`, str, int): a given device that must match the current + :ref:`accelerator` device type. + + .. note:: This function is a no-op if this device index is negative. + """ + device_index = _get_device_index(device, optional=False) + torch._C._accelerator_setDeviceIndex(device_index) + + +set_device_idx = deprecated( + "Use `set_device_index` instead.", + category=FutureWarning, +)(set_device_index) + + +def current_stream(device: _device_t = None, /) -> torch.Stream: + r"""Return the currently selected stream for a given device. + + Args: + device (:class:`torch.device`, str, int, optional): a given device that must match the current + :ref:`accelerator` device type. If not given, + use :func:`torch.accelerator.current_device_index` by default. + + Returns: + torch.Stream: the currently selected stream for a given device. + """ + device_index = _get_device_index(device, optional=True) + return torch._C._accelerator_getStream(device_index) + + +def set_stream(stream: torch.Stream) -> None: + r"""Set the current stream to a given stream. + + Args: + stream (torch.Stream): a given stream that must match the current :ref:`accelerator` device type. + + .. note:: This function will set the current device index to the device index of the given stream. + """ + torch._C._accelerator_setStream(stream) + + +def synchronize(device: _device_t = None, /) -> None: + r"""Wait for all kernels in all streams on the given device to complete. + + Args: + device (:class:`torch.device`, str, int, optional): device for which to synchronize. It must match + the current :ref:`accelerator` device type. If not given, + use :func:`torch.accelerator.current_device_index` by default. + + .. note:: This function is a no-op if the current :ref:`accelerator` is not initialized. + + Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> assert torch.accelerator.is_available() "No available accelerators detected." + >>> start_event = torch.Event(enable_timing=True) + >>> end_event = torch.Event(enable_timing=True) + >>> start_event.record() + >>> tensor = torch.randn(100, device=torch.accelerator.current_accelerator()) + >>> sum = torch.sum(tensor) + >>> end_event.record() + >>> torch.accelerator.synchronize() + >>> elapsed_time_ms = start_event.elapsed_time(end_event) + """ + device_index = _get_device_index(device, optional=True) + torch._C._accelerator_synchronizeDevice(device_index) + + +class device_index: + r"""Context manager to set the current device index for the current :ref:`accelerator`. + Temporarily changes the current device index to the specified value for the duration + of the context, and automatically restores the previous device index when exiting + the context. + + Args: + device (Optional[int]): a given device index to temporarily set. If None, + no device index switching occurs. + + Examples: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> # Set device 0 as the current device temporarily + >>> with torch.accelerator.device_index(0): + ... # Code here runs with device 0 as the current device + ... pass + >>> # Original device is now restored + >>> # No-op when None is passed + >>> with torch.accelerator.device_index(None): + ... # No device switching occurs + ... pass + """ + + def __init__(self, device: Optional[int], /) -> None: + self.idx = device + self.prev_idx = -1 + + def __enter__(self) -> None: + if self.idx is not None: + self.prev_idx = torch._C._accelerator_exchangeDevice(self.idx) + + def __exit__(self, *exc_info: object) -> None: + if self.idx is not None: + torch._C._accelerator_maybeExchangeDevice(self.prev_idx) diff --git a/phivenv/Lib/site-packages/torch/accelerator/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/accelerator/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c60e72917f83741a5a765aa08678ef08215e137b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/accelerator/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/accelerator/__pycache__/_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/accelerator/__pycache__/_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..904917870ee63ef368769c8924ad07b01ec3888c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/accelerator/__pycache__/_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/accelerator/_utils.py b/phivenv/Lib/site-packages/torch/accelerator/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..79df5b85fbbe0212977ea5548bd7fba6bab94580 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/accelerator/_utils.py @@ -0,0 +1,28 @@ +from typing import Optional + +import torch +from torch.types import Device as _device_t + + +def _get_device_index(device: _device_t, optional: bool = False) -> int: + if isinstance(device, int): + return device + if isinstance(device, str): + device = torch.device(device) + device_index: Optional[int] = None + if isinstance(device, torch.device): + acc = torch.accelerator.current_accelerator() + if acc is None: + raise RuntimeError("Accelerator expected") + if acc.type != device.type: + raise ValueError( + f"{device.type} doesn't match the current accelerator {acc}." + ) + device_index = device.index + if device_index is None: + if not optional: + raise ValueError( + f"Expected a torch.device with a specified index or an integer, but got:{device}" + ) + return torch.accelerator.current_device_index() + return device_index diff --git a/phivenv/Lib/site-packages/torch/amp/__init__.py b/phivenv/Lib/site-packages/torch/amp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..83fd1cd1484034e19b407f6f665a6132b562dde5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/amp/__init__.py @@ -0,0 +1,9 @@ +from .autocast_mode import ( + _enter_autocast, + _exit_autocast, + autocast, + custom_bwd, + custom_fwd, + is_autocast_available, +) +from .grad_scaler import GradScaler diff --git a/phivenv/Lib/site-packages/torch/amp/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/amp/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e386956737d749c7ead3b8b68b73c222cae81072 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/amp/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/amp/__pycache__/autocast_mode.cpython-39.pyc b/phivenv/Lib/site-packages/torch/amp/__pycache__/autocast_mode.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a43790e162d3aa40f0a4f4132356bf5bc4e9661f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/amp/__pycache__/autocast_mode.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/amp/__pycache__/grad_scaler.cpython-39.pyc b/phivenv/Lib/site-packages/torch/amp/__pycache__/grad_scaler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf854a42f0ff439777143601ab7961b765375a7f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/amp/__pycache__/grad_scaler.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/amp/autocast_mode.py b/phivenv/Lib/site-packages/torch/amp/autocast_mode.py new file mode 100644 index 0000000000000000000000000000000000000000..8c54c41079f888242c9e7cb8e5ad224103be8858 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/amp/autocast_mode.py @@ -0,0 +1,565 @@ +# mypy: allow-untyped-defs +import collections +import functools +import warnings +from typing import Any, Optional + +import torch +from torch.types import _dtype + + +try: + import numpy as np + + HAS_NUMPY = True +except ModuleNotFoundError: + HAS_NUMPY = False + np = None # type: ignore[assignment] + +__all__ = [ + "autocast_decorator", + "autocast", + "is_autocast_available", + "custom_fwd", + "custom_bwd", +] + + +def is_autocast_available(device_type: str) -> bool: + r""" + Return a bool indicating if autocast is available on :attr:`device_type`. + + Args: + device_type(str): Device type to use. Possible values are: 'cuda', 'cpu', 'mtia', 'maia', 'xpu', and so on. + The type is the same as the `type` attribute of a :class:`torch.device`. + Thus, you may obtain the device type of a tensor using `Tensor.device.type`. + """ + return torch._C._is_autocast_available(device_type) + + +def autocast_decorator(autocast_instance, func): + @functools.wraps(func) + def decorate_autocast(*args, **kwargs): + with autocast_instance: + return func(*args, **kwargs) + + decorate_autocast.__script_unsupported = "@autocast() decorator is not supported in script mode" # type: ignore[attr-defined] + return decorate_autocast + + +class autocast: + r""" + Instances of :class:`autocast` serve as context managers or decorators that + allow regions of your script to run in mixed precision. + + In these regions, ops run in an op-specific dtype chosen by autocast + to improve performance while maintaining accuracy. + See the :ref:`Autocast Op Reference` for details. + + When entering an autocast-enabled region, Tensors may be any type. + You should not call ``half()`` or ``bfloat16()`` on your model(s) or inputs when using autocasting. + + :class:`autocast` should wrap only the forward pass(es) of your network, including the loss + computation(s). Backward passes under autocast are not recommended. + Backward ops run in the same type that autocast used for corresponding forward ops. + + Example for CUDA Devices:: + + # Creates model and optimizer in default precision + model = Net().cuda() + optimizer = optim.SGD(model.parameters(), ...) + + for input, target in data: + optimizer.zero_grad() + + # Enables autocasting for the forward pass (model + loss) + with torch.autocast(device_type="cuda"): + output = model(input) + loss = loss_fn(output, target) + + # Exits the context manager before backward() + loss.backward() + optimizer.step() + + See the :ref:`Automatic Mixed Precision examples` for usage (along with gradient scaling) + in more complex scenarios (e.g., gradient penalty, multiple models/losses, custom autograd functions). + + :class:`autocast` can also be used as a decorator, e.g., on the ``forward`` method of your model:: + + class AutocastModel(nn.Module): + ... + @torch.autocast(device_type="cuda") + def forward(self, input): + ... + + Floating-point Tensors produced in an autocast-enabled region may be ``float16``. + After returning to an autocast-disabled region, using them with floating-point + Tensors of different dtypes may cause type mismatch errors. If so, cast the Tensor(s) + produced in the autocast region back to ``float32`` (or other dtype if desired). + If a Tensor from the autocast region is already ``float32``, the cast is a no-op, + and incurs no additional overhead. + CUDA Example:: + + # Creates some tensors in default dtype (here assumed to be float32) + a_float32 = torch.rand((8, 8), device="cuda") + b_float32 = torch.rand((8, 8), device="cuda") + c_float32 = torch.rand((8, 8), device="cuda") + d_float32 = torch.rand((8, 8), device="cuda") + + with torch.autocast(device_type="cuda"): + # torch.mm is on autocast's list of ops that should run in float16. + # Inputs are float32, but the op runs in float16 and produces float16 output. + # No manual casts are required. + e_float16 = torch.mm(a_float32, b_float32) + # Also handles mixed input types + f_float16 = torch.mm(d_float32, e_float16) + + # After exiting autocast, calls f_float16.float() to use with d_float32 + g_float32 = torch.mm(d_float32, f_float16.float()) + + CPU Training Example:: + + # Creates model and optimizer in default precision + model = Net() + optimizer = optim.SGD(model.parameters(), ...) + + for epoch in epochs: + for input, target in data: + optimizer.zero_grad() + + # Runs the forward pass with autocasting. + with torch.autocast(device_type="cpu", dtype=torch.bfloat16): + output = model(input) + loss = loss_fn(output, target) + + loss.backward() + optimizer.step() + + + CPU Inference Example:: + + # Creates model in default precision + model = Net().eval() + + with torch.autocast(device_type="cpu", dtype=torch.bfloat16): + for input in data: + # Runs the forward pass with autocasting. + output = model(input) + + CPU Inference Example with Jit Trace:: + + class TestModel(nn.Module): + def __init__(self, input_size, num_classes): + super().__init__() + self.fc1 = nn.Linear(input_size, num_classes) + def forward(self, x): + return self.fc1(x) + + input_size = 2 + num_classes = 2 + model = TestModel(input_size, num_classes).eval() + + # For now, we suggest to disable the Jit Autocast Pass, + # As the issue: https://github.com/pytorch/pytorch/issues/75956 + torch._C._jit_set_autocast_mode(False) + + with torch.cpu.amp.autocast(cache_enabled=False): + model = torch.jit.trace(model, torch.randn(1, input_size)) + model = torch.jit.freeze(model) + # Models Run + for _ in range(3): + model(torch.randn(1, input_size)) + + Type mismatch errors *in* an autocast-enabled region are a bug; if this is what you observe, + please file an issue. + + ``autocast(enabled=False)`` subregions can be nested in autocast-enabled regions. + Locally disabling autocast can be useful, for example, if you want to force a subregion + to run in a particular ``dtype``. Disabling autocast gives you explicit control over + the execution type. In the subregion, inputs from the surrounding region + should be cast to ``dtype`` before use:: + + # Creates some tensors in default dtype (here assumed to be float32) + a_float32 = torch.rand((8, 8), device="cuda") + b_float32 = torch.rand((8, 8), device="cuda") + c_float32 = torch.rand((8, 8), device="cuda") + d_float32 = torch.rand((8, 8), device="cuda") + + with torch.autocast(device_type="cuda"): + e_float16 = torch.mm(a_float32, b_float32) + with torch.autocast(device_type="cuda", enabled=False): + # Calls e_float16.float() to ensure float32 execution + # (necessary because e_float16 was created in an autocasted region) + f_float32 = torch.mm(c_float32, e_float16.float()) + + # No manual casts are required when re-entering the autocast-enabled region. + # torch.mm again runs in float16 and produces float16 output, regardless of input types. + g_float16 = torch.mm(d_float32, f_float32) + + The autocast state is thread-local. If you want it enabled in a new thread, the context manager or decorator + must be invoked in that thread. This affects :class:`torch.nn.DataParallel` and + :class:`torch.nn.parallel.DistributedDataParallel` when used with more than one GPU per process + (see :ref:`Working with Multiple GPUs`). + + Args: + device_type(str, required): Device type to use. Possible values are: 'cuda', 'cpu', 'mtia', 'maia', 'xpu', and 'hpu'. + The type is the same as the `type` attribute of a :class:`torch.device`. + Thus, you may obtain the device type of a tensor using `Tensor.device.type`. + enabled(bool, optional): Whether autocasting should be enabled in the region. + Default: ``True`` + dtype(torch_dtype, optional): Data type for ops run in autocast. It uses the default value + (``torch.float16`` for CUDA and ``torch.bfloat16`` for CPU), given by + :func:`~torch.get_autocast_dtype`, if :attr:`dtype` is ``None``. + Default: ``None`` + cache_enabled(bool, optional): Whether the weight cache inside autocast should be enabled. + Default: ``True`` + """ + + def __init__( + self, + device_type: str, + dtype: Optional[_dtype] = None, + enabled: bool = True, + cache_enabled: Optional[bool] = None, + ): + if not isinstance(device_type, str): + raise ValueError( + f"Expected `device_type` of type `str`, got: `{type(device_type)}`" + ) + if dtype is None: + dtype = torch.get_autocast_dtype(device_type) + if torch._jit_internal.is_scripting(): + self._enabled = enabled + self.device = device_type + self.fast_dtype = dtype + assert dtype is not None + return + self.device = device_type + if not is_autocast_available(self.device): + raise RuntimeError( + f"User specified an unsupported autocast device_type '{self.device}'" + ) + self.custom_backend_name = torch._C._get_privateuse1_backend_name() + self.fast_dtype = torch.get_autocast_dtype(self.device) + if self.device == self.custom_backend_name: + necessary_funcs = [ + "get_amp_supported_dtype", + ] + message = f"Tried to use AMP with the `{self.custom_backend_name}` backend, but the backend has not " + message += "registered a module or the module miss some necessary funcs. The backend should register " + message += "a module by `torch._register_device_module`, and the module must have these funcs: \n" + message += "`get_amp_supported_dtype() -> List[torch.dtype]`. \n" + + assert hasattr(torch, self.custom_backend_name), message + self.custom_device_mod = getattr(torch, self.custom_backend_name) + for func in necessary_funcs: + assert hasattr(self.custom_device_mod, func), ( + message + f"But the func `{func}` is missing. \n" + ) + + self._cache_enabled = torch.is_autocast_cache_enabled() + if ( + enabled + and self.device == "cuda" + and torch.cuda.amp.common.amp_definitely_not_available() + ): + warnings.warn( + "User provided device_type of 'cuda', but CUDA is not available. Disabling" + ) + enabled = False + if dtype is not None: + self.fast_dtype = dtype + if cache_enabled is not None: + self._cache_enabled = cache_enabled + + if self.device == "cpu": + supported_dtype = [torch.bfloat16, torch.float16] + if self.fast_dtype not in supported_dtype and enabled: + error_message = "In CPU autocast, but the target dtype is not supported. Disabling autocast.\n" + error_message += "CPU Autocast only supports dtype of " + error_message += ( + ", ".join(str(dtype) for dtype in supported_dtype) + " currently." + ) + warnings.warn(error_message) + enabled = False + elif self.device == "mtia": + supported_dtype = [torch.bfloat16, torch.float16] + if self.fast_dtype not in supported_dtype: + error_message = "In MTIA autocast, but the target dtype is not supported. Disabling autocast.\n" + error_message += "MTIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently." + warnings.warn(error_message) + enabled = False + elif self.device == "maia": + supported_dtype = [torch.bfloat16, torch.float16] + if self.fast_dtype not in supported_dtype: + error_message = "In MAIA autocast, but the target dtype is not supported. Disabling autocast.\n" + error_message += "MAIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently." + warnings.warn(error_message) + enabled = False + elif self.device == "xpu": + supported_dtype = [torch.bfloat16, torch.float16] + if self.fast_dtype not in supported_dtype: + error_message = "In XPU autocast, but the target dtype is not supported. Disabling autocast.\n" + error_message += "XPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently." + warnings.warn(error_message) + enabled = False + elif self.device == "ipu": + supported_dtypes = [torch.bfloat16, torch.float16] + if self.fast_dtype not in supported_dtypes: + error_message = "In IPU autocast, but the target dtype is not supported. Disabling autocast.\n" + error_message += "IPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently." + warnings.warn(error_message) + enabled = False + elif self.device == "hpu": + supported_dtype = [torch.bfloat16, torch.float16] + if self.fast_dtype not in supported_dtype: + error_message = "In HPU autocast, but the target dtype is not supported. Disabling autocast.\n" + error_message += "HPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently." + warnings.warn(error_message) + enabled = False + elif self.device == self.custom_backend_name: + supported_dtype = self.custom_device_mod.get_amp_supported_dtype() + if self.fast_dtype not in supported_dtype: + error_message = f"In {self.custom_backend_name} autocast, but the target dtype is not supported. " + error_message += f"Disabling autocast.\n {self.custom_backend_name} Autocast only supports dtypes of " + error_message += ( + ", ".join(str(dtype) for dtype in supported_dtype) + " currently." + ) + warnings.warn(error_message) + enabled = False + elif self.device == "cuda": + if ( + enabled + and self.fast_dtype == torch.bfloat16 + and not torch.cuda.is_bf16_supported() + ): + raise RuntimeError( + "Current CUDA Device does not support bfloat16. Please switch dtype to float16." + ) + elif self.device == "mps": + supported_dtype = [torch.bfloat16, torch.float16] + if self.fast_dtype not in supported_dtype: + error_message = ( + "In MPS autocast, but the target dtype is not supported. Disabling autocast.\n" + "MPS Autocast only supports dtype of torch.bfloat16 and torch.float16 currently." + ) + warnings.warn(error_message) + enabled = False + elif self.fast_dtype == torch.bfloat16: + if not torch.backends.mps.is_macos_or_newer(14, 0): + error_message = ( + "In MPS autocast, but the target dtype torch.bfloat16 is not supported " + "on macOS versions below 14. Disabling autocast." + ) + warnings.warn(error_message) + enabled = False + elif self.device == "xla": + supported_dtype = [torch.float16, torch.bfloat16] + if self.fast_dtype not in supported_dtype: + error_message = "In XLA autocast, but the target dtype is not supported. Disabling autocast.\n" + error_message += ( + "XLA Autocast only supports dtype of torch.bfloat16 currently." + ) + warnings.warn(error_message) + enabled = False + self._enabled = enabled + + def __enter__(self): + if torch._jit_internal.is_scripting(): + assert self.fast_dtype is not None + return self + + self.prev_cache_enabled = torch.is_autocast_cache_enabled() + self.prev = torch.is_autocast_enabled(self.device) + self.prev_fastdtype = torch.get_autocast_dtype(self.device) + torch.set_autocast_enabled(self.device, self._enabled) + torch.set_autocast_dtype(self.device, self.fast_dtype) # type: ignore[arg-type] + torch.autocast_increment_nesting() + torch.set_autocast_cache_enabled(self._cache_enabled) + + # only dispatch to PreDispatchTorchFunctionMode to avoid exposing this + # API to other functional modes. We only expose to PreDispatchTorchFunctionMode + # for preserving autocast in torch.export.export. + if torch._C._is_torch_function_mode_enabled(): + stacks = torch.overrides._get_current_function_mode_stack() + for mode in stacks: + if isinstance( + mode, + torch.fx.experimental.proxy_tensor.PreDispatchTorchFunctionMode, + ): + args = ( + self.device, + self.fast_dtype, + self._enabled, + self._cache_enabled, + ) + return mode.__torch_function__(torch.amp._enter_autocast, (), args) + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override] + if torch._jit_internal.is_scripting(): + return + + # Drop the cache when we exit to a nesting level that's outside any instance of autocast. + if torch.autocast_decrement_nesting() == 0: + torch.clear_autocast_cache() + torch.set_autocast_enabled(self.device, self.prev) + torch.set_autocast_dtype(self.device, self.prev_fastdtype) + torch.set_autocast_cache_enabled(self.prev_cache_enabled) + + # only dispatch to PreDispatchTorchFunctionMode to avoid exposing this + # API to other functional modes. We only expose to PreDispatchTorchFunctionMode + # for preserving autocast in torch.export.export. + if torch._C._is_torch_function_mode_enabled(): + stacks = torch.overrides._get_current_function_mode_stack() + for mode in stacks: + if isinstance( + mode, + torch.fx.experimental.proxy_tensor.PreDispatchTorchFunctionMode, + ): + return mode.__torch_function__(torch.amp._exit_autocast, (), ()) + return False + + def __call__(self, func): + if torch._jit_internal.is_scripting(): + return func + return autocast_decorator(self, func) + + +# These functions aren't meant for public usage. +# They are what we trace into a graph during pre_dispatch tracing +# when we encounter an autocast context manager. +def _enter_autocast(*vals): + # For pre-dispatch tracing, if a TorchFunction mode is active, we'll want to trace this into a graph. + if torch._C._is_torch_function_mode_enabled(): + return torch.overrides.handle_torch_function( + torch.amp._enter_autocast, [], *vals + ) + mode = torch.amp.autocast(*vals) + mode.__enter__() + return mode + + +def _exit_autocast(mode): + if torch._C._is_torch_function_mode_enabled(): + return torch.overrides.handle_torch_function(torch.amp._exit_autocast, [], mode) + mode.__exit__(None, None, None) + + +# Casts Tensors and containers of Tensors. Special-cases passthroughs for strings and np.ndarrays, which +# may be falsely detected as "Iterables." +def _cast(value, device_type: str, dtype: _dtype): + if isinstance(value, torch.Tensor): + is_eligible = ( + value.is_floating_point() + and value.device.type == device_type + and (value.dtype is not torch.float64) + ) + return value.to(dtype) if is_eligible else value + elif isinstance(value, (str, bytes)): + return value + elif HAS_NUMPY and isinstance(value, np.ndarray): + return value + elif isinstance(value, collections.abc.Mapping): + return { + _cast(k, device_type, dtype): _cast(v, device_type, dtype) + for k, v in value.items() + } + elif isinstance(value, collections.abc.Iterable): + iterable = (_cast(v, device_type, dtype) for v in value) + if isinstance(value, (list, tuple)): + return type(value)(iterable) + else: + return iterable + else: + return value + + +def custom_fwd( + fwd=None, + *, + device_type: str, + cast_inputs: Optional[_dtype] = None, +): + """ + Create a helper decorator for ``forward`` methods of custom autograd functions. + + Autograd functions are subclasses of :class:`torch.autograd.Function`. + See the :ref:`example page` for more detail. + + Args: + device_type(str): Device type to use. 'cuda', 'cpu', 'mtia', 'maia', 'xpu' and so on. + The type is the same as the `type` attribute of a :class:`torch.device`. + Thus, you may obtain the device type of a tensor using `Tensor.device.type`. + cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``, + when ``forward`` runs in an autocast-enabled region, casts incoming + floating-point Tensors to the target dtype (non-floating-point Tensors are not affected), + then executes ``forward`` with autocast disabled. + If ``None``, ``forward``'s internal ops execute with the current autocast state. + + .. note:: + If the decorated ``forward`` is called outside an autocast-enabled region, + :func:`custom_fwd` is a no-op and ``cast_inputs`` has no effect. + """ + if not isinstance(device_type, str): + raise ValueError( + f"Expected `device_type` of type `str`, got: `{type(device_type)}`" + ) + if fwd is None: + return functools.partial( + custom_fwd, device_type=device_type, cast_inputs=cast_inputs + ) + + @functools.wraps(fwd) + def decorate_fwd(*args, **kwargs): + args[0]._dtype = torch.get_autocast_dtype(device_type) + if cast_inputs is None: + args[0]._fwd_used_autocast = torch.is_autocast_enabled(device_type) + return fwd(*args, **kwargs) + else: + autocast_context = torch.is_autocast_enabled(device_type) + args[0]._fwd_used_autocast = False + if autocast_context: + with autocast(device_type=device_type, enabled=False): + return fwd( + *_cast(args, device_type, cast_inputs), + **_cast(kwargs, device_type, cast_inputs), + ) + else: + return fwd(*args, **kwargs) + + return decorate_fwd + + +# Autograd ensures incoming gradients are the same type as forward outputs. Allowing a separate +# cast_inputs argument on custom_bwd is unnecessary and could cause errors if it doesn't match +# cast_inputs supplied to custom_fwd. +def custom_bwd(bwd=None, *, device_type: str): + """Create a helper decorator for backward methods of custom autograd functions. + + Autograd functions are subclasses of :class:`torch.autograd.Function`. + Ensures that ``backward`` executes with the same autocast state as ``forward``. + See the :ref:`example page` for more detail. + + Args: + device_type(str): Device type to use. 'cuda', 'cpu', 'mtia', 'maia', 'xpu' and so on. + The type is the same as the `type` attribute of a :class:`torch.device`. + Thus, you may obtain the device type of a tensor using `Tensor.device.type`. + """ + + if not isinstance(device_type, str): + raise ValueError( + f"Expected `device_type` of type `str`, got: `{type(device_type)}`" + ) + if bwd is None: + return functools.partial(custom_bwd, device_type=device_type) + + @functools.wraps(bwd) + def decorate_bwd(*args, **kwargs): + with autocast( + device_type=device_type, + enabled=args[0]._fwd_used_autocast, + dtype=args[0]._dtype, + ): + return bwd(*args, **kwargs) + + return decorate_bwd diff --git a/phivenv/Lib/site-packages/torch/amp/grad_scaler.py b/phivenv/Lib/site-packages/torch/amp/grad_scaler.py new file mode 100644 index 0000000000000000000000000000000000000000..de8bcc8068d44dea9727a2e74e5a0e5768a25e66 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/amp/grad_scaler.py @@ -0,0 +1,693 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import inspect +import warnings +from collections import abc, defaultdict +from enum import Enum +from typing import Any, cast, Optional, overload, TYPE_CHECKING, Union + +import torch + + +if TYPE_CHECKING: + from collections.abc import Iterable + + +__all__ = ["OptState", "GradScaler"] + + +class _MultiDeviceReplicator: + """Lazily serves copies of a tensor to requested devices. + + Copies are cached per-device. + """ + + def __init__(self, master_tensor: torch.Tensor) -> None: + self.master = master_tensor + self._per_device_tensors: dict[torch.device, torch.Tensor] = {} + + def get(self, device: torch.device) -> torch.Tensor: + retval = self._per_device_tensors.get(device, None) + if retval is None: + retval = self.master.to(device=device, non_blocking=True, copy=True) + self._per_device_tensors[device] = retval + return retval + + +# Defines default_factory for GradScaler's _per_optimizer_states defaultdict, +# as well as associated "enum" values. Prefers defining these at top level because +# - Lambdas can't be pickled, so we don't want to supply a lambda as the factory. +# - Defining READY, UNSCALED, STEPPED and _refresh_per_optimizer_state within GradScaler +# causes a circular reference, which we'd rather avoid. +class OptState(Enum): + READY = 0 + UNSCALED = 1 + STEPPED = 2 + + +def _refresh_per_optimizer_state() -> dict[str, Any]: + return {"stage": OptState.READY, "found_inf_per_device": {}} + + +class GradScaler: + """An instance ``scaler`` of :class:`GradScaler`. + + Helps perform the steps of gradient scaling + conveniently. + + * ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor. + * ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``. + * ``scaler.update()`` updates ``scaler``'s scale factor. + + Example:: + + # Creates a GradScaler once at the beginning of training. + scaler = GradScaler() + + for epoch in epochs: + for input, target in data: + optimizer.zero_grad() + output = model(input) + loss = loss_fn(output, target) + + # Scales loss. Calls backward() on scaled loss to create scaled gradients. + scaler.scale(loss).backward() + + # scaler.step() first unscales gradients of the optimizer's params. + # If gradients don't contain infs/NaNs, optimizer.step() is then called, + # otherwise, optimizer.step() is skipped. + scaler.step(optimizer) + + # Updates the scale for next iteration. + scaler.update() + + See the :ref:`Automatic Mixed Precision examples` for usage + (along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty, + and multiple losses/optimizers. + + ``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow, + a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if + the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used + without incurring inf or NaN gradient values. + ``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every + ``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`). + + * If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params + themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``. + + * If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual. + If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by + ``growth_factor``. + + The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its + value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these + iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations). + + Args: + device (str, optional, default="cuda"): Device type to use. Possible values are: 'cuda' and 'cpu'. + The type is the same as the `type` attribute of a :class:`torch.device`. + Thus, you may obtain the device type of a tensor using `Tensor.device.type`. + init_scale (float, optional, default=2.**16): Initial scale factor. + growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during + :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations. + backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during + :meth:`update` if inf/NaN gradients occur in an iteration. + growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients + that must occur for the scale to be multiplied by ``growth_factor``. + enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply + invokes the underlying ``optimizer.step()``, and other methods become no-ops. + Default: ``True`` + """ + + def __init__( + self, + device: str = "cuda", + init_scale: float = 2.0**16, + growth_factor: float = 2.0, + backoff_factor: float = 0.5, + growth_interval: int = 2000, + enabled: bool = True, + ) -> None: + self._device = device + self._enabled = enabled + if self._device == "cuda": + if enabled and torch.cuda.amp.common.amp_definitely_not_available(): + warnings.warn( + "torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling." + ) + self._enabled = False + + if self._enabled: + assert growth_factor > 1.0, "The growth factor must be > 1.0." + assert backoff_factor < 1.0, "The backoff factor must be < 1.0." + + self._init_scale = init_scale + # self._scale will be lazily initialized during the first call to scale() + self._scale: Optional[torch.Tensor] = None + self._growth_factor = growth_factor + self._backoff_factor = backoff_factor + self._growth_interval = growth_interval + self._init_growth_tracker = 0 + # self._growth_tracker will be lazily initialized during the first call to scale() + self._growth_tracker: Optional[torch.Tensor] = None + self._per_optimizer_states: dict[int, dict[str, Any]] = defaultdict( + _refresh_per_optimizer_state + ) + + def _check_scale_growth_tracker( + self, funcname: str + ) -> tuple[torch.Tensor, torch.Tensor]: + fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration." + assert self._scale is not None, ( + f"Attempted {funcname} but _scale is None. " + fix + ) + assert self._growth_tracker is not None, ( + f"Attempted {funcname} but _growth_tracker is None. " + fix + ) + return (self._scale, self._growth_tracker) + + def _lazy_init_scale_growth_tracker(self, dev: torch.device) -> None: + assert self._growth_tracker is None, "_growth_tracker initialized before _scale" + self._scale = torch.full((), self._init_scale, dtype=torch.float32, device=dev) + self._growth_tracker = torch.full( + (), self._init_growth_tracker, dtype=torch.int32, device=dev + ) + + @overload + def scale(self, outputs: torch.Tensor) -> torch.Tensor: + ... + + @overload + def scale(self, outputs: list[torch.Tensor]) -> list[torch.Tensor]: + ... + + @overload + def scale(self, outputs: tuple[torch.Tensor, ...]) -> tuple[torch.Tensor, ...]: + ... + + @overload + def scale(self, outputs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]: + ... + + def scale( + self, + outputs: Union[torch.Tensor, Iterable[torch.Tensor]], + ) -> Union[torch.Tensor, Iterable[torch.Tensor]]: + """ + Multiplies ('scales') a tensor or list of tensors by the scale factor. + + Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned + unmodified. + + Args: + outputs (Tensor or iterable of Tensors): Outputs to scale. + """ + if not self._enabled: + return outputs + + # Short-circuit for the common case. + if isinstance(outputs, torch.Tensor): + if self._scale is None: + self._lazy_init_scale_growth_tracker(outputs.device) + assert self._scale is not None + return outputs * self._scale.to(device=outputs.device, non_blocking=True) + + # Invoke the more complex machinery only if we're treating multiple outputs. + stash: list[ + _MultiDeviceReplicator + ] = [] # holds a reference that can be overwritten by apply_scale + + def apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]): + if isinstance(val, torch.Tensor): + if len(stash) == 0: + if self._scale is None: + self._lazy_init_scale_growth_tracker(val.device) + assert self._scale is not None + stash.append(_MultiDeviceReplicator(self._scale)) + return val * stash[0].get(val.device) + if isinstance(val, abc.Iterable): + iterable = map(apply_scale, val) + if isinstance(val, (list, tuple)): + return type(val)(iterable) + return iterable + raise ValueError("outputs must be a Tensor or an iterable of Tensors") + + return apply_scale(outputs) + + def _unscale_grads_( + self, + optimizer: torch.optim.Optimizer, + inv_scale: torch.Tensor, + found_inf: torch.Tensor, + allow_fp16: bool, + ) -> dict[torch.device, torch.Tensor]: + per_device_inv_scale = _MultiDeviceReplicator(inv_scale) + per_device_found_inf = _MultiDeviceReplicator(found_inf) + + # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype. + # There could be hundreds of grads, so we'd like to iterate through them just once. + # However, we don't know their devices or dtypes in advance. + + # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict + # Google says mypy struggles with defaultdicts type annotations. + per_device_and_dtype_grads: dict[ + torch.device, dict[torch.dtype, list[torch.Tensor]] + ] = defaultdict(lambda: defaultdict(list)) + with torch.no_grad(): + for group in optimizer.param_groups: + for param in group["params"]: + assert isinstance(param, torch.Tensor) + if param.grad is None: + continue + if (not allow_fp16) and param.grad.dtype == torch.float16: + raise ValueError("Attempting to unscale FP16 gradients.") + if param.grad.is_sparse: + # is_coalesced() == False means the sparse grad has values with duplicate indices. + # coalesce() deduplicates indices and adds all values that have the same index. + # For scaled fp16 values, there's a good chance coalescing will cause overflow, + # so we should check the coalesced _values(). + if param.grad.dtype is torch.float16: + param.grad = param.grad.coalesce() + to_unscale = param.grad._values() + else: + to_unscale = param.grad + + # TODO: is there a way to split by device and dtype without appending in the inner loop? + per_device_and_dtype_grads[to_unscale.device][ + to_unscale.dtype + ].append(to_unscale) + + for device, per_dtype_grads in per_device_and_dtype_grads.items(): + for grads in per_dtype_grads.values(): + torch._amp_foreach_non_finite_check_and_unscale_( + grads, + per_device_found_inf.get(device), + per_device_inv_scale.get(device), + ) + + return per_device_found_inf._per_device_tensors + + def unscale_(self, optimizer: torch.optim.Optimizer) -> None: + """ + Divides ("unscales") the optimizer's gradient tensors by the scale factor. + + :meth:`unscale_` is optional, serving cases where you need to + :ref:`modify or inspect gradients` + between the backward pass(es) and :meth:`step`. + If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`. + + Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients:: + + ... + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) + scaler.step(optimizer) + scaler.update() + + Args: + optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled. + + .. note:: + :meth:`unscale_` does not incur a CPU-GPU sync. + + .. warning:: + :meth:`unscale_` should only be called once per optimizer per :meth:`step` call, + and only after all gradients for that optimizer's assigned parameters have been accumulated. + Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError. + + .. warning:: + :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute. + """ + if not self._enabled: + return + + self._check_scale_growth_tracker("unscale_") + + optimizer_state = self._per_optimizer_states[id(optimizer)] + + if optimizer_state["stage"] is OptState.UNSCALED: + raise RuntimeError( + "unscale_() has already been called on this optimizer since the last update()." + ) + elif optimizer_state["stage"] is OptState.STEPPED: + raise RuntimeError("unscale_() is being called after step().") + + # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. + assert self._scale is not None + inv_scale = ( + self._scale.double().reciprocal().float() + if self._scale.device != torch.device("mps:0") + else self._scale.reciprocal() + ) + found_inf = torch.full((), 0.0, dtype=torch.float32, device=self._scale.device) + + optimizer_state["found_inf_per_device"] = self._unscale_grads_( + optimizer, inv_scale, found_inf, False + ) + optimizer_state["stage"] = OptState.UNSCALED + + def _maybe_opt_step( + self, + optimizer: torch.optim.Optimizer, + optimizer_state: dict[str, Any], + *args: Any, + **kwargs: Any, + ) -> Optional[float]: + retval: Optional[float] = None + if not sum(v.item() for v in optimizer_state["found_inf_per_device"].values()): + retval = optimizer.step(*args, **kwargs) + return retval + + def step( + self, optimizer: torch.optim.Optimizer, *args: Any, **kwargs: Any + ) -> Optional[float]: + """Invoke ``unscale_(optimizer)`` followed by parameter update, if gradients are not infs/NaN. + + :meth:`step` carries out the following two operations: + + 1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer`` + earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs. + 2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled + gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params. + + ``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``. + + Returns the return value of ``optimizer.step(*args, **kwargs)``. + + Args: + optimizer (torch.optim.Optimizer): Optimizer that applies the gradients. + args: Any arguments. + kwargs: Any keyword arguments. + + .. warning:: + Closure use is not currently supported. + """ + if not self._enabled: + return optimizer.step(*args, **kwargs) + + if "closure" in kwargs: + raise RuntimeError( + "Closure use is not currently supported if GradScaler is enabled." + ) + + self._check_scale_growth_tracker("step") + + optimizer_state = self._per_optimizer_states[id(optimizer)] + + if optimizer_state["stage"] is OptState.STEPPED: + raise RuntimeError( + "step() has already been called since the last update()." + ) + + retval: Optional[float] = None + + if getattr(optimizer, "_step_supports_amp_scaling", False): + # This optimizer has customized scale-handling logic, so we can call optimizer.step() directly. + # The contract with custom optimizers is that their step() should accept an additional, + # optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information: + # it can query its own state, invoke unscale_ on itself, etc + # The contract above is being deprecated to avoid introducing `grad_scaler: GradScaler` argument + # to `Optimizer.step`. The new behavior is going to add two Tensor attributes of `grad_scale` + # and `found_inf` to the passed optimizer so that the optimizer can utilize those + # to skip the parameter updates or unscale gradients before updating parameters in + # the fused kernel, e.g. `FusedAdamMathFunctor`. + # In this behavior, `GradScaler._check_inf_per_device` is called if `OptState.READY`, + # while the method is expected to be called by users side, i.e. their optimizers. + kwargs_ = kwargs + has_grad_scaler_kwarg = ( + "grad_scaler" in inspect.signature(optimizer.step).parameters + ) + if has_grad_scaler_kwarg: + warnings.warn( + "GradScaler is going to stop passing itself as a keyword argument to the passed " + "optimizer. In the near future GradScaler registers `grad_scale: Tensor` and " + "`found_inf: Tensor` to the passed optimizer and let the optimizer use them directly.", + FutureWarning, + ) + kwargs_.update({"grad_scaler": self}) + else: + if optimizer_state["stage"] is OptState.READY: + self._check_inf_per_device(optimizer) + scaler = self._get_scale_async() + assert scaler is not None + found_inf = cast( + torch.Tensor, + sum( + [ # noqa: C419 + t.to(scaler.device, non_blocking=True) + for t in optimizer_state["found_inf_per_device"].values() + ] + ), + ) + # Take the product of the scales, if the user has already set `optimizer.grad_scale`. + optimizer.grad_scale = ( # type: ignore[attr-defined] + getattr(optimizer, "grad_scale", None) + if optimizer_state["stage"] == OptState.UNSCALED + else scaler * getattr(optimizer, "grad_scale", 1) + ) + optimizer.found_inf = found_inf # type: ignore[attr-defined] + retval = optimizer.step(*args, **kwargs_) + optimizer_state["stage"] = OptState.STEPPED + if not has_grad_scaler_kwarg: + del optimizer.grad_scale # type: ignore[attr-defined] + del optimizer.found_inf # type: ignore[attr-defined] + return retval + + if optimizer_state["stage"] is OptState.READY: + self.unscale_(optimizer) + + assert ( + len(optimizer_state["found_inf_per_device"]) > 0 + ), "No inf checks were recorded for this optimizer." + + retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs) + + optimizer_state["stage"] = OptState.STEPPED + + return retval + + def update(self, new_scale: Optional[Union[float, torch.Tensor]] = None) -> None: + """Update the scale factor. + + If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` + to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, + the scale is multiplied by ``growth_factor`` to increase it. + + Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not + used directly, it's used to fill GradScaler's internal scale tensor. So if + ``new_scale`` was a tensor, later in-place changes to that tensor will not further + affect the scale GradScaler uses internally.) + + Args: + new_scale (float or :class:`torch.Tensor`, optional, default=None): New scale factor. + + .. warning:: + :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has + been invoked for all optimizers used this iteration. + + .. warning:: + For performance reasons, we do not check the scale factor value to avoid synchronizations, + so the scale factor is not guaranteed to be above 1. If the scale falls below 1 and/or + you are seeing NaNs in your gradients or loss, something is likely wrong. For example, + bf16-pretrained models are often incompatible with AMP/fp16 due to differing dynamic ranges. + """ + if not self._enabled: + return + + _scale, _growth_tracker = self._check_scale_growth_tracker("update") + + if new_scale is not None: + assert self._scale is not None + # Accept a new user-defined scale. + if isinstance(new_scale, float): + self._scale.fill_(new_scale) + else: + reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor or \ + torch.FloatTensor with requires_grad=False." + assert new_scale.device.type == self._device, reason + assert new_scale.numel() == 1, reason + assert new_scale.requires_grad is False, reason + self._scale.copy_(new_scale) + else: + # Consume shared inf/nan data collected from optimizers to update the scale. + # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. + found_infs = [ + found_inf.to(device=_scale.device, non_blocking=True) + for state in self._per_optimizer_states.values() + for found_inf in state["found_inf_per_device"].values() + ] + + assert len(found_infs) > 0, "No inf checks were recorded prior to update." + + found_inf_combined = found_infs[0] + if len(found_infs) > 1: + for i in range(1, len(found_infs)): + found_inf_combined += found_infs[i] + + torch._amp_update_scale_( + _scale, + _growth_tracker, + found_inf_combined, + self._growth_factor, + self._backoff_factor, + self._growth_interval, + ) + + # To prepare for next iteration, clear the data collected from optimizers this iteration. + self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) + + def _get_scale_async(self) -> Optional[torch.Tensor]: + return self._scale + + def get_scale(self) -> float: + """Return a Python float containing the current scale, or 1.0 if scaling is disabled. + + .. warning:: + :meth:`get_scale` incurs a CPU-GPU sync. + """ + if self._enabled: + return ( + self._init_scale + if (scale := self._get_scale_async()) is None + else cast(float, scale.item()) + ) + return 1.0 + + def get_growth_factor(self) -> float: + r"""Return a Python float containing the scale growth factor.""" + return self._growth_factor + + def set_growth_factor(self, new_factor: float) -> None: + r"""Set a new scale growth factor. + + Args: + new_scale (float): Value to use as the new scale growth factor. + """ + self._growth_factor = new_factor + + def get_backoff_factor(self) -> float: + r"""Return a Python float containing the scale backoff factor.""" + return self._backoff_factor + + def set_backoff_factor(self, new_factor: float) -> None: + r"""Set a new scale backoff factor. + + Args: + new_scale (float): Value to use as the new scale backoff factor. + """ + self._backoff_factor = new_factor + + def get_growth_interval(self) -> int: + r"""Return a Python int containing the growth interval.""" + return self._growth_interval + + def set_growth_interval(self, new_interval: int) -> None: + r"""Set a new growth interval. + + Args: + new_interval (int): Value to use as the new growth interval. + """ + self._growth_interval = new_interval + + def _get_growth_tracker(self) -> int: + if self._enabled: + return ( + self._init_growth_tracker + if self._growth_tracker is None + else cast(int, self._growth_tracker.item()) + ) + return 0 + + def is_enabled(self) -> bool: + r"""Return a bool indicating whether this instance is enabled.""" + return self._enabled + + def state_dict(self) -> dict[str, Any]: + r"""Return the state of the scaler as a :class:`dict`. + + It contains five entries: + + * ``"scale"`` - a Python float containing the current scale + * ``"growth_factor"`` - a Python float containing the current growth factor + * ``"backoff_factor"`` - a Python float containing the current backoff factor + * ``"growth_interval"`` - a Python int containing the current growth interval + * ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps. + + If this instance is not enabled, returns an empty dict. + + .. note:: + If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict` + should be called after :meth:`update`. + """ + if self._enabled: + return { + "scale": self.get_scale(), + "growth_factor": self._growth_factor, + "backoff_factor": self._backoff_factor, + "growth_interval": self._growth_interval, + "_growth_tracker": self._get_growth_tracker(), + } + return {} + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + r"""Load the scaler state. + + If this instance is disabled, :meth:`load_state_dict` is a no-op. + + Args: + state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`. + """ + if not self._enabled: + return + + if len(state_dict) == 0: + raise RuntimeError( + "The source state dict is empty, possibly because it was saved " + "from a disabled instance of GradScaler." + ) + + self._init_scale = cast(float, state_dict["scale"]) + if self._scale is not None: + self._scale.fill_(state_dict["scale"]) + self._growth_factor = cast(float, state_dict["growth_factor"]) + self._backoff_factor = cast(float, state_dict["backoff_factor"]) + self._growth_interval = cast(int, state_dict["growth_interval"]) + self._init_growth_tracker = cast(int, state_dict["_growth_tracker"]) + if self._growth_tracker is not None: + self._growth_tracker.fill_(state_dict["_growth_tracker"]) + + def __getstate__(self) -> dict[str, Any]: + state = self.__dict__.copy() + if self._enabled: + assert len(self._per_optimizer_states) == 0, ( + "A GradScaler instance may only be pickled at the beginning " + "of an iteration, or at the end after scaler.update()." + ) + # Pickling _scale and _growth_tracker Tensors directly triggers + # "warnings.warn("pickle support for Storage will be removed in 1.5..." + # so instead, we set the unpickled instance up to reinitialize them lazily. + state["_init_scale"] = self.get_scale() + state["_init_growth_tracker"] = self._get_growth_tracker() + state["_scale"] = None + state["_growth_tracker"] = None + return state + + def __setstate__(self, state: dict[str, Any]) -> None: + self.__dict__.update(state) + + def _check_inf_per_device(self, optimizer: torch.optim.Optimizer) -> dict[str, Any]: + _scale, _ = self._check_scale_growth_tracker("_check_inf_per_device") + + dummy_inv_scale = torch.full((), 1.0, dtype=torch.float32, device=_scale.device) + found_inf = torch.full((), 0.0, dtype=torch.float32, device=_scale.device) + + self._per_optimizer_states[id(optimizer)][ + "found_inf_per_device" + ] = self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True) + + return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] + + def _found_inf_per_device(self, optimizer: torch.optim.Optimizer) -> dict[str, Any]: + return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] diff --git a/phivenv/Lib/site-packages/torch/ao/__init__.py b/phivenv/Lib/site-packages/torch/ao/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..63e5f3868deb7e9596c81ae8e5fdd59853bba61a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/__init__.py @@ -0,0 +1,31 @@ +# torch.ao is a package with a lot of interdependencies. +# We will use lazy import to avoid cyclic dependencies here. + +from typing import TYPE_CHECKING as _TYPE_CHECKING + + +if _TYPE_CHECKING: + from types import ModuleType + + from torch.ao import ( # noqa: TC004 + nn as nn, + ns as ns, + pruning as pruning, + quantization as quantization, + ) + + +__all__ = [ + "nn", + "ns", + "pruning", + "quantization", +] + + +def __getattr__(name: str) -> "ModuleType": + if name in __all__: + import importlib + + return importlib.import_module("." + name, __name__) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/phivenv/Lib/site-packages/torch/ao/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..879aab4a9491878b8aff08479662acfa373d0753 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/__init__.py b/phivenv/Lib/site-packages/torch/ao/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2e312889ec39a9a0569acdb87266098b254f49dc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/__init__.py @@ -0,0 +1,35 @@ +# We are exposing all subpackages to the end-user. +# Because of possible inter-dependency, we want to avoid +# the cyclic imports, thus implementing lazy version +# as per https://peps.python.org/pep-0562/ + +from typing import TYPE_CHECKING as _TYPE_CHECKING + + +if _TYPE_CHECKING: + from types import ModuleType + + from torch.ao.nn import ( # noqa: TC004 + intrinsic as intrinsic, + qat as qat, + quantizable as quantizable, + quantized as quantized, + sparse as sparse, + ) + + +__all__ = [ + "intrinsic", + "qat", + "quantizable", + "quantized", + "sparse", +] + + +def __getattr__(name: str) -> "ModuleType": + if name in __all__: + import importlib + + return importlib.import_module("." + name, __name__) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/phivenv/Lib/site-packages/torch/ao/nn/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4603c491e201f2ab560f0c4beee3c608848a72af Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/__init__.py b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..92a1f48add061fa2bd14c5dad2cdb9bbfd6ff5a9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/__init__.py @@ -0,0 +1,41 @@ +import types + +from .modules import * # noqa: F403 +from .modules.fused import _FusedModule # noqa: F403 + + +# # Subpackages +# from . import qat # noqa: F403 +# from . import quantized # noqa: F403 + +__all__ = [ + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "LinearReLU", + "BNReLU2d", + "BNReLU3d", + "LinearBn1d", + "LinearLeakyReLU", + "LinearTanh", + "ConvAdd2d", + "ConvAddReLU2d", +] + + +# We are exposing all subpackages to the end-user. +# Because of possible inter-dependency, we want to avoid +# the cyclic imports, thus implementing lazy version +# as per https://peps.python.org/pep-0562/ +def __getattr__(name: str) -> types.ModuleType: + if name in __all__: + import importlib + + return importlib.import_module("." + name, __name__) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b990751a4f36f9e1f8b4ecb1284053c71dcfbd98 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/modules/__init__.py b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..964b6eae13bad23cc73a2329fa91b8aed31d11ba --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/modules/__init__.py @@ -0,0 +1,41 @@ +from .fused import ( # noqa: F401 + _FusedModule, + BNReLU2d, + BNReLU3d, + ConvAdd2d, + ConvAddReLU2d, + ConvBn1d, + ConvBn2d, + ConvBn3d, + ConvBnReLU1d, + ConvBnReLU2d, + ConvBnReLU3d, + ConvReLU1d, + ConvReLU2d, + ConvReLU3d, + LinearBn1d, + LinearLeakyReLU, + LinearReLU, + LinearTanh, +) + + +__all__ = [ + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "LinearReLU", + "BNReLU2d", + "BNReLU3d", + "LinearBn1d", + "LinearLeakyReLU", + "LinearTanh", + "ConvAdd2d", + "ConvAddReLU2d", +] diff --git a/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d631a3f181fbd7cf21d6bdf76f2a59d83978b43c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/fused.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/fused.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b16b1d8f05acec45181349883413c7c17474e46 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/fused.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/modules/fused.py b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/modules/fused.py new file mode 100644 index 0000000000000000000000000000000000000000..2c47ff93825eec75c8896470de678654ad3e354a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/modules/fused.py @@ -0,0 +1,287 @@ +# mypy: allow-untyped-defs +import torch +from torch.nn import ( + BatchNorm1d, + BatchNorm2d, + BatchNorm3d, + Conv1d, + Conv2d, + Conv3d, + Linear, + ReLU, +) +from torch.nn.utils.parametrize import type_before_parametrizations + + +__all__ = [ + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "LinearReLU", + "ConvBn1d", + "ConvBn2d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBn3d", + "ConvBnReLU3d", + "BNReLU2d", + "BNReLU3d", + "LinearBn1d", + "LinearLeakyReLU", + "LinearTanh", + "ConvAdd2d", + "ConvAddReLU2d", +] + + +# Used for identifying intrinsic modules used in quantization +class _FusedModule(torch.nn.Sequential): + pass + + +class ConvReLU1d(_FusedModule): + r"""This is a sequential container which calls the Conv1d and ReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, conv, relu): + assert ( + type_before_parametrizations(conv) == Conv1d + and type_before_parametrizations(relu) == ReLU + ), ( + f"Incorrect types for input modules{type_before_parametrizations(conv)}" + f"{type_before_parametrizations(relu)}" + ) + super().__init__(conv, relu) + + +class ConvReLU2d(_FusedModule): + r"""This is a sequential container which calls the Conv2d and ReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, conv, relu): + assert ( + type_before_parametrizations(conv) == Conv2d + and type_before_parametrizations(relu) == ReLU + ), ( + f"Incorrect types for input modules{type_before_parametrizations(conv)}" + f"{type_before_parametrizations(relu)}" + ) + super().__init__(conv, relu) + + +class ConvReLU3d(_FusedModule): + r"""This is a sequential container which calls the Conv3d and ReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, conv, relu): + assert ( + type_before_parametrizations(conv) == Conv3d + and type_before_parametrizations(relu) == ReLU + ), ( + f"Incorrect types for input modules{type_before_parametrizations(conv)}" + f"{type_before_parametrizations(relu)}" + ) + super().__init__(conv, relu) + + +class LinearReLU(_FusedModule): + r"""This is a sequential container which calls the Linear and ReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, linear, relu): + assert ( + type_before_parametrizations(linear) == Linear + and type_before_parametrizations(relu) == ReLU + ), ( + f"Incorrect types for input modules{type_before_parametrizations(linear)}" + f"{type_before_parametrizations(relu)}" + ) + super().__init__(linear, relu) + + +class ConvBn1d(_FusedModule): + r"""This is a sequential container which calls the Conv 1d and Batch Norm 1d modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, conv, bn): + assert ( + type_before_parametrizations(conv) == Conv1d + and type_before_parametrizations(bn) == BatchNorm1d + ), ( + f"Incorrect types for input modules{type_before_parametrizations(conv)}" + f"{type_before_parametrizations(bn)}" + ) + super().__init__(conv, bn) + + +class ConvBn2d(_FusedModule): + r"""This is a sequential container which calls the Conv 2d and Batch Norm 2d modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, conv, bn): + assert ( + type_before_parametrizations(conv) == Conv2d + and type_before_parametrizations(bn) == BatchNorm2d + ), ( + f"Incorrect types for input modules{type_before_parametrizations(conv)}" + f"{type_before_parametrizations(bn)}" + ) + super().__init__(conv, bn) + + +class ConvBnReLU1d(_FusedModule): + r"""This is a sequential container which calls the Conv 1d, Batch Norm 1d, and ReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, conv, bn, relu): + assert ( + type_before_parametrizations(conv) == Conv1d + and type_before_parametrizations(bn) == BatchNorm1d + and type_before_parametrizations(relu) == ReLU + ), ( + f"Incorrect types for input modules{type_before_parametrizations(conv)}" + f"{type_before_parametrizations(bn)}" + f"{type_before_parametrizations(relu)}" + ) + super().__init__(conv, bn, relu) + + +class ConvBnReLU2d(_FusedModule): + r"""This is a sequential container which calls the Conv 2d, Batch Norm 2d, and ReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, conv, bn, relu): + assert ( + type_before_parametrizations(conv) == Conv2d + and type_before_parametrizations(bn) == BatchNorm2d + and type_before_parametrizations(relu) == ReLU + ), ( + f"Incorrect types for input modules{type_before_parametrizations(conv)}" + f"{type_before_parametrizations(bn)}" + f"{type_before_parametrizations(relu)}" + ) + super().__init__(conv, bn, relu) + + +class ConvBn3d(_FusedModule): + r"""This is a sequential container which calls the Conv 3d and Batch Norm 3d modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, conv, bn): + assert ( + type_before_parametrizations(conv) == Conv3d + and type_before_parametrizations(bn) == BatchNorm3d + ), ( + f"Incorrect types for input modules{type_before_parametrizations(conv)}" + f"{type_before_parametrizations(bn)}" + ) + super().__init__(conv, bn) + + +class ConvBnReLU3d(_FusedModule): + r"""This is a sequential container which calls the Conv 3d, Batch Norm 3d, and ReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, conv, bn, relu): + assert ( + type_before_parametrizations(conv) == Conv3d + and type_before_parametrizations(bn) == BatchNorm3d + and type_before_parametrizations(relu) == ReLU + ), ( + f"Incorrect types for input modules{type_before_parametrizations(conv)}" + f"{type_before_parametrizations(bn)}" + f"{type_before_parametrizations(relu)}" + ) + super().__init__(conv, bn, relu) + + +class BNReLU2d(_FusedModule): + r"""This is a sequential container which calls the BatchNorm 2d and ReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, batch_norm, relu): + assert ( + type_before_parametrizations(batch_norm) == BatchNorm2d + and type_before_parametrizations(relu) == ReLU + ), ( + f"Incorrect types for input modules{type_before_parametrizations(batch_norm)}" + f"{type_before_parametrizations(relu)}" + ) + super().__init__(batch_norm, relu) + + +class BNReLU3d(_FusedModule): + r"""This is a sequential container which calls the BatchNorm 3d and ReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, batch_norm, relu): + assert ( + type_before_parametrizations(batch_norm) == BatchNorm3d + and type_before_parametrizations(relu) == ReLU + ), ( + f"Incorrect types for input modules{type_before_parametrizations(batch_norm)}" + f"{type_before_parametrizations(relu)}" + ) + super().__init__(batch_norm, relu) + + +class LinearBn1d(_FusedModule): + r"""This is a sequential container which calls the Linear and BatchNorm1d modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, linear, bn): + assert ( + type_before_parametrizations(linear) == Linear + and type_before_parametrizations(bn) == BatchNorm1d + ), ( + f"Incorrect types for input modules{type_before_parametrizations(linear)}" + f"{type_before_parametrizations(bn)}" + ) + super().__init__(linear, bn) + + +class LinearLeakyReLU(_FusedModule): + r"""This is a sequential container which calls the Linear and LeakyReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, linear, leaky_relu): + assert type(linear) == Linear and type(leaky_relu) == torch.nn.LeakyReLU, ( + f"Incorrect types for input modules{type(linear)}{type(leaky_relu)}" + ) + super().__init__(linear, leaky_relu) + + +class LinearTanh(_FusedModule): + r"""This is a sequential container which calls the Linear and Tanh modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, linear, tanh): + assert type(linear) == Linear and type(tanh) == torch.nn.Tanh, ( + f"Incorrect types for input modules{type(linear)}{type(tanh)}" + ) + super().__init__(linear, tanh) + + +class ConvAdd2d(_FusedModule): + r"""This is a sequential container which calls the Conv2d modules with extra Add. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, conv, add): + super().__init__(conv) + self.add = add + + def forward(self, x1, x2): # type: ignore[override] + return self.add(self[0](x1), x2) + + +class ConvAddReLU2d(_FusedModule): + r"""This is a sequential container which calls the Conv2d, add, Relu. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, conv, add, relu): + super().__init__(conv) + self.add = add + self.relu = relu + + def forward(self, x1, x2): # type: ignore[override] + return self.relu(self.add(self[0](x1), x2)) diff --git a/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/qat/__init__.py b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/qat/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ebc6df8afce25c62a5707136bc46cab16c49a83c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/qat/__init__.py @@ -0,0 +1 @@ +from .modules import * # noqa: F403 diff --git a/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/qat/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/qat/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0624222dc5e0f4b2dc2b382f21670c8f64661e20 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/qat/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/__init__.py b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1a27323c72e46bca658700dc4628053dab7184fd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/__init__.py @@ -0,0 +1,32 @@ +from .conv_fused import ( + ConvBn1d, + ConvBn2d, + ConvBn3d, + ConvBnReLU1d, + ConvBnReLU2d, + ConvBnReLU3d, + ConvReLU1d, + ConvReLU2d, + ConvReLU3d, + freeze_bn_stats, + update_bn_stats, +) +from .linear_fused import LinearBn1d +from .linear_relu import LinearReLU + + +__all__ = [ + "LinearReLU", + "LinearBn1d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + "update_bn_stats", + "freeze_bn_stats", +] diff --git a/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42fddcee78e1890ef19bb0310ed7021b8e0412de Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/conv_fused.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/conv_fused.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6706b34c4a536b760cbd2f8b1cf5d9f4d4fbc9af Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/conv_fused.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/linear_fused.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/linear_fused.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f54c96f8b9b108f8b6e1239f2c847a2d7f680d2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/linear_fused.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/linear_relu.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/linear_relu.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45b7d46ff9e092c25deb66934205eaf216b7e064 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/linear_relu.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/conv_fused.py b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/conv_fused.py new file mode 100644 index 0000000000000000000000000000000000000000..c14bf92cfc2a7dac01e31f90819f180934cadaf9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/conv_fused.py @@ -0,0 +1,1064 @@ +# mypy: allow-untyped-defs +import math +from typing import ClassVar, Optional + +import torch +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.qat as nnqat +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import init +from torch.nn.modules.utils import _pair, _single, _triple +from torch.nn.parameter import Parameter +from torch.nn.utils import fuse_conv_bn_weights + + +__all__ = [ + "ConvBn1d", + "ConvBnReLU1d", + "ConvReLU1d", + "ConvBn2d", + "ConvBnReLU2d", + "ConvReLU2d", + "ConvBn3d", + "ConvBnReLU3d", + "ConvReLU3d", + "update_bn_stats", + "freeze_bn_stats", +] +_BN_CLASS_MAP = { + 1: nn.BatchNorm1d, + 2: nn.BatchNorm2d, + 3: nn.BatchNorm3d, +} + + +class _ConvBnNd(nn.modules.conv._ConvNd, nni._FusedModule): + _version = 2 + _FLOAT_MODULE: ClassVar[type[nn.modules.conv._ConvNd]] + + def __init__( + self, + # ConvNd args + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + bias, + padding_mode, + # BatchNormNd args + # num_features: out_channels + eps=1e-05, + momentum=0.1, + # affine: True + # track_running_stats: True + # Args for this module + freeze_bn=False, + qconfig=None, + dim=2, + ): + nn.modules.conv._ConvNd.__init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + False, + padding_mode, + ) + assert qconfig, "qconfig must be provided for QAT module" + self.qconfig = qconfig + self.freeze_bn = freeze_bn if self.training else True + self.bn = _BN_CLASS_MAP[dim](out_channels, eps, momentum, True, True) + self.weight_fake_quant = self.qconfig.weight() + if bias: + self.bias = Parameter(torch.empty(out_channels)) + else: + self.register_parameter("bias", None) + self.reset_bn_parameters() + + # this needs to be called after reset_bn_parameters, + # as they modify the same state + if self.training: + if freeze_bn: + self.freeze_bn_stats() + else: + self.update_bn_stats() + else: + self.freeze_bn_stats() + + self._enable_slow_path_for_better_numerical_stability = False + + def reset_running_stats(self): + self.bn.reset_running_stats() + + def reset_bn_parameters(self): + self.bn.reset_running_stats() + init.uniform_(self.bn.weight) + init.zeros_(self.bn.bias) + # note: below is actually for conv, not BN + if self.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + init.uniform_(self.bias, -bound, bound) + + def reset_parameters(self): + super().reset_parameters() + + def update_bn_stats(self): + self.freeze_bn = False + self.bn.training = True + return self + + def freeze_bn_stats(self): + self.freeze_bn = True + self.bn.training = False + return self + + def _forward(self, input): + if self._enable_slow_path_for_better_numerical_stability: + return self._forward_slow(input) + return self._forward_approximate(input) + + def _forward_approximate(self, input): + """Approximated method to fuse conv and bn. It requires only one forward pass. + conv_orig = conv / scale_factor where scale_factor = bn.weight / running_std + """ + assert self.bn.running_var is not None + running_std = torch.sqrt(self.bn.running_var + self.bn.eps) + scale_factor = self.bn.weight / running_std + weight_shape = [1] * len(self.weight.shape) + weight_shape[0] = -1 + bias_shape = [1] * len(self.weight.shape) + bias_shape[1] = -1 + scaled_weight = self.weight_fake_quant( + self.weight * scale_factor.reshape(weight_shape) + ) + # using zero bias here since the bias for original conv + # will be added later + if self.bias is not None: + zero_bias = torch.zeros_like(self.bias, dtype=input.dtype) + else: + zero_bias = torch.zeros( + self.out_channels, device=scaled_weight.device, dtype=input.dtype + ) + conv = self._conv_forward(input, scaled_weight, zero_bias) + conv_orig = conv / scale_factor.reshape(bias_shape) + if self.bias is not None: + conv_orig = conv_orig + self.bias.reshape(bias_shape) + conv = self.bn(conv_orig) + return conv + + def _forward_slow(self, input): + """ + A more accurate but slow method to compute conv bn fusion, following https://arxiv.org/pdf/1806.08342.pdf + It requires two forward passes but handles the case bn.weight == 0 + + Conv: Y = WX + B_c + Conv without bias: Y0 = WX = Y - B_c, Y = Y0 + B_c + + Batch statistics: + mean_Y = Y.mean() + = Y0.mean() + B_c + var_Y = (Y - mean_Y)^2.mean() + = (Y0 - Y0.mean())^2.mean() + BN (r: bn.weight, beta: bn.bias): + Z = r * (Y - mean_Y) / sqrt(var_Y + eps) + beta + = r * (Y0 - Y0.mean()) / sqrt(var_Y + eps) + beta + + Fused Conv BN training (std_Y = sqrt(var_Y + eps)): + Z = (r * W / std_Y) * X + r * (B_c - mean_Y) / std_Y + beta + = (r * W / std_Y) * X - r * Y0.mean() / std_Y + beta + + Fused Conv BN inference (running_std = sqrt(running_var + eps)): + Z = (r * W / running_std) * X - r * (running_mean - B_c) / running_std + beta + + QAT with fused conv bn: + Z_train = fake_quant(r * W / running_std) * X * (running_std / std_Y) - r * Y0.mean() / std_Y + beta + = conv(X, fake_quant(r * W / running_std)) * (running_std / std_Y) - r * Y0.mean() / std_Y + beta + Z_inference = conv(X, fake_quant(r * W / running_std)) - r * (running_mean - B_c) / running_std + beta + """ + + assert self.bn.running_var is not None + assert self.bn.running_mean is not None + + # using zero bias here since the bias for original conv + # will be added later + zero_bias = torch.zeros( + self.out_channels, device=self.weight.device, dtype=input.dtype + ) + + weight_shape = [1] * len(self.weight.shape) + weight_shape[0] = -1 + bias_shape = [1] * len(self.weight.shape) + bias_shape[1] = -1 + + if self.bn.training: + # needed to compute batch mean/std + conv_out = self._conv_forward(input, self.weight, zero_bias) + # update bn statistics + with torch.no_grad(): + conv_out_bias = ( + conv_out + if self.bias is None + else conv_out + self.bias.reshape(bias_shape) + ) + self.bn(conv_out_bias) + + # fused conv + bn without bias using bn running statistics + running_std = torch.sqrt(self.bn.running_var + self.bn.eps) + scale_factor = self.bn.weight / running_std + scaled_weight = self.weight_fake_quant( + self.weight * scale_factor.reshape(weight_shape) + ) + # fused conv without bias for inference: (r * W / running_std) * X + conv_bn = self._conv_forward(input, scaled_weight, zero_bias) + + avg_dims = [0] + list(range(2, len(self.weight.shape))) + batch_mean = conv_out.mean(avg_dims) + batch_var = torch.square(conv_out - batch_mean.reshape(bias_shape)).mean( + avg_dims + ) + batch_std = torch.sqrt(batch_var + self.bn.eps) + + # scale to use batch std in training mode + # conv(X, r * W / std_Y) = conv(X, r * W / running_std) * (running_std / std_Y) + unscale_factor = running_std / batch_std + conv_bn *= unscale_factor.reshape(bias_shape) + + fused_mean = batch_mean + fused_std = batch_std + else: + # fused conv + bn without bias using bn running statistics + running_std = torch.sqrt(self.bn.running_var + self.bn.eps) + scale_factor = self.bn.weight / running_std + scaled_weight = self.weight_fake_quant( + self.weight * scale_factor.reshape(weight_shape) + ) + # fused conv without bias for inference: (r * W / running_std) * X + conv_bn = self._conv_forward(input, scaled_weight, zero_bias) + + fused_mean = self.bn.running_mean - ( + self.bias if self.bias is not None else 0 + ) + fused_std = running_std + + # fused bias = beta - r * mean / std + fused_bias = self.bn.bias - self.bn.weight * fused_mean / fused_std + conv_bn += fused_bias.reshape(bias_shape) + + # HACK to let conv bias participate in loss to avoid DDP error (parameters + # were not used in producing loss) + if self.bias is not None: + conv_bn += (self.bias - self.bias).reshape(bias_shape) + + return conv_bn + + def extra_repr(self): + # TODO(jerryzh): extend + return super().extra_repr() + + def forward(self, input): + return self._forward(input) + + def train(self, mode=True): + """ + Batchnorm's training behavior is using the self.training flag. Prevent + changing it if BN is frozen. This makes sure that calling `model.train()` + on a model with a frozen BN will behave properly. + """ + self.training = mode + if not self.freeze_bn: + for module in self.children(): + module.train(mode) + return self + + # ===== Serialization version history ===== + # + # Version 1/None + # self + # |--- weight : Tensor + # |--- bias : Tensor + # |--- gamma : Tensor + # |--- beta : Tensor + # |--- running_mean : Tensor + # |--- running_var : Tensor + # |--- num_batches_tracked : Tensor + # + # Version 2 + # self + # |--- weight : Tensor + # |--- bias : Tensor + # |--- bn : Module + # |--- weight : Tensor (moved from v1.self.gamma) + # |--- bias : Tensor (moved from v1.self.beta) + # |--- running_mean : Tensor (moved from v1.self.running_mean) + # |--- running_var : Tensor (moved from v1.self.running_var) + # |--- num_batches_tracked : Tensor (moved from v1.self.num_batches_tracked) + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + if version is None or version == 1: + # BN related parameters and buffers were moved into the BN module for v2 + v2_to_v1_names = { + "bn.weight": "gamma", + "bn.bias": "beta", + "bn.running_mean": "running_mean", + "bn.running_var": "running_var", + "bn.num_batches_tracked": "num_batches_tracked", + } + for v2_name, v1_name in v2_to_v1_names.items(): + if prefix + v1_name in state_dict: + state_dict[prefix + v2_name] = state_dict[prefix + v1_name] + state_dict.pop(prefix + v1_name) + elif prefix + v2_name in state_dict: + # there was a brief period where forward compatibility + # for this module was broken (between + # https://github.com/pytorch/pytorch/pull/38478 + # and https://github.com/pytorch/pytorch/pull/38820) + # and modules emitted the v2 state_dict format while + # specifying that version == 1. This patches the forward + # compatibility issue by allowing the v2 style entries to + # be used. + pass + elif strict: + missing_keys.append(prefix + v2_name) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Create a qat module from a float module or qparams_dict + + Args: `mod` a float module, either produced by torch.ao.quantization utilities + or directly from user + """ + # The ignore is because _FLOAT_MODULE is a TypeVar here where the bound + # has no __name__ (code is fine though) + assert type(mod) == cls._FLOAT_MODULE, ( + "qat." + + cls.__name__ + + ".from_float only works for " + + cls._FLOAT_MODULE.__name__ + ) + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + assert mod.qconfig, "Input float module must have a valid qconfig" + qconfig = mod.qconfig + conv, bn = mod[0], mod[1] # type: ignore[index] + qat_convbn = cls( + conv.in_channels, + conv.out_channels, + conv.kernel_size, + conv.stride, + conv.padding, + conv.dilation, + conv.groups, + conv.bias is not None, + conv.padding_mode, + bn.eps, + bn.momentum, + False, + qconfig, + ) + qat_convbn.weight = conv.weight + qat_convbn.bias = conv.bias + qat_convbn.bn.weight = bn.weight + qat_convbn.bn.bias = bn.bias + qat_convbn.bn.running_mean = bn.running_mean + qat_convbn.bn.running_var = bn.running_var + # mypy error: Cannot determine type of 'num_batches_tracked' + qat_convbn.bn.num_batches_tracked = bn.num_batches_tracked + return qat_convbn + + def to_float(self): + cls = type(self) + conv = cls._FLOAT_CONV_MODULE( # type: ignore[attr-defined] + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + self.padding, + self.dilation, + self.groups, + self.bias is not None, + self.padding_mode, + ) + conv.weight = torch.nn.Parameter(self.weight.detach()) + if self.bias is not None: + conv.bias = torch.nn.Parameter(self.bias.detach()) + + if cls._FLOAT_BN_MODULE: # type: ignore[attr-defined] + # fuse bn into conv + assert self.bn.running_var is not None and self.bn.running_mean is not None + conv.weight, conv.bias = fuse_conv_bn_weights( + conv.weight, + conv.bias, + self.bn.running_mean, + self.bn.running_var, + self.bn.eps, + self.bn.weight, + self.bn.bias, + ) + + if cls._FLOAT_RELU_MODULE: # type: ignore[attr-defined] + modules = [] + modules.append(conv) + relu = cls._FLOAT_RELU_MODULE() # type: ignore[attr-defined] + modules.append(relu) + conv_relu = cls._FUSED_FLOAT_MODULE(*modules) # type: ignore[attr-defined] + conv_relu.train(self.training) + return conv_relu + else: + conv.train(self.training) + return conv + + +class ConvBn1d(_ConvBnNd, nn.Conv1d): + r""" + A ConvBn1d module is a module fused from Conv1d and BatchNorm1d, + attached with FakeQuantize modules for weight, + used in quantization aware training. + + We combined the interface of :class:`torch.nn.Conv1d` and + :class:`torch.nn.BatchNorm1d`. + + Similar to :class:`torch.nn.Conv1d`, with FakeQuantize modules initialized + to default. + + Attributes: + freeze_bn: + weight_fake_quant: fake quant module for weight + + """ + + _FLOAT_BN_MODULE: ClassVar[type[nn.BatchNorm1d]] = nn.BatchNorm1d + _FLOAT_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = None + _FLOAT_MODULE: ClassVar[type[nn.Module]] = nni.ConvBn1d # type: ignore[assignment] + _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv1d]] = nn.Conv1d + + def __init__( + self, + # Conv1d args + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=None, + padding_mode="zeros", + # BatchNorm1d args + # num_features: out_channels + eps=1e-05, + momentum=0.1, + # affine: True + # track_running_stats: True + # Args for this module + freeze_bn=False, + qconfig=None, + ): + kernel_size = _single(kernel_size) + stride = _single(stride) + padding = _single(padding) + dilation = _single(dilation) + _ConvBnNd.__init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + False, + _single(0), + groups, + bias, + padding_mode, + eps, + momentum, + freeze_bn, + qconfig, + dim=1, + ) + + +class ConvBnReLU1d(ConvBn1d): + r""" + A ConvBnReLU1d module is a module fused from Conv1d, BatchNorm1d and ReLU, + attached with FakeQuantize modules for weight, + used in quantization aware training. + + We combined the interface of :class:`torch.nn.Conv1d` and + :class:`torch.nn.BatchNorm1d` and :class:`torch.nn.ReLU`. + + Similar to `torch.nn.Conv1d`, with FakeQuantize modules initialized to + default. + + Attributes: + weight_fake_quant: fake quant module for weight + + """ + + # base class defines _FLOAT_MODULE as "ConvBn1d" + _FLOAT_MODULE: ClassVar[type[nn.Module]] = nni.ConvBnReLU1d + _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv1d]] = nn.Conv1d + _FLOAT_BN_MODULE: ClassVar[type[nn.BatchNorm1d]] = nn.BatchNorm1d + _FLOAT_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = nn.ReLU + # module class after fusing bn into conv + _FUSED_FLOAT_MODULE: ClassVar[Optional[type[nn.Module]]] = nni.ConvReLU1d + + def __init__( + self, + # Conv1d args + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=None, + padding_mode="zeros", + # BatchNorm1d args + # num_features: out_channels + eps=1e-05, + momentum=0.1, + # affine: True + # track_running_stats: True + # Args for this module + freeze_bn=False, + qconfig=None, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + eps, + momentum, + freeze_bn, + qconfig, + ) + + def forward(self, input): + return F.relu(self._forward(input)) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant) + + +class ConvReLU1d(nnqat.Conv1d, nni._FusedModule): + r"""A ConvReLU1d module is a fused module of Conv1d and ReLU, attached with + FakeQuantize modules for weight for + quantization aware training. + + We combined the interface of :class:`~torch.nn.Conv1d` and + :class:`~torch.nn.BatchNorm1d`. + + Attributes: + weight_fake_quant: fake quant module for weight + + """ + + _FLOAT_MODULE: ClassVar[type[nni.ConvReLU1d]] = nni.ConvReLU1d # type: ignore[assignment] + _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv1d]] = nn.Conv1d + _FLOAT_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = None + _FLOAT_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = nn.ReLU + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + qconfig=None, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + qconfig=qconfig, + ) + assert qconfig, "qconfig must be provided for QAT module" + self.qconfig = qconfig + self.weight_fake_quant = self.qconfig.weight() + + def forward(self, input): + return F.relu( + self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + +class ConvBn2d(_ConvBnNd, nn.Conv2d): + r""" + A ConvBn2d module is a module fused from Conv2d and BatchNorm2d, + attached with FakeQuantize modules for weight, + used in quantization aware training. + + We combined the interface of :class:`torch.nn.Conv2d` and + :class:`torch.nn.BatchNorm2d`. + + Similar to :class:`torch.nn.Conv2d`, with FakeQuantize modules initialized + to default. + + Attributes: + freeze_bn: + weight_fake_quant: fake quant module for weight + + """ + + _FLOAT_MODULE: ClassVar[type[nni.ConvBn2d]] = nni.ConvBn2d # type: ignore[assignment] + _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv2d]] = nn.Conv2d + _FLOAT_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = nn.BatchNorm2d + _FLOAT_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = None + + def __init__( + self, + # ConvNd args + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=None, + padding_mode="zeros", + # BatchNorm2d args + # num_features: out_channels + eps=1e-05, + momentum=0.1, + # affine: True + # track_running_stats: True + # Args for this module + freeze_bn=False, + qconfig=None, + ): + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + _ConvBnNd.__init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + False, + _pair(0), + groups, + bias, + padding_mode, + eps, + momentum, + freeze_bn, + qconfig, + dim=2, + ) + + +class ConvBnReLU2d(ConvBn2d): + r""" + A ConvBnReLU2d module is a module fused from Conv2d, BatchNorm2d and ReLU, + attached with FakeQuantize modules for weight, + used in quantization aware training. + + We combined the interface of :class:`torch.nn.Conv2d` and + :class:`torch.nn.BatchNorm2d` and :class:`torch.nn.ReLU`. + + Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to + default. + + Attributes: + weight_fake_quant: fake quant module for weight + + """ + + # base class defines _FLOAT_MODULE as "ConvBn2d" + _FLOAT_MODULE: ClassVar[type[nni.ConvBnReLU2d]] = nni.ConvBnReLU2d # type: ignore[assignment] + _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv2d]] = nn.Conv2d + _FLOAT_BN_MODULE: ClassVar[type[nn.BatchNorm2d]] = nn.BatchNorm2d + _FLOAT_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = nn.ReLU + # module class after fusing bn into conv + _FUSED_FLOAT_MODULE: ClassVar[Optional[type[nni.ConvReLU2d]]] = nni.ConvReLU2d + + def __init__( + self, + # Conv2d args + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=None, + padding_mode="zeros", + # BatchNorm2d args + # num_features: out_channels + eps=1e-05, + momentum=0.1, + # affine: True + # track_running_stats: True + # Args for this module + freeze_bn=False, + qconfig=None, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + eps, + momentum, + freeze_bn, + qconfig, + ) + + def forward(self, input): + return F.relu(self._forward(input)) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant) + + +class ConvReLU2d(nnqat.Conv2d, nni._FusedModule): + r"""A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with + FakeQuantize modules for weight for + quantization aware training. + + We combined the interface of :class:`~torch.nn.Conv2d` and + :class:`~torch.nn.BatchNorm2d`. + + Attributes: + weight_fake_quant: fake quant module for weight + + """ + + _FLOAT_MODULE: ClassVar[type[nn.Module]] = nni.ConvReLU2d # type: ignore[assignment] + _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv2d]] = nn.Conv2d + _FLOAT_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = None + _FLOAT_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = nn.ReLU + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + qconfig=None, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + qconfig=qconfig, + ) + assert qconfig, "qconfig must be provided for QAT module" + self.qconfig = qconfig + self.weight_fake_quant = self.qconfig.weight() + + def forward(self, input): + return F.relu( + self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + +class ConvBn3d(_ConvBnNd, nn.Conv3d): + r""" + A ConvBn3d module is a module fused from Conv3d and BatchNorm3d, + attached with FakeQuantize modules for weight, + used in quantization aware training. + + We combined the interface of :class:`torch.nn.Conv3d` and + :class:`torch.nn.BatchNorm3d`. + + Similar to :class:`torch.nn.Conv3d`, with FakeQuantize modules initialized + to default. + + Attributes: + freeze_bn: + weight_fake_quant: fake quant module for weight + + """ + + _FLOAT_MODULE: ClassVar[type[nni.ConvBn3d]] = nni.ConvBn3d # type: ignore[assignment] + _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv3d]] = nn.Conv3d + _FLOAT_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = nn.BatchNorm3d + _FLOAT_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = None + + def __init__( + self, + # ConvNd args + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=None, + padding_mode="zeros", + # BatchNorm3d args + # num_features: out_channels + eps=1e-05, + momentum=0.1, + # affine: True + # track_running_stats: True + # Args for this module + freeze_bn=False, + qconfig=None, + ): + kernel_size = _triple(kernel_size) + stride = _triple(stride) + padding = _triple(padding) + dilation = _triple(dilation) + _ConvBnNd.__init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + False, + _triple(0), + groups, + bias, + padding_mode, + eps, + momentum, + freeze_bn, + qconfig, + dim=3, + ) + + +class ConvBnReLU3d(ConvBn3d): + r""" + A ConvBnReLU3d module is a module fused from Conv3d, BatchNorm3d and ReLU, + attached with FakeQuantize modules for weight, + used in quantization aware training. + + We combined the interface of :class:`torch.nn.Conv3d` and + :class:`torch.nn.BatchNorm3d` and :class:`torch.nn.ReLU`. + + Similar to `torch.nn.Conv3d`, with FakeQuantize modules initialized to + default. + + Attributes: + weight_fake_quant: fake quant module for weight + + """ + + _FLOAT_MODULE: ClassVar[type[nni.ConvBnReLU3d]] = nni.ConvBnReLU3d # type: ignore[assignment] + _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv3d]] = nn.Conv3d + _FLOAT_BN_MODULE: ClassVar[type[nn.BatchNorm3d]] = nn.BatchNorm3d + _FLOAT_RELU_MODULE: ClassVar[Optional[type[nn.ReLU]]] = nn.ReLU + # module class after fusing bn into conv + _FUSED_FLOAT_MODULE: ClassVar[Optional[type[nni.ConvReLU3d]]] = nni.ConvReLU3d + + def __init__( + self, + # Conv3d args + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=None, + padding_mode="zeros", + # BatchNorm3d args + # num_features: out_channels + eps=1e-05, + momentum=0.1, + # affine: True + # track_running_stats: True + # Args for this module + freeze_bn=False, + qconfig=None, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + eps, + momentum, + freeze_bn, + qconfig, + ) + + def forward(self, input): + return F.relu(ConvBn3d._forward(self, input)) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + +class ConvReLU3d(nnqat.Conv3d, nni._FusedModule): + r"""A ConvReLU3d module is a fused module of Conv3d and ReLU, attached with + FakeQuantize modules for weight for + quantization aware training. + + We combined the interface of :class:`~torch.nn.Conv3d` and + :class:`~torch.nn.BatchNorm3d`. + + Attributes: + weight_fake_quant: fake quant module for weight + + """ + + _FLOAT_MODULE: ClassVar[type[nni.ConvReLU3d]] = nni.ConvReLU3d # type: ignore[assignment] + _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv3d]] = nn.Conv3d + _FLOAT_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = None + _FLOAT_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = nn.ReLU + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + qconfig=None, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + qconfig=qconfig, + ) + assert qconfig, "qconfig must be provided for QAT module" + self.qconfig = qconfig + self.weight_fake_quant = self.qconfig.weight() + + def forward(self, input): + return F.relu( + self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + +def update_bn_stats(mod): + if type(mod) in { + ConvBnReLU1d, + ConvBnReLU2d, + ConvBnReLU3d, + ConvBn1d, + ConvBn2d, + ConvBn3d, + }: + mod.update_bn_stats() + + +def freeze_bn_stats(mod): + if type(mod) in { + ConvBnReLU1d, + ConvBnReLU2d, + ConvBnReLU3d, + ConvBn1d, + ConvBn2d, + ConvBn3d, + }: + mod.freeze_bn_stats() diff --git a/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_fused.py b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_fused.py new file mode 100644 index 0000000000000000000000000000000000000000..c3ef6f71dff532c9d53071b0a26a8444125a118f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_fused.py @@ -0,0 +1,193 @@ +# mypy: allow-untyped-defs +import torch +import torch.ao.nn.intrinsic as nni +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import init +from torch.nn.parameter import Parameter +from torch.nn.utils.fusion import fuse_linear_bn_weights + + +__all__ = [ + "LinearBn1d", +] + + +class LinearBn1d(nn.modules.linear.Linear, nni._FusedModule): + r""" + A LinearBn1d module is a module fused from Linear and BatchNorm1d, attached + with FakeQuantize modules for weight, used in quantization aware training. + + We combined the interface of :class:`torch.nn.Linear` and + :class:torch.nn.BatchNorm1d`. + + Similar to :class:`torch.nn.Linear`, with FakeQuantize modules initialized + to default. + + Attributes: + freeze_bn: + weight_fake_quant: fake quant module for weight + + """ + + def __init__( + self, + # Linear args + in_features, + out_features, + bias=True, + # BatchNorm1d args + # num_features: out_features + eps=1e-05, + momentum=0.1, + # affine: True + # track_running_stats: True + # Args for this module + freeze_bn=False, + qconfig=None, + ): + nn.modules.linear.Linear.__init__(self, in_features, out_features, bias) + assert qconfig, "qconfig must be provided for QAT module" + self.qconfig = qconfig + self.freeze_bn = freeze_bn if self.training else True + self.bn = nn.BatchNorm1d(out_features, eps, momentum, True, True) + self.weight_fake_quant = self.qconfig.weight() + if bias: + self.bias = Parameter(torch.empty(out_features)) + else: + self.register_parameter("bias", None) + self.reset_bn_parameters() + + # this needs to be called after reset_bn_parameters, + # as they modify the same state + if self.training: + if freeze_bn: + self.freeze_bn_stats() + else: + self.update_bn_stats() + else: + self.freeze_bn_stats() + + def reset_running_stats(self): + self.bn.reset_running_stats() + + def reset_bn_parameters(self): + self.bn.reset_running_stats() + init.uniform_(self.bn.weight) + init.zeros_(self.bn.bias) + + def reset_parameters(self): + super().reset_parameters() + + def update_bn_stats(self): + self.freeze_bn = False + self.bn.training = True + return self + + def freeze_bn_stats(self): + self.freeze_bn = True + self.bn.training = False + return self + + def forward(self, input): + assert self.bn.running_var is not None + + # Scale the linear weights by BN's running statistics to reduce + # weight jitter, see https://arxiv.org/pdf/1806.08342.pdf, page 18 + # for motivation. + # + # Instead of + # + # x1 = F.linear(x0, fq(w), b) + # x2 = self.bn(x1) + # + # We have + # + # # scale the weight by previous batch's running statistics + # scale_factor = bn.w / bn.running_std_from_prev_batch + # # do the linear transformation without bias + # x1_scaled = F.linear(x0, fq(w * scale_factor), 0) + # # reverse the scaling and add original bias + # x1_orig = x1_scaled / scale_factor + b + # x2 = self.bn(x1_orig) + + running_std = torch.sqrt(self.bn.running_var + self.bn.eps) + scale_factor = self.bn.weight / running_std + weight_shape = [1] * len(self.weight.shape) + weight_shape[0] = -1 + bias_shape = [1] * len(self.weight.shape) + bias_shape[1] = -1 + scaled_weight = self.weight_fake_quant( + self.weight * scale_factor.reshape(weight_shape) + ) + if self.bias is not None: + zero_bias = torch.zeros_like(self.bias) + else: + zero_bias = torch.zeros(self.out_features, device=scaled_weight.device) + linear_out = F.linear(input, scaled_weight, zero_bias) + linear_out_orig = linear_out / scale_factor.reshape(bias_shape) + if self.bias is not None: + linear_out_orig = linear_out_orig + self.bias.reshape(bias_shape) + bn_out = self.bn(linear_out_orig) + return bn_out + + def train(self, mode=True): + """ + Batchnorm's training behavior is using the self.training flag. Prevent + changing it if BN is frozen. This makes sure that calling `model.train()` + on a model with a frozen BN will behave properly. + """ + self.training = mode + if not self.freeze_bn: + for module in self.children(): + module.train(mode) + return self + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Create a qat module from a float module or qparams_dict + + Args: `mod' a float module, either produced by torch.ao.quantization + utilities or directly from user + """ + assert type(mod) == nni.LinearBn1d, ( + "qat." + + cls.__name__ + + ".from_float only works for " + + nni.LinearBn1d.__name__ + ) + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + assert mod.qconfig, "Input float module must have a valid config" + qconfig = mod.qconfig + linear, bn = mod[0], mod[1] + qat_linearbn = cls( + linear.in_features, + linear.out_features, + linear.bias is not None, + bn.eps, + bn.momentum, + False, + qconfig, + ) + qat_linearbn.weight = linear.weight # type: ignore[assignment] + qat_linearbn.bias = linear.bias # type: ignore[assignment] + qat_linearbn.bn.weight = bn.weight # type: ignore[assignment] + qat_linearbn.bn.bias = bn.bias # type: ignore[assignment] + qat_linearbn.bn.running_mean = bn.running_mean # type: ignore[assignment] + qat_linearbn.bn.running_var = bn.running_var # type: ignore[assignment] + qat_linearbn.bn.num_batches_tracked = bn.num_batches_tracked # type: ignore[assignment] + return qat_linearbn + + def to_float(self): + linear = torch.nn.Linear(self.in_features, self.out_features) + assert self.bn.running_var is not None and self.bn.running_mean is not None + linear.weight, linear.bias = fuse_linear_bn_weights( + self.weight, + self.bias, + self.bn.running_mean, + self.bn.running_var, + self.bn.eps, + self.bn.weight, + self.bn.bias, + ) + return linear diff --git a/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_relu.py b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_relu.py new file mode 100644 index 0000000000000000000000000000000000000000..c969fab5a54de95747d86ac8036b8ee37381c65b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_relu.py @@ -0,0 +1,52 @@ +# mypy: allow-untyped-defs +import torch +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.qat as nnqat +import torch.nn.functional as F + + +class LinearReLU(nnqat.Linear, nni._FusedModule): + r""" + A LinearReLU module fused from Linear and ReLU modules, attached with + FakeQuantize modules for weight, used in + quantization aware training. + + We adopt the same interface as :class:`torch.nn.Linear`. + + Similar to `torch.ao.nn.intrinsic.LinearReLU`, with FakeQuantize modules initialized to + default. + + Attributes: + weight: fake quant module for weight + + Examples:: + + >>> # xdoctest: +SKIP + >>> m = nn.qat.LinearReLU(20, 30) + >>> input = torch.randn(128, 20) + >>> output = m(input) + >>> print(output.size()) + torch.Size([128, 30]) + """ + + _FLOAT_MODULE = nni.LinearReLU # type: ignore[assignment] + + def __init__(self, in_features, out_features, bias=True, qconfig=None): + super().__init__(in_features, out_features, bias, qconfig) + + def forward(self, input): + return F.relu(F.linear(input, self.weight_fake_quant(self.weight), self.bias)) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant) + + def to_float(self): + linear = torch.nn.Linear( + self.in_features, self.out_features, self.bias is not None + ) + linear.weight = torch.nn.Parameter(self.weight.detach()) + if self.bias is not None: + linear.bias = torch.nn.Parameter(self.bias.detach()) + relu = torch.nn.ReLU() + return torch.ao.nn.intrinsic.LinearReLU(linear, relu) diff --git a/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/__init__.py b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..904cf5dc8782c5630b06959c123f30ad5ff19a3e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/__init__.py @@ -0,0 +1,15 @@ +from .modules import * # noqa: F403 + + +__all__ = [ + "BNReLU2d", + "BNReLU3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "LinearReLU", + "LinearLeakyReLU", + "LinearTanh", + "ConvAdd2d", + "ConvAddReLU2d", +] diff --git a/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e731c8ef02449721c999364d6b5c54e5bee2dadd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/__init__.py b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ebc6df8afce25c62a5707136bc46cab16c49a83c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/__init__.py @@ -0,0 +1 @@ +from .modules import * # noqa: F403 diff --git a/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9515b572db8e5c59e7501ad5b505a8111c3c598a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__init__.py b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0210dd182c4cfc8075475fb01f46824d7ec7f100 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__init__.py @@ -0,0 +1,6 @@ +from .linear_relu import LinearReLU + + +__all__ = [ + "LinearReLU", +] diff --git a/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..297f547adf02f1dcb160542e3b85014181fe0239 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57cf71826bbe82a3437bfbaefe31fcf38b409bb5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py new file mode 100644 index 0000000000000000000000000000000000000000..826609e8c5a2ba7d77c82120f57ca1754b41df52 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py @@ -0,0 +1,61 @@ +# mypy: allow-untyped-defs +import torch +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.quantized.dynamic as nnqd + + +__all__ = ["LinearReLU"] + + +class LinearReLU(nnqd.Linear): + r""" + A LinearReLU module fused from Linear and ReLU modules that can be used + for dynamic quantization. + Supports both, FP16 and INT8 quantization. + + We adopt the same interface as :class:`torch.ao.nn.quantized.dynamic.Linear`. + + Attributes: + Same as torch.ao.nn.quantized.dynamic.Linear + + Examples:: + + >>> # xdoctest: +SKIP + >>> m = nn.intrinsic.quantized.dynamic.LinearReLU(20, 30) + >>> input = torch.randn(128, 20) + >>> output = m(input) + >>> print(output.size()) + torch.Size([128, 30]) + """ + + _FLOAT_MODULE = nni.LinearReLU # type: ignore[assignment] + + def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8): + super().__init__(in_features, out_features, bias, dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self._packed_params.dtype == torch.qint8: + # TODO check if we should set reduce_rage = True by default here + Y = torch.ops.quantized.linear_relu_dynamic( + x, self._packed_params._packed_params, reduce_range=True + ) + elif self._packed_params.dtype == torch.float16: + Y = torch.ops.quantized.linear_relu_dynamic_fp16( + x, self._packed_params._packed_params + ) + else: + raise RuntimeError("Unsupported dtype on dynamic quantized linear relu!") + return Y.to(x.dtype) + + def _get_name(self): + return "DynamicQuantizedLinearReLU" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + @classmethod + def from_reference(cls, ref_qlinear_relu): # type: ignore[override] + return super().from_reference(ref_qlinear_relu[0]) diff --git a/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__init__.py b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..040e70c6162ba4256a840676d87c3af526603f17 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__init__.py @@ -0,0 +1,18 @@ +from .bn_relu import BNReLU2d, BNReLU3d +from .conv_add import ConvAdd2d, ConvAddReLU2d +from .conv_relu import ConvReLU1d, ConvReLU2d, ConvReLU3d +from .linear_relu import LinearLeakyReLU, LinearReLU, LinearTanh + + +__all__ = [ + "LinearReLU", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "BNReLU2d", + "BNReLU3d", + "LinearLeakyReLU", + "LinearTanh", + "ConvAdd2d", + "ConvAddReLU2d", +] diff --git a/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5718f5a576794c67360c8045e8626f56a15829cc Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f696f01b3b07ef4a0e8505a078567302680dfc0c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_add.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_add.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9579aacfffe14ce843ae9d8c449b737454b0b1b3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_add.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_relu.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_relu.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7e0500f18fb7f7f4dcb075ac6392653e427ab5b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_relu.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/linear_relu.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/linear_relu.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1acc33e9667b0429ccf3f62294e5299ef9d4e358 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/linear_relu.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/bn_relu.py b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/bn_relu.py new file mode 100644 index 0000000000000000000000000000000000000000..37a4249887c5f3bfed27ad44d348fb466988b73e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/bn_relu.py @@ -0,0 +1,107 @@ +# mypy: allow-untyped-defs + +import torch +import torch.ao.nn.intrinsic +import torch.ao.nn.intrinsic.qat +import torch.ao.nn.quantized as nnq + + +__all__ = ["BNReLU2d", "BNReLU3d"] + + +class BNReLU2d(nnq.BatchNorm2d): + r""" + A BNReLU2d module is a fused module of BatchNorm2d and ReLU + + We adopt the same interface as :class:`torch.ao.nn.quantized.BatchNorm2d`. + + Attributes: + Same as torch.ao.nn.quantized.BatchNorm2d + + """ + + _FLOAT_MODULE = torch.ao.nn.intrinsic.BNReLU2d + + def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None): + super().__init__( + num_features, eps=eps, momentum=momentum, device=device, dtype=dtype + ) + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 4: + raise ValueError("Input shape must be `(N, C, H, W)`!") + return torch.ops.quantized.batch_norm2d_relu( + input, + self.weight, + self.bias, + self.running_mean, + self.running_var, + self.eps, + self.scale, + self.zero_point, + ) + + def _get_name(self): + return "QuantizedBNReLU2d" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + # TODO: Add qat support for BNReLU2d + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + @classmethod + def from_reference(cls, bn_relu, output_scale, output_zero_point): + return super().from_reference(bn_relu[0], output_scale, output_zero_point) + + +class BNReLU3d(nnq.BatchNorm3d): + r""" + A BNReLU3d module is a fused module of BatchNorm3d and ReLU + + We adopt the same interface as :class:`torch.ao.nn.quantized.BatchNorm3d`. + + Attributes: + Same as torch.ao.nn.quantized.BatchNorm3d + + """ + + _FLOAT_MODULE = torch.ao.nn.intrinsic.BNReLU3d + + def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None): + super().__init__( + num_features, eps=eps, momentum=momentum, device=device, dtype=dtype + ) + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 5: + raise ValueError("Input shape must be `(N, C, D, H, W)`!") + return torch.ops.quantized.batch_norm3d_relu( + input, + self.weight, + self.bias, + self.running_mean, + self.running_var, + self.eps, + self.scale, + self.zero_point, + ) + + def _get_name(self): + return "QuantizedBNReLU3d" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + # TODO: Add qat support for BNReLU3d + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + @classmethod + def from_reference(cls, bn_relu, output_scale, output_zero_point): + return super().from_reference(bn_relu[0], output_scale, output_zero_point) diff --git a/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_add.py b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_add.py new file mode 100644 index 0000000000000000000000000000000000000000..46c2cf471225a04d1548b145b6d80d57c2f7a73a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_add.py @@ -0,0 +1,147 @@ +# mypy: allow-untyped-defs +import torch +import torch.ao.nn.intrinsic +import torch.ao.nn.intrinsic.qat +import torch.ao.nn.quantized as nnq +import torch.nn.functional as F + + +_reverse_repeat_padding = nnq.modules.conv._reverse_repeat_padding + + +class ConvAdd2d(nnq.Conv2d): + r""" + A ConvAdd2d module is a fused module of Conv2d and Add + + We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`. + + Attributes: + Same as torch.ao.nn.quantized.Conv2d + + """ + + _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvAdd2d # type: ignore[assignment] + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + + def forward(self, input, extra_input): # type: ignore[override] + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 4: + raise ValueError("Input shape must be `(N, C, H, W)`!") + if self.padding_mode != "zeros": + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + input = F.pad( + input, _reversed_padding_repeated_twice, mode=self.padding_mode + ) + return torch.ops.quantized.conv2d_add( + input, extra_input, self._packed_params, self.scale, self.zero_point + ) + + def _get_name(self): + return "QuantizedConvAdd2d" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + @classmethod + def from_reference(cls, ref_qconv, output_scale, output_zero_point): + return super().from_reference(ref_qconv[0], output_scale, output_zero_point) + + +class ConvAddReLU2d(nnq.Conv2d): + r""" + A ConvAddReLU2d module is a fused module of Conv2d, Add and Relu + + We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`. + + Attributes: + Same as torch.ao.nn.quantized.Conv2d + + """ + + _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvAddReLU2d # type: ignore[assignment] + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + + def forward(self, input, extra_input): # type: ignore[override] + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 4: + raise ValueError("Input shape must be `(N, C, H, W)`!") + if self.padding_mode != "zeros": + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + input = F.pad( + input, _reversed_padding_repeated_twice, mode=self.padding_mode + ) + return torch.ops.quantized.conv2d_add_relu( + input, extra_input, self._packed_params, self.scale, self.zero_point + ) + + def _get_name(self): + return "QuantizedConvAddReLU2d" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + @classmethod + def from_reference(cls, ref_qconv, output_scale, output_zero_point): + return super().from_reference(ref_qconv[0], output_scale, output_zero_point) diff --git a/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py new file mode 100644 index 0000000000000000000000000000000000000000..b3f757a608685f0357d8b476aed65784e130771f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py @@ -0,0 +1,266 @@ +# mypy: allow-untyped-defs + +import torch +import torch.ao.nn.intrinsic +import torch.ao.nn.intrinsic.qat +import torch.ao.nn.quantized as nnq +import torch.nn.functional as F +from torch.nn.utils import fuse_conv_bn_weights + + +__all__ = [ + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", +] + +_reverse_repeat_padding = nnq.modules.conv._reverse_repeat_padding + + +# TODO: factor out the common parts to ConvNd +class ConvReLU1d(nnq.Conv1d): + r""" + A ConvReLU1d module is a fused module of Conv1d and ReLU + + We adopt the same interface as :class:`torch.ao.nn.quantized.Conv1d`. + + Attributes: + Same as torch.ao.nn.quantized.Conv1d + + """ + + _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU1d # type: ignore[assignment] + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 3: + raise ValueError("Input shape must be `(N, C, L)`!") + if self.padding_mode != "zeros": + # Padding in Conv1d is stored as (p, p), need to get (p,) + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1]) + input = F.pad( + input, _reversed_padding_repeated_twice, mode=self.padding_mode + ) + return torch.ops.quantized.conv1d_relu( + input, self._packed_params, self.scale, self.zero_point + ) + + def _get_name(self): + return "QuantizedConvReLU1d" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU1d: + assert mod.bn.running_var is not None and mod.bn.running_mean is not None + mod.weight, mod.bias = fuse_conv_bn_weights( + mod.weight, + mod.bias, + mod.bn.running_mean, + mod.bn.running_var, + mod.bn.eps, + mod.bn.weight, + mod.bn.bias, + ) + return super().from_float(mod, use_precomputed_fake_quant) + + @classmethod + def from_reference(cls, ref_qconv, output_scale, output_zero_point): + assert type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU1d, ( + "BatchNorm1d should be fused into Conv1d before converting to reference module" + ) + return super().from_reference(ref_qconv[0], output_scale, output_zero_point) + + +class ConvReLU2d(nnq.Conv2d): + r""" + A ConvReLU2d module is a fused module of Conv2d and ReLU + + We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`. + + Attributes: + Same as torch.ao.nn.quantized.Conv2d + + """ + + _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU2d # type: ignore[assignment] + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 4: + raise ValueError("Input shape must be `(N, C, H, W)`!") + if self.padding_mode != "zeros": + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + input = F.pad( + input, _reversed_padding_repeated_twice, mode=self.padding_mode + ) + return torch.ops.quantized.conv2d_relu( + input, self._packed_params, self.scale, self.zero_point + ) + + def _get_name(self): + return "QuantizedConvReLU2d" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU2d: + assert mod.bn.running_var is not None and mod.bn.running_mean is not None + mod.weight, mod.bias = fuse_conv_bn_weights( + mod.weight, + mod.bias, + mod.bn.running_mean, + mod.bn.running_var, + mod.bn.eps, + mod.bn.weight, + mod.bn.bias, + ) + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + @classmethod + def from_reference(cls, ref_qconv, output_scale, output_zero_point): + assert type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU2d, ( + "BatchNorm2d should be fused into Conv2d before converting to reference module" + ) + return super().from_reference(ref_qconv[0], output_scale, output_zero_point) + + +class ConvReLU3d(nnq.Conv3d): + r""" + A ConvReLU3d module is a fused module of Conv3d and ReLU + + We adopt the same interface as :class:`torch.ao.nn.quantized.Conv3d`. + + Attributes: Same as torch.ao.nn.quantized.Conv3d + + """ + + _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU3d # type: ignore[assignment] + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + ): + assert padding_mode != "reflect", "Conv3d does not support reflection padding" + super().__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 5: + raise ValueError("Input shape must be `(N, C, D, H, W)`!") + if self.padding_mode != "zeros": + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + input = F.pad( + input, _reversed_padding_repeated_twice, mode=self.padding_mode + ) + return torch.ops.quantized.conv3d_relu( + input, self._packed_params, self.scale, self.zero_point + ) + + def _get_name(self): + return "QuantizedConvReLU3d" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU3d: + assert mod.bn.running_var is not None and mod.bn.running_mean is not None + mod.weight, mod.bias = fuse_conv_bn_weights( + mod.weight, + mod.bias, + mod.bn.running_mean, + mod.bn.running_var, + mod.bn.eps, + mod.bn.weight, + mod.bn.bias, + ) + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + @classmethod + def from_reference(cls, ref_qconv, output_scale, output_zero_point): + assert type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU3d, ( + "BatchNorm3d should be fused into Conv3d before converting to reference module" + ) + return super().from_reference(ref_qconv[0], output_scale, output_zero_point) diff --git a/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py new file mode 100644 index 0000000000000000000000000000000000000000..cbfbafd5ca8db8582b4dd54ce01af17b5fa87a7b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py @@ -0,0 +1,190 @@ +# mypy: allow-untyped-defs +import torch +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.quantized as nnq +from torch.ao.nn.quantized.modules.utils import _quantize_weight + + +__all__ = [ + "LinearReLU", + "LinearLeakyReLU", + "LinearTanh", +] + + +class LinearReLU(nnq.Linear): + r""" + A LinearReLU module fused from Linear and ReLU modules + + We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`. + + Attributes: + Same as torch.ao.nn.quantized.Linear + + Examples:: + + >>> # xdoctest: +SKIP + >>> m = nn.intrinsic.LinearReLU(20, 30) + >>> input = torch.randn(128, 20) + >>> output = m(input) + >>> print(output.size()) + torch.Size([128, 30]) + """ + + _FLOAT_MODULE = nni.LinearReLU # type: ignore[assignment] + + def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8): + super().__init__(in_features, out_features, bias, dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.quantized.linear_relu( + x, self._packed_params._packed_params, self.scale, self.zero_point + ) + + def _get_name(self): + return "QuantizedLinearReLU" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant) + + @classmethod + def from_reference(cls, ref_linear_relu, output_scale, output_zero_point): + return super().from_reference( + ref_linear_relu[0], output_scale, output_zero_point + ) + + +class LinearLeakyReLU(nnq.Linear): + r""" + For onednn backend only + A LinearLeakyReLU module fused from Linear and LeakyReLU modules + We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`. + Attributes: + Same as torch.ao.nn.quantized.Linear + + negative_slope + Examples:: + >>> # xdoctest: +SKIP + >>> m = nn.intrinsic.LinearLeakyReLU(20, 30, 0.01) + >>> input = torch.randn(128, 20) + >>> output = m(input) + >>> print(output.size()) + torch.Size([128, 30]) + """ + + _FLOAT_MODULE = nni.LinearLeakyReLU # type: ignore[assignment] + + def __init__( + self, in_features, out_features, negative_slope, bias=True, dtype=torch.qint8 + ): + super().__init__(in_features, out_features, bias, dtype) + self.negative_slope = negative_slope + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.quantized.linear_leaky_relu( + x, + self._packed_params._packed_params, + self.scale, + self.zero_point, + self.negative_slope, + ) + + def _get_name(self): + return "QuantizedLinearLeakyReLU" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + assert type(mod) == nni.LinearLeakyReLU, ( + "Input float module should be LinearLeakyReLU" + ) + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + activation_post_process = mod.activation_post_process + leaky_relu = mod[1] + mod = mod[0] + weight_post_process = mod.qconfig.weight() # type: ignore[union-attr, operator] + weight_post_process(mod.weight) + dtype = weight_post_process.dtype + act_scale, act_zp = activation_post_process.calculate_qparams() # type: ignore[union-attr,operator] + assert dtype == torch.qint8, "Weight observer must have dtype torch.qint8" + qweight = _quantize_weight(mod.weight.float(), weight_post_process) + qlinear_leaky_relu = cls( + mod.in_features, mod.out_features, leaky_relu.negative_slope, dtype=dtype + ) + qlinear_leaky_relu.set_weight_bias(qweight, mod.bias) # type: ignore[arg-type] + qlinear_leaky_relu.scale = float(act_scale) + qlinear_leaky_relu.zero_point = int(act_zp) + return qlinear_leaky_relu + + @classmethod + def from_reference(cls, ref_mod, output_scale, output_zero_point): + linear = ref_mod[0] + leaky_relu = ref_mod[1] + qlinear_leaky_relu = cls( + linear.in_features, linear.out_features, leaky_relu.negative_slope + ) + qweight = linear.get_quantized_weight() + qlinear_leaky_relu.set_weight_bias(qweight, linear.bias) + qlinear_leaky_relu.scale = float(output_scale) + qlinear_leaky_relu.zero_point = int(output_zero_point) + return qlinear_leaky_relu + + +class LinearTanh(nnq.Linear): + r""" + A LinearTanh module fused from Linear and Tanh modules + + We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`. + + Attributes: + Same as torch.ao.nn.quantized.Linear + + Examples:: + + >>> # xdoctest: +SKIP + >>> m = nn.intrinsic.LinearTanh(20, 30) + >>> input = torch.randn(128, 20) + >>> output = m(input) + >>> print(output.size()) + torch.Size([128, 30]) + """ + + _FLOAT_MODULE = nni.LinearTanh # type: ignore[assignment] + + def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8): + super().__init__(in_features, out_features, bias, dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.quantized.linear_tanh( + x, self._packed_params._packed_params, self.scale, self.zero_point + ) + + def _get_name(self): + return "QuantizedLinearTanh" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + assert type(mod) == nni.LinearTanh, "Input float module should be LinearTanh" + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + activation_post_process = mod.activation_post_process + mod = mod[0] + weight_post_process = mod.qconfig.weight() # type: ignore[union-attr,operator] + weight_post_process(mod.weight) + dtype = weight_post_process.dtype + act_scale, act_zp = activation_post_process.calculate_qparams() # type: ignore[union-attr,operator] + assert dtype == torch.qint8, "Weight observer must have dtype torch.qint8" + qweight = _quantize_weight(mod.weight.float(), weight_post_process) + qlinear_tanh = cls(mod.in_features, mod.out_features, dtype=dtype) + qlinear_tanh.set_weight_bias(qweight, mod.bias) # type: ignore[arg-type] + qlinear_tanh.scale = float(act_scale) + qlinear_tanh.zero_point = int(act_zp) + return qlinear_tanh + + @classmethod + def from_reference(cls, ref_mod, output_scale, output_zero_point): + linear = ref_mod[0] + qlinear_tanh = cls(linear.in_features, linear.out_features) + qweight = linear.get_quantized_weight() + qlinear_tanh.set_weight_bias(qweight, linear.bias) + qlinear_tanh.scale = float(output_scale) + qlinear_tanh.zero_point = int(output_zero_point) + return qlinear_tanh diff --git a/phivenv/Lib/site-packages/torch/ao/nn/qat/__init__.py b/phivenv/Lib/site-packages/torch/ao/nn/qat/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ebc6df8afce25c62a5707136bc46cab16c49a83c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/qat/__init__.py @@ -0,0 +1 @@ +from .modules import * # noqa: F403 diff --git a/phivenv/Lib/site-packages/torch/ao/nn/qat/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/qat/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eda204d51db6d368cf453c12a72b66f320241fbe Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/qat/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/qat/dynamic/__init__.py b/phivenv/Lib/site-packages/torch/ao/nn/qat/dynamic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ebc6df8afce25c62a5707136bc46cab16c49a83c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/qat/dynamic/__init__.py @@ -0,0 +1 @@ +from .modules import * # noqa: F403 diff --git a/phivenv/Lib/site-packages/torch/ao/nn/qat/dynamic/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/qat/dynamic/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51a315b7c4fbe74d148989b3c0a8f4eed2a10deb Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/qat/dynamic/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/qat/dynamic/modules/__init__.py b/phivenv/Lib/site-packages/torch/ao/nn/qat/dynamic/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f431e4d629815cba952a9f67b8774604f188ceb8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/qat/dynamic/modules/__init__.py @@ -0,0 +1,4 @@ +from .linear import Linear + + +__all__ = ["Linear"] diff --git a/phivenv/Lib/site-packages/torch/ao/nn/qat/dynamic/modules/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/qat/dynamic/modules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6da9fac93504858514e3d031026ac1a8ad83de5b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/qat/dynamic/modules/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/qat/dynamic/modules/__pycache__/linear.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/qat/dynamic/modules/__pycache__/linear.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eeeb5c1fb92a7213486fe2b4681392a34ff0f8dc Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/qat/dynamic/modules/__pycache__/linear.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/qat/dynamic/modules/linear.py b/phivenv/Lib/site-packages/torch/ao/nn/qat/dynamic/modules/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..ddbd55c6defce4df89efad9212d43abb6cf142d3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/qat/dynamic/modules/linear.py @@ -0,0 +1,40 @@ +from typing import Optional, TYPE_CHECKING, Union + +import torch + + +if TYPE_CHECKING: + from torch.ao.quantization.qconfig import QConfig # noqa: TC004 + + +__all__ = ["Linear"] + + +class Linear(torch.ao.nn.qat.Linear): + r""" + A linear module attached with FakeQuantize modules for weight, + used for dynamic quantization aware training. + + We adopt the same interface as `torch.nn.Linear`, please see + https://pytorch.org/docs/stable/nn.html#torch.nn.Linear + for documentation. + + Similar to `torch.nn.Linear`, with FakeQuantize modules initialized to + default. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + qconfig: Optional["QConfig"] = None, + device: Optional[Union[int, str, torch.device]] = None, + dtype: Optional[str] = None, + ) -> None: + super().__init__(in_features, out_features, bias, qconfig, device, dtype) + if not torch.ao.quantization.qconfig._activation_is_memoryless(qconfig): # type: ignore[arg-type] + raise ValueError( + "Dynamic QAT requires a memoryless observer." + + "This means a MovingAverage observer with averaging constant equal to 1" + ) diff --git a/phivenv/Lib/site-packages/torch/ao/nn/qat/modules/__init__.py b/phivenv/Lib/site-packages/torch/ao/nn/qat/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8f46985c76c079764129e5543f0c40cce2b509ab --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/qat/modules/__init__.py @@ -0,0 +1,13 @@ +from .conv import Conv1d, Conv2d, Conv3d +from .embedding_ops import Embedding, EmbeddingBag +from .linear import Linear + + +__all__ = [ + "Linear", + "Conv1d", + "Conv2d", + "Conv3d", + "Embedding", + "EmbeddingBag", +] diff --git a/phivenv/Lib/site-packages/torch/ao/nn/qat/modules/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/qat/modules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a680fea699e068632488404c72eb6b5c5e9ec13 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/qat/modules/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/qat/modules/__pycache__/conv.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/qat/modules/__pycache__/conv.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5084b8b0d8d5cbfff441532f7b558857376688af Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/qat/modules/__pycache__/conv.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/qat/modules/__pycache__/embedding_ops.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/qat/modules/__pycache__/embedding_ops.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e402aba82a02906e4f4d02dd0874e85669407f44 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/qat/modules/__pycache__/embedding_ops.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/qat/modules/__pycache__/linear.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/qat/modules/__pycache__/linear.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea010c7c8f3381981ef7ab1af18a2a0a628d032f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/qat/modules/__pycache__/linear.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/qat/modules/conv.py b/phivenv/Lib/site-packages/torch/ao/nn/qat/modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..e6b8b365481f7541a0a9e0a06bd1d4fe5981aa5e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/qat/modules/conv.py @@ -0,0 +1,311 @@ +# mypy: allow-untyped-defs +from typing import ClassVar, Union + +import torch +import torch.nn as nn +from torch.ao.nn.intrinsic import _FusedModule +from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t +from torch.nn.modules.utils import _pair, _single, _triple + + +__all__ = ["Conv1d", "Conv2d", "Conv3d"] + + +class _ConvNd(nn.modules.conv._ConvNd): + _FLOAT_MODULE: ClassVar[type[nn.modules.conv._ConvNd]] + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: tuple[int, ...], + stride: tuple[int, ...], + padding: Union[str, tuple[int, ...]], + dilation: tuple[int, ...], + transposed: bool, + output_padding: tuple[int, ...], + groups: int, + bias: bool, + padding_mode: str, + qconfig=None, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + nn.modules.conv._ConvNd.__init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + bias, + padding_mode, + **factory_kwargs, + ) + assert qconfig, "qconfig must be provided for QAT module" + self.qconfig = qconfig + self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs) + + def forward(self, input): + return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) + + @staticmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Create a qat module from a float module + + Args: + `mod`: a float module, either produced by torch.ao.quantization utilities + or directly from user + """ + assert type(mod) == cls._FLOAT_MODULE, ( + "qat." + + cls.__name__ + + ".from_float only works for " + + cls._FLOAT_MODULE.__name__ + ) + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + assert mod.qconfig, "Input float module must have a valid qconfig" + if issubclass(type(mod), _FusedModule): + mod = mod[0] + qconfig = mod.qconfig + qat_conv = cls( + mod.in_channels, + mod.out_channels, + mod.kernel_size, + stride=mod.stride, + padding=mod.padding, + dilation=mod.dilation, + groups=mod.groups, + bias=mod.bias is not None, + padding_mode=mod.padding_mode, + qconfig=qconfig, + ) + qat_conv.weight = mod.weight + qat_conv.bias = mod.bias + return qat_conv + + def to_float(self): + """This works for both single qat conv, and the qat conv - relu modules + to convert the qat module to a floating point module + """ + cls = type(self) + conv = cls._FLOAT_CONV_MODULE( # type: ignore[attr-defined] + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + self.padding, + self.dilation, + self.groups, + self.bias is not None, + self.padding_mode, + ) + conv.weight = torch.nn.Parameter(self.weight.detach()) + if self.bias is not None: + conv.bias = torch.nn.Parameter(self.bias.detach()) + # conv relu + if issubclass(cls, _FusedModule): + modules = [conv] + assert hasattr(cls, "_FLOAT_RELU_MODULE") + relu = cls._FLOAT_RELU_MODULE() + modules.append(relu) + fused = cls._FLOAT_MODULE(*modules) + fused.train(self.training) + return fused + else: + return conv + + +class Conv1d(_ConvNd, nn.Conv1d): + r""" + A Conv1d module attached with FakeQuantize modules for weight, + used for quantization aware training. + + We adopt the same interface as :class:`~torch.nn.Conv1d` + + Similar to :class:`~torch.nn.Conv2d`, with FakeQuantize modules initialized to + default. + + Attributes: + weight_fake_quant: fake quant module for weight + """ + + _FLOAT_MODULE: ClassVar[type[nn.Conv1d]] = nn.Conv1d + _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv1d]] = nn.Conv1d + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: Union[str, _size_1_t] = 0, + dilation: _size_1_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + qconfig=None, + device=None, + dtype=None, + ) -> None: + kernel_size_ = _single(kernel_size) + stride_ = _single(stride) + padding_ = padding if isinstance(padding, str) else _single(padding) + dilation_ = _single(dilation) + super().__init__( + in_channels, + out_channels, + kernel_size_, + stride=stride_, + padding=padding_, + dilation=dilation_, + transposed=False, + output_padding=_single(0), + groups=groups, + bias=bias, + padding_mode=padding_mode, + qconfig=qconfig, + device=device, + dtype=dtype, + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + return super().from_float( + cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + +class Conv2d(_ConvNd, nn.Conv2d): + r""" + A Conv2d module attached with FakeQuantize modules for weight, + used for quantization aware training. + + We adopt the same interface as `torch.nn.Conv2d`, please see + https://pytorch.org/docs/stable/nn.html?highlight=conv2d#torch.nn.Conv2d + for documentation. + + Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to + default. + + Attributes: + weight_fake_quant: fake quant module for weight + """ + + _FLOAT_MODULE: ClassVar[type[nn.Conv2d]] = nn.Conv2d + _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv2d]] = nn.Conv2d + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: Union[str, _size_2_t] = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + qconfig=None, + device=None, + dtype=None, + ) -> None: + kernel_size_ = _pair(kernel_size) + stride_ = _pair(stride) + padding_ = padding if isinstance(padding, str) else _pair(padding) + dilation_ = _pair(dilation) + super().__init__( + in_channels, + out_channels, + kernel_size_, + stride=stride_, + padding=padding_, + dilation=dilation_, + transposed=False, + output_padding=_pair(0), + groups=groups, + bias=bias, + padding_mode=padding_mode, + qconfig=qconfig, + device=device, + dtype=dtype, + ) + + def forward(self, input): + return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + return super().from_float( + cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + +class Conv3d(_ConvNd, nn.Conv3d): + r""" + A Conv3d module attached with FakeQuantize modules for weight, + used for quantization aware training. + + We adopt the same interface as `torch.nn.Conv3d`, please see + https://pytorch.org/docs/stable/nn.html?highlight=conv3d#torch.nn.Conv3d + for documentation. + + Similar to `torch.nn.Conv3d`, with FakeQuantize modules initialized to + default. + + Attributes: + weight_fake_quant: fake quant module for weight + """ + + _FLOAT_MODULE: ClassVar[type[nn.Conv3d]] = nn.Conv3d + _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv3d]] = nn.Conv3d + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_3_t, + stride: _size_3_t = 1, + padding: Union[str, _size_3_t] = 0, + dilation: _size_3_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + qconfig=None, + device=None, + dtype=None, + ) -> None: + kernel_size_ = _triple(kernel_size) + stride_ = _triple(stride) + padding_ = padding if isinstance(padding, str) else _triple(padding) + dilation_ = _triple(dilation) + super().__init__( + in_channels, + out_channels, + kernel_size_, + stride=stride_, + padding=padding_, + dilation=dilation_, + transposed=False, + output_padding=_triple(0), + groups=groups, + bias=bias, + padding_mode=padding_mode, + qconfig=qconfig, + device=device, + dtype=dtype, + ) + + def forward(self, input): + return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + return super().from_float( + cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) diff --git a/phivenv/Lib/site-packages/torch/ao/nn/qat/modules/embedding_ops.py b/phivenv/Lib/site-packages/torch/ao/nn/qat/modules/embedding_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..29df01437835f260b07c2235dd142406d084aac7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/qat/modules/embedding_ops.py @@ -0,0 +1,250 @@ +# mypy: allow-untyped-defs +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + + +__all__ = ["Embedding", "EmbeddingBag"] + + +class Embedding(nn.Embedding): + r""" + An embedding bag module attached with FakeQuantize modules for weight, + used for quantization aware training. + + We adopt the same interface as `torch.nn.Embedding`, please see + https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html#torch.nn.Embedding + for documentation. + + Similar to `torch.nn.Embedding`, with FakeQuantize modules initialized to + default. + + Attributes: + weight: fake quant module for weight + """ + + _FLOAT_MODULE = nn.Embedding + + def __init__( + self, + num_embeddings, + embedding_dim, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + _weight=None, + device=None, + dtype=None, + qconfig=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + num_embeddings, + embedding_dim, + padding_idx, + max_norm, + norm_type, + scale_grad_by_freq, + sparse, + _weight, + **factory_kwargs, + ) + assert qconfig, "qconfig must be provided for QAT module" + assert qconfig.weight().qscheme == torch.per_channel_affine_float_qparams, ( + "Embedding weights requires a qscheme of torch.per_channel_affine_float_qparams Got " + + str(qconfig.weight().qscheme) + ) + self.qconfig = qconfig + self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs) + + def forward(self, input) -> Tensor: + return F.embedding( + input, + self.weight_fake_quant(self.weight), + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Create a qat module from a float module + + Args: `mod` a float module, either produced by torch.ao.quantization utilities + or directly from user + """ + assert type(mod) == cls._FLOAT_MODULE, ( + " qat." + + cls.__name__ + + ".from_float only works for " + + cls._FLOAT_MODULE.__name__ + ) + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + assert mod.qconfig, "Input float module must have a valid qconfig" + weight_qscheme = mod.qconfig.weight().qscheme # type: ignore[union-attr, operator] + assert weight_qscheme == torch.per_channel_affine_float_qparams, ( + "Embedding weights requires a qscheme of torch.per_channel_affine_float_qparams Got " + + str(weight_qscheme) + ) + + qconfig = mod.qconfig + qat_embedding_bag = cls( + mod.num_embeddings, + mod.embedding_dim, + mod.padding_idx, + mod.max_norm, + mod.norm_type, + mod.scale_grad_by_freq, + mod.sparse, + mod.weight, + qconfig=qconfig, + ) + + return qat_embedding_bag + + def to_float(self): + embedding_bag = torch.nn.Embedding( + self.num_embeddings, + self.embedding_dim, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + None, + ) + embedding_bag.weight = torch.nn.Parameter(self.weight.detach()) + embedding_bag.train(self.training) + return embedding_bag + + +class EmbeddingBag(nn.EmbeddingBag): + r""" + An embedding bag module attached with FakeQuantize modules for weight, + used for quantization aware training. + + We adopt the same interface as `torch.nn.EmbeddingBag`, please see + https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html#torch.nn.EmbeddingBag + for documentation. + + Similar to `torch.nn.EmbeddingBag`, with FakeQuantize modules initialized to + default. + + Attributes: + weight: fake quant module for weight + """ + + _FLOAT_MODULE = nn.EmbeddingBag + + def __init__( + self, + num_embeddings, + embedding_dim, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + mode="mean", + sparse=False, + _weight=None, + include_last_offset=False, + padding_idx=None, + qconfig=None, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + num_embeddings, + embedding_dim, + max_norm, + norm_type, + scale_grad_by_freq, + mode, + sparse, + _weight, + include_last_offset, + padding_idx, + **factory_kwargs, + ) + assert qconfig, "qconfig must be provided for QAT module" + assert qconfig.weight().qscheme == torch.per_channel_affine_float_qparams, ( + "Embedding Bag weights requires a qscheme of torch.per_channel_affine_float_qparams Got " + + str(qconfig.weight().qscheme) + ) + self.qconfig = qconfig + self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs) + + def forward(self, input, offsets=None, per_sample_weights=None) -> Tensor: + return F.embedding_bag( + input, + self.weight_fake_quant(self.weight), + offsets, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.mode, + self.sparse, + per_sample_weights, + self.include_last_offset, + self.padding_idx, + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Create a qat module from a float module + + Args: `mod` a float module, either produced by torch.ao.quantization utilities + or directly from user + """ + assert type(mod) == cls._FLOAT_MODULE, ( + " qat." + + cls.__name__ + + ".from_float only works for " + + cls._FLOAT_MODULE.__name__ + ) + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + assert mod.qconfig, "Input float module must have a valid qconfig" + weight_qscheme = mod.qconfig.weight().qscheme # type: ignore[union-attr, operator] + assert weight_qscheme == torch.per_channel_affine_float_qparams, ( + "Embedding Bag weights requires a qscheme of torch.per_channel_affine_float_qparams Got " + + str(weight_qscheme) + ) + + qconfig = mod.qconfig + qat_embedding_bag = cls( + mod.num_embeddings, + mod.embedding_dim, + mod.max_norm, + mod.norm_type, + mod.scale_grad_by_freq, + mod.mode, + mod.sparse, + mod.weight, + mod.include_last_offset, + mod.padding_idx, + qconfig=qconfig, + ) + + return qat_embedding_bag + + def to_float(self): + embedding_bag = torch.nn.EmbeddingBag( + self.num_embeddings, + self.embedding_dim, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.mode, + self.sparse, + None, + self.include_last_offset, + self.padding_idx, + ) + embedding_bag.weight = torch.nn.Parameter(self.weight.detach()) + embedding_bag.train(self.training) + return embedding_bag diff --git a/phivenv/Lib/site-packages/torch/ao/nn/qat/modules/linear.py b/phivenv/Lib/site-packages/torch/ao/nn/qat/modules/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..0fae508df1df58498b0a6f001cbc1c0493f60526 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/qat/modules/linear.py @@ -0,0 +1,97 @@ +# mypy: allow-untyped-defs +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.ao.nn.intrinsic import LinearReLU +from torch.nn.utils.parametrize import ( + is_parametrized, + transfer_parametrizations_and_params, + type_before_parametrizations, +) + + +__all__ = ["Linear"] + + +class Linear(nn.Linear): + r""" + A linear module attached with FakeQuantize modules for weight, + used for quantization aware training. + + We adopt the same interface as `torch.nn.Linear`, please see + https://pytorch.org/docs/stable/nn.html#torch.nn.Linear + for documentation. + + Similar to `torch.nn.Linear`, with FakeQuantize modules initialized to + default. + + Attributes: + weight: fake quant module for weight + """ + + _FLOAT_MODULE = nn.Linear + + def __init__( + self, + in_features, + out_features, + bias=True, + qconfig=None, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(in_features, out_features, bias, **factory_kwargs) + assert qconfig, "qconfig must be provided for QAT module" + self.qconfig = qconfig + self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs) + + def forward(self, input): + return F.linear(input, self.weight_fake_quant(self.weight), self.bias) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Create a qat module from a float module or qparams_dict + Args: `mod` a float module, either produced by torch.ao.quantization utilities + or directly from user + """ + assert type_before_parametrizations(mod) == cls._FLOAT_MODULE, ( + " qat." + + cls.__name__ + + ".from_float only works for " + + cls._FLOAT_MODULE.__name__ + ) + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + assert mod.qconfig, "Input float module must have a valid qconfig" + if type_before_parametrizations(mod) == LinearReLU: + mod = mod[0] + + qconfig = mod.qconfig + qat_linear = cls( + mod.in_features, + mod.out_features, + bias=mod.bias is not None, + qconfig=qconfig, + ) + + if is_parametrized(mod, "weight"): + transfer_parametrizations_and_params(mod, qat_linear, "weight") + else: + qat_linear.weight = mod.weight + + if is_parametrized(mod, "bias"): + transfer_parametrizations_and_params(mod, qat_linear, "bias") + else: + qat_linear.bias = mod.bias + + return qat_linear + + def to_float(self): + linear = torch.nn.Linear( + self.in_features, self.out_features, self.bias is not None + ) + linear.weight = torch.nn.Parameter(self.weight.detach()) + if self.bias is not None: + linear.bias = torch.nn.Parameter(self.bias.detach()) + linear.train(self.training) + return linear diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantizable/__init__.py b/phivenv/Lib/site-packages/torch/ao/nn/quantizable/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ebc6df8afce25c62a5707136bc46cab16c49a83c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/quantizable/__init__.py @@ -0,0 +1 @@ +from .modules import * # noqa: F403 diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantizable/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/quantizable/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7d91cab26d7327602caaa04b42f3b57e0e6f931 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/quantizable/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantizable/modules/__init__.py b/phivenv/Lib/site-packages/torch/ao/nn/quantizable/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9d83639d4e2323bc827a17b55acf3f9de6ee781d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/quantizable/modules/__init__.py @@ -0,0 +1,9 @@ +from .activation import MultiheadAttention +from .rnn import LSTM, LSTMCell + + +__all__ = [ + "LSTM", + "LSTMCell", + "MultiheadAttention", +] diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantizable/modules/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/quantizable/modules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9e532db823fc91cdaaef72d78144f5fdd4a53e5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/quantizable/modules/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantizable/modules/__pycache__/activation.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/quantizable/modules/__pycache__/activation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9451cd83278f05cddf1951331a3e800dff0cfcd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/quantizable/modules/__pycache__/activation.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantizable/modules/__pycache__/rnn.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/quantizable/modules/__pycache__/rnn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2aa5dfd393ccaec3dd155a158b0bca1f8dfccf8e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/quantizable/modules/__pycache__/rnn.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantizable/modules/activation.py b/phivenv/Lib/site-packages/torch/ao/nn/quantizable/modules/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..654cd51e67a6039a10c4a29cb953e92062eeb372 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/quantizable/modules/activation.py @@ -0,0 +1,552 @@ +# mypy: allow-untyped-defs +import warnings +from typing import Optional + +import torch +import torch.jit # this is needed to avoid a circular import +import torch.nn.functional as F +from torch import nn, Tensor + + +__all__ = ["MultiheadAttention"] + + +class MultiheadAttention(nn.MultiheadAttention): + _FLOAT_MODULE = nn.MultiheadAttention + + r"""Quantizable implementation of the MultiheadAttention. + + Note:: + Please, refer to :class:`~torch.nn.MultiheadAttention` for more + information + + Allows the model to jointly attend to information from different + representation subspaces. + See reference: Attention Is All You Need + + The original MHA module is not quantizable. + This reimplements it by explicitly instantiating the linear layers. + + .. math:: + \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) + + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + bias: add bias as module parameter. Default: True. + add_bias_kv: add bias to the key and value sequences at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + kdim: total number of features in key. Default: None. + vdim: total number of features in value. Default: None. + batch_first: If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Default: ``False`` (seq, batch, feature). + + Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set + to :attr:`embed_dim` such that query, key, and value have the same + number of features. + + Examples:: + + >>> import torch.ao.nn.quantizable as nnqa + >>> multihead_attn = nnqa.MultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value) + + Note:: + Please, follow the quantization flow to convert the quantizable MHA. + """ + __constants__ = ["batch_first"] + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + bias: bool = True, + add_bias_kv: bool = False, + add_zero_attn: bool = False, + kdim: Optional[int] = None, + vdim: Optional[int] = None, + batch_first: bool = False, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + embed_dim, + num_heads, + dropout, + bias, + add_bias_kv, + add_zero_attn, + kdim, + vdim, + batch_first, + **factory_kwargs, + ) + self.linear_Q = nn.Linear( + self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs + ) + self.linear_K = nn.Linear( + self.kdim, self.embed_dim, bias=bias, **factory_kwargs + ) + self.linear_V = nn.Linear( + self.vdim, self.embed_dim, bias=bias, **factory_kwargs + ) + # for the type: ignore, see https://github.com/pytorch/pytorch/issues/58969 + self.out_proj = nn.Linear( + self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs + ) # type: ignore[assignment] + + # Functionals + self.q_scaling_product = torch.ao.nn.quantized.FloatFunctional() + # note: importing torch.ao.nn.quantized at top creates a circular import + + # Quant/Dequant + self.quant_attn_output = torch.ao.quantization.QuantStub() + self.quant_attn_output_weights = torch.ao.quantization.QuantStub() + self.dequant_q = torch.ao.quantization.DeQuantStub() + self.dequant_k = torch.ao.quantization.DeQuantStub() + self.dequant_v = torch.ao.quantization.DeQuantStub() + + def _get_name(self): + return "QuantizableMultiheadAttention" + + @classmethod + def from_float(cls, other): + assert type(other) == cls._FLOAT_MODULE + assert hasattr(other, "qconfig"), "The float module must have 'qconfig'" + # Setting the dropout to 0.0! + observed = cls( + other.embed_dim, + other.num_heads, + other.dropout, + (other.in_proj_bias is not None), + (other.bias_k is not None), + other.add_zero_attn, + other.kdim, + other.vdim, + other.batch_first, + ) + observed.bias_k = other.bias_k + observed.bias_v = other.bias_v + observed.qconfig = other.qconfig + + # Set the linear weights + # for the type: ignores, see https://github.com/pytorch/pytorch/issues/58969 + observed.out_proj.weight = other.out_proj.weight + observed.out_proj.bias = other.out_proj.bias + if other._qkv_same_embed_dim: + # Use separate params + bias = other.in_proj_bias + _start = 0 + _end = _start + other.embed_dim + weight = other.in_proj_weight[_start:_end, :] + if bias is not None: + bias = torch.nn.Parameter(bias[_start:_end], bias.requires_grad) + observed.linear_Q.weight = torch.nn.Parameter(weight, weight.requires_grad) + observed.linear_Q.bias = bias + + bias = other.in_proj_bias + _start = _end + _end = _start + other.embed_dim + weight = other.in_proj_weight[_start:_end, :] + if bias is not None: + bias = torch.nn.Parameter(bias[_start:_end], bias.requires_grad) + observed.linear_K.weight = torch.nn.Parameter(weight, weight.requires_grad) + observed.linear_K.bias = bias + + bias = other.in_proj_bias + _start = _end + weight = other.in_proj_weight[_start:, :] + if bias is not None: + bias = torch.nn.Parameter(bias[_start:], bias.requires_grad) + observed.linear_V.weight = torch.nn.Parameter(weight, weight.requires_grad) + observed.linear_V.bias = bias + else: + observed.linear_Q.weight = nn.Parameter(other.q_proj_weight) + observed.linear_K.weight = nn.Parameter(other.k_proj_weight) + observed.linear_V.weight = nn.Parameter(other.v_proj_weight) + if other.in_proj_bias is None: + observed.linear_Q.bias = None + observed.linear_K.bias = None + observed.linear_V.bias = None + else: + observed.linear_Q.bias = nn.Parameter( + other.in_proj_bias[0 : other.embed_dim] + ) + observed.linear_K.bias = nn.Parameter( + other.in_proj_bias[other.embed_dim : (other.embed_dim * 2)] + ) + observed.linear_V.bias = nn.Parameter( + other.in_proj_bias[(other.embed_dim * 2) :] + ) + observed.eval() + # Explicit prepare + observed = torch.ao.quantization.prepare(observed, inplace=True) + return observed + + @torch.jit.unused + def dequantize(self): + r"""Utility to convert the quantized MHA back to float. + + The motivation for this is that it is not trivial to convert the weights + from the format that is used in the quantized version back to the + float. + """ + fp = self._FLOAT_MODULE( + self.embed_dim, + self.num_heads, + self.dropout, + (self.linear_Q._weight_bias()[1] is not None), # type: ignore[operator] + (self.bias_k is not None), + self.add_zero_attn, + self.kdim, + self.vdim, + self.batch_first, + ) + assert fp._qkv_same_embed_dim == self._qkv_same_embed_dim + if self.bias_k is not None: + fp.bias_k = nn.Parameter(self.bias_k.dequantize()) + if self.bias_v is not None: + fp.bias_v = nn.Parameter(self.bias_v.dequantize()) + + # Set the linear weights + # Note: Because the linear layers are quantized, mypy does not nkow how + # to deal with them -- might need to ignore the typing checks. + # for the type: ignore[has-type], see https://github.com/pytorch/pytorch/issues/58969 + w, b = self.out_proj._weight_bias() # type: ignore[operator, has-type] + fp.out_proj.weight = nn.Parameter(w.dequantize()) + if b is not None: + fp.out_proj.bias = nn.Parameter(b) + + wQ, bQ = self.linear_Q._weight_bias() # type: ignore[operator] + wQ = wQ.dequantize() + wK, bK = self.linear_K._weight_bias() # type: ignore[operator] + wK = wK.dequantize() + wV, bV = self.linear_V._weight_bias() # type: ignore[operator] + wV = wV.dequantize() + if fp._qkv_same_embed_dim: + # Use separate params + _start = 0 + _end = _start + fp.embed_dim + fp.in_proj_weight[_start:_end, :] = wQ + if fp.in_proj_bias is not None: + assert all(bQ == 0) + fp.in_proj_bias[_start:_end] = bQ + + _start = _end + _end = _start + fp.embed_dim + fp.in_proj_weight[_start:_end, :] = wK + if fp.in_proj_bias is not None: + assert all(bK == 0) + fp.in_proj_bias[_start:_end] = bK + + _start = _end + fp.in_proj_weight[_start:, :] = wV + if fp.in_proj_bias is not None: + assert all(bV == 0) + fp.in_proj_bias[_start:] = bV + else: + fp.q_proj_weight = nn.Parameter(wQ) + fp.k_proj_weight = nn.Parameter(wK) + fp.v_proj_weight = nn.Parameter(wV) + if fp.in_proj_bias is None: + self.linear_Q.bias = None + self.linear_K.bias = None + self.linear_V.bias = None + else: + fp.in_proj_bias[0 : fp.embed_dim] = bQ + fp.in_proj_bias[fp.embed_dim : (fp.embed_dim * 2)] = bK + fp.in_proj_bias[(fp.embed_dim * 2) :] = bV + + return fp + + @classmethod + def from_observed(cls, other): + # The whole flow is float -> observed -> quantized + # This class does float -> observed only + # See nn.quantized.MultiheadAttention + raise NotImplementedError( + "It looks like you are trying to prepare an " + "MHA module. Please, see " + "the examples on quantizable MHAs." + ) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + average_attn_weights: bool = True, + is_causal: bool = False, + ) -> tuple[Tensor, Optional[Tensor]]: + r""" + Note:: + Please, refer to :func:`~torch.nn.MultiheadAttention.forward` for more + information + + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + - Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + - is_causal: If specified, applies a causal mask as attention mask. Mutually exclusive with providing attn_mask. + Default: ``False``. + - average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across + heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an + effect when ``need_weights=True.``. Default: True (i.e. average weights across heads) + + - Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``. + - attn_output_weights: If ``average_attn_weights=True``, returns attention weights averaged + across heads of shape :math:`(N, L, S)`, where N is the batch size, L is the target sequence length, + S is the source sequence length. If ``average_attn_weights=False``, returns attention weights per + head of shape :math:`(N, num_heads, L, S)`. + """ + return self._forward_impl( + query, + key, + value, + key_padding_mask, + need_weights, + attn_mask, + average_attn_weights, + is_causal, + ) + + def _forward_impl( + self, + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + average_attn_weights: bool = True, + is_causal: bool = False, + ) -> tuple[Tensor, Optional[Tensor]]: + # This version will not deal with the static key/value pairs. + # Keeping it here for future changes. + # + # TODO: This method has some duplicate lines with the + # `torch.nn.functional.multi_head_attention`. Will need to refactor. + static_k = None + static_v = None + + if attn_mask is not None and is_causal: + raise AssertionError("Only allow causal mask or attn_mask") + + if is_causal: + raise AssertionError("causal mask not supported by AO MHA module") + + if self.batch_first: + query, key, value = (x.transpose(0, 1) for x in (query, key, value)) + + tgt_len, bsz, embed_dim_to_check = query.size() + assert self.embed_dim == embed_dim_to_check + # allow MHA to have different sizes for the feature dimension + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + + head_dim = self.embed_dim // self.num_heads + assert head_dim * self.num_heads == self.embed_dim, ( + "embed_dim must be divisible by num_heads" + ) + scaling = float(head_dim) ** -0.5 + + q = self.linear_Q(query) + k = self.linear_K(key) + v = self.linear_V(value) + + q = self.q_scaling_product.mul_scalar(q, scaling) + + if attn_mask is not None: + if attn_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for `attn_mask` in `nn.MultiheadAttention` is deprecated. " + "Use bool tensor instead.", + stacklevel=3, + ) + attn_mask = attn_mask.to(torch.bool) + assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, ( + f"Only float and bool types are supported for attn_mask, not {attn_mask.dtype}" + ) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError("The size of the 2D attn_mask is not correct.") + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [ + bsz * self.num_heads, + query.size(0), + key.size(0), + ]: + raise RuntimeError("The size of the 3D attn_mask is not correct.") + else: + raise RuntimeError( + f"attn_mask's dimension {attn_mask.dim()} is not supported" + ) + # attn_mask's dim is 3 now. + + # convert ByteTensor key_padding_mask to bool + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for `key_padding_mask` in `nn.MultiheadAttention` is deprecated. " + "Use bool tensor instead.", + stacklevel=3, + ) + key_padding_mask = key_padding_mask.to(torch.bool) + if self.bias_k is not None and self.bias_v is not None: + if static_k is None and static_v is None: + # Explicitly assert that bias_k and bias_v are not None + # in a way that TorchScript can understand. + bias_k = self.bias_k + assert bias_k is not None + bias_v = self.bias_v + assert bias_v is not None + + k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = F.pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = F.pad(key_padding_mask, (0, 1)) + else: + assert static_k is None, "bias cannot be added to static key." + assert static_v is None, "bias cannot be added to static value." + else: + assert self.bias_k is None + assert self.bias_v is None + + q = q.contiguous().view(tgt_len, bsz * self.num_heads, head_dim).transpose(0, 1) + if k is not None: + k = k.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1) + if v is not None: + v = v.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1) + + if static_k is not None: + assert static_k.size(0) == bsz * self.num_heads + assert static_k.size(2) == head_dim + k = static_k + + if static_v is not None: + assert static_v.size(0) == bsz * self.num_heads + assert static_v.size(2) == head_dim + v = static_v + + src_len = k.size(1) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + src_len += 1 + k_zeros = torch.zeros((k.size(0), 1) + k.size()[2:]) + if k.is_quantized: + k_zeros = torch.quantize_per_tensor( + k_zeros, k.q_scale(), k.q_zero_point(), k.dtype + ) + k = torch.cat([k, k_zeros], dim=1) + v_zeros = torch.zeros((v.size(0), 1) + k.size()[2:]) + if v.is_quantized: + v_zeros = torch.quantize_per_tensor( + v_zeros, v.q_scale(), v.q_zero_point(), v.dtype + ) + v = torch.cat([v, v_zeros], dim=1) + + if attn_mask is not None: + attn_mask = F.pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = F.pad(key_padding_mask, (0, 1)) + + # Leaving the quantized zone here + q = self.dequant_q(q) + k = self.dequant_k(k) + v = self.dequant_v(v) + attn_output_weights = torch.bmm(q, k.transpose(1, 2)) + assert list(attn_output_weights.size()) == [ + bsz * self.num_heads, + tgt_len, + src_len, + ] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float("-inf")) + else: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view( + bsz, self.num_heads, tgt_len, src_len + ) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + attn_output_weights = attn_output_weights.view( + bsz * self.num_heads, tgt_len, src_len + ) + + attn_output_weights = F.softmax(attn_output_weights, dim=-1) + attn_output_weights = F.dropout( + attn_output_weights, p=self.dropout, training=self.training + ) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * self.num_heads, tgt_len, head_dim] + if self.batch_first: + attn_output = attn_output.view(bsz, tgt_len, self.embed_dim) + else: + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, self.embed_dim) + ) + + # Reentering the quantized zone + attn_output = self.quant_attn_output(attn_output) + # for the type: ignore[has-type], see https://github.com/pytorch/pytorch/issues/58969 + attn_output = self.out_proj(attn_output) # type: ignore[has-type] + attn_output_weights = self.quant_attn_output_weights(attn_output_weights) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view( + bsz, self.num_heads, tgt_len, src_len + ) + if average_attn_weights: + attn_output_weights = attn_output_weights.mean(dim=1) + return attn_output, attn_output_weights + else: + return attn_output, None diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantizable/modules/rnn.py b/phivenv/Lib/site-packages/torch/ao/nn/quantizable/modules/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..3b2477c49e7e0557a987b8e611852baa4588dce4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/quantizable/modules/rnn.py @@ -0,0 +1,599 @@ +""" +We will recreate all the RNN modules as we require the modules to be decomposed +into its building blocks to be able to observe. +""" + +# mypy: allow-untyped-defs + +import numbers +import warnings +from typing import Optional + +import torch +from torch import Tensor + + +__all__ = ["LSTMCell", "LSTM"] + + +class LSTMCell(torch.nn.Module): + r"""A quantizable long short-term memory (LSTM) cell. + + For the description and the argument types, please, refer to :class:`~torch.nn.LSTMCell` + + `split_gates`: specify True to compute the input/forget/cell/output gates separately + to avoid an intermediate tensor which is subsequently chunk'd. This optimization can + be beneficial for on-device inference latency. This flag is cascaded down from the + parent classes. + + Examples:: + + >>> import torch.ao.nn.quantizable as nnqa + >>> rnn = nnqa.LSTMCell(10, 20) + >>> input = torch.randn(6, 10) + >>> hx = torch.randn(3, 20) + >>> cx = torch.randn(3, 20) + >>> output = [] + >>> for i in range(6): + ... hx, cx = rnn(input[i], (hx, cx)) + ... output.append(hx) + """ + + _FLOAT_MODULE = torch.nn.LSTMCell + __constants__ = ["split_gates"] # for jit.script + + def __init__( + self, + input_dim: int, + hidden_dim: int, + bias: bool = True, + device=None, + dtype=None, + *, + split_gates=False, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.input_size = input_dim + self.hidden_size = hidden_dim + self.bias = bias + self.split_gates = split_gates + + if not split_gates: + self.igates: torch.nn.Module = torch.nn.Linear( + input_dim, 4 * hidden_dim, bias=bias, **factory_kwargs + ) + self.hgates: torch.nn.Module = torch.nn.Linear( + hidden_dim, 4 * hidden_dim, bias=bias, **factory_kwargs + ) + self.gates: torch.nn.Module = torch.ao.nn.quantized.FloatFunctional() + else: + # keep separate Linear layers for each gate + self.igates = torch.nn.ModuleDict() + self.hgates = torch.nn.ModuleDict() + self.gates = torch.nn.ModuleDict() + for g in ["input", "forget", "cell", "output"]: + # pyre-fixme[29]: `Union[torch._tensor.Tensor, torch.nn.modules.module.Module]` + self.igates[g] = torch.nn.Linear( + input_dim, hidden_dim, bias=bias, **factory_kwargs + ) + # pyre-fixme[29]: `Union[torch._tensor.Tensor, torch.nn.modules.module.Module]` + self.hgates[g] = torch.nn.Linear( + hidden_dim, hidden_dim, bias=bias, **factory_kwargs + ) + # pyre-fixme[29]: `Union[torch._tensor.Tensor, torch.nn.modules.module.Module]` + self.gates[g] = torch.ao.nn.quantized.FloatFunctional() + + self.input_gate = torch.nn.Sigmoid() + self.forget_gate = torch.nn.Sigmoid() + self.cell_gate = torch.nn.Tanh() + self.output_gate = torch.nn.Sigmoid() + + self.fgate_cx = torch.ao.nn.quantized.FloatFunctional() + self.igate_cgate = torch.ao.nn.quantized.FloatFunctional() + self.fgate_cx_igate_cgate = torch.ao.nn.quantized.FloatFunctional() + + self.ogate_cy = torch.ao.nn.quantized.FloatFunctional() + + self.initial_hidden_state_qparams: tuple[float, int] = (1.0, 0) + self.initial_cell_state_qparams: tuple[float, int] = (1.0, 0) + self.hidden_state_dtype: torch.dtype = torch.quint8 + self.cell_state_dtype: torch.dtype = torch.quint8 + + def forward( + self, x: Tensor, hidden: Optional[tuple[Tensor, Tensor]] = None + ) -> tuple[Tensor, Tensor]: + if hidden is None or hidden[0] is None or hidden[1] is None: + hidden = self.initialize_hidden(x.shape[0], x.is_quantized) + hx, cx = hidden + + if not self.split_gates: + igates = self.igates(x) + hgates = self.hgates(hx) + gates = self.gates.add(igates, hgates) # type: ignore[operator] + + input_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1) + + input_gate = self.input_gate(input_gate) + forget_gate = self.forget_gate(forget_gate) + cell_gate = self.cell_gate(cell_gate) + out_gate = self.output_gate(out_gate) + else: + # apply each input + hidden projection and add together + gate = {} + for (key, gates), igates, hgates in zip( + self.gates.items(), # type: ignore[operator] + self.igates.values(), # type: ignore[operator] + self.hgates.values(), # type: ignore[operator] + ): + gate[key] = gates.add(igates(x), hgates(hx)) + + input_gate = self.input_gate(gate["input"]) + forget_gate = self.forget_gate(gate["forget"]) + cell_gate = self.cell_gate(gate["cell"]) + out_gate = self.output_gate(gate["output"]) + + fgate_cx = self.fgate_cx.mul(forget_gate, cx) + igate_cgate = self.igate_cgate.mul(input_gate, cell_gate) + fgate_cx_igate_cgate = self.fgate_cx_igate_cgate.add(fgate_cx, igate_cgate) + cy = fgate_cx_igate_cgate + + # TODO: make this tanh a member of the module so its qparams can be configured + tanh_cy = torch.tanh(cy) + hy = self.ogate_cy.mul(out_gate, tanh_cy) + return hy, cy + + def initialize_hidden( + self, batch_size: int, is_quantized: bool = False + ) -> tuple[Tensor, Tensor]: + h, c = ( + torch.zeros((batch_size, self.hidden_size)), + torch.zeros((batch_size, self.hidden_size)), + ) + if is_quantized: + (h_scale, h_zp) = self.initial_hidden_state_qparams + (c_scale, c_zp) = self.initial_cell_state_qparams + h = torch.quantize_per_tensor( + h, scale=h_scale, zero_point=h_zp, dtype=self.hidden_state_dtype + ) + c = torch.quantize_per_tensor( + c, scale=c_scale, zero_point=c_zp, dtype=self.cell_state_dtype + ) + return h, c + + def _get_name(self): + return "QuantizableLSTMCell" + + @classmethod + def from_params(cls, wi, wh, bi=None, bh=None, split_gates=False): + """Uses the weights and biases to create a new LSTM cell. + + Args: + wi, wh: Weights for the input and hidden layers + bi, bh: Biases for the input and hidden layers + """ + assert (bi is None) == (bh is None) # Either both None or both have values + input_size = wi.shape[1] + hidden_size = wh.shape[1] + cell = cls( + input_dim=input_size, + hidden_dim=hidden_size, + bias=(bi is not None), + split_gates=split_gates, + ) + + if not split_gates: + cell.igates.weight = torch.nn.Parameter(wi) + if bi is not None: + cell.igates.bias = torch.nn.Parameter(bi) + cell.hgates.weight = torch.nn.Parameter(wh) + if bh is not None: + cell.hgates.bias = torch.nn.Parameter(bh) + else: + # split weight/bias + for w, b, gates in zip([wi, wh], [bi, bh], [cell.igates, cell.hgates]): + for w_chunk, gate in zip(w.chunk(4, dim=0), gates.values()): # type: ignore[operator] + gate.weight = torch.nn.Parameter(w_chunk) + + if b is not None: + for b_chunk, gate in zip(b.chunk(4, dim=0), gates.values()): # type: ignore[operator] + gate.bias = torch.nn.Parameter(b_chunk) + + return cell + + @classmethod + def from_float(cls, other, use_precomputed_fake_quant=False, split_gates=False): + assert type(other) == cls._FLOAT_MODULE + assert hasattr(other, "qconfig"), "The float module must have 'qconfig'" + observed = cls.from_params( + other.weight_ih, + other.weight_hh, + other.bias_ih, + other.bias_hh, + split_gates=split_gates, + ) + observed.qconfig = other.qconfig + observed.igates.qconfig = other.qconfig + observed.hgates.qconfig = other.qconfig + if split_gates: + # also apply qconfig directly to Linear modules + for g in observed.igates.values(): + g.qconfig = other.qconfig + for g in observed.hgates.values(): + g.qconfig = other.qconfig + return observed + + +class _LSTMSingleLayer(torch.nn.Module): + r"""A single one-directional LSTM layer. + + The difference between a layer and a cell is that the layer can process a + sequence, while the cell only expects an instantaneous value. + """ + + def __init__( + self, + input_dim: int, + hidden_dim: int, + bias: bool = True, + device=None, + dtype=None, + *, + split_gates=False, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.cell = LSTMCell( + input_dim, hidden_dim, bias=bias, split_gates=split_gates, **factory_kwargs + ) + + def forward(self, x: Tensor, hidden: Optional[tuple[Tensor, Tensor]] = None): + result = [] + seq_len = x.shape[0] + for i in range(seq_len): + hidden = self.cell(x[i], hidden) + result.append(hidden[0]) # type: ignore[index] + result_tensor = torch.stack(result, 0) + return result_tensor, hidden + + @classmethod + def from_params(cls, *args, **kwargs): + cell = LSTMCell.from_params(*args, **kwargs) + layer = cls( + cell.input_size, cell.hidden_size, cell.bias, split_gates=cell.split_gates + ) + layer.cell = cell + return layer + + +class _LSTMLayer(torch.nn.Module): + r"""A single bi-directional LSTM layer.""" + + def __init__( + self, + input_dim: int, + hidden_dim: int, + bias: bool = True, + batch_first: bool = False, + bidirectional: bool = False, + device=None, + dtype=None, + *, + split_gates=False, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.batch_first = batch_first + self.bidirectional = bidirectional + self.layer_fw = _LSTMSingleLayer( + input_dim, hidden_dim, bias=bias, split_gates=split_gates, **factory_kwargs + ) + if self.bidirectional: + self.layer_bw = _LSTMSingleLayer( + input_dim, + hidden_dim, + bias=bias, + split_gates=split_gates, + **factory_kwargs, + ) + + def forward(self, x: Tensor, hidden: Optional[tuple[Tensor, Tensor]] = None): + if self.batch_first: + x = x.transpose(0, 1) + if hidden is None: + hx_fw, cx_fw = (None, None) + else: + hx_fw, cx_fw = hidden + hidden_bw: Optional[tuple[Tensor, Tensor]] = None + if self.bidirectional: + if hx_fw is None: + hx_bw = None + else: + hx_bw = hx_fw[1] + hx_fw = hx_fw[0] + if cx_fw is None: + cx_bw = None + else: + cx_bw = cx_fw[1] + cx_fw = cx_fw[0] + if hx_bw is not None and cx_bw is not None: + hidden_bw = hx_bw, cx_bw + if hx_fw is None and cx_fw is None: + hidden_fw = None + else: + hidden_fw = ( + torch.jit._unwrap_optional(hx_fw), + torch.jit._unwrap_optional(cx_fw), + ) + result_fw, hidden_fw = self.layer_fw(x, hidden_fw) + + if hasattr(self, "layer_bw") and self.bidirectional: + x_reversed = x.flip(0) + result_bw, hidden_bw = self.layer_bw(x_reversed, hidden_bw) + result_bw = result_bw.flip(0) + + result = torch.cat([result_fw, result_bw], result_fw.dim() - 1) + if hidden_fw is None and hidden_bw is None: + h = None + c = None + elif hidden_fw is None: + (h, c) = torch.jit._unwrap_optional(hidden_bw) + elif hidden_bw is None: + (h, c) = torch.jit._unwrap_optional(hidden_fw) + else: + h = torch.stack([hidden_fw[0], hidden_bw[0]], 0) # type: ignore[list-item] + c = torch.stack([hidden_fw[1], hidden_bw[1]], 0) # type: ignore[list-item] + else: + result = result_fw + h, c = torch.jit._unwrap_optional(hidden_fw) # type: ignore[assignment] + + if self.batch_first: + result.transpose_(0, 1) + + return result, (h, c) + + @classmethod + def from_float(cls, other, layer_idx=0, qconfig=None, **kwargs): + r""" + There is no FP equivalent of this class. This function is here just to + mimic the behavior of the `prepare` within the `torch.ao.quantization` + flow. + """ + assert hasattr(other, "qconfig") or (qconfig is not None) + + input_size = kwargs.get("input_size", other.input_size) + hidden_size = kwargs.get("hidden_size", other.hidden_size) + bias = kwargs.get("bias", other.bias) + batch_first = kwargs.get("batch_first", other.batch_first) + bidirectional = kwargs.get("bidirectional", other.bidirectional) + split_gates = kwargs.get("split_gates", False) + + layer = cls( + input_size, + hidden_size, + bias, + batch_first, + bidirectional, + split_gates=split_gates, + ) + layer.qconfig = getattr(other, "qconfig", qconfig) + wi = getattr(other, f"weight_ih_l{layer_idx}") + wh = getattr(other, f"weight_hh_l{layer_idx}") + bi = getattr(other, f"bias_ih_l{layer_idx}", None) + bh = getattr(other, f"bias_hh_l{layer_idx}", None) + + layer.layer_fw = _LSTMSingleLayer.from_params( + wi, wh, bi, bh, split_gates=split_gates + ) + + if other.bidirectional: + wi = getattr(other, f"weight_ih_l{layer_idx}_reverse") + wh = getattr(other, f"weight_hh_l{layer_idx}_reverse") + bi = getattr(other, f"bias_ih_l{layer_idx}_reverse", None) + bh = getattr(other, f"bias_hh_l{layer_idx}_reverse", None) + layer.layer_bw = _LSTMSingleLayer.from_params( + wi, wh, bi, bh, split_gates=split_gates + ) + return layer + + +class LSTM(torch.nn.Module): + r"""A quantizable long short-term memory (LSTM). + + For the description and the argument types, please, refer to :class:`~torch.nn.LSTM` + + Attributes: + layers : instances of the `_LSTMLayer` + + .. note:: + To access the weights and biases, you need to access them per layer. + See examples below. + + Examples:: + + >>> import torch.ao.nn.quantizable as nnqa + >>> rnn = nnqa.LSTM(10, 20, 2) + >>> input = torch.randn(5, 3, 10) + >>> h0 = torch.randn(2, 3, 20) + >>> c0 = torch.randn(2, 3, 20) + >>> output, (hn, cn) = rnn(input, (h0, c0)) + >>> # To get the weights: + >>> # xdoctest: +SKIP + >>> print(rnn.layers[0].weight_ih) + tensor([[...]]) + >>> print(rnn.layers[0].weight_hh) + AssertionError: There is no reverse path in the non-bidirectional layer + """ + + _FLOAT_MODULE = torch.nn.LSTM + + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int = 1, + bias: bool = True, + batch_first: bool = False, + dropout: float = 0.0, + bidirectional: bool = False, + device=None, + dtype=None, + *, + split_gates: bool = False, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.bias = bias + self.batch_first = batch_first + self.dropout = float(dropout) + self.bidirectional = bidirectional + self.training = False # Default to eval mode. If we want to train, we will explicitly set to training. + + if ( + not isinstance(dropout, numbers.Number) + or not 0 <= dropout <= 1 + or isinstance(dropout, bool) + ): + raise ValueError( + "dropout should be a number in range [0, 1] " + "representing the probability of an element being " + "zeroed" + ) + if dropout > 0: + warnings.warn( + "dropout option for quantizable LSTM is ignored. " + "If you are training, please, use nn.LSTM version " + "followed by `prepare` step." + ) + if num_layers == 1: + warnings.warn( + "dropout option adds dropout after all but last " + "recurrent layer, so non-zero dropout expects " + f"num_layers greater than 1, but got dropout={dropout} " + f"and num_layers={num_layers}" + ) + + layers = [ + _LSTMLayer( + self.input_size, + self.hidden_size, + self.bias, + batch_first=False, + bidirectional=self.bidirectional, + split_gates=split_gates, + **factory_kwargs, + ) + ] + layers.extend( + _LSTMLayer( + self.hidden_size, + self.hidden_size, + self.bias, + batch_first=False, + bidirectional=self.bidirectional, + split_gates=split_gates, + **factory_kwargs, + ) + for _ in range(1, num_layers) + ) + self.layers = torch.nn.ModuleList(layers) + + def forward(self, x: Tensor, hidden: Optional[tuple[Tensor, Tensor]] = None): + if self.batch_first: + x = x.transpose(0, 1) + + max_batch_size = x.size(1) + num_directions = 2 if self.bidirectional else 1 + if hidden is None: + zeros = torch.zeros( + num_directions, + max_batch_size, + self.hidden_size, + dtype=torch.float, + device=x.device, + ) + zeros.squeeze_(0) + if x.is_quantized: + zeros = torch.quantize_per_tensor( + zeros, scale=1.0, zero_point=0, dtype=x.dtype + ) + hxcx = [(zeros, zeros) for _ in range(self.num_layers)] + else: + hidden_non_opt = torch.jit._unwrap_optional(hidden) + if isinstance(hidden_non_opt[0], Tensor): + hx = hidden_non_opt[0].reshape( + self.num_layers, num_directions, max_batch_size, self.hidden_size + ) + cx = hidden_non_opt[1].reshape( + self.num_layers, num_directions, max_batch_size, self.hidden_size + ) + hxcx = [ + (hx[idx].squeeze(0), cx[idx].squeeze(0)) + for idx in range(self.num_layers) + ] + else: + hxcx = hidden_non_opt + + hx_list = [] + cx_list = [] + for idx, layer in enumerate(self.layers): + x, (h, c) = layer(x, hxcx[idx]) + hx_list.append(torch.jit._unwrap_optional(h)) + cx_list.append(torch.jit._unwrap_optional(c)) + hx_tensor = torch.stack(hx_list) + cx_tensor = torch.stack(cx_list) + + # We are creating another dimension for bidirectional case + # need to collapse it + hx_tensor = hx_tensor.reshape(-1, hx_tensor.shape[-2], hx_tensor.shape[-1]) + cx_tensor = cx_tensor.reshape(-1, cx_tensor.shape[-2], cx_tensor.shape[-1]) + + if self.batch_first: + x = x.transpose(0, 1) + + return x, (hx_tensor, cx_tensor) + + def _get_name(self): + return "QuantizableLSTM" + + @classmethod + def from_float(cls, other, qconfig=None, split_gates=False): + assert isinstance(other, cls._FLOAT_MODULE) + assert hasattr(other, "qconfig") or qconfig + observed = cls( + other.input_size, + other.hidden_size, + other.num_layers, + other.bias, + other.batch_first, + other.dropout, + other.bidirectional, + split_gates=split_gates, + ) + observed.qconfig = getattr(other, "qconfig", qconfig) + for idx in range(other.num_layers): + observed.layers[idx] = _LSTMLayer.from_float( + other, idx, qconfig, batch_first=False, split_gates=split_gates + ) + + # Prepare the model + if other.training: + observed.train() + observed = torch.ao.quantization.prepare_qat(observed, inplace=True) + else: + observed.eval() + observed = torch.ao.quantization.prepare(observed, inplace=True) + return observed + + @classmethod + def from_observed(cls, other): + # The whole flow is float -> observed -> quantized + # This class does float -> observed only + raise NotImplementedError( + "It looks like you are trying to convert a " + "non-quantizable LSTM module. Please, see " + "the examples on quantizable LSTMs." + ) diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/__init__.py b/phivenv/Lib/site-packages/torch/ao/nn/quantized/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0ca0e51ae283b8e739be7de93aac5cf0ba4a699b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/quantized/__init__.py @@ -0,0 +1,39 @@ +from . import functional +from .modules import * # noqa: F403 +from .modules import MaxPool2d + + +__all__ = [ + "BatchNorm2d", + "BatchNorm3d", + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", + "DeQuantize", + "ELU", + "Embedding", + "EmbeddingBag", + "GroupNorm", + "Hardswish", + "InstanceNorm1d", + "InstanceNorm2d", + "InstanceNorm3d", + "LayerNorm", + "LeakyReLU", + "Linear", + "LSTM", + "MultiheadAttention", + "Quantize", + "ReLU6", + "Sigmoid", + "Softmax", + "Dropout", + "PReLU", + # Wrapper modules + "FloatFunctional", + "FXFloatFunctional", + "QFunctional", +] diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/quantized/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..633b0a9198496f0393086a49774907da0b17e6b3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/quantized/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/__pycache__/functional.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/quantized/__pycache__/functional.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f5722c6665e782224ccfb56f8d77cc016b84924 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/quantized/__pycache__/functional.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/dynamic/__init__.py b/phivenv/Lib/site-packages/torch/ao/nn/quantized/dynamic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ebc6df8afce25c62a5707136bc46cab16c49a83c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/quantized/dynamic/__init__.py @@ -0,0 +1 @@ +from .modules import * # noqa: F403 diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/dynamic/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/quantized/dynamic/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a199431cf4626d78ba40224137c19519ed185bef Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/quantized/dynamic/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/__init__.py b/phivenv/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e87787461a5eaa73759b33153f27af36ba027ce5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/__init__.py @@ -0,0 +1,26 @@ +from .conv import ( + Conv1d, + Conv2d, + Conv3d, + ConvTranspose1d, + ConvTranspose2d, + ConvTranspose3d, +) +from .linear import Linear +from .rnn import GRU, GRUCell, LSTM, LSTMCell, RNNCell + + +__all__ = [ + "Linear", + "LSTM", + "GRU", + "LSTMCell", + "RNNCell", + "GRUCell", + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", +] diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02cb6d4997c083417b4281929bcd074f085b43a5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/conv.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/conv.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ccd505534ff2aa84ab549540fd24ab75b496b41 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/conv.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/linear.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/linear.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31a2ba142c624dfaa1bfa5a580bcc002d3e1d589 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/linear.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/rnn.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/rnn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b42a7ad83b2a3fca1c6a2bdeacd8d092ffdd547 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/rnn.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/conv.py b/phivenv/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..9c720f1ed3755bb2f25722e94c27ab4d139c31fb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/conv.py @@ -0,0 +1,523 @@ +# mypy: allow-untyped-defs +r"""Dynamically quantized convolution modules.""" + +import warnings +from typing import ClassVar, Optional + +import torch +import torch.ao.nn.quantized as nnq +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch._ops import ops +from torch.ao.nn.quantized.modules.conv import _reverse_repeat_padding +from torch.nn.common_types import _size_1_t +from torch.nn.modules.utils import _pair, _single, _triple + + +__all__ = [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", +] + + +class Conv1d(nnq.Conv1d): + r"""A dynamically quantized conv module with floating point tensors as inputs and outputs. + + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.Conv1d` and :class:`~torch.ao.nn.quantized.dynamic.Conv1d` and + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + + See :class:`~torch.nn.Conv1d` for other attributes. + + Examples:: + + >>> # xdoctest: +SKIP + >>> m = nn.quantized.dynamic.Conv1d(16, 33, 3, stride=2) + >>> input = torch.randn(20, 16, 100) + >>> output = m(input) + + """ + + _FLOAT_MODULE: ClassVar[type[nn.Conv1d]] = nn.Conv1d + _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = None + _NNI_CONV_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = None + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: _size_1_t = 0, + dilation: _size_1_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=None, + reduce_range=True, + ): + warnings.warn( + f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended" # noqa: B950 + ) + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size = _single(kernel_size) + stride = _single(stride) + padding = padding if isinstance(padding, str) else _single(padding) + dilation = _single(dilation) + + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _get_name(self): + return "DynamicQuantizedConv1d" + + def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor: + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 3: + raise ValueError("Input shape must be `(N, C, L)`!") + if self.padding_mode != "zeros": + # Padding in Conv1d is stored as (p, p), need to get (p,) + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1]) + input = F.pad( + input, _reversed_padding_repeated_twice, mode=self.padding_mode + ) + return ops.quantized.conv1d_dynamic(input, self._packed_params, reduce_range) + + +class Conv2d(nnq.Conv2d): + r"""A dynamically quantized conv module with floating point tensors as inputs and outputs. + + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.Conv2d` and :class:`~torch.ao.nn.quantized.dynamic.Conv2d` and + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + + See :class:`~torch.nn.Conv2d` for other attributes. + + Examples:: + + >>> # xdoctest: +SKIP + >>> # With square kernels and equal stride + >>> m = nn.quantized.dynamic.Conv2d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nn.quantized.dynamic.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) + >>> # non-square kernels and unequal stride and with padding and dilation + >>> m = nn.quantized.dynamic.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)) + >>> input = torch.randn(20, 16, 50, 100) + >>> output = m(input) + + """ + + _FLOAT_MODULE: ClassVar[type[nn.Conv2d]] = nn.Conv2d + _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = None + _NNI_CONV_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = None + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + ): + warnings.warn( + f"The current implementation of the {self._get_name()} module " + "has poor numerical accuracy and its use is not recommended" + ) + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _get_name(self): + return "DynamicQuantizedConv2d" + + def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor: + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 4: + raise ValueError("Input shape must be `(N, C, H, W)`!") + if self.padding_mode != "zeros": + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + input = F.pad( + input, _reversed_padding_repeated_twice, mode=self.padding_mode + ) + return ops.quantized.conv2d_dynamic(input, self._packed_params, reduce_range) + + +class Conv3d(nnq.Conv3d): + r"""A dynamically quantized conv module with floating point tensors as inputs and outputs. + + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.Conv3d` and :class:`~torch.ao.nn.quantized.dynamic.Conv3d` and + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + + See :class:`~torch.nn.Conv3d` for other attributes. + + Examples:: + + >>> # xdoctest: +SKIP + >>> # With square kernels and equal stride + >>> m = nn.quantized.dynamic.Conv3d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nn.quantized.dynamic.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2)) + >>> # non-square kernels and unequal stride and with padding and dilation + >>> m = nn.quantized.dynamic.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2), dilation=(1, 2, 2)) + >>> input = torch.randn(20, 16, 56, 56, 56) + >>> output = m(input) + + """ + + _FLOAT_MODULE: ClassVar[type[nn.Conv3d]] = nn.Conv3d + _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = None + _NNI_CONV_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = None + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + ): + warnings.warn( + f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended" # noqa: B950 + ) + assert padding_mode != "reflect", "Conv3d does not support reflection padding" + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size = _triple(kernel_size) + stride = _triple(stride) + padding = _triple(padding) + dilation = _triple(dilation) + super()._init( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + False, + _triple(0), + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _get_name(self): + return "DynamicQuantizedConv3d" + + def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor: + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 5: + raise ValueError("Input shape must be `(N, C, D, H, W)`!") + if self.padding_mode != "zeros": + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + input = F.pad( + input, _reversed_padding_repeated_twice, mode=self.padding_mode + ) + return ops.quantized.conv3d_dynamic(input, self._packed_params, reduce_range) + + +class ConvTranspose1d(nnq.ConvTranspose1d): + r"""A dynamically quantized transposed convolution module with floating point tensors as inputs and outputs. + + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.ConvTranspose1d`. + + For special notes, please, see :class:`~torch.ao.nn.quantized.dynamic.Conv1d` + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + See :class:`~torch.nn.ConvTranspose1d` for other attributes. + + Examples:: + + >>> # xdoctest: +SKIP + >>> # With square kernels and equal stride + >>> m = nndq.ConvTranspose1d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nndq.ConvTranspose1d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) + >>> output = m(input) + >>> # exact output size can be also specified as an argument + >>> downsample = nndq.Conv1d(16, 16, 3, stride=2, padding=1) + >>> upsample = nndq.ConvTranspose1d(16, 16, 3, stride=2, padding=1) + >>> h = downsample(input) + >>> h.size() + torch.Size([1, 16, 6]) + >>> output = upsample(h, output_size=input.size()) + >>> output.size() + torch.Size([1, 16, 12]) + """ + + _FLOAT_MODULE: ClassVar[type[nn.ConvTranspose1d]] = nn.ConvTranspose1d + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + padding_mode="zeros", + device=None, + dtype=None, + ): + warnings.warn( + f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended" # noqa: B950 + ) + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + output_padding, + groups, + bias, + dilation, + padding_mode, + **factory_kwargs, + ) + + def _get_name(self): + return "DynamicQuantizedConvTranspose1d" + + def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor: + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 3: + raise ValueError("Input shape must be `(N, C, L)`!") + return torch.ops.quantized.conv_transpose1d_dynamic( + input, self._packed_params, reduce_range + ) + + +class ConvTranspose2d(nnq.ConvTranspose2d): + r"""A dynamically quantized transposed convolution module with floating point tensors as inputs and outputs. + + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.ConvTranspose2d`. + + For special notes, please, see :class:`~torch.ao.nn.quantized.dynamic.Conv2d` + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + See :class:`~torch.nn.ConvTranspose2d` for other attributes. + + Examples:: + + >>> # xdoctest: +SKIP + >>> # With square kernels and equal stride + >>> m = nnq.ConvTranspose2d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nnq.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) + >>> output = m(input) + >>> # exact output size can be also specified as an argument + >>> downsample = nnq.Conv2d(16, 16, 3, stride=2, padding=1) + >>> upsample = nnq.ConvTranspose2d(16, 16, 3, stride=2, padding=1) + >>> h = downsample(input) + >>> h.size() + torch.Size([1, 16, 6, 6]) + >>> output = upsample(h, output_size=input.size()) + >>> output.size() + torch.Size([1, 16, 12, 12]) + """ + + _FLOAT_MODULE: ClassVar[type[nn.ConvTranspose2d]] = nn.ConvTranspose2d + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + padding_mode="zeros", + device=None, + dtype=None, + ): + warnings.warn( + f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended" # noqa: B950 + ) + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + output_padding, + groups, + bias, + dilation, + padding_mode, + **factory_kwargs, + ) + + def _get_name(self): + return "DynamicQuantizedConvTranspose2d" + + def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor: + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 4: + raise ValueError("Input shape must be `(N, C, H, W)`!") + return ops.quantized.conv_transpose2d_dynamic( + input, self._packed_params, reduce_range + ) + + +class ConvTranspose3d(nnq.ConvTranspose3d): + r"""A dynamically quantized transposed convolution module with floating point tensors as inputs and outputs. + + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.ConvTranspose3d`. + + For special notes, please, see :class:`~torch.ao.nn.quantized.dynamic.Conv3d` + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + See :class:`~torch.nn.ConvTranspose3d` for other attributes. + + Examples:: + + >>> # xdoctest: +SKIP + >>> # With cubic kernels and equal stride + >>> m = nnq.ConvTranspose3d(16, 33, 3, stride=2) + >>> # non-cubic kernels and unequal stride and with padding + >>> m = nnq.ConvTranspose3d(16, 33, (3, 3, 5), stride=(2, 1, 1), padding=(4, 2, 2)) + >>> output = m(input) + >>> # exact output size can be also specified as an argument + >>> downsample = nnq.Conv3d(16, 16, 3, stride=2, padding=1) + >>> upsample = nnq.ConvTranspose3d(16, 16, 3, stride=2, padding=1) + >>> h = downsample(input) + >>> h.size() + torch.Size([1, 16, 6, 6, 6]) + >>> output = upsample(h, output_size=input.size()) + >>> output.size() + torch.Size([1, 16, 12, 12, 12]) + """ + + _FLOAT_MODULE: ClassVar[type[nn.ConvTranspose3d]] = nn.ConvTranspose3d + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + padding_mode="zeros", + device=None, + dtype=None, + ): + warnings.warn( + f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended" # noqa: B950 + ) + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + output_padding, + groups, + bias, + dilation, + padding_mode, + **factory_kwargs, + ) + + def _get_name(self): + return "DynamicQuantizedConvTranspose3d" + + def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor: + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 5: + raise ValueError("Input shape must be `(N, C, T, H, W)`!") + return ops.quantized.conv_transpose3d_dynamic( + input, self._packed_params, reduce_range + ) diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/linear.py b/phivenv/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..a7981f315390882d5175c1e8152f3b8717553ce8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/linear.py @@ -0,0 +1,165 @@ +# mypy: allow-untyped-defs +import torch +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.quantized as nnq +from torch.ao.nn.quantized.modules.utils import _quantize_weight + + +__all__ = [ + "Linear", +] + + +class Linear(nnq.Linear): + r""" + A dynamic quantized linear module with floating point tensor as inputs and outputs. + We adopt the same interface as `torch.nn.Linear`, please see + https://pytorch.org/docs/stable/nn.html#torch.nn.Linear for documentation. + + Similar to :class:`torch.nn.Linear`, attributes will be randomly + initialized at module creation time and will be overwritten later + + Attributes: + weight (Tensor): the non-learnable quantized weights of the module which are of + shape :math:`(\text{out\_features}, \text{in\_features})`. + bias (Tensor): the non-learnable floating point bias of the module of shape + :math:`(\text{out\_features})`. If :attr:`bias` is ``True``, + the values are initialized to zero. + + Examples:: + + >>> # xdoctest: +SKIP + >>> m = nn.quantized.dynamic.Linear(20, 30) + >>> input = torch.randn(128, 20) + >>> output = m(input) + >>> print(output.size()) + torch.Size([128, 30]) + """ + + # version used in this class is different from the parent class nnq.Linear + _version = 4 + + def __init__(self, in_features, out_features, bias_=True, dtype=torch.qint8): + super().__init__(in_features, out_features, bias_, dtype=dtype) + # We don't muck around with buffers or attributes or anything here + # to keep the module simple. *everything* is simply a Python attribute. + # Serialization logic is explicitly handled in the below serialization and + # deserialization modules + self.version = 4 + + def forward(self, x): + # Note that we can handle self.bias == None case. + if self._packed_params.dtype == torch.qint8: + if self.version is None or self.version < 4: + Y = torch.ops.quantized.linear_dynamic( + x, self._packed_params._packed_params + ) + else: + Y = torch.ops.quantized.linear_dynamic( + x, self._packed_params._packed_params, reduce_range=True + ) + elif self._packed_params.dtype == torch.float16: + Y = torch.ops.quantized.linear_dynamic_fp16( + x, self._packed_params._packed_params + ) + else: + raise RuntimeError("Unsupported dtype on dynamic quantized linear!") + return Y.to(x.dtype) + + def _get_name(self): + return "DynamicQuantizedLinear" + + def extra_repr(self): + extra_repr_str = f"in_features={self.in_features}, out_features={self.out_features}, dtype={self._packed_params.dtype}" + if self._packed_params.dtype == torch.qint8: + extra_repr_str += f", qscheme={self.weight().qscheme()}" + return extra_repr_str + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + self.version = version + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Create a dynamic quantized module from a float module or qparams_dict + + Args: + mod (Module): a float module, either produced by torch.ao.quantization + utilities or provided by the user + """ + float_modules = [ + torch.nn.Linear, + torch.nn.modules.linear.NonDynamicallyQuantizableLinear, + torch.ao.nn.intrinsic.modules.fused.LinearReLU, + torch.ao.nn.qat.dynamic.Linear, + ] + + assert type(mod) in float_modules, ( + "nn.quantized.dynamic.Linear.from_float only works for one of" + + str([float_mod.__name__ for float_mod in float_modules]) + ) + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + if type(mod) == nni.LinearReLU: + mod = mod[0] + if mod.qconfig is not None and mod.qconfig.weight is not None: + weight_observer = mod.qconfig.weight() + else: + # We have the circular import issues if we import the qconfig in the beginning of this file: + # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the + # import until we need it. + from torch.ao.quantization.qconfig import default_dynamic_qconfig + + weight_observer = default_dynamic_qconfig.weight() + dtype = weight_observer.dtype + assert dtype in [torch.qint8, torch.float16], ( + "The only supported dtypes for " + f"dynamic quantized linear are qint8 and float16 got: {dtype}" + ) + weight_observer(mod.weight) + if dtype == torch.qint8: + qweight = _quantize_weight(mod.weight.float(), weight_observer) + elif dtype == torch.float16: + qweight = mod.weight.float() + else: + raise RuntimeError( + "Unsupported dtype specified for dynamic quantized Linear!" + ) + qlinear = cls(mod.in_features, mod.out_features, dtype=dtype) + qlinear.set_weight_bias(qweight, mod.bias) + return qlinear + + @classmethod + def from_reference(cls, ref_qlinear): # type: ignore[override] + """Create a (fbgemm/qnnpack) dynamic quantized module from a reference quantized + module + Args: + ref_qlinear (Module): a reference quantized module, either produced by + torch.ao.quantization functions or provided by the user + """ + qlinear = cls( + ref_qlinear.in_features, + ref_qlinear.out_features, + dtype=ref_qlinear.weight_dtype, + ) + qweight = ref_qlinear.get_quantized_weight() + bias = ref_qlinear.bias + qlinear.set_weight_bias(qweight, bias) + return qlinear diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/rnn.py b/phivenv/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..7a45d836bf9cc5fc92a899f5acb5f33bf604bee2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/rnn.py @@ -0,0 +1,1363 @@ +# mypy: allow-untyped-defs +import numbers +import warnings +from typing_extensions import deprecated + +import torch +import torch.nn as nn +from torch import Tensor # noqa: F401 +from torch._jit_internal import Dict, List, Optional, Tuple, Union # noqa: F401 +from torch.ao.nn.quantized.modules.utils import _quantize_weight +from torch.nn.utils.rnn import PackedSequence + + +__all__ = [ + "pack_weight_bias", + "PackedParameter", + "RNNBase", + "LSTM", + "GRU", + "RNNCellBase", + "RNNCell", + "LSTMCell", + "GRUCell", + "apply_permutation", +] + + +def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: + return tensor.index_select(dim, permutation) + + +@deprecated( + "`apply_permutation` is deprecated, please use `tensor.index_select(dim, permutation)` instead", + category=FutureWarning, +) +def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: + return _apply_permutation(tensor, permutation, dim) + + +def pack_weight_bias(qweight, bias, dtype): + if dtype == torch.qint8: + # for each layer, for each direction we need to quantize and pack + # weights and pack parameters in this order: + # + # w_ih, w_hh + packed_weight = torch.ops.quantized.linear_prepack(qweight, bias) + + return packed_weight + else: + # for each layer, for each direction we need to quantize and pack + # weights and pack parameters in this order: + # + # packed_ih, packed_hh, b_ih, b_hh + packed_weight = torch.ops.quantized.linear_prepack_fp16(qweight, bias) + + return packed_weight + + +class PackedParameter(torch.nn.Module): + def __init__(self, param): + super().__init__() + self.param = param + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + "param"] = self.param + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + self.param = state_dict[prefix + "param"] + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + +class RNNBase(torch.nn.Module): + _FLOAT_MODULE = nn.RNNBase + + _version = 2 + + def __init__( + self, + mode, + input_size, + hidden_size, + num_layers=1, + bias=True, + batch_first=False, + dropout=0.0, + bidirectional=False, + dtype=torch.qint8, + ): + super().__init__() + + self.mode = mode + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.bias = bias + self.batch_first = batch_first + self.dropout = float(dropout) + self.bidirectional = bidirectional + self.dtype = dtype + self.version = 2 + self.training = False + num_directions = 2 if bidirectional else 1 + + # "type: ignore" is required since ints and Numbers are not fully comparable + # https://github.com/python/mypy/issues/8566 + if ( + not isinstance(dropout, numbers.Number) + or not 0 <= dropout <= 1 # type: ignore[operator] + or isinstance(dropout, bool) + ): + raise ValueError( + "dropout should be a number in range [0, 1] " + "representing the probability of an element being " + "zeroed" + ) + if dropout > 0 and num_layers == 1: # type: ignore[operator] + warnings.warn( + "dropout option adds dropout after all but last " + "recurrent layer, so non-zero dropout expects " + f"num_layers greater than 1, but got dropout={dropout} and " + f"num_layers={num_layers}" + ) + + if mode == "LSTM": + gate_size = 4 * hidden_size + elif mode == "GRU": + gate_size = 3 * hidden_size + else: + raise ValueError("Unrecognized RNN mode: " + mode) + + _all_weight_values = [] + for layer in range(num_layers): + for _ in range(num_directions): + layer_input_size = ( + input_size if layer == 0 else hidden_size * num_directions + ) + + w_ih = torch.randn(gate_size, layer_input_size).to(torch.float) + w_hh = torch.randn(gate_size, hidden_size).to(torch.float) + b_ih = torch.randn(gate_size).to(torch.float) + b_hh = torch.randn(gate_size).to(torch.float) + if dtype == torch.qint8: + w_ih = torch.quantize_per_tensor( + w_ih, scale=0.1, zero_point=0, dtype=torch.qint8 + ) + w_hh = torch.quantize_per_tensor( + w_hh, scale=0.1, zero_point=0, dtype=torch.qint8 + ) + packed_ih = torch.ops.quantized.linear_prepack(w_ih, b_ih) + packed_hh = torch.ops.quantized.linear_prepack(w_hh, b_hh) + if self.version is None or self.version < 2: + cell_params = ( + torch.ops.quantized.make_quantized_cell_params_dynamic( + packed_ih, packed_hh, b_ih, b_hh + ) + ) + else: + cell_params = ( + torch.ops.quantized.make_quantized_cell_params_dynamic( + packed_ih, packed_hh, b_ih, b_hh, True + ) + ) + else: + packed_ih = torch.ops.quantized.linear_prepack_fp16(w_ih, b_ih) + packed_hh = torch.ops.quantized.linear_prepack_fp16(w_hh, b_hh) + cell_params = torch.ops.quantized.make_quantized_cell_params_fp16( + packed_ih, packed_hh + ) + + _all_weight_values.append(PackedParameter(cell_params)) + self._all_weight_values = torch.nn.ModuleList(_all_weight_values) + + def _get_name(self): + return "DynamicQuantizedRNN" + + def extra_repr(self): + s = "{input_size}, {hidden_size}" + if self.num_layers != 1: + s += ", num_layers={num_layers}" + if self.bias is not True: + s += ", bias={bias}" + if self.batch_first is not False: + s += ", batch_first={batch_first}" + if self.dropout != 0: + s += ", dropout={dropout}" + if self.bidirectional is not False: + s += ", bidirectional={bidirectional}" + return s.format(**self.__dict__) + + def __repr__(self): + # We don't want to show `ModuleList` children, hence custom + # `__repr__`. This is the same as nn.Module.__repr__, except the check + # for the `PackedParameter` and `nn.ModuleList`. + # You should still override `extra_repr` to add more info. + extra_lines = [] + extra_repr = self.extra_repr() + # empty string will be split into list [''] + if extra_repr: + extra_lines = extra_repr.split("\n") + child_lines = [] + for key, module in self._modules.items(): + if isinstance(module, (PackedParameter, nn.ModuleList)): + continue + mod_str = repr(module) + mod_str = nn.modules.module._addindent(mod_str, 2) + child_lines.append("(" + key + "): " + mod_str) + lines = extra_lines + child_lines + + main_str = self._get_name() + "(" + if lines: + # simple one-liner info, which most builtin Modules will use + if len(extra_lines) == 1 and not child_lines: + main_str += extra_lines[0] + else: + main_str += "\n " + "\n ".join(lines) + "\n" + + main_str += ")" + return main_str + + def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None: + expected_input_dim = 2 if batch_sizes is not None else 3 + if input.dim() != expected_input_dim: + raise RuntimeError( + f"input must have {expected_input_dim} dimensions, got {input.dim()}" + ) + if self.input_size != input.size(-1): + raise RuntimeError( + f"input.size(-1) must be equal to input_size. Expected {self.input_size}, got {input.size(-1)}" + ) + + def get_expected_hidden_size( + self, input: Tensor, batch_sizes: Optional[Tensor] + ) -> tuple[int, int, int]: + if batch_sizes is not None: + mini_batch = int(batch_sizes[0]) + else: + mini_batch = input.size(0) if self.batch_first else input.size(1) + num_directions = 2 if self.bidirectional else 1 + expected_hidden_size = ( + self.num_layers * num_directions, + mini_batch, + self.hidden_size, + ) + return expected_hidden_size + + def check_hidden_size( + self, + hx: Tensor, + expected_hidden_size: tuple[int, int, int], + msg: str = "Expected hidden size {}, got {}", + ) -> None: + if hx.size() != expected_hidden_size: + raise RuntimeError(msg.format(expected_hidden_size, list(hx.size()))) + + def check_forward_args( + self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor] + ) -> None: + self.check_input(input, batch_sizes) + expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) + self.check_hidden_size( + hidden, expected_hidden_size, msg="Expected hidden size {}, got {}" + ) + + def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]) -> Tensor: + if permutation is None: + return hx + return _apply_permutation(hx, permutation) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + self.version = version + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def set_weight_bias(self, weight_bias_dict): + def weight_bias_name(ihhh, layer, suffix): + weight_name = f"weight_{ihhh}_l{layer}{suffix}" + bias_name = f"bias_{ihhh}_l{layer}{suffix}" + return weight_name, bias_name + + num_directions = 2 if self.bidirectional else 1 + # TODO: dedup with __init__ of RNNBase + _all_weight_values = [] + for layer in range(self.num_layers): + for direction in range(num_directions): + suffix = "_reverse" if direction == 1 else "" + w_ih_name, b_ih_name = weight_bias_name("ih", layer, suffix) + w_hh_name, b_hh_name = weight_bias_name("hh", layer, suffix) + w_ih = weight_bias_dict[w_ih_name] + b_ih = weight_bias_dict[b_ih_name] + w_hh = weight_bias_dict[w_hh_name] + b_hh = weight_bias_dict[b_hh_name] + if w_ih.dtype == torch.qint8: + packed_ih = torch.ops.quantized.linear_prepack(w_ih, b_ih) + packed_hh = torch.ops.quantized.linear_prepack(w_hh, b_hh) + if self.version is None or self.version < 2: + cell_params = ( + torch.ops.quantized.make_quantized_cell_params_dynamic( + packed_ih, packed_hh, b_ih, b_hh + ) + ) + else: + cell_params = ( + torch.ops.quantized.make_quantized_cell_params_dynamic( + packed_ih, packed_hh, b_ih, b_hh, True + ) + ) + else: + packed_ih = torch.ops.quantized.linear_prepack_fp16(w_ih, b_ih) + packed_hh = torch.ops.quantized.linear_prepack_fp16(w_hh, b_hh) + cell_params = torch.ops.quantized.make_quantized_cell_params_fp16( + packed_ih, packed_hh + ) + + _all_weight_values.append(PackedParameter(cell_params)) + self._all_weight_values = torch.nn.ModuleList(_all_weight_values) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + assert type(mod) in { + torch.nn.LSTM, + torch.nn.GRU, + }, "nn.quantized.dynamic.RNNBase.from_float only works for nn.LSTM and nn.GRU" + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + + if mod.qconfig is not None and mod.qconfig.weight is not None: + weight_observer_method = mod.qconfig.weight + else: + # We have the circular import issues if we import the qconfig in the beginning of this file: + # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the + # import until we need it. + from torch.ao.quantization.qconfig import default_dynamic_qconfig + + weight_observer_method = default_dynamic_qconfig.weight + + dtype = weight_observer_method().dtype + supported_scalar_types = [torch.qint8, torch.float16] + if dtype not in supported_scalar_types: + raise RuntimeError( + f"Unsupported dtype for dynamic RNN quantization: {dtype}" + ) + # RNNBase can be either LSTM or GRU + qRNNBase: Union[LSTM, GRU] + if mod.mode == "LSTM": + qRNNBase = LSTM( + mod.input_size, + mod.hidden_size, + mod.num_layers, + mod.bias, + mod.batch_first, + mod.dropout, + mod.bidirectional, + dtype, + ) + elif mod.mode == "GRU": + qRNNBase = GRU( + mod.input_size, + mod.hidden_size, + mod.num_layers, + mod.bias, + mod.batch_first, + mod.dropout, + mod.bidirectional, + dtype, + ) + else: + raise NotImplementedError( + "Only LSTM/GRU is supported for QuantizedRNN for now" + ) + + num_directions = 2 if mod.bidirectional else 1 + + assert mod.bias + + _all_weight_values = [] + for layer in range(qRNNBase.num_layers): + for direction in range(num_directions): + suffix = "_reverse" if direction == 1 else "" + + def retrieve_weight_bias(ihhh): + weight_name = f"weight_{ihhh}_l{layer}{suffix}" + bias_name = f"bias_{ihhh}_l{layer}{suffix}" + weight = getattr(mod, weight_name) + bias = getattr(mod, bias_name) + return weight, bias + + weight_ih, bias_ih = retrieve_weight_bias("ih") + weight_hh, bias_hh = retrieve_weight_bias("hh") + + if dtype == torch.qint8: + + def quantize_and_pack(w, b): + weight_observer = weight_observer_method() + weight_observer(w) + qweight = _quantize_weight(w.float(), weight_observer) + packed_weight = torch.ops.quantized.linear_prepack(qweight, b) + return packed_weight + + packed_ih = quantize_and_pack(weight_ih, bias_ih) + packed_hh = quantize_and_pack(weight_hh, bias_hh) + if qRNNBase.version is None or qRNNBase.version < 2: + cell_params = ( + torch.ops.quantized.make_quantized_cell_params_dynamic( + packed_ih, packed_hh, bias_ih, bias_hh + ) + ) + else: + cell_params = ( + torch.ops.quantized.make_quantized_cell_params_dynamic( + packed_ih, packed_hh, bias_ih, bias_hh, True + ) + ) + + elif dtype == torch.float16: + packed_ih = torch.ops.quantized.linear_prepack_fp16( + weight_ih.float(), bias_ih + ) + packed_hh = torch.ops.quantized.linear_prepack_fp16( + weight_hh.float(), bias_hh + ) + + cell_params = torch.ops.quantized.make_quantized_cell_params_fp16( + packed_ih, packed_hh + ) + else: + raise RuntimeError( + "Unsupported dtype specified for dynamic quantized LSTM!" + ) + + _all_weight_values.append(PackedParameter(cell_params)) + qRNNBase._all_weight_values = torch.nn.ModuleList(_all_weight_values) + + return qRNNBase + + def _weight_bias(self): + # Returns a dict of weights and biases + weight_bias_dict: Dict[str, Dict] = {"weight": {}, "bias": {}} + count = 0 + num_directions = 2 if self.bidirectional else 1 + for layer in range(self.num_layers): + for direction in range(num_directions): + suffix = "_reverse" if direction == 1 else "" + key_name1 = f"weight_ih_l{layer}{suffix}" + key_name2 = f"weight_hh_l{layer}{suffix}" + # packed weights are part of torchbind class, CellParamsSerializationType + # Within the packed weight class, the weight and bias are accessible as Tensors + packed_weight_bias = self._all_weight_values[ # type: ignore[index] + count + ].param.__getstate__()[0][4] + weight_bias_dict["weight"][key_name1] = packed_weight_bias[ + 0 + ].__getstate__()[0][0] + weight_bias_dict["weight"][key_name2] = packed_weight_bias[ + 1 + ].__getstate__()[0][0] + key_name1 = f"bias_ih_l{layer}{suffix}" + key_name2 = f"bias_hh_l{layer}{suffix}" + weight_bias_dict["bias"][key_name1] = packed_weight_bias[ + 0 + ].__getstate__()[0][1] + weight_bias_dict["bias"][key_name2] = packed_weight_bias[ + 1 + ].__getstate__()[0][1] + count = count + 1 + return weight_bias_dict + + def get_weight(self): + return self._weight_bias()["weight"] + + def get_bias(self): + return self._weight_bias()["bias"] + + +class LSTM(RNNBase): + r""" + A dynamic quantized LSTM module with floating point tensor as inputs and outputs. + We adopt the same interface as `torch.nn.LSTM`, please see + https://pytorch.org/docs/stable/nn.html#torch.nn.LSTM for documentation. + + Examples:: + + >>> # xdoctest: +SKIP + >>> rnn = nn.LSTM(10, 20, 2) + >>> input = torch.randn(5, 3, 10) + >>> h0 = torch.randn(2, 3, 20) + >>> c0 = torch.randn(2, 3, 20) + >>> output, (hn, cn) = rnn(input, (h0, c0)) + """ + + _FLOAT_MODULE = nn.LSTM + + __overloads__ = {"forward": ["forward_packed", "forward_tensor"]} + + def __init__(self, *args, **kwargs): + super().__init__("LSTM", *args, **kwargs) + + def _get_name(self): + return "DynamicQuantizedLSTM" + + def forward_impl( + self, + input: Tensor, + hx: Optional[tuple[Tensor, Tensor]], + batch_sizes: Optional[Tensor], + max_batch_size: int, + sorted_indices: Optional[Tensor], + ) -> tuple[Tensor, tuple[Tensor, Tensor]]: + if hx is None: + num_directions = 2 if self.bidirectional else 1 + zeros = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + hx = (zeros, zeros) + else: + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + hx = self.permute_hidden(hx, sorted_indices) + + self.check_forward_args(input, hx, batch_sizes) + + _all_params = [m.param for m in self._all_weight_values] + if batch_sizes is None: + result = torch.quantized_lstm( + input, + hx, + _all_params, + self.bias, + self.num_layers, + float(self.dropout), + self.training, + self.bidirectional, + self.batch_first, + dtype=self.dtype, + use_dynamic=True, + ) + else: + result = torch.quantized_lstm( + input, + batch_sizes, + hx, + _all_params, + self.bias, + self.num_layers, + float(self.dropout), + self.training, + self.bidirectional, + dtype=self.dtype, + use_dynamic=True, + ) + output = result[0] + hidden = result[1:] + + return output, hidden + + @torch.jit.export + def forward_tensor( + self, input: Tensor, hx: Optional[tuple[Tensor, Tensor]] = None + ) -> tuple[Tensor, tuple[Tensor, Tensor]]: + batch_sizes = None + max_batch_size = input.size(0) if self.batch_first else input.size(1) + sorted_indices = None + unsorted_indices = None + + output, hidden = self.forward_impl( + input, hx, batch_sizes, max_batch_size, sorted_indices + ) + + return output, self.permute_hidden(hidden, unsorted_indices) + + @torch.jit.export + def forward_packed( + self, input: PackedSequence, hx: Optional[tuple[Tensor, Tensor]] = None + ) -> tuple[PackedSequence, tuple[Tensor, Tensor]]: + input_, batch_sizes, sorted_indices, unsorted_indices = input + max_batch_size = int(batch_sizes[0]) + + output_, hidden = self.forward_impl( + input_, hx, batch_sizes, max_batch_size, sorted_indices + ) + + output = PackedSequence(output_, batch_sizes, sorted_indices, unsorted_indices) + return output, self.permute_hidden(hidden, unsorted_indices) + + # "type: ignore" is required due to issue #43072 + def permute_hidden( # type: ignore[override] + self, + hx: tuple[Tensor, Tensor], + permutation: Optional[Tensor], + ) -> tuple[Tensor, Tensor]: + if permutation is None: + return hx + return _apply_permutation(hx[0], permutation), _apply_permutation( + hx[1], permutation + ) + + # "type: ignore" is required due to issue #43072 + def check_forward_args( # type: ignore[override] + self, + input: Tensor, + hidden: tuple[Tensor, Tensor], + batch_sizes: Optional[Tensor], + ) -> None: + self.check_input(input, batch_sizes) + expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) + + self.check_hidden_size( + hidden[0], expected_hidden_size, "Expected hidden[0] size {}, got {}" + ) + self.check_hidden_size( + hidden[1], expected_hidden_size, "Expected hidden[1] size {}, got {}" + ) + + @torch.jit.ignore + def forward(self, input, hx=None): + if isinstance(input, PackedSequence): + return self.forward_packed(input, hx) + else: + return self.forward_tensor(input, hx) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + @classmethod + def from_reference(cls, ref_mod): + assert hasattr(ref_mod, "weight_ih_l0_dtype"), "We are assuming weight_ih_l0 " + "exists in LSTM, may need to relax the assumption to support the use case" + qmod = cls( + ref_mod.input_size, + ref_mod.hidden_size, + ref_mod.num_layers, + ref_mod.bias, + ref_mod.batch_first, + ref_mod.dropout, + ref_mod.bidirectional, + # assuming there is layer 0, which should be OK + ref_mod.weight_ih_l0_dtype, + ) + qmod.set_weight_bias(ref_mod.get_quantized_weight_bias_dict()) + return qmod + + +class GRU(RNNBase): + r"""Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence. + + + For each element in the input sequence, each layer computes the following + function: + + .. math:: + \begin{array}{ll} + r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\ + z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\ + n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn})) \\ + h_t = (1 - z_t) \odot n_t + z_t \odot h_{(t-1)} + \end{array} + + where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the input + at time `t`, :math:`h_{(t-1)}` is the hidden state of the layer + at time `t-1` or the initial hidden state at time `0`, and :math:`r_t`, + :math:`z_t`, :math:`n_t` are the reset, update, and new gates, respectively. + :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product. + + In a multilayer GRU, the input :math:`x^{(l)}_t` of the :math:`l` -th layer + (:math:`l >= 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by + dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random + variable which is :math:`0` with probability :attr:`dropout`. + + Args: + input_size: The number of expected features in the input `x` + hidden_size: The number of features in the hidden state `h` + num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` + would mean stacking two GRUs together to form a `stacked GRU`, + with the second GRU taking in outputs of the first GRU and + computing the final results. Default: 1 + bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. + Default: ``True`` + batch_first: If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Default: ``False`` + dropout: If non-zero, introduces a `Dropout` layer on the outputs of each + GRU layer except the last layer, with dropout probability equal to + :attr:`dropout`. Default: 0 + bidirectional: If ``True``, becomes a bidirectional GRU. Default: ``False`` + + Inputs: input, h_0 + - **input** of shape `(seq_len, batch, input_size)`: tensor containing the features + of the input sequence. The input can also be a packed variable length + sequence. See :func:`torch.nn.utils.rnn.pack_padded_sequence` + for details. + - **h_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor + containing the initial hidden state for each element in the batch. + Defaults to zero if not provided. If the RNN is bidirectional, + num_directions should be 2, else it should be 1. + + Outputs: output, h_n + - **output** of shape `(seq_len, batch, num_directions * hidden_size)`: tensor + containing the output features h_t from the last layer of the GRU, + for each `t`. If a :class:`torch.nn.utils.rnn.PackedSequence` has been + given as the input, the output will also be a packed sequence. + For the unpacked case, the directions can be separated + using ``output.view(seq_len, batch, num_directions, hidden_size)``, + with forward and backward being direction `0` and `1` respectively. + + Similarly, the directions can be separated in the packed case. + - **h_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor + containing the hidden state for `t = seq_len` + + Like *output*, the layers can be separated using + ``h_n.view(num_layers, num_directions, batch, hidden_size)``. + + Shape: + - Input1: :math:`(L, N, H_{in})` tensor containing input features where + :math:`H_{in}=\text{input\_size}` and `L` represents a sequence length. + - Input2: :math:`(S, N, H_{out})` tensor + containing the initial hidden state for each element in the batch. + :math:`H_{out}=\text{hidden\_size}` + Defaults to zero if not provided. where :math:`S=\text{num\_layers} * \text{num\_directions}` + If the RNN is bidirectional, num_directions should be 2, else it should be 1. + - Output1: :math:`(L, N, H_{all})` where :math:`H_{all}=\text{num\_directions} * \text{hidden\_size}` + - Output2: :math:`(S, N, H_{out})` tensor containing the next hidden state + for each element in the batch + + Attributes: + weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer + (W_ir|W_iz|W_in), of shape `(3*hidden_size, input_size)` for `k = 0`. + Otherwise, the shape is `(3*hidden_size, num_directions * hidden_size)` + weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer + (W_hr|W_hz|W_hn), of shape `(3*hidden_size, hidden_size)` + bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer + (b_ir|b_iz|b_in), of shape `(3*hidden_size)` + bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer + (b_hr|b_hz|b_hn), of shape `(3*hidden_size)` + + .. note:: + All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` + where :math:`k = \frac{1}{\text{hidden\_size}}` + + .. note:: + The calculation of new gate :math:`n_t` subtly differs from the original paper and other frameworks. + In the original implementation, the Hadamard product :math:`(\odot)` between :math:`r_t` and the + previous hidden state :math:`h_{(t-1)}` is done before the multiplication with the weight matrix + `W` and addition of bias: + + .. math:: + \begin{aligned} + n_t = \tanh(W_{in} x_t + b_{in} + W_{hn} ( r_t \odot h_{(t-1)} ) + b_{hn}) + \end{aligned} + + This is in contrast to PyTorch implementation, which is done after :math:`W_{hn} h_{(t-1)}` + + .. math:: + \begin{aligned} + n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn})) + \end{aligned} + + This implementation differs on purpose for efficiency. + + .. include:: ../cudnn_persistent_rnn.rst + + Examples:: + + >>> # xdoctest: +SKIP + >>> rnn = nn.GRU(10, 20, 2) + >>> input = torch.randn(5, 3, 10) + >>> h0 = torch.randn(2, 3, 20) + >>> output, hn = rnn(input, h0) + """ + + _FLOAT_MODULE = nn.GRU + + __overloads__ = {"forward": ["forward_packed", "forward_tensor"]} + + def __init__(self, *args, **kwargs): + super().__init__("GRU", *args, **kwargs) + + def _get_name(self): + return "DynamicQuantizedGRU" + + def check_forward_args( + self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor] + ) -> None: + self.check_input(input, batch_sizes) + expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) + + self.check_hidden_size( + hidden, expected_hidden_size, "Expected hidden size {}, got {}" + ) + + def forward_impl( + self, + input: Tensor, + hx: Optional[Tensor], + batch_sizes: Optional[Tensor], + max_batch_size: int, + sorted_indices: Optional[Tensor], + ) -> tuple[Tensor, Tensor]: + if hx is None: + num_directions = 2 if self.bidirectional else 1 + zeros = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + hx = zeros + else: + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + hx = self.permute_hidden(hx, sorted_indices) + + self.check_forward_args(input, hx, batch_sizes) + + _all_params = [m.param for m in self._all_weight_values] + if batch_sizes is None: + result = torch.quantized_gru( + input, + hx, + _all_params, + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + self.batch_first, + ) + else: + result = torch.quantized_gru( + input, + batch_sizes, + hx, + _all_params, + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + ) + output = result[0] + hidden = result[1] + + return output, hidden + + @torch.jit.export + def forward_tensor( + self, input: Tensor, hx: Optional[Tensor] = None + ) -> tuple[Tensor, Tensor]: + batch_sizes = None + max_batch_size = input.size(0) if self.batch_first else input.size(1) + sorted_indices = None + unsorted_indices = None + + output, hidden = self.forward_impl( + input, hx, batch_sizes, max_batch_size, sorted_indices + ) + + return output, self.permute_hidden(hidden, unsorted_indices) + + @torch.jit.export + def forward_packed( + self, input: PackedSequence, hx: Optional[Tensor] = None + ) -> tuple[PackedSequence, Tensor]: + input_, batch_sizes, sorted_indices, unsorted_indices = input + max_batch_size = int(batch_sizes[0]) + output_, hidden = self.forward_impl( + input_, hx, batch_sizes, max_batch_size, sorted_indices + ) + + output = PackedSequence(output_, batch_sizes, sorted_indices, unsorted_indices) + return output, self.permute_hidden(hidden, unsorted_indices) + + def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]) -> Tensor: + if permutation is None: + return hx + return _apply_permutation(hx, permutation) + + @torch.jit.ignore + def forward(self, input, hx=None): + if isinstance(input, PackedSequence): + return self.forward_packed(input, hx) + else: + return self.forward_tensor(input, hx) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + @classmethod + def from_reference(cls, ref_mod): + assert hasattr(ref_mod, "weight_ih_l0_dtype"), "We are assuming weight_ih_l0 " + "exists in LSTM, may need to relax the assumption to support the use case" + qmod = cls( + ref_mod.input_size, + ref_mod.hidden_size, + ref_mod.num_layers, + ref_mod.bias, + ref_mod.batch_first, + ref_mod.dropout, + ref_mod.bidirectional, + # assuming there is layer 0, which should be OK + ref_mod.weight_ih_l0_dtype, + ) + qmod.set_weight_bias(ref_mod.get_quantized_weight_bias_dict()) + return qmod + + +class RNNCellBase(torch.nn.Module): + # _FLOAT_MODULE = nn.CellRNNBase + __constants__ = ["input_size", "hidden_size", "bias"] + + def __init__( + self, input_size, hidden_size, bias=True, num_chunks=4, dtype=torch.qint8 + ): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.bias = bias + self.weight_dtype = dtype + if bias: + self.bias_ih = torch.randn(num_chunks * hidden_size).to(dtype=torch.float) + self.bias_hh = torch.randn(num_chunks * hidden_size).to(dtype=torch.float) + else: + self.register_parameter("bias_ih", None) + self.register_parameter("bias_hh", None) + + weight_ih = torch.randn(num_chunks * hidden_size, input_size).to(torch.float) + weight_hh = torch.randn(num_chunks * hidden_size, hidden_size).to(torch.float) + if dtype == torch.qint8: + weight_ih = torch.quantize_per_tensor( + weight_ih, scale=1, zero_point=0, dtype=torch.qint8 + ) + weight_hh = torch.quantize_per_tensor( + weight_hh, scale=1, zero_point=0, dtype=torch.qint8 + ) + + if dtype == torch.qint8: + # for each layer, for each direction we need to quantize and pack + # weights and pack parameters in this order: + # + # w_ih, w_hh + packed_weight_ih = torch.ops.quantized.linear_prepack( + weight_ih, self.bias_ih + ) + packed_weight_hh = torch.ops.quantized.linear_prepack( + weight_hh, self.bias_hh + ) + else: + # for each layer, for each direction we need to quantize and pack + # weights and pack parameters in this order: + # + # packed_ih, packed_hh, b_ih, b_hh + packed_weight_ih = torch.ops.quantized.linear_prepack_fp16( + weight_ih, self.bias_ih + ) + packed_weight_hh = torch.ops.quantized.linear_prepack_fp16( + weight_hh, self.bias_hh + ) + + self._packed_weight_ih = packed_weight_ih + self._packed_weight_hh = packed_weight_hh + + def _get_name(self): + return "DynamicQuantizedRNNBase" + + def extra_repr(self): + s = "{input_size}, {hidden_size}" + if "bias" in self.__dict__ and self.bias is not True: + s += ", bias={bias}" + if "nonlinearity" in self.__dict__ and self.nonlinearity != "tanh": + s += ", nonlinearity={nonlinearity}" + return s.format(**self.__dict__) + + def check_forward_input(self, input): + if input.size(1) != self.input_size: + raise RuntimeError( + f"input has inconsistent input_size: got {input.size(1)}, expected {self.input_size}" + ) + + def check_forward_hidden( + self, input: Tensor, hx: Tensor, hidden_label: str = "" + ) -> None: + if input.size(0) != hx.size(0): + raise RuntimeError( + f"Input batch size {input.size(0)} doesn't match hidden{hidden_label} batch size {hx.size(0)}" + ) + + if hx.size(1) != self.hidden_size: + raise RuntimeError( + f"hidden{hidden_label} has inconsistent hidden_size: got {hx.size(1)}, expected {self.hidden_size}" + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + assert type(mod) in { + torch.nn.LSTMCell, + torch.nn.GRUCell, + torch.nn.RNNCell, + }, ( + "nn.quantized.dynamic.RNNCellBase.from_float \ + only works for nn.LSTMCell, nn.GRUCell and nn.RNNCell" + ) + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + + if mod.qconfig is not None and mod.qconfig.weight is not None: + weight_observer_method = mod.qconfig.weight + else: + # We have the circular import issues if we import the qconfig in the beginning of this file: + # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the + # import until we need it. + from torch.ao.quantization.qconfig import default_dynamic_qconfig + + weight_observer_method = default_dynamic_qconfig.weight + + dtype = weight_observer_method().dtype + supported_scalar_types = [torch.qint8, torch.float16] + if dtype not in supported_scalar_types: + raise RuntimeError( + f"Unsupported dtype for dynamic RNN quantization: {dtype}" + ) + + qRNNCellBase: Union[LSTMCell, GRUCell, RNNCell] + + if type(mod) == torch.nn.LSTMCell: + qRNNCellBase = LSTMCell( + mod.input_size, mod.hidden_size, bias=mod.bias, dtype=dtype + ) + elif type(mod) == torch.nn.GRUCell: + qRNNCellBase = GRUCell( + mod.input_size, mod.hidden_size, bias=mod.bias, dtype=dtype + ) + elif type(mod) == torch.nn.RNNCell: + qRNNCellBase = RNNCell( + mod.input_size, + mod.hidden_size, + bias=mod.bias, + nonlinearity=mod.nonlinearity, + dtype=dtype, + ) + else: + raise NotImplementedError( + "Only LSTMCell, GRUCell and RNNCell \ + are supported for QuantizedRNN for now" + ) + + assert mod.bias + + def _observe_and_quantize_weight(weight): + if dtype == torch.qint8: + weight_observer = weight_observer_method() + weight_observer(weight) + qweight = _quantize_weight(weight.float(), weight_observer) + return qweight + else: + return weight.float() + + qRNNCellBase._packed_weight_ih = pack_weight_bias( + _observe_and_quantize_weight(mod.weight_ih), mod.bias_ih, dtype + ) + qRNNCellBase._packed_weight_hh = pack_weight_bias( + _observe_and_quantize_weight(mod.weight_hh), mod.bias_hh, dtype + ) + return qRNNCellBase + + @classmethod + def from_reference(cls, ref_mod): + assert hasattr(ref_mod, "weight_ih_dtype"), "We are assuming weight_ih " + "exists in reference module, may need to relax the assumption to support the use case" + if hasattr(ref_mod, "nonlinearity"): + qmod = cls( + ref_mod.input_size, + ref_mod.hidden_size, + ref_mod.bias, + ref_mod.nonlinearity, + dtype=ref_mod.weight_ih_dtype, + ) + else: + qmod = cls( + ref_mod.input_size, + ref_mod.hidden_size, + ref_mod.bias, + dtype=ref_mod.weight_ih_dtype, + ) + weight_bias_dict = { + "weight": { + "weight_ih": ref_mod.get_quantized_weight_ih(), + "weight_hh": ref_mod.get_quantized_weight_hh(), + }, + "bias": { + "bias_ih": ref_mod.bias_ih, + "bias_hh": ref_mod.bias_hh, + }, + } + qmod.set_weight_bias(weight_bias_dict) + return qmod + + def _weight_bias(self): + # Returns a dict of weights and biases + weight_bias_dict: Dict[str, Dict] = {"weight": {}, "bias": {}} + w1, b1 = self._packed_weight_ih.__getstate__()[0] + w2, b2 = self._packed_weight_hh.__getstate__()[0] + # TODO: these can be simplified to one level? e.g. using weight_ih as key + # directly + weight_bias_dict["weight"]["weight_ih"] = w1 + weight_bias_dict["weight"]["weight_hh"] = w2 + weight_bias_dict["bias"]["bias_ih"] = b1 + weight_bias_dict["bias"]["bias_hh"] = b2 + return weight_bias_dict + + def get_weight(self): + return self._weight_bias()["weight"] + + def get_bias(self): + return self._weight_bias()["bias"] + + def set_weight_bias(self, weight_bias_dict): + # TODO: these can be simplified to one level? e.g. using weight_ih as key + # directly + self._packed_weight_ih = pack_weight_bias( + weight_bias_dict["weight"]["weight_ih"], + weight_bias_dict["bias"]["bias_ih"], + self.weight_dtype, + ) + self._packed_weight_hh = pack_weight_bias( + weight_bias_dict["weight"]["weight_hh"], + weight_bias_dict["bias"]["bias_hh"], + self.weight_dtype, + ) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + "_packed_weight_ih"] = self._packed_weight_ih + destination[prefix + "_packed_weight_hh"] = self._packed_weight_hh + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + self._packed_weight_ih = state_dict.pop(prefix + "_packed_weight_ih") + self._packed_weight_hh = state_dict.pop(prefix + "_packed_weight_hh") + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + +class RNNCell(RNNCellBase): + r"""An Elman RNN cell with tanh or ReLU non-linearity. + A dynamic quantized RNNCell module with floating point tensor as inputs and outputs. + Weights are quantized to 8 bits. We adopt the same interface as `torch.nn.RNNCell`, + please see https://pytorch.org/docs/stable/nn.html#torch.nn.RNNCell for documentation. + + Examples:: + + >>> # xdoctest: +SKIP + >>> rnn = nn.RNNCell(10, 20) + >>> input = torch.randn(6, 3, 10) + >>> hx = torch.randn(3, 20) + >>> output = [] + >>> for i in range(6): + ... hx = rnn(input[i], hx) + ... output.append(hx) + """ + + __constants__ = ["input_size", "hidden_size", "bias", "nonlinearity"] + + def __init__( + self, input_size, hidden_size, bias=True, nonlinearity="tanh", dtype=torch.qint8 + ): + super().__init__(input_size, hidden_size, bias, num_chunks=1, dtype=dtype) + self.nonlinearity = nonlinearity + + def _get_name(self): + return "DynamicQuantizedRNNCell" + + def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: + self.check_forward_input(input) + if hx is None: + hx = torch.zeros( + input.size(0), self.hidden_size, dtype=input.dtype, device=input.device + ) + self.check_forward_hidden(input, hx, "") + if self.nonlinearity == "tanh": + ret = torch.ops.quantized.quantized_rnn_tanh_cell_dynamic( + input, + hx, + self._packed_weight_ih, + self._packed_weight_hh, + self.bias_ih, + self.bias_hh, + ) + elif self.nonlinearity == "relu": + ret = torch.ops.quantized.quantized_rnn_relu_cell_dynamic( + input, + hx, + self._packed_weight_ih, + self._packed_weight_hh, + self.bias_ih, + self.bias_hh, + ) + else: + ret = input # TODO: remove when jit supports exception flow + raise RuntimeError(f"Unknown nonlinearity: {self.nonlinearity}") + return ret + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + +class LSTMCell(RNNCellBase): + r"""A long short-term memory (LSTM) cell. + + A dynamic quantized LSTMCell module with floating point tensor as inputs and outputs. + Weights are quantized to 8 bits. We adopt the same interface as `torch.nn.LSTMCell`, + please see https://pytorch.org/docs/stable/nn.html#torch.nn.LSTMCell for documentation. + + Examples:: + + >>> # xdoctest: +SKIP + >>> rnn = nn.LSTMCell(10, 20) + >>> input = torch.randn(6, 3, 10) + >>> hx = torch.randn(3, 20) + >>> cx = torch.randn(3, 20) + >>> output = [] + >>> for i in range(6): + ... hx, cx = rnn(input[i], (hx, cx)) + ... output.append(hx) + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, num_chunks=4, **kwargs) # type: ignore[misc] + + def _get_name(self): + return "DynamicQuantizedLSTMCell" + + def forward( + self, input: Tensor, hx: Optional[tuple[Tensor, Tensor]] = None + ) -> tuple[Tensor, Tensor]: + self.check_forward_input(input) + if hx is None: + zeros = torch.zeros( + input.size(0), self.hidden_size, dtype=input.dtype, device=input.device + ) + hx = (zeros, zeros) + self.check_forward_hidden(input, hx[0], "[0]") + self.check_forward_hidden(input, hx[1], "[1]") + return torch.ops.quantized.quantized_lstm_cell_dynamic( + input, + hx, + self._packed_weight_ih, + self._packed_weight_hh, + self.bias_ih, + self.bias_hh, + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + +class GRUCell(RNNCellBase): + r"""A gated recurrent unit (GRU) cell + + A dynamic quantized GRUCell module with floating point tensor as inputs and outputs. + Weights are quantized to 8 bits. We adopt the same interface as `torch.nn.GRUCell`, + please see https://pytorch.org/docs/stable/nn.html#torch.nn.GRUCell for documentation. + + Examples:: + + >>> # xdoctest: +SKIP + >>> rnn = nn.GRUCell(10, 20) + >>> input = torch.randn(6, 3, 10) + >>> hx = torch.randn(3, 20) + >>> output = [] + >>> for i in range(6): + ... hx = rnn(input[i], hx) + ... output.append(hx) + """ + + def __init__(self, input_size, hidden_size, bias=True, dtype=torch.qint8): + super().__init__(input_size, hidden_size, bias, num_chunks=3, dtype=dtype) + + def _get_name(self): + return "DynamicQuantizedGRUCell" + + def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: + self.check_forward_input(input) + if hx is None: + hx = torch.zeros( + input.size(0), self.hidden_size, dtype=input.dtype, device=input.device + ) + self.check_forward_hidden(input, hx, "") + return torch.ops.quantized.quantized_gru_cell_dynamic( + input, + hx, + self._packed_weight_ih, + self._packed_weight_hh, + self.bias_ih, + self.bias_hh, + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/functional.py b/phivenv/Lib/site-packages/torch/ao/nn/quantized/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..d6d99bf0ff245faf0dd67bee9d2b8ea858ae78fb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/quantized/functional.py @@ -0,0 +1,779 @@ +# mypy: allow-untyped-defs +r"""Functional interface (quantized).""" + +import warnings +from typing import Optional + +import torch +from torch import Tensor +from torch.jit.annotations import BroadcastingList2 +from torch.nn.modules.utils import _pair, _triple + +from .modules.utils import _pair_from_first + + +# Although some of the functions and docstrings are mirrored from the torch.nn, +# we want to have them here for future changes. + +__all__ = [ + "avg_pool2d", + "avg_pool3d", + "adaptive_avg_pool2d", + "adaptive_avg_pool3d", + "conv1d", + "conv2d", + "conv3d", + "interpolate", + "linear", + "max_pool1d", + "max_pool2d", + "celu", + "leaky_relu", + "hardtanh", + "hardswish", + "threshold", + "elu", + "hardsigmoid", + "clamp", + "upsample", + "upsample_bilinear", + "upsample_nearest", +] + + +def avg_pool2d( + input, + kernel_size, + stride=None, + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, +): + r""" + Applies 2D average-pooling operation in :math:`kH \times kW` regions by step size + :math:`sH \times sW` steps. The number of output features is equal to the number of + input planes. + + .. note:: The input quantization parameters propagate to the output. + + See :class:`~torch.ao.nn.quantized.AvgPool2d` for details and output shape. + + Args: + input: quantized input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` + kernel_size: size of the pooling region. Can be a single number or a + tuple `(kH, kW)` + stride: stride of the pooling operation. Can be a single number or a + tuple `(sH, sW)`. Default: :attr:`kernel_size` + padding: implicit zero paddings on both sides of the input. Can be a + single number or a tuple `(padH, padW)`. Default: 0 + ceil_mode: when True, will use `ceil` instead of `floor` in the formula + to compute the output shape. Default: ``False`` + count_include_pad: when True, will include the zero-padding in the + averaging calculation. Default: ``True`` + divisor_override: if specified, it will be used as divisor, otherwise + size of the pooling region will be used. Default: None + """ + if not input.is_quantized: + raise ValueError("Input to 'quantized.avg_pool2d' must be quantized!") + return torch.nn.functional.avg_pool2d( + input, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + + +def avg_pool3d( + input, + kernel_size, + stride=None, + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, +): + r""" + Applies 3D average-pooling operation in :math:`kD \ times kH \times kW` regions by step size + :math:`sD \times sH \times sW` steps. The number of output features is equal to the number of + input planes. + + .. note:: The input quantization parameters propagate to the output. + + Args: + input: quantized input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` + kernel_size: size of the pooling region. Can be a single number or a + tuple `(kD, kH, kW)` + stride: stride of the pooling operation. Can be a single number or a + tuple `(sD, sH, sW)`. Default: :attr:`kernel_size` + padding: implicit zero paddings on both sides of the input. Can be a + single number or a tuple `(padD, padH, padW)`. Default: 0 + ceil_mode: when True, will use `ceil` instead of `floor` in the formula + to compute the output shape. Default: ``False`` + count_include_pad: when True, will include the zero-padding in the + averaging calculation. Default: ``True`` + divisor_override: if specified, it will be used as divisor, otherwise + size of the pooling region will be used. Default: None + """ + if not input.is_quantized: + raise ValueError("Input to 'quantized.avg_pool3d' must be quantized!") + return torch.nn.functional.avg_pool3d( + input, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + + +def adaptive_avg_pool2d(input: Tensor, output_size: BroadcastingList2[int]) -> Tensor: + r""" + Applies a 2D adaptive average pooling over a quantized input signal composed + of several quantized input planes. + + .. note:: The input quantization parameters propagate to the output. + + See :class:`~torch.ao.nn.quantized.AdaptiveAvgPool2d` for details and output shape. + + Args: + output_size: the target output size (single integer or + double-integer tuple) + """ + if not input.is_quantized: + raise ValueError( + "Input to 'quantized.functional.adaptive_avg_pool2d' must be quantized!" + ) + return torch.nn.functional.adaptive_avg_pool2d(input, output_size) + + +def adaptive_avg_pool3d(input: Tensor, output_size: BroadcastingList2[int]) -> Tensor: + r""" + Applies a 3D adaptive average pooling over a quantized input signal composed + of several quantized input planes. + + .. note:: The input quantization parameters propagate to the output. + + See :class:`~torch.ao.nn.quantized.AdaptiveAvgPool3d` for details and output shape. + + Args: + output_size: the target output size (single integer or + double-integer tuple) + """ + if not input.is_quantized: + raise ValueError( + "Input to 'quantized.functional.adaptive_avg_pool3d' must be quantized!" + ) + return torch.nn.functional.adaptive_avg_pool3d(input, output_size) + + +def conv1d( + input, + weight, + bias, + stride=1, + padding=0, + dilation=1, + groups=1, + padding_mode="zeros", + scale=1.0, + zero_point=0, + dtype=torch.quint8, +): + r""" + Applies a 1D convolution over a quantized 1D input composed of several input + planes. + + See :class:`~torch.ao.nn.quantized.Conv1d` for details and output shape. + + Args: + input: quantized input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)` + weight: quantized filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , iW)` + bias: **non-quantized** bias tensor of shape :math:`(\text{out\_channels})`. The tensor type must be `torch.float`. + stride: the stride of the convolving kernel. Can be a single number or a + tuple `(sW,)`. Default: 1 + padding: implicit paddings on both sides of the input. Can be a + single number or a tuple `(padW,)`. Default: 0 + dilation: the spacing between kernel elements. Can be a single number or + a tuple `(dW,)`. Default: 1 + groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the + number of groups. Default: 1 + padding_mode: the padding mode to use. Only "zeros" is supported for quantized convolution at the moment. Default: "zeros" + scale: quantization scale for the output. Default: 1.0 + zero_point: quantization zero_point for the output. Default: 0 + dtype: quantization data type to use. Default: ``torch.quint8`` + + Examples:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) + >>> from torch.ao.nn.quantized import functional as qF + >>> filters = torch.randn(33, 16, 3, dtype=torch.float) + >>> inputs = torch.randn(20, 16, 50, dtype=torch.float) + >>> bias = torch.randn(33, dtype=torch.float) + >>> + >>> scale, zero_point = 1.0, 0 + >>> dtype_inputs = torch.quint8 + >>> dtype_filters = torch.qint8 + >>> + >>> q_filters = torch.quantize_per_tensor(filters, scale, zero_point, dtype_filters) + >>> q_inputs = torch.quantize_per_tensor(inputs, scale, zero_point, dtype_inputs) + >>> qF.conv1d(q_inputs, q_filters, bias, padding=1, scale=scale, zero_point=zero_point) + """ # noqa: E501 + if padding_mode != "zeros": + raise NotImplementedError("Only zero-padding is supported!") + if input.dtype != torch.quint8: + raise NotImplementedError( + "Only torch.quint8 is supported for activation tensor!" + ) + if weight.dtype != torch.qint8: + raise NotImplementedError("Only torch.qint8 is supported for weight tensor!") + if input.ndim != 3: + raise ValueError("Input shape must be `(N, C, L)`!") + stride = _pair_from_first(stride) + padding = _pair_from_first(padding) + dilation = _pair_from_first(dilation) + + packed_params = torch.ops.quantized.conv1d_prepack( + weight, bias, stride, padding, dilation, groups + ) + return torch.ops.quantized.conv1d(input, packed_params, scale, zero_point) + + +def conv2d( + input, + weight, + bias, + stride=1, + padding=0, + dilation=1, + groups=1, + padding_mode="zeros", + scale=1.0, + zero_point=0, + dtype=torch.quint8, +): + r""" + Applies a 2D convolution over a quantized 2D input composed of several input + planes. + + See :class:`~torch.ao.nn.quantized.Conv2d` for details and output shape. + + Args: + input: quantized input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` + weight: quantized filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kH , kW)` + bias: **non-quantized** bias tensor of shape :math:`(\text{out\_channels})`. The tensor type must be `torch.float`. + stride: the stride of the convolving kernel. Can be a single number or a + tuple `(sH, sW)`. Default: 1 + padding: implicit paddings on both sides of the input. Can be a + single number or a tuple `(padH, padW)`. Default: 0 + dilation: the spacing between kernel elements. Can be a single number or + a tuple `(dH, dW)`. Default: 1 + groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the + number of groups. Default: 1 + padding_mode: the padding mode to use. Only "zeros" is supported for quantized convolution at the moment. Default: "zeros" + scale: quantization scale for the output. Default: 1.0 + zero_point: quantization zero_point for the output. Default: 0 + dtype: quantization data type to use. Default: ``torch.quint8`` + + Examples:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) + >>> from torch.ao.nn.quantized import functional as qF + >>> filters = torch.randn(8, 4, 3, 3, dtype=torch.float) + >>> inputs = torch.randn(1, 4, 5, 5, dtype=torch.float) + >>> bias = torch.randn(8, dtype=torch.float) + >>> + >>> scale, zero_point = 1.0, 0 + >>> dtype_inputs = torch.quint8 + >>> dtype_filters = torch.qint8 + >>> + >>> q_filters = torch.quantize_per_tensor(filters, scale, zero_point, dtype_filters) + >>> q_inputs = torch.quantize_per_tensor(inputs, scale, zero_point, dtype_inputs) + >>> qF.conv2d(q_inputs, q_filters, bias, padding=1, scale=scale, zero_point=zero_point) + """ # noqa: E501 + if padding_mode != "zeros": + raise NotImplementedError("Only zero-padding is supported!") + if input.dtype != torch.quint8: + raise NotImplementedError( + "Only torch.quint8 is supported for activation tensor!" + ) + if weight.dtype != torch.qint8: + raise NotImplementedError("Only torch.qint8 is supported for weight tensor!") + if input.ndim != 4: + raise ValueError("Input shape must be `(N, C, H, W)`!") + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + + packed_params = torch.ops.quantized.conv2d_prepack( + weight, bias, stride, padding, dilation, groups + ) + return torch.ops.quantized.conv2d(input, packed_params, scale, zero_point) + + +def conv3d( + input, + weight, + bias, + stride=1, + padding=0, + dilation=1, + groups=1, + padding_mode="zeros", + scale=1.0, + zero_point=0, + dtype=torch.quint8, +): + r""" + Applies a 3D convolution over a quantized 3D input composed of several input + planes. + + See :class:`~torch.ao.nn.quantized.Conv3d` for details and output shape. + + Args: + input: quantized input tensor of shape + :math:`(\text{minibatch} , \text{in\_channels} , iD , iH , iW)` + weight: quantized filters of shape + :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kD , kH , kW)` + bias: **non-quantized** bias tensor of shape + :math:`(\text{out\_channels})`. The tensor type must be `torch.float`. + stride: the stride of the convolving kernel. Can be a single number or a + tuple `(sD, sH, sW)`. Default: 1 + padding: implicit paddings on both sides of the input. Can be a + single number or a tuple `(padD, padH, padW)`. Default: 0 + dilation: the spacing between kernel elements. Can be a single number or + a tuple `(dD, dH, dW)`. Default: 1 + groups: split input into groups, :math:`\text{in\_channels}` should be + divisible by the number of groups. Default: 1 + padding_mode: the padding mode to use. Only "zeros" is supported for + quantized convolution at the moment. Default: "zeros" + scale: quantization scale for the output. Default: 1.0 + zero_point: quantization zero_point for the output. Default: 0 + dtype: quantization data type to use. Default: ``torch.quint8`` + + Examples:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) + >>> from torch.ao.nn.quantized import functional as qF + >>> filters = torch.randn(8, 4, 3, 3, 3, dtype=torch.float) + >>> inputs = torch.randn(1, 4, 5, 5, 5, dtype=torch.float) + >>> bias = torch.randn(8, dtype=torch.float) + >>> + >>> scale, zero_point = 1.0, 0 + >>> dtype_inputs = torch.quint8 + >>> dtype_filters = torch.qint8 + >>> + >>> q_filters = torch.quantize_per_tensor(filters, scale, zero_point, dtype_filters) + >>> q_inputs = torch.quantize_per_tensor(inputs, scale, zero_point, dtype_inputs) + >>> qF.conv3d(q_inputs, q_filters, bias, padding=1, scale=scale, zero_point=zero_point) + """ # noqa: E501 + if padding_mode != "zeros": + raise NotImplementedError("Only zero-padding is supported!") + if input.dtype != torch.quint8: + raise NotImplementedError( + "Only torch.quint8 is supported for activation tensor!" + ) + if weight.dtype != torch.qint8: + raise NotImplementedError("Only torch.qint8 is supported for weight tensor!") + if input.ndim != 5: + raise ValueError("Input shape must be `(N, C, D, H, W)`!") + stride = _triple(stride) + padding = _triple(padding) + dilation = _triple(dilation) + + packed_params = torch.ops.quantized.conv3d_prepack( + weight, bias, stride, padding, dilation, groups + ) + return torch.ops.quantized.conv3d(input, packed_params, scale, zero_point) + + +def interpolate( + input, size=None, scale_factor=None, mode="nearest", align_corners=None +): + r"""Down/up samples the input to either the given :attr:`size` or the given + :attr:`scale_factor` + + See :func:`torch.nn.functional.interpolate` for implementation details. + + The input dimensions are interpreted in the form: + `mini-batch x channels x [optional depth] x [optional height] x width`. + + .. note:: The input quantization parameters propagate to the output. + + .. note:: Only 2D/3D input is supported for quantized inputs + + .. note:: Only the following modes are supported for the quantized inputs: + + - `bilinear` + - `nearest` + + Args: + input (Tensor): the input tensor + size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]): + output spatial size. + scale_factor (float or Tuple[float]): multiplier for spatial size. Has to match input size if it is a tuple. + mode (str): algorithm used for upsampling: + ``'nearest'`` | ``'bilinear'`` + align_corners (bool, optional): Geometrically, we consider the pixels of the + input and output as squares rather than points. + If set to ``True``, the input and output tensors are aligned by the + center points of their corner pixels, preserving the values at the corner pixels. + If set to ``False``, the input and output tensors are aligned by the corner + points of their corner pixels, and the interpolation uses edge value padding + for out-of-boundary values, making this operation *independent* of input size + when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode` + is ``'bilinear'``. + Default: ``False`` + """ + if not input.is_quantized: + raise ValueError("Input to 'quantized.interpolate' must be quantized!") + return torch.nn.functional.interpolate( + input, size, scale_factor, mode, align_corners + ) + + +def linear( + input: Tensor, + weight: Tensor, + bias: Optional[Tensor] = None, + scale: Optional[float] = None, + zero_point: Optional[int] = None, +) -> Tensor: + r""" + Applies a linear transformation to the incoming quantized data: + :math:`y = xA^T + b`. + See :class:`~torch.ao.nn.quantized.Linear` + + .. note:: + + Current implementation packs weights on every call, which has penalty on performance. + If you want to avoid the overhead, use :class:`~torch.ao.nn.quantized.Linear`. + + Args: + input (Tensor): Quantized input of type `torch.quint8` + weight (Tensor): Quantized weight of type `torch.qint8` + bias (Tensor): None or fp32 bias of type `torch.float` + scale (double): output scale. If None, derived from the input scale + zero_point (long): output zero point. If None, derived from the input zero_point + + Shape: + - Input: :math:`(N, *, in\_features)` where `*` means any number of + additional dimensions + - Weight: :math:`(out\_features, in\_features)` + - Bias: :math:`(out\_features)` + - Output: :math:`(N, *, out\_features)` + """ + if scale is None: + scale = input.q_scale() + if zero_point is None: + zero_point = input.q_zero_point() + _packed_params = torch.ops.quantized.linear_prepack(weight, bias) + return torch.ops.quantized.linear(input, _packed_params, scale, zero_point) + + +def max_pool1d( + input, + kernel_size, + stride=None, + padding=0, + dilation=1, + ceil_mode=False, + return_indices=False, +): + r"""Applies a 1D max pooling over a quantized input signal composed of + several quantized input planes. + + .. note:: The input quantization parameters are propagated to the output. + + See :class:`~torch.ao.nn.quantized.MaxPool1d` for details. + """ + if return_indices: + raise NotImplementedError("return_indices is not yet implemented!") + if stride is None: + stride = torch.jit.annotate(list[int], []) + return torch.nn.functional.max_pool1d( + input, + kernel_size, + stride, + padding, + dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) + + +def max_pool2d( + input, + kernel_size, + stride=None, + padding=0, + dilation=1, + ceil_mode=False, + return_indices=False, +): + r"""Applies a 2D max pooling over a quantized input signal composed of + several quantized input planes. + + .. note:: The input quantization parameters are propagated to the output. + + See :class:`~torch.ao.nn.quantized.MaxPool2d` for details. + """ + if return_indices: + raise NotImplementedError("return_indices is not yet implemented!") + if stride is None: + stride = torch.jit.annotate(list[int], []) + return torch.nn.functional.max_pool2d( + input, + kernel_size, + stride, + padding, + dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) + + +def celu(input: Tensor, scale: float, zero_point: int, alpha: float = 1.0) -> Tensor: + r"""celu(input, scale, zero_point, alpha=1.) -> Tensor + + Applies the quantized CELU function element-wise. + + .. math:: + \text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x / \alpha) - 1)) + + Args: + input: quantized input + alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0 + """ + if not input.is_quantized: + raise ValueError("Input to 'quantized.celu' must be quantized!") + return torch.ops.quantized.celu(input, scale, zero_point, alpha) + + +def leaky_relu( + input: Tensor, + negative_slope: float = 0.01, + inplace: bool = False, + scale: Optional[float] = None, + zero_point: Optional[int] = None, +): + r""" + Quantized version of the. + leaky_relu(input, negative_slope=0.01, inplace=False, scale, zero_point) -> Tensor + + Applies element-wise, + :math:`\text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)` + + Args: + input: Quantized input + negative_slope: The slope of the negative input + inplace: Inplace modification of the input tensor + scale, zero_point: Scale and zero point of the output tensor. + + See :class:`~torch.nn.LeakyReLU` for more details. + """ + if scale is not None and zero_point is not None: + assert not inplace, "Cannot rescale with `inplace`" + output = torch._empty_affine_quantized( + input.shape, scale=scale, zero_point=int(zero_point), dtype=input.dtype + ) + torch._C._nn.leaky_relu(input, negative_slope, out=output) + return output + if inplace: + result = torch._C._nn.leaky_relu_(input, negative_slope) + else: + result = torch._C._nn.leaky_relu(input, negative_slope) + return result + + +def hardtanh( + input: Tensor, min_val: float = -1.0, max_val: float = 1.0, inplace: bool = False +) -> Tensor: + r"""This is the quantized version of :func:`~torch.nn.functional.hardtanh`.""" + if not input.is_quantized: + raise ValueError("Input to 'quantized.hardtanh' must be quantized!") + if inplace: + return torch._C._nn.hardtanh_(input, min_val, max_val) + return torch._C._nn.hardtanh(input, min_val, max_val) + + +def hardswish(input: Tensor, scale: float, zero_point: int) -> Tensor: + r"""This is the quantized version of :func:`~torch.nn.functional.hardswish`. + + Args: + input: quantized input + scale: quantization scale of the output tensor + zero_point: quantization zero point of the output tensor + """ + if not input.is_quantized: + raise ValueError("Input to 'quantized.hardswish' must be quantized!") + return torch._ops.ops.quantized.hardswish(input, scale, zero_point) + + +def threshold(input: Tensor, threshold: float, value: float) -> Tensor: + r"""Applies the quantized version of the threshold function element-wise: + + .. math:: + x = \begin{cases} + x & \text{if~} x > \text{threshold} \\ + \text{value} & \text{otherwise} + \end{cases} + + See :class:`~torch.nn.Threshold` for more details. + """ + if not input.is_quantized: + raise ValueError("Input to 'quantized.threshold' must be quantized!") + if threshold is None: + raise ValueError("Input to 'threshold' must be specified!") + if value is None: + raise ValueError("Input to 'value' must be specified!") + return torch._ops.ops.quantized.threshold(input, threshold, value) + + +def elu(input: Tensor, scale: float, zero_point: int, alpha: float = 1.0) -> Tensor: + r"""This is the quantized version of :func:`~torch.nn.functional.elu`. + + Args: + input: quantized input + scale: quantization scale of the output tensor + zero_point: quantization zero point of the output tensor + alpha: the alpha constant + """ + if not input.is_quantized: + raise ValueError("Input to 'quantized.elu' must be quantized!") + return torch.ops.quantized.elu(input, scale, zero_point, alpha) + + +def hardsigmoid(input: Tensor, inplace: bool = False) -> Tensor: + r"""This is the quantized version of :func:`~torch.nn.functional.hardsigmoid`.""" + if not input.is_quantized: + raise ValueError("Input to 'quantized.hardsigmoid' must be quantized!") + if inplace: + return torch._C._nn.hardsigmoid_(input) # type: ignore[attr-defined] + return torch._C._nn.hardsigmoid(input) + + +def clamp(input: Tensor, min_: float, max_: float) -> Tensor: + r"""float(input, min\_, max\_) -> Tensor + + Applies the clamp function element-wise. + See :class:`~torch.ao.nn.quantized.clamp` for more details. + + Args: + input: quantized input + min_: minimum value for clamping + max_: maximum value for clamping + """ + if not input.is_quantized: + raise ValueError("Input to 'quantized.clamp' must be quantized!") + return torch.clamp(input, min_, max_) + + +def upsample(input, size=None, scale_factor=None, mode="nearest", align_corners=None): + r"""Upsamples the input to either the given :attr:`size` or the given + :attr:`scale_factor` + + .. warning:: + This function is deprecated in favor of + :func:`torch.ao.nn.quantized.functional.interpolate`. + This is equivalent with ``nn.quantized.functional.interpolate(...)``. + + See :func:`torch.nn.functional.interpolate` for implementation details. + + The input dimensions are interpreted in the form: + `mini-batch x channels x [optional depth] x [optional height] x width`. + + .. note:: The input quantization parameters propagate to the output. + + .. note:: Only 2D input is supported for quantized inputs + + .. note:: Only the following modes are supported for the quantized inputs: + + - `bilinear` + - `nearest` + + Args: + input (Tensor): quantized input tensor + size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]): + output spatial size. + scale_factor (float or Tuple[float]): multiplier for spatial size. Has to be an integer. + mode (str): algorithm used for upsampling: + ``'nearest'`` | ``'bilinear'`` + align_corners (bool, optional): Geometrically, we consider the pixels of the + input and output as squares rather than points. + If set to ``True``, the input and output tensors are aligned by the + center points of their corner pixels, preserving the values at the corner pixels. + If set to ``False``, the input and output tensors are aligned by the corner + points of their corner pixels, and the interpolation uses edge value padding + for out-of-boundary values, making this operation *independent* of input size + when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode` + is ``'bilinear'``. + Default: ``False`` + + .. warning:: + With ``align_corners = True``, the linearly interpolating modes + (`bilinear`) don't proportionally align the + output and input pixels, and thus the output values can depend on the + input size. This was the default behavior for these modes up to version + 0.3.1. Since then, the default behavior is ``align_corners = False``. + See :class:`~torch.nn.Upsample` for concrete examples on how this + affects the outputs. + """ + warnings.warn( + "nn.quantized.functional.upsample is deprecated. Use nn.quantized.functional.interpolate instead." + ) + return interpolate(input, size, scale_factor, mode, align_corners) + + +def upsample_bilinear(input, size=None, scale_factor=None): + r"""Upsamples the input, using bilinear upsampling. + + .. warning:: + This function is deprecated in favor of + :func:`torch.ao.nn.quantized.functional.interpolate`. + This is equivalent with + ``nn.quantized.functional.interpolate(..., mode='bilinear', align_corners=True)``. + + .. note:: The input quantization parameters propagate to the output. + + .. note:: Only 2D inputs are supported + + Args: + input (Tensor): quantized input + size (int or Tuple[int, int]): output spatial size. + scale_factor (int or Tuple[int, int]): multiplier for spatial size + """ + # DeprecationWarning is ignored by default + warnings.warn( + "nn.quantized.functional.upsample_bilinear is deprecated. Use nn.quantized.functional.interpolate instead." + ) + return interpolate(input, size, scale_factor, mode="bilinear", align_corners=True) + + +def upsample_nearest(input, size=None, scale_factor=None): + r"""Upsamples the input, using nearest neighbours' pixel values. + + .. warning:: + This function is deprecated in favor of + :func:`torch.ao.nn.quantized.functional.interpolate`. + This is equivalent with ``nn.quantized.functional.interpolate(..., mode='nearest')``. + + .. note:: The input quantization parameters propagate to the output. + + .. note:: Only 2D inputs are supported + + Args: + input (Tensor): quantized input + size (int or Tuple[int, int] or Tuple[int, int, int]): output spatial + size. + scale_factor (int): multiplier for spatial size. Has to be an integer. + """ + # DeprecationWarning is ignored by default + warnings.warn( + "nn.quantized.functional.upsample_nearest is deprecated. Use nn.quantized.functional.interpolate instead." + ) + return interpolate(input, size, scale_factor, mode="nearest") diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/__init__.py b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d870e8260d7f4aeefac680312e4ae5cdfbd12294 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/__init__.py @@ -0,0 +1,162 @@ +# mypy: allow-untyped-defs +import torch + +# The quantized modules use `torch.nn` and `torch.ao.nn.quantizable` +# packages. However, the `quantizable` package uses "lazy imports" +# to avoid circular dependency. +# Hence we need to include it here to make sure it is resolved before +# they are used in the modules. +import torch.ao.nn.quantizable +from torch.nn.modules.pooling import MaxPool2d + +from .activation import ( + ELU, + Hardswish, + LeakyReLU, + MultiheadAttention, + PReLU, + ReLU6, + Sigmoid, + Softmax, +) +from .batchnorm import BatchNorm2d, BatchNorm3d +from .conv import ( + Conv1d, + Conv2d, + Conv3d, + ConvTranspose1d, + ConvTranspose2d, + ConvTranspose3d, +) +from .dropout import Dropout +from .embedding_ops import Embedding, EmbeddingBag +from .functional_modules import FloatFunctional, FXFloatFunctional, QFunctional +from .linear import Linear +from .normalization import ( + GroupNorm, + InstanceNorm1d, + InstanceNorm2d, + InstanceNorm3d, + LayerNorm, +) +from .rnn import LSTM + + +__all__ = [ + "BatchNorm2d", + "BatchNorm3d", + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", + "DeQuantize", + "ELU", + "Embedding", + "EmbeddingBag", + "GroupNorm", + "Hardswish", + "InstanceNorm1d", + "InstanceNorm2d", + "InstanceNorm3d", + "LayerNorm", + "LeakyReLU", + "Linear", + "LSTM", + "MultiheadAttention", + "Quantize", + "ReLU6", + "Sigmoid", + "Softmax", + "Dropout", + "PReLU", + # Wrapper modules + "FloatFunctional", + "FXFloatFunctional", + "QFunctional", +] + + +class Quantize(torch.nn.Module): + r"""Quantizes an incoming tensor + + Args: + `scale`: scale of the output Quantized Tensor + `zero_point`: zero_point of output Quantized Tensor + `dtype`: data type of output Quantized Tensor + `factory_kwargs`: Dictionary of kwargs used for configuring initialization + of internal buffers. Currently, `device` and `dtype` are supported. + Example: `factory_kwargs={'device': 'cuda', 'dtype': torch.float64}` + will initialize internal buffers as type `torch.float64` on the current CUDA device. + Note that `dtype` only applies to floating-point buffers. + + Examples:: + >>> t = torch.tensor([[1., -1.], [1., -1.]]) + >>> scale, zero_point, dtype = 1.0, 2, torch.qint8 + >>> qm = Quantize(scale, zero_point, dtype) + >>> # xdoctest: +SKIP + >>> qt = qm(t) + >>> print(qt) + tensor([[ 1., -1.], + [ 1., -1.]], size=(2, 2), dtype=torch.qint8, scale=1.0, zero_point=2) + """ + + scale: torch.Tensor + zero_point: torch.Tensor + + def __init__(self, scale, zero_point, dtype, factory_kwargs=None): + factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) + super().__init__() + self.register_buffer("scale", torch.tensor([scale], **factory_kwargs)) + self.register_buffer( + "zero_point", + torch.tensor( + [zero_point], + dtype=torch.long, + **{k: v for k, v in factory_kwargs.items() if k != "dtype"}, + ), + ) + self.dtype = dtype + + def forward(self, X): + return torch.quantize_per_tensor( + X, float(self.scale), int(self.zero_point), self.dtype + ) + + @staticmethod + def from_float(mod, use_precomputed_fake_quant=False): + assert hasattr(mod, "activation_post_process") + scale, zero_point = mod.activation_post_process.calculate_qparams() + return Quantize( + scale.float().item(), + zero_point.long().item(), + mod.activation_post_process.dtype, + ) + + def extra_repr(self): + return f"scale={self.scale}, zero_point={self.zero_point}, dtype={self.dtype}" + + +class DeQuantize(torch.nn.Module): + r"""Dequantizes an incoming tensor + + Examples:: + >>> input = torch.tensor([[1., -1.], [1., -1.]]) + >>> scale, zero_point, dtype = 1.0, 2, torch.qint8 + >>> qm = Quantize(scale, zero_point, dtype) + >>> # xdoctest: +SKIP + >>> quantized_input = qm(input) + >>> dqm = DeQuantize() + >>> dequantized = dqm(quantized_input) + >>> print(dequantized) + tensor([[ 1., -1.], + [ 1., -1.]], dtype=torch.float32) + """ + + def forward(self, Xq): + return Xq.dequantize() + + @staticmethod + def from_float(mod, use_precomputed_fake_quant=False): + return DeQuantize() diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60f7a8c5a3e50f5ca45b38d06a604100888e1363 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/activation.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/activation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8407da3891c786f4920a25c2004e6384735d2ef Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/activation.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/batchnorm.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/batchnorm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61679652a7238c1935ab601741d7d34d6dbcf106 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/batchnorm.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/conv.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/conv.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61235e23de0ce171fcdeab451d650b87b0f1f570 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/conv.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/dropout.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/dropout.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc520846278fe316447b79666569f9bd0654ec86 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/dropout.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/embedding_ops.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/embedding_ops.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5da8f06187787975e5993f1453c6007f5929c7e9 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/embedding_ops.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/functional_modules.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/functional_modules.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4c3a343c5d222d60de3564f4a90fe264520c31b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/functional_modules.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/linear.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/linear.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e7e1cfd978e92b1c5e18125f30eb2249a8148ed Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/linear.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/normalization.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/normalization.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de351cc8775a5a5357b273cc5af096552fa99dfd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/normalization.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/rnn.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/rnn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85f8f1c14b31c7cb91ccda5d99cad57a115d4f22 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/rnn.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab5c0c93ed7a43e879b9bc366a7bae1310b1cc35 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/activation.py b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..e39951070e91b7571517ef11816c45f195d31ef5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/activation.py @@ -0,0 +1,344 @@ +# mypy: allow-untyped-defs +from warnings import warn + +import torch + + +__all__ = [ + "ReLU6", + "Hardswish", + "ELU", + "LeakyReLU", + "Sigmoid", + "Softmax", + "MultiheadAttention", + "PReLU", +] + + +class ReLU6(torch.nn.ReLU): + r"""Applies the element-wise function: + + :math:`\text{ReLU6}(x) = \min(\max(x_0, x), q(6))`, where :math:`x_0` is the + zero_point, and :math:`q(6)` is the quantized representation of number 6. + + Args: + inplace: can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(N, *)` where `*` means, any number of additional + dimensions + - Output: :math:`(N, *)`, same shape as the input + + .. image:: ../scripts/activation_images/ReLU6.png + + Examples:: + + >>> m = nn.quantized.ReLU6() + >>> input = torch.randn(2) + >>> # xdoctest: +SKIP + >>> input = torch.quantize_per_tensor(input, 1.0, 0, dtype=torch.qint32) + >>> output = m(input) + """ + + def __init__(self, inplace=False): + super().__init__(inplace) + self.inplace = inplace + + def forward(self, input): + return torch.ops.quantized.relu6(input, self.inplace) + + def _get_name(self): + return "QuantizedReLU6" + + @staticmethod + def from_float(mod, use_precomputed_fake_quant=False): + return ReLU6(mod.inplace) + + +class Hardswish(torch.nn.Hardswish): + r"""This is the quantized version of :class:`~torch.nn.Hardswish`. + + Args: + scale: quantization scale of the output tensor + zero_point: quantization zero point of the output tensor + """ + + def __init__(self, scale, zero_point, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) + self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) + + def forward(self, input): + return torch.ops.quantized.hardswish(input, self.scale, self.zero_point) + + def _get_name(self): + return "QuantizedHardswish" + + @staticmethod + def from_float(mod, use_precomputed_fake_quant=False): + scale, zero_point = mod.activation_post_process.calculate_qparams() + return Hardswish(float(scale), int(zero_point)) + + @classmethod + def from_reference(cls, mod, scale, zero_point): + return cls(float(scale), int(zero_point)) + + +class ELU(torch.nn.ELU): + r"""This is the quantized equivalent of :class:`~torch.nn.ELU`. + + Args: + scale: quantization scale of the output tensor + zero_point: quantization zero point of the output tensor + alpha: the alpha constant + """ + + def __init__(self, scale, zero_point, alpha=1.0): + super().__init__(alpha) + self.scale = scale + self.zero_point = zero_point + + def forward(self, input): + return torch.ao.nn.quantized.functional.elu( + input, self.scale, self.zero_point, self.alpha + ) + + def _get_name(self): + return "QuantizedELU" + + @staticmethod + def from_float(mod, use_precomputed_fake_quant=False): + scale, zero_point = mod.activation_post_process.calculate_qparams() + return ELU(float(scale), int(zero_point), mod.alpha) + + @classmethod + def from_reference(cls, mod, scale, zero_point): + return cls(float(scale), int(zero_point), mod.alpha) + + +class LeakyReLU(torch.nn.LeakyReLU): + r"""This is the quantized equivalent of :class:`~torch.nn.LeakyReLU`. + + Args: + scale: quantization scale of the output tensor + zero_point: quantization zero point of the output tensor + negative_slope: Controls the angle of the negative slope. Default: 1e-2 + """ + + def __init__( + self, + scale: float, + zero_point: int, + negative_slope: float = 1e-2, + inplace: bool = False, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(negative_slope, inplace) + self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) + self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) + + def forward(self, input): + return torch.ops.quantized.leaky_relu( + input, self.negative_slope, self.inplace, self.scale, self.zero_point + ) + + def _get_name(self): + return "QuantizedLeakyReLU" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + scale, zero_point = mod.activation_post_process.calculate_qparams() + return cls(float(scale), int(zero_point), mod.negative_slope, mod.inplace) + + @classmethod + def from_reference(cls, mod, scale, zero_point): + return cls(float(scale), int(zero_point), mod.negative_slope, mod.inplace) + + +class Sigmoid(torch.nn.Sigmoid): + r"""This is the quantized equivalent of :class:`~torch.nn.Sigmoid`. + + Args: + scale: quantization scale of the output tensor + zero_point: quantization zero point of the output tensor + """ + + def __init__(self, output_scale: float, output_zero_point: int): + super().__init__() + self.output_scale = output_scale + self.output_zero_point = output_zero_point + + def forward(self, input): + return torch.ops.quantized.sigmoid( + input, self.output_scale, self.output_zero_point + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + ( + output_scale, + output_zero_point, + ) = mod.activation_post_process.calculate_qparams() + return cls(float(output_scale), int(output_zero_point)) + + +class Softmax(torch.nn.Softmax): + r"""This is the quantized version of :class:`~torch.nn.Softmax`. + + Args: + dim: A dimension along which Softmax will be computed (so every slice along dim will sum to 1). + scale: quantization scale of the output tensor + zero_point: quantization zero point of the output tensor + """ + + def __init__(self, dim=None, scale=1.0, zero_point=0): + super().__init__() + self.dim = dim + self.scale = scale + self.zero_point = zero_point + + def forward(self, input): + dim = self.dim + if dim is None: + stacklevel = 3 + # Note: adding the mypy ignore on _get_softmax_dim seems less bad + # than making `_get_softmax_dim` an official API. + dim = torch.nn.functional._get_softmax_dim( # type: ignore[attr-defined] + "softmax", input.dim(), stacklevel + ) + return torch.ops.quantized.softmax(input, dim, self.scale, self.zero_point) + + def _get_name(self): + return "QuantizedSoftmax" + + @staticmethod + def from_float(mod, use_precomputed_fake_quant=False): + scale, zero_point = mod.activation_post_process.calculate_qparams() + return Softmax(mod.dim, float(scale), int(zero_point)) + + @classmethod + def from_reference(cls, mod, scale, zero_point): + return cls(mod.dim, float(scale), int(zero_point)) + + +class MultiheadAttention(torch.ao.nn.quantizable.MultiheadAttention): + _FLOAT_MODULE = torch.ao.nn.quantizable.MultiheadAttention + + def _get_name(self): + return "QuantizedMultiheadAttention" + + @classmethod + def from_float(cls, other): + # The whole flow is float -> observed -> quantized + # This class does observed -> quantized only + raise NotImplementedError( + "It looks like you are trying to convert a " + "non-observed MHA module. Please, see " + "the examples on quantizable MHAs." + ) + + @classmethod + def from_observed(cls, other): + converted = torch.ao.quantization.convert( + other, + mapping=None, + inplace=False, + remove_qconfig=True, + convert_custom_config_dict=None, + ) + converted.__class__ = cls + # Remove the parameters for the bias_k and bias_v to quantize them + # TODO: This is a potential source of accuracy drop. + # quantized cat takes the scale and zp of the first + # element, which might lose the precision in the bias_k + # and the bias_v (which are cat'ed with k/v being first). + if converted.bias_k is not None: + bias_k = converted._parameters.pop("bias_k") + sc, zp = torch._choose_qparams_per_tensor(bias_k, reduce_range=False) + bias_k = torch.quantize_per_tensor(bias_k, sc, zp, torch.quint8) + setattr(converted, "bias_k", bias_k) # noqa: B010 + + if converted.bias_v is not None: + bias_v = converted._parameters.pop("bias_v") + sc, zp = torch._choose_qparams_per_tensor( + bias_k, # type: ignore[possibly-undefined] + reduce_range=False, + ) + bias_v = torch.quantize_per_tensor(bias_v, sc, zp, torch.quint8) + setattr(converted, "bias_v", bias_v) # noqa: B010 + + del converted.in_proj_weight + del converted.in_proj_bias + + return converted + + +class PReLU(torch.nn.Module): + r"""This is the quantized equivalent of :class:`~torch.nn.PReLU`. + + Args: + scale: quantization scale of the output tensor + zero_point: quantization zero point of the output tensor + num_parameters: number of parameters: 1, or the number of channels at input. Default: 1 + """ + + def __init__( + self, output_scale: float, output_zero_point: int, num_parameters: int = 1 + ) -> None: + super().__init__() + self.num_parameters = num_parameters + self.scale = output_scale + self.zero_point = output_zero_point + w = torch.randn(num_parameters, dtype=torch.float) + qw = torch.quantize_per_tensor(w, scale=1.0, zero_point=0, dtype=torch.quint8) + self.set_weight(qw) + + def set_weight(self, w: torch.Tensor) -> None: + self.weight = w + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return torch.ops.quantized.prelu( + input, self.weight, self.scale, self.zero_point + ) + + def _get_name(self): + return "QuantizedPReLU" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + scale, zero_point = mod.activation_post_process.calculate_qparams() + qprelu = cls(float(scale), int(zero_point), mod.num_parameters) + float_wt = mod.weight.float() + observer = mod.qconfig.weight() + observer(float_wt) + if observer.dtype != torch.quint8: + warn( + f"PReLU's weight observer should have dtype quint8 but got {observer.dtype}" + ) + wt_scale, wt_zp = observer.calculate_qparams() + qweight = torch.quantize_per_tensor( + float_wt, float(wt_scale), int(wt_zp), torch.quint8 + ) + qprelu.set_weight(qweight) + return qprelu + + @classmethod + def from_reference(cls, mod, scale, zero_point): + qprelu = cls(float(scale), int(zero_point), mod.num_parameters) + float_wt = mod.weight.float() + observer = mod.qconfig.weight() + observer(float_wt) + if observer.dtype != torch.quint8: + warn( + f"PReLU's weight observer should have dtype quint8 but got {observer.dtype}" + ) + wt_scale, wt_zp = observer.calculate_qparams() + qweight = torch.quantize_per_tensor( + float_wt, float(wt_scale), int(wt_zp), torch.quint8 + ) + qprelu.set_weight(qweight) + return qprelu diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/batchnorm.py b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..695e16362fb7afe3f163f7d84ebca831f71451cb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/batchnorm.py @@ -0,0 +1,128 @@ +# mypy: allow-untyped-defs +import torch +import torch.ao.nn.intrinsic as nni + + +__all__ = ["BatchNorm2d", "BatchNorm3d"] + + +class _BatchNorm(torch.nn.modules.batchnorm._BatchNorm): + def __init__( + self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(num_features, eps, momentum, True, True, **factory_kwargs) + self.register_buffer("scale", torch.tensor(1.0, **factory_kwargs)) + self.register_buffer("zero_point", torch.tensor(0, **factory_kwargs)) + + @staticmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + activation_post_process = mod.activation_post_process + if type(mod) == cls._NNI_BN_RELU_MODULE: + mod = mod[0] + scale, zero_point = activation_post_process.calculate_qparams() + new_mod = cls(mod.num_features, mod.eps) + new_mod.weight = mod.weight + new_mod.bias = mod.bias + new_mod.running_mean = mod.running_mean + new_mod.running_var = mod.running_var + new_mod.scale = scale + new_mod.zero_point = zero_point + return new_mod + + @classmethod + def from_reference(cls, bn, output_scale, output_zero_point): + qbn = cls( + bn.num_features, + bn.eps, + bn.momentum, + device=bn.weight.device, + dtype=bn.weight.dtype, + ) + qbn.weight = bn.weight + qbn.bias = bn.bias + qbn.running_mean = bn.running_mean + qbn.running_var = bn.running_var + qbn.scale = output_scale + qbn.zero_point = output_zero_point + return qbn + + +class BatchNorm2d(_BatchNorm): + r"""This is the quantized version of :class:`~torch.nn.BatchNorm2d`.""" + + _NNI_BN_RELU_MODULE = nni.BNReLU2d + + def __init__( + self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(num_features, eps, momentum, **factory_kwargs) + + def _get_name(self): + return "QuantizedBatchNorm2d" + + def _check_input_dim(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 4: + raise ValueError("Input shape must be `(N, C, H, W)`!") + + def forward(self, input: torch.Tensor) -> torch.Tensor: + # disabling this since this is not symbolically traceable + # self._check_input_dim(input) + return torch.ops.quantized.batch_norm2d( + input, + self.weight, + self.bias, + self.running_mean, + self.running_var, + self.eps, + self.scale, + self.zero_point, + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + return _BatchNorm.from_float( + cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + +class BatchNorm3d(_BatchNorm): + r"""This is the quantized version of :class:`~torch.nn.BatchNorm3d`.""" + + _NNI_BN_RELU_MODULE = nni.BNReLU3d + + def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(num_features, eps, momentum, **factory_kwargs) + + def _get_name(self): + return "QuantizedBatchNorm3d" + + def _check_input_dim(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 5: + raise ValueError("Input shape must be `(N, C, H, W)`!") + + def forward(self, input: torch.Tensor) -> torch.Tensor: + # disabling this since this is not symbolically traceable + # self._check_input_dim(input) + return torch.ops.quantized.batch_norm3d( + input, + self.weight, + self.bias, + self.running_mean, + self.running_var, + self.eps, + self.scale, + self.zero_point, + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + return _BatchNorm.from_float( + cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/conv.py b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..787a54a2d8b0de5329494c9696183c8cf3c9b099 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/conv.py @@ -0,0 +1,1243 @@ +# mypy: allow-untyped-defs +r"""Quantized convolution modules.""" + +from typing import ClassVar, Optional + +import torch +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.intrinsic.qat as nniqat +import torch.nn as nn +import torch.nn.functional as F +from torch._ops import ops +from torch.nn.common_types import _size_1_t +from torch.nn.modules.utils import _pair, _single, _triple +from torch.nn.utils import fuse_conv_bn_weights + +from .utils import _quantize_weight, WeightedQuantizedModule + + +__all__ = [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", +] + +_SUPPORTED_PADDING = {"zeros", "reflect"} + + +def _reverse_repeat_padding(padding: list[int]) -> list[int]: + _reversed_padding_repeated_twice: list[int] = [] + N = len(padding) + for idx in range(N): + _reversed_padding_repeated_twice.extend(padding[N - idx - 1] for _ in range(2)) + return _reversed_padding_repeated_twice + + +class _ConvNd(WeightedQuantizedModule): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + ): + # All subclasses have this signature - See PR #49702s + raise NotImplementedError + + def _init( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + bias, + padding_mode="zeros", + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + if in_channels % groups != 0: + raise ValueError("in_channels must be divisible by groups") + if out_channels % groups != 0: + raise ValueError("out_channels must be divisible by groups") + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.transposed = transposed + self.output_padding = output_padding + self.groups = groups + if padding_mode not in _SUPPORTED_PADDING: + raise ValueError( + f"'padding_mode' {padding_mode} is not supported by quantized convolution" + ) + self.padding_mode = padding_mode + # Initialize as NCHW. set_weight will internally transpose to NHWC. + if self.transposed: + weight_shape = [in_channels, out_channels // self.groups] + else: + weight_shape = [out_channels, in_channels // self.groups] + qweight = torch._empty_affine_quantized( + weight_shape + list(kernel_size), + scale=1, + zero_point=0, + dtype=torch.qint8, + **{k: v for k, v in factory_kwargs.items() if k != "dtype"}, + ) + bias_float = ( + torch.zeros( + out_channels, + dtype=torch.float, + **{k: v for k, v in factory_kwargs.items() if k != "dtype"}, + ) + if bias + else None + ) + + self.set_weight_bias(qweight, bias_float) + self.scale = 1.0 + self.zero_point = 0 + + def set_weight_bias(self, qweight, bias_float): + raise NotImplementedError + + def bias(self): + raise NotImplementedError + + def _weight_bias(self): + raise NotImplementedError + + def extra_repr(self): + s = ( + "{in_channels}, {out_channels}, kernel_size={kernel_size}" + ", stride={stride}, scale={scale}, zero_point={zero_point}" + ) + if self.padding != (0,) * len(self.padding): + s += ", padding={padding}" + if self.dilation != (1,) * len(self.dilation): + s += ", dilation={dilation}" + if self.output_padding != (0,) * len(self.output_padding): + s += ", output_padding={output_padding}" + if self.groups != 1: + s += ", groups={groups}" + if self.bias() is None: + s += ", bias=False" + return s.format(**self.__dict__) + + # ===== Serialization methods ===== + # The special consideration here is that we have to unpack the weights into + # their regular QTensor form for serialization. Packed weights should not + # live outside the process in which they were created, rather they should be + # derived from the QTensor weight. + # self + # |--- weight : Tensor + # |--- bias : Tensor + # + # TODO: maybe change to this when https://github.com/pytorch/pytorch/pull/32958 is landed + # self + # |--- _packed_params : Conv2dPackedParamsBase or Conv3dPackedParamsBase + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + (w, b) = self._weight_bias() + destination[prefix + "weight"] = w + destination[prefix + "bias"] = b + destination[prefix + "scale"] = torch.tensor(self.scale) + destination[prefix + "zero_point"] = torch.tensor(self.zero_point) + + @torch.jit.export + def __getstate__(self): + (w, b) = self._weight_bias() + return ( + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + self.padding, + self.dilation, + self.transposed, + self.output_padding, + self.groups, + self.padding_mode, + w, + b, + self.scale, + self.zero_point, + self.training, + ) + + # ===== Deserialization methods ===== + # Counterpart to the serialization methods, we must pack the serialized + # QTensor weight into its packed format for use by the FBGEMM ops. + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + self.set_weight_bias(state_dict[prefix + "weight"], state_dict[prefix + "bias"]) + state_dict.pop(prefix + "weight") + state_dict.pop(prefix + "bias") + self.scale = float(state_dict[prefix + "scale"]) + state_dict.pop(prefix + "scale") + self.zero_point = int(state_dict[prefix + "zero_point"]) + state_dict.pop(prefix + "zero_point") + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + @torch.jit.export + def __setstate__(self, state): + self.in_channels = state[0] + self.out_channels = state[1] + self.kernel_size = state[2] + self.stride = state[3] + self.padding = state[4] + self.dilation = state[5] + self.transposed = state[6] + self.output_padding = state[7] + self.groups = state[8] + self.padding_mode = state[9] + self.set_weight_bias(state[10], state[11]) + self.scale = state[12] + self.zero_point = state[13] + self.training = state[14] + + def __deepcopy__(self, memo): + new_instance = type(self).__new__(type(self)) + torch.nn.Module.__init__(new_instance) + state = self.__getstate__() + new_instance.__setstate__(state) + return new_instance + + def __copy__(self): + return self.__deepcopy__({}) + + @classmethod + def get_qconv(cls, mod, activation_post_process, weight_post_process=None): + r"""Creates a qconv object and returns it.""" + if weight_post_process is None: + weight_post_process = mod.qconfig.weight() + weight_post_process(mod.weight) + assert weight_post_process.dtype == torch.qint8, ( + "Weight observer must have a dtype of qint8" + ) + qweight = _quantize_weight(mod.weight.float(), weight_post_process) + # the __init__ call used is the one from derived classes and not the one from _ConvNd + qconv = cls( + mod.in_channels, + mod.out_channels, + mod.kernel_size, + mod.stride, + mod.padding, + mod.dilation, + mod.groups, + mod.bias is not None, + mod.padding_mode, + ) + qconv.set_weight_bias(qweight, mod.bias) + if ( + activation_post_process is None + or activation_post_process.dtype == torch.float + ): + return qconv # dynamic quantization doesn't need scale/zero_point + else: + act_scale, act_zp = activation_post_process.calculate_qparams() + qconv.scale = float(act_scale) + qconv.zero_point = int(act_zp) + return qconv + + @staticmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + if hasattr(mod, "weight_fake_quant"): + # assert type(mod) == cls.__QAT_MODULE, " nnq." + cls.__name__ + \ + # ".from_float only works for " + cls.__QAT_MODULE.__name__ + if type(mod) == cls._NNIQAT_CONV_BN_MODULE: + mod.weight, mod.bias = fuse_conv_bn_weights( + mod.weight, + mod.bias, + mod.bn.running_mean, + mod.bn.running_var, + mod.bn.eps, + mod.bn.weight, + mod.bn.bias, + ) + assert hasattr(mod, "activation_post_process"), ( + "Input QAT module must have observer attached" + ) + weight_post_process = mod.weight_fake_quant + activation_post_process = mod.activation_post_process + else: + assert type(mod) == cls._FLOAT_MODULE, ( + " nnq." + + cls.__name__ + + ".from_float only works for " + + cls._FLOAT_MODULE.__name__ + + " but got:" + + str(type(mod)) + ) + assert hasattr(mod, "qconfig"), ( + "Input float module must have qconfig defined." + ) + activation_post_process = ( + None + if not hasattr(mod, "activation_post_process") + else mod.activation_post_process + ) + if type(mod) in [ + cls._NNI_CONV_RELU_MODULE, + cls._NNI_CONV_ADD_MODULE, + cls._NNI_CONV_ADD_RELU_MODULE, + ]: + mod = mod[0] + weight_post_process = mod.qconfig.weight() + return cls.get_qconv(mod, activation_post_process, weight_post_process) + + @classmethod + def from_reference(cls, ref_qconv, output_scale, output_zero_point): + r"""Create a (fbgemm/qnnpack) quantized module from a reference quantized module + Args: + ref_qconv (Module): a reference quantized module, either produced by torch.ao.quantization + utilities or provided by the user + output_scale (float): scale for output Tensor + output_zero_point (int): zero point for output Tensor + """ + qconv = cls( + ref_qconv.in_channels, + ref_qconv.out_channels, + ref_qconv.kernel_size, # type: ignore[arg-type] + ref_qconv.stride, # type: ignore[arg-type] + ref_qconv.padding, # type: ignore[arg-type] + ref_qconv.dilation, # type: ignore[arg-type] + ref_qconv.groups, + ref_qconv.bias is not None, # type: ignore[arg-type] + ref_qconv.padding_mode, + device=ref_qconv.weight.device, + dtype=ref_qconv.weight.dtype, + ) + qweight = ref_qconv.get_quantized_weight() + qconv.set_weight_bias(qweight, ref_qconv.bias) + qconv.scale = float(output_scale) + qconv.zero_point = int(output_zero_point) + return qconv + + +class Conv1d(_ConvNd): + r"""Applies a 1D convolution over a quantized input signal composed of + several quantized input planes. + + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.Conv1d`. + + .. note:: + Only `zeros` is supported for the :attr:`padding_mode` argument. + + .. note:: + Only `torch.quint8` is supported for the input data type. + + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + + See :class:`~torch.nn.Conv1d` for other attributes. + + Examples:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) + >>> m = nn.quantized.Conv1d(16, 33, 3, stride=2) + >>> input = torch.randn(20, 16, 100) + >>> # quantize input to quint8 + >>> # xdoctest: +SKIP + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, + ... dtype=torch.quint8) + >>> output = m(q_input) + + """ + + _FLOAT_MODULE: ClassVar[type[nn.Conv1d]] = nn.Conv1d + _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = nniqat.ConvBn1d + _NNI_CONV_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = nni.ConvReLU1d + _NNI_CONV_ADD_MODULE: ClassVar[Optional[type[nn.Module]]] = None + _NNI_CONV_ADD_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = None + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: _size_1_t = 0, + dilation: _size_1_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size = _single(kernel_size) + stride = _single(stride) + padding = padding if isinstance(padding, str) else _single(padding) + dilation = _single(dilation) + + # Subclasses of _ConvNd needs to call _init rather than __init__. See + # discussion on PR #49702 + super()._init( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + False, + _single(0), + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _get_name(self): + return "QuantizedConv1d" + + def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: + if self.padding_mode == "zeros": + self._packed_params = torch.ops.quantized.conv1d_prepack( + w, b, self.stride, self.padding, self.dilation, self.groups + ) + else: + self._packed_params = torch.ops.quantized.conv1d_prepack( + w, b, self.stride, _pair(0), self.dilation, self.groups + ) + + def _weight_bias(self): + w, b = torch.ops.quantized.conv1d_unpack(self._packed_params) + return w, b + + def weight(self): + return self._weight_bias()[0] + + def bias(self): + return self._weight_bias()[1] + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 3: + raise ValueError("Input shape must be `(N, C, L)`!") + if self.padding_mode != "zeros": + # Padding in Conv1d is stored as (p, p), need to get (p,) + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1]) + input = F.pad( + input, _reversed_padding_repeated_twice, mode=self.padding_mode + ) + return ops.quantized.conv1d( + input, self._packed_params, self.scale, self.zero_point + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + r"""Creates a quantized module from a float module or qparams_dict. + + Args: + mod (Module): a float module, either produced by torch.ao.quantization + utilities or provided by the user + """ + return _ConvNd.from_float( + cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + +class Conv2d(_ConvNd): + r"""Applies a 2D convolution over a quantized input signal composed of + several quantized input planes. + + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.Conv2d`. + + .. note:: + Only `zeros` is supported for the :attr:`padding_mode` argument. + + .. note:: + Only `torch.quint8` is supported for the input data type. + + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + + See :class:`~torch.nn.Conv2d` for other attributes. + + Examples:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) + >>> # With square kernels and equal stride + >>> m = nn.quantized.Conv2d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nn.quantized.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) + >>> # non-square kernels and unequal stride and with padding and dilation + >>> m = nn.quantized.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)) + >>> input = torch.randn(20, 16, 50, 100) + >>> # quantize input to quint8 + >>> # xdoctest: +SKIP + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> output = m(q_input) + + """ + + _FLOAT_MODULE: ClassVar[type[nn.Conv2d]] = nn.Conv2d + _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = nniqat.ConvBn2d + _NNI_CONV_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = nni.ConvReLU2d + _NNI_CONV_ADD_MODULE: ClassVar[type[nni.ConvAdd2d]] = nni.ConvAdd2d + _NNI_CONV_ADD_RELU_MODULE: ClassVar[type[nni.ConvAddReLU2d]] = nni.ConvAddReLU2d + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + # Subclasses of _ConvNd need to call _init rather than __init__. See + # discussion on PR #49702 + super()._init( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + False, + _pair(0), + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _get_name(self): + return "QuantizedConv2d" + + def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: + if self.padding_mode == "zeros": + self._packed_params = torch.ops.quantized.conv2d_prepack( + w, b, self.stride, self.padding, self.dilation, self.groups + ) + else: + self._packed_params = torch.ops.quantized.conv2d_prepack( + w, b, self.stride, _pair(0), self.dilation, self.groups + ) + + def _weight_bias(self): + return self._packed_params.unpack() + + def weight(self): + return self._weight_bias()[0] + + def bias(self): + return self._weight_bias()[1] + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 4: + raise ValueError("Input shape must be `(N, C, H, W)`!") + if self.padding_mode != "zeros": + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + input = F.pad( + input, _reversed_padding_repeated_twice, mode=self.padding_mode + ) + return ops.quantized.conv2d( + input, self._packed_params, self.scale, self.zero_point + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + r"""Creates a quantized module from a float module or qparams_dict. + + Args: + mod (Module): a float module, either produced by torch.ao.quantization + utilities or provided by the user + """ + return _ConvNd.from_float( + cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + +class Conv3d(_ConvNd): + r"""Applies a 3D convolution over a quantized input signal composed of + several quantized input planes. + + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.Conv3d`. + + .. note:: + Only `zeros` is supported for the :attr:`padding_mode` argument. + + .. note:: + Only `torch.quint8` is supported for the input data type. + + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + + See :class:`~torch.nn.Conv3d` for other attributes. + + Examples:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) + >>> # With square kernels and equal stride + >>> m = nn.quantized.Conv3d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nn.quantized.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2)) + >>> # non-square kernels and unequal stride and with padding and dilation + >>> m = nn.quantized.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2), dilation=(1, 2, 2)) + >>> input = torch.randn(20, 16, 56, 56, 56) + >>> # quantize input to quint8 + >>> # xdoctest: +SKIP + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> output = m(q_input) + + """ + + _FLOAT_MODULE: ClassVar[type[nn.Conv3d]] = nn.Conv3d + _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = nniqat.ConvBn3d + _NNI_CONV_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = nni.ConvReLU3d + _NNI_CONV_ADD_MODULE: ClassVar[Optional[type[nn.Module]]] = None + _NNI_CONV_ADD_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = None + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + ): + assert padding_mode != "reflect", "Conv3d does not support reflection padding" + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size = _triple(kernel_size) + stride = _triple(stride) + padding = _triple(padding) + dilation = _triple(dilation) + # Subclasses of _ConvNd need to call _init rather than __init__. See + # discussion on PR #49702 + super()._init( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + False, + _triple(0), + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _get_name(self): + return "QuantizedConv3d" + + def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: + if self.padding_mode == "zeros": + self._packed_params = torch.ops.quantized.conv3d_prepack( + w, b, self.stride, self.padding, self.dilation, self.groups + ) + else: + self._packed_params = torch.ops.quantized.conv3d_prepack( + w, b, self.stride, _triple(0), self.dilation, self.groups + ) + + def _weight_bias(self): + return self._packed_params.unpack() + + def weight(self): + return self._weight_bias()[0] + + def bias(self): + return self._weight_bias()[1] + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 5: + raise ValueError("Input shape must be `(N, C, D, H, W)`!") + if self.padding_mode != "zeros": + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + input = F.pad( + input, _reversed_padding_repeated_twice, mode=self.padding_mode + ) + return ops.quantized.conv3d( + input, self._packed_params, self.scale, self.zero_point + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + r"""Creates a quantized module from a float module or qparams_dict. + + Args: + mod (Module): a float module, either produced by torch.ao.quantization + utilities or provided by the user + """ + return _ConvNd.from_float( + cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + +# === Transposed Convolutions === + + +class _ConvTransposeNd(_ConvNd): + _FLOAT_MODULE: ClassVar[type[nn.modules.conv._ConvNd]] + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + bias, + padding_mode, + device=None, + dtype=None, + ): + if padding_mode != "zeros": + raise ValueError( + f'Only "zeros" padding mode is supported for {self.__class__.__name__}' + ) + factory_kwargs = {"device": device, "dtype": dtype} + # Subclasses of _ConvNd need to call _init rather than __init__. See + # discussion on PR #49702 + super()._init( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _input_padding( + self, kernel_size: list[int], dilation: list[int], padding: list[int] + ) -> list[int]: + res = torch.jit.annotate(list[int], []) + for kdx in range(len(kernel_size)): + pad = dilation[kdx] * (kernel_size[kdx] - 1) - padding[kdx] + res.append(pad) + return res + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + r"""Creates a quantized module from a float module or qparams_dict. + Args: + mod (Module): a float module, either produced by torch.ao.quantization + utilities or provided by the user + """ + # derived classes override cls._FLOAT_MODULE attribute + msg = ( + " nnq." + + cls.__name__ + + ".from_float only works for " + + cls._FLOAT_MODULE.__name__ # type: ignore[attr-defined] + ) + assert type(mod) == cls._FLOAT_MODULE, msg + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined." + weight_post_process = mod.qconfig.weight() # type: ignore[operator, union-attr] + weight_post_process(mod.weight) + assert weight_post_process.dtype == torch.qint8, ( + "Weight observer must have a dtype of qint8" + ) + qweight = _quantize_weight(mod.weight.float(), weight_post_process) + # the __init__ call used is the one from derived classes and not the one from _ConvTransposeNd + qconv = cls( + mod.in_channels, + mod.out_channels, + mod.kernel_size, # type: ignore[call-arg] + mod.stride, + mod.padding, + mod.output_padding, + mod.groups, + mod.bias is not None, + mod.dilation, + mod.padding_mode, + ) + qconv.set_weight_bias(qweight, mod.bias) + if ( + not hasattr(mod, "activation_post_process") + or mod.activation_post_process.dtype == torch.float + ): + return qconv # dynamic quantization doesn't need scale/zero_point + else: + act_scale, act_zp = mod.activation_post_process.calculate_qparams() # type: ignore[operator, union-attr] + qconv.scale = float(act_scale) + qconv.zero_point = int(act_zp) + return qconv + + @staticmethod + def from_reference(cls, ref_qconvt, output_scale, output_zero_point): # type: ignore[override] + r"""Create a (fbgemm/qnnpack) quantized module from a reference quantized module + Args: + ref_qconvt (Module): a reference quantized module, either produced by torch.ao.quantization + utilities or provided by the user + output_scale (float): scale for output Tensor + output_zero_point (int): zero point for output Tensor + """ + qconv = cls( + ref_qconvt.in_channels, + ref_qconvt.out_channels, + ref_qconvt.kernel_size, # type: ignore[arg-type] + ref_qconvt.stride, # type: ignore[arg-type] + ref_qconvt.padding, # type: ignore[arg-type] + ref_qconvt.output_padding, # type: ignore[arg-type] + ref_qconvt.groups, + ref_qconvt.bias is not None, # type: ignore[arg-type] + ref_qconvt.dilation, # type: ignore[arg-type] + ref_qconvt.padding_mode, + device=ref_qconvt.weight.device, + dtype=ref_qconvt.weight.dtype, + ) + qweight = ref_qconvt.get_quantized_weight() + qconv.set_weight_bias(qweight, ref_qconvt.bias) + qconv.scale = float(output_scale) + qconv.zero_point = int(output_zero_point) + return qconv + + +class ConvTranspose1d(_ConvTransposeNd): + r"""Applies a 1D transposed convolution operator over an input image + composed of several input planes. + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.ConvTranspose1d`. + + .. note:: Currently only the QNNPACK engine is implemented. + Please, set the `torch.backends.quantized.engine = 'qnnpack'` + + For special notes, please, see :class:`~torch.ao.nn.quantized.Conv1d` + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + See :class:`~torch.nn.ConvTranspose2d` for other attributes. + + Examples:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) + >>> torch.backends.quantized.engine = 'qnnpack' + >>> from torch.ao.nn import quantized as nnq + >>> # With square kernels and equal stride + >>> m = nnq.ConvTranspose1d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nnq.ConvTranspose1d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) + >>> input = torch.randn(20, 16, 50) + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> output = m(q_input) + >>> # exact output size can be also specified as an argument + >>> input = torch.randn(1, 16, 12) + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> downsample = nnq.Conv1d(16, 16, 3, stride=2, padding=1) + >>> upsample = nnq.ConvTranspose1d(16, 16, 3, stride=2, padding=1) + >>> h = downsample(q_input) + >>> h.size() + torch.Size([1, 16, 6]) + >>> # xdoctest: +SKIP("FIXME: output_size is not a parameter) + >>> output = upsample(h, output_size=input.size()) + >>> output.size() + torch.Size([1, 16, 12]) + """ + + _FLOAT_MODULE: ClassVar[type[nn.ConvTranspose1d]] = nn.ConvTranspose1d + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + padding_mode="zeros", + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size = _single(kernel_size) + stride = _single(stride) + padding = _single(padding) + dilation = _single(dilation) + output_padding = _single(output_padding) + + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + True, + output_padding, + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _get_name(self): + return "QuantizedConvTranspose1d" + + def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: + self._packed_params = torch.ops.quantized.conv_transpose1d_prepack( + w, + b, + self.stride, + self.padding, + self.output_padding, + self.dilation, + self.groups, + ) + + def _weight_bias(self): + w, b = torch.ops.quantized.conv_transpose1d_unpack(self._packed_params) + return w, b + + def weight(self): + (w, _) = self._weight_bias() + return w + + def bias(self): + (_, b) = self._weight_bias() + return b + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 3: + raise ValueError("Input shape must be `(N, C, L)`!") + return torch.ops.quantized.conv_transpose1d( + input, self._packed_params, self.scale, self.zero_point + ) + + @classmethod + def from_reference(cls, ref_qconvt, output_scale, output_zero_point): # type: ignore[override] + return _ConvTransposeNd.from_reference( + cls, ref_qconvt, output_scale, output_zero_point + ) + + +class ConvTranspose2d(_ConvTransposeNd): + r"""Applies a 2D transposed convolution operator over an input image + composed of several input planes. + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.ConvTranspose2d`. + + For special notes, please, see :class:`~torch.ao.nn.quantized.Conv2d` + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + See :class:`~torch.nn.ConvTranspose2d` for other attributes. + + Examples:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) + >>> # QNNPACK or FBGEMM as backend + >>> torch.backends.quantized.engine = 'qnnpack' + >>> # With square kernels and equal stride + >>> import torch.ao.nn.quantized as nnq + >>> m = nnq.ConvTranspose2d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nnq.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) + >>> input = torch.randn(20, 16, 50, 100) + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> output = m(q_input) + >>> # exact output size can be also specified as an argument + >>> input = torch.randn(1, 16, 12, 12) + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> downsample = nnq.Conv2d(16, 16, 3, stride=2, padding=1) + >>> upsample = nnq.ConvTranspose2d(16, 16, 3, stride=2, padding=1) + >>> h = downsample(q_input) + >>> h.size() + torch.Size([1, 16, 6, 6]) + >>> # xdoctest: +SKIP("FIXME: output_size is not a parameter) + >>> output = upsample(h, output_size=input.size()) + >>> output.size() + torch.Size([1, 16, 12, 12]) + """ + + _FLOAT_MODULE: ClassVar[type[nn.ConvTranspose2d]] = nn.ConvTranspose2d + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + padding_mode="zeros", + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + output_padding = _pair(output_padding) + + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + True, + output_padding, + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _get_name(self): + return "QuantizedConvTranspose2d" + + def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: + self._packed_params = torch.ops.quantized.conv_transpose2d_prepack( + w, + b, + self.stride, + self.padding, + self.output_padding, + self.dilation, + self.groups, + ) + + def _weight_bias(self): + w, b = torch.ops.quantized.conv2d_unpack(self._packed_params) + return w, b + + def weight(self): + (w, _) = self._weight_bias() + return w + + def bias(self): + (_, b) = self._weight_bias() + return b + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 4: + raise ValueError("Input shape must be `(N, C, H, W)`!") + return ops.quantized.conv_transpose2d( + input, self._packed_params, self.scale, self.zero_point + ) + + @classmethod + def from_reference(cls, ref_qconvt, output_scale, output_zero_point): # type: ignore[override] + return _ConvTransposeNd.from_reference( + cls, ref_qconvt, output_scale, output_zero_point + ) + + +class ConvTranspose3d(_ConvTransposeNd): + r"""Applies a 3D transposed convolution operator over an input image + composed of several input planes. + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.ConvTranspose3d`. + + .. note:: Currently only the FBGEMM engine is implemented. + Please, set the `torch.backends.quantized.engine = 'fbgemm'` + + For special notes, please, see :class:`~torch.ao.nn.quantized.Conv3d` + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + See :class:`~torch.nn.ConvTranspose3d` for other attributes. + + Examples:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) + >>> torch.backends.quantized.engine = 'fbgemm' + >>> from torch.ao.nn import quantized as nnq + >>> # With cubic kernels and equal stride + >>> m = nnq.ConvTranspose3d(16, 33, 3, stride=2) + >>> # non-cubic kernels and unequal stride and with padding + >>> m = nnq.ConvTranspose3d(16, 33, (3, 3, 5), stride=(2, 1, 1), padding=(4, 2, 2)) + >>> input = torch.randn(20, 16, 50, 100, 100) + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> output = m(q_input) + >>> # exact output size can be also specified as an argument + >>> input = torch.randn(1, 16, 12, 12, 12) + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> downsample = nnq.Conv3d(16, 16, 3, stride=2, padding=1) + >>> upsample = nnq.ConvTranspose3d(16, 16, 3, stride=2, padding=1) + >>> h = downsample(q_input) + >>> h.size() + torch.Size([1, 16, 6, 6, 6]) + >>> # xdoctest: +SKIP("FIXME: output_size is not a parameter) + >>> output = upsample(h, output_size=input.size()) + >>> output.size() + torch.Size([1, 16, 12, 12, 12]) + """ + + _FLOAT_MODULE: ClassVar[type[nn.ConvTranspose3d]] = nn.ConvTranspose3d + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + padding_mode="zeros", + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size = _triple(kernel_size) + stride = _triple(stride) + padding = _triple(padding) + dilation = _triple(dilation) + output_padding = _triple(output_padding) + + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + True, + output_padding, + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _get_name(self): + return "QuantizedConvTranspose3d" + + def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: + self._packed_params = torch.ops.quantized.conv_transpose3d_prepack( + w, + b, + self.stride, + self.padding, + self.output_padding, + self.dilation, + self.groups, + ) + + def _weight_bias(self): + w, b = torch.ops.quantized.conv3d_unpack(self._packed_params) + return w, b + + def weight(self): + (w, _) = self._weight_bias() + return w + + def bias(self): + (_, b) = self._weight_bias() + return b + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 5: + raise ValueError("Input shape must be `(N, C, T, H, W)`!") + return ops.quantized.conv_transpose3d( + input, self._packed_params, self.scale, self.zero_point + ) + + @classmethod + def from_reference(cls, ref_qconvt, output_scale, output_zero_point): # type: ignore[override] + return _ConvTransposeNd.from_reference( + cls, ref_qconvt, output_scale, output_zero_point + ) diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/dropout.py b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/dropout.py new file mode 100644 index 0000000000000000000000000000000000000000..18de8eb2e829bf12c436fa432bd949c0d1af2275 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/dropout.py @@ -0,0 +1,30 @@ +# mypy: allow-untyped-defs +import torch + + +__all__ = ["Dropout"] + + +class Dropout(torch.nn.Dropout): + r"""This is the quantized equivalent of :class:`~torch.nn.Dropout`. + And this is a placeholder to enable models where fp32 tensors + had dropout to work with quantized tensors in train and eval mode. + + Args: + p: probability of an element to be zeroed + inplace: can optionally do the operation in-place. Default: ``False`` + """ + + def forward(self, input): + return input + + def _get_name(self): + return "QuantizedDropout" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + return cls(mod.p, mod.inplace) + + @classmethod + def from_reference(cls, mod, scale, zero_point): + return cls(mod.p, mod.inplace) diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/embedding_ops.py b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/embedding_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..b6a5cc346e38d563e01c957046efc0e45a632907 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/embedding_ops.py @@ -0,0 +1,413 @@ +# mypy: allow-untyped-defs +import torch +import torch.nn as nn +from torch import Tensor # noqa: F401 +from torch._jit_internal import List, Optional # noqa: F401 + +from .utils import _hide_packed_params_repr, _quantize_weight + + +__all__ = ["EmbeddingPackedParams", "Embedding", "EmbeddingBag"] + + +class EmbeddingPackedParams(torch.nn.Module): + _version = 1 + + def __init__(self, num_embeddings, embedding_dim, dtype=torch.quint8): + super().__init__() + self.dtype = dtype + if self.dtype in [torch.quint8, torch.quint4x2]: + scales = torch.ones(num_embeddings, dtype=torch.float) + zero_points = torch.zeros(num_embeddings, dtype=torch.float) + wq = torch._empty_per_channel_affine_quantized( + [num_embeddings, embedding_dim], + scales=scales, + zero_points=zero_points, + axis=0, + dtype=self.dtype, + ) + self.set_weight(wq) + else: + raise NotImplementedError( + f"Unsupported dtype on quantized embedding! Supports quint8 and quint4x2. Got dtype: {dtype}" + ) + + @torch.jit.export + def set_weight(self, weight: torch.Tensor) -> None: + if self.dtype in [torch.quint8, torch.quint4x2]: + self._packed_weight = torch.ops.quantized.embedding_bag_prepack(weight) + else: + raise NotImplementedError( + "Unsupported dtype for quantized embedding prepack! Supports quint8 and quint4x2." + ) + + @torch.jit.export + def _weight(self): + if self.dtype in [torch.quint8, torch.quint4x2]: + return torch.ops.quantized.embedding_bag_unpack(self._packed_weight) + else: + raise NotImplementedError( + "Unsupported dtype for quantized embedding unpack! Supports quint8 and quint4x2." + ) + + def forward(self, x): + return x + + # Version 1 + # self + # |--- _packed_weight : Tensor representing weight of EmbeddingPackedParamsBase + # |--- dtype : torch.dtype + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + "dtype"] = self.dtype + destination[prefix + "_packed_weight"] = self._weight() + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + self.dtype = state_dict[prefix + "dtype"] + state_dict.pop(prefix + "dtype") + + weight = state_dict[prefix + "_packed_weight"] + state_dict.pop(prefix + "_packed_weight") + self.set_weight(weight) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def __repr__(self): + return self._weight().__repr__() + + +class Embedding(torch.nn.Module): + r""" + A quantized Embedding module with quantized packed weights as inputs. + We adopt the same interface as `torch.nn.Embedding`, please see + https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html for documentation. + + Similar to :class:`~torch.nn.Embedding`, attributes will be randomly + initialized at module creation time and will be overwritten later + + Attributes: + weight (Tensor): the non-learnable quantized weights of the module of + shape :math:`(\text{num\_embeddings}, \text{embedding\_dim})`. + + Examples:: + >>> m = nn.quantized.Embedding(num_embeddings=10, embedding_dim=12) + >>> indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8]) + >>> output = m(indices) + >>> print(output.size()) + torch.Size([9, 12]) + + """ + + _version = 1 + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, + _weight: Optional[Tensor] = None, + dtype=torch.quint8, + ) -> None: + super().__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.dtype = dtype + + if _weight is None: + scales = torch.ones(num_embeddings, dtype=torch.float) + zero_points = torch.zeros(num_embeddings, dtype=torch.float) + qweight = torch._empty_per_channel_affine_quantized( + [num_embeddings, embedding_dim], + scales=scales, + zero_points=zero_points, + axis=0, + dtype=torch.quint8, + ) + else: + assert list(_weight.shape) == [ + num_embeddings, + embedding_dim, + ], "Shape of weight does not match num_embeddings and embedding_dim" + qweight = _weight + + self._packed_params = EmbeddingPackedParams( + num_embeddings, embedding_dim, dtype + ) + self._packed_params.set_weight(qweight) + + def forward(self, indices: Tensor) -> Tensor: + if self.dtype == torch.quint4x2: + return torch.ops.quantized.embedding_4bit( + self._packed_params._packed_weight, indices + ) + else: + return torch.ops.quantized.embedding_byte( + self._packed_params._packed_weight, indices + ) + + def _get_name(self): + return "QuantizedEmbedding" + + def __repr__(self): + return _hide_packed_params_repr(self, EmbeddingPackedParams) + + def extra_repr(self): + extra_repr_str = ( + f"num_embeddings={self.num_embeddings}, embedding_dim={self.embedding_dim}, " + f"dtype={self._packed_params.dtype}, qscheme={self.weight().qscheme()}" + ) + + return extra_repr_str + + def set_weight(self, w: torch.Tensor) -> None: + self._packed_params.set_weight(w) + + def weight(self): + return self._packed_params._weight() + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Create a quantized embedding module from a float module + + Args: + mod (Module): a float module, either produced by torch.ao.quantization + utilities or provided by user + """ + if hasattr(mod, "weight_fake_quant"): + assert type(mod) == torch.ao.nn.qat.Embedding, ( + "nnq." + + cls.__name__ + + ".from_float " + + "with fake quant only works for " + + torch.ao.nn.qat.Embedding.__name__ + ) + weight_observer = mod.weight_fake_quant + else: + assert type(mod) == nn.Embedding, ( + "nnq." + + cls.__name__ + + ".from_float only works for " + + nn.Embedding.__name__ + ) + assert hasattr(mod, "qconfig"), ( + "Embedding input float module must have qconfig defined" + ) + from torch.ao.quantization import float_qparams_weight_only_qconfig + + if mod.qconfig is not None and mod.qconfig.weight is not None: # type: ignore[union-attr] + weight_observer = mod.qconfig.weight() # type: ignore[union-attr, operator] + else: + weight_observer = float_qparams_weight_only_qconfig.weight() + + dtype = weight_observer.dtype + is_float_qparams_qconfig = ( + weight_observer.qscheme == torch.per_channel_affine_float_qparams + ) + assert is_float_qparams_qconfig, ( + "Embedding quantization is only supported with float_qparams_weight_only_qconfig." + ) + + assert dtype == torch.quint8 or dtype == torch.quint4x2, ( + f"The only supported dtype for nnq.Embedding is torch.quint8 and torch.quint4x2, got {dtype}" + ) + + # Run the observer to calculate qparams. + weight_observer(mod.weight) + qweight = _quantize_weight(mod.weight.float(), weight_observer) + + # Create quantized Embedding module and pass in the quantized weight + qembedding = Embedding(mod.num_embeddings, mod.embedding_dim) + qembedding.set_weight(qweight) + return qembedding + + @classmethod + def from_reference(cls, ref_embedding): + qembedding = cls( + ref_embedding.num_embeddings, + ref_embedding.embedding_dim, + ref_embedding.padding_idx, + ref_embedding.max_norm, + ref_embedding.norm_type, + ref_embedding.scale_grad_by_freq, + ref_embedding.sparse, + ref_embedding.get_quantized_weight(), + ref_embedding.weight_dtype, + ) + return qembedding + + +class EmbeddingBag(Embedding): + r""" + A quantized EmbeddingBag module with quantized packed weights as inputs. + We adopt the same interface as `torch.nn.EmbeddingBag`, please see + https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html for documentation. + + Similar to :class:`~torch.nn.EmbeddingBag`, attributes will be randomly + initialized at module creation time and will be overwritten later + + Attributes: + weight (Tensor): the non-learnable quantized weights of the module of + shape :math:`(\text{num\_embeddings}, \text{embedding\_dim})`. + + Examples:: + >>> m = nn.quantized.EmbeddingBag(num_embeddings=10, embedding_dim=12, include_last_offset=True, mode='sum') + >>> indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]) + >>> offsets = torch.tensor([0, 19, 20, 28, 28, 32]) + >>> output = m(indices, offsets) + >>> print(output.size()) + torch.Size([5, 12]) + + """ + + _version = 1 + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + mode: str = "sum", + sparse: bool = False, + _weight: Optional[Tensor] = None, + include_last_offset: bool = False, + dtype=torch.quint8, + ) -> None: + super().__init__(num_embeddings, embedding_dim, _weight=_weight, dtype=dtype) + + self.mode = mode + self.pruned_weights = False + self.include_last_offset = include_last_offset + self.dtype = dtype + + def forward( + self, + indices: Tensor, + offsets: Optional[Tensor] = None, + per_sample_weights: Optional[Tensor] = None, + compressed_indices_mapping: Optional[Tensor] = None, + ) -> Tensor: + if self.dtype == torch.quint4x2: + return torch.ops.quantized.embedding_bag_4bit( + self._packed_params._packed_weight, + indices, + offsets, + False, + 0, + self.pruned_weights, + per_sample_weights, + compressed_indices_mapping, + self.include_last_offset, + ) + else: + return torch.ops.quantized.embedding_bag_byte( + self._packed_params._packed_weight, + indices, + offsets, + False, + 0, + self.pruned_weights, + per_sample_weights, + compressed_indices_mapping, + self.include_last_offset, + ) + + def _get_name(self): + return "QuantizedEmbeddingBag" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Create a quantized embedding_bag module from a float module + + Args: + mod (Module): a float module, either produced by torch.ao.quantization + utilities or provided by user + """ + if hasattr(mod, "weight_fake_quant"): + weight_observer = mod.weight_fake_quant + else: + assert type(mod) == nn.EmbeddingBag, ( + "nnq." + + cls.__name__ + + ".from_float only works for " + + nn.EmbeddingBag.__name__ + ) + assert hasattr(mod, "qconfig"), ( + "EmbeddingBag input float module must have qconfig defined" + ) + from torch.ao.quantization.qconfig import float_qparams_weight_only_qconfig + + if mod.qconfig is not None and mod.qconfig.weight is not None: # type: ignore[union-attr] + weight_observer = mod.qconfig.weight() # type: ignore[union-attr, operator] + else: + weight_observer = float_qparams_weight_only_qconfig.weight() + + dtype = weight_observer.dtype + is_float_qparams_qconfig = ( + weight_observer.qscheme == torch.per_channel_affine_float_qparams + ) + assert is_float_qparams_qconfig, ( + "EmbeddingBag quantization is only supported with float_qparams_weight_only_qconfig." + ) + + assert dtype == torch.quint8 or dtype == torch.quint4x2, ( + f"The only supported dtype for nnq.EmbeddingBag is torch.quint8 and torch.quint4x2, got {dtype}" + ) + + # Run the observer to calculate qparams. + weight_observer(mod.weight) + qweight = _quantize_weight(mod.weight.float(), weight_observer) + + # Create quantized EmbeddingBag module and pass in the quantized weight + qembedding_bag = EmbeddingBag( + mod.num_embeddings, + mod.embedding_dim, + max_norm=mod.max_norm, + norm_type=mod.norm_type, + scale_grad_by_freq=mod.scale_grad_by_freq, + mode=mod.mode, + sparse=mod.sparse, + include_last_offset=mod.include_last_offset, + dtype=dtype, + ) + qembedding_bag.set_weight(qweight) + return qembedding_bag + + @classmethod + def from_reference(cls, ref_embedding_bag): + qembedding_bag = cls( + ref_embedding_bag.num_embeddings, + ref_embedding_bag.embedding_dim, + ref_embedding_bag.max_norm, + ref_embedding_bag.norm_type, + ref_embedding_bag.scale_grad_by_freq, + ref_embedding_bag.mode, + ref_embedding_bag.sparse, + ref_embedding_bag.get_quantized_weight(), + ref_embedding_bag.include_last_offset, + ref_embedding_bag.weight_dtype, + ) + return qembedding_bag diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/functional_modules.py b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/functional_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..0f16190232c37687a126dcb5d5b03430dfbbad53 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/functional_modules.py @@ -0,0 +1,298 @@ +# mypy: allow-untyped-defs + +import torch +from torch import Tensor +from torch._ops import ops + + +__all__ = ["FloatFunctional", "FXFloatFunctional", "QFunctional"] + + +class FloatFunctional(torch.nn.Module): + r"""State collector class for float operations. + + The instance of this class can be used instead of the ``torch.`` prefix for + some operations. See example usage below. + + .. note:: + + This class does not provide a ``forward`` hook. Instead, you must use + one of the underlying functions (e.g. ``add``). + + Examples:: + + >>> f_add = FloatFunctional() + >>> a = torch.tensor(3.0) + >>> b = torch.tensor(4.0) + >>> f_add.add(a, b) # Equivalent to ``torch.add(a, b)`` + + Valid operation names: + - add + - cat + - mul + - add_relu + - add_scalar + - mul_scalar + """ + + def __init__(self) -> None: + super().__init__() + self.activation_post_process = torch.nn.Identity() + + def forward(self, x): + raise RuntimeError( + "FloatFunctional is not intended to use the " + + "'forward'. Please use the underlying operation" + ) + + r"""Operation equivalent to ``torch.add(Tensor, Tensor)``""" + + def add(self, x: Tensor, y: Tensor) -> Tensor: + r = torch.add(x, y) + r = self.activation_post_process(r) + return r + + r"""Operation equivalent to ``torch.add(Tensor, float)``""" + + def add_scalar(self, x: Tensor, y: float) -> Tensor: + r = torch.add(x, y) + # Note: this operation is not observed because the observation is not + # needed for the quantized op. + return r + + r"""Operation equivalent to ``torch.mul(Tensor, Tensor)``""" + + def mul(self, x: Tensor, y: Tensor) -> Tensor: + r = torch.mul(x, y) + r = self.activation_post_process(r) + return r + + r"""Operation equivalent to ``torch.mul(Tensor, float)``""" + + def mul_scalar(self, x: Tensor, y: float) -> Tensor: + r = torch.mul(x, y) + # Note: this operation is not observed because the observation is not + # needed for the quantized op. + return r + + r"""Operation equivalent to ``torch.cat``""" + + def cat(self, x: list[Tensor], dim: int = 0) -> Tensor: + r = torch.cat(x, dim=dim) + r = self.activation_post_process(r) + return r + + r"""Operation equivalent to ``relu(torch.add(x,y))``""" + + def add_relu(self, x: Tensor, y: Tensor) -> Tensor: + r = torch.add(x, y) + r = torch.nn.functional.relu(r) + r = self.activation_post_process(r) + return r + + r"""Operation equivalent to ``torch.matmul(Tensor, Tensor)``""" + + def matmul(self, x: Tensor, y: Tensor) -> Tensor: + r = torch.matmul(x, y) + r = self.activation_post_process(r) + return r + + +class FXFloatFunctional(torch.nn.Module): + r"""module to replace FloatFunctional module before FX graph mode quantization, + since activation_post_process will be inserted in top level module directly + + Valid operation names: + - add + - cat + - mul + - add_relu + - add_scalar + - mul_scalar + """ + + def forward(self, x): + raise RuntimeError( + "FloatFunctional is not intended to use the " + + "'forward'. Please use the underlying operation" + ) + + r"""Operation equivalent to ``torch.add(Tensor, Tensor)``""" + + def add(self, x: Tensor, y: Tensor) -> Tensor: + r = torch.add(x, y) + return r + + r"""Operation equivalent to ``torch.add(Tensor, float)``""" + + def add_scalar(self, x: Tensor, y: float) -> Tensor: + r = torch.add(x, y) + return r + + r"""Operation equivalent to ``torch.mul(Tensor, Tensor)``""" + + def mul(self, x: Tensor, y: Tensor) -> Tensor: + r = torch.mul(x, y) + return r + + r"""Operation equivalent to ``torch.mul(Tensor, float)``""" + + def mul_scalar(self, x: Tensor, y: float) -> Tensor: + r = torch.mul(x, y) + return r + + r"""Operation equivalent to ``torch.cat``""" + + def cat(self, x: list[Tensor], dim: int = 0) -> Tensor: + r = torch.cat(x, dim=dim) + return r + + r"""Operation equivalent to ``relu(torch.add(x,y))``""" + + def add_relu(self, x: Tensor, y: Tensor) -> Tensor: + r = torch.add(x, y) + r = torch.nn.functional.relu(r) + return r + + r"""Operation equivalent to ``torch.matmul(Tensor, Tensor)``""" + + def matmul(self, x: Tensor, y: Tensor) -> Tensor: + r = torch.matmul(x, y) + return r + + +class QFunctional(torch.nn.Module): + r"""Wrapper class for quantized operations. + + The instance of this class can be used instead of the + ``torch.ops.quantized`` prefix. See example usage below. + + .. note:: + + This class does not provide a ``forward`` hook. Instead, you must use + one of the underlying functions (e.g. ``add``). + + Examples:: + + >>> q_add = QFunctional() + >>> # xdoctest: +SKIP + >>> a = torch.quantize_per_tensor(torch.tensor(3.0), 1.0, 0, torch.qint32) + >>> b = torch.quantize_per_tensor(torch.tensor(4.0), 1.0, 0, torch.qint32) + >>> q_add.add(a, b) # Equivalent to ``torch.ops.quantized.add(a, b, 1.0, 0)`` + + Valid operation names: + - add + - cat + - mul + - add_relu + - add_scalar + - mul_scalar + """ + + def __init__(self) -> None: + super().__init__() + self.scale = 1.0 + self.zero_point = 0 + self.activation_post_process = torch.nn.Identity() + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + "scale"] = torch.tensor(self.scale) + destination[prefix + "zero_point"] = torch.tensor(self.zero_point) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + self.scale = float(state_dict.pop(prefix + "scale")) + self.zero_point = int(state_dict.pop(prefix + "zero_point")) + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def _get_name(self): + return "QFunctional" + + def extra_repr(self): + return f"scale={self.scale}, zero_point={self.zero_point}" + + def forward(self, x): + raise RuntimeError( + "Functional is not intended to use the " + + "'forward'. Please use the underlying operation" + ) + + r"""Operation equivalent to ``torch.ops.quantized.add``""" + + def add(self, x: Tensor, y: Tensor) -> Tensor: + r = ops.quantized.add(x, y, scale=self.scale, zero_point=self.zero_point) + r = self.activation_post_process(r) + return r + + r"""Operation equivalent to ``torch.ops.quantized.add(Tensor, float)``""" + + def add_scalar(self, x: Tensor, y: float) -> Tensor: + r = ops.quantized.add_scalar(x, y) + # Note: this operation is not observed because the observation is not + # needed for the quantized op. + return r + + r"""Operation equivalent to ``torch.ops.quantized.mul(Tensor, Tensor)``""" + + def mul(self, x: Tensor, y: Tensor) -> Tensor: + r = ops.quantized.mul(x, y, scale=self.scale, zero_point=self.zero_point) + r = self.activation_post_process(r) + return r + + r"""Operation equivalent to ``torch.ops.quantized.mul(Tensor, float)``""" + + def mul_scalar(self, x: Tensor, y: float) -> Tensor: + r = ops.quantized.mul_scalar(x, y) + # Note: this operation is not observed because the observation is not + # needed for the quantized op. + return r + + r"""Operation equivalent to ``torch.ops.quantized.cat``""" + + def cat(self, x: list[Tensor], dim: int = 0) -> Tensor: + r = ops.quantized.cat(x, scale=self.scale, zero_point=self.zero_point, dim=dim) + r = self.activation_post_process(r) + return r + + r"""Operation equivalent to ``torch.ops.quantized.add_relu``""" + + def add_relu(self, x: Tensor, y: Tensor) -> Tensor: + r = ops.quantized.add_relu(x, y, scale=self.scale, zero_point=self.zero_point) + r = self.activation_post_process(r) + return r + + r"""Operation equivalent to ``torch.ops.quantized.matmul(Tensor, Tensor)``""" + + def matmul(self, x: Tensor, y: Tensor) -> Tensor: + r = ops.quantized.matmul(x, y, scale=self.scale, zero_point=self.zero_point) + # Note: this operation is not observed because the observation is not + # needed for the quantized op. + return r + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + assert type(mod) == FloatFunctional, ( + "QFunctional.from_float expects an instance of FloatFunctional" + ) + scale, zero_point = mod.activation_post_process.calculate_qparams() # type: ignore[operator] + new_mod = QFunctional() + new_mod.scale = float(scale) + new_mod.zero_point = int(zero_point) + return new_mod diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/linear.py b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..b25ad4b93c88cf2ad324bacb6445342aee26b0f5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/linear.py @@ -0,0 +1,363 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.intrinsic.qat as nniqat +import torch.nn as nn +from torch.nn.utils.fusion import fuse_linear_bn_weights +from torch.nn.utils.parametrize import type_before_parametrizations + +from .utils import _hide_packed_params_repr, _quantize_weight, WeightedQuantizedModule + + +__all__ = ["LinearPackedParams", "Linear"] + + +class LinearPackedParams(torch.nn.Module): + _version = 3 + + def __init__(self, dtype=torch.qint8): + super().__init__() + self.dtype = dtype + if self.dtype == torch.qint8: + wq = torch._empty_affine_quantized( + [1, 1], scale=1.0, zero_point=0, dtype=torch.qint8 + ) + elif self.dtype == torch.float16: + wq = torch.zeros([1, 1], dtype=torch.float) + self.set_weight_bias(wq, None) # type: ignore[possibly-undefined] + + @torch.jit.export + def set_weight_bias( + self, weight: torch.Tensor, bias: Optional[torch.Tensor] + ) -> None: + if self.dtype == torch.qint8: + self._packed_params = torch.ops.quantized.linear_prepack(weight, bias) + elif self.dtype == torch.float16: + self._packed_params = torch.ops.quantized.linear_prepack_fp16(weight, bias) + else: + raise RuntimeError("Unsupported dtype on dynamic quantized linear!") + + @torch.jit.export + def _weight_bias(self): + if self.dtype == torch.qint8: + return torch.ops.quantized.linear_unpack(self._packed_params) + elif self.dtype == torch.float16: + return torch.ops.quantized.linear_unpack_fp16(self._packed_params) + else: + raise RuntimeError("Unsupported dtype on dynamic quantized linear!") + + def forward(self, x): + return x + + # Version 1 + # self + # |--- weight : Tensor + # |--- bias : Tensor + # + # Version 2 + # self + # |--- weight : Tensor + # |--- bias : Tensor + # |--- dtype : torch.dtype + # + # Version 3 + # self + # |--- _packed_params : (Tensor, Tensor) representing (weight, bias) + # of LinearPackedParams + # |--- dtype : torch.dtype + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + "dtype"] = self.dtype + destination[prefix + "_packed_params"] = self._weight_bias() + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + if version is None or version < 2: + self.dtype = torch.qint8 + else: + self.dtype = state_dict[prefix + "dtype"] + state_dict.pop(prefix + "dtype") + + if version is None or version < 3: + self.set_weight_bias( + state_dict[prefix + "weight"], state_dict[prefix + "bias"] + ) + state_dict.pop(prefix + "weight") + state_dict.pop(prefix + "bias") + + if version == 3: + weight, bias = state_dict[prefix + "_packed_params"] + state_dict.pop(prefix + "_packed_params") + self.set_weight_bias(weight, bias) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def __repr__(self): + return self._weight_bias().__repr__() + + +class Linear(WeightedQuantizedModule): + r""" + A quantized linear module with quantized tensor as inputs and outputs. + We adopt the same interface as `torch.nn.Linear`, please see + https://pytorch.org/docs/stable/nn.html#torch.nn.Linear for documentation. + + Similar to :class:`~torch.nn.Linear`, attributes will be randomly + initialized at module creation time and will be overwritten later + + Attributes: + weight (Tensor): the non-learnable quantized weights of the module of + shape :math:`(\text{out\_features}, \text{in\_features})`. + bias (Tensor): the non-learnable bias of the module of shape :math:`(\text{out\_features})`. + If :attr:`bias` is ``True``, the values are initialized to zero. + scale: `scale` parameter of output Quantized Tensor, type: double + zero_point: `zero_point` parameter for output Quantized Tensor, type: long + + Examples:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) + >>> m = nn.quantized.Linear(20, 30) + >>> input = torch.randn(128, 20) + >>> # xdoctest: +SKIP + >>> input = torch.quantize_per_tensor(input, 1.0, 0, torch.quint8) + >>> output = m(input) + >>> print(output.size()) + torch.Size([128, 30]) + """ + + _version = 3 + _FLOAT_MODULE = (nn.Linear, nn.modules.linear.NonDynamicallyQuantizableLinear) + + def __init__(self, in_features, out_features, bias_=True, dtype=torch.qint8): + super().__init__() + # We don't muck around with buffers or attributes or anything here + # to keep the module simple. *everything* is simply a Python attribute. + # Serialization logic is explicitly handled in the below serialization and + # deserialization modules + self.in_features = in_features + self.out_features = out_features + bias = None + if bias_: + bias = torch.zeros(out_features, dtype=torch.float) + + if dtype == torch.qint8: + qweight = torch._empty_affine_quantized( + [out_features, in_features], scale=1, zero_point=0, dtype=torch.qint8 + ) + elif dtype == torch.float16: + qweight = torch.zeros([out_features, in_features], dtype=torch.float) + else: + raise RuntimeError("Unsupported dtype specified for quantized Linear!") + + self._packed_params = LinearPackedParams(dtype) + self._packed_params.set_weight_bias(qweight, bias) + self.scale = 1.0 + self.zero_point = 0 + + def _get_name(self): + return "QuantizedLinear" + + def extra_repr(self): + return ( + f"in_features={self.in_features}, out_features={self.out_features}, scale={self.scale}, " + f"zero_point={self.zero_point}, qscheme={self.weight().qscheme()}" + ) + + def __repr__(self): + return _hide_packed_params_repr(self, LinearPackedParams) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.quantized.linear( + x, self._packed_params._packed_params, self.scale, self.zero_point + ) + + # ===== Serialization methods ===== + # The special consideration here is that we have to unpack the weights into their + # regular QTensor form for serialization. Packed weights should not live + # outside the process in which they were created, rather they should be derived + # from the QTensor weight. + # + # Version 1 + # self + # |--- scale : float + # |--- zero_point : int + # |--- weight : Tensor + # |--- bias : Tensor + # + # Version 2 + # self + # |--- scale : float + # |--- zero_point : int + # |--- _packed_params : Module + # |--- weight : Tensor + # |--- bias : Tensor + # + # Version 3 + # self + # |--- scale : float + # |--- zero_point : int + # |--- _packed_params : Module + # |--- _packed_params : (Tensor, Tensor) representing weight, bias + # of LinearPackedParams C++ struct + # + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + "scale"] = torch.tensor(self.scale) + destination[prefix + "zero_point"] = torch.tensor(self.zero_point) + + # ===== Deserialization methods ===== + # Counterpart to the serialization methods, we must pack the serialized QTensor + # weight into its packed format for use by the FBGEMM ops. + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + self.scale = float(state_dict[prefix + "scale"]) + state_dict.pop(prefix + "scale") + + self.zero_point = int(state_dict[prefix + "zero_point"]) + state_dict.pop(prefix + "zero_point") + + version = local_metadata.get("version", None) + + if version is None or version == 1: + # We moved the parameters into a LinearPackedParameters submodule + weight = state_dict.pop(prefix + "weight") + bias = state_dict.pop(prefix + "bias") + state_dict.update( + { + prefix + "_packed_params.weight": weight, + prefix + "_packed_params.bias": bias, + } + ) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + # Function rather than property to make sure that JIT serialization doesn't + # register this as an attribute + def _weight_bias(self): + return self._packed_params._weight_bias() + + def weight(self): + return self._weight_bias()[0] + + def bias(self): + return self._weight_bias()[1] + + def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: + self._packed_params.set_weight_bias(w, b) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Create a quantized module from an observed float module + + Args: + mod (Module): a float module, either produced by torch.ao.quantization + utilities or provided by the user + use_precomputed_fake_quant (bool): if True, the module will reuse min/max + values from the precomputed fake quant module. + """ + if hasattr(mod, "weight_fake_quant"): + if type_before_parametrizations(mod) == nniqat.LinearBn1d: + mod.weight, mod.bias = fuse_linear_bn_weights( + mod.weight, + mod.bias, + mod.bn.running_mean, + mod.bn.running_var, + mod.bn.eps, + mod.bn.weight, + mod.bn.bias, + ) + weight_post_process = mod.weight_fake_quant + activation_post_process = mod.activation_post_process + else: + # This function does not participate in JIT, so it is OK to ignore + # the type mismatch in assignment. Also, mypy has an issue with + # iterables not being implemented, so we are ignoring those too. + if not isinstance(cls._FLOAT_MODULE, Iterable): + cls._FLOAT_MODULE = [cls._FLOAT_MODULE] + supported_modules = ", ".join( + [float_mod.__name__ for float_mod in cls._FLOAT_MODULE] + ) + error_msg = f"nnq.{cls.__name__}.from_float only works for {supported_modules}, but got: {type(mod)}" + assert type_before_parametrizations(mod) in cls._FLOAT_MODULE, ( + error_msg.format() + ) + assert hasattr(mod, "qconfig"), ( + "Input float module must have qconfig defined" + ) + activation_post_process = mod.activation_post_process + if type_before_parametrizations(mod) == nni.LinearReLU: + mod = mod[0] + weight_post_process = ( + mod.qconfig.weight() + if not hasattr(mod, "weight_fake_quant") + else mod.weight_fake_quant + ) + + if not use_precomputed_fake_quant: + # Observer may not have been called yet + # Observer might have been called in the previous stage via PTQ algorithm e.g. AdaRound + weight_post_process(mod.weight) + dtype = weight_post_process.dtype + act_scale, act_zp = activation_post_process.calculate_qparams() + assert dtype == torch.qint8, "Weight observer must have dtype torch.qint8" + qweight = _quantize_weight(mod.weight.float(), weight_post_process) + qlinear = cls(mod.in_features, mod.out_features, dtype=dtype) + qlinear.set_weight_bias(qweight, mod.bias) + qlinear.scale = float(act_scale) + qlinear.zero_point = int(act_zp) + return qlinear + + @classmethod + def from_reference(cls, ref_qlinear, output_scale, output_zero_point): + r"""Create a (fbgemm/qnnpack) quantized module from a reference quantized module + + Args: + ref_qlinear (Module): a reference quantized linear module, either produced by torch.ao.quantization + utilities or provided by the user + output_scale (float): scale for output Tensor + output_zero_point (int): zero point for output Tensor + """ + qlinear = cls(ref_qlinear.in_features, ref_qlinear.out_features) + qweight = ref_qlinear.get_quantized_weight() + qlinear.set_weight_bias(qweight, ref_qlinear.bias) + + qlinear.scale = float(output_scale) + qlinear.zero_point = int(output_zero_point) + return qlinear diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/normalization.py b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..e4a366043809fc53e0d1a33ca898a610d67ccb1f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/normalization.py @@ -0,0 +1,347 @@ +# mypy: allow-untyped-defs +import torch + + +__all__ = [ + "LayerNorm", + "GroupNorm", + "InstanceNorm1d", + "InstanceNorm2d", + "InstanceNorm3d", +] + + +class LayerNorm(torch.nn.LayerNorm): + r"""This is the quantized version of :class:`~torch.nn.LayerNorm`. + + Additional args: + * **scale** - quantization scale of the output, type: double. + * **zero_point** - quantization zero point of the output, type: long. + + """ + + def __init__( + self, + normalized_shape, + weight, + bias, + scale, + zero_point, + eps=1e-5, + elementwise_affine=True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + normalized_shape, + eps=eps, + elementwise_affine=elementwise_affine, + **factory_kwargs, + ) + self.weight = weight + self.bias = bias + self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) + self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) + + def forward(self, input): + return torch.ops.quantized.layer_norm( + input, + self.normalized_shape, + weight=self.weight, + bias=self.bias, + eps=self.eps, + output_scale=self.scale, + output_zero_point=self.zero_point, + ) + + def _get_name(self): + return "QuantizedLayerNorm" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + scale, zero_point = mod.activation_post_process.calculate_qparams() + new_mod = cls( + mod.normalized_shape, + mod.weight, + mod.bias, + float(scale), + int(zero_point), + mod.eps, + mod.elementwise_affine, + ) + return new_mod + + @classmethod + def from_reference(cls, mod, scale, zero_point): + return cls( + mod.normalized_shape, + mod.weight, + mod.bias, + float(scale), + int(zero_point), + mod.eps, + mod.elementwise_affine, + ) + + +class GroupNorm(torch.nn.GroupNorm): + r"""This is the quantized version of :class:`~torch.nn.GroupNorm`. + + Additional args: + * **scale** - quantization scale of the output, type: double. + * **zero_point** - quantization zero point of the output, type: long. + + """ + + __constants__ = ["num_groups", "num_channels", "eps", "affine"] + + def __init__( + self, + num_groups, + num_channels, + weight, + bias, + scale, + zero_point, + eps=1e-5, + affine=True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(num_groups, num_channels, eps, affine, **factory_kwargs) + self.weight = weight + self.bias = bias + self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) + self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) + + def forward(self, input): + return torch.ops.quantized.group_norm( + input, + self.num_groups, + self.weight, + self.bias, + self.eps, + self.scale, + self.zero_point, + ) + + def _get_name(self): + return "QuantizedGroupNorm" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + scale, zero_point = mod.activation_post_process.calculate_qparams() + new_mod = cls( + mod.num_groups, + mod.num_channels, + mod.weight, + mod.bias, + float(scale), + int(zero_point), + mod.eps, + mod.affine, + ) + return new_mod + + +class InstanceNorm1d(torch.nn.InstanceNorm1d): + r"""This is the quantized version of :class:`~torch.nn.InstanceNorm1d`. + + Additional args: + * **scale** - quantization scale of the output, type: double. + * **zero_point** - quantization zero point of the output, type: long. + + """ + + def __init__( + self, + num_features, + weight, + bias, + scale, + zero_point, + eps=1e-5, + momentum=0.1, + affine=False, + track_running_stats=False, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + num_features, eps, momentum, affine, track_running_stats, **factory_kwargs + ) + self.weight = weight + self.bias = bias + self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) + self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) + + def forward(self, input): + return torch.ops.quantized.instance_norm( + input, self.weight, self.bias, self.eps, self.scale, self.zero_point + ) + + def _get_name(self): + return "QuantizedInstanceNorm1d" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + scale, zero_point = mod.activation_post_process.calculate_qparams() + new_mod = cls( + mod.num_features, + mod.weight, + mod.bias, + float(scale), + int(zero_point), + mod.eps, + mod.affine, + ) + return new_mod + + @classmethod + def from_reference(cls, mod, scale, zero_point): + return cls( + mod.num_features, + mod.weight, + mod.bias, + float(scale), + int(zero_point), + mod.eps, + mod.affine, + ) + + +class InstanceNorm2d(torch.nn.InstanceNorm2d): + r"""This is the quantized version of :class:`~torch.nn.InstanceNorm2d`. + + Additional args: + * **scale** - quantization scale of the output, type: double. + * **zero_point** - quantization zero point of the output, type: long. + + """ + + def __init__( + self, + num_features, + weight, + bias, + scale, + zero_point, + eps=1e-5, + momentum=0.1, + affine=False, + track_running_stats=False, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + num_features, eps, momentum, affine, track_running_stats, **factory_kwargs + ) + self.weight = weight + self.bias = bias + self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) + self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) + + def forward(self, input): + return torch.ops.quantized.instance_norm( + input, self.weight, self.bias, self.eps, self.scale, self.zero_point + ) + + def _get_name(self): + return "QuantizedInstanceNorm2d" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + scale, zero_point = mod.activation_post_process.calculate_qparams() + new_mod = cls( + mod.num_features, + mod.weight, + mod.bias, + float(scale), + int(zero_point), + mod.eps, + mod.affine, + ) + return new_mod + + @classmethod + def from_reference(cls, mod, scale, zero_point): + return cls( + mod.num_features, + mod.weight, + mod.bias, + float(scale), + int(zero_point), + mod.eps, + mod.affine, + ) + + +class InstanceNorm3d(torch.nn.InstanceNorm3d): + r"""This is the quantized version of :class:`~torch.nn.InstanceNorm3d`. + + Additional args: + * **scale** - quantization scale of the output, type: double. + * **zero_point** - quantization zero point of the output, type: long. + + """ + + def __init__( + self, + num_features, + weight, + bias, + scale, + zero_point, + eps=1e-5, + momentum=0.1, + affine=False, + track_running_stats=False, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + num_features, eps, momentum, affine, track_running_stats, **factory_kwargs + ) + self.weight = weight + self.bias = bias + self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) + self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) + + def forward(self, input): + return torch.ops.quantized.instance_norm( + input, self.weight, self.bias, self.eps, self.scale, self.zero_point + ) + + def _get_name(self): + return "QuantizedInstanceNorm3d" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + scale, zero_point = mod.activation_post_process.calculate_qparams() + new_mod = cls( + mod.num_features, + mod.weight, + mod.bias, + float(scale), + int(zero_point), + mod.eps, + mod.affine, + ) + return new_mod + + @classmethod + def from_reference(cls, mod, scale, zero_point): + return cls( + mod.num_features, + mod.weight, + mod.bias, + float(scale), + int(zero_point), + mod.eps, + mod.affine, + ) diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/rnn.py b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..e4d0eeb855dba456ad5baabbca34256c0ac13406 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/rnn.py @@ -0,0 +1,58 @@ +# mypy: allow-untyped-defs +import torch + + +__all__ = [ + "LSTM", +] + + +class LSTM(torch.ao.nn.quantizable.LSTM): + r"""A quantized long short-term memory (LSTM). + + For the description and the argument types, please, refer to :class:`~torch.nn.LSTM` + + Attributes: + layers : instances of the `_LSTMLayer` + + .. note:: + To access the weights and biases, you need to access them per layer. + See examples in :class:`~torch.ao.nn.quantizable.LSTM` + + Examples:: + >>> # xdoctest: +SKIP + >>> custom_module_config = { + ... 'float_to_observed_custom_module_class': { + ... nn.LSTM: nn.quantizable.LSTM, + ... }, + ... 'observed_to_quantized_custom_module_class': { + ... nn.quantizable.LSTM: nn.quantized.LSTM, + ... } + ... } + >>> tq.prepare(model, prepare_custom_module_class=custom_module_config) + >>> tq.convert(model, convert_custom_module_class=custom_module_config) + """ + + _FLOAT_MODULE = torch.ao.nn.quantizable.LSTM # type: ignore[assignment] + + def _get_name(self): + return "QuantizedLSTM" + + @classmethod + def from_float(cls, *args, **kwargs): + # The whole flow is float -> observed -> quantized + # This class does observed -> quantized only + raise NotImplementedError( + "It looks like you are trying to convert a " + "non-observed LSTM module. Please, see " + "the examples on quantizable LSTMs." + ) + + @classmethod + def from_observed(cls, other): + assert isinstance(other, cls._FLOAT_MODULE) # type: ignore[has-type] + converted = torch.ao.quantization.convert( + other, inplace=False, remove_qconfig=True + ) + converted.__class__ = cls + return converted diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/utils.py b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..508cf386c0d118477e2636f2d8d7d2296e41f477 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/quantized/modules/utils.py @@ -0,0 +1,144 @@ +# mypy: allow-untyped-defs +import abc +import collections +import itertools + +import torch +from torch.nn.modules.module import _addindent + + +__all__ = [ + "WeightedQuantizedModule", +] + + +class WeightedQuantizedModule(torch.nn.Module, metaclass=abc.ABCMeta): + """Wrapper for quantized modules than can be lowered from reference modules.""" + + @classmethod + @abc.abstractmethod + def from_reference(cls, ref_module, output_scale, output_zero_point): + raise NotImplementedError + + +def _get_weight_observer(observer): + # FakeQuantize observer + if hasattr(observer, "activation_post_process"): + observer = observer.activation_post_process + # UniformQuantizationObserverBase observer + return observer + + +def _needs_weight_clamping(observer, dtype): + observer = _get_weight_observer(observer) + if dtype in [torch.qint8, torch.quint8, torch.qint32]: + info = torch.iinfo(dtype) + return observer.quant_min > info.min or observer.quant_max < info.max + return False + + +def _clamp_weights(qweight, observer, scale, zp): + if not _needs_weight_clamping(observer, qweight.dtype): + return qweight + + observer = _get_weight_observer(observer) + min_, max_ = observer.quant_min, observer.quant_max + + # Doing this because can't use torch.ops.quantized.clamp() with per_channel qscheme yet. + qw_int_max = torch.clone(qweight.int_repr()).fill_(max_) + qw_int_min = torch.clone(qweight.int_repr()).fill_(min_) + qw_int = torch.minimum(torch.maximum(qweight.int_repr(), qw_int_min), qw_int_max) + + if observer.qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]: + qweight = torch._make_per_tensor_quantized_tensor( + qw_int, scale.item(), zp.item() + ) + elif observer.qscheme in [ + torch.per_channel_symmetric, + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + ]: + qweight = torch._make_per_channel_quantized_tensor( + qw_int, scale, zp, axis=observer.ch_axis + ) + else: + raise ValueError("Unexpected qscheme " + observer.qscheme) + return qweight + + +def _quantize_weight(float_wt, observer): + wt_scale, wt_zp = observer.calculate_qparams() + if observer.qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]: + qweight = torch.quantize_per_tensor( + float_wt, float(wt_scale), int(wt_zp), torch.qint8 + ) + qweight = _clamp_weights(qweight, observer, wt_scale, wt_zp) + elif observer.qscheme in [torch.per_channel_symmetric, torch.per_channel_affine]: + wt_axis = observer.ch_axis + qweight = torch.quantize_per_channel( + float_wt, + wt_scale.to(torch.double), + wt_zp.to(torch.int64), + wt_axis, + torch.qint8, + ) + qweight = _clamp_weights(qweight, observer, wt_scale, wt_zp) + elif observer.qscheme in [torch.per_channel_affine_float_qparams]: + qweight = torch.quantize_per_channel( + float_wt, + wt_scale.to(torch.float), + wt_zp.to(torch.float), + observer.ch_axis, + observer.dtype, + ) + qweight = _clamp_weights(qweight, observer, wt_scale, wt_zp) + else: + raise ValueError("Unexpected qscheme " + observer.qscheme) + return qweight + + +def _ntuple_from_first(n): + """Converts the argument to a tuple of size n + with the first element repeated.""" + + def parse(x): + while isinstance(x, collections.abc.Sequence): + if len(x) == n: + break + x = x[0] + return tuple(itertools.repeat(x, n)) + + return parse + + +def _hide_packed_params_repr(self, params): + # We don't want to show `PackedParams` children, hence custom + # `__repr__`. This is the same as nn.Module.__repr__, except the check + # for the `params module`. + extra_lines = [] + extra_repr = self.extra_repr() + # empty string will be split into list [''] + if extra_repr: + extra_lines = extra_repr.split("\n") + child_lines = [] + for key, module in self._modules.items(): + if isinstance(module, params): + continue + mod_str = repr(module) + mod_str = _addindent(mod_str, 2) + child_lines.append("(" + key + "): " + mod_str) + lines = extra_lines + child_lines + + main_str = self._get_name() + "(" + if lines: + # simple one-liner info, which most builtin Modules will use + if len(extra_lines) == 1 and not child_lines: + main_str += extra_lines[0] + else: + main_str += "\n " + "\n ".join(lines) + "\n" + + main_str += ")" + return main_str + + +_pair_from_first = _ntuple_from_first(2) diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/__init__.py b/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0619cb1c81ad897fa4c966506c4243ccb174b557 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/__init__.py @@ -0,0 +1,19 @@ +from .modules import * # noqa: F403 + + +__all__ = [ + "Linear", + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", + "RNNCell", + "LSTMCell", + "GRUCell", + "LSTM", + "GRU", + "Embedding", + "EmbeddingBag", +] diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d2ba638f4b9a47563ff8bc2fa176f2e6ee083ae Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__init__.py b/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..51233ab56d31857868d0667af3f428eadcefe7b7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__init__.py @@ -0,0 +1,29 @@ +from .conv import ( + Conv1d, + Conv2d, + Conv3d, + ConvTranspose1d, + ConvTranspose2d, + ConvTranspose3d, +) +from .linear import Linear +from .rnn import GRU, GRUCell, LSTM, LSTMCell, RNNCell +from .sparse import Embedding, EmbeddingBag + + +__all__ = [ + "Linear", + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", + "RNNCell", + "LSTMCell", + "GRUCell", + "LSTM", + "GRU", + "Embedding", + "EmbeddingBag", +] diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13002ee8ee98756bfa82c2f877f150a0eeffc371 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/conv.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/conv.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f8741b123a21ebef2307efff728c54dbf8ccf2f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/conv.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/linear.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/linear.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f61c55f3caf7486121a724bfa02362b35f8f105a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/linear.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/rnn.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/rnn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3aaed04ca20a83a4f709d63970eea8ea62b553e2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/rnn.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/sparse.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/sparse.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c8cc69116ebfd565d9ef1bb3a50d5f5fe2aa652 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/sparse.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f316fb972684c8a6fc2c1a6771d349a80ee60e2a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/conv.py b/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..d3b31aa1fc5cb892deef5f383c4b1434f92a8f73 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/conv.py @@ -0,0 +1,511 @@ +# mypy: allow-untyped-defs +from typing import Any, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.common_types import _size_1_t + +from .utils import ReferenceQuantizedModule + + +__all__ = [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", +] + + +class _ConvNd(torch.nn.modules.conv._ConvNd, ReferenceQuantizedModule): + """A reference version of nn.quantized.Conv2d + we will not pack the parameters in this module, since weight packing is an + optimization for quantized backends supported in PyTorch (fbgemm/qnnpack), + this is useful when user want to use this module in other backends like Glow. + """ + + __annotations__ = {"bias": Optional[torch.Tensor]} + _IS_REFERENCE = True + + @staticmethod + def from_float(cls, float_conv, weight_qparams): + qref_conv = cls( + float_conv.in_channels, + float_conv.out_channels, + float_conv.kernel_size, # type: ignore[arg-type] + float_conv.stride, # type: ignore[arg-type] + float_conv.padding, # type: ignore[arg-type] + float_conv.dilation, # type: ignore[arg-type] + float_conv.groups, + float_conv.bias is not None, # type: ignore[arg-type] + float_conv.padding_mode, + device=float_conv.weight.device, + dtype=float_conv.weight.dtype, + weight_qparams=weight_qparams, + ) + qref_conv.weight = torch.nn.Parameter(float_conv.weight.detach()) + if float_conv.bias is not None: + qref_conv.bias = torch.nn.Parameter(float_conv.bias.detach()) + return qref_conv + + +class Conv1d(_ConvNd, nn.Conv1d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: _size_1_t = 0, + dilation: _size_1_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=None, + weight_qparams: Optional[dict[str, Any]] = None, + ): + nn.Conv1d.__init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + device, + dtype, + ) + self._init_weight_qparams(weight_qparams, device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + we have: + w(float) -- quant - dequant \ + x(float) ------------- F.conv1d --- + + In the full model, we will see + w(float) -- quant - *dequant \ + x -- quant --- *dequant -- *F.conv1d --- *quant - dequant + and the backend should be able to fuse the ops with `*` into a quantized conv1d + """ + weight_quant_dequant = self.get_weight() + result = F.conv1d( + x, + weight_quant_dequant, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + return result + + def _get_name(self): + return "QuantizedConv1d(Reference)" + + @classmethod + def from_float(cls, float_conv, weight_qparams): # type: ignore[override] + return _ConvNd.from_float(cls, float_conv, weight_qparams) + + +class Conv2d(_ConvNd, nn.Conv2d): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + weight_qparams: Optional[dict[str, Any]] = None, + ): + nn.Conv2d.__init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + device, + dtype, + ) + self._init_weight_qparams(weight_qparams, device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + we have: + w(float) -- quant - dequant \ + x(float) ------------- F.conv2d --- + + In the full model, we will see + w(float) -- quant - *dequant \ + x -- quant --- *dequant -- *F.conv2d --- *quant - dequant + and the backend should be able to fuse the ops with `*` into a quantized conv2d + """ + weight_quant_dequant = self.get_weight() + result = F.conv2d( + x, + weight_quant_dequant, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + return result + + def _get_name(self): + return "QuantizedConv2d(Reference)" + + @classmethod + def from_float(cls, float_conv, weight_qparams): # type: ignore[override] + return _ConvNd.from_float(cls, float_conv, weight_qparams) + + +class Conv3d(_ConvNd, nn.Conv3d): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + weight_qparams: Optional[dict[str, Any]] = None, + ): + nn.Conv3d.__init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + device, + dtype, + ) + self._init_weight_qparams(weight_qparams, device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + we have: + w(float) -- quant - dequant \ + x(float) ------------- F.conv3d --- + + In the full model, we will see + w(float) -- quant - *dequant \ + x -- quant --- *dequant -- *F.conv3d --- *quant - dequant + and the backend should be able to fuse the ops with `*` into a quantized conv3d + """ + weight_quant_dequant = self.get_weight() + result = F.conv3d( + x, + weight_quant_dequant, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + return result + + def _get_name(self): + return "QuantizedConv3d(Reference)" + + @classmethod + def from_float(cls, float_conv, weight_qparams): # type: ignore[override] + return _ConvNd.from_float(cls, float_conv, weight_qparams) + + +class _ConvTransposeNd(_ConvNd, torch.nn.modules.conv._ConvTransposeNd): + """A reference version of nn.quantized.ConvTranspose2d + we will not pack the parameters in this module, since weight packing is an + optimization for quantized backends supported in PyTorch (fbgemm/qnnpack), + this is useful when user want to use this module in other backends like Glow. + """ + + @staticmethod + def from_float(cls, float_conv, weight_qparams): + qref_conv = cls( + float_conv.in_channels, + float_conv.out_channels, + float_conv.kernel_size, # type: ignore[arg-type] + float_conv.stride, # type: ignore[arg-type] + float_conv.padding, # type: ignore[arg-type] + float_conv.output_padding, # type: ignore[arg-type] + float_conv.groups, + float_conv.bias is not None, # type: ignore[arg-type] + float_conv.dilation, # type: ignore[arg-type] + float_conv.padding_mode, + device=float_conv.weight.device, + dtype=float_conv.weight.dtype, + weight_qparams=weight_qparams, + ) + qref_conv.weight = torch.nn.Parameter(float_conv.weight.detach()) + if float_conv.bias is not None: + qref_conv.bias = torch.nn.Parameter(float_conv.bias.detach()) + return qref_conv + + +class ConvTranspose1d(_ConvTransposeNd, nn.ConvTranspose1d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: _size_1_t = 0, + output_padding: _size_1_t = 0, + groups: int = 1, + bias: bool = True, + dilation: _size_1_t = 1, + padding_mode: str = "zeros", + device=None, + dtype=None, + weight_qparams: Optional[dict[str, Any]] = None, + ): + nn.ConvTranspose1d.__init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + output_padding, + groups, + bias, + dilation, + padding_mode, + device, + dtype, + ) + self._init_weight_qparams(weight_qparams, device) + + def forward( + self, x: torch.Tensor, output_size: Optional[list[int]] = None + ) -> torch.Tensor: + """ + we have: + w(float) -- quant - dequant \ + x(float) ------------- F.convTranspose1d --- + In the full model, we will see + w(float) -- quant - *dequant \ + x -- quant --- *dequant -- *F.convTranspose1d --- *quant - dequant + and the backend should be able to fuse the ops with `*` into a quantized conv1d + """ + + assert isinstance(self.padding, tuple) + # One cannot replace List by Tuple or Sequence in "_output_padding" because + # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. + output_padding = self._output_padding( + input, # type: ignore[arg-type] + output_size, + self.stride, # type: ignore[arg-type] + self.padding, # type: ignore[arg-type] + self.kernel_size, # type: ignore[arg-type] + self.dilation, # type: ignore[arg-type] + ) + + weight_quant_dequant = self.get_weight() + result = F.conv_transpose1d( + x, + weight_quant_dequant, + self.bias, + self.stride, + self.padding, + output_padding, + self.groups, + self.dilation, + ) + return result + + def _get_name(self): + return "QuantizedConvTranspose1d(Reference)" + + @classmethod + def from_float(cls, float_conv, weight_qparams): # type: ignore[override] + return _ConvTransposeNd.from_float(cls, float_conv, weight_qparams) + + +class ConvTranspose2d(_ConvTransposeNd, nn.ConvTranspose2d): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + padding_mode="zeros", + device=None, + dtype=None, + weight_qparams: Optional[dict[str, Any]] = None, + ): + nn.ConvTranspose2d.__init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + output_padding, + groups, + bias, + dilation, + padding_mode, + device, + dtype, + ) + self._init_weight_qparams(weight_qparams, device) + + def forward( + self, x: torch.Tensor, output_size: Optional[list[int]] = None + ) -> torch.Tensor: + """ + we have: + w(float) -- quant - dequant \ + x(float) ------------- F.convTranspose2d --- + In the full model, we will see + w(float) -- quant - *dequant \ + x -- quant --- *dequant -- *F.convTranspose2d --- *quant - dequant + and the backend should be able to fuse the ops with `*` into a quantized conv2d + """ + assert isinstance(self.padding, tuple) + # One cannot replace List by Tuple or Sequence in "_output_padding" because + # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. + + output_padding = self._output_padding( + input, # type: ignore[arg-type] + output_size, + self.stride, # type: ignore[arg-type] + self.padding, # type: ignore[arg-type] + self.kernel_size, # type: ignore[arg-type] + self.dilation, # type: ignore[arg-type] + ) + + weight_quant_dequant = self.get_weight() + result = F.conv_transpose2d( + x, + weight_quant_dequant, + self.bias, + self.stride, + self.padding, + output_padding, + self.groups, + self.dilation, + ) + + return result + + def _get_name(self): + return "QuantizedConvTranspose2d(Reference)" + + @classmethod + def from_float(cls, float_conv, weight_qparams): # type: ignore[override] + return _ConvTransposeNd.from_float(cls, float_conv, weight_qparams) + + +class ConvTranspose3d(_ConvTransposeNd, nn.ConvTranspose3d): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + padding_mode="zeros", + device=None, + dtype=None, + weight_qparams: Optional[dict[str, Any]] = None, + ): + nn.ConvTranspose3d.__init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + output_padding, + groups, + bias, + dilation, + padding_mode, + device, + dtype, + ) + self._init_weight_qparams(weight_qparams, device) + + def forward( + self, x: torch.Tensor, output_size: Optional[list[int]] = None + ) -> torch.Tensor: + """ + we have: + w(float) -- quant - dequant \ + x(float) ------------- F.convTranspose3d --- + In the full model, we will see + w(float) -- quant - *dequant \ + x -- quant --- *dequant -- *F.convTranspose3d --- *quant - dequant + and the backend should be able to fuse the ops with `*` into a quantized conv3d + """ + + assert isinstance(self.padding, tuple) + # One cannot replace List by Tuple or Sequence in "_output_padding" because + # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. + output_padding = self._output_padding( + input, # type: ignore[arg-type] + output_size, + self.stride, # type: ignore[arg-type] + self.padding, # type: ignore[arg-type] + self.kernel_size, # type: ignore[arg-type] + self.dilation, # type: ignore[arg-type] + ) + + weight_quant_dequant = self.get_weight() + result = F.conv_transpose3d( + x, + weight_quant_dequant, + self.bias, + self.stride, + self.padding, + output_padding, + self.groups, + self.dilation, + ) + return result + + def _get_name(self): + return "QuantizedConvTranspose3d(Reference)" + + @classmethod + def from_float(cls, float_conv, weight_qparams): # type: ignore[override] + return _ConvTransposeNd.from_float(cls, float_conv, weight_qparams) diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/linear.py b/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..d760b8ada3f4e8299b9f985acd68171b0e2e7432 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/linear.py @@ -0,0 +1,69 @@ +from typing import Any, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .utils import ReferenceQuantizedModule + + +__all__ = ["Linear"] + + +class Linear(nn.Linear, ReferenceQuantizedModule): + """A reference quantized linear module that fits into the FX + Graph Mode Quantization workflow + activation will be floating point Tensor, we will store floating + point weight as well in the module, but in forward we'll quantize + and dequantize the weight before running the floating point functional + linear operator. + """ + + _IS_REFERENCE = True + + def __init__( + self, + in_features: int, + out_features: int, + bias_: bool = True, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + weight_qparams: Optional[dict[str, Any]] = None, + ) -> None: + super().__init__(in_features, out_features, bias_, device, dtype) + self._init_weight_qparams(weight_qparams, device) + + def _get_name(self) -> str: + return "QuantizedLinear(Reference)" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + we have: + w(float) -- quant - dequant \ + x(float) ------------- F.linear --- + + In the full model, we will see + w(float) -- quant - *dequant \ + x -- quant --- *dequant -- *F.linear --- *quant - dequant + and the backend should be able to fuse the ops with `*` into a quantized linear + """ + weight_quant_dequant = self.get_weight() + result = F.linear(x, weight_quant_dequant, self.bias) + return result + + @classmethod + def from_float( + cls, float_linear: nn.Linear, weight_qparams: dict[str, Any] + ) -> "Linear": + qref_linear = Linear( + float_linear.in_features, + float_linear.out_features, + float_linear.bias is not None, + device=float_linear.weight.device, + dtype=float_linear.weight.dtype, + weight_qparams=weight_qparams, + ) + qref_linear.weight = torch.nn.Parameter(float_linear.weight.detach()) + if float_linear.bias is not None: + qref_linear.bias = torch.nn.Parameter(float_linear.bias.detach()) + return qref_linear diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/rnn.py b/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..a4427589fe5c3e4d099e26c0be9739192ab5cbfd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/rnn.py @@ -0,0 +1,853 @@ +# mypy: allow-untyped-defs +from typing import Any, Optional + +import torch +import torch.nn as nn +from torch import _VF, Tensor +from torch.nn.utils.rnn import PackedSequence + +from .utils import _quantize_and_dequantize_weight, _quantize_weight + + +__all__ = [ + "RNNCellBase", + "RNNCell", + "LSTMCell", + "GRUCell", + "RNNBase", + "LSTM", + "GRU", + "get_quantized_weight", +] + + +def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: + return tensor.index_select(dim, permutation) + + +def _get_weight_and_quantization_params(module, wn): + weight = getattr(module, wn) + params = [weight] + for param_name in [ + wn + n for n in ["_qscheme", "_dtype", "_scale", "_zero_point", "_axis_int"] + ]: + if hasattr(module, param_name): + param = getattr(module, param_name) + else: + param = None + params.append(param) + return params + + +def get_quantized_weight(module, wn): + if not hasattr(module, wn): + return None + params = _get_weight_and_quantization_params(module, wn) + weight = _quantize_weight(*params) + return weight + + +def _get_quantize_and_dequantized_weight(module, wn): + if not hasattr(module, wn): + return None + params = _get_weight_and_quantization_params(module, wn) + weight = _quantize_and_dequantize_weight(*params) + return weight + + +class RNNCellBase(nn.RNNCellBase): + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool, + num_chunks: int, + device=None, + dtype=None, + weight_qparams_dict=None, + ) -> None: + super().__init__( + input_size, hidden_size, bias, num_chunks, device=device, dtype=dtype + ) + # TODO(jerryzh168): maybe make this arg a required arg + if weight_qparams_dict is None: + weight_qparams = { + "qscheme": torch.per_tensor_affine, + "dtype": torch.quint8, + "scale": 1.0, + "zero_point": 0, + } + weight_qparams_dict = { + "weight_ih": weight_qparams, + "weight_hh": weight_qparams, + "is_decomposed": False, + } + assert len(weight_qparams_dict) == 3, ( + "Expected length for weight_qparams_dict to be 3 for QuantizedRNNCellBase(Reference)" + ) + self._init_weight_qparams_dict(weight_qparams_dict, device) + + def _init_weight_qparams_dict(self, weight_qparams_dict, device): + assert weight_qparams_dict is not None + self.is_decomposed = weight_qparams_dict["is_decomposed"] + for key, weight_qparams in weight_qparams_dict.items(): + if key == "is_decomposed": + continue + # TODO: refactor the duplicated code to utils.py + weight_qscheme = weight_qparams["qscheme"] + weight_dtype = weight_qparams["dtype"] + setattr(self, key + "_qscheme", weight_qscheme) + setattr(self, key + "_dtype", weight_dtype) + assert weight_qscheme in [ + None, + torch.per_tensor_affine, + torch.per_channel_affine, + ], Exception( + f"qscheme: {weight_qscheme} is not support in {self._get_name()}" + ) + if weight_qscheme is not None: + scale = weight_qparams["scale"] + scale_tensor = ( + scale.detach().clone() + if isinstance(scale, torch.Tensor) + else torch.tensor(scale, dtype=torch.float, device=device) + ) + self.register_buffer(key + "_scale", scale_tensor) + zp = weight_qparams["zero_point"] + zp_tensor = ( + zp.detach().clone() + if isinstance(zp, torch.Tensor) + else torch.tensor(zp, dtype=torch.int, device=device) + ) + self.register_buffer(key + "_zero_point", zp_tensor) + if weight_qscheme == torch.per_channel_affine: + axis = weight_qparams["axis"] + axis_tensor = ( + axis.detach().clone() + if isinstance(axis, torch.Tensor) + else torch.tensor(axis, dtype=torch.int, device=device) + ) + self.register_buffer(key + "_axis", axis_tensor) + else: + # added for TorchScriptability, not used + self.register_buffer( + key + "_axis", torch.tensor(0, dtype=torch.int, device=device) + ) + setattr(self, key + "_axis_int", getattr(self, key + "_axis").item()) + + def _get_name(self): + return "QuantizedRNNCellBase(Reference)" + + def get_quantized_weight_ih(self): + return get_quantized_weight(self, "weight_ih") + + def get_quantized_weight_hh(self): + return get_quantized_weight(self, "weight_hh") + + def get_weight_ih(self): + return _get_quantize_and_dequantized_weight(self, "weight_ih") + + def get_weight_hh(self): + return _get_quantize_and_dequantized_weight(self, "weight_hh") + + +class RNNCell(RNNCellBase): + """ + We'll store weight_qparams for all the weights (weight_ih and weight_hh), + we need to pass in a `weight_qparams_dict` that maps from weight name, + e.g. weight_ih, to the weight_qparams for that weight + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + nonlinearity: str = "tanh", + device=None, + dtype=None, + weight_qparams_dict: Optional[dict[str, Any]] = None, + ) -> None: + factory_kwargs = { + "device": device, + "dtype": dtype, + "weight_qparams_dict": weight_qparams_dict, + } + super().__init__(input_size, hidden_size, bias, num_chunks=1, **factory_kwargs) + self.nonlinearity = nonlinearity + + def _get_name(self): + return "QuantizedRNNCell(Reference)" + + # TODO: refactor nn.RNNCell to have a _forward that takes weight_ih and weight_hh as input + # and remove duplicated code, same for the other two Cell modules + def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: + assert input.dim() in ( + 1, + 2, + ), ( + f"RNNCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor" + ) + is_batched = input.dim() == 2 + if not is_batched: + input = input.unsqueeze(0) + + if hx is None: + hx = torch.zeros( + input.size(0), self.hidden_size, dtype=input.dtype, device=input.device + ) + else: + hx = hx.unsqueeze(0) if not is_batched else hx + + if self.nonlinearity == "tanh": + ret = _VF.rnn_tanh_cell( + input, + hx, + self.get_weight_ih(), + self.get_weight_hh(), + self.bias_ih, + self.bias_hh, + ) + elif self.nonlinearity == "relu": + ret = _VF.rnn_relu_cell( + input, + hx, + self.get_weight_ih(), + self.get_weight_hh(), + self.bias_ih, + self.bias_hh, + ) + else: + ret = input # TODO: remove when jit supports exception flow + raise RuntimeError(f"Unknown nonlinearity: {self.nonlinearity}") + + if not is_batched: + ret = ret.squeeze(0) + + return ret + + @classmethod + def from_float(cls, mod, weight_qparams_dict): + ref_mod = cls( + mod.input_size, + mod.hidden_size, + mod.bias, + mod.nonlinearity, + mod.weight_ih.device, + mod.weight_ih.dtype, + weight_qparams_dict, + ) + ref_mod.weight_ih = mod.weight_ih + ref_mod.weight_hh = mod.weight_hh + ref_mod.bias_ih = mod.bias_ih + ref_mod.bias_hh = mod.bias_hh + return ref_mod + + +class LSTMCell(RNNCellBase): + """ + We'll store weight_qparams for all the weights (weight_ih and weight_hh), + we need to pass in a `weight_qparams_dict` that maps from weight name, + e.g. weight_ih, to the weight_qparams for that weight + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + device=None, + dtype=None, + weight_qparams_dict: Optional[dict[str, Any]] = None, + ) -> None: + factory_kwargs = { + "device": device, + "dtype": dtype, + "weight_qparams_dict": weight_qparams_dict, + } + super().__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs) + + def _get_name(self): + return "QuantizedLSTMCell(Reference)" + + def forward( + self, input: Tensor, hx: Optional[tuple[Tensor, Tensor]] = None + ) -> tuple[Tensor, Tensor]: + assert input.dim() in ( + 1, + 2, + ), ( + f"LSTMCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor" + ) + is_batched = input.dim() == 2 + if not is_batched: + input = input.unsqueeze(0) + + if hx is None: + zeros = torch.zeros( + input.size(0), self.hidden_size, dtype=input.dtype, device=input.device + ) + hx = (zeros, zeros) + else: + hx = (hx[0].unsqueeze(0), hx[1].unsqueeze(0)) if not is_batched else hx + + ret = _VF.lstm_cell( + input, + hx, + self.get_weight_ih(), + self.get_weight_hh(), + self.bias_ih, + self.bias_hh, + ) + + if not is_batched: + ret = (ret[0].squeeze(0), ret[1].squeeze(0)) + return ret + + @classmethod + def from_float(cls, mod, weight_qparams_dict, use_precomputed_fake_quant=False): + ref_mod = cls( + mod.input_size, + mod.hidden_size, + mod.bias, + mod.weight_ih.device, + mod.weight_ih.dtype, + weight_qparams_dict, + ) + ref_mod.weight_ih = mod.weight_ih + ref_mod.weight_hh = mod.weight_hh + ref_mod.bias_ih = mod.bias_ih + ref_mod.bias_hh = mod.bias_hh + return ref_mod + + +class GRUCell(RNNCellBase): + """ + We'll store weight_qparams for all the weights (weight_ih and weight_hh), + we need to pass in a `weight_qparams_dict` that maps from weight name, + e.g. weight_ih, to the weight_qparams for that weight + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + device=None, + dtype=None, + weight_qparams_dict: Optional[dict[str, Any]] = None, + ) -> None: + factory_kwargs = { + "device": device, + "dtype": dtype, + "weight_qparams_dict": weight_qparams_dict, + } + super().__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs) + + def _get_name(self): + return "QuantizedGRUCell(Reference)" + + def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: + assert input.dim() in ( + 1, + 2, + ), ( + f"GRUCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor" + ) + is_batched = input.dim() == 2 + if not is_batched: + input = input.unsqueeze(0) + + if hx is None: + hx = torch.zeros( + input.size(0), self.hidden_size, dtype=input.dtype, device=input.device + ) + else: + hx = hx.unsqueeze(0) if not is_batched else hx + + ret = _VF.gru_cell( + input, + hx, + self.get_weight_ih(), + self.get_weight_hh(), + self.bias_ih, + self.bias_hh, + ) + + if not is_batched: + ret = ret.squeeze(0) + + return ret + + @classmethod + def from_float(cls, mod, weight_qparams_dict): + ref_mod = cls( + mod.input_size, + mod.hidden_size, + mod.bias, + mod.weight_ih.device, + mod.weight_ih.dtype, + weight_qparams_dict, + ) + ref_mod.weight_ih = mod.weight_ih + ref_mod.weight_hh = mod.weight_hh + ref_mod.bias_ih = mod.bias_ih + ref_mod.bias_hh = mod.bias_hh + return ref_mod + + +class RNNBase(nn.RNNBase): + def __init__( + self, + mode: str, + input_size: int, + hidden_size: int, + num_layers: int = 1, + bias: bool = True, + batch_first: bool = False, + dropout: float = 0.0, + bidirectional: bool = False, + proj_size: int = 0, + device=None, + dtype=None, + weight_qparams_dict: Optional[dict[str, Any]] = None, + ) -> None: + super().__init__( + mode, + input_size, + hidden_size, + num_layers, + bias, + batch_first, + dropout, + bidirectional, + proj_size, + device, + dtype, + ) + # TODO(jerryzh168): maybe make this arg a required arg + if weight_qparams_dict is None: + weight_qparams = { + "qscheme": torch.per_tensor_affine, + "dtype": torch.quint8, + "scale": 1.0, + "zero_point": 0, + } + weight_qparams_dict = {"is_decomposed": False} # type: ignore[dict-item] + for wn in self._flat_weights_names: + if wn.startswith("weight"): + weight_qparams_dict[wn] = weight_qparams + self._init_weight_qparams_dict(weight_qparams_dict, device) + + def _init_weight_qparams_dict(self, weight_qparams_dict, device): + self.is_decomposed = weight_qparams_dict["is_decomposed"] + for key, weight_qparams in weight_qparams_dict.items(): + if key == "is_decomposed": + continue + weight_qscheme = weight_qparams["qscheme"] + weight_dtype = weight_qparams["dtype"] + setattr(self, key + "_qscheme", weight_qscheme) + setattr(self, key + "_dtype", weight_dtype) + assert weight_qscheme in [ + None, + torch.per_tensor_affine, + torch.per_channel_affine, + ], Exception( + f"qscheme: {weight_qscheme} is not support in {self._get_name()}" + ) + if weight_qscheme is not None: + self.register_buffer( + key + "_scale", + torch.tensor( + weight_qparams["scale"], dtype=torch.float, device=device + ), + ) + self.register_buffer( + key + "_zero_point", + torch.tensor( + weight_qparams["zero_point"], dtype=torch.int, device=device + ), + ) + if weight_qscheme == torch.per_channel_affine: + self.register_buffer( + key + "_axis", + torch.tensor( + weight_qparams["axis"], dtype=torch.int, device=device + ), + ) + else: + # added for TorchScriptability, not used + self.register_buffer( + key + "_axis", torch.tensor(0, dtype=torch.int, device=device) + ) + setattr(self, key + "_axis_int", getattr(self, key + "_axis").item()) + + +class LSTM(RNNBase): + """Reference Quantized LSTM Module + We'll store weight_qparams for all the weights in _flat_weights, we need to pass in + a `weight_qparams_dict` that maps from weight name, e.g. weight_ih_l0, + to the weight_qparams for that weight + """ + + def __init__(self, *args, **kwargs): + super().__init__("LSTM", *args, **kwargs) + + # Same as above, see torch/nn/modules/module.py::_forward_unimplemented + def permute_hidden( # type: ignore[override] + self, + hx: tuple[Tensor, Tensor], + permutation: Optional[Tensor], + ) -> tuple[Tensor, Tensor]: + if permutation is None: + return hx + return _apply_permutation(hx[0], permutation), _apply_permutation( + hx[1], permutation + ) + + def get_expected_cell_size( + self, input: Tensor, batch_sizes: Optional[Tensor] + ) -> tuple[int, int, int]: + if batch_sizes is not None: + mini_batch = int(batch_sizes[0]) + else: + mini_batch = input.size(0) if self.batch_first else input.size(1) + num_directions = 2 if self.bidirectional else 1 + expected_hidden_size = ( + self.num_layers * num_directions, + mini_batch, + self.hidden_size, + ) + return expected_hidden_size + + # In the future, we should prevent mypy from applying contravariance rules here. + # See torch/nn/modules/module.py::_forward_unimplemented + def check_forward_args( # type: ignore[override] + self, + input: Tensor, + hidden: tuple[Tensor, Tensor], + batch_sizes: Optional[Tensor], + ): + self.check_input(input, batch_sizes) + self.check_hidden_size( + hidden[0], + self.get_expected_hidden_size(input, batch_sizes), + "Expected hidden[0] size {}, got {}", + ) + self.check_hidden_size( + hidden[1], + self.get_expected_cell_size(input, batch_sizes), + "Expected hidden[1] size {}, got {}", + ) + + def get_quantized_weight_bias_dict(self): + """dictionary from flat_weight_name to quantized weight or (unquantized) bias + e.g. + { + "weight_ih_l0": quantized_weight, + "bias_ih_l0": unquantized_bias, + ... + } + """ + quantized_weight_bias_dict = {} + for wn in self._flat_weights_names: + if hasattr(self, wn): + if wn.startswith("weight"): + weight_or_bias = get_quantized_weight(self, wn) + else: + weight_or_bias = getattr(self, wn) + else: + weight_or_bias = None + quantized_weight_bias_dict[wn] = weight_or_bias + return quantized_weight_bias_dict + + def get_flat_weights(self): + flat_weights = [] + for wn in self._flat_weights_names: + if hasattr(self, wn): + weight = getattr(self, wn) + if wn.startswith("weight"): + params = _get_weight_and_quantization_params(self, wn) + weight = _quantize_and_dequantize_weight(*params) + else: + weight = None + flat_weights.append(weight) + return flat_weights + + def forward(self, input, hx=None): # noqa: F811 + orig_input = input + # xxx: isinstance check needs to be in conditional for TorchScript to compile + batch_sizes = None + if isinstance(orig_input, PackedSequence): + input, batch_sizes, sorted_indices, unsorted_indices = input + max_batch_size = int(batch_sizes[0]) + else: + batch_sizes = None + is_batched = input.dim() == 3 + batch_dim = 0 if self.batch_first else 1 + if not is_batched: + input = input.unsqueeze(batch_dim) + max_batch_size = input.size(0) if self.batch_first else input.size(1) + sorted_indices = None + unsorted_indices = None + + if hx is None: + num_directions = 2 if self.bidirectional else 1 + real_hidden_size = ( + self.proj_size if self.proj_size > 0 else self.hidden_size + ) + h_zeros = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + real_hidden_size, + dtype=input.dtype, + device=input.device, + ) + c_zeros = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + hx = (h_zeros, c_zeros) + else: + if batch_sizes is None: # If not PackedSequence input. + if is_batched: # type: ignore[possibly-undefined] + if hx[0].dim() != 3 or hx[1].dim() != 3: + msg = ( + "For batched 3-D input, hx and cx should " + f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors" + ) + raise RuntimeError(msg) + else: + if hx[0].dim() != 2 or hx[1].dim() != 2: + msg = ( + "For unbatched 2-D input, hx and cx should " + f"also be 2-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors" + ) + raise RuntimeError(msg) + hx = (hx[0].unsqueeze(1), hx[1].unsqueeze(1)) + + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + hx = self.permute_hidden(hx, sorted_indices) + + self.check_forward_args(input, hx, batch_sizes) + if batch_sizes is None: + result = _VF.lstm( + input, + hx, + self.get_flat_weights(), + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + self.batch_first, + ) + else: + result = _VF.lstm( + input, + batch_sizes, + hx, + self.get_flat_weights(), + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + ) + output = result[0] + hidden = result[1:] + # xxx: isinstance check needs to be in conditional for TorchScript to compile + if isinstance(orig_input, PackedSequence): + output_packed = PackedSequence( + output, batch_sizes, sorted_indices, unsorted_indices + ) + return output_packed, self.permute_hidden(hidden, unsorted_indices) + else: + if not is_batched: # type: ignore[possibly-undefined] + output = output.squeeze(batch_dim) # type: ignore[possibly-undefined] + hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1)) + return output, self.permute_hidden(hidden, unsorted_indices) + + def _get_name(self): + return "QuantizedLSTM(Reference)" + + @classmethod + def from_float(cls, mod, weight_qparams_dict): + ref_mod = cls( + mod.input_size, + mod.hidden_size, + mod.num_layers, + mod.bias, + mod.batch_first, + mod.dropout, + mod.bidirectional, + weight_qparams_dict=weight_qparams_dict, + ) + for wn in mod._flat_weights_names: + setattr(ref_mod, wn, getattr(mod, wn)) + return ref_mod + + +class GRU(RNNBase): + """Reference Quantized GRU Module + We'll store weight_qparams for all the weights in _flat_weights, we need to pass in + a `weight_qparams_dict` that maps from weight name, e.g. weight_ih_l0, + to the weight_qparams for that weight + """ + + def __init__(self, *args, **kwargs): + if "proj_size" in kwargs: + raise ValueError( + "proj_size argument is only supported for LSTM, not RNN or GRU" + ) + super().__init__("GRU", *args, **kwargs) + + def get_quantized_weight_bias_dict(self): + """dictionary from flat_weight_name to quantized weight or (unquantized) bias + e.g. + { + "weight_ih_l0": quantized_weight, + "bias_ih_l0": unquantized_bias, + ... + } + """ + quantized_weight_bias_dict = {} + for wn in self._flat_weights_names: + if hasattr(self, wn): + if wn.startswith("weight"): + weight_or_bias = get_quantized_weight(self, wn) + else: + weight_or_bias = getattr(self, wn) + else: + weight_or_bias = None + quantized_weight_bias_dict[wn] = weight_or_bias + return quantized_weight_bias_dict + + def get_flat_weights(self): + flat_weights = [] + for wn in self._flat_weights_names: + if hasattr(self, wn): + weight = getattr(self, wn) + if wn.startswith("weight"): + params = _get_weight_and_quantization_params(self, wn) + weight = _quantize_and_dequantize_weight(*params) + else: + weight = None + flat_weights.append(weight) + return flat_weights + + def forward(self, input, hx=None): # noqa: F811 + # Note: this is copied from the forward of GRU in https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py + # only changed self._flat_weights to self.get_flat_weights() + # TODO: maybe we can try inheriting from that class and define get_flat_weights + # as a @property? this might interfere with TorchScript, if we remove that + # requirement in the future we should be able to do this + orig_input = input + # xxx: isinstance check needs to be in conditional for TorchScript to compile + if isinstance(orig_input, PackedSequence): + input, batch_sizes, sorted_indices, unsorted_indices = input + max_batch_size = int(batch_sizes[0]) + else: + batch_sizes = None + assert input.dim() in ( + 2, + 3, + ), ( + f"GRU: Expected input to be 2-D or 3-D but received {input.dim()}-D tensor" + ) + is_batched = input.dim() == 3 + batch_dim = 0 if self.batch_first else 1 + if not is_batched: + input = input.unsqueeze(batch_dim) + if hx is not None: + if hx.dim() != 2: + raise RuntimeError( + f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor" + ) + hx = hx.unsqueeze(1) + else: + if hx is not None and hx.dim() != 3: + raise RuntimeError( + f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor" + ) + max_batch_size = input.size(0) if self.batch_first else input.size(1) + sorted_indices = None + unsorted_indices = None + + if hx is None: + num_directions = 2 if self.bidirectional else 1 + hx = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + else: + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + hx = self.permute_hidden(hx, sorted_indices) + + self.check_forward_args(input, hx, batch_sizes) + if batch_sizes is None: + result = _VF.gru( + input, + hx, + self.get_flat_weights(), + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + self.batch_first, + ) + else: + result = _VF.gru( + input, + batch_sizes, + hx, + self.get_flat_weights(), + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + ) + output = result[0] + hidden = result[1] + + # xxx: isinstance check needs to be in conditional for TorchScript to compile + if isinstance(orig_input, PackedSequence): + output_packed = PackedSequence( + output, batch_sizes, sorted_indices, unsorted_indices + ) + return output_packed, self.permute_hidden(hidden, unsorted_indices) + else: + if not is_batched: # type: ignore[possibly-undefined] + output = output.squeeze(batch_dim) # type: ignore[possibly-undefined] + hidden = hidden.squeeze(1) + + return output, self.permute_hidden(hidden, unsorted_indices) + + def _get_name(self): + return "QuantizedGRU(Reference)" + + @classmethod + def from_float(cls, mod, weight_qparams_dict): + ref_mod = cls( + mod.input_size, + mod.hidden_size, + mod.num_layers, + mod.bias, + mod.batch_first, + mod.dropout, + mod.bidirectional, + weight_qparams_dict=weight_qparams_dict, + ) + for wn in mod._flat_weights_names: + setattr(ref_mod, wn, getattr(mod, wn)) + return ref_mod diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/sparse.py b/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/sparse.py new file mode 100644 index 0000000000000000000000000000000000000000..567a737c503215332d949624da4a47e8dba2cb7b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/sparse.py @@ -0,0 +1,162 @@ +# mypy: allow-untyped-defs +from typing import Any, Optional + +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from .utils import ReferenceQuantizedModule + + +__all__ = ["Embedding", "EmbeddingBag"] + + +class Embedding(nn.Embedding, ReferenceQuantizedModule): + """A reference quantized Embedding module that fits into the + FX Graph Mode Quantization workflow, activation will be floating point Tensor, + we will store floating point weight as well in the module, but in forward we'll + quantize and dequantize the weight before running the floating point functional + embedding operator. + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, + _weight: Optional[Tensor] = None, + device=None, + dtype=None, + weight_qparams: Optional[dict[str, Any]] = None, + ) -> None: + super().__init__( + num_embeddings, + embedding_dim, + padding_idx, + max_norm, + norm_type, + scale_grad_by_freq, + sparse, + _weight, + device, + dtype, + ) + self._init_weight_qparams(weight_qparams, device) + + def _get_name(self): + return "QuantizedEmbedding(Reference)" + + def forward(self, input: Tensor) -> Tensor: + weight_quant_dequant = self.get_weight() + return F.embedding( + input, + weight_quant_dequant, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) + + @classmethod + def from_float(cls, mod, weight_qparams): + return cls( + mod.num_embeddings, + mod.embedding_dim, + mod.padding_idx, + mod.max_norm, + mod.norm_type, + mod.scale_grad_by_freq, + mod.sparse, + mod.weight, + mod.weight.device, + mod.weight.dtype, + weight_qparams, + ) + + +class EmbeddingBag(nn.EmbeddingBag, ReferenceQuantizedModule): + """A reference quantized EmbeddingBag module that fits into the + FX Graph Mode Quantization workflow, activation will be floating point Tensor, + we will store floating point weight as well in the module, but in forward we'll + quantize and dequantize the weight before running the floating point functional + embedding operator. + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + mode: str = "mean", + sparse: bool = False, + _weight: Optional[Tensor] = None, + include_last_offset: bool = False, + padding_idx: Optional[int] = None, + device=None, + dtype=None, + weight_qparams: Optional[dict[str, Any]] = None, + ) -> None: + super().__init__( + num_embeddings, + embedding_dim, + max_norm, + norm_type, + scale_grad_by_freq, + mode, + sparse, + _weight, + include_last_offset, + padding_idx, + device, + dtype, + ) + self._init_weight_qparams(weight_qparams, device) + + def _get_name(self): + return "QuantizedEmbedding(Reference)" + + def forward( + self, + input: Tensor, + offsets: Optional[Tensor] = None, + per_sample_weights: Optional[Tensor] = None, + ) -> Tensor: + weight_quant_dequant = self.get_weight() + return F.embedding_bag( + input, + weight_quant_dequant, + offsets, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.mode, + self.sparse, + per_sample_weights, + self.include_last_offset, + self.padding_idx, + ) + + @classmethod + def from_float(cls, mod, weight_qparams, use_precomputed_fake_quant=False): + return cls( + mod.num_embeddings, + mod.embedding_dim, + mod.max_norm, + mod.norm_type, + mod.scale_grad_by_freq, + mod.mode, + mod.sparse, + mod.weight, + mod.include_last_offset, + mod.padding_idx, + mod.weight.device, + mod.weight.dtype, + weight_qparams, + ) diff --git a/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/utils.py b/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..481ec29fd96649065c0f79c73e402cf7995dc263 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/utils.py @@ -0,0 +1,434 @@ +# mypy: allow-untyped-defs +import typing + +import torch + + +__all__ = [ + "ReferenceQuantizedModule", +] + + +class ReferenceQuantizedModule(torch.nn.Module): + def _init_weight_qparams(self, weight_qparams, device): + if weight_qparams is None: + weight_qparams = { + "qscheme": torch.per_tensor_affine, + "dtype": torch.quint8, + "scale": 1.0, + "zero_point": 0, + } + self.weight_qscheme: torch.qscheme = weight_qparams["qscheme"] + self.weight_dtype = weight_qparams["dtype"] + assert self.weight_qscheme in [ + None, + torch.per_tensor_affine, + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + ], ( + f"qscheme: {self.weight_qscheme} is not support in reference quantized {self._get_name()}" + ) + if self.weight_dtype in [ + torch.quint8, + torch.qint8, + torch.quint4x2, + torch.qint32, + ]: + zero_point_dtype = ( + weight_qparams["zero_point"].dtype + if isinstance(weight_qparams["zero_point"], torch.Tensor) + else torch.int + ) + w_scale = weight_qparams["scale"] + w_scale_tensor = ( + w_scale.detach().clone() + if isinstance(w_scale, torch.Tensor) + else torch.tensor(w_scale, dtype=torch.float, device=device) + ) + self.register_buffer("weight_scale", w_scale_tensor) + w_zp = weight_qparams["zero_point"] + w_zp_tensor = ( + w_zp.detach().clone() + if isinstance(w_zp, torch.Tensor) + else torch.tensor(w_zp, dtype=zero_point_dtype, device=device) + ) + self.register_buffer("weight_zero_point", w_zp_tensor) + if self.weight_qscheme in [ + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + ]: + w_axis = weight_qparams["axis"] + w_axis_tensor = ( + w_axis.detach().clone() + if isinstance(w_axis, torch.Tensor) + else torch.tensor(w_axis, dtype=torch.int, device=device) + ) + self.register_buffer("weight_axis", w_axis_tensor) + else: + # added for TorchScriptability, not used + self.register_buffer( + "weight_axis", torch.tensor(0, dtype=torch.int, device=device) + ) + else: + # added for TorchScriptability, and for torch.float + self.register_buffer( + "weight_scale", torch.tensor(1.0, dtype=torch.float, device=device) + ) + self.register_buffer( + "weight_zero_point", torch.tensor(0, dtype=torch.int, device=device) + ) + self.register_buffer( + "weight_axis", torch.tensor(0, dtype=torch.int, device=device) + ) + self.is_decomposed: bool = weight_qparams.get("is_decomposed", False) + # store weight_axis as weight_axis_int due to some constraints of torchdynamo.export + # for capturing `.item` operations + self.weight_axis_int: int = self.weight_axis.item() # type: ignore[operator, assignment] + self.weight_quant_min: typing.Optional[int] = weight_qparams.get( + "quant_min", None + ) + self.weight_quant_max: typing.Optional[int] = weight_qparams.get( + "quant_max", None + ) + + def get_weight(self): + """ + Fake quantize (quantize and dequantize) the weight with + the quantization parameters for weight, this is used to + simulate the numerics for the quantized weight in a quantized + model + """ + # suppress mypy warning + assert isinstance(self.weight_scale, torch.Tensor) + assert isinstance(self.weight_zero_point, torch.Tensor) + if self.is_decomposed: + return _quantize_and_dequantize_weight_decomposed( + self.weight, # type: ignore[arg-type] + self.weight_qscheme, + self.weight_dtype, + self.weight_scale, + self.weight_zero_point, + self.weight_axis_int, + self.weight_quant_min, + self.weight_quant_max, + ) + else: + return _quantize_and_dequantize_weight( + self.weight, # type: ignore[arg-type] + self.weight_qscheme, + self.weight_dtype, + self.weight_scale, + self.weight_zero_point, + self.weight_axis_int, + ) + + def get_quantized_weight(self): + # suppress mypy warning + assert isinstance(self.weight_scale, torch.Tensor) + assert isinstance(self.weight_zero_point, torch.Tensor) + # assert isinstance(self.weight_axis, torch.Tensor) + if self.is_decomposed: + return _quantize_weight_decomposed( + self.weight, # type: ignore[arg-type] + self.weight_qscheme, + self.weight_dtype, + self.weight_scale, + self.weight_zero_point, + self.weight_axis_int, + self.weight_quant_min, + self.weight_quant_max, + ) + else: + return _quantize_weight( + self.weight, # type: ignore[arg-type] + self.weight_qscheme, + self.weight_dtype, + self.weight_scale, + self.weight_zero_point, + self.weight_axis_int, + ) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + _save_weight_qparams( + destination, + prefix, + self.weight_qscheme, + self.weight_dtype, + self.weight_scale, + self.weight_zero_point, + self.weight_axis, + ) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + for key in _get_weight_qparam_keys(state_dict, prefix): + setattr(self, key, state_dict[prefix + key]) + state_dict.pop(prefix + key) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + +def _quantize_weight_decomposed( + weight: torch.Tensor, + weight_qscheme: torch.qscheme, + weight_dtype: torch.dtype, + weight_scale: torch.Tensor, + weight_zero_point: torch.Tensor, + weight_axis: int, + weight_quant_min: typing.Optional[int], + weight_quant_max: typing.Optional[int], +) -> torch.Tensor: + _DTYPE_TO_QVALUE_BOUNDS: dict[torch.dtype, tuple[int, int]] = { + torch.uint8: (0, 255), + torch.int8: (-128, 127), + torch.int32: (int(-(2**31)), int(2**31 - 1)), + } + + # TODO: add an util function for converting qdtype to dtype + _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE = { + torch.quint8: torch.uint8, + torch.qint8: torch.int8, + torch.qint32: torch.int32, + } + if weight_qscheme == torch.per_tensor_affine: + if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]: + weight_dtype_ = _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE[weight_dtype] + if weight_quant_min is None or weight_quant_max is None: + weight_quant_min, weight_quant_max = _DTYPE_TO_QVALUE_BOUNDS[ + weight_dtype_ + ] + weight = torch.ops.quantized_decomposed.quantize_per_tensor( + weight, + weight_scale, + weight_zero_point, + weight_quant_min, + weight_quant_max, + weight_dtype_, + ) + return weight + elif weight_qscheme in [ + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + ]: + # TODO: torch.quint4x2 is not supported + if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]: + weight_dtype_ = _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE[weight_dtype] + if weight_quant_min is None or weight_quant_max is None: + weight_quant_min, weight_quant_max = _DTYPE_TO_QVALUE_BOUNDS[ + weight_dtype_ + ] + weight = torch.ops.quantized_decomposed.quantize_per_channel( + weight, + weight_scale, + weight_zero_point, + weight_axis, + weight_quant_min, + weight_quant_max, + weight_dtype_, + ) # type: ignore[arg-type] + return weight + raise ValueError(f"Unsupported dtype and qscheme: {weight_dtype}, {weight_qscheme}") + + +def _dequantize_weight_decomposed( + weight: torch.Tensor, + weight_qscheme: torch.qscheme, + weight_dtype: torch.dtype, + weight_scale: torch.Tensor, + weight_zero_point: torch.Tensor, + weight_axis: int, + weight_quant_min: typing.Optional[int], + weight_quant_max: typing.Optional[int], +) -> torch.Tensor: + # TODO: get the quant_min and quant_max from activation_post_process + _DTYPE_TO_QVALUE_BOUNDS: dict[torch.dtype, tuple[int, int]] = { + torch.uint8: (0, 255), + torch.int8: (-128, 127), + torch.int32: (int(-(2**31)), int(2**31 - 1)), + } + # TODO: add an util function for converting qdtype to dtype + _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE = { + torch.quint8: torch.uint8, + torch.qint8: torch.int8, + torch.qint32: torch.int32, + } + weight_dtype_ = _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE[weight_dtype] + if weight_quant_min is None or weight_quant_max is None: + weight_quant_min, weight_quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype_] + if weight_qscheme == torch.per_tensor_affine: + if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]: + weight = torch.ops.quantized_decomposed.dequantize_per_tensor( + weight, + weight_scale, + weight_zero_point, + weight_quant_min, + weight_quant_max, + weight_dtype_, + ) + return weight + elif weight_qscheme in [ + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + ]: + # TODO: torch.quint4x2 is not supported + if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]: + weight = torch.ops.quantized_decomposed.dequantize_per_channel( + weight, + weight_scale, + weight_zero_point, + weight_axis, + weight_quant_min, + weight_quant_max, + weight_dtype_, + ) # type: ignore[arg-type] + return weight + raise ValueError(f"Unsupported dtype and qscheme: {weight_dtype}, {weight_qscheme}") + + +def _quantize_weight( + weight: torch.Tensor, + weight_qscheme: torch.qscheme, + weight_dtype: torch.dtype, + weight_scale: torch.Tensor, + weight_zero_point: torch.Tensor, + weight_axis_int: int, +) -> torch.Tensor: + if weight_dtype == torch.float16: + weight = weight.to(weight_dtype) + return weight + + if weight_qscheme == torch.per_tensor_affine: + if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]: + weight = torch.quantize_per_tensor( + weight, weight_scale, weight_zero_point, weight_dtype + ) + return weight + elif weight_qscheme in [ + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + ]: + if weight_dtype in [torch.quint8, torch.qint8, torch.quint4x2, torch.qint32]: + weight = torch.quantize_per_channel( + weight, weight_scale, weight_zero_point, weight_axis_int, weight_dtype + ) # type: ignore[arg-type] + return weight + raise ValueError(f"Unsupported dtype and qscheme: {weight_dtype}, {weight_qscheme}") + + +def _quantize_and_dequantize_weight_decomposed( + weight: torch.Tensor, + weight_qscheme: torch.qscheme, + weight_dtype: torch.dtype, + weight_scale: torch.Tensor, + weight_zero_point: torch.Tensor, + weight_axis_int: int, + weight_quant_min: typing.Optional[int], + weight_quant_max: typing.Optional[int], +) -> torch.Tensor: + """Quantize and then dequantize the weight based on + the quantization parameters + """ + if weight_qscheme in [ + torch.per_tensor_affine, + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + ]: + weight_quant = _quantize_weight_decomposed( + weight, + weight_qscheme, + weight_dtype, + weight_scale, + weight_zero_point, + weight_axis_int, + weight_quant_min, + weight_quant_max, + ) + weight_dequant = _dequantize_weight_decomposed( + weight_quant, + weight_qscheme, + weight_dtype, + weight_scale, + weight_zero_point, + weight_axis_int, + weight_quant_min, + weight_quant_max, + ) + else: + weight_dequant = weight + return weight_dequant + + +def _quantize_and_dequantize_weight( + weight: torch.Tensor, + weight_qscheme: torch.qscheme, + weight_dtype: torch.dtype, + weight_scale: torch.Tensor, + weight_zero_point: torch.Tensor, + weight_axis_int: int, +) -> torch.Tensor: + """Quantize and then dequantize the weight based on + the quantization parameters + """ + if weight_qscheme in [ + torch.per_tensor_affine, + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + ]: + weight_quant = _quantize_weight( + weight, + weight_qscheme, + weight_dtype, + weight_scale, + weight_zero_point, + weight_axis_int, + ) + weight_dequant = weight_quant.dequantize() + else: + weight_dequant = weight + return weight_dequant + + +def _save_weight_qparams( + destination, + prefix, + weight_qscheme, + weight_dtype, + weight_scale, + weight_zero_point, + weight_axis, +): + destination[prefix + "weight_qscheme"] = weight_qscheme + destination[prefix + "weight_dtype"] = weight_dtype + if weight_qscheme is not None: + destination[prefix + "weight_scale"] = weight_scale + destination[prefix + "weight_zero_point"] = weight_zero_point + if weight_qscheme == torch.per_channel_affine: + destination[prefix + "weight_axis"] = weight_axis + + +def _get_weight_qparam_keys(state_dict: dict[str, typing.Any], prefix: str): + keys = ["weight_qscheme", "weight_dtype"] + weight_qscheme = state_dict[prefix + "weight_qscheme"] + if weight_qscheme is not None: + keys.append("weight_scale") + keys.append("weight_zero_point") + if weight_qscheme == torch.quantize_per_channel: + keys.append("weight_axis") + return keys diff --git a/phivenv/Lib/site-packages/torch/ao/nn/sparse/__init__.py b/phivenv/Lib/site-packages/torch/ao/nn/sparse/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..308f4d5b55c44b6749a5ede08d5ccc40a09b7bca --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/sparse/__init__.py @@ -0,0 +1 @@ +from . import quantized diff --git a/phivenv/Lib/site-packages/torch/ao/nn/sparse/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/sparse/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da85ad2d41c7402446c62e352654673d3f2a7006 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/sparse/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/sparse/quantized/__init__.py b/phivenv/Lib/site-packages/torch/ao/nn/sparse/quantized/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a2243c81bdfd91480acc208701abb52c2cdf31b0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/sparse/quantized/__init__.py @@ -0,0 +1,10 @@ +from torch.ao.nn.sparse.quantized import dynamic + +from .linear import Linear, LinearPackedParams + + +__all__ = [ + "dynamic", + "Linear", + "LinearPackedParams", +] diff --git a/phivenv/Lib/site-packages/torch/ao/nn/sparse/quantized/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/sparse/quantized/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ac2eb98e4db2dfdcbe913f38db74c3b64adfa37 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/sparse/quantized/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/sparse/quantized/__pycache__/linear.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/sparse/quantized/__pycache__/linear.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7bbd2160bc279801385009cde8f2b79b128b83f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/sparse/quantized/__pycache__/linear.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/sparse/quantized/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/sparse/quantized/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7549626d5bd7ccc49421b7868102e0ec2f9976c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/sparse/quantized/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/__init__.py b/phivenv/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7399b5b90e3074a9a1de49ad15215b576035fdb5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/__init__.py @@ -0,0 +1,6 @@ +from .linear import Linear + + +__all__ = [ + "Linear", +] diff --git a/phivenv/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..998f735998c75408b7138fd6b9d3fb69228a71ab Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/linear.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/linear.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2fd034bf5960bf0a2e13112bb502b4a6546b8b50 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/linear.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/linear.py b/phivenv/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..4b0fa06690afc035f8433db969198531bf5d42c3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/linear.py @@ -0,0 +1,189 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch +import torch.ao.nn.intrinsic as nni +from torch.ao.nn.quantized.modules.utils import ( + _hide_packed_params_repr, + _quantize_weight, +) +from torch.ao.nn.sparse.quantized import linear +from torch.ao.nn.sparse.quantized.utils import LinearBlockSparsePattern + + +__all__ = ["Linear"] + + +class Linear(torch.nn.Module): + r""" + A dynamically quantized sparse linear module with float tensor as inputs and outputs. + """ + + _version = 1 + _op_type = "sparse_dynamic" + _FLOAT_MODULE = torch.nn.Linear + + def __init__( + self, + in_features, + out_features, + row_block_size, + col_block_size, + bias=True, + dtype=torch.qint8, + ): + super().__init__() + + if dtype != torch.qint8: + raise NotImplementedError( + "Only QINT8 is supported for Sparse Quantized Linear Dynamic" + ) + + self.in_features = in_features + self.out_features = out_features + + if bias: + bias = torch.zeros(self.out_features, dtype=torch.float) + else: + bias = None + + qweight = torch._empty_affine_quantized( + [out_features, in_features], scale=1, zero_point=0, dtype=torch.qint8 + ) + self._packed_params = linear.LinearPackedParams( + row_block_size=row_block_size, col_block_size=col_block_size, dtype=dtype + ) + self._packed_params.set_weight_bias( + qweight, bias, row_block_size, col_block_size + ) + + def _get_name(self): + return "SparseQuantizedDynamicLinear" + + def extra_repr(self): + return f"in_features={self.in_features}, out_features={self.out_features}, qscheme={self.weight().qscheme()}" + + def __repr__(self): + return _hide_packed_params_repr(self, linear.LinearPackedParams) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.sparse.qlinear_dynamic(x, self._packed_params._packed_params) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + "op_type"] = self._op_type + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + op_type = int(state_dict[prefix + "op_type"]) + assert op_type == "sparse", ( + f"Cannot load from op_type [{op_type}], expecting [{self._op_type}]" + ) + state_dict.pop(prefix + "op_type") + + version = local_metadata.get("version", None) + assert version <= self._version + + # Is this code valid? In old quantization it seemed to be used to load + # older model + weight = state_dict.pop(prefix + "weight") + bias = state_dict.pop(prefix + "bias") + state_dict.update( + { + prefix + "_packed_params.weight": weight, + prefix + "_packed_params.bias": bias, + } + ) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def _weight_bias(self): + return self._packed_params._weight_bias() + + def weight(self): + return self._weight_bias()[0] + + def bias(self): + return self._weight_bias()[1] + + def set_weight_bias( + self, + w: torch.Tensor, + b: Optional[torch.Tensor], + row_block_size: Optional[int], + col_block_size: Optional[int], + ) -> None: + assert row_block_size is not None and col_block_size is not None + self.out_features = w.shape[0] + self.in_features = w.shape[1] + self._packed_params.set_weight_bias(w, b, row_block_size, col_block_size) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Create a quantized sparse dynamic module from a float module. + + We only care about the convert at this stage, no need for observers just yet. + """ + assert type(mod) == cls._FLOAT_MODULE, ( + " nnq." + + cls.__name__ + + ".from_float only works for " + + cls._FLOAT_MODULE.__name__ + ) + # TODO: Need to add options to qconfig to avoid the calibration. + # TODO: Add calibration for the sparsity + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + if type(mod) == nni.LinearReLU: + mod = mod[0] + if mod.qconfig is not None and mod.qconfig.weight is not None: + weight_observer = mod.qconfig.weight() + else: + # We have the circular import issues if we import the qconfig in the beginning of this file: + # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the + # import until we need it. + from torch.ao.quantization.qconfig import default_dynamic_qconfig + + weight_observer = default_dynamic_qconfig.weight() + + # It is important to multiply by the mask BEFORE calling the `weight_observer` + # TODO (zaf): Mask might not be part of the qconfig (T83295194) + weight = mod.weight + if getattr(mod.qconfig, "mask", False): + weight = mod.qconfig.mask * mod.weight + + weight_observer(weight) + dtype = weight_observer.dtype + assert dtype == torch.qint8, "Weight observer must have dtype torch.qint8" + _w_sc, w_zp = weight_observer.calculate_qparams() + if isinstance(w_zp, torch.Tensor): + assert not torch.any(w_zp.bool()), "All weight zero points must map to 0" + else: + assert w_zp == 0, "Weight zero point must map to 0" + qweight = _quantize_weight(weight.float(), weight_observer) + + row_block_size, col_block_size = LinearBlockSparsePattern.block_size() + qlinear = cls( + mod.in_features, + mod.out_features, + row_block_size, + col_block_size, + dtype=dtype, + ) + qlinear.set_weight_bias(qweight, mod.bias, row_block_size, col_block_size) + return qlinear diff --git a/phivenv/Lib/site-packages/torch/ao/nn/sparse/quantized/linear.py b/phivenv/Lib/site-packages/torch/ao/nn/sparse/quantized/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..f7683e7f3cb847ff4ab32eae33822514f5ed9671 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/sparse/quantized/linear.py @@ -0,0 +1,275 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch +from torch.ao.nn.quantized.modules.utils import ( + _hide_packed_params_repr, + _quantize_weight, +) + + +__all__ = ["LinearPackedParams", "Linear"] + + +# TODO (zaf): Inherit from `quantized.LinearPackedParams` (T83294430) +class LinearPackedParams(torch.nn.Module): + _version = 1 + + def __init__(self, row_block_size=1, col_block_size=4, dtype=torch.qint8): + super().__init__() + + if dtype != torch.qint8: + raise NotImplementedError("Linear prepacking only supports QINT8") + self.dtype = dtype + wq = torch._empty_affine_quantized( + [1, 1], scale=1.0, zero_point=0, dtype=torch.qint8 + ) + self.set_weight_bias(wq, None, row_block_size, col_block_size) + + def _get_name(self): + return "SparseQuantizedLinearPackedParams" + + @torch.jit.export + def set_weight_bias( + self, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + row_block_size: Optional[int], + col_block_size: Optional[int], + ) -> None: + assert row_block_size is not None and col_block_size is not None + self._packed_params = torch.ops.sparse.qlinear_prepack( + weight, bias, row_block_size, col_block_size + ) + + @torch.jit.export + def _weight_bias(self): + (weight, bias, block_sizes) = torch.ops.sparse.qlinear_unpack( + self._packed_params + ) + return (weight, bias, block_sizes[0], block_sizes[1]) + + def forward(self, x): + return x + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + "dtype"] = self.dtype + destination[prefix + "_packed_params"] = self._weight_bias() + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + assert version <= self._version + + self.dtype = state_dict.pop(prefix + "dtype") + weight, bias, row_block_size, col_block_size = state_dict.pop( + prefix + "_packed_params" + ) + self.set_weight_bias(weight, bias, row_block_size, col_block_size) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + @torch.jit.export + def __getstate__(self): + return self._packed_params, self.training, self.dtype + + @torch.jit.export + def __setstate__(self, state): + (self._packed_params, self.training, self.dtype) = state + + def __repr__(self): + return self._weight_bias().__repr__() + + +# TODO (zaf): Inherit from `quantized.Linear` (T83294430) +class Linear(torch.nn.Module): + r""" + A quantized sparse linear module with quantized tensor as inputs and outputs. + """ + + _version = 1 + _FLOAT_MODULE = torch.nn.Linear + + def __init__( + self, + in_features, + out_features, + row_block_size, + col_block_size, + bias=True, + dtype=torch.qint8, + ): + super().__init__() + + if dtype != torch.qint8: + raise NotImplementedError( + "Only QINT8 is supported for Sparse Quantized Linear" + ) + + self.in_features = in_features + self.out_features = out_features + + if bias: + bias = torch.zeros(self.out_features, dtype=torch.float) + else: + bias = None + + qweight = torch._empty_affine_quantized( + [out_features, in_features], scale=1, zero_point=0, dtype=torch.qint8 + ) + self._packed_params = LinearPackedParams( + row_block_size=row_block_size, col_block_size=col_block_size, dtype=dtype + ) + self._packed_params.set_weight_bias( + qweight, bias, row_block_size, col_block_size + ) + self.scale = 1.0 + self.zero_point = 0 + + @classmethod + def _get_name(cls): + return "SparseQuantizedLinear" + + def extra_repr(self): + return ( + f"in_features={self.in_features}, out_features={self.out_features}, scale={self.scale}, " + f"zero_point={self.zero_point}, qscheme={self.weight().qscheme()}" + ) + + def __repr__(self): + return _hide_packed_params_repr(self, LinearPackedParams) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.sparse.qlinear( + x, self._packed_params._packed_params, self.scale, self.zero_point + ) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + "scale"] = torch.tensor(self.scale) + destination[prefix + "zero_point"] = torch.tensor(self.zero_point) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + self.scale = float(state_dict[prefix + "scale"]) + state_dict.pop(prefix + "scale") + + self.zero_point = int(state_dict[prefix + "zero_point"]) + state_dict.pop(prefix + "zero_point") + + state_dict.pop(prefix + "op_type") + + version = local_metadata.get("version", None) + assert version <= self._version + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def _weight_bias(self): + return self._packed_params._weight_bias() + + def weight(self): + return self._weight_bias()[0] + + def bias(self): + return self._weight_bias()[1] + + def set_weight_bias( + self, + w: torch.Tensor, + b: Optional[torch.Tensor], + row_block_size: Optional[int], + col_block_size: Optional[int], + ) -> None: + assert row_block_size is not None and col_block_size is not None + self._packed_params.set_weight_bias(w, b, row_block_size, col_block_size) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Create a quantized sparse module from a float module. + + We only care about the convert at this stage, no need for observers just yet. + + TODO(zaf): Need to add the sparse params to the qconfig + """ + assert type(mod) == cls._FLOAT_MODULE, ( + cls._get_name() + ".from_float only works for " + cls._FLOAT_MODULE.__name__ + ) + assert hasattr(mod, "sparse_params"), ( + "Expecting the Linear to have `sparse_params`. Make sure you have provided arguments " + 'in the `sparsifier.squash_mask(params_to_save=("sparse_block_shape",))` method.' + ) + sparse_block_shape = mod.sparse_params.get("sparse_block_shape", None) # type: ignore[operator, union-attr] + assert isinstance(sparse_block_shape, (tuple, list)) + assert len(sparse_block_shape) == 2 + # TODO: Need to add options to qconfig to avoid the calibration. + # TODO: Add calibration for the sparsity + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + activation_post_process = mod.activation_post_process + weight_post_process = mod.qconfig.weight() # type: ignore[operator, union-attr] + + # Assumption is that the weight is already sparsified by the + # `sparsifier.convert` + weight = mod.weight + + weight_post_process(weight) + dtype = weight_post_process.dtype + act_scale, act_zp = activation_post_process.calculate_qparams() # type: ignore[operator, union-attr] + assert dtype == torch.qint8, "Weight observer must have dtype torch.qint8" + w_sc, w_zp = weight_post_process.calculate_qparams() + if isinstance(w_zp, torch.Tensor): + assert not torch.any(w_zp.bool()), "All weight zero points must map to 0" + else: + assert w_zp == 0, "Weight zero point must map to 0" + qweight = _quantize_weight(weight.float(), weight_post_process) + + row_block_size = mod.sparse_params["sparse_block_shape"][0] # type: ignore[index] + col_block_size = mod.sparse_params["sparse_block_shape"][1] # type: ignore[index] + qlinear = cls( + mod.in_features, + mod.out_features, + row_block_size, + col_block_size, + dtype=dtype, + ) + qlinear.set_weight_bias( + qweight, + mod.bias, + row_block_size, # type: ignore[arg-type] + col_block_size, # type: ignore[arg-type] + ) + qlinear.scale = float(act_scale) + qlinear.zero_point = int(act_zp) + return qlinear diff --git a/phivenv/Lib/site-packages/torch/ao/nn/sparse/quantized/utils.py b/phivenv/Lib/site-packages/torch/ao/nn/sparse/quantized/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c984cb2a9587f46c877cb9c3bbafa3563615ed81 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/nn/sparse/quantized/utils.py @@ -0,0 +1,63 @@ +import threading +from typing import Optional + + +__all__ = ["LinearBlockSparsePattern"] + + +def _is_valid_linear_block_sparse_pattern( + row_block_size: int, col_block_size: int +) -> bool: + return (row_block_size == 1 and col_block_size == 4) or ( + row_block_size == 8 and col_block_size == 1 + ) + + +# This is a stop-gap measure as current flow does not allow module +# specific block sparse pattern. +# Infact there is no way to convey sparse pattern via module config +# of quantization flow. Thus using the global context to convey +# sparsity pattern. +# Once the flow supports it, this should be removed. +class LinearBlockSparsePattern: + rlock = threading.RLock() + row_block_size: int = 1 + col_block_size: int = 4 + prev_row_block_size: int = 1 + prev_col_block_size: int = 4 + + def __init__(self, row_block_size: int = 1, col_block_size: int = 4): + assert _is_valid_linear_block_sparse_pattern(row_block_size, col_block_size) + LinearBlockSparsePattern.rlock.acquire() + LinearBlockSparsePattern.prev_row_block_size = ( + LinearBlockSparsePattern.row_block_size + ) + LinearBlockSparsePattern.prev_col_block_size = ( + LinearBlockSparsePattern.col_block_size + ) + LinearBlockSparsePattern.row_block_size = row_block_size + LinearBlockSparsePattern.col_block_size = col_block_size + + def __enter__(self) -> None: + pass + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], + backtrace: Optional[object], + ) -> None: + LinearBlockSparsePattern.row_block_size = ( + LinearBlockSparsePattern.prev_row_block_size + ) + LinearBlockSparsePattern.col_block_size = ( + LinearBlockSparsePattern.prev_col_block_size + ) + LinearBlockSparsePattern.rlock.release() + + @staticmethod + def block_size() -> tuple[int, int]: + return ( + LinearBlockSparsePattern.row_block_size, + LinearBlockSparsePattern.col_block_size, + ) diff --git a/phivenv/Lib/site-packages/torch/ao/ns/__init__.py b/phivenv/Lib/site-packages/torch/ao/ns/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/ao/ns/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/ns/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60a68b9f1c23586a9f385f1abfddcd2d21f483b9 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/ns/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/ns/__pycache__/_numeric_suite.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/ns/__pycache__/_numeric_suite.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cef5392d01c4a444a9fd6e573b73a7dedec314a6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/ns/__pycache__/_numeric_suite.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/ns/__pycache__/_numeric_suite_fx.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/ns/__pycache__/_numeric_suite_fx.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..774eadd5b98cc0caa0aa0278568ffc3d28473a41 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/ns/__pycache__/_numeric_suite_fx.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/ns/_numeric_suite.py b/phivenv/Lib/site-packages/torch/ao/ns/_numeric_suite.py new file mode 100644 index 0000000000000000000000000000000000000000..0a6376e5be9d2e04e5f79d957510fca2a3f204c0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/ns/_numeric_suite.py @@ -0,0 +1,567 @@ +# mypy: allow-untyped-defs +from typing import Any, Callable, Optional, Union + +import torch +import torch.ao.nn.quantized as nnq +import torch.ao.nn.quantized.dynamic as nnqd +import torch.nn as nn +from torch.ao.quantization import prepare +from torch.ao.quantization.quantization_mappings import ( + get_default_compare_output_module_list, +) + + +NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST = { + nnqd.Linear, + nnq.Linear, + nnqd.LSTM, + nn.LSTM, +} + + +def _find_match( + str_list: Union[dict[str, Any], list[str]], + key_str: str, + postfix: str, +) -> Optional[str]: + split_str = key_str.split(".") + if split_str[-1] == postfix: + match_string = "".join(key_str.split(".")[0:-1]) + for s2 in str_list: + pattern1 = "".join(s2.split(".")[0:-1]) + pattern2 = "".join(s2.split(".")[0:-2]) + if match_string == pattern1: + return s2 + if match_string == pattern2: + return s2 + + # For matching "fc.weight" and "fc._packed_params._packed_params" + if postfix == "_packed_params": + match_string = "".join(key_str.split(".")[0:-2]) + if len(match_string) == 0: + return None + for s2 in str_list: + pattern1 = "".join(s2.split(".")[0:-1]) + pattern2 = "".join(s2.split(".")[0:-2]) + if match_string == pattern1: + return s2 + if match_string == pattern2: + return s2 + return None + else: + return None + + +def compare_weights( + float_dict: dict[str, Any], quantized_dict: dict[str, Any] +) -> dict[str, dict[str, torch.Tensor]]: + r"""Compare the weights of the float module with its corresponding quantized + module. Return a dict with key corresponding to module names and each entry being + a dictionary with two keys 'float' and 'quantized', containing the float and + quantized weights. This dict can be used to compare and compute the quantization + error of the weights of float and quantized models. + + Example usage:: + + wt_compare_dict = compare_weights(float_model.state_dict(), qmodel.state_dict()) + for key in wt_compare_dict: + print( + key, + compute_error( + wt_compare_dict[key]["float"], + wt_compare_dict[key]["quantized"].dequantize(), + ), + ) + + Args: + float_dict: state dict of the float model + quantized_dict: state dict of the quantized model + + Return: + weight_dict: dict with key corresponding to module names and each entry being + a dictionary with two keys 'float' and 'quantized', containing the float and + quantized weights + """ + torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_weights") + weight_dict: dict[str, dict] = {} + for key in quantized_dict: + match_key = _find_match(float_dict, key, "weight") + if match_key is not None: + weight_dict[key] = {} + weight_dict[key]["float"] = float_dict[match_key] + weight_dict[key]["quantized"] = quantized_dict[key] + continue + + # For matching "fc.weight" and "fc._packed_params._packed_params" + match_key = _find_match(float_dict, key, "_packed_params") + if match_key is not None: + weight_dict[key] = {} + weight_dict[key]["float"] = float_dict[match_key] + weight_dict[key]["quantized"] = quantized_dict[key][0] + + # For LSTM + split_str = key.split(".") + if split_str[-1] == "param" and split_str[-3] == "_all_weight_values": + layer = split_str[-2] + module_name = ".".join(split_str[:-3]) + float_weight_ih_key = module_name + ".weight_ih_l" + layer + float_weight_hh_key = module_name + ".weight_hh_l" + layer + if float_weight_ih_key in float_dict and float_weight_hh_key in float_dict: + weight_dict[key] = {} + weight_dict[key]["float"] = float_dict[float_weight_ih_key] + weight_dict[key]["quantized"] = ( + quantized_dict[key].__getstate__()[0][4][0].__getstate__()[0][0] + ) + weight_dict[key]["float"] = float_dict[float_weight_hh_key] + weight_dict[key]["quantized"] = ( + quantized_dict[key].__getstate__()[0][4][1].__getstate__()[0][0] + ) + + return weight_dict + + +def _get_logger_dict_helper( + mod: nn.Module, + target_dict: dict[str, Any], + prefix: str = "", +) -> None: + r"""This is the helper function for get_logger_dict + + Args: + mod: module we want to save all logger stats + prefix: prefix for the current module + target_dict: the dictionary used to save all logger stats + """ + + def get_prefix(prefix): + return prefix if prefix == "" else prefix + "." + + for name, child in mod.named_children(): + if isinstance(child, Logger): + target_dict[get_prefix(prefix) + "stats"] = child.stats + break + + for name, child in mod.named_children(): + module_prefix = get_prefix(prefix) + name if prefix else name + _get_logger_dict_helper(child, target_dict, module_prefix) + + +def get_logger_dict(mod: nn.Module, prefix: str = "") -> dict[str, dict]: + r"""Traverse the modules and save all logger stats into target dict. + This is mainly used for quantization accuracy debug. + + Type of loggers supported: + ShadowLogger: used to log the outputs of the quantized module and its matching float shadow module, + OutputLogger: used to log the outputs of the modules + + Args: + mod: module we want to save all logger stats + prefix: prefix for the current module + + Return: + target_dict: the dictionary used to save all logger stats + + """ + torch._C._log_api_usage_once("quantization_api._numeric_suite.get_logger_dict") + + target_dict: dict[str, dict] = {} + _get_logger_dict_helper(mod, target_dict, prefix) + return target_dict + + +class Logger(nn.Module): + r"""Base class for stats logging""" + + def __init__(self): + super().__init__() + self.stats = {} + # We only insert observer if the op is quantized with static quantization, + # which is identified by activation_observer.dtype == quint8. This is needed + # when attaching Logger as observer for FX mode + self.dtype = torch.quint8 + + def forward(self, x): + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + + +class ShadowLogger(Logger): + r"""Class used in Shadow module to record the outputs of the original and + shadow modules. + """ + + def __init__(self): + super().__init__() + self.stats["float"] = [] + self.stats["quantized"] = [] + + def forward(self, x, y): # type: ignore[override] + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + if len(x) > 1: + x = x[0] + if len(y) > 1: + y = y[0] + self.stats["quantized"].append(x.detach()) + self.stats["float"].append(y.detach()) + + +class OutputLogger(Logger): + r"""Class used to log the outputs of the module""" + + def __init__(self): + super().__init__() + self.stats["tensor_val"] = [] + + def forward(self, x): + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + self.stats["tensor_val"].append(x) + return x + + +def _convert_tuple_to_list(t: Any) -> Any: + return [_convert_tuple_to_list(x) for x in t] if type(t) is tuple else t + + +def _dequantize_tensor_list(t: Any) -> Any: + return ( + [_dequantize_tensor_list(x) for x in t] + if type(t) is list + else t.dequantize() + if t.is_quantized + else t + ) + + +class Shadow(nn.Module): + r"""Shadow module attaches the float module to its matching quantized module + as the shadow. Then it uses Logger module to process the outputs of both + modules. + + Args: + q_module: module quantized from float_module that we want to shadow + float_module: float module used to shadow q_module + logger_cls: type of logger used to process the outputs of q_module and + float_module. ShadowLogger or custom loggers can be used. + """ + + def __init__(self, q_module, float_module, logger_cls): + super().__init__() + self.orig_module = q_module + self.shadow_module = float_module + self.dequant = nnq.DeQuantize() + self.logger = logger_cls() + + def forward(self, *x) -> torch.Tensor: + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + xl = _convert_tuple_to_list(x) + output = self.orig_module(*xl) + xl_float = _dequantize_tensor_list(xl) + shadow_output = self.shadow_module(*xl_float) + self.logger(output, shadow_output) + return output + + def add(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + output = self.orig_module.add(x, y) + x = x.dequantize() + y = y.dequantize() + shadow_output = self.shadow_module.add(x, y) + self.logger(output, shadow_output) + return output + + def add_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor: + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + output = self.orig_module.add_scalar(x, y) + x = x.dequantize() + shadow_output = self.shadow_module.add_scalar(x, y) + self.logger(output, shadow_output) + return output + + def mul(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + output = self.orig_module.mul(x, y) + x = x.dequantize() + y = y.dequantize() + shadow_output = self.shadow_module.mul(x, y) + self.logger(output, shadow_output) + return output + + def mul_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor: + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + output = self.orig_module.mul_scalar(x, y) + x = x.dequantize() + shadow_output = self.shadow_module.mul_scalar(x, y) + self.logger(output, shadow_output) + return output + + def cat(self, x: list[torch.Tensor], dim: int = 0) -> torch.Tensor: + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + output = self.orig_module.cat(x, dim) + x = [y.dequantize() for y in x] + shadow_output = self.shadow_module.cat(x, dim) + self.logger(output, shadow_output) + return output + + def add_relu(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + output = self.orig_module.add_relu(x, y) + x = x.dequantize() + y = y.dequantize() + shadow_output = self.shadow_module.add_relu(x, y) + self.logger(output, shadow_output) + return output + + +def prepare_model_with_stubs( + float_module: nn.Module, + q_module: nn.Module, + module_swap_list: set[type], + logger_cls: Callable, +) -> None: + r"""Prepare the model by attaching the float module to its matching quantized + module as the shadow if the float module type is in module_swap_list. + + Example usage:: + + prepare_model_with_stubs(float_model, q_model, module_swap_list, Logger) + q_model(data) + ob_dict = get_logger_dict(q_model) + + Args: + float_module: float module used to generate the q_module + q_module: module quantized from float_module + module_swap_list: list of float module types to attach the shadow + logger_cls: type of logger to be used in shadow module to process the outputs of + quantized module and its float shadow module + """ + torch._C._log_api_usage_once( + "quantization_api._numeric_suite.prepare_model_with_stubs" + ) + + float_module_children = dict(float_module.named_children()) + + reassign = {} + for name, mod in q_module.named_children(): + if name not in float_module_children: + continue + + float_mod = float_module_children[name] + + if type(float_mod) not in module_swap_list: + prepare_model_with_stubs(float_mod, mod, module_swap_list, logger_cls) + + # Insert shadow module only if the module is not of the same type as + # the floating point module + if type(float_mod) in module_swap_list and not _is_identical_module_type( + mod, float_mod + ): + reassign[name] = Shadow(mod, float_mod, logger_cls) + + for key, value in reassign.items(): + q_module._modules[key] = value + + +def _is_identical_module_type(mod1, mod2): + # Compare if two modules have the same dtype + mod1_module_types = [type(mod) for mod in mod1.modules()] + mod2_module_types = [type(mod) for mod in mod2.modules()] + return mod1_module_types == mod2_module_types + + +def compare_model_stub( + float_model: nn.Module, + q_model: nn.Module, + module_swap_list: set[type], + *data, + logger_cls=ShadowLogger, +) -> dict[str, dict]: + r"""Compare quantized module in a model with its floating point counterpart, + feeding both of them the same input. Return a dict with key corresponding to + module names and each entry being a dictionary with two keys 'float' and + 'quantized', containing the output tensors of quantized and its matching + float shadow module. This dict can be used to compare and compute the module + level quantization error. + + This function first call prepare_model_with_stubs() to swap the quantized + module that we want to compare with the Shadow module, which takes quantized + module, corresponding float module and logger as input, and creates a forward + path inside to make the float module to shadow quantized module sharing the + same input. The logger can be customizable, default logger is ShadowLogger + and it will save the outputs of the quantized module and float module that + can be used to compute the module level quantization error. + + Example usage:: + + module_swap_list = [ + torchvision.models.quantization.resnet.QuantizableBasicBlock + ] + ob_dict = compare_model_stub(float_model, qmodel, module_swap_list, data) + for key in ob_dict: + print( + key, + compute_error( + ob_dict[key]["float"], ob_dict[key]["quantized"].dequantize() + ), + ) + + Args: + float_model: float model used to generate the q_model + q_model: model quantized from float_model + module_swap_list: list of float module types at which shadow modules will + be attached. + data: input data used to run the prepared q_model + logger_cls: type of logger to be used in shadow module to process the outputs of + quantized module and its float shadow module + """ + torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_model_stub") + prepare_model_with_stubs(float_model, q_model, module_swap_list, logger_cls) + q_model(*data) + ob_dict = get_logger_dict(q_model) + return ob_dict + + +def get_matching_activations( + float_module: nn.Module, + q_module: nn.Module, +) -> dict[str, dict[str, torch.Tensor]]: + r"""Find the matching activation between float and quantized modules. + + Args: + float_module: float module used to generate the q_module + q_module: module quantized from float_module + + Return: + act_dict: dict with key corresponding to quantized module names and each + entry being a dictionary with two keys 'float' and 'quantized', containing + the matching float and quantized activations + """ + torch._C._log_api_usage_once( + "quantization_api._numeric_suite.get_matching_activations" + ) + float_dict = get_logger_dict(float_module) + quantized_dict = get_logger_dict(q_module) + act_dict: dict[str, dict] = {} + for key in quantized_dict: + if len(quantized_dict[key]["tensor_val"]) == 0: + continue + match_key = _find_match(sorted(float_dict, reverse=True), key, "stats") + if match_key is not None: + act_dict[key] = {} + act_dict[key]["float"] = float_dict[match_key]["tensor_val"] + act_dict[key]["quantized"] = quantized_dict[key]["tensor_val"] + return act_dict + + +def prepare_model_outputs( + float_module: nn.Module, + q_module: nn.Module, + logger_cls=OutputLogger, + allow_list=None, +) -> None: + r"""Prepare the model by attaching the logger to both float module + and quantized module if they are in the allow_list. + + Args: + float_module: float module used to generate the q_module + q_module: module quantized from float_module + logger_cls: type of logger to be attached to float_module and q_module + allow_list: list of module types to attach logger + """ + torch._C._log_api_usage_once( + "quantization_api._numeric_suite.prepare_model_outputs" + ) + if allow_list is None: + allow_list = get_default_compare_output_module_list() + + qconfig_debug = torch.ao.quantization.QConfig(activation=logger_cls, weight=None) + float_module.qconfig = qconfig_debug # type: ignore[assignment] + prepare( + float_module, inplace=True, allow_list=allow_list, prepare_custom_config_dict={} + ) + q_module.qconfig = qconfig_debug # type: ignore[assignment] + prepare( + q_module, + inplace=True, + allow_list=allow_list, + observer_non_leaf_module_list=NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST, + prepare_custom_config_dict={}, + ) + + +def compare_model_outputs( + float_model: nn.Module, + q_model: nn.Module, + *data, + logger_cls=OutputLogger, + allow_list=None, +) -> dict[str, dict[str, torch.Tensor]]: + r"""Compare output activations between float and quantized models at + corresponding locations for the same input. Return a dict with key corresponding + to quantized module names and each entry being a dictionary with two keys + 'float' and 'quantized', containing the activations of quantized model and + float model at matching locations. This dict can be used to compare and + compute the propagation quantization error. + + Example usage:: + + act_compare_dict = compare_model_outputs(float_model, qmodel, data) + for key in act_compare_dict: + print( + key, + compute_error( + act_compare_dict[key]["float"], + act_compare_dict[key]["quantized"].dequantize(), + ), + ) + + Args: + float_model: float model used to generate the q_model + q_model: model quantized from float_model + data: input data used to run the prepared float_model and q_model + logger_cls: type of logger to be attached to float_module and q_module + allow_list: list of module types to attach logger + + Return: + act_compare_dict: dict with key corresponding to quantized module names + and each entry being a dictionary with two keys 'float' and 'quantized', + containing the matching float and quantized activations + """ + torch._C._log_api_usage_once( + "quantization_api._numeric_suite.compare_model_outputs" + ) + if allow_list is None: + allow_list = get_default_compare_output_module_list() + prepare_model_outputs(float_model, q_model, logger_cls, allow_list) + float_model(*data) + q_model(*data) + act_compare_dict = get_matching_activations(float_model, q_model) + return act_compare_dict diff --git a/phivenv/Lib/site-packages/torch/ao/ns/_numeric_suite_fx.py b/phivenv/Lib/site-packages/torch/ao/ns/_numeric_suite_fx.py new file mode 100644 index 0000000000000000000000000000000000000000..2d30cbe68b600ca9f6b86ad158babfd30d188d04 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/ns/_numeric_suite_fx.py @@ -0,0 +1,1124 @@ +# mypy: allow-untyped-defs +""" +This module contains tooling to compare weights and activations +across models. Example usage:: + + import copy + import torch + import torch.ao.quantization.quantize_fx as quantize_fx + import torch.ao.ns._numeric_suite_fx as ns + + m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1)).eval() + mp = quantize_fx.prepare_fx(m, {"": torch.ao.quantization.default_qconfig}) + # We convert a copy because we need the original prepared model + # to be available for comparisons, and `quantize_fx.convert_fx` is inplace. + mq = quantize_fx.convert_fx(copy.deepcopy(mp)) + + # + # Comparing weights + # + + # extract weight pairs + weight_comparison = ns.extract_weights("a", mp, "b", mq) + + # add SQNR for each comparison, inplace + ns.extend_logger_results_with_comparison( + weight_comparison, "a", "b", torch.ao.ns.fx.utils.compute_sqnr, "sqnr" + ) + + # weight_comparison contains the weights from `mp` and `mq` stored + # in pairs, and can be used for further analysis. + + + # + # Comparing activations, with error propagation + # + + # add loggers + mp_ns, mq_ns = ns.add_loggers( + "a", copy.deepcopy(mp), "b", copy.deepcopy(mq), ns.OutputLogger + ) + + # send an example datum to capture intermediate activations + datum = torch.randn(1, 1, 1, 1) + mp_ns(datum) + mq_ns(datum) + + # extract intermediate activations + act_comparison = ns.extract_logger_info(mp_ns, mq_ns, ns.OutputLogger, "b") + + # add SQNR for each comparison, inplace + ns.extend_logger_results_with_comparison( + act_comparison, "a", "b", torch.ao.ns.fx.utils.compute_sqnr, "sqnr" + ) + + # act_comparison contains the activations from `mp_ns` and `mq_ns` stored + # in pairs, and can be used for further analysis. + + # + # Comparing activations, without error propagation + # + + # create shadow model + mp_shadows_mq = ns.add_shadow_loggers( + "a", copy.deepcopy(mp), "b", copy.deepcopy(mq), ns.OutputLogger + ) + + # send an example datum to capture intermediate activations + datum = torch.randn(1, 1, 1, 1) + mp_shadows_mq(datum) + + # extract intermediate activations + shadow_act_comparison = ns.extract_shadow_logger_info( + mp_shadows_mq, ns.OutputLogger, "b" + ) + + # add SQNR for each comparison, inplace + ns.extend_logger_results_with_comparison( + shadow_act_comparison, "a", "b", torch.ao.ns.fx.utils.compute_sqnr, "sqnr" + ) + + # shadow_act_comparison contains the activations from `mp_ns` and `mq_ns` stored + # in pairs, and can be used for further analysis. + +""" + +import collections +from typing import Any, Callable, Optional, TYPE_CHECKING + +import torch +import torch.ao.quantization.quantize_fx as quantize_fx +import torch.nn as nn +from torch.ao.ns.fx.graph_matcher import get_matching_subgraph_pairs +from torch.ao.ns.fx.mappings import get_base_name_to_sets_of_related_ops +from torch.ao.ns.fx.n_shadows_utils import ( + _get_dedup_subgraphs, + create_add_loggers_graph, + create_n_transformed_and_logged_copies_of_subgraph, + create_results_comparison, + extract_weight_comparison, + group_results_by_subgraph, + OutputProp, + print_n_shadows_summary, + SHADOW_WRAPPER_NODE_NAME_PREFIX, +) +from torch.ao.ns.fx.qconfig_multi_mapping import QConfigMultiMapping +from torch.ao.quantization import QConfigMapping +from torch.ao.quantization.backend_config import BackendConfig +from torch.ao.quantization.backend_config.utils import ( + get_fusion_pattern_to_root_node_getter, +) +from torch.ao.quantization.fx.graph_module import _get_observed_graph_module_attr +from torch.ao.quantization.fx.match_utils import _find_matches +from torch.ao.quantization.fx.qconfig_mapping_utils import ( + _generate_node_name_to_qconfig, +) +from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers +from torch.fx import GraphModule +from torch.fx.graph import Node + +from .fx.graph_passes import add_loggers_to_model, create_a_shadows_b +from .fx.ns_types import NSNodeTargetType, NSResultsType, NSSingleResultValuesType +from .fx.utils import ( + get_target_type_str, + maybe_add_missing_fqns, + rekey_logger_info_on_node_name_of_model, +) +from .fx.weight_utils import extract_weight_from_node + + +if TYPE_CHECKING: + from torch.ao.quantization.qconfig import QConfigAny + +RNNReturnType = tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]] + + +class OutputLogger(nn.Module): + """ + Base class for capturing intermediate values. + """ + + stats: list[torch.Tensor] + stats_rnn: list[RNNReturnType] + + # Mark as impure so that calls to it will not be removed during DCE. + _is_impure = True + + def __init__( + self, + ref_node_name: str, + prev_node_name: str, + model_name: str, + ref_name: str, + prev_node_target_type: str, + ref_node_target_type: str, + results_type: str, + index_within_arg: int, + index_of_arg: int, + fqn: Optional[str], + qconfig_str: Optional[str] = "", + ): + super().__init__() + self.stats: list[torch.Tensor] = [] + self.stats_rnn: list[RNNReturnType] = [] + + # name of the node which was responsible for adding this logger + # Note: + # - if we are logging node outputs, this is the same as prev_node_name + # - if we are logging node inputs, this is the name of the node + # whose input this logger is logging. + # + # example, where logger1 is logging input of op1 and logger2 is logging + # the output of op1: + # + # x1 -> logger1 -> op1 -> logger2 -> x2 + # + # in this example, + # - logger1's prev_node_name is x1 and ref_node_name is op1 + # - logger2's prev_node_name is op1 and ref_node_name is op1 + self.ref_node_name = ref_node_name + # name of the node whose output this Logger is capturing + self.prev_node_name = prev_node_name + + # name of the model from which the node originated from + self.model_name = model_name + # reference name, used to match loggers from separate models + # to each other + self.ref_name = ref_name + # type of the target of the node whose output this logger is logging + self.prev_node_target_type = prev_node_target_type + # type of the target of the node which was responsible for adding this + # logger + self.ref_node_target_type = ref_node_target_type + # what kind of values are inside of stats + self.results_type = results_type + # index of this node within the arg of the input/output node + # for example, in cat([x1, x2, x3], dim=0), x2 would have index_within_arg == 1 + self.index_within_arg = index_within_arg + # index of this node within the args of the input/output node + # for example, in add(x1, x2), x2 would have index_of_arg == 1 + self.index_of_arg = index_of_arg + # fully qualified name + self.fqn = fqn + # if loggers are added before prepare_fx, but we do not want + # collect results of calibration, only results after convert_fx + # so, we add a flag to control whether this logger collects data + self.enabled = True + # string representation of qconfig + self.qconfig_str = qconfig_str + # this can be turned off to reduce memory usage during calibration + self.save_activations = True + + # Note: cannot annotate the type of x because TorchScript does not support + # the Union type. + def forward(self, x): + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + # TODO(future PR): consider designing this better, as the difference + # between these two flags is subtle and not obvious. + if not self.enabled: + return x + if not self.save_activations: + return x + # TODO(future PR): consider refactoring this to better reuse the parent + # class + if isinstance(x, torch.Tensor): + self.stats.append(x.detach()) + elif isinstance(x, tuple) and len(x) == 2 and len(x[1]) == 2: + new_res = (x[0].detach(), (x[1][0].detach(), x[1][1].detach())) + self.stats_rnn.append(new_res) + return x + + def __repr__(self): + clean_dict = { + k: v + for k, v in self.__dict__.items() + # skip nn.Module keys + if (k != "training") and not k.startswith("_") + } + return f"OutputLogger({clean_dict})" + + +class OutputComparisonLogger(OutputLogger): + """ + Same as OutputLogger, but also requires the original activation + in order to calculate the comparison at calibration time + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # TODO(future PR): make the comparison function configurable + self.comparison_fn = torch.ao.ns.fx.utils.compute_sqnr + self.comparison_fn_name = "sqnr" + # precalculated comparisons of logger output versus reference + self.comparisons = [] + # precalculated comparisons function + + def forward(self, x, x_ref): # type: ignore[override] + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + if not self.enabled: + return x + assert isinstance(x, torch.Tensor), "non-tensor inputs not yet supported" + if self.save_activations: + # save the activation, for debugging + self.stats.append(x.detach()) + # save the comparison + self.comparisons.append(self.comparison_fn(x, x_ref)) + return x + + def __repr__(self): + clean_dict = { + k: v + for k, v in self.__dict__.items() + # skip nn.Module keys + if (k != "training") and not k.startswith("_") + } + return f"OutputComparisonLogger({clean_dict})" + + +class NSTracer(quantize_fx.QuantizationTracer): + """ + Just like a regular FX quantization tracer, but treats observers and fake_quantize + modules as leaf modules. + """ + + def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + if isinstance(m, torch.ao.quantization.ObserverBase): + return True + elif isinstance(m, torch.ao.quantization.FakeQuantizeBase): + return True + return super().is_leaf_module(m, module_qualified_name) + + +def _extract_weights_one_model( + model_name: str, + model: GraphModule, + nodes_and_names_to_instrument: list[tuple[Node, str]], + results: NSResultsType, + op_to_type_to_weight_extraction_fn: Optional[ + dict[str, dict[Callable, Callable]] + ] = None, +) -> None: + torch._C._log_api_usage_once( + "quantization_api._numeric_suite_fx._extract_weights_one_model" + ) + for node, ref_name in nodes_and_names_to_instrument: + res_type = NSSingleResultValuesType.WEIGHT.value + extracted_weight = extract_weight_from_node( + node, model, op_to_type_to_weight_extraction_fn + ) + if extracted_weight: + if ref_name not in results: + results[ref_name] = {res_type: {}} + results[ref_name][res_type][model_name] = [extracted_weight] + + +def _extract_weights_impl( + model_name_a: str, + gm_a: GraphModule, + model_name_b: str, + gm_b: GraphModule, + base_name_to_sets_of_related_ops: Optional[dict[str, set[NSNodeTargetType]]] = None, + unmatchable_types_map: Optional[dict[str, set[NSNodeTargetType]]] = None, + op_to_type_to_weight_extraction_fn: Optional[ + dict[str, dict[Callable, Callable]] + ] = None, +) -> NSResultsType: + torch._C._log_api_usage_once( + "quantization_api._numeric_suite_fx._extract_weights_impl" + ) + matched_subgraph_pairs = get_matching_subgraph_pairs( + gm_a, gm_b, base_name_to_sets_of_related_ops, unmatchable_types_map + ) + + # split the subgraph pairs into one data structure for each model + nodes_and_names_to_instrument_a: list[tuple[Node, str]] = [] + nodes_and_names_to_instrument_b: list[tuple[Node, str]] = [] + for match_name, match in matched_subgraph_pairs.items(): + subgraph_a, subgraph_b = match + nodes_and_names_to_instrument_a.append((subgraph_a.base_op_node, match_name)) + nodes_and_names_to_instrument_b.append((subgraph_b.base_op_node, match_name)) + + # populate the results, one model at a time + results: NSResultsType = {} + _extract_weights_one_model( + model_name_a, + gm_a, + nodes_and_names_to_instrument_a, + results, + op_to_type_to_weight_extraction_fn, + ) + _extract_weights_one_model( + model_name_b, + gm_b, + nodes_and_names_to_instrument_b, + results, + op_to_type_to_weight_extraction_fn, + ) + + # fill in missing fqn entries + maybe_add_missing_fqns(results) + + # rekey on names of nodes in gm_b + results = rekey_logger_info_on_node_name_of_model(results, model_name_b) + + return results + + +def extract_weights( + model_name_a: str, + model_a: nn.Module, + model_name_b: str, + model_b: nn.Module, + base_name_to_sets_of_related_ops: Optional[dict[str, set[NSNodeTargetType]]] = None, + unmatchable_types_map: Optional[dict[str, set[NSNodeTargetType]]] = None, + op_to_type_to_weight_extraction_fn: Optional[ + dict[str, dict[Callable, Callable]] + ] = None, +) -> NSResultsType: + """ + Extract weights from model A and model B, and return a comparison. + + Args: + model_name_a: string name of model A to use in results + model_a: model A + model_name_b: string name of model B to use in results + model_b: model B + base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change + unmatchable_types_map: optional override of unmatchable types, subject to change + op_to_type_to_weight_extraction_fn: optional override of function which extracts weight + from a type, subject to change + + Return: + NSResultsType, containing the weight comparisons + """ + + torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_weights") + if base_name_to_sets_of_related_ops is None: + base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops() + + # TODO(future PR): expose these + skipped_module_names: list[str] = [] + skipped_module_classes: list[Callable] = [] + tracer_a = NSTracer(skipped_module_names, skipped_module_classes) + tracer_b = NSTracer(skipped_module_names, skipped_module_classes) + gm_a = GraphModule(model_a, tracer_a.trace(model_a)) + maybe_model_a_node_name_to_scope = _get_observed_graph_module_attr( + model_a, "node_name_to_scope" + ) + if maybe_model_a_node_name_to_scope is not None: + gm_a._node_name_to_scope = maybe_model_a_node_name_to_scope + gm_b = GraphModule(model_b, tracer_b.trace(model_b)) + maybe_model_b_node_name_to_scope = _get_observed_graph_module_attr( + model_b, "node_name_to_scope" + ) + if maybe_model_b_node_name_to_scope is not None: + gm_b._node_name_to_scope = maybe_model_b_node_name_to_scope + return _extract_weights_impl( + model_name_a, + gm_a, + model_name_b, + gm_b, + base_name_to_sets_of_related_ops, + unmatchable_types_map, + op_to_type_to_weight_extraction_fn, + ) + + +def _add_loggers_one_model( + model_name: str, + model: GraphModule, + nodes_and_names_to_instrument_inputs: list[tuple[Node, str, str]], + nodes_and_names_to_instrument_outputs: list[tuple[Node, str, str]], + logger_cls: Callable, +) -> nn.Module: + torch._C._log_api_usage_once( + "quantization_api._numeric_suite_fx._add_loggers_one_model" + ) + + # TODO(future PR): do not observe nodes we do not care + # about (both fp32, denylist, etc) + node_to_instrument_inputs_to_ref_name: dict[Node, tuple[str, str]] = {} + node_to_instrument_outputs_to_ref_name: dict[Node, tuple[str, str]] = {} + for node, ref_name, ref_node_type in nodes_and_names_to_instrument_inputs: + node_to_instrument_inputs_to_ref_name[node] = (ref_name, ref_node_type) + for node, ref_name, ref_node_type in nodes_and_names_to_instrument_outputs: + node_to_instrument_outputs_to_ref_name[node] = (ref_name, ref_node_type) + + model = add_loggers_to_model( + model, + node_to_instrument_inputs_to_ref_name, + node_to_instrument_outputs_to_ref_name, + logger_cls, + model_name, + ) + return model + + +def _add_loggers_impl( + name_a: str, + gm_a: GraphModule, + name_b: str, + gm_b: GraphModule, + logger_cls: Callable, + should_log_inputs: bool, + base_name_to_sets_of_related_ops: Optional[dict[str, set[NSNodeTargetType]]] = None, + unmatchable_types_map: Optional[dict[str, set[NSNodeTargetType]]] = None, +) -> tuple[nn.Module, nn.Module]: + torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_impl") + matched_subgraph_pairs = get_matching_subgraph_pairs( + gm_a, gm_b, base_name_to_sets_of_related_ops, unmatchable_types_map + ) + nodes_and_names_to_instrument_inputs_a = [] + nodes_and_names_to_instrument_inputs_b = [] + nodes_and_names_to_instrument_outputs_a = [] + nodes_and_names_to_instrument_outputs_b = [] + for match_name, (subgraph_a, subgraph_b) in matched_subgraph_pairs.items(): + ref_node_type_a = get_target_type_str(subgraph_a.base_op_node, gm_a) + ref_node_type_b = get_target_type_str(subgraph_b.base_op_node, gm_b) + # Note: for matching inputs we use start_node, such as observing + # the input of linear in linear-relu + if should_log_inputs: + nodes_and_names_to_instrument_inputs_a.append( + (subgraph_a.start_node, match_name, ref_node_type_a) + ) + nodes_and_names_to_instrument_inputs_b.append( + (subgraph_b.start_node, match_name, ref_node_type_b) + ) + # Note: for matching activations we always use end_node, + # such as observing the output of relu in linear-relu + nodes_and_names_to_instrument_outputs_a.append( + (subgraph_a.end_node, match_name, ref_node_type_a) + ) + nodes_and_names_to_instrument_outputs_b.append( + (subgraph_b.end_node, match_name, ref_node_type_b) + ) + + new_model_a = _add_loggers_one_model( + name_a, + gm_a, + nodes_and_names_to_instrument_inputs_a, + nodes_and_names_to_instrument_outputs_a, + logger_cls, + ) + new_model_b = _add_loggers_one_model( + name_b, + gm_b, + nodes_and_names_to_instrument_inputs_b, + nodes_and_names_to_instrument_outputs_b, + logger_cls, + ) + return (new_model_a, new_model_b) + + +def add_loggers( + name_a: str, + model_a: nn.Module, + name_b: str, + model_b: nn.Module, + logger_cls: Callable, + should_log_inputs: bool = False, + base_name_to_sets_of_related_ops: Optional[dict[str, set[NSNodeTargetType]]] = None, + unmatchable_types_map: Optional[dict[str, set[NSNodeTargetType]]] = None, +) -> tuple[nn.Module, nn.Module]: + """ + Instrument model A and model B with loggers. + + Args: + name_a: string name of model A to use in results + model_a: model A + name_b: string name of model B to use in results + model_b: model B + logger_cls: class of Logger to use + base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change + unmatchable_types_map: optional override of unmatchable types, subject to change + + Return: + Returns a tuple of (model_a_with_loggers, model_b_with_loggers). Modifies both models inplace. + """ + + torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_loggers") + # TODO(future PR): expose these + skipped_module_names: list[str] = [] + skipped_module_classes: list[Callable] = [] + tracer_a = NSTracer(skipped_module_names, skipped_module_classes) + tracer_b = NSTracer(skipped_module_names, skipped_module_classes) + gm_a = GraphModule(model_a, tracer_a.trace(model_a)) + maybe_model_a_node_name_to_scope = _get_observed_graph_module_attr( + model_a, "node_name_to_scope" + ) + if maybe_model_a_node_name_to_scope is not None: + gm_a._node_name_to_scope = maybe_model_a_node_name_to_scope + gm_b = GraphModule(model_b, tracer_b.trace(model_b)) + maybe_model_b_node_name_to_scope = _get_observed_graph_module_attr( + model_b, "node_name_to_scope" + ) + if maybe_model_b_node_name_to_scope is not None: + gm_b._node_name_to_scope = maybe_model_b_node_name_to_scope + return _add_loggers_impl( + name_a, + gm_a, + name_b, + gm_b, + logger_cls, + should_log_inputs=should_log_inputs, + base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops, + unmatchable_types_map=unmatchable_types_map, + ) + + +def _extract_logger_info_one_model( + model: nn.Module, + results: NSResultsType, + logger_cls: Callable, +) -> None: + torch._C._log_api_usage_once( + "quantization_api._numeric_suite_fx._extract_logger_info_one_model" + ) + for _gm_name, mod in model.named_modules(): + # TODO(future PR): better check when scripted + is_logger = isinstance(mod, logger_cls) or ( # type: ignore[arg-type] + isinstance(mod, torch.jit.RecursiveScriptModule) + and mod.original_name == "OutputLogger" + ) + if is_logger: + key = mod.ref_name + if key not in results: + results[key] = {} + assert mod.model_name not in results[key], ( + f"{mod.model_name} is already present in results" + ) + if mod.results_type not in results[key]: + results[key][mod.results_type] = {} + if mod.model_name not in results[key][mod.results_type]: + results[key][mod.results_type][mod.model_name] = [] + stats_to_use = mod.stats + if len(mod.stats_rnn) > 0: + stats_to_use = mod.stats_rnn + data = { + "type": mod.results_type, + "values": stats_to_use, + "ref_node_name": mod.ref_node_name, + "ref_node_target_type": mod.ref_node_target_type, + "prev_node_name": mod.prev_node_name, + "prev_node_target_type": mod.prev_node_target_type, + "index_within_arg": mod.index_within_arg, + "index_of_arg": mod.index_of_arg, + "fqn": mod.fqn, + "qconfig_str": mod.qconfig_str, + } + if hasattr(mod, "comparisons"): + data["comparisons"] = mod.comparisons + data["comparison_fn_name"] = mod.comparison_fn_name + else: + data["comparisons"] = [] + data["comparison_fn_name"] = "" + results[key][mod.results_type][mod.model_name].append(data) + # ensure the list stays sorted + results[key][mod.results_type][mod.model_name].sort( + key=lambda res: f"{res['index_of_arg']}:{res['index_within_arg']}" + ) + + +# TODO(future PR): align on naming +# this is equivalent of just the comparison extraction part of `ns.compare_model_outputs` +def extract_logger_info( + model_a: nn.Module, + model_b: nn.Module, + logger_cls: Callable, + model_name_to_use_for_layer_names: str, +) -> NSResultsType: + """ + Traverse all loggers in `model_a` and `model_b`, and extract the logged + information. + + Args: + model_a: model A + model_b: model B + logger_cls: class of Logger to use + model_name_to_use_for_layer_names: string name of model to use for + layer names in the output + + Return: + NSResultsType, containing the logged comparisons + """ + torch._C._log_api_usage_once( + "quantization_api._numeric_suite_fx.extract_logger_info" + ) + results: NSResultsType = {} + for model in (model_a, model_b): + _extract_logger_info_one_model(model, results, logger_cls) + # fill in missing fqn entries + maybe_add_missing_fqns(results) + # rekey on the name of model b + results = rekey_logger_info_on_node_name_of_model( + results, model_name_to_use_for_layer_names + ) + return results + + +def _add_shadow_loggers_impl( + name_a: str, + gm_a: GraphModule, + name_b: str, + gm_b: GraphModule, + logger_cls: Callable, + should_log_inputs: bool, + base_name_to_sets_of_related_ops: Optional[dict[str, set[NSNodeTargetType]]] = None, + node_type_to_io_type_map: Optional[dict[str, set[NSNodeTargetType]]] = None, + unmatchable_types_map: Optional[dict[str, set[NSNodeTargetType]]] = None, +) -> nn.Module: + torch._C._log_api_usage_once( + "quantization_api._numeric_suite_fx._add_shadow_loggers_impl" + ) + matched_subgraph_pairs = get_matching_subgraph_pairs( + gm_a, gm_b, base_name_to_sets_of_related_ops, unmatchable_types_map + ) + gm_a_shadows_b = create_a_shadows_b( + name_a, + gm_a, + name_b, + gm_b, + matched_subgraph_pairs, + logger_cls, + should_log_inputs=should_log_inputs, + node_type_to_io_type_map=node_type_to_io_type_map, + ) + return gm_a_shadows_b + + +def add_shadow_loggers( + name_a: str, + model_a: nn.Module, + name_b: str, + model_b: nn.Module, + logger_cls: Callable, + should_log_inputs: bool = False, + base_name_to_sets_of_related_ops: Optional[dict[str, set[NSNodeTargetType]]] = None, + node_type_to_io_type_map: Optional[dict[str, set[NSNodeTargetType]]] = None, + unmatchable_types_map: Optional[dict[str, set[NSNodeTargetType]]] = None, +) -> nn.Module: + """ + Instrument model A and model B with shadow loggers. + + Args: + name_a: string name of model A to use in results + model_a: model A + name_b: string name of model B to use in results + model_b: model B + logger_cls: class of Logger to use + should_log_inputs: whether to log inputs + base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change + unmatchable_types_map: optional override of unmatchable types, subject to change + """ + torch._C._log_api_usage_once( + "quantization_api._numeric_suite_fx.add_shadow_loggers" + ) + # TODO(future PR): expose these + skipped_module_names: list[str] = [] + skipped_module_classes: list[Callable] = [] + tracer_a = NSTracer(skipped_module_names, skipped_module_classes) + tracer_b = NSTracer(skipped_module_names, skipped_module_classes) + gm_a = GraphModule(model_a, tracer_a.trace(model_a)) + maybe_model_a_node_name_to_scope = _get_observed_graph_module_attr( + model_a, "node_name_to_scope" + ) + if maybe_model_a_node_name_to_scope is not None: + gm_a._node_name_to_scope = maybe_model_a_node_name_to_scope + gm_b = GraphModule(model_b, tracer_b.trace(model_b)) + maybe_model_b_node_name_to_scope = _get_observed_graph_module_attr( + model_b, "node_name_to_scope" + ) + if maybe_model_b_node_name_to_scope is not None: + gm_b._node_name_to_scope = maybe_model_b_node_name_to_scope + return _add_shadow_loggers_impl( + name_a, + gm_a, + name_b, + gm_b, + logger_cls, + should_log_inputs=should_log_inputs, + base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops, + node_type_to_io_type_map=node_type_to_io_type_map, + unmatchable_types_map=unmatchable_types_map, + ) + + +def extract_shadow_logger_info( + model_a_shadows_b: nn.Module, + logger_cls: Callable, + model_name_to_use_for_layer_names: str, +) -> NSResultsType: + """ + Traverse all loggers in a shadow model, and extract the logged + information. + + Args: + model_a_shadows_b: shadow model + logger_cls: class of Logger to use + model_name_to_use_for_layer_names: string name of model to use for + layer names in the output + + Return: + NSResultsType, containing the logged comparisons + """ + torch._C._log_api_usage_once( + "quantization_api._numeric_suite_fx.extract_shadow_logger_info" + ) + results: NSResultsType = collections.defaultdict(dict) + _extract_logger_info_one_model(model_a_shadows_b, results, logger_cls) + # fill in missing fqn entries + maybe_add_missing_fqns(results) + # rekey on the name of model b + results = rekey_logger_info_on_node_name_of_model( + results, model_name_to_use_for_layer_names + ) + return dict(results) + + +def extend_logger_results_with_comparison( + results: NSResultsType, + model_name_1: str, + model_name_2: str, + comparison_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + comparison_name: str, +) -> None: + """ + Compares the logged values from `model_name_2` against the corresponding + values in `model_name_1`, using `comparison_fn`. Records the result + in `model_name_2`'s results under `comparison_name`. Modifies `results` inplace. + + Args: + results: the result data structure from `extract_logger_info` or + `extract_shadow_logger_info`. + model_name_1: string name of model 1 + model_name_2: string name of model 2 + comparison_fn: function to compare two Tensors + comparison_name: string name of model to use for + layer names in the output + """ + for results_type_to_results in results.values(): + for model_name_to_results in results_type_to_results.values(): + assert model_name_1 in model_name_to_results, ( + f"{model_name_1} not found in results" + ) + assert model_name_2 in model_name_to_results, ( + f"{model_name_2} not found in results" + ) + + results_1 = model_name_to_results[model_name_1] + results_2 = model_name_to_results[model_name_2] + + for result_2 in results_2: + index_within_arg_2 = result_2["index_within_arg"] + index_of_arg_2 = result_2["index_of_arg"] + # find corresponding result_1 + result_1 = None + for cur_result_1 in results_1: + index_within_arg_1 = cur_result_1["index_within_arg"] + index_of_arg_1 = cur_result_1["index_of_arg"] + if (index_within_arg_1 == index_within_arg_2) and ( + index_of_arg_1 == index_of_arg_2 + ): + result_1 = cur_result_1 + break + assert result_1 is not None + + values_1 = result_1["values"] + values_2 = result_2["values"] + result_2[comparison_name] = [] + for value_1, value_2 in zip(values_1, values_2): + comparison_result = comparison_fn(value_1, value_2) + result_2[comparison_name].append(comparison_result) + + +def prepare_n_shadows_model( + model: torch.nn.Module, + example_inputs: Any, + qconfig_multi_mapping: QConfigMultiMapping, + backend_config: BackendConfig, + custom_prepare_fn: Optional[Callable] = None, + custom_prepare_kwargs: Optional[dict[str, Any]] = None, + custom_tracer: Any = None, +) -> GraphModule: + """ + Given a model with a graph with M ops such as + + + args_kwargs_m -> op_m -> output_m + + + And a set of N qconfigs for each op, creates a new model, with + each of the subgraph of `op_m` transformed into + + .. code:: + + |---------> op_m_n -> log_m_n + | / + args_kwargs_m ---------> op_m -> log_m_0 + + Where op_m_n is op_m wrapped in a submodule and transformed with + qconfig_n, and its inner graph looks like + + .. code:: + + args_m -------- op_m_prepared_with_qconfig_n -> out_m_n + / + kwargs_m --- + + This is useful for testing different quantization of multiple layers in + a single pass through the model. + + High level TODOs for future PRs: + * figure out a better way to name the output structure + * return a results data structure instead of printing it out + * add examples to docblocks + """ + + if custom_tracer is None: + tracer = quantize_fx.QuantizationTracer([], []) + else: + tracer = custom_tracer + mt = torch.fx.GraphModule(model, tracer.trace(model)) + # this is necessary to ensure logger FQNs get populated + mt._node_name_to_scope = tracer.node_name_to_scope # type: ignore[assignment] + + # run example input propagation, we need this to call prepare_fx on + # individual subgraphs + output_prop = OutputProp(mt) + output_prop.propagate(*example_inputs) + + # Find the set of subgraphs in the original graph which we need to + # consider. + modules = dict(mt.named_modules(remove_duplicate=False)) + patterns = _get_pattern_to_quantize_handlers(backend_config) + root_node_getter_mapping = get_fusion_pattern_to_root_node_getter(backend_config) + standalone_module_names: list[str] = [] + standalone_module_classes: list[type] = [] + custom_module_classes: list[type] = [] + matches = _find_matches( + mt.graph, + modules, + patterns, + root_node_getter_mapping, + standalone_module_names, + standalone_module_classes, + custom_module_classes, + ) + subgraphs_dedup: dict[str, list[Node]] = _get_dedup_subgraphs(matches) + + # generate node to qconfig for each subgraph + # TODO(future PR): deduplicate repeating entries + list_of_node_name_to_qconfig: list[dict[str, QConfigAny]] = [] + for qconfig_mapping in qconfig_multi_mapping.qconfig_mappings_list: + node_name_to_qconfig = _generate_node_name_to_qconfig( + mt, modules, mt.graph, qconfig_mapping, tracer.node_name_to_scope + ) + list_of_node_name_to_qconfig.append(node_name_to_qconfig) + + # For each region in the model, do the following: + # For each qconfig for that region, do the following: + # 1. create a copy of the region wrapped in a module + # 2. pass original args, original kwargs, and expected output to module + # 3. add an output comparison logger and hook it up to compare + # actual output to expected output + # 4. run `prepare_fx` on the module + for subgraph_idx, (match_name, nodes_in_this_subgraph) in enumerate( + subgraphs_dedup.items() + ): + create_n_transformed_and_logged_copies_of_subgraph( + mt, + subgraph_idx, + match_name, + nodes_in_this_subgraph, + qconfig_multi_mapping.qconfig_mappings_list, + list_of_node_name_to_qconfig, + custom_prepare_fn, + custom_prepare_kwargs, # type: ignore[arg-type] + ) + + return mt + + +# TODO(future PR): we should rethink the names of all the PNP APIs +def _prepare_n_shadows_add_loggers_model( + model: torch.nn.Module, + example_inputs: Any, + qconfig_mapping: QConfigMapping, + backend_config: BackendConfig, +) -> torch.nn.Module: + r""" + Note: this API is not recommended for wide usage, it is only + provided for customers who need to migrate from the `add_loggers` + API. + + This creates a model which provides logging for the following + problem: if we quantize `model` with `qconfig_mapping` and feed + the same input through both models, log the comparisons of + corresponding intermediate layers. + + The problem is solved with a single model. Specifically, we + partition `model` into N subgraphs, create a copy of each relevant + subgraph, wrap it in a module, apply the quantization API to that + module, and hook up loggers to measure the comparisons. + + Example starting graph: + + x0 -> op0 -> x1 -> op1 -> x2 + + Example config: quantize op0 to int8, do nothing to op1. + The following graph will be created: + + .. code:: + + x0_0 -> op0_0 -> x1_0 -> log -----> op1_0 -> x2_0 -> log + \ \ \ # noqa: W605 + ---> op0_1 -> x1_1 ----> clog -> op1_0 -> x2_1 ----> clog + + Where op0_0 is op0, op0_1 is op0 wrapped in a submodule and quantized + to int8, op1_0 is op1 (appearing in the graph twice), log is a logger, + and clog is a comparison logger. + """ + + tracer = quantize_fx.QuantizationTracer([], []) + mt = torch.fx.GraphModule(model, tracer.trace(model)) + # this is necessary to ensure logger FQNs get populated + mt._node_name_to_scope = tracer.node_name_to_scope # type: ignore[assignment] + + # run example input propagation, we need this to call prepare_fx on + # individual subgraphs + output_prop = OutputProp(mt) + output_prop.propagate(*example_inputs) + + # Find the set of subgraphs in the original graph which we need to + # consider. + modules = dict(mt.named_modules(remove_duplicate=False)) + patterns = _get_pattern_to_quantize_handlers(backend_config) + root_node_getter_mapping = get_fusion_pattern_to_root_node_getter(backend_config) + standalone_module_names: list[str] = [] + standalone_module_classes: list[type] = [] + custom_module_classes: list[type] = [] + matches = _find_matches( + mt.graph, + modules, + patterns, + root_node_getter_mapping, + standalone_module_names, + standalone_module_classes, + custom_module_classes, + ) + subgraphs_dedup: dict[str, list[Node]] = _get_dedup_subgraphs(matches) + + # generate node to qconfig for each subgraph + node_name_to_qconfig = _generate_node_name_to_qconfig( + mt, modules, mt.graph, qconfig_mapping, tracer.node_name_to_scope + ) + + # Now, mutate the graph to be the add_loggers graph with propagation + # error. + create_add_loggers_graph(mt, subgraphs_dedup, qconfig_mapping, node_name_to_qconfig) + + return mt + + +# TODO(future PR): we should rethink the names of all the PNP APIs +def _n_shadows_compare_weights( + model: torch.nn.Module, + example_inputs: Any, + qconfig_mapping: QConfigMapping, + backend_config: BackendConfig, +) -> NSResultsType: + """ + Note: this API is not recommended for wide usage, it is only + provided for customers who need to migrate from the `add_loggers` + API. + """ + qconfig_multi_mapping = QConfigMultiMapping.from_list_qconfig_mapping( + [qconfig_mapping] + ) + mp = prepare_n_shadows_model( + model, example_inputs, qconfig_multi_mapping, backend_config + ) + # passing inputs through the model is necessary to populate + # observers which observe weights with real values + mp(*example_inputs) + mq = convert_n_shadows_model(mp) + weight_comparison = extract_weight_comparison(mq) + return weight_comparison + + +# TODO(future PR): consider aligning API signature with other similar quantization +# functions (enable_fake_quant, etc) +def loggers_set_enabled(model: torch.nn.Module, enabled: bool) -> None: + """ + Sets the `enabled` setting on a `model`'s loggers + """ + for _, child in model.named_modules(): + if isinstance(child, OutputLogger): + child.enabled = enabled + + +# TODO(future PR): consider aligning API signature with other similar quantization +# functions (enable_fake_quant, etc) +def loggers_set_save_activations( + model: torch.nn.Module, + save_activations: bool, +) -> None: + """ + Sets the `save_activations` setting on a `model`'s loggers + """ + for _name, child in model.named_modules(): + if isinstance(child, OutputLogger): + child.save_activations = save_activations + + +def convert_n_shadows_model( + model: GraphModule, + custom_convert_fn: Optional[Callable] = None, + custom_convert_kwargs: Optional[dict[str, Any]] = None, +) -> GraphModule: + """ + Given a model from `prepare_n_shadows_model`, runs `convert_fx` + on each shadow submodule. + """ + for node in model.graph.nodes: + # TODO(future PR): consider matching in a safer way than + # node name string match + if node.name.startswith(SHADOW_WRAPPER_NODE_NAME_PREFIX): + orig_mod = getattr(model, node.name) + if custom_convert_fn is None: + converted_mod = torch.ao.quantization.quantize_fx.convert_fx(orig_mod) + else: + if custom_convert_kwargs is None: + custom_convert_kwargs = {} + converted_mod = custom_convert_fn(orig_mod, **custom_convert_kwargs) + setattr(model, node.name, converted_mod) + + return model + + +def extract_results_n_shadows_model(model: torch.nn.Module) -> NSResultsType: + """ + Extracts logger results from `model`. + """ + results: NSResultsType = {} + _extract_logger_info_one_model(model, results, OutputLogger) + return results + + +def print_comparisons_n_shadows_model(results: NSResultsType) -> None: + """ + Prints a summary of extracted `results`. + """ + results_grouped = group_results_by_subgraph(results) + results_comparison = create_results_comparison(results_grouped) + print_n_shadows_summary(results_comparison) diff --git a/phivenv/Lib/site-packages/torch/ao/ns/fx/__init__.py b/phivenv/Lib/site-packages/torch/ao/ns/fx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/ao/ns/fx/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/ns/fx/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f4e114c8fdbd473b3c4eb41a43ab1d062a0614d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/ns/fx/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/ns/fx/__pycache__/graph_matcher.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/ns/fx/__pycache__/graph_matcher.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0644e62049bec82d5cddba230335bcbc921c8a2d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/ns/fx/__pycache__/graph_matcher.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/ns/fx/__pycache__/graph_passes.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/ns/fx/__pycache__/graph_passes.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5d7f8fc3b76d7168563dda48d3e8b298eaf5813 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/ns/fx/__pycache__/graph_passes.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/ns/fx/__pycache__/mappings.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/ns/fx/__pycache__/mappings.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e1bb20b23c661a66f61738dfdbe1380b07a9f13 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/ns/fx/__pycache__/mappings.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/ns/fx/__pycache__/n_shadows_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/ns/fx/__pycache__/n_shadows_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0855be312149582e96727182f4ba4792e89aba43 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/ns/fx/__pycache__/n_shadows_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/ns/fx/__pycache__/ns_types.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/ns/fx/__pycache__/ns_types.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b09146345d6e9d0b3919d235fed5f5bc5c46d4f7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/ns/fx/__pycache__/ns_types.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/ns/fx/__pycache__/pattern_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/ns/fx/__pycache__/pattern_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..192ecfb453e9f10e082e47df43d3c1577d20ee50 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/ns/fx/__pycache__/pattern_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/ns/fx/__pycache__/qconfig_multi_mapping.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/ns/fx/__pycache__/qconfig_multi_mapping.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4b066565a4138d7d3c9f0e371975ef3b320283f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/ns/fx/__pycache__/qconfig_multi_mapping.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/ns/fx/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/ns/fx/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be4257573c4e53a853969a04e76da65e5f7724e4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/ns/fx/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/ns/fx/__pycache__/weight_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/ns/fx/__pycache__/weight_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8af5835bc02ff95a38bd9f6a3c935a4e7b4a86af Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/ns/fx/__pycache__/weight_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/ns/fx/graph_matcher.py b/phivenv/Lib/site-packages/torch/ao/ns/fx/graph_matcher.py new file mode 100644 index 0000000000000000000000000000000000000000..00a883d92823f2b7753f9468fcc85fc56e8425ce --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/ns/fx/graph_matcher.py @@ -0,0 +1,472 @@ +# mypy: allow-untyped-defs +import collections +import enum +from typing import Any, Optional + +import torch +from torch.ao.quantization import FakeQuantizeBase, ObserverBase +from torch.ao.quantization.utils import getattr_from_fqn +from torch.fx import GraphModule +from torch.fx.graph import Graph, Node + +from .mappings import get_base_name_to_sets_of_related_ops, get_unmatchable_types_map +from .ns_types import NSNodeTargetType, NSSubgraph +from .pattern_utils import ( + end_node_matches_reversed_fusion, + get_reversed_fusions, + get_type_a_related_to_b, +) + + +toq = torch.ops.quantized + + +def _get_output_nodes(g: Graph) -> list[Node]: + return [n for n in g.nodes if n.op == "output"] + + +class _NSGraphMatchableSubgraphsIterator: + """ + Iterates through the graph of gm, starting with the output nodes + and continuing backwards. + 1. Returns matchable subgraphs, in order. A subgraph is defined by + (start_node, end_node). + 2. Skips over non-matchable subgraphs + """ + + def __init__( + self, + gm: GraphModule, + non_matchable_functions: set[NSNodeTargetType], + non_matchable_modules: set[NSNodeTargetType], + non_matchable_methods: set[NSNodeTargetType], + ): + self.gm: GraphModule = gm + self.non_matchable_functions: set[NSNodeTargetType] = non_matchable_functions + self.non_matchable_modules: set[NSNodeTargetType] = non_matchable_modules + self.non_matchable_methods: set[NSNodeTargetType] = non_matchable_methods + self.seen_nodes: set[Node] = set() + self.stack: list[Node] = [] + for start_node in _get_output_nodes(self.gm.graph): + self.stack.append(start_node) + + def __iter__(self): + return self + + def __next__(self) -> NSSubgraph: + """ + Returns the next matchable subgraph. + """ + while len(self.stack) > 0: + cur_end_node = self.stack.pop() + if cur_end_node in self.seen_nodes: + continue + + # for subgraphs which are single nodes, start_node == end_node + # for subgraphs with more than one node, start node != end_node + cur_start_node = cur_end_node + # Subgraphs like linear-relu have the base node as the start node. + # Subgraphs like dequantize-linear-relu-to(torch.float16) have the + # base node as the second node. + # The cur_base_op_node var will move to the actual node during + # the fusion matching later in this code block. + cur_base_op_node = cur_end_node + + # Check for potential fusions. For now, we are greedy + # and always skip all non-base nodes of a fusion. For example, + # if we match linear-relu backwards, we will always skip the + # relu node and attempt to match the linear node. This can + # be made configurable later if needed. + for _reverse_fusion_ops, base_op_idx in get_reversed_fusions(): + is_match = end_node_matches_reversed_fusion( + cur_end_node, _reverse_fusion_ops, self.gm, self.seen_nodes + ) + if is_match: + # navigate to the base node + for rev_fusion_idx in range(len(_reverse_fusion_ops) - 1): + self.seen_nodes.add(cur_start_node) + # for now, assume that there are no other nodes + # which need to be added to the stack + cur_start_node = cur_start_node.args[0] # type: ignore[assignment] + # if the base op index matches the current node, set it + rev_base_op_idx = len(_reverse_fusion_ops) - 2 - base_op_idx + if rev_fusion_idx == rev_base_op_idx: + cur_base_op_node = cur_start_node + break + + self.seen_nodes.add(cur_start_node) + # add args of previous nodes to stack + for arg in cur_start_node.all_input_nodes: + self._recursively_add_node_arg_to_stack(arg) + + # skip unmatchable nodes + # note: this check is done on the start_node, i.e. + # if we are matching linear-relu in reverse, this would do the matchable + # check on the linear + if not self._is_matchable(cur_base_op_node): + continue + + # If an observer or a fake_quant was not matched as a part of + # a pattern of multiple nodes, ignore it. One case where this is + # relevant is an observer on a graph input, which was added because + # it is necessary for the next node. + if cur_end_node.op == "call_module" and cur_start_node is cur_end_node: + maybe_obs = getattr_from_fqn(self.gm, cur_end_node.target) # type: ignore[arg-type] + if isinstance(maybe_obs, (ObserverBase, FakeQuantizeBase)): + continue + + return NSSubgraph( + start_node=cur_start_node, + end_node=cur_end_node, + base_op_node=cur_base_op_node, + ) + + raise StopIteration + + def _recursively_add_node_arg_to_stack(self, arg: Any) -> None: + """ + Adds all of the nodes in this arg to the stack, properly navigating + through list, dicts and tuples. + """ + if isinstance(arg, Node): + self.stack.append(arg) + elif ( + isinstance(arg, torch.fx.immutable_collections.immutable_list) + or type(arg) is tuple + ): + for inner_arg in arg: + self._recursively_add_node_arg_to_stack(inner_arg) + elif isinstance(arg, torch.fx.immutable_collections.immutable_dict): + for value in arg.values(): + self._recursively_add_node_arg_to_stack(value) + + def _is_matchable(self, node: Node) -> bool: + if node.op == "call_function": + return node.target not in self.non_matchable_functions + elif node.op == "call_module": + assert isinstance(node.target, str) + target_mod = getattr_from_fqn(self.gm, node.target) + return not any( + isinstance(target_mod, t) # type: ignore[arg-type] + for t in self.non_matchable_modules + ) + elif node.op == "call_method": + return node.target not in self.non_matchable_methods + else: + return False + + +class GraphMatchingException(Exception): + """ + Exception raised when two graphs cannot be matched. + """ + + +class SubgraphTypeRelationship(enum.Enum): + # same type, known + # example: F.linear and F.linear, or nn.Conv2d and nn.Conv2d + EQUAL = enum.auto() + # same type, but the type is not known to Numerical Suite + # (user defined type, etc). + EQUAL_BUT_UKNOWN = enum.auto() + # known, same subgraph_relationship set, but not the same type + # example: F.linear and toq.linear + RELATED_BUT_NOT_EQUAL = enum.auto() + # not related + NOT_RELATED = enum.auto() + + +def _get_subgraph_relationship_type( + subgraph_a: NSSubgraph, + subgraph_b: NSSubgraph, + gm_a: GraphModule, + gm_b: GraphModule, + type_a_related_to_b: set[tuple[NSNodeTargetType, NSNodeTargetType]], +) -> SubgraphTypeRelationship: + node_a = subgraph_a.base_op_node + node_b = subgraph_b.base_op_node + + # TODO(next): make this code handle matching by what is before the base op + if node_a.op != node_b.op: + if not ( + node_a.op in ("call_function", "call_method") + and node_b.op in ("call_function", "call_method") + ): + return SubgraphTypeRelationship.NOT_RELATED + + if node_a.op in ("call_function", "call_method"): + key = (node_a.target, node_b.target) + + if key not in type_a_related_to_b: + if node_a.target == node_b.target: + return SubgraphTypeRelationship.EQUAL_BUT_UKNOWN + else: + return SubgraphTypeRelationship.NOT_RELATED + # after this point, we are dealing with known types + + if node_a.target == node_b.target: + node_a_has_prev = subgraph_a.base_op_node == subgraph_a.start_node + node_b_has_prev = subgraph_b.base_op_node == subgraph_b.start_node + if node_a_has_prev and (not node_b_has_prev): + return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL + elif (not node_a_has_prev) and node_b_has_prev: + return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL + elif (not node_a_has_prev) and (not node_b_has_prev): + return SubgraphTypeRelationship.EQUAL + else: + # TODO(future PR): check for matches start_op_node and base_op_node + return SubgraphTypeRelationship.EQUAL + + if key in type_a_related_to_b: + return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL + else: + return SubgraphTypeRelationship.NOT_RELATED + elif node_a.op == "call_module": + assert ( + subgraph_a.base_op_node == subgraph_a.start_node + and subgraph_b.base_op_node == subgraph_b.start_node + ), ( + "Matching call_module patterns where base_op_node != start_node is not supported yet" + ) + # for call_module, we need to look up the modules to do the type check + assert isinstance(node_a.target, str) + mod_a = getattr_from_fqn(gm_a, node_a.target) + assert isinstance(node_b.target, str) + mod_b = getattr_from_fqn(gm_b, node_b.target) + + key = (type(mod_a), type(mod_b)) + + if key not in type_a_related_to_b: + if type(mod_a) == type(mod_b): + return SubgraphTypeRelationship.EQUAL_BUT_UKNOWN + else: + return SubgraphTypeRelationship.NOT_RELATED + elif type(mod_a) == type(mod_b): + return SubgraphTypeRelationship.EQUAL + else: + return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL + + return SubgraphTypeRelationship.NOT_RELATED + + +def _get_name_for_subgraph( + subgraph_a: NSSubgraph, + gm_a: GraphModule, + base_name_to_sets_of_related_ops: dict[str, set[NSNodeTargetType]], + existing_names: set[str], +) -> str: + """ + Returns a unique name for a subgraph. This name is based on two things: + 1. the name of the set containing the underlying type of the base op in the + subgraph (i.e. 'torch.nn.functional.linear' if this is related to a linear op) + 2. the number of previous subgraphs with related underlying type of the base op + + For example, in the graph + + linear0 -> relu0 -> linear1 -> relu1 + + The subgraphs are (linear0, relu0) and (linear1, relu1). If we iterate + from the output node backwards, the name given to (linear1, relu1) will be + `base_op_torch.nn.functional.linear_0`, and the name given to (linear0, relu0) + will be `base_op_torch.nn.functional.linear_1`. + + Why are we not just using the node name? Answer: because of two requirements: + A. fusions must be supported + B. some Numeric Suite APIs can be called without having all of the models in memory + + For example, let's say we need to match nodes of + + (1) ... -> linear0 -> relu0 -> ... + + And + + (2) ... -> linear_relu0 -> ... + + Without being able to inspect them together. With the current naming scheme, if + we iterate through both of these graphs in the same order, and assuming the rest + of the graphs match, both of these subgraphs will get the same name without + (1) and (2) knowing anything about each other. + """ + target_type = _get_node_target_type(subgraph_a.base_op_node, gm_a) + target_base_type = None + for base_name, sets_of_related_ops in base_name_to_sets_of_related_ops.items(): + if target_type in sets_of_related_ops: + target_base_type = base_name + target_base_name = "base_op_" + str(target_base_type) + counter = 0 + proposed_name = target_base_name + "_" + str(counter) + while proposed_name in existing_names: + counter += 1 + proposed_name = target_base_name + "_" + str(counter) + existing_names.add(proposed_name) + return proposed_name + + +def _get_node_target_type(node: Node, gm: GraphModule) -> Optional[NSNodeTargetType]: + if node.op in ("call_function", "call_method"): + return node.target + elif node.op == "call_module": + assert isinstance(node.target, str) + mod = getattr_from_fqn(gm, node.target) + return type(mod) + return None + + +def get_matching_subgraph_pairs( + gm_a: GraphModule, + gm_b: GraphModule, + base_name_to_sets_of_related_ops: Optional[dict[str, set[NSNodeTargetType]]] = None, + unmatchable_types_map: Optional[dict[str, set[NSNodeTargetType]]] = None, +) -> dict[str, tuple[NSSubgraph, NSSubgraph]]: + """ + Matches matchable subgraphs of graph_a to graph_b. + + For a node, "matchable" is defined as a node which is not an observer, + fake_quants, quant or dequant. + + A subgraph can contain one or more nodes. A subgraph is matchable if + at least one node inside of it is matchable. Currently, all nodes in + a subgraph must be matchable (because we assume no observers will be + inserted in the middle of a fusion). + + A subgraph is defined by (start_node, end_node). We assume that only + start_node and end_node are linked with the surrounding graph, all other + nodes in a subgraph are self-contained. + + A pair of nodes is "related" if both nodes represent the same mathematical + operation across different quantization flavors. For example, + `F.linear` and `torch.ops.quantized.linear` are related, and + `F.linear` and `torch.nn.Conv` are not related. + + For each matchable pair of nodes node_a and node_b, they will match + if node_a and node_b are related. + + For graphs A and B, they will match iff: + 1. the number of matchable subgraphs in A and B is equivalent + 2. when iterating through the matchable subgraphs of A and B in the same order, each + corresponding pair of base nodes is related. + + This enables us to find the corresponding subgraphs between + graphs of related models. For example, if we had two graphs such as: + + graph_a: x0 -> conv_0 (type: nn.Conv2d) -> obs_0 -> x1 + w -/ + b -/ + + graph_b: x0 -> quant_0 -> qconv_0 (type: nnq.Conv2d) -> dequant_0 -> x1 + packed_params_0 -/ + + This function will return the following result: + { + 'conv_0': ( # the name of the node in graph_b + (conv_0, conv_0), # (start_node_a, end_node_a) + (qconv_0, qconv_0), # (start_node_b, end_node_b) + ), + } + + Or, if we have a fusion pattern, + + graph_a: x0 -> linear_0 -> relu_0 -> obs_0 -> x1 + w -/ + b -/ + + graph_b: x0 -> quant_0 -> linear_relu_0 -> dequant_0 -> x1 + packed_params_0 -/ + + This function will return the following result: + { + 'linear_relu_0': ( # the name of the node in graph_b + (linear_0, relu_0), # (start_node_a, end_node_a) + (linear_relu_0, linear_relu_0), # (start_node_b, end_node_b) + ), + } + """ + if unmatchable_types_map is None: + unmatchable_types_map = get_unmatchable_types_map() + non_matchable_functions = unmatchable_types_map["funs_unmatchable"] + non_matchable_modules = unmatchable_types_map["mods_unmatchable"] + non_matchable_methods = unmatchable_types_map["meths_unmatchable"] + + graph_a_iterator = _NSGraphMatchableSubgraphsIterator( + gm_a, non_matchable_functions, non_matchable_modules, non_matchable_methods + ) + graph_b_iterator = _NSGraphMatchableSubgraphsIterator( + gm_b, non_matchable_functions, non_matchable_modules, non_matchable_methods + ) + results = collections.OrderedDict() + if base_name_to_sets_of_related_ops is None: + base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops() + type_a_related_to_b = get_type_a_related_to_b(base_name_to_sets_of_related_ops) + + existing_names_a: set[str] = set() + existing_names_b: set[str] = set() + + while True: + # fetch the next subgraphs from a and b + cur_subgraph_a, cur_subgraph_b = None, None + try: + cur_subgraph_a = next(graph_a_iterator) + except StopIteration: + pass + try: + cur_subgraph_b = next(graph_b_iterator) + except StopIteration: + pass + + # look up types of a and b for useful error messages + type_start_a, type_start_b = None, None + if cur_subgraph_a is not None: + type_start_a = _get_node_target_type(cur_subgraph_a.start_node, gm_a) + if cur_subgraph_b is not None: + type_start_b = _get_node_target_type(cur_subgraph_b.start_node, gm_b) + + # check for results and determine what to do next + if cur_subgraph_a is not None and cur_subgraph_b is not None: + # both nodes were fetched, check for subgraph_relationship + # note: subgraph_relationship is checked on the start node, i.e. + # if a linear-relu pattern is checked, we would check for subgraph_relationship + # of the linear + subgraph_relationship = _get_subgraph_relationship_type( + cur_subgraph_a, cur_subgraph_b, gm_a, gm_b, type_a_related_to_b + ) + if subgraph_relationship == SubgraphTypeRelationship.NOT_RELATED: + msg = f""" +The subgraphs +({cur_subgraph_a}, {type_start_a}) and +({cur_subgraph_b}, {type_start_b}) +are not related. Please ensure that the two models you pass in have the same number +of subgraphs, and each pair of subgraphs is related to each other.""" + raise GraphMatchingException(msg) + elif subgraph_relationship == SubgraphTypeRelationship.EQUAL_BUT_UKNOWN: + # skip matching but unknown types + continue + key_name_a = _get_name_for_subgraph( + cur_subgraph_a, gm_a, base_name_to_sets_of_related_ops, existing_names_a + ) + key_name_b = _get_name_for_subgraph( + cur_subgraph_b, gm_b, base_name_to_sets_of_related_ops, existing_names_b + ) + assert key_name_a == key_name_b, ( + f"Subgraph names {key_name_a} and {key_name_b} do not match" + ) + results[key_name_a] = (cur_subgraph_a, cur_subgraph_b) + continue + elif cur_subgraph_a is None and cur_subgraph_b is None: + # we reached the end of both graphs + break + else: + # only one node was fetched, no match possible, throw error + msg = f""" +Attempting to match +({cur_subgraph_a}, {type_start_a}) and +({cur_subgraph_b}, {type_start_b}), +one of which is empty. Please ensure that the two models you pass in have the same number +of subgraphs.""" + raise GraphMatchingException(msg) + + # The subgraph pairs are originally created by traversing the two graphs + # from the outputs to the inputs. Reverse the results to return the + # subgraphs in their order of execution. + results = collections.OrderedDict(reversed(results.items())) + + return results diff --git a/phivenv/Lib/site-packages/torch/ao/ns/fx/graph_passes.py b/phivenv/Lib/site-packages/torch/ao/ns/fx/graph_passes.py new file mode 100644 index 0000000000000000000000000000000000000000..2f98da3d1326db0abb237aa02ede716c267b86a5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/ns/fx/graph_passes.py @@ -0,0 +1,1133 @@ +# mypy: allow-untyped-defs +from typing import Any, Callable, Optional, Union + +import torch +from torch.ao.ns.fx.mappings import get_node_type_to_io_type_map +from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix +from torch.ao.quantization.observer import _is_activation_post_process +from torch.fx import GraphModule, map_arg +from torch.fx.graph import Graph, Node + +from .ns_types import NSNodeTargetType, NSSingleResultValuesType, NSSubgraph +from .utils import ( + get_arg_indices_of_inputs_to_log, + get_node_first_input_and_output_type, + get_node_input_qparams, + get_normalized_nth_input, + get_number_of_non_param_args, + get_target_type_str, + getattr_from_fqn, + NodeInputOrOutputType, + op_type_supports_shadowing, + return_first_non_observer_node, +) + + +def _maybe_get_fqn(node: Node, gm: GraphModule) -> Optional[str]: + fqn = None + if hasattr(gm, "_node_name_to_scope"): + # fqn on observers is not present, because they do not + # exist when the fqns are created during tracing. If this is + # an observer, get the fqn of the node being observed. + node_to_use_for_fqn = node + if node.op == "call_module": + assert isinstance(node.target, str) + module = getattr_from_fqn(gm, node.target) + if _is_activation_post_process(module): + node_to_use_for_fqn = get_normalized_nth_input(node, gm, 0) + fqn = gm._node_name_to_scope[node_to_use_for_fqn.name][0] # type: ignore[index] + return fqn # type: ignore[return-value] + + +def _insert_logger_after_node( + node: Node, + gm: GraphModule, + logger_cls: Callable, + logger_node_name_suffix: str, + ref_node_name: str, + model_name: str, + ref_name: str, + ref_node_target_type: str, + results_type: str, + index_within_arg: int, + index_of_arg: int, + fqn: Optional[str], +) -> Node: + """ + Given a starting graph of + + prev_node -> node -> next_node + + This function creates a new logger_cls obj and adds it + after node, resulting in + + prev_node -> node -> logger_obj -> next_node + """ + # create new name + logger_node_name = get_new_attr_name_with_prefix( + node.name + logger_node_name_suffix + )(gm) + target_type = get_target_type_str(node, gm) + # create the logger object + logger_obj = logger_cls( + ref_node_name, + node.name, + model_name, + ref_name, + target_type, + ref_node_target_type, + results_type, + index_within_arg, + index_of_arg, + fqn, + ) + # attach the logger object to the parent module + setattr(gm, logger_node_name, logger_obj) + logger_node = node.graph.create_node("call_module", logger_node_name, (node,), {}) + return logger_node + + +def add_loggers_to_model( + gm: GraphModule, + node_to_instrument_inputs_to_ref_node_name: dict[Node, tuple[str, str]], + node_to_instrument_outputs_to_ref_node_name: dict[Node, tuple[str, str]], + logger_cls: Callable, + model_name: str, +) -> GraphModule: + """ + Takes the graph of gm, adds loggers to the output + of each node in nodes_to_instrument. Returns a GraphModule with the new + graph. + """ + + new_graph = Graph() + env: dict[str, Any] = {} + + def load_arg(a): + return map_arg(a, lambda node: env[node.name]) + + for node in gm.graph.nodes: + if node.op == "output": + new_graph.output(map_arg(get_normalized_nth_input(node, gm, 0), load_arg)) + continue + + if (node in node_to_instrument_inputs_to_ref_node_name) or ( + node in node_to_instrument_outputs_to_ref_node_name + ): + fqn = _maybe_get_fqn(node, gm) + + if node in node_to_instrument_inputs_to_ref_node_name: + ref_name, ref_node_type = node_to_instrument_inputs_to_ref_node_name[ + node + ] + # Ops such add and mul are special because either + # one or two of the first two arguments can be tensors, + # and if one argument is a tensor it can be first or + # second (x + 1 versus 1 + x). + arg_indices_to_log = get_arg_indices_of_inputs_to_log(node) + for node_arg_idx in arg_indices_to_log: + node_arg = get_normalized_nth_input(node, gm, node_arg_idx) + if type(node_arg) == Node: + # create a single input logger + prev_node = env[node_arg.name] + env[node_arg.name] = _insert_logger_after_node( + prev_node, + gm, + logger_cls, + "_ns_logger_", + node.name, + model_name, + ref_name, + ref_node_type, + NSSingleResultValuesType.NODE_INPUT.value, + index_within_arg=0, + index_of_arg=node_arg_idx, + fqn=fqn, + ) + elif ( + type(node_arg) == torch.fx.immutable_collections.immutable_list + ): + # create N input loggers, one for each node + for arg_idx, arg in enumerate(node_arg): # type: ignore[var-annotated, arg-type] + prev_node = env[arg.name] + env[prev_node.name] = _insert_logger_after_node( + prev_node, + gm, + logger_cls, + "_ns_logger_", + node.name, + model_name, + ref_name, + ref_node_type, + NSSingleResultValuesType.NODE_INPUT.value, + index_within_arg=arg_idx, + index_of_arg=node_arg_idx, + fqn=fqn, + ) + else: + pass + + # ensure env is populated with base node + # Note: runs for both inputs and outputs + env[node.name] = new_graph.node_copy(node, load_arg) + + if node in node_to_instrument_outputs_to_ref_node_name: + ref_name, ref_node_type = node_to_instrument_outputs_to_ref_node_name[ + node + ] + # add the logger after the base node + env[node.name] = _insert_logger_after_node( + env[node.name], + gm, + logger_cls, + "_ns_logger_", + node.name, + model_name, + ref_name, + ref_node_type, + NSSingleResultValuesType.NODE_OUTPUT.value, + index_within_arg=0, + index_of_arg=0, + fqn=fqn, + ) + + else: + env[node.name] = new_graph.node_copy(node, load_arg) + + new_gm = GraphModule(gm, new_graph) + return new_gm + + +def _insert_quantize_per_tensor_node( + prev_node_c: Node, + node_a: Node, + gm_b: GraphModule, + graph_c: Graph, + scale: Union[torch.Tensor, float], + zero_point: Union[torch.Tensor, int], + dtype_cast_name: str, +) -> Node: + # copy scale + scale_node_name = get_new_attr_name_with_prefix(node_a.name + "_input_scale_")(gm_b) + setattr(gm_b, scale_node_name, scale) + scale_node = graph_c.create_node( + "get_attr", scale_node_name, (), {}, scale_node_name + ) + # copy zero_point + zero_point_node_name = get_new_attr_name_with_prefix( + node_a.name + "_input_zero_point_" + )(gm_b) + setattr(gm_b, zero_point_node_name, zero_point) + zero_point_node = graph_c.create_node( + "get_attr", zero_point_node_name, (), {}, zero_point_node_name + ) + # create the quantize_per_tensor call + return graph_c.create_node( + "call_function", + torch.quantize_per_tensor, + (prev_node_c, scale_node, zero_point_node, torch.quint8), + {}, + dtype_cast_name, + ) + + +def _insert_dtype_cast_after_node( + node_a: Node, + node_c: Node, + prev_node_c: Union[Node, list[Node]], + gm_a: GraphModule, + gm_b: GraphModule, + graph_c: Graph, + node_name_prefix: str, + logger_cls: Callable, + node_type_to_io_type_map: dict[str, set[NSNodeTargetType]], +) -> Union[Node, list[Node]]: + """ + Given a starting graph C (derived from graph B) of + + ... -> prev_node_c -> node_c -> ... + + And a corresponding related node_a, inserts the correct dtype + cast node after prev_node_c to cast into the dtype expected + by node_a, resulting in: + + dtype_cast + / + ... -> prev_node_c -> node_c -> ... + + For example, if node_c is an int8 op and node_a is an fp32 op, this function + will insert a dequant. + """ + dtype_cast_op = None + dtype_cast_mod_cls = None + dtype_cast_method = None + dtype_cast_method_dtype = None + dtype_cast_scale = None + dtype_cast_zero_point = None + node_input_type_a, _node_output_type_a = get_node_first_input_and_output_type( + node_a, gm_a, logger_cls, node_type_to_io_type_map + ) + node_input_type_c, _node_output_type_c = get_node_first_input_and_output_type( + node_c, gm_b, logger_cls, node_type_to_io_type_map + ) + + if ( + ( + node_input_type_a == NodeInputOrOutputType.FP32 + and node_input_type_c == NodeInputOrOutputType.INT8 + ) + or ( + node_input_type_a == NodeInputOrOutputType.FP32 + and node_input_type_c == NodeInputOrOutputType.FP16 + ) + or + # TODO(future PR): determine the actual dtype of node_c, + # the current code only works because dequantize works with + # multiple input dtypes. + ( + node_input_type_a == NodeInputOrOutputType.FP32 + and node_input_type_c == NodeInputOrOutputType.FP32_OR_INT8 + ) + ): + dtype_cast_op = torch.dequantize + elif ( + node_input_type_a == node_input_type_c + and node_input_type_a != NodeInputOrOutputType.UNKNOWN + ): + dtype_cast_mod_cls = torch.nn.Identity + elif ( + node_input_type_a == NodeInputOrOutputType.INT8 + and node_input_type_c == NodeInputOrOutputType.FP32 + ): + # int8 shadows fp32, the dtype cast needs to quantize to int8 + # with the right qparams. + node_a_input_qparams = get_node_input_qparams( + node_a, gm_a, node_type_to_io_type_map + ) + if node_a_input_qparams is not None: + dtype_cast_op = torch.quantize_per_tensor # type: ignore[assignment] + dtype_cast_scale, dtype_cast_zero_point = node_a_input_qparams + elif ( + node_input_type_a == NodeInputOrOutputType.FP16 + and node_input_type_c == NodeInputOrOutputType.FP32 + ): + dtype_cast_method = "to" + dtype_cast_method_dtype = torch.float16 + else: + raise AssertionError( + f"dtype cast from {node_input_type_c} {node_c.format_node()} to " + + f"{node_input_type_a} {node_a.format_node()} needs to be implemented" + ) + + if isinstance(prev_node_c, Node): + new_dtype_cast_name = get_new_attr_name_with_prefix(node_name_prefix)(gm_b) + if dtype_cast_op: + if dtype_cast_scale is not None and dtype_cast_zero_point is not None: + return _insert_quantize_per_tensor_node( + prev_node_c, + node_a, + gm_b, + graph_c, + dtype_cast_scale, + dtype_cast_zero_point, + new_dtype_cast_name, + ) + else: + return graph_c.create_node( + "call_function", + dtype_cast_op, + (prev_node_c,), + {}, + new_dtype_cast_name, + ) + elif dtype_cast_method: + return graph_c.create_node( + "call_method", + dtype_cast_method, + (prev_node_c, dtype_cast_method_dtype), + {}, + new_dtype_cast_name, + ) + else: + assert dtype_cast_mod_cls + dtype_cast_mod = dtype_cast_mod_cls() + setattr(gm_b, new_dtype_cast_name, dtype_cast_mod) + return graph_c.create_node( + "call_module", + new_dtype_cast_name, + (prev_node_c,), + {}, + new_dtype_cast_name, + ) + elif isinstance(prev_node_c, list): + results = [] + for prev_node_c_inner in prev_node_c: + new_dtype_cast_name = get_new_attr_name_with_prefix(node_name_prefix)(gm_b) + if dtype_cast_op: + # TODO(future PR): add handling for quantize_per_tensor + new_dtype_cast_node = graph_c.create_node( + "call_function", + dtype_cast_op, + (prev_node_c_inner,), + {}, + new_dtype_cast_name, + ) + results.append(new_dtype_cast_node) + else: + assert dtype_cast_mod_cls + dtype_cast_mod = dtype_cast_mod_cls() + setattr(gm_b, new_dtype_cast_name, dtype_cast_mod) + new_dtype_cast_node = graph_c.create_node( + "call_module", + new_dtype_cast_name, + (prev_node_c_inner,), + {}, + new_dtype_cast_name, + ) + results.append(new_dtype_cast_node) + return results + else: + raise AssertionError(f"type f{type(prev_node_c)} is not handled") + + +# TODO(future PR): look into using copy_node API instead +def _copy_node_from_a_to_c( + node_a: Node, + gm_a: GraphModule, + gm_b: GraphModule, + graph_c: Graph, +) -> Node: + """ + Simple copy of node_a to graph_c. + """ + if node_a.op == "get_attr": + node_a_copy_name = get_new_attr_name_with_prefix(node_a.name + "_shadow_copy_")( + gm_b + ) + node_a_obj = getattr_from_fqn(gm_a, node_a.target) # type: ignore[arg-type] + if torch.is_tensor(node_a_obj): + node_a_obj = node_a_obj.detach() + setattr(gm_b, node_a_copy_name, node_a_obj) + node_a_copy = graph_c.create_node( + node_a.op, node_a_copy_name, (), {}, node_a_copy_name + ) + return node_a_copy + elif node_a.op == "call_method": + assert node_a.target in ( + "dequantize", + "to", + ), f"target {node_a.target} is not implemented" + if node_a.target == "dequantize": + arg_copy = _copy_node_from_a_to_c( + get_normalized_nth_input(node_a, gm_a, 0), gm_a, gm_b, graph_c + ) # type: ignore[arg-type] + node_a_copy_name = get_new_attr_name_with_prefix( + node_a.name + "_shadow_copy_" + )(gm_b) + node_a_copy = graph_c.create_node( + node_a.op, node_a.target, (arg_copy,), {}, node_a_copy_name + ) + return node_a_copy + else: # to + arg_copy = _copy_node_from_a_to_c( + get_normalized_nth_input(node_a, gm_a, 0), gm_a, gm_b, graph_c + ) # type: ignore[arg-type] + node_a_copy_name = get_new_attr_name_with_prefix( + node_a.name + "_shadow_copy_" + )(gm_b) + node_a_copy = graph_c.create_node( + node_a.op, + node_a.target, + (arg_copy, get_normalized_nth_input(node_a, gm_a, 1)), + {}, + node_a_copy_name, + ) + return node_a_copy + + else: + raise AssertionError( + f"handling of node {node_a.format_node()} with op {node_a.op} is not implemented" + ) + + +def _can_insert_copy_of_subgraph_a( + subgraph_a: NSSubgraph, + gm_a: GraphModule, + num_non_param_args_node_a: int, +) -> bool: + """ + This function returns `False` if the input subgraph cannot be copied by + `_insert_copy_of_subgraph_a_after_input_node_c`. This usually means + that there is a corner case logic for which copy is not yet implemented. + """ + # populate the list of nodes we need to check + nodes = [] + cur_node = subgraph_a.end_node + while cur_node != subgraph_a.start_node: + nodes.append(cur_node) + cur_node = get_normalized_nth_input(cur_node, gm_a, 0) # type: ignore[assignment] + nodes.append(cur_node) + nodes.reverse() + + def _can_insert(node_a_arg, gm_a): + if isinstance(node_a_arg, Node): + arg_a = return_first_non_observer_node(node_a_arg, gm_a) + if arg_a.op == "call_method": + return arg_a.target in ("dequantize", "to") + elif arg_a.op == "get_attr": + return True + else: + return False + elif isinstance(node_a_arg, (list, tuple)): + for el in node_a_arg: + if not isinstance(el, Node): + return False + return True + + # For each node, check if we handle the copy behavior. This follows the + # logic in `_insert_copy_of_subgraph_a_after_input_node_c`. + for node_a in nodes: + local_num_non_param_args_node_a = ( + num_non_param_args_node_a if node_a is nodes[0] else 1 + ) + + norm_args_kwargs = node_a.normalized_arguments( + gm_a, normalize_to_only_use_kwargs=True + ) + if norm_args_kwargs is not None: + norm_args, norm_kwargs = norm_args_kwargs + else: + norm_args, norm_kwargs = node_a.args, node_a.kwargs + + cur_idx = 0 + + while cur_idx < len(norm_args): + if cur_idx == 0: + pass + elif cur_idx == 1 and local_num_non_param_args_node_a == 2: + pass + else: + if not _can_insert(norm_args[cur_idx], gm_a): + return False + cur_idx += 1 + + for kwarg_val in norm_kwargs.values(): + # stitch the inputs from base graph + if cur_idx == 0: + pass + elif cur_idx == 1 and local_num_non_param_args_node_a == 2: + pass + else: + if not _can_insert(kwarg_val, gm_a): + return False + cur_idx += 1 + + return True + + +def _insert_copy_of_subgraph_a_after_input_node_c( + input_node_c: Union[Node, list[Node]], + input_node_c_2: Optional[Union[Node, list[Node]]], + subgraph_a: NSSubgraph, + gm_a: GraphModule, + gm_b: GraphModule, + node_name_prefix: str, +) -> Node: + """ + TODO(before land): real docblock + """ + assert isinstance(input_node_c, (Node, list)) + + # create a sequential list of the subgraphs' nodes from start to end, + # because we need to add the nodes to graph C in non-reverse order + nodes_of_a = [subgraph_a.end_node] + cur_node = subgraph_a.end_node + while cur_node != subgraph_a.start_node: + cur_node = get_normalized_nth_input(cur_node, gm_a, 0) # type: ignore[assignment] + nodes_of_a.insert(0, cur_node) + + # go through nodes of a in order, and insert them into the graph of c + # sequentially + cur_node_a = nodes_of_a[0] + cur_node_c = _insert_copy_of_node_a_after_input_node_c( + input_node_c, input_node_c_2, cur_node_a, gm_a, gm_b, node_name_prefix + ) + for cur_idx_a in range(1, len(nodes_of_a)): + cur_node_a = nodes_of_a[cur_idx_a] + prev_node_c = cur_node_c # previous added node is the input to next node + cur_node_c = _insert_copy_of_node_a_after_input_node_c( + prev_node_c, + # TODO(future PR): enable multiple inputs for nodes which are not at start of subgraph + None, + cur_node_a, + gm_a, + gm_b, + node_name_prefix, + ) + # return the last inserted node + return cur_node_c + + +def _insert_copy_of_node_a_after_input_node_c( + input_node_c: Union[Node, list[Node]], + input_node_c_2: Optional[Union[Node, list[Node]]], + node_a: Node, + gm_a: GraphModule, + gm_b: GraphModule, + node_name_prefix: str, +) -> Node: + """ + Assume that node_a from graph_a has + args (input, (input2)?, arg1, ...), and + kwargs {kw0: kwarg0, ...} + + Note: input2 is optional. If it equals to None, we assume that the op + has a single non-param input. If it is specified, we assume that the op + has two non-param inputs. + + Copies the underlying values of arg1..argn and kwarg0..kwargn into gm_b, + and creates the corresponding nodes in graph_c. Note: observers are ignored, + so if an arg is an observer we navigate up until we find a non-observer parent. + + If node_a is a call_module, points the module pointed to by node_a to gm_b. + + Creates the copy of node_a in graph_c, with input as the first arg, + and all other args and kwargs pointing to the copies of the objects + in gm_b created above. + + An example in pictures: + + graph A: + ======== + + input -------------> node_a + / / / + (input_2)?----------/ / / + / / + weight -> weight_obs / + / + bias ---------------- + + graph C (derived from B): + ========================= + + input_node_c --> node_a_copy + / / / + (input_node_c_2)? / / + / / + weight_copy ----/ / + / + bias_copy ------/ + """ + if isinstance(input_node_c, Node): + graph_c = input_node_c.graph + else: + assert isinstance(input_node_c, list) + graph_c = input_node_c[0].graph + + norm_args_kwargs = node_a.normalized_arguments( + gm_a, normalize_to_only_use_kwargs=True + ) + if norm_args_kwargs is not None: + norm_args, norm_kwargs = norm_args_kwargs + else: + norm_args, norm_kwargs = node_a.args, node_a.kwargs + + new_args = [] + new_kwargs = {} + + def _copy_arg(arg): + # copy the other inputs from the other graph + if isinstance(arg, Node): + arg = return_first_non_observer_node(arg, gm_a) + arg = _copy_node_from_a_to_c(arg, gm_a, gm_b, graph_c) + return arg + elif isinstance(arg, (int, float, torch.dtype)): + return arg + elif isinstance(kwarg_val, (list, tuple)): + for el in kwarg_val: + assert not isinstance(el, Node), ( + "handling of Node inside list is not implemented" + ) + return arg + else: + raise AssertionError( + f"handling for kwarg of type {type(kwarg_val)} is not implemented" + ) + + cur_idx = 0 + + while cur_idx < len(norm_args): + if cur_idx == 0: + new_arg = input_node_c + elif cur_idx == 1 and input_node_c_2 is not None: + new_arg = input_node_c_2 + else: + new_arg = _copy_arg(norm_args[cur_idx]) + new_args.append(new_arg) + cur_idx += 1 + + for kwarg_name, kwarg_val in norm_kwargs.items(): + # stitch the inputs from base graph + if cur_idx == 0: + new_kwargs[kwarg_name] = input_node_c + elif cur_idx == 1 and input_node_c_2 is not None: + new_kwargs[kwarg_name] = input_node_c_2 + else: + new_kwargs[kwarg_name] = _copy_arg(kwarg_val) + cur_idx += 1 + + new_args = tuple(new_args) # type: ignore[assignment] + + node_a_shadows_c_name = get_new_attr_name_with_prefix(node_name_prefix)(gm_b) + + if node_a.op == "call_module": + # if target is a module, we point to the module from gm_b + new_mod_copy_name = get_new_attr_name_with_prefix(node_name_prefix)(gm_b) + # fetch the corresponding module from gm_a + assert isinstance(node_a.target, str) + mod_a = getattr_from_fqn(gm_a, node_a.target) + setattr(gm_b, new_mod_copy_name, mod_a) + node_a_shadows_c = graph_c.create_node( + node_a.op, + new_mod_copy_name, + new_args, # type: ignore[arg-type] + new_kwargs, # type: ignore[arg-type] + node_a_shadows_c_name, + ) + return node_a_shadows_c + else: + assert node_a.op in ("call_function", "call_method") + node_a_shadows_c = graph_c.create_node( + node_a.op, + node_a.target, + new_args, # type: ignore[arg-type] + new_kwargs, # type: ignore[arg-type] + node_a_shadows_c_name, + ) + return node_a_shadows_c + + +def create_a_shadows_b( + name_a: str, + gm_a: GraphModule, + name_b: str, + gm_b: GraphModule, + matched_subgraph_pairs: dict[str, tuple[NSSubgraph, NSSubgraph]], + logger_cls: Callable, + should_log_inputs: bool, + node_type_to_io_type_map: Optional[dict[str, set[NSNodeTargetType]]] = None, +) -> GraphModule: + """ + Creates a new GraphModule consisting of the graph of C, with the meaningful + nodes of A shadowing the corresponding nodes of B. For example, + + Graph A: + a0 -> op0_fp32 -> a1 -> op1_fp32 -> a2 + + Graph B: + b0 -> op0_int8 -> b1 -> op1_int8 -> b2 + + matched_node_pairs: {'op0': (op0_fp32, op0_int8), 'op1': (op1_fp32, op1_int8)} + + Graph C (A shadows B): + + / dequant0 -> op0_fp32 -> logger_a_0 / dequant_1 -> op1_fp32 -> logger_a_1 + / / + b0 -------------> op0_int8 -> logger_b_0 --------------> op1_int8 -> logger_b_1 + + In a nutshell, this function does the following for each node pair: + * copies the necessary attributes and modules from gm_a to gm_b, + keeping names unique + * adds a dtype cast op (dequant, quant, etc) + * adds a copy of node_a in gm_b's graph + * adds loggers to the outputs of node_a and node_b + """ + + if node_type_to_io_type_map is None: + node_type_to_io_type_map = get_node_type_to_io_type_map() + + # graph_c is the graph created from copying the nodes of graph_b and inserting + # the shadows with the nodes copied from graph_a + graph_c = Graph() + env_c: dict[str, Any] = {} + + def load_arg(a): + return map_arg(a, lambda node: env_c[node.name]) + + start_node_b_to_matched_subgraph_a_and_name = {} + end_node_b_to_matched_subgraph_a_and_name = {} + for match_name, match in matched_subgraph_pairs.items(): + subgraph_a, subgraph_b = match + ref_node_type_a = get_target_type_str(subgraph_a.base_op_node, gm_a) + ref_node_type_b = get_target_type_str(subgraph_b.base_op_node, gm_b) + start_node_b_to_matched_subgraph_a_and_name[subgraph_b.start_node] = ( + subgraph_a, + match_name, + ref_node_type_a, + ref_node_type_b, + ) + end_node_b_to_matched_subgraph_a_and_name[subgraph_b.end_node] = ( + subgraph_a, + match_name, + ref_node_type_a, + ref_node_type_b, + ) + + for node_b in gm_b.graph.nodes: + if node_b.op == "output": + graph_c.output(map_arg(node_b.args[0], load_arg)) + continue + + # calculate the flags to determine what to do with this node + node_b_is_start_node = node_b in start_node_b_to_matched_subgraph_a_and_name + node_b_is_end_node = node_b in end_node_b_to_matched_subgraph_a_and_name + + if node_b_is_start_node or node_b_is_end_node: + if node_b_is_start_node: + ( + subgraph_a, + ref_name, + ref_node_type_a, + ref_node_type_b, + ) = start_node_b_to_matched_subgraph_a_and_name[node_b] + else: + assert node_b_is_end_node + ( + subgraph_a, + ref_name, + ref_node_type_a, + ref_node_type_b, + ) = end_node_b_to_matched_subgraph_a_and_name[node_b] + + all_op_types_support_shadowing = op_type_supports_shadowing( + subgraph_a.start_node + ) and op_type_supports_shadowing(node_b) + if not all_op_types_support_shadowing: + print( + f"skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}" + + f", start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}" + + ", unsupported" + ) + env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) + continue + + # For both start_node and end_node verify that we know how to do + # the dtype cast. If we do not, skip. + ( + node_input_type_a, + node_output_type_a, + ) = get_node_first_input_and_output_type( + subgraph_a.start_node, gm_a, logger_cls, node_type_to_io_type_map + ) + ( + node_input_type_b, + node_output_type_b, + ) = get_node_first_input_and_output_type( + node_b, gm_b, logger_cls, node_type_to_io_type_map + ) + node_io_types_known_a_and_b = ( + node_input_type_a != NodeInputOrOutputType.UNKNOWN + and node_output_type_a != NodeInputOrOutputType.UNKNOWN + and node_input_type_b != NodeInputOrOutputType.UNKNOWN + and node_output_type_b != NodeInputOrOutputType.UNKNOWN + ) + if not node_io_types_known_a_and_b: + print( + f"skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}" + + f", start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}" + + ", unknown dtype cast" + ) + env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) + continue + + # If we are shadowing from fp32 to int8, we need to insert + # quantize_per_tensor call with qparams from the previous node. + # Only do this if we are able to infer these qparams from the graph. + if ( + node_input_type_a == NodeInputOrOutputType.INT8 + and node_input_type_b == NodeInputOrOutputType.FP32 + ): + node_a_input_qparams = get_node_input_qparams( + subgraph_a.start_node, gm_a, node_type_to_io_type_map + ) + if not node_a_input_qparams: + print( + f"skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}" + + f", start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}" + + ", unknown input qparams" + ) + env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) + continue + + num_non_param_args_node_a = get_number_of_non_param_args( + subgraph_a.start_node, gm_a + ) + if not _can_insert_copy_of_subgraph_a( + subgraph_a, gm_a, num_non_param_args_node_a + ): + print( + f"skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}" + + f", start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}" + + ", unhandled logic in subgraph copy" + ) + env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) + continue + + fqn_base_a = _maybe_get_fqn(subgraph_a.base_op_node, gm_a) + fqn_base_b = _maybe_get_fqn(subgraph_b.base_op_node, gm_b) # type: ignore[possibly-undefined] + + if node_b_is_start_node: + # if necessary, log the input of node_c + if should_log_inputs: + prev_node_b = get_normalized_nth_input(node_b, gm_b, 0) + if isinstance(prev_node_b, Node): + prev_node_c = env_c[prev_node_b.name] + env_c[prev_node_c.name] = _insert_logger_after_node( + prev_node_c, + gm_b, + logger_cls, + "_ns_logger_b_inp_", + node_b.name, + name_b, + ref_name, + ref_node_type_b, + NSSingleResultValuesType.NODE_INPUT.value, + index_within_arg=0, + index_of_arg=0, + fqn=fqn_base_b, + ) + elif isinstance(prev_node_b, list): + # first, save the prev_node instances, because they + # will be overwritten in the env after the first logger + # is added + prev_node_c_list = [env_c[arg.name] for arg in prev_node_b] + + for arg_idx, arg in enumerate(prev_node_b): + prev_node_c = prev_node_c_list[arg_idx] + env_c[prev_node_c.name] = _insert_logger_after_node( + prev_node_c, + gm_b, + logger_cls, + "_ns_logger_b_inp_", + node_b.name, + name_b, + ref_name, + ref_node_type_b, + NSSingleResultValuesType.NODE_INPUT.value, + index_within_arg=arg_idx, + index_of_arg=0, + fqn=fqn_base_b, + ) + else: + # logging of inputs which are not lists is not supported yet + raise AssertionError( + f"type {type(prev_node_b)} is not handled yet" + ) + # subgraph so far: + # + # (prev_node_c)+ -> (logger_c_input)? + + # Note: this if statement is always True, spelling it out to clarify code + # intent. + if node_b_is_start_node or node_b_is_end_node: + # ensure env_c is populated with base node + env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) + node_c = env_c[node_b.name] + + # after this point, + # + # node_a is the original node from graph_a, with parent module gm_a + # node_b is the original node from graph_b, with parent module gm_b + # node_c is the copy of node_b in graph_c + # + # subgraph so far: + # + # (prev_node_c)+ -> (logger_c_input)? -> node_start_c + + if node_b_is_start_node: + # cast dtype from the dtype of node_c's input to the dtype of + # node_a's input (dequant, etc) + # prev_node_c = node_c.args[0] + prev_node_c = get_normalized_nth_input(node_c, gm_b, 0) # type: ignore[possibly-undefined] + if should_log_inputs: + # skip the input logger when inserting a dtype cast + if isinstance(prev_node_c, Node): + prev_node_c = get_normalized_nth_input(node_c, gm_b, 0) + elif isinstance(prev_node_c, list): + prev_node_c = [ + get_normalized_nth_input(arg, gm_b, 0) + for arg in prev_node_c + ] + dtype_cast_node = _insert_dtype_cast_after_node( + subgraph_a.start_node, + node_c, + prev_node_c, + gm_a, + gm_b, + graph_c, + node_b.name + "_dtype_cast_", + logger_cls, + node_type_to_io_type_map, + ) + # note: not inserting to env_c because all nodes which use the dtype + # casts are copied from graph_a + # + # subgraph so far: + # + # (dtype_cast_node)+ + # / + # (prev_node_c)+ -> (logger_c_input)? -> node_start_c + + # if input logging is enabled, log the input to the subgraph + if should_log_inputs: + # TODO: explain this + ref_node_name = "" + if isinstance(dtype_cast_node, Node): + dtype_cast_node = _insert_logger_after_node( + dtype_cast_node, + gm_b, + logger_cls, + "_ns_logger_a_inp_", + ref_node_name, + name_a, + ref_name, + ref_node_type_a, + NSSingleResultValuesType.NODE_INPUT.value, + index_within_arg=0, + index_of_arg=0, + fqn=fqn_base_a, + ) + input_logger: Union[Node, list[Node]] = dtype_cast_node + else: + assert isinstance(dtype_cast_node, list) + new_loggers = [] + for dtype_cast_idx, dtype_cast_node_inner in enumerate( + dtype_cast_node + ): + dtype_cast_logger = _insert_logger_after_node( + dtype_cast_node_inner, + gm_b, + logger_cls, + "_ns_logger_a_inp_", + ref_node_name, + name_a, + ref_name, + ref_node_type_a, + NSSingleResultValuesType.NODE_INPUT.value, + index_within_arg=dtype_cast_idx, + index_of_arg=0, + fqn=fqn_base_a, + ) + new_loggers.append(dtype_cast_logger) + dtype_cast_node = new_loggers + input_logger = dtype_cast_node + # subgraph so far: + # + # (dtype_cast_node)+ -> (logger_a_input)? + # / + # prev_node_c -> (logger_c_input)? -> node_start_c + + # hook up the new mod_a copy to be in the graph, receiving the + # same inputs as mod_b does, with dtype cast to match a + # Some ops, such as LSTMs, have two non-param inputs. If we have + # such an op, pass the second param as well. Note: dtype casting + # for the second param is not implemented yet, it can be added + # later if there is a use case. + node_c_second_non_param_arg = None + num_non_param_args_node_a = get_number_of_non_param_args( + subgraph_a.start_node, gm_a + ) + if num_non_param_args_node_a == 2: + # node_c_second_non_param_arg = node_c.args[1] + node_c_second_non_param_arg = get_normalized_nth_input( + node_c, gm_b, 1 + ) + node_a_shadows_c = _insert_copy_of_subgraph_a_after_input_node_c( + dtype_cast_node, + node_c_second_non_param_arg, + subgraph_a, + gm_a, + gm_b, + node_c.name + "_shadow_copy_", + ) + env_c[node_a_shadows_c.name] = node_a_shadows_c + # subgraph so far: + # + # dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy(args/kwargs not shown) + # / + # (prev_node_c)+ -> (logger_c_input)? -> node_start_c + + if should_log_inputs: + # When we created the input logger, we left the ref_node_name + # as an empty string, because the subgraph copy did not exist + # yet. Now that the subgraph copy exists, we modify this name + # to its true value. + # Note: the alternative to this is to create the input logger + # after creating the subgraph, which is slightly more + # complicated. This is the lesser of two evils. + # input_logger = env_c[dtype_cast_node.name] + # Find the first node in the subgraph + cur_node = node_a_shadows_c + while get_normalized_nth_input(cur_node, gm_b, 0) != input_logger: # type: ignore[possibly-undefined] + cur_node = get_normalized_nth_input(cur_node, gm_b, 0) # type: ignore[assignment] + if isinstance(input_logger, Node): + input_logger_mod = getattr(gm_b, input_logger.name) + input_logger_mod.ref_node_name = cur_node.name + else: + assert isinstance(input_logger, list) + for input_logger_inner in input_logger: + input_logger_mod = getattr(gm_b, input_logger_inner.name) + input_logger_mod.ref_node_name = cur_node.name + + # hook up a logger to the mod_a copy + env_c[node_a_shadows_c.name] = _insert_logger_after_node( + env_c[node_a_shadows_c.name], + gm_b, + logger_cls, + "_ns_logger_a_", + node_a_shadows_c.name, + name_a, + ref_name, + ref_node_type_a, + NSSingleResultValuesType.NODE_OUTPUT.value, + index_within_arg=0, + index_of_arg=0, + fqn=fqn_base_a, + ) + # subgraph so far: + # + # dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a + # / + # (prev_node_c)+ -> (logger_c_input)? -> node_start_c + + if node_b_is_end_node: + # hook up a logger to the mod_b copy + env_c[node_b.name] = _insert_logger_after_node( + env_c[node_b.name], + gm_b, + logger_cls, + "_ns_logger_b_", + node_b.name, + name_b, + ref_name, + ref_node_type_b, + NSSingleResultValuesType.NODE_OUTPUT.value, + index_within_arg=0, + index_of_arg=0, + fqn=fqn_base_b, + ) + # subgraph so far: + # + # dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a + # / + # (prev_node_c+) -> (logger_c_input)? -> node_start_c -> ... -> node_end_c -> logger_c + # + # Note: node_start_c may be the same node as node_end_c, or they + # may have nodes inbetween. + + else: + env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) + + gm_c = GraphModule(gm_b, graph_c) + return gm_c diff --git a/phivenv/Lib/site-packages/torch/ao/ns/fx/mappings.py b/phivenv/Lib/site-packages/torch/ao/ns/fx/mappings.py new file mode 100644 index 0000000000000000000000000000000000000000..cc512a62c562825fe3da9b2fd3e71707413fceb5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/ns/fx/mappings.py @@ -0,0 +1,757 @@ +import operator +from typing import Callable, Optional + +import torch +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.intrinsic.qat as nniqat +import torch.ao.nn.intrinsic.quantized as nniq +import torch.ao.nn.intrinsic.quantized.dynamic as nniqd +import torch.ao.nn.qat as nnqat +import torch.ao.nn.qat.dynamic as nnqatd +import torch.ao.nn.quantized as nnq +import torch.ao.nn.quantized.dynamic as nnqd +import torch.ao.quantization.fx._lower_to_native_backend as _lower_to_native_backend +import torch.ao.quantization.quantization_mappings as quantization_mappings +import torch.nn as nn +import torch.nn.functional as F +from torch.ao.quantization.backend_config import get_native_backend_config + +from .ns_types import NSNodeTargetType + + +toq = torch.ops.quantized + + +def get_base_name_to_sets_of_related_ops() -> dict[str, set[NSNodeTargetType]]: + # note: this set is modified below by items from backend_config + sets_of_related_ops: list[set[NSNodeTargetType]] = [ + # conv modules + { + nn.Conv1d, + }, + { + nn.Conv2d, + }, + { + nn.Conv3d, + }, + # conv functionals + { + F.conv1d, + }, + { + F.conv2d, + }, + { + F.conv3d, + }, + # linear modules + { + nn.Linear, + }, + # linear functionals + { + F.linear, + }, + # average pool + { + nn.AvgPool1d, + torch.avg_pool1d, + }, + { + nn.AvgPool2d, + torch._C._nn.avg_pool2d, + }, + { + nn.AvgPool3d, + torch._C._nn.avg_pool3d, + }, + # adaptive average pool + { + nn.AdaptiveAvgPool1d, + F.adaptive_avg_pool1d, + }, + { + nn.AdaptiveAvgPool2d, + F.adaptive_avg_pool2d, + }, + { + nn.AdaptiveAvgPool3d, + F.adaptive_avg_pool3d, + }, + # LSTM + { + nn.LSTM, + }, + # add + { + torch.add, + operator.add, # x + y + }, + # cat + { + torch.cat, + }, + # mul + { + torch.mul, + operator.mul, + }, + # relu + { + F.relu, + nn.ReLU, + "relu", + "relu_", + torch.relu, + }, + # maxpool + { + nn.MaxPool1d, + F.max_pool1d, + }, + { + nn.MaxPool2d, + F.max_pool2d, + }, + { + nn.MaxPool3d, + F.max_pool3d, + }, + # sigmoid + { + torch.sigmoid, + "sigmoid", + "sigmoid_", + nn.Sigmoid, + F.sigmoid, + }, + # BatchNorm + { + nn.BatchNorm2d, + }, + { + nn.BatchNorm3d, + }, + # ConvTranspose + { + nn.ConvTranspose1d, + }, + { + nn.ConvTranspose2d, + }, + { + nn.ConvTranspose3d, + }, + # functional transposed conv + { + F.conv_transpose1d, + }, + { + F.conv_transpose2d, + }, + { + F.conv_transpose3d, + }, + # ELU + { + nn.ELU, + }, + # Embedding + { + nn.Embedding, + }, + # EmbeddingBag + { + nn.EmbeddingBag, + }, + # GroupNorm + { + nn.GroupNorm, + }, + # Hardswish + { + nn.Hardswish, + }, + # InstanceNorm + { + nn.InstanceNorm1d, + }, + { + nn.InstanceNorm2d, + }, + { + nn.InstanceNorm3d, + }, + # LayerNorm + { + nn.LayerNorm, + }, + # LeakyReLU + { + nn.LeakyReLU, + }, + # ReLU6 + { + nn.ReLU6, + F.relu6, + }, + # F.elu + { + F.elu, + }, + # F.hardswish + { + F.hardswish, + }, + # F.group_norm + { + F.group_norm, + }, + # F.instance_norm + { + F.instance_norm, + }, + # F.layer_norm + { + F.layer_norm, + }, + # F.leaky_relu + { + F.leaky_relu, + }, + # F.silu + { + nn.SiLU, + F.silu, + }, + # F.mish + { + nn.Mish, + F.mish, + }, + # F.tanh + { + nn.Tanh, + F.tanh, + torch.tanh, + "tanh_", + "tanh", + }, + # F.hardsigmoid + { + "hardsigmoid_", + "hardsigmoid", + F.hardsigmoid, + nn.Hardsigmoid, + }, + # F.hardtanh + { + nn.Hardtanh, + F.hardtanh, + F.hardtanh_, + }, + # floordiv + { + operator.floordiv, + }, + # unsqueeze + { + torch.unsqueeze, + }, + # stack + { + torch.stack, + }, + # squeeze + { + torch.squeeze, + }, + # sort + { + torch.sort, + }, + # repeat_interleave + { + torch.repeat_interleave, + }, + # min + { + torch.min, + }, + # mean + { + torch.mean, + }, + # max + { + torch.max, + }, + # transpose + { + torch.transpose, + }, + # flatten + { + torch.flatten, + }, + # clamp + { + torch.clamp, + }, + # chunk + { + torch.chunk, + }, + # interpolate + { + torch.nn.functional.interpolate, + }, + # dropout + { + nn.Dropout, + }, + # F.dropout + { + F.dropout, + }, + # matmul + { + torch.matmul, + }, + # Softmax + { + nn.Softmax, + }, + # PReLU + { + nn.PReLU, + nnq.PReLU, + }, + # F.prelu + { + F.prelu, + toq.prelu, + }, + # pixel shuffle + { + nn.PixelShuffle, + }, + { + F.pixel_shuffle, + }, + # pixel unshuffle + { + nn.PixelUnshuffle, + }, + { + F.pixel_unshuffle, + }, + # narrow + { + torch.narrow, + }, + ] + + # for each floating point op, add versions of the op added by + # backend_config + backend_config = get_native_backend_config() + + new_connections: list[tuple[Callable, Callable]] = [ + # technical debt edge case + (nn.Linear, nn.modules.linear.NonDynamicallyQuantizableLinear), + ] + + for pattern, config in backend_config._pattern_complex_format_to_config.items(): + # pattern format: (c, (b, a)) + first_element = pattern + # look from the end, because pattern is in reverse order + while isinstance(first_element, (list, tuple)): + first_element = first_element[-1] + + if config.fused_module is not None: + # case 1: pattern fuses a pattern of ops into an op + # example: nn.Conv1d, nn.ReLU fused into nni.ConvReLU1d + new_connections.append((first_element, config.fused_module)) + + if config.qat_module is not None: + # case 2: pattern swaps a module into a QAT module + # example: nni.ConvReLU1d swapped into nniqat.ConvReLU1d + new_connections.append((first_element, config.qat_module)) + + if config.reference_quantized_module is not None: + # case 3: reference version of floating point module, such as + # nn.Conv2d and nnqr.Conv2d + new_connections.append((first_element, config.reference_quantized_module)) + + # + # Add reference module swaps from default lowering path + # + + for source_to_target in ( + _lower_to_native_backend.STATIC_LOWER_MODULE_MAP, + _lower_to_native_backend.DYNAMIC_LOWER_MODULE_MAP, + _lower_to_native_backend.WEIGHT_ONLY_LOWER_MODULE_MAP, + _lower_to_native_backend.SPECIAL_PATTERN_LOWER_MODULE_MAP, + ): + for source, target in source_to_target.items(): # type: ignore[attr-defined] + new_connections.append((source, target)) + + for source_to_double_target in ( + _lower_to_native_backend.STATIC_LOWER_FUSED_MODULE_MAP, + _lower_to_native_backend.STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP, + _lower_to_native_backend.DYNAMIC_LOWER_FUSED_MODULE_MAP, + ): + for source, (target1, target2) in source_to_double_target.items(): # type: ignore[attr-defined] + new_connections.append((source, target1)) + new_connections.append((source, target2)) + + # + # Add function swaps from default lowering path + # + + for source, ( # type:ignore[assignment] + target1, + target2, + ) in _lower_to_native_backend.STATIC_LOWER_FUNCTIONAL_MAP.items(): + new_connections.append((source, target1)) + new_connections.append((source, target2)) + + for source_to_target in ( + _lower_to_native_backend.QBIN_OP_MAPPING, + _lower_to_native_backend.QBIN_RELU_OP_MAPPING, + quantization_mappings.DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS, + ): + for source, target in source_to_target.items(): # type:ignore[assignment] + new_connections.append((source, target)) + + # + # Add other swaps, ideally in the future this could be removed + # after the lowering code stops using these. + # + for source_to_target in ( + quantization_mappings.DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS, + ): + for source, target in source_to_target.items(): # type:ignore[assignment] + new_connections.append((source, target)) + + # add the new connections from backend_config + for item1, item2 in new_connections: + for set_of_related_ops in sets_of_related_ops: + if item1 in set_of_related_ops or item2 in set_of_related_ops: + set_of_related_ops.add(item1) + set_of_related_ops.add(item2) + break + + base_name_to_sets_of_related_ops: dict[str, set[NSNodeTargetType]] = {} + + for counter, set_of_related_ops in enumerate(sets_of_related_ops): + base_name = str(counter) + base_name_to_sets_of_related_ops[base_name] = set_of_related_ops + + return base_name_to_sets_of_related_ops + + +def get_base_name_for_op( + base_name_to_sets_of_related_ops: dict[str, set[NSNodeTargetType]], + op: NSNodeTargetType, +) -> Optional[str]: + for base_name, set_of_related_ops in base_name_to_sets_of_related_ops.items(): + if op in set_of_related_ops: + return base_name + return None + + +def add_op_to_sets_of_related_ops( + base_name_to_sets_of_related_ops: dict[str, set[NSNodeTargetType]], + op: NSNodeTargetType, + related_op: Optional[NSNodeTargetType], +) -> None: + if related_op is not None: + for set_of_related_ops in base_name_to_sets_of_related_ops.values(): + if related_op in set_of_related_ops: + set_of_related_ops.add(op) + return + # if we got here, related_op was not found + raise AssertionError(f"{related_op} was not found") + else: + counter = 0 + while str(counter) in base_name_to_sets_of_related_ops: + counter += 1 + base_name_to_sets_of_related_ops[str(counter)] = {op} + + +# TODO(future PR): clean this up +def get_node_type_to_io_type_map() -> dict[str, set[NSNodeTargetType]]: + FUNS_IO_TYPE_FP32: set[NSNodeTargetType] = { + F.linear, + F.conv1d, + F.conv2d, + F.conv3d, + torch.cat, + F.elu, + F.hardswish, + F.instance_norm, + F.layer_norm, + F.leaky_relu, + F.dropout, + F.silu, + F.mish, + operator.add, + torch.add, + operator.mul, + torch.mul, + torch.sum, + F.prelu, + } + + FUNS_IO_TYPE_FP16: set[NSNodeTargetType] = set() + + FUNS_IO_TYPE_INT8: set[NSNodeTargetType] = { + toq.linear, + toq.linear_relu, + toq.conv1d, + toq.conv1d_relu, + toq.conv2d, + toq.conv2d_relu, + toq.conv3d, + toq.conv3d_relu, + toq.cat, + toq.elu, + toq.hardswish, + toq.instance_norm, + toq.layer_norm, + toq.leaky_relu, + toq.dropout, + toq.prelu, + # TODO(future PR): implement shadowing for binary ops and + # uncomment below + # toq.add, + # toq.mul, + } + + FUNS_IO_TYPE_FP32_OR_INT8: set[NSNodeTargetType] = { + F.relu, + F.tanh, + torch.tanh, + F.sigmoid, + torch.sigmoid, + F.hardsigmoid, + operator.floordiv, + torch.adaptive_avg_pool1d, + F.adaptive_avg_pool2d, + F.adaptive_avg_pool3d, + F.dropout, + F.hardtanh, + F.hardtanh_, + F.interpolate, + F.max_pool1d, + F.max_pool2d, + F.max_pool3d, + F.relu6, + F.pixel_shuffle, + F.pixel_unshuffle, + torch.avg_pool1d, + torch._C._nn.avg_pool2d, + torch._C._nn.avg_pool3d, + torch.cat, + torch.chunk, + torch.clamp, + torch.flatten, + torch.transpose, + torch.max, + torch.mean, + torch.min, + torch.narrow, + torch.repeat_interleave, + torch.sort, + torch.squeeze, + torch.stack, + torch.unsqueeze, + operator.add, + } + + MODS_IO_TYPE_FP32: set[NSNodeTargetType] = { + nn.Linear, + nnqat.Linear, + nnqatd.Linear, + nnqd.Linear, + torch.nn.modules.linear.NonDynamicallyQuantizableLinear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nnqat.Conv1d, + nnqat.Conv2d, + nnqat.Conv3d, + nnqat.Embedding, + nnqat.EmbeddingBag, + nn.LSTM, + # note: nnqd.Linear is an instance of nnq.Linear, so this + # check has to happen before the int8 module check + nnqd.LSTM, + nn.BatchNorm2d, + nn.BatchNorm3d, + nn.Dropout, + nn.ConvTranspose1d, + nn.ConvTranspose2d, + nn.ConvTranspose3d, + nn.ELU, + nn.GroupNorm, + nn.InstanceNorm1d, + nn.InstanceNorm2d, + nn.InstanceNorm3d, + nn.LayerNorm, + nn.Hardswish, + nn.LeakyReLU, + nn.ReLU6, + nn.SiLU, + nn.Mish, + nn.Softmax, + nn.PReLU, + nni.BNReLU2d, + nni.BNReLU3d, + nni.ConvReLU1d, + nni.ConvReLU2d, + nni.ConvReLU3d, + nni.LinearReLU, + nni.LinearBn1d, + nni.ConvBn1d, + nni.ConvBn2d, + nni.ConvBn3d, + nniqat.ConvBn1d, + nniqat.ConvBn2d, + nniqat.ConvBn3d, + nniqat.ConvBnReLU1d, + nniqat.ConvBnReLU2d, + nniqat.ConvBnReLU3d, + nniqat.ConvReLU1d, + nniqat.ConvReLU2d, + nniqat.ConvReLU3d, + nniqat.LinearReLU, + nniqat.LinearBn1d, + nniqd.LinearReLU, + nni.LinearLeakyReLU, + nni.LinearTanh, + nni.ConvAdd2d, + nni.ConvAddReLU2d, + } + + MODS_IO_TYPE_INT8: set[NSNodeTargetType] = { + nnq.Linear, + nnq.Conv1d, + nnq.Conv2d, + nnq.Conv3d, + nnq.BatchNorm2d, + nnq.BatchNorm3d, + nnq.Dropout, + nnq.ConvTranspose1d, + nnq.ConvTranspose2d, + nnq.ELU, + nnq.InstanceNorm1d, + nnq.InstanceNorm2d, + nnq.InstanceNorm3d, + nnq.LayerNorm, + nnq.Hardswish, + nnq.LeakyReLU, + nnq.Embedding, + nnq.EmbeddingBag, + nnq.Dropout, + nnq.Softmax, + nnq.PReLU, + nniq.BNReLU2d, + nniq.BNReLU3d, + nniq.ConvReLU1d, + nniq.ConvReLU2d, + nniq.ConvReLU3d, + nniq.LinearReLU, + nniq.LinearLeakyReLU, + nniq.LinearTanh, + nniq.ConvAdd2d, + nniq.ConvAddReLU2d, + } + + MODS_IO_TYPE_FP32_OR_INT8: set[NSNodeTargetType] = { + nn.ReLU, + nn.Tanh, + nn.Sigmoid, + nn.Hardsigmoid, + nn.AdaptiveAvgPool1d, + nn.AdaptiveAvgPool2d, + nn.AdaptiveAvgPool3d, + nn.AvgPool1d, + nn.AvgPool2d, + nn.AvgPool3d, + nn.Dropout, + nn.Hardtanh, + nn.Identity, + nn.MaxPool1d, + nn.MaxPool2d, + nn.MaxPool3d, + nn.PixelShuffle, + nn.PixelUnshuffle, + nn.ReLU6, + } + + METHS_IO_TYPE_FP32_OR_INT8: set[NSNodeTargetType] = { + "sigmoid_", + "sigmoid", + "tanh_", + "tanh", + "hardsigmoid_", + "hardsigmoid", + "relu_", + "relu", + } + + return { + "funs_io_type_fp32": FUNS_IO_TYPE_FP32, + "funs_io_type_fp16": FUNS_IO_TYPE_FP16, + "funs_io_type_int8": FUNS_IO_TYPE_INT8, + "funs_io_type_fp32_or_int8": FUNS_IO_TYPE_FP32_OR_INT8, + "mods_io_type_fp32": MODS_IO_TYPE_FP32, + "mods_io_type_int8": MODS_IO_TYPE_INT8, + "mods_io_type_fp32_or_int8": MODS_IO_TYPE_FP32_OR_INT8, + "meths_io_type_fp32_or_int8": METHS_IO_TYPE_FP32_OR_INT8, + } + + +def get_unmatchable_types_map() -> dict[str, set[NSNodeTargetType]]: + FUNS_UNMATCHABLE: set[NSNodeTargetType] = { + torch.quantize_per_tensor, + operator.getitem, + } + + MODS_UNMATCHABLE: set[NSNodeTargetType] = { + nn.Identity, + } + + METHS_UNMATCHABLE: set[NSNodeTargetType] = { + "to", + "dequantize", + "reshape", + "view", + "unsqueeze_", + "unsqueeze", + "transpose", + "squeeze_", + "squeeze", + "size", + "shape", + "resize_", + "repeat_interleave", + "repeat", + "permute", + "numel", + "mean", + "detach_", + "detach", + "contiguous", + "clamp", + "chunk", + } + + return { + "funs_unmatchable": FUNS_UNMATCHABLE, + "mods_unmatchable": MODS_UNMATCHABLE, + "meths_unmatchable": METHS_UNMATCHABLE, + } diff --git a/phivenv/Lib/site-packages/torch/ao/ns/fx/n_shadows_utils.py b/phivenv/Lib/site-packages/torch/ao/ns/fx/n_shadows_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..df6d4ad02d13604ec0f90200b1287f05b2a5ae17 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/ns/fx/n_shadows_utils.py @@ -0,0 +1,1378 @@ +# mypy: allow-untyped-defs +import collections +import copy +import operator +from typing import Any, Callable, Optional + +import torch +import torch.fx +from torch.ao.ns.fx.graph_passes import _maybe_get_fqn +from torch.ao.ns.fx.ns_types import NSResultsType, NSSingleResultValuesType +from torch.ao.ns.fx.utils import ( # TODO(future PR): make this work correctly for methods + get_normalized_nth_input, + get_target_type_str, +) +from torch.ao.quantization import QConfigMapping +from torch.ao.quantization.fx.match_utils import _MatchResult +from torch.ao.quantization.qconfig import QConfigAny +from torch.ao.quantization.utils import getattr_from_fqn +from torch.fx import Graph, GraphModule, Node +from torch.utils._pytree import tree_map + + +SHADOW_NODE_NAME_PREFIX = "shadow" +SHADOW_WRAPPER_NODE_NAME_PREFIX = "shadow_wrapper" + +# TODO(future PR): reuse existing mapping instead of creating a new one +BINARY_FUNCTIONS = { + torch.add, + torch.Tensor.add, + operator.add, + torch.mul, + torch.Tensor.mul, + operator.mul, +} + + +def _get_attr_name(subgraph_idx, subgraph_candidate_idx): + return f"{SHADOW_NODE_NAME_PREFIX}_{subgraph_idx}_{subgraph_candidate_idx}" + + +def _get_attr_wrapper_name(subgraph_idx, subgraph_candidate_idx): + return f"{SHADOW_WRAPPER_NODE_NAME_PREFIX}_{subgraph_idx}_{subgraph_candidate_idx}" + + +class OutputProp: + """ + Output propagation (modeled from shape propagation). + + Given a GraphModule and an example input, saves the output flowing + through each node on `node.traced_result`. + + Code based on the example from + https://pytorch.org/docs/stable/fx.html#the-interpreter-pattern + """ + + def __init__(self, mod): + self.mod = mod + self.graph = mod.graph + self.modules = dict(self.mod.named_modules()) + + def propagate(self, *args): + args_iter = iter(args) + env: dict[str, Node] = {} + + def load_arg(a): + return torch.fx.graph.map_arg(a, lambda n: env[n.name]) + + def fetch_attr(target: str): + target_atoms = target.split(".") + attr_itr = self.mod + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError( + f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}" + ) + attr_itr = getattr(attr_itr, atom) + return attr_itr + + for node in self.graph.nodes: + if node.op == "placeholder": + result = next(args_iter) + elif node.op == "get_attr": + result = fetch_attr(node.target) + elif node.op == "call_function": + result = node.target(*load_arg(node.args), **load_arg(node.kwargs)) + elif node.op == "call_method": + self_obj, *args = load_arg(node.args) + kwargs = load_arg(node.kwargs) + result = getattr(self_obj, node.target)(*args, **kwargs) + elif node.op == "call_module": + result = self.modules[node.target]( + *load_arg(node.args), **load_arg(node.kwargs) + ) + + if isinstance(result, torch.Tensor): # type: ignore[possibly-undefined] + node.traced_result = result + + env[node.name] = result + + return None + + +def _get_dedup_subgraphs(matches: dict[str, _MatchResult]) -> dict[str, list[Node]]: + # the original matches variable is unique by node, make it unique by subgraph + # instead + seen_nodes = set() + subgraphs_dedup = {} + + # Dict items are not reversible until Python 3.8, so we hack it + # to be compatible with previous Python versions + # TODO(future PR): try reversed(list(matches.items())) + matches_items_reversed: list[tuple[str, _MatchResult]] = list( + reversed(matches.items()) + ) + + # Note: the order is important. `matches` currently provides the matches + # in reverse order. We would like to process the matches in non-reverse + # order, so that we can create an intuitive naming scheme, such as + # naming the first op's submodules `shadow_0_0` through `shadow_0_(n-1)` + for name, cur_match in matches_items_reversed: # type: ignore[call-overload] + was_seen = False + for node_or_tuple in cur_match[1]: + # Cur_match[1] has an unusual type. It says that it's a `List[Node]`, + # but it is really not. Furthermore, the contents of this field + # can change from match results of multiple nodes of the same pattern + # + # For example, for conv -> bn -> relu, we see + # match_results = { + # 'conv': (relu, [(bn, conv), relu], ...), + # 'bn': (relu, [(bn, conv), relu], ...), + # 'relu': (relu, [(bn, conv), relu], ...), + # } + # + # Ideally we should clean up the `find_matches` function to make + # this more intuitive. For the purposes of this prototype, we hack + # around it. + + if isinstance(node_or_tuple, Node): + if node_or_tuple in seen_nodes: + was_seen = True + seen_nodes.add(node_or_tuple) + + else: + assert isinstance(node_or_tuple, tuple) + for node in node_or_tuple: + assert isinstance(node, Node) + if node in seen_nodes: + was_seen = True + seen_nodes.add(node) + + if was_seen: + continue + + # Start with the unusual type, convert it to [op_0, ..., op_n] + list_of_nodes = [] + + if len(cur_match[1]) == 1: + list_of_nodes = cur_match[1] + else: + assert len(cur_match[1]) == 2 + # either (a, b), or ((a, b), c) or (c, (a, b)) + # cannot make any assumptions on order, not clear what the + # _find_matches function is doing to populate this + # TODO(future PR): make this code less confusing, see discussion + # in https://github.com/pytorch/pytorch/pull/80521/files#r975918836 + + def _order_nodes(node_a, node_b, node_c) -> list[Node]: + nodes = [node_a, node_b, node_c] + first_node = None + mid_node = None + last_node = None + for n in nodes: + prev_n = n.args[0] + next_n = next(iter(n.users)) + if prev_n not in nodes: + first_node = n + elif next_n not in nodes: + last_node = n + else: + mid_node = n + assert ( + first_node is not None + and mid_node is not None + and last_node is not None + ) + assert mid_node.args[0] is first_node + assert last_node.args[0] is mid_node + return [last_node, mid_node, first_node] + + if isinstance(cur_match[1][0], Node) and isinstance(cur_match[1][1], Node): + # (a, b) + list_of_nodes = cur_match[1] + elif isinstance(cur_match[1][0], tuple): + # ((a, b), c) + node_a, node_b = cur_match[1][0] + node_c = cur_match[1][1] + list_of_nodes = _order_nodes(node_a, node_b, node_c) + elif isinstance(cur_match[1][1], tuple): + # (a, (b, c)) + node_a, node_b = cur_match[1][1] + node_c = cur_match[1][0] + list_of_nodes = _order_nodes(node_a, node_b, node_c) + + # [node_n, ..., node_0], note that the order is reversed + # to make it chronological for simple subgraphs + list_of_nodes.reverse() + subgraphs_dedup[name] = list_of_nodes + + return subgraphs_dedup + + +def _get_logger_for_subgraph( + model: GraphModule, + first_node: Node, + last_node: Node, + subgraph_idx: int, + subgraph_candidate_idx: int, + qconfig_str: str, + logger_cls: Callable, + fqn: Optional[str], +) -> torch.nn.Module: + """ + Given a model and a linear subgraph starting from `first_node` and + ending with `last_node`, creates a logger for the end of this + subgraph. + """ + if fqn is None: + fqn = "" + logger_mod_orig = logger_cls( + first_node.name, # ref_node_name + last_node.name, # prev_node_name + f"subgraph_{subgraph_idx}_{subgraph_candidate_idx}", # model_name + "model", # ref_name + get_target_type_str(last_node, model), # prev_node_target_type + get_target_type_str(first_node, model), # ref_node_target_type + NSSingleResultValuesType.NODE_OUTPUT.value, # results_type + 0, # index_within_arg + 0, # index_of_arg + fqn, # fqn + qconfig_str, + ) + # Usually we expect the user to add loggers, then calibrate, then convert, + # and then populate loggers. This is why the loggers start disabled. + # TODO(future PR): reconsider the design to make this more intuitive. + logger_mod_orig.enabled = False + return logger_mod_orig + + +def create_submodule_from_subgraph( + model: torch.nn.Module, + first_node: Node, + last_node: Node, +) -> GraphModule: + """ + Input: a model, and a linear subgraph within the model from first_node to + last_node. + + Output: a new submodule containing a copy of the subgraph, with the inputs + to the first node becoming the inputs to the submodule, and all other + nodes in the subgraph being copied. + + Example inputs: + + `model`: a module with graph + + x0 -> op1 -> x1 -> op2 -> x2 + | + arg1 + + `first_node`: op1 + `last_node`: op2 + + Example output: a new module with graph + + input1 -> op1_copy -> x1 -> op2_copy -> output1 + | + arg1 + """ + + # + # create a blank GraphModule with an empty graph + # + + class M(torch.nn.Module): + def forward(self, x): + pass + + m = M() + gm = torch.fx.symbolic_trace(m) + g = gm.graph + for node in reversed(gm.graph.nodes): + g.erase_node(node) + + # + # modify the graph to have a copy of our subgraph + # + + cur_node_orig = first_node + + cur_name_idx = 0 + + iteration_limit = 100 + cur_iteration = 0 + + while True: + if cur_node_orig is first_node: + # we are at the first node, we need to set up graph inputs + # TODO(future): some graphs could have placeholders which are unrelated + # to the first node, need to handle this + cur_args_copy = [] + cur_kwargs_copy = {} + seen_names: set[str] = set() + old_name_to_new_node: dict[str, Node] = {} + + def _add_placeholder( + g: Graph, node: Node, seen_names, old_name_to_new_node + ): + # note: for graphs starting with patterns such as `y = x + x`, we + # need to ensure we do not add multiple placeholders with the + # same name + counter = 0 + while node.name + "_" + str(counter) in seen_names: + counter += 1 + cur_name = node.name + "_" + str(counter) + seen_names.add(cur_name) + placeholder = g.placeholder(cur_name) + old_name_to_new_node[node.name] = placeholder + return placeholder + + for arg in cur_node_orig.args: + if isinstance(arg, Node): + p = _add_placeholder(g, arg, seen_names, old_name_to_new_node) + cur_args_copy.append(p) + elif isinstance(arg, (list, tuple)): + new_arg = [] + for inner_arg in arg: + if isinstance(inner_arg, Node): + new_arg.append( + _add_placeholder( + g, inner_arg, seen_names, old_name_to_new_node + ) + ) + else: + new_arg.append(inner_arg) + cur_args_copy.append(new_arg) + else: + cur_args_copy.append(arg) + + # TODO(future PR): handle non-normalized kwargs + for kwarg_name, kwarg in cur_node_orig.kwargs.items(): + if isinstance(kwarg, Node): + cur_kwargs_copy[kwarg_name] = _add_placeholder( + g, kwarg, seen_names, old_name_to_new_node + ) + elif isinstance(kwarg, (list, tuple)): + new_kwarg = [] + for inner_kwarg in kwarg: + p = _add_placeholder( + g, + inner_kwarg, # type: ignore[arg-type] + seen_names, + old_name_to_new_node, + ) + new_kwarg.append(p) + cur_kwargs_copy[kwarg_name] = new_kwarg + else: + cur_kwargs_copy[kwarg_name] = kwarg + + cur_args_copy = tuple(cur_args_copy) # type: ignore[assignment] + else: + # we are not at first node, first arg is from the previous node, + # and all other args are copied + + # the current implementation is simplistic and cannot handle + # ops with two or more arguments which need to be passed from + # the previous op, so we assert them out + assert cur_node_orig.target not in BINARY_FUNCTIONS + + # at this point in the code, cur_node_copy is pointing to the copy + # of the previous node + # TODO(future PR): this is not handling complicated graphs correctly, need to + # look at actual relationships instead of assuming sequential graph + # TODO(future PR): this is ignoring kwargs, will need to support kwargs + # for any fusion pattern which has them for a node that is not the + # first node. + cur_args_copy = [cur_node_copy] # type: ignore[has-type, possibly-undefined] # noqa: F821 + + if len(cur_node_orig.args) > 1: + for arg in cur_node_orig.args[1:]: + if isinstance(arg, torch.nn.Parameter): + new_arg = arg.detach().clone() # type: ignore[assignment] + mod_name = f"mod_{cur_name_idx}" + cur_name_idx += 1 + setattr(gm, mod_name, new_arg) + new_arg_placeholder = gm.placeholder(mod_name) # type: ignore[operator] + cur_args_copy.append(new_arg_placeholder) + elif isinstance(arg, (float, int, torch.dtype)): + cur_args_copy.append(arg) + else: + raise AssertionError(f"arg of type {type(arg)} not handled yet") + cur_args_copy = tuple(cur_args_copy) # type: ignore[assignment] + + # copy the node + if cur_node_orig.op == "call_module": + orig_mod = getattr_from_fqn(model, cur_node_orig.target) # type: ignore[arg-type] + orig_mod_copy = copy.deepcopy(orig_mod) + mod_name = f"mod_{cur_name_idx}" + setattr(gm, mod_name, orig_mod_copy) + cur_name_idx += 1 + cur_node_copy = g.call_module(mod_name, cur_args_copy, cur_kwargs_copy) # type: ignore[possibly-undefined,arg-type] + + elif cur_node_orig.op == "call_function": + cur_node_copy = g.call_function( + cur_node_orig.target, # type: ignore[arg-type] + cur_args_copy, # type: ignore[arg-type] + cur_kwargs_copy, # type: ignore[possibly-undefined] + ) + + elif cur_node_orig.op == "call_method": + cur_node_copy = g.call_method( + cur_node_orig.target, # type: ignore[arg-type] + cur_args_copy, # type: ignore[arg-type] + cur_kwargs_copy, # type: ignore[possibly-undefined] + ) + + else: + raise AssertionError(f"{cur_node_orig.op} not supported yet") + + if cur_node_orig is last_node: + break + + # go to next node + assert len(cur_node_orig.users.keys()) == 1, ( + f"{cur_node_orig} has more than 1 users, not supported yet" + ) + cur_node_orig = next(iter(cur_node_orig.users.keys())) + cur_iteration += 1 + if cur_iteration > iteration_limit: + raise AssertionError("iteration limit exceeded") + + # set up outputs + g.output(cur_node_copy) + + gm.recompile() + return gm + + +def create_one_transformed_and_logged_copy_of_subgraph( + mt: GraphModule, + subgraph_idx: int, + subgraph_candidate_idx: int, + first_node: Node, + last_node: Node, + fqn: Optional[str], + list_of_node_name_to_qconfig: list[dict[str, QConfigAny]], + example_inputs: Any, + last_added_shadow_node_list: list[Optional[Node]], + custom_prepare_fn: Optional[Callable] = None, + custom_prepare_kwargs: Optional[dict[str, Any]] = None, +) -> None: + """ + Given a subgraph in `mt` and a subgraph candidate idx, inserts the + subgraph candidate copy and instruments it with loggers. + + If subgraph_candidate_idx is 0, this is the baseline fp32 subgraph and we just + add a logger to the end. + + If subgraph_candidate_idx is not 0, we create a copy of the subgraph and + prepare it with `prepare_fx`. + """ + + # TODO(future PR): move logger classes to utils to remove circular dependency + from torch.ao.ns._numeric_suite_fx import OutputComparisonLogger, OutputLogger + + if subgraph_candidate_idx == 0: + # idx = 0 is the floating point (original) version of the subgraph + # We keep the subgraph as is, and add a logger at the end + + qconfig_str = "" + logger_mod_orig = _get_logger_for_subgraph( + mt, + first_node, + last_node, + subgraph_idx, + subgraph_candidate_idx, + qconfig_str, + OutputLogger, + fqn, + ) + + attr_name = _get_attr_name(subgraph_idx, subgraph_candidate_idx) + assert not hasattr(mt, attr_name) + setattr(mt, attr_name, logger_mod_orig) + with mt.graph.inserting_after(last_node): + new_node = mt.graph.call_module(attr_name, args=(last_node,), kwargs={}) + last_added_shadow_node_list[0] = new_node + + else: + # idx > 0 means we have a candidate qconfig to try, so we need + # to make a copy of the subgraph, feed it with the right inputs, + # and add a logger at the end + + # get the qconfig + # subtract one because the first candidate is the floating point + # version of the subgraph + node_name_to_qconfig = list_of_node_name_to_qconfig[subgraph_candidate_idx - 1] + qconfig = node_name_to_qconfig[first_node.name] + + # if no quantization is requested, skip + # TODO(future PR): deduplicate equivalent qconfigs that come from + # different qconfig mapping objects + if qconfig is None: + return + + qconfig_mapping = QConfigMapping().set_global(qconfig) + + # create a copy of the submodule, wrapped in a separate module + orig_mod_copy_wrapped = create_submodule_from_subgraph( + mt, first_node, last_node + ) + + # add a call to prepare_fx on the wrapper module + if custom_prepare_fn is None: + orig_mod_copy_wrapped = torch.ao.quantization.quantize_fx.prepare_fx( + orig_mod_copy_wrapped, qconfig_mapping, example_inputs=example_inputs + ) + else: + if custom_prepare_kwargs is None: + custom_prepare_kwargs = {} + for kwarg_name in [ + "example_inputs", + "prepare_custom_config", + "qconfig_mapping", + ]: + assert kwarg_name not in custom_prepare_kwargs, ( + f"cannot specify {kwarg_name} in custom_prepare_kwargs" + ) + prepare_kwargs: dict[str, Any] = { + "example_inputs": example_inputs, + "qconfig_mapping": qconfig_mapping, + } + prepare_kwargs.update(custom_prepare_kwargs) + orig_mod_copy_wrapped = custom_prepare_fn( + orig_mod_copy_wrapped, **prepare_kwargs + ) + + # attach the wrapper to the model + attr_name = _get_attr_wrapper_name(subgraph_idx, subgraph_candidate_idx) + assert not hasattr(mt, attr_name) + setattr(mt, attr_name, orig_mod_copy_wrapped) + + # add a call to the wrapper module from the parent graph + insert_after_node = last_added_shadow_node_list[0] + with mt.graph.inserting_after(insert_after_node): + # TODO(future PR): handle fusion patterns where non-first nodes + # need inputs + + # pass in all node args and kwargs + + new_args = [] + for arg in first_node.args: + if isinstance(arg, Node): + new_args.append(arg) + elif ( + isinstance(arg, (list, tuple)) + and len(arg) + and isinstance(arg[0], Node) + ): + new_args.extend( + inner_arg for inner_arg in arg if isinstance(inner_arg, Node) + ) + + new_kwargs = {} + for name, old_kwarg in first_node.kwargs.items(): + if isinstance(old_kwarg, Node): + new_kwargs[name] = old_kwarg + elif isinstance(old_kwarg, (list, tuple)) and len(old_kwarg): + # TODO(future PR): clarify why we are adding kwargs to args + new_args.extend(old_kwarg) # type: ignore[arg-type] + + new_args = tuple(new_args) # type: ignore[assignment] + + new_node = mt.graph.call_module(attr_name, args=new_args, kwargs=new_kwargs) # type: ignore[arg-type] + + # add a logger to parent graph to observe the shadow wrapper + logger_mod_orig = _get_logger_for_subgraph( + mt, + first_node, + last_node, + subgraph_idx, + subgraph_candidate_idx, + str(qconfig), + OutputComparisonLogger, + fqn, + ) + + attr_name = _get_attr_name(subgraph_idx, subgraph_candidate_idx) + assert not hasattr(mt, attr_name) + setattr(mt, attr_name, logger_mod_orig) + with mt.graph.inserting_after(new_node): + logger = mt.graph.call_module( + attr_name, args=(new_node, last_node), kwargs={} + ) + last_added_shadow_node_list[0] = logger + + mt.recompile() + + +def create_n_transformed_and_logged_copies_of_subgraph( + mt: GraphModule, + subgraph_idx: int, + match_name: str, + nodes_in_this_subgraph: list[Any], + qconfig_mappings: list[QConfigMapping], + list_of_node_name_to_qconfig: list[dict[str, QConfigAny]], + custom_prepare_fn: Optional[Callable] = None, + custom_prepare_kwargs: Optional[dict[str, Any]] = None, +) -> None: + """ + Given a model `mt` and a subgraph_idx, creates the needed copies + of the subgraph for all qconfigs, and instruments them with loggers. + """ + # for now, assume that + # 1. the first node has one input + # 2. the last node has one output + + # for now, ignore all subgraphs that contain non-nodes (tuples, etc) + # TODO(future PR): implement this + if any(not isinstance(node, Node) for node in nodes_in_this_subgraph): + return + + first_node = nodes_in_this_subgraph[0] + last_node = nodes_in_this_subgraph[-1] + # We used output propagation to populate example values on each + # node. Use the example values from the previous node as the input + # to the current node. + prev_node = get_normalized_nth_input(first_node, mt, 0) + if isinstance(prev_node, list): + example_inputs = [x.traced_result for x in prev_node] + elif isinstance(prev_node, tuple): + example_inputs = (x.traced_result for x in prev_node) # type: ignore[assignment] + else: + # currently some customer models do not have a traced_result in + # every node, so we have to guard for this case since we cannot + # quantize without an example input + # TODO(future PR): add a test case for this once we have an easy + # repro, see https://github.com/pytorch/pytorch/pull/80521/files#r975940489 + # for additional context + if hasattr(prev_node, "traced_result"): + example_inputs = (prev_node.traced_result,) # type: ignore[attr-defined, assignment] + else: + print( + "unable to get example input for node " + + f"{first_node.format_node()}, skipping" + ) + return + + # If there are no quantization configs for this subgraph, skip adding + # loggers. This reduces memory usage for models where not all layers are + # quantized. + # TODO(future): consider making this configurable + found_at_least_one_qconfig = False + for subgraph_candidate_idx in range(len(qconfig_mappings) + 1): + if subgraph_candidate_idx == 0: + # fp32 baseline does not need a qconfig + continue + + # a. we have N shadows, so len(qconfig_mappings) is N + # b. we will have the fp32 layer + N shadows, so overall number of + # (original_op) + (*shadows) will be N+1 + # c. since `subgraph_candidate_idx` represents (b), we need + # to subtract 1 to query from (a) + node_name_to_qconfig = list_of_node_name_to_qconfig[subgraph_candidate_idx - 1] + qconfig = node_name_to_qconfig[first_node.name] + if qconfig is not None: + found_at_least_one_qconfig = True + break + if not found_at_least_one_qconfig: + print( + "unable to find at least one qconfig for node " + + f"{first_node.format_node()}, skipping" + ) + return + + fqn = _maybe_get_fqn(first_node, mt) + + # We want the results to contain the subgraphs in natural order, + # and the graph to also contain shadow wrappers and shadow loggers + # in natural order. + # If we just iterate in reverse, the graph will be in natural + # order but the eventual results will be in reverse order. + # So, we keep track of the last shadow logger we added and + # always insert after it. + last_added_shadow_node_list: list[Optional[Node]] = [None] + for subgraph_candidate_idx in range(len(qconfig_mappings) + 1): + create_one_transformed_and_logged_copy_of_subgraph( + mt, + subgraph_idx, + subgraph_candidate_idx, + first_node, + last_node, + fqn, + list_of_node_name_to_qconfig, + example_inputs, + last_added_shadow_node_list, + custom_prepare_fn, + custom_prepare_kwargs, + ) + + +def create_add_loggers_graph( + model: GraphModule, + subgraphs_dedup: dict[str, list[Node]], + qconfig_mapping: QConfigMapping, + node_name_to_qconfig: dict[str, QConfigAny], +) -> None: + r""" + Given a model, a model graph partition (currently a set of matched + subgraphs) and instructions how to transform each subgraph + (currently quantizing it according to qconfig_mapping), modifies + the model graph to create an alternate path through the original graph, + with each of the subgraphs quantized. This is useful to compare + propagation error of a transformation such as quantization. + + For example, given layer op0 and op1, there are four cases when handling op1: + 1. op0 and op1 quantized + 2. op0 and op1 unquantized + 3. op0 quantized, op1 unquantized + 4. op0 unquantized, op1 quantized + + Example input, case 1: + + .. code:: + + x0_0 -> op0_0 -> x1_0 -> log -----> op1_0 -> x2_0 -> log + \ \ \ \ # noqa: W605 + ---> op0_1 -> x1_1 ----> clog op1_1 -> x2_1 ----> clog + + Example output, case 1: + + .. code:: + + x0_0 -> op0_0 -> x1_0 -> log -----> op1_0 -> x2_0 -> log + \ \ \ # noqa: W605 + ---> op0_1 -> x1_1 ----> clog -> op1_1 -> x2_1 ----> clog + + """ + # TODO(future PR): move logger classes to utils to remove circular dependency + from torch.ao.ns._numeric_suite_fx import OutputComparisonLogger, OutputLogger + + def _get_subgraph_containing_node(node, subgraphs_dedup): + for subgraph in subgraphs_dedup.values(): + if node in subgraph: + return subgraph + return None + + # First, we need to create shadow branches, going from + # + # x0 -> op0 -> x1 -> ... + # + # + # to + # + # x0 -> op0_0 -> x1_0 -> log -> ... + # \ \ + # -> op0_1 -> x1_1 -> clog + # + # Later, the outputs of each shadow will be rerouted to calculate + # propagation error. + + # Note: we cannot iterate over matched subgraphs because some nodes + # may not be matched. So, we iterate over nodes in the graph, and + # associate them to matched subgraphs if possible. + + nodes_to_skip = set() + # for each subgraph, save a mapping from first node of subgraph + # to first and last node of the shadow of this subgraph + orig_first_node_to_shadow_in_node = {} + orig_first_node_to_shadow_out_node = {} + # need to record original list because we will mutate the graph as we go + orig_nodes = list(model.graph.nodes) # type: ignore[union-attr, arg-type] + cur_subgraph_idx = 0 + for n in orig_nodes: + if n.op in ("placeholder", "get_attr", "output") or n in nodes_to_skip: + continue + + maybe_subgraph = _get_subgraph_containing_node(n, subgraphs_dedup) + insert_submodule_copy = False + if maybe_subgraph is not None: + first_node, last_node = maybe_subgraph[0], maybe_subgraph[-1] + nodes_to_skip.update(maybe_subgraph) + qconfig = node_name_to_qconfig[first_node.name] + if qconfig is not None: + insert_submodule_copy = True + else: + first_node, last_node = n, n + + if insert_submodule_copy: + match_name = first_node.name + create_n_transformed_and_logged_copies_of_subgraph( + model, + cur_subgraph_idx, + match_name, + maybe_subgraph, + [qconfig_mapping], + [node_name_to_qconfig], + None, + None, # type: ignore[arg-type] + ) + # find the created shadow module and record it so we + # can find it easily in step 2 + expected_shadow_target = f"shadow_wrapper_{cur_subgraph_idx}_1" + new_shadow_mod = None + for maybe_shadow_mod in model.graph.nodes: + if ( + maybe_shadow_mod.op == "call_module" + and maybe_shadow_mod.target == expected_shadow_target + ): + new_shadow_mod = maybe_shadow_mod + break + assert new_shadow_mod is not None + orig_first_node_to_shadow_in_node[first_node] = new_shadow_mod + orig_first_node_to_shadow_out_node[first_node] = new_shadow_mod + + else: + # create a copy of the subgraph by only copying FX nodes + # but not copying any parameters, to minimize memory usage + subgraph_to_use = ( + maybe_subgraph if maybe_subgraph is not None else [first_node] + ) + + # add a regular logger after last_node + qconfig_str = "" + subgraph_candidate_idx = 0 + fqn = _maybe_get_fqn(first_node, model) + logger_mod_orig = _get_logger_for_subgraph( + model, + first_node, + last_node, + cur_subgraph_idx, + subgraph_candidate_idx, + qconfig_str, + OutputLogger, + fqn, + ) + attr_name = _get_attr_name(cur_subgraph_idx, subgraph_candidate_idx) + assert not hasattr(model, attr_name) + setattr(model, attr_name, logger_mod_orig) + insertion_point = last_node + with model.graph.inserting_after(insertion_point): + logger = model.graph.call_module( + attr_name, args=(last_node,), kwargs={} + ) + insertion_point = logger + + # create a copy of the subgraph + cur_node_orig = first_node + cur_node_copy = None + first_node_copy = None + while cur_node_orig in subgraph_to_use: + # TODO(future PR): make this support all possible args/kwargs + if cur_node_orig is first_node: + new_args = cur_node_orig.args + new_kwargs = cur_node_orig.kwargs + else: + first_arg_for_copy: Optional[Node] = cur_node_copy + new_args = (first_arg_for_copy, *cur_node_orig.args[1:]) + new_kwargs = cur_node_orig.kwargs + # make a copy of cur_node_orig + with model.graph.inserting_after(insertion_point): + cur_node_copy = model.graph.create_node( + cur_node_orig.op, + cur_node_orig.target, + new_args, + new_kwargs, + # cur_node_orig.name, # TODO(future PR): set name explicitly + ) + if first_node_copy is None: + first_node_copy = cur_node_copy + # since now only linear subgraphs are supported, all nodes + # except the last one must have only one user + if cur_node_orig != last_node: + assert len(cur_node_orig.users.keys()) == 1 + cur_node_orig = next(iter(cur_node_orig.users.keys())) + assert not cur_node_orig.name.startswith(SHADOW_NODE_NAME_PREFIX) + insertion_point = cur_node_copy + + # add a comparison logger after last_node's copy + subgraph_candidate_idx = 1 + logger_mod_orig = _get_logger_for_subgraph( + model, + first_node, + last_node, + cur_subgraph_idx, + subgraph_candidate_idx, + qconfig_str, + OutputComparisonLogger, + fqn, + ) + attr_name = _get_attr_name(cur_subgraph_idx, subgraph_candidate_idx) + assert not hasattr(model, attr_name) + setattr(model, attr_name, logger_mod_orig) + with model.graph.inserting_after(insertion_point): + logger = model.graph.call_module( + attr_name, args=(cur_node_copy, last_node), kwargs={} + ) + + # save the final node so we can use it in step 2 + orig_first_node_to_shadow_in_node[first_node] = first_node_copy + orig_first_node_to_shadow_out_node[first_node] = cur_node_copy + + cur_subgraph_idx += 1 + + model.recompile() + + # Now, we go from + # + # x0 -> op0_0 -> x1_0 -> log -> x1 -> op1_0 -> ... + # \ \ \ + # -> op0_1 -> x1_1 -> clog -> op1_1 -> ... + # + # to + # + # x0 -> op0_0 -> x1_0 -> log --> x1_0 -> op1_0 -> ... + # \ \ + # -> op0_1 -> x1_1 -> clog -> x1_1 -> op1_1 -> ... + # + # sample values of key internal variables for the example above: + # + # orig_first_node_to_shadow_in_node = {op0_0: op0_1, op1_0: op1_1} + # orig_first_node_to_shadow_out_node = {op0_0: op0_1, op1_0: op1_1} + # + # note: for subgraphs with more than one node, in_node will be different + # compared to out_node + + nodes_to_skip = set() + for n in orig_nodes: + if n.op in ("placeholder", "get_attr", "output") or n in nodes_to_skip: + continue + + maybe_subgraph = _get_subgraph_containing_node(n, subgraphs_dedup) + if maybe_subgraph is not None: + first_node, last_node = maybe_subgraph[0], maybe_subgraph[-1] + nodes_to_skip.update(maybe_subgraph) + else: + first_node, last_node = n, n + + def maybe_remap_node_to_shadow(node): + """ + If unshadowed `node` has a shadow version, return that. If not, + return `node`. + """ + if not isinstance(node, Node): + # handle scalars + return node + + if node.op in ("placeholder", "get_attr"): + return node + + # Find the shadowed version of this arg from the previous + # subgraph. For this, we need to: + # 1. navigate to the first node of the previous subgraph + # 2. get the output of the shadow wrapper which has (1) as an input + + # For now, assume the arg is in matched subgraphs. In the + # future we may have to handle the case where this is not true. + prev_subgraph = _get_subgraph_containing_node(node, subgraphs_dedup) + if prev_subgraph is None: + prev_subgraph = [node] + prev_first_node = prev_subgraph[0] + prev_shadow_output = orig_first_node_to_shadow_out_node[prev_first_node] + return prev_shadow_output + + cur_shadow_input = orig_first_node_to_shadow_in_node[first_node] + assert cur_shadow_input is not None + cur_shadow_input.args = tree_map( + maybe_remap_node_to_shadow, cur_shadow_input.args + ) + cur_shadow_input.kwargs = tree_map( + maybe_remap_node_to_shadow, cur_shadow_input.kwargs + ) + + model.recompile() + + +def _get_weight_info_from_shadow_wrapper(shadow_wrapper: torch.nn.Module): + # input: shadow wrapper module + # output if shadow wrapper module has a weighted op: + # (quantize_fn, (quantize_fn_args)) + # output if shadow wrapper module doesn't have a weighted op: + # None + + # For now, assume that the weight is the second input + # to the shadow module. If that changes, we can fix it later. + placeholders_seen = 0 + for shadow_n in shadow_wrapper.graph.nodes: # type: ignore[union-attr] + if shadow_n.op != "placeholder": + continue + + placeholders_seen += 1 + if placeholders_seen != 2: + continue + + # the subgraph looks like + # + # _input_scale_1 = self._input_scale_1 + # _input_zero_point_1 = self._input_zero_point_1 + # quantize_per_channel = torch.quantize_per_channel( + # w2_0, _input_scale_1, _input_zero_point_1, + # 0, torch.qint8) + # + # we have `w2_0`, and are navigating this subgraph + # to get `_input_scale_1` and `_input_zero_point_1` + + assert len(shadow_n.users) == 1 + quant_node = next(iter(shadow_n.users.keys())) + new_args: Any = None + if quant_node.target == torch.quantize_per_channel: + _weight, scale_node, zp_node, axis, dtype = quant_node.args + scale_val = getattr_from_fqn(shadow_wrapper, scale_node.target) + zp_val = getattr_from_fqn(shadow_wrapper, zp_node.target) + new_args = (scale_val, zp_val, axis, dtype) + else: + assert quant_node.target == torch.quantize_per_tensor + _weight, scale_node, zp_node, dtype = quant_node.args + scale_val = getattr_from_fqn(shadow_wrapper, scale_node.target) + zp_val = getattr_from_fqn(shadow_wrapper, zp_node.target) + new_args = (scale_val, zp_val, dtype) + return (quant_node.target, new_args) + + return None + + +def extract_weight_comparison(m: GraphModule) -> NSResultsType: + # example graph: + # + # w1 = self.w1 + # b1 = self.b1 + # linear = torch._C._nn.linear(x, w1, b1) + # shadow_0_0 = self.shadow_0_0(linear) + # shadow_wrapper_0_1 = self.shadow_wrapper_0_1(x, w1, b1) + # shadow_0_1 = self.shadow_0_1(shadow_wrapper_0_1, linear) + # + # algorithm: + # 1. for each call_function node matching our allowlist: + # 2. if corresponding shadow wrapper exists, extract the weight pair + # + # Note: this is not super robust, but that's ok because this is + # just for legacy customers who depend on the previous two-model version + # of this API. TBD if we need to make this robust. + # Note: modules are not supported, since existing customers only + # use functions. + + # TODO(future PR): move this to config + weighted_ops = { + torch.nn.functional.linear, + } + + results: NSResultsType = {"model": {NSSingleResultValuesType.WEIGHT.value: {}}} + + for n in m.graph.nodes: # type: ignore[union-attr] + if not (n.op == "call_function" and n.target in weighted_ops): + continue + + # Check if we have a corresponding shadow wrapper + # TODO(future PR, if needed): support kwargs + # TODO(future PR, if needed): support multiple shadow users + first_arg = n.args[0] + shadow_wrapper_node = None + for user in first_arg.users: + # TODO(before land): fix string match + if user.op == "call_module" and user.target.startswith("shadow_wrapper"): + shadow_wrapper_node = user + break + + if shadow_wrapper_node is None: + continue + + shadow_wrapper = getattr_from_fqn(m, shadow_wrapper_node.target) # type: ignore[arg-type] + weight_info = _get_weight_info_from_shadow_wrapper(shadow_wrapper) + if weight_info is None: + continue + + # get weight + w_node = n.args[1] + w_obj = getattr_from_fqn(m, w_node.target).detach() + + # get a quantized version of weight + quant_fn, quant_fn_args_except_first = weight_info + new_args = (w_obj, *quant_fn_args_except_first) + w_obj_q = quant_fn(*new_args) + + # add a comparison + ref_node_name = n.name + prev_node_name = n.name + ref_node_type = get_target_type_str(n, m) + prev_node_type = ref_node_type + fqn = None + if hasattr(m, "_node_name_to_scope"): + fqn = m._node_name_to_scope[n.name][0] # type: ignore[index] + comparison = torch.ao.ns.fx.utils.compute_sqnr(w_obj, w_obj_q) + result_fp32 = { + "res_type": NSSingleResultValuesType.WEIGHT.value, + "values": [w_obj], + "prev_node_name": prev_node_name, + "prev_node_target_type": prev_node_type, + "ref_node_name": ref_node_name, + "ref_node_target_type": ref_node_type, + "index_within_arg": 0, + "index_of_arg": 0, + "fqn": fqn, + "qconfig_str": "", + "comparisons": [comparison], + "comparison_fn_name": "sqnr", + } + result_q = { + "res_type": NSSingleResultValuesType.WEIGHT.value, + "values": [w_obj_q], + "prev_node_name": prev_node_name, + "prev_node_target_type": prev_node_type, + "ref_node_name": ref_node_name, + "ref_node_target_type": ref_node_type, + "index_within_arg": 0, + "index_of_arg": 0, + "fqn": fqn, + "qconfig_str": "", + "comparisons": [comparison], + "comparison_fn_name": "sqnr", + } + + # go from subgraph_n_1 to subgraph_n_0 + _1, _2, node_idx, _3 = shadow_wrapper_node.target.split("_") + name_fp32 = f"subgraph_{node_idx}_0" + name_q = f"subgraph_{node_idx}_1" + + results["model"][NSSingleResultValuesType.WEIGHT.value][name_fp32] = [ + result_fp32 + ] + results["model"][NSSingleResultValuesType.WEIGHT.value][name_q] = [result_q] + + return results + + +# TODO(future PR): redesign this to make it easier to consume outputs +def group_results_by_subgraph(results: NSResultsType) -> Any: + """ + Creates a comparison of results + + Input: + + { + 'model': { + 'node_output': { + 'subgraph_0_0': [ + 'values': [torch.tensor(...), ...], ... + 'ref_node_name': ..., + 'ref_node_target_type': ..., + 'qconfig_str': ..., + 'comparisons': [], ... + 'comparison_fn_name': '', + 'fqn': '...', + ], + 'subgraph_0_1': [ + 'values': [torch.tensor(...), ...], ... + 'ref_node_name': ..., + 'ref_node_target_type': ..., + 'qconfig_str': ..., + 'comparisons': [torch.tensor(...), ...], ... + 'comparison_fn_name': '...', + 'fqn': '...', + ], + ... + }, + }, + } + + Output: + { + 'subgraph_0': { + '0': { + 'ref_node_name': '...', + 'ref_node_target_type': ..., + 'values': [torch.tensor(...), ...], + 'qconfig_str': None, + 'comparisons': [torch.tensor(...), ...], ... + 'comparison_fn_name': '...', + 'fqn': '...', + }, + '1': { + 'ref_node_name': '...', + 'ref_node_target_type': ..., + 'values': [torch.tensor(...), ...], + 'qconfig_str': '...', + 'comparisons': [torch.tensor(...), ...], ... + 'comparison_fn_name': '...', + 'fqn': '...', + }, + }, + } + + """ + subgraph_name_to_subgraph_results: Any = collections.defaultdict(dict) + + # node_output or weight + key_to_use = next(iter(results["model"].keys())) + + for subgraph_name_with_idx, subgraph_candidate_results in results["model"][ + key_to_use + ].items(): + # convert from `subgraph_m_n` to `subgraph_m` and `n` + ( + subgraph_str, + subgraph_idx, + subgraph_candidate_idx, + ) = subgraph_name_with_idx.split("_") + subgraph_name = f"{subgraph_str}_{subgraph_idx}" + + subgraph_results = { + "ref_node_name": subgraph_candidate_results[0]["ref_node_name"], + "ref_node_target_type": subgraph_candidate_results[0][ + "ref_node_target_type" + ], + "fqn": subgraph_candidate_results[0]["fqn"], + "values": subgraph_candidate_results[0]["values"], + "qconfig_str": subgraph_candidate_results[0]["qconfig_str"], + "comparisons": subgraph_candidate_results[0]["comparisons"], + "comparison_fn_name": subgraph_candidate_results[0]["comparison_fn_name"], + } + + subgraph_name_to_subgraph_results[subgraph_name][subgraph_candidate_idx] = ( + subgraph_results + ) + + return dict(subgraph_name_to_subgraph_results) + + +# TODO(future PR): redesign this to make it easier to consume outputs +def create_results_comparison( + results_grouped, +) -> Any: + """ + Input: + + { + 'subgraph_0': { + '0': { + 'ref_node_name': '...', + 'ref_node_target_type': ..., + 'values': [torch.tensor(...), ...], + 'qconfig_str': '', + 'comparisons': [], + 'comparison_fn_name': '', + 'fqn': '...', + }, + '1': { + 'ref_node_name': '...', + 'ref_node_target_type': ..., + 'values': [torch.tensor(...), ...], + 'qconfig_str': '...', + 'comparisons': [torch.tensor(...), ...], + 'comparison_fn_name': 'sqnr', + 'fqn': '...', + }, + }, + } + + Output: + { + 'subgraph_0': { + 'ref_node_name': '...', + 'ref_node_target_type': '...', + 'fqn': '...', + 'candidates': { + '1': { + 'qconfig_str': ..., + 'comparison_fn_name': 'sqnr', + 'cmp_raw': [..., ...], + 'cmp_mean': ..., + }, + ..., + }, + }, + } + """ + + results_comparison = {} + + for subgraph_name, subgraph_results in results_grouped.items(): + candidates = {} + for subgraph_inner_name, subgraph_inner_result in subgraph_results.items(): + # skip comparing baseline to baseline + if subgraph_inner_name == "0": + continue + + # we expect the comparisons to be precalculated from + # calibration, so we just fetch them here + cmp_raw = subgraph_inner_result["comparisons"] + cmp_raw_tensor = torch.stack(cmp_raw) + + candidates[subgraph_inner_name] = { + "qconfig_str": subgraph_inner_result["qconfig_str"], + "comparison_fn_name": subgraph_inner_result["comparison_fn_name"], + "cmp_raw": cmp_raw_tensor, + "cmp_mean": torch.mean(cmp_raw_tensor), + } + + results_comparison[subgraph_name] = { + "ref_node_name": subgraph_results["0"]["ref_node_name"], + "ref_node_target_type": subgraph_results["0"]["ref_node_target_type"], + "fqn": subgraph_results["0"]["fqn"], + "candidates": candidates, + } + + return results_comparison + + +# TODO(future PR): redesign this to make it easier to consume outputs +def print_n_shadows_summary( + results_comparison, +) -> None: + """ + Input: + + { + 'subgraph_0': { + 'ref_node_name': 'linear1', + 'ref_node_target_type': '...', + 'fqn': '...', + 'candidates': { + '1': { + 'qconfig_str': ..., + 'comparison_fn_name': ..., + 'cmp_raw': [45.0, 55.0], + 'cmp_mean': 50.0, + }, + ..., + }, + }, + } + + Prints: + + node_name | node_type | fqn | 0 | 1 | ... + linear1 | ... | ... | 45.0 | 50.0 | ... + """ + + try: + from tabulate import tabulate + except ImportError: + print( + "`print_tabular` relies on the library `tabulate`, " + "which could not be found on this machine. Run `pip " + "install tabulate` to install the library." + ) + return + + results = [] + for subgraph_data in results_comparison.values(): + mean_all_candidates = [ + candidate["cmp_mean"] + for candidate_name, candidate in subgraph_data["candidates"].items() + ] + + data_row = [ + subgraph_data["ref_node_name"], + subgraph_data["ref_node_target_type"], + subgraph_data["fqn"], + *mean_all_candidates, + ] + results.append(data_row) + + max_candidate_idx_len = -1 + for data_row in results: + max_candidate_idx_len = max(max_candidate_idx_len, len(data_row[1])) + candidate_idx_headers = [str(x) for x in range(max_candidate_idx_len)] + + headers = ["node_name", "node_type", "fqn", *candidate_idx_headers] + print(tabulate(results, headers=headers)) diff --git a/phivenv/Lib/site-packages/torch/ao/ns/fx/ns_types.py b/phivenv/Lib/site-packages/torch/ao/ns/fx/ns_types.py new file mode 100644 index 0000000000000000000000000000000000000000..a43bafbfb22078a35bac27ec2a08dfb962489763 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/ns/fx/ns_types.py @@ -0,0 +1,65 @@ +import enum +from typing import Any, Callable, NamedTuple, Union + +from torch.fx.graph import Node + + +class NSSingleResultValuesType(str, enum.Enum): + WEIGHT = "weight" + NODE_OUTPUT = "node_output" + NODE_INPUT = "node_input" + + +class NSSubgraph(NamedTuple): + start_node: Node + end_node: Node + base_op_node: Node + + +# TODO(future PR): see if we can use typing_extensions's TypedDict instead +# to properly type the various keys +# { +# # one of NSSingleResultValuesType +# 'type': 'weight', +# # the values of type specified above +# 'values': [torch.tensor(...), ...], +# # name of the node directly before the logger +# 'prev_node_name': 'linear1', +# # type of the underlying function or module +# 'prev_node_target_type': torch.nn.functional.linear # or torch.nn.Linear, etc +# # name of the node responsible for adding this logger +# # Note: this may differ from prev_node_name if we are logging inputs +# 'ref_node_name': 'linear1', +# # index of this node within the arg of the input/output node +# # for example, in cat([x1, x2, x3], dim=0), x2 would have index_within_arg == 1 +# 'index_within_arg': 0, +# # index of this node within the args of the input/output node +# # for example, in add(x1, x2), x2 would have index_of_arg == 1 +# 'index_of_arg': 0, +# # precomputed comparisons of logger values to reference values +# 'comparisons': [torch.tensor(...), ...] +# # name of function used for precomputed comparisons +# 'comparison_fn_name': 'sqnr', +# # string representation of qconfig responsible for creating this logger +# 'qconfig_str': 'QConfig(...)', +# } +NSSingleResultType = dict[str, Any] + +# { +# 'layer_name_1': { # subgraph name +# 'node_output': { # results type (node_output, node_input, weight) +# 'model_name_a': # model name +# [NSSingleResultType, ...], # results, ordered by index_within_arg +# 'model_name_b': +# [NSSingleResultType, ...], +# }, +# }, +# } +# +NSResultsType = dict[str, dict[str, dict[str, list[NSSingleResultType]]]] + +# Defines the underlying target type of a node, for example: +# `F.conv1d` for a `call_function` conv node +# `nn.Conv1d` for a `call_module` node calling the forward of a `nn.Conv1d` module +# `'sigmoid'` for a `call_method` node calling `x.sigmoid()` +NSNodeTargetType = Union[Callable, str] diff --git a/phivenv/Lib/site-packages/torch/ao/ns/fx/pattern_utils.py b/phivenv/Lib/site-packages/torch/ao/ns/fx/pattern_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d0fbd46968d36242871ae1dab38f104bff0c6289 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/ns/fx/pattern_utils.py @@ -0,0 +1,209 @@ +from typing import Any, Callable, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.ao.quantization import FakeQuantizeBase, ObserverBase +from torch.ao.quantization.backend_config import get_native_backend_config +from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers +from torch.ao.quantization.utils import getattr_from_fqn +from torch.fx import GraphModule +from torch.fx.graph import Node + +from .ns_types import NSNodeTargetType + + +toq = torch.ops.quantized + + +def get_type_a_related_to_b( + base_name_to_sets_of_related_ops: dict[str, set[NSNodeTargetType]], +) -> set[tuple[NSNodeTargetType, NSNodeTargetType]]: + # TODO(future PR): allow customizations + # TODO(future PR): reuse existing quantization mappings + # TODO(future PR): add the rest of modules and ops here + type_a_related_to_b: set[tuple[NSNodeTargetType, NSNodeTargetType]] = set() + + for s in base_name_to_sets_of_related_ops.values(): + s_list = list(s) + # add every bidirectional pair + for idx_0 in range(0, len(s_list)): + for idx_1 in range(idx_0, len(s_list)): + type_a_related_to_b.add((s_list[idx_0], s_list[idx_1])) + type_a_related_to_b.add((s_list[idx_1], s_list[idx_0])) + + return type_a_related_to_b + + +NSFusionElType = Union[ + Callable, # call_function or call_module type, example: F.linear or nn.Conv2d + str, # call_method name, example: "dequantize" + tuple[ + str, Any + ], # call_method name and first argument, example: ("to", torch.float16) +] +NSFusionType = Union[ + tuple[NSFusionElType, NSFusionElType], + tuple[NSFusionElType, NSFusionElType, NSFusionElType, NSFusionElType], +] + + +def get_reversed_fusions() -> list[tuple[NSFusionType, int]]: + """ + Set of potential fusions, in reverse order. The order is reversed + to match how fusion patterns are defined in quantization code. + + Fusion format: + ((fusion_op_0, fusion_op_1), base_op_idx) + + Where base_op_idx is the idx of the op we should use to match other related + ops. Note: base_op_idx is specified in non-reverse order, i.e. a base_op_idx + of 0 represents the first op in regular (non-reverse) order, 1 represents the + second op, etc. + """ + results: list[tuple[NSFusionType, int]] = [] + + # Possible syntaxes: + # * single op: torch.nn.Conv2d + # * multiple ops: (torch.nn.ReLU, torch.nn.Conv2d) + # For fusions, we only care about patterns composed of multiple ops. + # TODO(future PR): allow customizations from default patterns. + all_quant_patterns = _get_pattern_to_quantize_handlers(get_native_backend_config()) + + default_base_op_idx = 0 + for quant_pattern in all_quant_patterns.keys(): + # TODO: this is a temporary hack to flatten the patterns from quantization so + # that it works with the ns matcher function, maybe we should use `_is_match` + # in torch.ao.quantization.fx.match_utils to match the patterns + if ( + isinstance(quant_pattern, tuple) + and len(quant_pattern) == 2 + and isinstance(quant_pattern[1], tuple) + and len(quant_pattern[1]) == 2 + ): + # flatten the pattern with form (nn.ReLU, (nn.BatchNorm2d, nn.Conv2d)) + quant_pattern = (quant_pattern[0], quant_pattern[1][0], quant_pattern[1][1]) + + # Only patterns of multiple ops are fusions, ignore + # patterns which contain a single ops (they get matched + # without caring about fusions). + if isinstance(quant_pattern, tuple): + results.append((quant_pattern, default_base_op_idx)) # type: ignore[arg-type] + + # For each pattern, add additional patterns with observers and + # fake quants at the end. + # TODO(future PR): if needed, implement matching for a node + # having multiple output observers. + for cls in (ObserverBase, FakeQuantizeBase): + if isinstance(quant_pattern, tuple): + new_pattern = (cls, *quant_pattern) + else: + new_pattern = (cls, quant_pattern) + results.append((new_pattern, default_base_op_idx)) # type: ignore[arg-type] + + # After this point, results contains values such as + # [..., ((torch.nn.Relu, torch.nn.Conv2d), 0), ...] + + # Patterns for matching fp16 emulation are not specified in the quantization + # fusion mappings. For now, define them here. + fp16_em_base_op_idx = 1 + patterns_to_add = [ + # linear-relu fp16 emulation: + # fp16_to_fp32 -> linear -> relu -> fp32_to_fp16 + ( + (("to", torch.float16), F.relu, F.linear, "dequantize"), + fp16_em_base_op_idx, + ), + # Conv-BN fusion (this happens outside of quantization patterns, + # which is why it is defined separately here). + ((nn.BatchNorm1d, nn.Conv1d), default_base_op_idx), + ((nn.BatchNorm2d, nn.Conv2d), default_base_op_idx), + ((nn.BatchNorm3d, nn.Conv3d), default_base_op_idx), + ((nn.ReLU, nn.BatchNorm1d, nn.Conv1d), default_base_op_idx), + ((nn.ReLU, nn.BatchNorm2d, nn.Conv2d), default_base_op_idx), + ((nn.ReLU, nn.BatchNorm3d, nn.Conv3d), default_base_op_idx), + ] + for p in patterns_to_add: + results.append(p) # type: ignore[arg-type] + results.append(((ObserverBase, *p[0]), p[1])) # type: ignore[arg-type] + results.append(((FakeQuantizeBase, *p[0]), p[1])) # type: ignore[arg-type] + + return results + + +def end_node_matches_reversed_fusion( + end_node: Node, + reversed_fusion: NSFusionType, + gm: GraphModule, + seen_nodes: set[Node], +) -> bool: + """ + Returns true if a pattern ending with `end_node` matches + the fusion pattern. + """ + cur_node = end_node + for fusion_idx in range(len(reversed_fusion)): + # each node can only belong to one matched pattern + if cur_node in seen_nodes: + return False + + cur_fusion_el = reversed_fusion[fusion_idx] + + if cur_node.op == "call_function": + fusion_el_is_fun = (not isinstance(cur_fusion_el, str)) and ( + not isinstance(cur_fusion_el, type) + ) + if fusion_el_is_fun: + if cur_node.target != cur_fusion_el: + return False + if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node): + cur_node = cur_node.args[0] + else: + return False + else: + return False + + elif cur_node.op == "call_module": + fusion_el_is_mod = isinstance(cur_fusion_el, type) + if fusion_el_is_mod: + assert isinstance(cur_node.target, str) + target_mod = getattr_from_fqn(gm, cur_node.target) + if not isinstance(cur_fusion_el, type): + return False + if not isinstance(target_mod, cur_fusion_el): + return False + if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node): + cur_node = cur_node.args[0] + else: + return False + else: + return False + + elif cur_node.op == "call_method": + fusion_el_is_meth_with_second_arg = ( + isinstance(cur_fusion_el, tuple) and len(cur_fusion_el) == 2 + ) + fusion_el_is_meth_without_args = isinstance(cur_fusion_el, str) + if fusion_el_is_meth_without_args or fusion_el_is_meth_with_second_arg: + if fusion_el_is_meth_without_args: + if cur_node.target != cur_fusion_el: + return False + else: + assert isinstance(cur_fusion_el, tuple) + if cur_node.target != cur_fusion_el[0]: + return False + elif len(cur_node.args) < 2: + return False + elif cur_node.args[1] != cur_fusion_el[1]: + return False + + if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node): + cur_node = cur_node.args[0] + else: + return False + else: + return False + else: + return False + + return True diff --git a/phivenv/Lib/site-packages/torch/ao/ns/fx/qconfig_multi_mapping.py b/phivenv/Lib/site-packages/torch/ao/ns/fx/qconfig_multi_mapping.py new file mode 100644 index 0000000000000000000000000000000000000000..b3aead99fd21834e8a11a0a0b6092ffaf89fb917 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/ns/fx/qconfig_multi_mapping.py @@ -0,0 +1,249 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import copy +from typing import Any, Callable, TYPE_CHECKING, Union + +import torch +from torch.ao.quantization import QConfigMapping +from torch.ao.quantization.qconfig_mapping import _QCONFIG_STYLE_ORDER + + +if TYPE_CHECKING: + from torch.ao.quantization.qconfig import QConfigAny + +__all__ = ["QConfigMultiMapping"] + +_QCONFIG_STYLE_TO_METHOD: dict[str, str] = { + "global_qconfig": "set_global", + "object_type_qconfigs": "set_object_type", + "module_name_regex_qconfigs": "set_module_name_regex", + "module_name_qconfigs": "set_module_name", + "module_name_object_type_order_qconfigs": "set_module_name_object_type_order", +} + + +def _remove_duplicates_and_none(qconfig_list: list[QConfigAny]) -> None: + to_remove = [] + for index, cur_qconfig in enumerate(qconfig_list): + if cur_qconfig is None: + to_remove.append(index) + break + for checked_qconfig in qconfig_list[:index]: + if torch.ao.quantization.qconfig_equals(cur_qconfig, checked_qconfig): + to_remove.append(index) + break + for index in to_remove[::-1]: + qconfig_list.pop(index) + + +class QConfigMultiMapping: + """ + This class, used with the prepare_n_shadows_model API, stores a list of :class:`torch.ao.quantization.QConfigMapping`s + so that multiple QConfigs can be specified for each QConfig matching style. + + The user can specify QConfigs using the following methods (in increasing match priority): + + ``set_global`` : sets the global (default) QConfigs + + ``set_object_type`` : sets the QConfigs for a given module type, function, or method name + + ``set_module_name_regex`` : sets the QConfigs for modules matching the given regex string + + ``set_module_name`` : sets the QConfigs for modules matching the given module name + + ``set_module_name_object_type_order`` : sets the QConfigs for modules matching a combination + of the given module name, object type, and the index at which the module appears + + Note: Usage of set methods is the same as in QConfigMapping except with a passed in list of QConfigs rather than a + single QConfig. + + Example usage:: + + qconfig_mapping = QConfigMultiMapping() + .set_global([qconfig1, qconfig2]) + .set_object_type(torch.nn.Linear, [qconfig2, qconfig3]) + .set_object_type(torch.nn.ReLU, [qconfig1]) + .set_module_name_regex("foo.*bar.*conv[0-9]+", [qconfig2]) + .set_module_name_regex("foo.*", [qconfig1, qconfig2, qconfig3]) + .set_module_name("module1", [None]) + .set_module_name("module2", [qconfig2]) + .set_module_name_object_type_order("foo.bar", torch.nn.functional.linear, 0, [qconfig3]) + + """ + + def __init__(self) -> None: + # initialize this with 1 QConfigMapping to avoid corner cases + self.qconfig_mappings_list: list[QConfigMapping] = [QConfigMapping()] + + def _handle_list_size_mismatch( + self, qconfig_list: list[QConfigAny], style: str + ) -> None: + # this method handles cases where the size of qconfig_list does not match + # the size of qconfig_mappings_list. + # Issue: Consider a user inserting global_qconfig A and B first, then inserting + # qconfig C as an object_type_qconfig for conv ops. If we internally store + # 1 QConfigMapping with A and C and another with just B, then the + # second QConfigMapping will match B to conv ops (which is not wanted), since B is global. + + # we avoid this by maintaining the invariant that if any QConfigMapping + # has a qconfig style+key with a qconfig in it, all QConfigMappings must + # have either a qconfig or None for that same style+key. In the above + # example, a None qconfig would prevent the unwanted match in the + # second QConfigMapping + + if len(qconfig_list) > len(self.qconfig_mappings_list): + # Case: we have more qconfigs (in qconfig_list) than QConfigMappings + + # Add new QConfigMappings (initialized so we maintain the `invariant`) + + new_qconfig_mapping = QConfigMapping() + # searches other QConfigMappings for qconfig style+keys + # that need to be inserted as `None` into the new QConfigMapping + for qconfig_mapping in self.qconfig_mappings_list: + # global_qconfig has None by default + for check_style in _QCONFIG_STYLE_ORDER[1:]: + qconfigs_dict = getattr(qconfig_mapping, check_style) + target_qconfigs_dict = getattr(new_qconfig_mapping, check_style) + for key in qconfigs_dict: + target_qconfigs_dict[key] = None + break + + # insert copies of this new QConfigMapping until all entires + # in qconfig_list can fit among the QConfigMappings + while len(qconfig_list) > len(self.qconfig_mappings_list): + self.qconfig_mappings_list.append(copy.deepcopy(new_qconfig_mapping)) + else: + # Case: we have fewer qconfigs in qconfig_list than QConfigMappings + + # pad qconfig_list with `None` until length is same + while len(qconfig_list) < len(self.qconfig_mappings_list): + qconfig_list.append(None) + + # this function applies the insertion method across each QConfigMapping + def _insert_qconfig_list( + self, + style: str, + args: list[Union[str, int, Callable]], + qconfig_list: list[QConfigAny], + ) -> None: + # we remove duplicates and None to make the ordering of qconfigs + # deterministic upon insertion. + _remove_duplicates_and_none(qconfig_list) + + self._handle_list_size_mismatch(qconfig_list, style) + method_name = _QCONFIG_STYLE_TO_METHOD[style] + for qconfig_mapping, qconfig in zip(self.qconfig_mappings_list, qconfig_list): + # uses QConfigMapping set method to insert qconfig + set_method = getattr(qconfig_mapping, method_name) + set_method(*args, qconfig) + + def set_global(self, global_qconfig_list: list[QConfigAny]) -> QConfigMultiMapping: + """ + Set global QConfigs + see :func:`~torch.ao.quantization.QConfigMapping.set_global()` for more info + """ + self._insert_qconfig_list("global_qconfig", [], global_qconfig_list) + return self + + def set_object_type( + self, object_type: Union[Callable, str], qconfig_list: list[QConfigAny] + ) -> QConfigMultiMapping: + """ + Set object type QConfigs + see :func:`~torch.ao.quantization.QConfigMapping.set_object_type()` for more info + """ + self._insert_qconfig_list("object_type_qconfigs", [object_type], qconfig_list) + return self + + def set_module_name_regex( + self, module_name_regex: str, qconfig_list: list[QConfigAny] + ) -> QConfigMultiMapping: + """ + Set module_name_regex QConfigs + see :func:`~torch.ao.quantization.QConfigMapping.set_module_name_regex()` for more info + """ + self._insert_qconfig_list( + "module_name_regex_qconfigs", [module_name_regex], qconfig_list + ) + return self + + def set_module_name( + self, module_name: str, qconfig_list: list[QConfigAny] + ) -> QConfigMultiMapping: + """ + Set module_name QConfigs + see :func:`~torch.ao.quantization.QConfigMapping.set_module_name()` for more info + """ + self._insert_qconfig_list("module_name_qconfigs", [module_name], qconfig_list) + return self + + def set_module_name_object_type_order( + self, + module_name: str, + object_type: Callable, + index: int, + qconfig_list: list[QConfigAny], + ) -> QConfigMultiMapping: + """ + Set module_name QConfigs + see :func:`~torch.ao.quantization.QConfigMapping.set_module_name_object_type_order()` for more info + """ + self._insert_qconfig_list( + "module_name_object_type_order_qconfigs", + [module_name, object_type, index], + qconfig_list, + ) + return self + + def __repr__(self): + return ( + self.__class__.__name__ + + " [" + + "".join( + f"\n{qconfig_mapping.__repr__()}," + for qconfig_mapping in self.qconfig_mappings_list + ) + + "\n]" + ) + + @classmethod + def from_list_qconfig_mapping( + cls, qconfig_mapping_list: list[QConfigMapping] + ) -> QConfigMultiMapping: + """ + Creates a QConfigMultiMapping from a list of QConfigMappings + """ + new_qconfig_multi_mapping = cls() + + new_qconfig_multi_mapping.qconfig_mappings_list = copy.deepcopy( + qconfig_mapping_list + ) + + # we need to avoid the issue described in _handle_list_size_mismatch, + # so we reinsert all the qconfigs using the QConfigMultiMapping + # set methods + + # go through all qconfig styles + # note: global can be ignored since it is None by default + for style in _QCONFIG_STYLE_ORDER[1:]: + # gather all key+qconfigs for current style + # into qconfig_dict_list + qconfig_dict_list: dict[Any, list[QConfigAny]] = {} + for qconfig_mapping in qconfig_mapping_list: + qconfig_dict = getattr(qconfig_mapping, style) + for key, qconfig in qconfig_dict.items(): + if key not in qconfig_dict_list: + qconfig_dict_list[key] = [] + qconfig_dict_list[key].append(qconfig) + + # reinsert all gathered key+qconfigs + set_method_name = _QCONFIG_STYLE_TO_METHOD[style] + set_method = getattr(new_qconfig_multi_mapping, set_method_name) + for key, qconfig_list in qconfig_dict_list.items(): + if isinstance(key, tuple): + set_method(*key, qconfig_list) + else: + set_method(key, qconfig_list) + + return new_qconfig_multi_mapping diff --git a/phivenv/Lib/site-packages/torch/ao/ns/fx/utils.py b/phivenv/Lib/site-packages/torch/ao/ns/fx/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ae7f7100abd78827e8c57b7e0602d0f789b428bf --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/ns/fx/utils.py @@ -0,0 +1,541 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import enum +import operator +from typing import Callable, Optional, Union + +import torch +import torch.ao.nn.intrinsic.quantized as nniq +import torch.ao.nn.quantized as nnq +import torch.nn as nn +from torch.ao.quantization import FakeQuantizeBase, ObserverBase +from torch.ao.quantization.observer import _is_activation_post_process +from torch.ao.quantization.utils import getattr_from_fqn +from torch.fx import GraphModule +from torch.fx.graph import Node + +from .ns_types import NSNodeTargetType, NSResultsType + + +toq = torch.ops.quantized + + +# TODO(future PR): consider deleting this enum and using the torch types +# directly. This might be tricky because it is not a one to one mapping. +class NodeInputOrOutputType(enum.Enum): + FP32 = enum.auto() # torch.float + INT8 = enum.auto() # torch.qint8 or torch.quint8 + FP16 = enum.auto() # torch.float16 + UNKNOWN = enum.auto() # we cannot determine input/output dtype + # TODO(future PR): while these functions can support multiple dtypes, + # for the purposes of numerical debugging we want to get the actual + # dtype used in the model. We will likely need some kind of dtype + # propagation to estimate this. + FP32_OR_INT8 = enum.auto() # either torch.float or torch.quint8 or torch.qint8 + # TODO(future PRs): dynamic quant, fake quant, etc + + +def get_node_first_input_and_output_type( + node: Node, + gm: GraphModule, + logger_cls: Callable, + node_type_to_io_type_map: dict[str, set[NSNodeTargetType]], +) -> tuple[NodeInputOrOutputType, NodeInputOrOutputType]: + # TODO(future PR): clean this up + FUNS_IO_TYPE_FP32 = node_type_to_io_type_map["funs_io_type_fp32"] + FUNS_IO_TYPE_FP16 = node_type_to_io_type_map["funs_io_type_fp16"] + FUNS_IO_TYPE_INT8 = node_type_to_io_type_map["funs_io_type_int8"] + FUNS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["funs_io_type_fp32_or_int8"] + MODS_IO_TYPE_FP32 = node_type_to_io_type_map["mods_io_type_fp32"] + MODS_IO_TYPE_INT8 = node_type_to_io_type_map["mods_io_type_int8"] + MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["mods_io_type_fp32_or_int8"] + METHS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["meths_io_type_fp32_or_int8"] + + if node.op == "call_function": + if node.target in FUNS_IO_TYPE_FP32: + return (NodeInputOrOutputType.FP32, NodeInputOrOutputType.FP32) + if node.target in FUNS_IO_TYPE_FP16: + return (NodeInputOrOutputType.FP16, NodeInputOrOutputType.FP16) + elif node.target in FUNS_IO_TYPE_INT8: + return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8) + elif node.target in FUNS_IO_TYPE_FP32_OR_INT8: + first_arg = get_normalized_nth_input(node, gm, 0) + assert isinstance(first_arg, Node) + ( + _prev_node_input_type, + prev_node_output_type, + ) = get_node_first_input_and_output_type( + first_arg, gm, logger_cls, node_type_to_io_type_map + ) + return (prev_node_output_type, prev_node_output_type) + else: + return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN) + + elif node.op == "call_module": + assert node.op == "call_module" + assert isinstance(node.target, str) + mod = getattr_from_fqn(gm, node.target) + is_known_fp32_or_int8_input_module = any( + isinstance(mod, target_type) # type: ignore[arg-type] + for target_type in MODS_IO_TYPE_FP32_OR_INT8 + ) + if ( + isinstance(mod, (logger_cls, ObserverBase, FakeQuantizeBase)) # type: ignore[arg-type] + or is_known_fp32_or_int8_input_module + ): + # A logger or observer's input and output type is the output + # type of the preceding node. + first_arg = get_normalized_nth_input(node, gm, 0) + assert isinstance(first_arg, Node) + ( + _prev_node_input_type, + prev_node_output_type, + ) = get_node_first_input_and_output_type( + first_arg, gm, logger_cls, node_type_to_io_type_map + ) + return (prev_node_output_type, prev_node_output_type) + is_known_fp32_input_module = any( + isinstance(mod, target_type) # type: ignore[arg-type] + for target_type in MODS_IO_TYPE_FP32 + ) + is_known_int8_input_module = any( + isinstance(mod, target_type) # type: ignore[arg-type] + for target_type in MODS_IO_TYPE_INT8 + ) + if is_known_fp32_input_module: + return (NodeInputOrOutputType.FP32, NodeInputOrOutputType.FP32) + elif is_known_int8_input_module: + return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8) + else: + return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN) + + elif node.op == "call_method": + if node.target == "dequantize": + # Dequantize is a special node because it allows multiple input types. + # So, we look up the output type of the previous node and return that + # as the input type of this node instance. + prev_node = get_normalized_nth_input(node, gm, 0) + assert isinstance(prev_node, Node) + ( + _prev_node_input_type, + prev_node_output_type, + ) = get_node_first_input_and_output_type( + prev_node, gm, logger_cls, node_type_to_io_type_map + ) + return (prev_node_output_type, NodeInputOrOutputType.FP32) + + elif node.target == "to": + # to is a special node because it allows multiple input types. + # So, we look up the output type of the previous node and return that + # as the input type of this node instance. We also look up the target + # of to and return the correct output type. + prev_node = get_normalized_nth_input(node, gm, 0) + assert isinstance(prev_node, Node) + ( + _prev_node_input_type, + prev_node_output_type, + ) = get_node_first_input_and_output_type( + prev_node, gm, logger_cls, node_type_to_io_type_map + ) + + cur_node_dtype_target = get_normalized_nth_input(node, gm, 1) + assert cur_node_dtype_target is torch.float16, ( + f"{cur_node_dtype_target} handling needs to be added" + ) + + return (prev_node_output_type, NodeInputOrOutputType.FP16) + + elif node.target in METHS_IO_TYPE_FP32_OR_INT8: + first_arg = get_normalized_nth_input(node, gm, 0) + assert isinstance(first_arg, Node) + ( + _prev_node_input_type, + prev_node_output_type, + ) = get_node_first_input_and_output_type( + first_arg, gm, logger_cls, node_type_to_io_type_map + ) + return (prev_node_output_type, prev_node_output_type) + + return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN) + else: + return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN) + + +def get_node_input_qparams( + node: Node, + gm: GraphModule, + node_type_to_io_type_map: dict[str, set[NSNodeTargetType]], +) -> Optional[tuple[Union[torch.Tensor, float], Union[torch.Tensor, int]]]: + """ + Returns the qparams (scale, zero_point) of the first input to `node`, + if they can be inferred from the graph. + """ + prev_node = get_normalized_nth_input(node, gm, 0) + + if not isinstance(prev_node, Node): + return None + + MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["mods_io_type_fp32_or_int8"] + + def _get_scale_zp_from_function_args(node, gm, scale_arg_idx, zp_arg_idx): + scale_node = get_normalized_nth_input(node, gm, scale_arg_idx) + zp_node = get_normalized_nth_input(node, gm, zp_arg_idx) + assert isinstance(scale_node, Node) and isinstance(scale_node.target, str) + assert isinstance(zp_node, Node) and isinstance(zp_node.target, str) + scale_obj = getattr_from_fqn(gm, scale_node.target) + zp_obj = getattr_from_fqn(gm, zp_node.target) + return (scale_obj, zp_obj) + + if prev_node.op == "call_function": + # quantize - read the args directly + if prev_node.target == torch.quantize_per_tensor: + return _get_scale_zp_from_function_args(prev_node, gm, 1, 2) + elif prev_node.target in (toq.add, toq.add_relu, toq.mul, toq.mul_relu): + return _get_scale_zp_from_function_args(prev_node, gm, 2, 3) + + return None + # TODO(future PR): handle more functionals + # TODO(future PR): handle functional ops which inherit qparams from input + + elif prev_node.op == "call_module": + # get type of the module + assert isinstance(prev_node.target, str) + module_obj = getattr_from_fqn(gm, prev_node.target) + if isinstance( + module_obj, + ( + nnq.Linear, + nnq.Conv1d, + nnq.Conv2d, + nniq.ConvReLU2d, + nnq.Conv3d, + nnq.BatchNorm2d, + nnq.BatchNorm3d, + nnq.ConvTranspose1d, + nnq.ConvTranspose2d, + nnq.ELU, + nnq.GroupNorm, + nnq.InstanceNorm1d, + nnq.InstanceNorm2d, + nnq.InstanceNorm3d, + nnq.LayerNorm, + nnq.Hardswish, + nnq.LeakyReLU, + nnq.ReLU6, + nniq.BNReLU2d, + nniq.BNReLU3d, + nniq.ConvReLU1d, + nniq.ConvReLU2d, + nniq.ConvReLU3d, + nniq.LinearReLU, + ), + ): + return (module_obj.scale, module_obj.zero_point) # type: ignore[return-value] + + is_known_fp32_or_int8_input_module = any( + isinstance(module_obj, target_type) # type: ignore[arg-type] + for target_type in MODS_IO_TYPE_FP32_OR_INT8 + ) + if is_known_fp32_or_int8_input_module: + return get_node_input_qparams(prev_node, gm, node_type_to_io_type_map) + + return None + + +def return_first_non_observer_node( + node: Node, + gm: GraphModule, +) -> Node: + """ + If node is not an observer, returns it. If node is an observer, + navigates up the graph and returns the first parent which is not an + observer. For example, + + graph: (node_non_obs), node = node_non_obs : returns node_non_obs + graph: (node_non_obs -> obs0), node = obs0 : returns node_non_obs + graph: (node_non_obs -> obs0 -> fq0), node = fq0 : returns node_non_obs + """ + if node.op == "call_module": + node_obj = getattr_from_fqn(gm, node.target) # type: ignore[arg-type] + if _is_activation_post_process(node_obj): + assert len(node.args) == 1 + assert isinstance(node.args[0], Node) + node = node.args[0] + # code duplication intended, not worth refactoring + assert isinstance(node.target, str) + node_obj = getattr_from_fqn(gm, node.target) + if _is_activation_post_process(node_obj): + assert len(node.args) == 1 + assert isinstance(node.args[0], Node) + node = node.args[0] + return node + + +def get_number_of_non_param_args( + node: Node, + gm: GraphModule, +) -> int: + """ + Assumes that all non-param args occur first. Returns the number of + non-param args expected for a node. For example, for + + F.linear(x, weight, bias) + + Returns 1, because x is a non-param arg and weight and bias are params. + For + + lstm_mod(x, hid) + + Returns 2, because both x and hid are non-param args. + """ + if node.op == "call_module": + node_obj = getattr_from_fqn(gm, node.target) # type: ignore[arg-type] + if isinstance(node_obj, nn.LSTM): + return 2 + + # default is 1 + return 1 + + +def get_arg_indices_of_inputs_to_log(node: Node) -> list[int]: + """ + Returns the indices of args of the node which we should attach + loggers to, if input logging is enabled. + + For example, + * for (x + y), returns [0, 1] + * for (1 + y), returns [1] + * for (x + 1), returns [0] + * for (linear(x, w, b)) returns [0] + * by default, returns [0] + """ + if len(node.args) == 0: + return [] + if node.op == "call_function" and ( + # TODO(future PR): use relationship map instead of hardcoding + node.target in (torch.add, torch.ops.quantized.add, operator.add) + or node.target in (torch.mul, torch.ops.quantized.mul, operator.mul) + ): + result = [i for i in range(2) if type(node.args[i]) == Node] + return result + return [0] + + +def get_target_type_str(node: Node, gm: GraphModule) -> str: + """ + Returns a string representation of the type of the function or module + pointed to by this node, or '' for other node types. + """ + target_type = "" + if node.op in ("call_function", "call_method"): + target_type = torch.typename(node.target) + elif node.op == "call_module": + assert isinstance(node.target, str) + target_mod = getattr_from_fqn(gm, node.target) + target_type = torch.typename(target_mod) + return target_type + + +def rekey_logger_info_on_node_name_of_model( + results: NSResultsType, + model_name: str, +) -> NSResultsType: + """ + Rekeys the layer name of a results dictionary to use node names + from `model_name`. + + For example, transforms + + {'base_op_1_0': {'node_output': {'model_a': + [{'ref_node_name': 'linear1', ...}]}}} + + into + + {'linear1': {'node_output': {'model_a': + [{'ref_node_name': 'linear1', ...}]}}} + + Note: we cannot use these node names directly because they are not + guaranteed to be consistent across models. This is why we extract + the results first and rekey afterwards. + """ + new_results = {} + for old_layer_name, result_type_to_results in results.items(): + new_layer_name = None + for model_name_to_results in result_type_to_results.values(): + for cur_model_name, list_of_results in model_name_to_results.items(): + if cur_model_name == model_name: + assert len(list_of_results) + new_layer_name = list_of_results[0]["ref_node_name"] + else: + continue + if new_layer_name is not None: + new_results[new_layer_name] = result_type_to_results + else: + new_results[old_layer_name] = result_type_to_results + return new_results + + +def maybe_add_missing_fqns(results: NSResultsType) -> None: + """ + If `fqn` entries are filled in for one of the models in `results`, copies + them over to any models which do not have them filled out. + + A common use case benefitting from this is comparing a model prepared by + quantization to a quantized model. In this case, the model prepared by + quantization would have `fqn` entries, and the quantized model would not. + """ + + # Check in the first result to find any model with fqn entries defined. + model_name_with_fqns = None + for result_type_to_results in results.values(): + for model_name_to_results in result_type_to_results.values(): + for model_name, model_results in model_name_to_results.items(): + if len(model_results) > 0: + if model_results[0]["fqn"] is not None: + model_name_with_fqns = model_name + break + break + break + + if model_name_with_fqns: + for result_type_to_results in results.values(): + for model_name_to_results in result_type_to_results.values(): + ref_model_results = model_name_to_results[model_name_with_fqns] + for model_name, model_results in model_name_to_results.items(): + if model_name == model_name_with_fqns: + continue + for i in range(len(model_results)): + fqn = ref_model_results[i]["fqn"] + model_results[i]["fqn"] = fqn + + +def maybe_dequantize_first_two_tensor_args_and_handle_tuples(f): + def inner(*args, **kwargs): + a0, a1, *a_other = args + + if (isinstance(a0, tuple) and isinstance(a1, tuple)) or ( + isinstance(a0, list) and isinstance(a1, list) + ): + results = [] + for el0, el1 in zip(a0, a1): + new_args = (el0, el1, *a_other) + results.append(inner(*new_args, **kwargs)) + return results + + elif isinstance(a0, torch.Tensor) and isinstance(a1, torch.Tensor): + if a0.is_quantized: + a0 = a0.dequantize() + if a1.is_quantized: + a1 = a1.dequantize() + + # for the purposes of this util, only handle floats + if a0.dtype != torch.float or a1.dtype != torch.float: + return None + + new_args = (a0, a1, *a_other) + return f(*new_args, **kwargs) + + return inner + + +@maybe_dequantize_first_two_tensor_args_and_handle_tuples +def compute_sqnr(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + Computes the SQNR between `x` and `y`. + + Args: + x: Tensor or tuple of tensors + y: Tensor or tuple of tensors + + Return: + float or tuple of floats + """ + Ps = torch.norm(x) + Pn = torch.norm(x - y) + return 20 * torch.log10(Ps / Pn) + + +@maybe_dequantize_first_two_tensor_args_and_handle_tuples +def compute_normalized_l2_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + Computes the normalized L2 error between `x` and `y`. + + Args: + x: Tensor or tuple of tensors + y: Tensor or tuple of tensors + + Return: + float or tuple of floats + """ + return torch.sqrt(((x - y) ** 2).sum() / (x**2).sum()) + + +@maybe_dequantize_first_two_tensor_args_and_handle_tuples +def compute_cosine_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + Computes the cosine similarity between `x` and `y`. + + Args: + x: Tensor or tuple of tensors + y: Tensor or tuple of tensors + + Return: + float or tuple of floats + """ + # For convolutions, the shape of the quantized weight has one additional + # dimension compared to the shape of the fp32 weight. Match the shapes + # to enable cosine similarity comparison. + x = x.reshape(1, -1) + y = y.reshape(1, -1) + return torch.nn.functional.cosine_similarity(x, y) + + +def op_type_supports_shadowing(node: Node) -> bool: + if node.op == "call_function": + if node.target in ( + torch.add, + torch.mul, + operator.add, + operator.mul, + torch.cat, + torch.stack, + ): + # shadowing for ops with multiple tensor inputs is not implemented yet + return False + return True + + +def get_normalized_nth_input(node: Node, gm: GraphModule, idx: int) -> Node: + """ + Given a node, gets the n'th input to that node, normalizing + args and kwargs to the best of its ability. + """ + try: + norm_args_and_kwargs = node.normalized_arguments( + gm, normalize_to_only_use_kwargs=True + ) + if norm_args_and_kwargs is not None: + norm_args, norm_kwargs = norm_args_and_kwargs + assert len(norm_args) + len(norm_kwargs) > idx + if idx < len(norm_args): + return norm_args[idx] + else: + # note: in Python 3.7+ dicts are ordered + return list(norm_kwargs.values())[idx] + else: + assert len(node.args) + len(node.kwargs) > idx + if idx < len(node.args): + return node.args[idx] # type: ignore[return-value] + else: + kwargs_idx = idx + len(node.args) + return list(node.kwargs.values())[kwargs_idx] # type: ignore[return-value] + except RuntimeError: + # this RuntimeError happens when node argument normalization + # requires typehints to proceed, such as for torch.add where + # either the first, second or both arguments could be tensors + assert len(node.args) + len(node.kwargs) > idx + if idx < len(node.args): + return node.args[idx] # type: ignore[return-value] + else: + kwargs_idx = idx + len(node.args) + return list(node.kwargs.values())[kwargs_idx] # type: ignore[return-value] diff --git a/phivenv/Lib/site-packages/torch/ao/ns/fx/weight_utils.py b/phivenv/Lib/site-packages/torch/ao/ns/fx/weight_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..46af82a6608d75a6a46b62218d833daefebc24a8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/ns/fx/weight_utils.py @@ -0,0 +1,284 @@ +from typing import Callable, Optional + +import torch +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.intrinsic.qat as nniqat +import torch.ao.nn.intrinsic.quantized as nniq +import torch.ao.nn.qat as nnqat +import torch.ao.nn.quantized as nnq +import torch.ao.nn.quantized.dynamic as nnqd +import torch.nn as nn +import torch.nn.functional as F +from torch.fx import GraphModule +from torch.fx.graph import Node + +from .ns_types import NSSingleResultType, NSSingleResultValuesType +from .utils import get_target_type_str, getattr_from_fqn, return_first_non_observer_node + + +toq = torch.ops.quantized + + +def mod_weight_detach(mod: nn.Module) -> torch.Tensor: + return mod.weight.detach() # type: ignore[operator] + + +def mod_0_weight_detach(mod: nn.Module) -> torch.Tensor: + return mod[0].weight.detach() # type: ignore[index] + + +def mod_weight_bias_0(mod: nn.Module) -> torch.Tensor: + return mod._weight_bias()[0] # type: ignore[operator] + + +def get_lstm_weight(mod: nn.Module) -> list[torch.Tensor]: + res = [] + for idx, param_name in enumerate(mod._flat_weights_names): # type: ignore[arg-type] + if "weight_ih_l" in param_name or "weight_hh_l" in param_name: + param_value = mod._flat_weights[idx].detach() # type: ignore[index,union-attr] + res.append(param_value) + return res + + +def get_qlstm_weight(mod: nn.Module) -> list[torch.Tensor]: + res = [] + for weight_value in mod._all_weight_values: # type: ignore[union-attr] + res.append(weight_value.param.__getstate__()[0][4][0].__getstate__()[0][0]) + res.append(weight_value.param.__getstate__()[0][4][1].__getstate__()[0][0]) + return res + + +def get_conv_mod_weight(mod: nn.Module) -> torch.Tensor: + if isinstance(mod, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + return mod.weight.detach() + elif isinstance(mod, (nni.ConvReLU1d, nni.ConvReLU2d, nni.ConvReLU3d)): + return mod[0].weight.detach() # type: ignore[operator] + else: + return mod._weight_bias()[0] # type: ignore[operator] + + +def get_linear_mod_weight(mod: nn.Module) -> torch.Tensor: + if isinstance(mod, nn.Linear): + return mod.weight.detach() + elif isinstance(mod, nni.LinearReLU): + return mod[0].weight.detach() # type: ignore[operator] + else: + return mod._weight_bias()[0] # type: ignore[operator] + + +def get_lstm_mod_weights(mod: nn.Module) -> list[torch.Tensor]: + # TODO(future PR): make more generic, handle everything + if isinstance(mod, nn.LSTM): + res = [] + for idx, param_name in enumerate(mod._flat_weights_names): + if "weight_ih_l" in param_name or "weight_hh_l" in param_name: + param_value = mod._flat_weights[idx].detach() # type: ignore[index,union-attr] + res.append(param_value) + return res + else: + assert isinstance(mod, nnqd.LSTM), f"type {type(mod)} not handled yet" + res = [] + for weight_value in mod._all_weight_values: + res.append( + weight_value.param.__getstate__()[0][4][0].__getstate__()[0][0] # type: ignore[index] + ) + res.append( + weight_value.param.__getstate__()[0][4][1].__getstate__()[0][0] # type: ignore[index] + ) + return res + + +def get_conv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor: + # traverse backwards from the weight arg, accounting for any observers + weight_arg_node = node.args[1] + assert isinstance(weight_arg_node, Node) + weight_node = return_first_non_observer_node(weight_arg_node, gm) + assert isinstance(weight_node, Node) + assert weight_node.op == "get_attr" + weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type] + return weight.detach() + + +def get_qconv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor: + # qconv state is arg 1 + qconv_state_node = node.args[1] + assert isinstance(qconv_state_node, Node) + assert qconv_state_node.op == "get_attr" + qconv_state_obj = getattr_from_fqn(gm, qconv_state_node.target) # type: ignore[arg-type] + return qconv_state_obj.weight() + + +def get_linear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor: + # traverse backwards from the weight arg, accounting for any observers + # supported patterns: + # weight -> obs -> linear + # weight -> to(torch.float16) -> dequantize -> linear + linear_second_arg = node.args[1] + assert isinstance(linear_second_arg, Node) + + if linear_second_arg.op == "call_module": + # weight -> obs -> linear + weight_arg_node = node.args[1] + assert isinstance(weight_arg_node, Node) + weight_node = weight_arg_node.args[0] + assert isinstance(weight_node, Node) + assert weight_node.op == "get_attr" + weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type] + return weight.detach() + elif linear_second_arg.op == "call_method": + # weight -> to(torch.float16) -> dequantize -> linear + assert linear_second_arg.op == "call_method" + dequant_node = node.args[1] + assert isinstance(dequant_node, Node) + to_fp16_node = dequant_node.args[0] + assert isinstance(to_fp16_node, Node) + # extract the dtype, so we can cast to it before returning + target_dtype = to_fp16_node.args[1] + weight_node = to_fp16_node.args[0] + assert isinstance(weight_node, Node) + assert weight_node.op == "get_attr" + weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type] + # return the weight with fp16 cast + return weight.detach().to(target_dtype) + else: + assert linear_second_arg.op == "get_attr" + weight = getattr_from_fqn(gm, linear_second_arg.target) # type: ignore[arg-type] + return weight.detach() + + +def get_qlinear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor: + # packed weight is arg 1 + packed_weight_node = node.args[1] + assert isinstance(packed_weight_node, Node) + assert packed_weight_node.op == "get_attr" + packed_weight = getattr_from_fqn(gm, packed_weight_node.target) # type: ignore[arg-type] + # TODO(future PR): why does packed_weight.unpack() not work? + (weight, _bias), _name = packed_weight.__getstate__() + return weight + + +def get_op_to_type_to_weight_extraction_fn() -> dict[str, dict[Callable, Callable]]: + op_to_type_to_weight_extraction_fn: dict[str, dict[Callable, Callable]] = { + "call_module": { + # Conv1d + nn.Conv1d: mod_weight_detach, + nni.ConvReLU1d: mod_0_weight_detach, + nnq.Conv1d: mod_weight_bias_0, + nnqat.Conv1d: mod_weight_detach, + nniqat.ConvBn1d: mod_weight_detach, + nniqat.ConvBnReLU1d: mod_weight_detach, + nniqat.ConvReLU1d: mod_weight_detach, + nniq.ConvReLU1d: mod_weight_bias_0, + # Conv2d + nn.Conv2d: mod_weight_detach, + nni.ConvReLU2d: mod_0_weight_detach, + nnq.Conv2d: mod_weight_bias_0, + nnqat.Conv2d: mod_weight_detach, + nniqat.ConvBn2d: mod_weight_detach, + nniqat.ConvBnReLU2d: mod_weight_detach, + nniqat.ConvReLU2d: mod_weight_detach, + nniq.ConvReLU2d: mod_weight_bias_0, + # Conv3d + nn.Conv3d: mod_weight_detach, + nni.ConvReLU3d: mod_0_weight_detach, + nnq.Conv3d: mod_weight_bias_0, + nnqat.Conv3d: mod_weight_detach, + nniqat.ConvBn3d: mod_weight_detach, + nniqat.ConvBnReLU3d: mod_weight_detach, + nniqat.ConvReLU3d: mod_weight_detach, + nniq.ConvReLU3d: mod_weight_bias_0, + # Linear + nn.Linear: mod_weight_detach, + nnq.Linear: mod_weight_bias_0, + nni.LinearReLU: mod_0_weight_detach, + nniq.LinearReLU: mod_weight_bias_0, + nnqat.Linear: mod_weight_detach, + nnqd.Linear: mod_weight_bias_0, + nniqat.LinearReLU: mod_weight_detach, + nniqat.LinearBn1d: mod_weight_detach, + nn.modules.linear.NonDynamicallyQuantizableLinear: mod_weight_detach, + # LSTM + nn.LSTM: get_lstm_weight, + nnqd.LSTM: get_qlstm_weight, + }, + "call_function": { + # Conv + F.conv1d: get_conv_fun_weight, + F.conv2d: get_conv_fun_weight, + F.conv3d: get_conv_fun_weight, + toq.conv1d: get_qconv_fun_weight, + toq.conv2d: get_qconv_fun_weight, + toq.conv3d: get_qconv_fun_weight, + toq.conv1d_relu: get_qconv_fun_weight, + toq.conv2d_relu: get_qconv_fun_weight, + toq.conv3d_relu: get_qconv_fun_weight, + # Linear + F.linear: get_linear_fun_weight, + toq.linear: get_qlinear_fun_weight, + toq.linear_relu: get_qlinear_fun_weight, + }, + } + + return op_to_type_to_weight_extraction_fn + + +def extract_weight_from_node( + node: Node, + gm: GraphModule, + op_to_type_to_weight_extraction_fn: Optional[ + dict[str, dict[Callable, Callable]] + ] = None, +) -> Optional[NSSingleResultType]: + res_type = NSSingleResultValuesType.WEIGHT.value + + # Not all graphmodules have _node_name_to_scope, so only fill it + # out if it exists. + fqn = None + if hasattr(gm, "_node_name_to_scope"): + fqn = gm._node_name_to_scope[node.name][0] # type: ignore[index] + + if op_to_type_to_weight_extraction_fn is None: + op_to_type_to_weight_extraction_fn = get_op_to_type_to_weight_extraction_fn() + + ref_node_type = get_target_type_str(node, gm) + # for extracting weights, these are always the same + prev_node_type = ref_node_type + + if node.op == "call_function": + function_mapping = op_to_type_to_weight_extraction_fn["call_function"] + for target_fn_type, weight_extraction_fn in function_mapping.items(): + if node.target == target_fn_type: + weight = weight_extraction_fn(node, gm) + return { + "type": res_type, + "values": [weight], + "prev_node_name": node.name, + "prev_node_target_type": prev_node_type, + "ref_node_name": node.name, + "ref_node_target_type": ref_node_type, + "index_within_arg": 0, + "index_of_arg": 0, + "fqn": fqn, + } + + elif node.op == "call_module": + # for call_module, we need to look up the modules to do the type check + assert isinstance(node.target, str) + mod = getattr_from_fqn(gm, node.target) + module_mapping = op_to_type_to_weight_extraction_fn["call_module"] + for target_mod_type, weight_extraction_fn in module_mapping.items(): + if type(mod) == target_mod_type: + weight = weight_extraction_fn(mod) + return { + "type": res_type, + "values": [weight], + "prev_node_name": node.name, + "prev_node_target_type": prev_node_type, + "ref_node_name": node.name, + "ref_node_target_type": ref_node_type, + "index_within_arg": 0, + "index_of_arg": 0, + "fqn": fqn, + } + + return None diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/__init__.py b/phivenv/Lib/site-packages/torch/ao/pruning/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f790c8892355e6396d558da2410ffbabaf0925e3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/pruning/__init__.py @@ -0,0 +1,23 @@ +# Variables +from ._mappings import ( + get_dynamic_sparse_quantized_mapping, + get_static_sparse_quantized_mapping, +) + +# Scheduler +from .scheduler.base_scheduler import BaseScheduler +from .scheduler.cubic_scheduler import CubicSL +from .scheduler.lambda_scheduler import LambdaSL + +# Sparsifier +from .sparsifier.base_sparsifier import BaseSparsifier +from .sparsifier.nearly_diagonal_sparsifier import NearlyDiagonalSparsifier + +# Parametrizations +from .sparsifier.utils import ( + FakeSparsity, + fqn_to_module, + get_arg_info_from_tensor_fqn, + module_to_fqn, +) +from .sparsifier.weight_norm_sparsifier import WeightNormSparsifier diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/pruning/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33432b71ad1fe74b7bb0dc3866e899754a13f3c6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/pruning/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/__pycache__/_mappings.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/pruning/__pycache__/_mappings.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0db85e4b43a55d71f161cbc0b9f4c4a918705538 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/pruning/__pycache__/_mappings.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/__init__.py b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a80d95a4bdbd8446b70a637dc4b5220abaa3bbf Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/__init__.py b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db8321a90a224420a02d6193a53450261c6c2ee2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/__pycache__/activation_sparsifier.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/__pycache__/activation_sparsifier.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0ceb43e789f3b781a876eea7c96b577f02a26a2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/__pycache__/activation_sparsifier.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py new file mode 100644 index 0000000000000000000000000000000000000000..90dfee0cefecaac14783127ae1767ee7aad9e785 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py @@ -0,0 +1,476 @@ +# mypy: allow-untyped-defs +import copy +import warnings +from collections import defaultdict +from typing import Any, Optional + +import torch +from torch import nn +from torch.ao.pruning.sparsifier.utils import fqn_to_module, module_to_fqn + + +__all__ = ["ActivationSparsifier"] + + +class ActivationSparsifier: + r""" + The Activation sparsifier class aims to sparsify/prune activations in a neural + network. The idea is to attach the sparsifier to a layer (or layers) and it + zeroes out the activations based on the mask_fn (or sparsification function) + input by the user. + The mask_fn is applied once all the inputs are aggregated and reduced i.e. + mask = mask_fn(reduce_fn(aggregate_fn(activations))) + + Note:: + The sparsification mask is computed on the input **before it goes through the attached layer**. + + Args: + model (nn.Module): + The model whose layers will be sparsified. The layers that needs to be + sparsified should be added separately using the register_layer() function + aggregate_fn (Optional, Callable): + default aggregate_fn that is used if not specified while registering the layer. + specifies how inputs should be aggregated over time. + The aggregate_fn should usually take 2 torch tensors and return the aggregated tensor. + Example + def add_agg_fn(tensor1, tensor2): return tensor1 + tensor2 + reduce_fn (Optional, Callable): + default reduce_fn that is used if not specified while registering the layer. + reduce_fn will be called on the aggregated tensor i.e. the tensor obtained after + calling agg_fn() on all inputs. + Example + def mean_reduce_fn(agg_tensor): return agg_tensor.mean(dim=0) + mask_fn (Optional, Callable): + default mask_fn that is used to create the sparsification mask using the tensor obtained after + calling the reduce_fn(). This is used by default if a custom one is passed in the + register_layer(). + Note that the mask_fn() definition should contain the sparse arguments that is passed in sparse_config + arguments. + features (Optional, list): + default selected features to sparsify. + If this is non-empty, then the mask_fn will be applied for each feature of the input. + For example, + mask = [mask_fn(reduce_fn(aggregated_fn(input[feature])) for feature in features] + feature_dim (Optional, int): + default dimension of input features. Again, features along this dim will be chosen + for sparsification. + sparse_config (Dict): + Default configuration for the mask_fn. This config will be passed + with the mask_fn() + + Example: + >>> # xdoctest: +SKIP + >>> model = SomeModel() + >>> act_sparsifier = ActivationSparsifier(...) # init activation sparsifier + >>> # Initialize aggregate_fn + >>> def agg_fn(x, y): + >>> return x + y + >>> + >>> # Initialize reduce_fn + >>> def reduce_fn(x): + >>> return torch.mean(x, dim=0) + >>> + >>> # Initialize mask_fn + >>> def mask_fn(data): + >>> return torch.eye(data.shape).to(data.device) + >>> + >>> + >>> act_sparsifier.register_layer( + ... model.some_layer, + ... aggregate_fn=agg_fn, + ... reduce_fn=reduce_fn, + ... mask_fn=mask_fn, + ... ) + >>> + >>> # start training process + >>> for _ in [...]: + >>> # epoch starts + >>> # model.forward(), compute_loss() and model.backwards() + >>> # epoch ends + >>> act_sparsifier.step() + >>> # end training process + >>> sparsifier.squash_mask() + """ + + def __init__( + self, + model: nn.Module, + aggregate_fn=None, + reduce_fn=None, + mask_fn=None, + features=None, + feature_dim=None, + **sparse_config, + ): + self.model = model + self.defaults: dict[str, Any] = defaultdict() + self.defaults["sparse_config"] = sparse_config + + # functions + self.defaults["aggregate_fn"] = aggregate_fn + self.defaults["reduce_fn"] = reduce_fn + self.defaults["mask_fn"] = mask_fn + + # default feature and feature_dim + self.defaults["features"] = features + self.defaults["feature_dim"] = feature_dim + + self.data_groups: dict[str, dict] = defaultdict( + dict + ) # contains all relevant info w.r.t each registered layer + + self.state: dict[str, Any] = defaultdict(dict) # layer name -> mask + + @staticmethod + def _safe_rail_checks(args): + """Makes sure that some of the functions and attributes are not passed incorrectly""" + + # if features are not None, then feature_dim must not be None + features, feature_dim = args["features"], args["feature_dim"] + if features is not None: + assert feature_dim is not None, "need feature dim to select features" + + # all the *_fns should be callable + fn_keys = ["aggregate_fn", "reduce_fn", "mask_fn"] + for key in fn_keys: + fn = args[key] + assert callable(fn), "function should be callable" + + def _aggregate_hook(self, name): + """Returns hook that computes aggregate of activations passing through.""" + + # gather some data + feature_dim = self.data_groups[name]["feature_dim"] + features = self.data_groups[name]["features"] + agg_fn = self.data_groups[name]["aggregate_fn"] + + def hook(module, input) -> None: + input_data = input[0] + + data = self.data_groups[name].get("data") # aggregated data + if features is None: + # no features associated, data should not be a list + if data is None: + data = torch.zeros_like(input_data) + self.state[name]["mask"] = torch.ones_like(input_data) + out_data = agg_fn(data, input_data) + else: + # data should be a list [aggregated over each feature only] + if data is None: + out_data = [ + 0 for _ in range(0, len(features)) + ] # create one incase of 1st forward + self.state[name]["mask"] = [0 for _ in range(0, len(features))] + else: + out_data = data # a list + + # compute aggregate over each feature + for feature_idx in range(len(features)): + # each feature is either a list or scalar, convert it to torch tensor + feature_tensor = ( + torch.Tensor([features[feature_idx]]) + .long() + .to(input_data.device) + ) + data_feature = torch.index_select( + input_data, feature_dim, feature_tensor + ) + if data is None: + curr_data = torch.zeros_like(data_feature) + self.state[name]["mask"][feature_idx] = torch.ones_like( + data_feature + ) + else: + curr_data = data[feature_idx] + out_data[feature_idx] = agg_fn(curr_data, data_feature) + self.data_groups[name]["data"] = out_data + + return hook + + def register_layer( + self, + layer: nn.Module, + aggregate_fn=None, + reduce_fn=None, + mask_fn=None, + features=None, + feature_dim=None, + **sparse_config, + ): + r""" + Registers a layer for sparsification. The layer should be part of self.model. + Specifically, registers a pre-forward hook to the layer. The hook will apply the aggregate_fn + and store the aggregated activations that is input over each step. + + Note:: + - There is no need to pass in the name of the layer as it is automatically computed as per + the fqn convention. + + - All the functions (fn) passed as argument will be called at a dim, feature level. + """ + name = module_to_fqn(self.model, layer) + assert name is not None, "layer not found in the model" # satisfy mypy + + if name in self.data_groups: # unregister layer if already present + warnings.warn( + "layer already attached to the sparsifier, deregistering the layer and registering with new config" + ) + self.unregister_layer(name=name) + + local_args = copy.deepcopy(self.defaults) + update_dict = { + "aggregate_fn": aggregate_fn, + "reduce_fn": reduce_fn, + "mask_fn": mask_fn, + "features": features, + "feature_dim": feature_dim, + "layer": layer, + } + local_args.update( + (arg, val) for arg, val in update_dict.items() if val is not None + ) + local_args["sparse_config"].update(sparse_config) + + self._safe_rail_checks(local_args) + + self.data_groups[name] = local_args + agg_hook = layer.register_forward_pre_hook(self._aggregate_hook(name=name)) + + self.state[name]["mask"] = ( + None # mask will be created when model forward is called. + ) + + # attach agg hook + self.data_groups[name]["hook"] = agg_hook + + # for serialization purposes, we know whether aggregate_hook is attached + # or sparsify_hook() + self.data_groups[name]["hook_state"] = "aggregate" # aggregate hook is attached + + def get_mask(self, name: Optional[str] = None, layer: Optional[nn.Module] = None): + """ + Returns mask associated to the layer. + + The mask is + - a torch tensor is features for that layer is None. + - a list of torch tensors for each feature, otherwise + + Note:: + The shape of the mask is unknown until model.forward() is applied. + Hence, if get_mask() is called before model.forward(), an + error will be raised. + """ + assert name is not None or layer is not None, ( + "Need at least name or layer obj to retrieve mask" + ) + + if name is None: + assert layer is not None + name = module_to_fqn(self.model, layer) + assert name is not None, "layer not found in the specified model" + + if name not in self.state: + raise ValueError("Error: layer with the given name not found") + + mask = self.state[name].get("mask", None) + + if mask is None: + raise ValueError( + "Error: shape unknown, call layer() routine at least once to infer mask" + ) + return mask + + def unregister_layer(self, name): + """Detaches the sparsifier from the layer""" + + # detach any hooks attached + self.data_groups[name]["hook"].remove() + + # pop from the state dict + self.state.pop(name) + + # pop from the data groups + self.data_groups.pop(name) + + def step(self): + """Internally calls the update_mask() function for each layer""" + with torch.no_grad(): + for name, configs in self.data_groups.items(): + data = configs["data"] + self.update_mask(name, data, configs) + + self.data_groups[name].pop("data") # reset the accumulated data + + def update_mask(self, name, data, configs): + """ + Called for each registered layer and does the following- + 1. apply reduce_fn on the aggregated activations + 2. use mask_fn to compute the sparsification mask + + Note: + the reduce_fn and mask_fn is called for each feature, dim over the data + """ + mask = self.get_mask(name) + sparse_config = configs["sparse_config"] + features = configs["features"] + reduce_fn = configs["reduce_fn"] + mask_fn = configs["mask_fn"] + if features is None: + data = reduce_fn(data) + mask.data = mask_fn(data, **sparse_config) + else: + for feature_idx in range(len(features)): + data_feature = reduce_fn(data[feature_idx]) + mask[feature_idx].data = mask_fn(data_feature, **sparse_config) + + def _sparsify_hook(self, name): + """Returns hook that applies sparsification mask to input entering the attached layer""" + mask = self.get_mask(name) + features = self.data_groups[name]["features"] + feature_dim = self.data_groups[name]["feature_dim"] + + def hook(module, input): + input_data = input[0] + if features is None: + # apply to all the features + return input_data * mask + else: + # apply per feature, feature_dim + for feature_idx in range(0, len(features)): + feature = ( + torch.Tensor([features[feature_idx]]) + .long() + .to(input_data.device) + ) + sparsified = ( + torch.index_select(input_data, feature_dim, feature) + * mask[feature_idx] + ) + input_data.index_copy_(feature_dim, feature, sparsified) + return input_data + + return hook + + def squash_mask(self, attach_sparsify_hook=True, **kwargs): + """ + Unregisters aggregate hook that was applied earlier and registers sparsification hooks if + attach_sparsify_hook = True. + """ + for name, configs in self.data_groups.items(): + # unhook agg hook + configs["hook"].remove() + configs.pop("hook") + self.data_groups[name]["hook_state"] = "None" + if attach_sparsify_hook: + configs["hook"] = configs["layer"].register_forward_pre_hook( + self._sparsify_hook(name) + ) + configs["hook_state"] = ( + "sparsify" # signals that sparsify hook is now attached + ) + + def _get_serializable_data_groups(self): + """Exclude hook and layer from the config keys before serializing + + TODO: Might have to treat functions (reduce_fn, mask_fn etc) in a different manner while serializing. + For time-being, functions are treated the same way as other attributes + """ + data_groups: dict[str, Any] = defaultdict() + for name, config in self.data_groups.items(): + new_config = { + key: value + for key, value in config.items() + if key not in ["hook", "layer"] + } + data_groups[name] = new_config + return data_groups + + def _convert_mask(self, states_dict, sparse_coo=True): + r"""Converts the mask to sparse coo or dense depending on the `sparse_coo` argument. + If `sparse_coo=True`, then the mask is stored as sparse coo else dense tensor + """ + states = copy.deepcopy(states_dict) + for state in states.values(): + if state["mask"] is not None: + if isinstance(state["mask"], list): + for idx in range(len(state["mask"])): + if sparse_coo: + state["mask"][idx] = state["mask"][idx].to_sparse_coo() + else: + state["mask"][idx] = state["mask"][idx].to_dense() + else: + if sparse_coo: + state["mask"] = state["mask"].to_sparse_coo() + else: + state["mask"] = state["mask"].to_dense() + return states + + def state_dict(self) -> dict[str, Any]: + r"""Returns the state of the sparsifier as a :class:`dict`. + + It contains: + * state - contains name -> mask mapping. + * data_groups - a dictionary containing all config information for each + layer + * defaults - the default config while creating the constructor + """ + data_groups = self._get_serializable_data_groups() + state = self._convert_mask(self.state) + return {"state": state, "data_groups": data_groups, "defaults": self.defaults} + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + r"""The load_state_dict() restores the state of the sparsifier based on the state_dict + + Args: + * state_dict - the dictionary that to which the current sparsifier needs to be restored to + """ + state = state_dict["state"] + data_groups, defaults = state_dict["data_groups"], state_dict["defaults"] + + self.__set_state__( + {"state": state, "data_groups": data_groups, "defaults": defaults} + ) + + def __get_state__(self) -> dict[str, Any]: + data_groups = self._get_serializable_data_groups() + state = self._convert_mask(self.state) + return { + "defaults": self.defaults, + "state": state, + "data_groups": data_groups, + } + + def __set_state__(self, state: dict[str, Any]) -> None: + state["state"] = self._convert_mask( + state["state"], sparse_coo=False + ) # convert mask to dense tensor + self.__dict__.update(state) + + # need to attach layer and hook info into the data_groups + for name, config in self.data_groups.items(): + # fetch layer + layer = fqn_to_module(self.model, name) + assert layer is not None # satisfy mypy + + # if agg_mode is True, then layer in aggregate mode + if "hook_state" in config and config["hook_state"] == "aggregate": + hook = layer.register_forward_pre_hook(self._aggregate_hook(name)) + + elif "hook_state" in config and config["hook_state"] == "sparsify": + hook = layer.register_forward_pre_hook(self._sparsify_hook(name)) + + config["layer"] = layer + config["hook"] = hook # type: ignore[possibly-undefined] + + def __repr__(self): + format_string = self.__class__.__name__ + " (" + for name, config in self.data_groups.items(): + format_string += "\n" + format_string += "\tData Group\n" + format_string += f"\t name: {name}\n" + for key in sorted(config.keys()): + if key in ["data", "hook", "reduce_fn", "mask_fn", "aggregate_fn"]: + continue + format_string += f"\t {key}: {config[key]}\n" + format_string += ")" + return format_string diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_scheduler/__init__.py b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_scheduler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cca2961f5ffa95ec6e610ddcc71e8d6b43a4f4e3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_scheduler/__init__.py @@ -0,0 +1,6 @@ +from .base_data_scheduler import BaseDataScheduler + + +__all__ = [ + "BaseDataScheduler", +] diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_scheduler/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_scheduler/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20ca7a8ef02b8190ef61e7ee7fb186269c84d1c1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_scheduler/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_scheduler/__pycache__/base_data_scheduler.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_scheduler/__pycache__/base_data_scheduler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7aabe9e40d001180eb0cd8b9a7e17064ec4a85dc Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_scheduler/__pycache__/base_data_scheduler.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_scheduler/base_data_scheduler.py b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_scheduler/base_data_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..1f3e331862d96e0ca7eeac1fd3f342923b149c4a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_scheduler/base_data_scheduler.py @@ -0,0 +1,197 @@ +# mypy: allow-untyped-defs +import abc +import warnings +import weakref +from functools import wraps + +from torch.ao.pruning._experimental.data_sparsifier import BaseDataSparsifier + + +__all__ = ["BaseDataScheduler"] + + +class BaseDataScheduler: + r""" + The BaseDataScheduler is the abstract scheduler class specifically for the + BaseDataSparsifier class. This class controls a specific hyperparameter of + the sparsifier class and varies it across the training process (or across time). + + Args: + data_sparsifier (instance of BaseDataSparsifier) + Implemented class data sparsifier class wherein the update_mask is implemented + schedule_param (str) + A specific hyperparameter of the passed sparsifier that needs to be scheduled/varied + last_epoch (int, default=-1) + This is specifically is passed when training needs to be resumed from a particular + point. + verbose (bool, default=False) + Verbosity of the BaseDataScheduler + + The *get_hyperparam()* function needs to be implemented by the user. + """ + + def __init__( + self, data_sparsifier, schedule_param: str, last_epoch=-1, verbose=False + ): + # Attach sparsifier + if not isinstance(data_sparsifier, BaseDataSparsifier): + raise TypeError( + f"{type(data_sparsifier).__name__} is not an instance of torch.ao.pruning.BaseDataSparsifier" + ) + self.data_sparsifier = data_sparsifier + self.schedule_param = schedule_param + + # Initialize epoch and base hyper-params + self.base_param = { + name: config.get(schedule_param, None) + for name, config in self.data_sparsifier.data_groups.items() + } + + self.last_epoch = last_epoch + + # Following https://github.com/pytorch/pytorch/issues/20124 + # We would like to ensure that `scheduler.step()` is called after + # `sparsifier.step()` + def with_counter(method): + if getattr(method, "_with_counter", False): + # `sparsifier.step()` has already been replaced, return. + return method + + # Keep a weak reference to the sparsifier instance to prevent + # cyclic references. + instance_ref = weakref.ref(method.__self__) + # Get the unbound method for the same purpose. + func = method.__func__ + cls = instance_ref().__class__ + del method + + @wraps(func) + def wrapper(*args, **kwargs): + instance = instance_ref() + instance._step_count += 1 # type: ignore[union-attr] + wrapped = func.__get__(instance, cls) + return wrapped(*args, **kwargs) + + # Note that the returned function here is no longer a bound method, + # so attributes like `__func__` and `__self__` no longer exist. + wrapper._with_counter = True # type: ignore[attr-defined] + return wrapper + + self.data_sparsifier.step = with_counter(self.data_sparsifier.step) # type: ignore[assignment] + self.data_sparsifier._step_count = 0 # type: ignore[attr-defined] + self._step_count: int = 0 + self.verbose = verbose + + # Housekeeping + self._get_sp_called_within_step: bool = False # sp -> schedule parameter + self.step() + + @abc.abstractmethod + def get_schedule_param(self): + r""" + Abstract method that needs to be implemented by the child class. + The expected return type should is a dictionary of name to schedule_param value + The returned values will be updated in sparsifier when the scheduler step() function + is called. + + Example: + >>> def get_schedule_param(self): + ... new_param = {} + ... for name in self.sparsifier.data_groups.keys(): + ... new_param[name] = ( + ... self.sparsifier.data_groups[name][self.schedule_param] * 0.5 + ... ) + ... return new_param + + When the step() function is called, the value in self.sparsifier.data_groups[name][self.schedule_param] + would be halved + """ + raise NotImplementedError + + def __repr__(self): + format_string = self.__class__.__name__ + " (" + format_string += "\n" + format_string += f"Data Sparsifier {self.data_sparsifier}\n" + format_string += f" {self.schedule_param}: {self.base_param}\n" + format_string += ")" + return format_string + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the sparsifier. + + Note: + The scheduler class does not track the state of the data_sparsifier. + Make sure to store the state of the sparsifier before storing the + state of the scheduler + """ + return { + key: value + for key, value in self.__dict__.items() + if key != "data_sparsifier" + } + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Note: + Remember to restore the state of the data_sparsifier before the scheduler. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_last_param(self): + return self._last_param + + def step(self): + # Raise warning if trying to call scheduler step before the sparsifier. + # https://github.com/pytorch/pytorch/issues/20124 + if self._step_count == 1: + if not hasattr(self.data_sparsifier.step, "_with_counter"): + warnings.warn( + "Seems like `data_sparsifier.step()` has been overridden after sparsity scheduler " + "initialization. Please, make sure to call `data_sparsifier.step()` before " + "`scheduler.step()`.", + UserWarning, + ) + + # Just check if there were two first scheduler.step() calls before sparsifier.step() + elif self.data_sparsifier._step_count < 1: # type: ignore[attr-defined] + warnings.warn( + "Detected call of `scheduler.step()` before `data_sparsifier.step()`. " + "You have to make sure you run the data_sparsifier.step() BEFORE any " + "calls to the scheduler.step().", + UserWarning, + ) + self._step_count += 1 + + class _enable_get_sp_call: + def __init__(self, o): + self.o = o + + def __enter__(self): + self.o._get_sp_called_within_step = True + return self + + def __exit__(self, type, value, traceback): + self.o._get_sp_called_within_step = False + + with _enable_get_sp_call(self): + self.last_epoch += 1 + updated_scheduler_params = self.get_schedule_param() + + for name, param in updated_scheduler_params.items(): + self.data_sparsifier.data_groups[name][self.schedule_param] = param + if self.verbose: + print(f"Adjusting {self.schedule_param} for group {name} to {param}") + + self._last_param = { + name: config.get(self.schedule_param, None) + for name, config in self.data_sparsifier.data_groups.items() + } + self.data_sparsifier.enable_mask_update = True diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__init__.py b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8fc0f0853b78d7bb4276e7e3bd7e544497d63ecf --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__init__.py @@ -0,0 +1,8 @@ +from .base_data_sparsifier import BaseDataSparsifier +from .data_norm_sparsifier import DataNormSparsifier + + +__all__ = [ + "BaseDataSparsifier", + "DataNormSparsifier", +] diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14af098bfbbd9012e8d7c1e96fe1b1b89e36a2cf Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/base_data_sparsifier.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/base_data_sparsifier.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6677a2ab5f090b597c6ee63e08e383b000f39713 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/base_data_sparsifier.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/data_norm_sparsifier.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/data_norm_sparsifier.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae39085b7d888df71a47902a52102255abb3a7f9 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/data_norm_sparsifier.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/quantization_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/quantization_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..838bfd417c65ce8082e6d7ad7c53af2d1edbd521 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/quantization_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py new file mode 100644 index 0000000000000000000000000000000000000000..f375df356d1ac49748d5865882a7a3a9f1eb3a04 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py @@ -0,0 +1,331 @@ +# mypy: allow-untyped-defs +import abc +import copy +import sys +import warnings +from collections import defaultdict +from typing import Any, Optional + +import torch +from torch import nn +from torch.ao.pruning.sparsifier import base_sparsifier, utils +from torch.nn.utils import parametrize + + +if not sys.warnoptions: + # to suppress repeated warnings when being used in a training loop. + warnings.simplefilter("once") + +__all__ = ["BaseDataSparsifier"] + +EMBEDDING_TYPES = { + nn.Embedding, + nn.EmbeddingBag, +} + +SUPPORTED_TYPES = { + torch.Tensor, + nn.Parameter, + *EMBEDDING_TYPES, +} + + +class _Container(nn.Module): + pass + + +class BaseDataSparsifier(base_sparsifier.BaseSparsifier): + r""" + Base Data Sparsifier class for all Data sparsifiers. + The abstract class accepts raw torch tensors / embedding / embedding bags (refer to SUPPORTED_TYPES above) + to prepare for sparsification. + In this case, mask (and parametrizations) is owned by the class and not by the user. + Specifically, the container object inside the class maintains the mask and parametrizations of the input data + + Args: + data_list (list of tuples) + list of (name, data) tuples to sparsify. Lookup SUPPORTED_TYPES + for type of data. Internally, a container module handles the data sparsification. + + defaults (dict) + default configurations will be attached to the + configuration. Only the keys that don't exist in the `config` will + be updated. + Example:: + >>> # xdoctest: +SKIP + >>> data_list = [('tensor_1', torch.randn(3,3)), ('tensor_2', torch.randn(4,4))] + >>> defaults = {'sparsity_level': 0.7} + >>> sparsifier = DerivedDataSparsifier(data_list = data_list, **defaults) # Some sparsifier that inherits BaseDataSparsifier + >>> new_tensor_to_add = {'name': 'tensor_3', 'data': torch.randn(5,5), 'sparsity_level': 0.3} + >>> sparsifier.add_data(**new_tensor_to_add) + >>> # tensor_1 and tensor_2 will have sparsity_level of 0.7 but tensor_3 will have sparsity_level=0.3 + """ + + def __init__(self, data_list: Optional[list[tuple[str, Any]]] = None, **defaults): + super().__init__(defaults=defaults) + + self._container = _Container() + + self.data_groups: dict[str, dict] = defaultdict(dict) # name -> {**config} + if data_list is not None: + # add data with default config here + [self.add_data(name, data, **self.defaults) for name, data in data_list] + + def prepare(self, model, config): + raise NotImplementedError("this function is undefined for this class") + + def _extract_weight(self, data): + # extract the weight parameter instead of underlying data + if type(data) in [torch.Tensor, nn.Parameter]: + return data + elif type(data) in EMBEDDING_TYPES: + return data.weight + + def add_data(self, name: str, data, reuse_mask=True, **config): + r"""Configures and parametrizes the internal container model with name and data. + + **Note**: + 1. If the data with name already exists, it replaces the data. + 2. While replacing, the old mask is reused when `reuse_mask=True` + 3. If `reuse_mask=True`, then the replacing data needs to have the same shape as that of old data. + 4. By default, the config of the replaced data is used as config for the replacing data, unless something + is specified in the config dictionary. + """ + assert type(data) in SUPPORTED_TYPES, ( + "specified data type not supported at the moment" + ) + local_args = copy.deepcopy(self.defaults) + local_args.update(config) + weight = self._extract_weight(data) + + # Bookkeeping in the container class + mask = local_args.get("mask", torch.ones_like(weight)) + param_class = local_args.get("parametrization", utils.FakeSparsity) + + if name in self.state: + # If the named data already exists - replace + warnings.warn( + "Replacing existing data of the same name. - Did you mean a different name?" + ) + + # reuse old config + old_args = self.data_groups[name] + local_args = copy.deepcopy(old_args) + local_args.update(config) + + if reuse_mask: + current_data = self.get_data(name=name) + assert weight.shape == current_data.shape, ( + "to retain the old mask, the shape of the new data must be the same as the previous one" + ) + mask = self.get_mask( + name=name + ) # reuse mask instead of creating a new one + + self._delete_data(name=name) + + # parameter creates a deepcopy of the weight inside, so create a buffer + self._container.register_buffer(name=name, tensor=weight) + parametrize.register_parametrization(self._container, name, param_class(mask)) + self.state[name]["mask"] = mask + self.data_groups[name] = local_args + return getattr(self._container, name) + + def get_data(self, name: str, return_original: bool = True): + r"""Returns weight tensor (or data) + Args: + - name: name of the data to be returned + - return_original returns weight tensor without applying parametrization if True + else - returns the sparsified version (parametrized) + """ + if name not in self.data_groups: + raise ValueError("data with specified name does not exist") + + if return_original: + if not parametrize.is_parametrized(self._container, name): + raise ValueError("mask squashed - original mask value does not exist") + data = getattr(self._container.parametrizations, name).original + return data + else: + return getattr(self._container, name) + + def _convert_mask(self, states, sparse_coo=True): + r"""Converts the mask to sparse coo or dense tensors depending on the `sparse_coo` argument.""" + states = copy.deepcopy(states) + for state in states.values(): + if sparse_coo: + state["mask"] = state["mask"].to_sparse_coo() + else: + state["mask"] = state["mask"].to_dense() + + return states + + def state_dict(self): + r"""Returns the state of the optimizer as a :class:`dict`. + + It contains: + * state - contains name -> mask mapping. + * data_groups - a list containing all sparsity configuration groups + with the key name specifying the name of the data + * container_state_dict - the state dictionary of the internal + container model used for sparsification + """ + state = self._convert_mask(self.state) + return { + "state": state, + "data_groups": self.data_groups, + "_container": self._container.state_dict(), + } + + def _load_container_from_state(self, states, data_groups, container_state_dict): + r"""This restores the state of the container specifically based on the data present in state and data_groups + If the data was parametrized, then the data would be added to the container and then parametrized, + else it would just add the attribute the container. + """ + for name, state in states.items(): + config_name = data_groups.get(name, None) + if config_name is None: + raise RuntimeError(f"Error loading {name}") + + # check if the data with such a name was parametrized, if so parametrize + # otherwise just set the attribute and continue + parametrized_name = f"parametrizations.{name}.original" + parametrized = False + data = container_state_dict.get(name, None) + if name in container_state_dict: + # the parametrization was probably removed for this + data = container_state_dict.get(name) + + elif parametrized_name in container_state_dict: + # so the weight was parametrized + data = container_state_dict.get(parametrized_name) + parametrized = True + + else: + raise RuntimeError(f"Error loading {name}") + + self._container.register_buffer(name=name, tensor=data) + + if parametrized: + # register parameter if parametrized + mask = state.get("mask", torch.ones_like(data)) + param_class = data_groups.get( + "parametrization", utils.FakeSparsity + ) # change once public_api for utils is fixed! + parametrize.register_parametrization( + self._container, name, param_class(mask) + ) + + def load_state_dict(self, state_dict, strict=True): + r"""The load_state_dict() restores the state of the sparsifier based on the state_dict + + Args: + * state_dict - the dictionary that to which the current sparsifier needs to be restored to + * strict - If True - the sparsifier is reset and is restored exactly to the state in state_dict. + If False - the current sparsifier is not reset before loading the state_dict i.e. data added + before loading the state_dict is not erased. + """ + states = copy.deepcopy(state_dict["state"]) + data_groups = copy.deepcopy(state_dict["data_groups"]) + container_state_dict = copy.deepcopy(state_dict["_container"]) + + states = self._convert_mask( + states, sparse_coo=False + ) # convert sparse coo mask to dense + if strict: + # if strict load -> then reset container + self._container = _Container() + + self._load_container_from_state(states, data_groups, container_state_dict) + + if not strict: + states.update(self.state) + data_groups.update(self.data_groups) + + self.__setstate__({"state": states, "data_groups": data_groups}) + + def __setstate__(self, state): + if "_container" in state: # If container object is in state then load model + container_dict = state.pop("_container") + self._container = _Container() + state["state"] = self._convert_mask( + state["state"], sparse_coo=False + ) # convert sparse coo mask to dense + self._load_container_from_state( + state["state"], state["data_groups"], container_dict + ) + + self.__dict__.update(state) + + def __getstate__(self): + state = self._convert_mask(self.state) + return { + "defaults": self.defaults, + "state": state, + "data_groups": self.data_groups, + "_container": self._container.state_dict(), + } + + def __repr__(self): # type:ignore[override] + format_string = self.__class__.__name__ + " (" + for name, sparse_args in self.data_groups.items(): + format_string += "\n" + format_string += "\tData Group\n" + format_string += f"\t name: {name}\n" + for key in sorted(sparse_args.keys()): + if key == "data": + continue + format_string += f"\t {key}: {sparse_args[key]}\n" + format_string += ")" + return format_string + + def get_mask(self, name: str): + if name not in self.state: + raise ValueError("data with specified name does not exist") + return self.state[name]["mask"] + + def squash_mask(self, *args, leave_parametrized=True, names=None, **kwargs): + r"""Squashes the sparse masks into the appropriate tensors. Also, accepts list of strings + to squash mask for. If none, squashes mask for all the keys + kwargs: + * names: list of strings to squash mask for + * sparsified: if true - applies the mask before squashing + if false - does not apply the mask before squashing + """ + if names is None: + names = list(self.data_groups.keys()) + for name in names: + parametrize.remove_parametrizations( + self._container, name, leave_parametrized=leave_parametrized + ) + + def step(self): # type:ignore[override] + if not self.enable_mask_update: + return + with torch.no_grad(): + for name, config in self.data_groups.items(): + # get non-sparsified data + data = self.get_data(name) + # need name for the mask otherwise can directly pass mask? + self.update_mask(name, data, **config) + + @abc.abstractmethod + def update_mask(self, name, data, **kwargs): # type: ignore[override] + pass + + def _delete_data(self, name): + """Detaches some data from the sparsifier. + + Args: + name (str) + Name of the data to be removed from the sparsifier + + Note: + Currently private. Kind of used as a helper function when replacing data of the same name + """ + self.squash_mask( + names=[name], leave_parametrized=False + ) # do not apply the mask while deleting + delattr(self._container, name) + self.state.pop(name) + self.data_groups.pop(name) diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py new file mode 100644 index 0000000000000000000000000000000000000000..cb4c0c3c51bdd095d5c663df2b4b2dfed1d49894 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py @@ -0,0 +1,203 @@ +# mypy: allow-untyped-defs +import operator +from functools import reduce +from typing import Any, Optional + +import torch +from torch.nn import functional as F + +from .base_data_sparsifier import BaseDataSparsifier + + +__all__ = ["DataNormSparsifier"] + + +class DataNormSparsifier(BaseDataSparsifier): + r"""L1-Norm Sparsifier + This sparsifier computes the *L1-norm* of every sparse block and "zeroes-out" the + ones with the lowest norm. The level of sparsity defines how many of the + blocks is removed. + This sparsifier is controlled by three variables: + 1. `sparsity_level` defines the number of *sparse blocks* that are zeroed-out + 2. `sparse_block_shape` defines the shape of the sparse blocks. Note that + the sparse blocks originate at the zero-index of the tensor. + 3. `zeros_per_block` is the number of zeros that we are expecting in each + sparse block. By default we assume that all elements within a block are + zeroed-out. However, setting this variable sets the target number of + zeros per block. The zeros within each block are chosen as the *smallest + absolute values*. + Args: + sparsity_level: The target level of sparsity + sparse_block_shape: The shape of a sparse block + zeros_per_block: Number of zeros in a sparse block + Note:: + All arguments to the DataNormSparsifier constructor are "default" + arguments and could be overridden by the configuration provided in the + `add_data` step. + """ + + def __init__( + self, + data_list: Optional[list[tuple[str, Any]]] = None, + sparsity_level: float = 0.5, + sparse_block_shape: tuple[int, int] = (1, 4), + zeros_per_block: Optional[int] = None, + norm: str = "L1", + ): + if zeros_per_block is None: + zeros_per_block = reduce(operator.mul, sparse_block_shape) + + assert norm in ["L1", "L2"], "only L1 and L2 norm supported at the moment" + + defaults = { + "sparsity_level": sparsity_level, + "sparse_block_shape": sparse_block_shape, + "zeros_per_block": zeros_per_block, + } + self.norm = norm + super().__init__(data_list=data_list, **defaults) + + def __get_scatter_folded_mask( + self, data, dim, indices, output_size, sparse_block_shape + ): + mask = torch.ones_like(data) + mask.scatter_(dim=dim, index=indices, value=0) # zeroing out + mask = F.fold( + mask, + output_size=output_size, + kernel_size=sparse_block_shape, + stride=sparse_block_shape, + ) + mask = mask.to(torch.int8) + return mask + + def __get_block_level_mask(self, data, sparse_block_shape, zeros_per_block): + # Assume data is a squeezed tensor + height, width = data.shape[-2], data.shape[-1] + block_height, block_width = sparse_block_shape + values_per_block = block_height * block_width + + # just return zeros if zeroing all elements in block + if values_per_block == zeros_per_block: + return torch.zeros_like(data, dtype=torch.int8) + + # creating additional height and width to support padding + dh = (block_height - height % block_height) % block_height + dw = (block_width - width % block_width) % block_width + + # create a new padded tensor like data (to match the block_shape) + padded_data = torch.ones( + height + dh, width + dw, dtype=data.dtype, device=data.device + ) + padded_data = ( + padded_data * torch.nan + ) # can also be replaced with 0 to stop the removal of edge data + padded_data[0:height, 0:width] = data + unfolded_data = F.unfold( + padded_data[None, None, :], + kernel_size=sparse_block_shape, + stride=sparse_block_shape, + ) + + _, sorted_idx = torch.sort(unfolded_data, dim=1) + sorted_idx = sorted_idx[ + :, :zeros_per_block, : + ] # zero out zeros_per_block number of elements + + mask = self.__get_scatter_folded_mask( + data=unfolded_data, + dim=1, + indices=sorted_idx, + output_size=padded_data.shape, + sparse_block_shape=sparse_block_shape, + ) + + mask = ( + mask.squeeze(0).squeeze(0)[:height, :width].contiguous() + ) # remove padding and make contiguous + return mask + + def __get_data_level_mask(self, data, sparsity_level, sparse_block_shape): + height, width = data.shape[-2], data.shape[-1] + block_height, block_width = sparse_block_shape + dh = (block_height - height % block_height) % block_height + dw = (block_width - width % block_width) % block_width + + data_norm = F.avg_pool2d( + data[None, None, :], + kernel_size=sparse_block_shape, + stride=sparse_block_shape, + ceil_mode=True, + ) + + values_per_block = reduce(operator.mul, sparse_block_shape) + + data_norm = data_norm.flatten() + num_blocks = len(data_norm) + + data_norm = data_norm.repeat( + 1, values_per_block, 1 + ) # get similar shape after unfold + _, sorted_idx = torch.sort(data_norm, dim=2) + + threshold_idx = round(sparsity_level * num_blocks) # number of blocks to remove + sorted_idx = sorted_idx[:, :, :threshold_idx] + + mask = self.__get_scatter_folded_mask( + data=data_norm, + dim=2, + indices=sorted_idx, + output_size=(height + dh, width + dw), + sparse_block_shape=sparse_block_shape, + ) + + mask = mask.squeeze(0).squeeze(0)[ + :height, :width + ] # squeeze only the first 2 dimension + return mask + + def update_mask( # type: ignore[override] + self, name, data, sparsity_level, sparse_block_shape, zeros_per_block, **kwargs + ): + values_per_block = reduce(operator.mul, sparse_block_shape) + if zeros_per_block > values_per_block: + raise ValueError( + "Number of zeros per block cannot be more than " + "the total number of elements in that block." + ) + if zeros_per_block < 0: + raise ValueError("Number of zeros per block should be positive.") + + if self.norm == "L1": + data_norm = torch.abs(data).squeeze() # absolute value based (L1) + else: + data_norm = (data * data).squeeze() # square every element for L2 + + if len(data_norm.shape) > 2: # only supports 2 dimensional data at the moment + raise ValueError("only supports 2-D at the moment") + + elif len(data_norm.shape) == 1: # in case the data is bias (or 1D) + data_norm = data_norm[None, :] + + mask = self.get_mask(name) + if sparsity_level <= 0 or zeros_per_block == 0: + mask.data = torch.ones_like(mask) + elif sparsity_level >= 1.0 and (zeros_per_block == values_per_block): + mask.data = torch.zeros_like(mask) + + # Fetch the high level mask that zeros out entire blocks + data_lvl_mask = self.__get_data_level_mask( + data=data_norm, + sparsity_level=sparsity_level, + sparse_block_shape=sparse_block_shape, + ) + + # Fetch block level mask that zeros out 'zeros_per_block' number of elements in every block + block_lvl_mask = self.__get_block_level_mask( + data=data_norm, + sparse_block_shape=sparse_block_shape, + zeros_per_block=zeros_per_block, + ) + + # zero out the entries inside those blocks whose block is sparsified + mask.data = torch.where(data_lvl_mask == 1, data_lvl_mask, block_lvl_mask) diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/__init__.py b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e923d74c20c15c07f93130c6467e1ad44170440 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__init__.py b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3102061c768077ac8b6190de09c6339094c7592a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__pycache__/_data_sparstity_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__pycache__/_data_sparstity_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ae6a694064b686c81d0d9f2932330a5408ed254 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__pycache__/_data_sparstity_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__pycache__/data_sparsity.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__pycache__/data_sparsity.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c266bc2f384b4c3de66ad574832d0a20f25ce432 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__pycache__/data_sparsity.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/_data_sparstity_utils.py b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/_data_sparstity_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4cacf3c813adccff91c6e2c0646d08665e03f18b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/_data_sparstity_utils.py @@ -0,0 +1,44 @@ +# mypy: allow-untyped-defs +import logging + +from torch.ao.pruning._experimental.data_sparsifier.base_data_sparsifier import ( + SUPPORTED_TYPES, +) + + +logger: logging.Logger = logging.getLogger(__name__) + + +def _attach_model_to_data_sparsifier(module, data_sparsifier, config=None): + """Attaches a data sparsifier to all the layers of the module. + Essentially, loop over all the weight parameters in the module and + attach it to the data sparsifier. + Note:: + The '.' in the layer names are replaced with '_' (refer to _get_valid_name() below) + before attaching to the sparsifier. This is because, the data + sparsifier uses a dummy model inside to store the weight parameters. + """ + if config is None: + config = {} + for name, parameter in module.named_parameters(): + if type(parameter) in SUPPORTED_TYPES: + valid_name = _get_valid_name(name) + # will be defaulted to default configs + data_sparsifier.add_data( + name=valid_name, data=parameter, **config.get(valid_name, {}) + ) + + +def _get_valid_name(name): + return name.replace(".", "_") # . is not allowed as a name + + +def _log_sparsified_level(model, data_sparsifier) -> None: + # Show the level of sparsity AFTER step: + for name, parameter in model.named_parameters(): + if type(parameter) not in SUPPORTED_TYPES: + continue + valid_name = _get_valid_name(name) + mask = data_sparsifier.get_mask(name=valid_name) + sparsity_level = 1.0 - mask.float().mean() + logger.info("Sparsity in layer %s = % .2%", name, sparsity_level) diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/data_sparsity.py b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/data_sparsity.py new file mode 100644 index 0000000000000000000000000000000000000000..bf2c02f1066c362f687399d3311592114528fba2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/data_sparsity.py @@ -0,0 +1,181 @@ +# mypy: allow-untyped-defs +from collections import defaultdict +from copy import deepcopy +from typing import Any, Optional, TYPE_CHECKING + +import pytorch_lightning as pl # type: ignore[import] + +from ._data_sparstity_utils import ( + _attach_model_to_data_sparsifier, + _get_valid_name, + _log_sparsified_level, +) + + +if TYPE_CHECKING: + import torch + + +class PostTrainingDataSparsity(pl.callbacks.Callback): + """Lightning callback that enables post-training sparsity. + + This callback aims to sparsify the model inside lightning module after training. + **Note that the model is copied and then sparsified, so the existing model is not modified** + + The sparsified model can be used for comparison and can be accessed using + .sparsified + + Args: + data_sparsifier_class (some implemented class of BaseDataSparsifier) + The data sparsifier object of this class is created when the + training starts. + Note: Objects should not be passed in here as they are created + once the training completes. + + data_sparsifier_args (Dict) + Dictionary of args to be passed to the data sparsifier. + Note: data_list arg should be ignored + + Hooks implemented: + on_fit_end() + 1. copies the model and attaches it to the sparsifier + 2. sparsier step() is called + 3. squashes the mask() + """ + + def __init__(self, data_sparsifier_class, data_sparsifier_args): + super().__init__() + self.data_sparsifier_class = data_sparsifier_class + self.data_sparsifier_args = data_sparsifier_args + self.data_sparsifier: Any = None + self.sparsified: Optional[torch.nn.Module] = None + + def on_fit_end(self, trainer, pl_module) -> None: + self.sparsified = deepcopy(pl_module.model).eval() + self.data_sparsifier = self.data_sparsifier_class(**self.data_sparsifier_args) + + _attach_model_to_data_sparsifier(self.sparsified, self.data_sparsifier) + + self.data_sparsifier.step() + + self.data_sparsifier.squash_mask() # currently squashes params for all mask + + _log_sparsified_level(self.sparsified, self.data_sparsifier) + + +class TrainingAwareDataSparsity(pl.callbacks.Callback): + """Lightning callback that enables in-training sparsity. + + This callback aims to sparsify the model inside lightning module during training. + **Note that the model is copied and then sparsified, so the existing model is not modified** + + The sparsified model can be used for comparison and can be accessed using + .sparsified + + Args: + data_sparsifier_class (some implemented class of BaseDataSparsifier) + The data sparsifier object of this class is created when the + training starts. + Note: Objects should not be passed in here as they are created + when the training starts. + + data_sparsifier_args (Dict) + Dictionary of args to be passed to the data sparsifier. + Note: data_list arg should be ignored + + data_scheduler_class (some implemented class of BaseDataScheduler) + The data scheduler of this class is created when the training starts + Note: Objects should not be passed in here as they are created + when the training starts. + + data_scheduler_args(Dict) + Dictionary of args to be passed to the data scheduler. + **Note: data_sparsifier arg should be ignored as the recipe + creates and pass sparsifier object into the class** + + Hooks implemented: + on_train_start() + Data sparsifier and scheduler objects are created. + Pytorch model attached to the sparsifier + + on_train_epoch_start() + Loads the state_dict of the data sparsifier + + on_train_epoch_end() + 1. Copies the model and attaches it to the sparsifier + 2. sparsifier step() and scheduler step() + 3. Dump state_dict of the current sparsifier + + on_train_end() + squash mask + """ + + def __init__( + self, + data_sparsifier_class, + data_sparsifier_args, + data_scheduler_class, + data_scheduler_args, + ): + super().__init__() + # data sparsifier objects + self.data_sparsifier_class = data_sparsifier_class + self.data_sparsifier_args = data_sparsifier_args + + # scheduler objects + self.data_scheduler_class = data_scheduler_class + self.data_scheduler_args = data_scheduler_args + + # fields + self.data_sparsifier: Any = None + self.data_scheduler: Any = None + self.sparsified: Optional[torch.nn.Module] = None + + self.data_sparsifier_state_dict: Any = None + + def on_train_start(self, trainer, pl_module) -> None: + # create sparsifier + self.data_sparsifier = self.data_sparsifier_class(**self.data_sparsifier_args) + self.sparsified = deepcopy(pl_module.model) + + _attach_model_to_data_sparsifier( + self.sparsified, self.data_sparsifier + ) # just to populate the base_sl in the scheduler + + # create scheduler + args = deepcopy(self.data_scheduler_args) + args["data_sparsifier"] = self.data_sparsifier + self.data_scheduler = self.data_scheduler_class(**args) + + def on_train_epoch_start(self, trainer, pl_module): + if self.data_sparsifier_state_dict is None: + return # probably first epoch + + # load the existing config for each data + self.data_sparsifier.load_state_dict(self.data_sparsifier_state_dict) + + def __create_config_based_on_state(self, pl_module): + config: dict = defaultdict() + if self.data_sparsifier_state_dict is None: + return config + for name, _ in pl_module.model.named_parameters(): + valid_name = _get_valid_name(name) + config[valid_name] = self.data_sparsifier.data_groups[valid_name] + + return config + + def on_train_epoch_end(self, trainer, pl_module): + self.sparsified = deepcopy(pl_module.model) + config = self.__create_config_based_on_state(pl_module) + + # attach model to the data sparsifier + _attach_model_to_data_sparsifier( + self.sparsified, self.data_sparsifier, config=config + ) + self.data_sparsifier.step() + self.data_scheduler.step() + + self.data_sparsifier_state_dict = self.data_sparsifier.state_dict() + + def on_train_end(self, trainer, pl_module): + self.data_sparsifier.squash_mask() diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ec50e72f620a59e2b730d43dc351aaa254e4d6eb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py @@ -0,0 +1,150 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch +import torch.nn as nn +from torch.ao.pruning.sparsifier.utils import fqn_to_module, module_to_fqn + + +SUPPORTED_MODULES = {nn.Embedding, nn.EmbeddingBag} + + +def _fetch_all_embeddings(model): + """Fetches Embedding and EmbeddingBag modules from the model""" + embedding_modules = [] + stack = [model] + while stack: + module = stack.pop() + for _, child in module.named_children(): + fqn_name = module_to_fqn(model, child) + if type(child) in SUPPORTED_MODULES: + embedding_modules.append((fqn_name, child)) + else: + stack.append(child) + return embedding_modules + + +def post_training_sparse_quantize( + model, + data_sparsifier_class, + sparsify_first=True, + select_embeddings: Optional[list[nn.Module]] = None, + **sparse_config, +): + """Takes in a model and applies sparsification and quantization to only embeddings & embeddingbags. + The quantization step can happen before or after sparsification depending on the `sparsify_first` argument. + + Args: + - model (nn.Module) + model whose embeddings needs to be sparsified + - data_sparsifier_class (type of data sparsifier) + Type of sparsification that needs to be applied to model + - sparsify_first (bool) + if true, sparsifies first and then quantizes + otherwise, quantizes first and then sparsifies. + - select_embeddings (List of Embedding modules) + List of embedding modules to in the model to be sparsified & quantized. + If None, all embedding modules with be sparsified + - sparse_config (Dict) + config that will be passed to the constructor of data sparsifier object. + + Note: + 1. When `sparsify_first=False`, quantization occurs first followed by sparsification. + - before sparsifying, the embedding layers are dequantized. + - scales and zero-points are saved + - embedding layers are sparsified and `squash_mask` is applied + - embedding weights are requantized using the saved scales and zero-points + 2. When `sparsify_first=True`, sparsification occurs first followed by quantization. + - embeddings are sparsified first + - quantization is applied on the sparsified embeddings + """ + data_sparsifier = data_sparsifier_class(**sparse_config) + + # if select_embeddings is None, perform it on all embeddings + if select_embeddings is None: + embedding_modules = _fetch_all_embeddings(model) + + else: + embedding_modules = [] + assert isinstance(select_embeddings, list), ( + "the embedding_modules must be a list of embedding modules" + ) + for emb in select_embeddings: + assert type(emb) in SUPPORTED_MODULES, ( + "the embedding_modules list must be an embedding or embedding bags" + ) + fqn_name = module_to_fqn(model, emb) + assert fqn_name is not None, ( + "the embedding modules must be part of input model" + ) + embedding_modules.append((fqn_name, emb)) + + if sparsify_first: + # sparsify + for name, emb_module in embedding_modules: + valid_name = name.replace(".", "_") + data_sparsifier.add_data(name=valid_name, data=emb_module) + + data_sparsifier.step() + data_sparsifier.squash_mask() + + # quantize + for _, emb_module in embedding_modules: + emb_module.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig + + torch.ao.quantization.prepare(model, inplace=True) + torch.ao.quantization.convert(model, inplace=True) + + else: + # quantize + for _, emb_module in embedding_modules: + emb_module.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig + + torch.ao.quantization.prepare(model, inplace=True) + torch.ao.quantization.convert(model, inplace=True) + + # retrieve scale & zero_points + quantize_params: dict[str, dict] = { + "scales": {}, + "zero_points": {}, + "dequant_weights": {}, + "axis": {}, + "dtype": {}, + } + + for name, _ in embedding_modules: + quantized_emb = fqn_to_module(model, name) + assert quantized_emb is not None # satisfy mypy + + quantized_weight = quantized_emb.weight() # type: ignore[operator] + quantize_params["scales"][name] = quantized_weight.q_per_channel_scales() + quantize_params["zero_points"][name] = ( + quantized_weight.q_per_channel_zero_points() + ) + quantize_params["dequant_weights"][name] = torch.dequantize( + quantized_weight + ) + quantize_params["axis"][name] = quantized_weight.q_per_channel_axis() + quantize_params["dtype"][name] = quantized_weight.dtype + + # attach data to sparsifier + data_sparsifier.add_data( + name=name.replace(".", "_"), + data=quantize_params["dequant_weights"][name], + ) + + data_sparsifier.step() + data_sparsifier.squash_mask() + + for name, _ in embedding_modules: + quantized_emb = fqn_to_module(model, name) + assert quantized_emb is not None # satisfy mypy + requantized_vector = torch.quantize_per_channel( + quantize_params["dequant_weights"][name], + scales=quantize_params["scales"][name], + zero_points=quantize_params["zero_points"][name], + dtype=quantize_params["dtype"][name], + axis=quantize_params["axis"][name], + ) + + quantized_emb.set_weight(requantized_vector) # type: ignore[operator] diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py new file mode 100644 index 0000000000000000000000000000000000000000..ef73d3451f0eb25ee774c0d073ca2d51194ae86f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py @@ -0,0 +1,97 @@ +# mypy: allow-untyped-defs +from typing import Callable, Optional, Union + +import torch + +from .base_structured_sparsifier import BaseStructuredSparsifier + + +__all__ = ["FPGMPruner"] + + +class FPGMPruner(BaseStructuredSparsifier): + r"""Filter Pruning via Geometric Median (FPGM) Structured Pruner + This sparsifier prune fliter (row) in a tensor according to distances among filters according to + `Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration `_. + + This sparsifier is controlled by three variables: + 1. `sparsity_level` defines the number of filters (rows) that are zeroed-out. + 2. `dist` defines the distance measurement type. Default: 3 (L2 distance). + Available options are: [1, 2, (custom callable distance function)]. + + Note:: + Inputs should be a 4D convolutional tensor of shape (N, C, H, W). + - N: output channels size + - C: input channels size + - H: height of kernel + - W: width of kernel + """ + + def __init__( + self, sparsity_level: float = 0.5, dist: Optional[Union[Callable, int]] = None + ): + defaults = { + "sparsity_level": sparsity_level, + } + + if dist is None: + dist = 2 + + if callable(dist): + self.dist_fn = dist + elif dist == 1: + self.dist_fn = lambda x: torch.cdist(x, x, p=1) + elif dist == 2: + self.dist_fn = lambda x: torch.cdist(x, x, p=2) + else: + raise NotImplementedError("Distance function is not yet implemented.") + super().__init__(defaults=defaults) + + def _compute_distance(self, t): + r"""Compute distance across all entries in tensor `t` along all dimension + except for the one identified by dim. + Args: + t (torch.Tensor): tensor representing the parameter to prune + Returns: + distance (torch.Tensor): distance computed across filtters + """ + dim = 0 # prune filter (row) + + size = t.size(dim) + slc = [slice(None)] * t.dim() + + # flatten the tensor along the dimension + t_flatten = [ + t[tuple(slc[:dim] + [slice(i, i + 1)] + slc[dim + 1 :])].reshape(-1) + for i in range(size) + ] + t_flatten = torch.stack(t_flatten) + + # distance measurement + dist_matrix = self.dist_fn(t_flatten) + + # more similar with other filter indicates large in the sum of row + distance = torch.sum(torch.abs(dist_matrix), 1) + + return distance + + def update_mask( # type: ignore[override] + self, module, tensor_name, sparsity_level, **kwargs + ): + tensor_weight = getattr(module, tensor_name) + mask = getattr(module.parametrizations, tensor_name)[0].mask + + if sparsity_level <= 0: + mask.data = torch.ones_like(mask).bool() + elif sparsity_level >= 1.0: + mask.data = torch.zeros_like(mask).bool() + else: + distance = self._compute_distance(tensor_weight) + + tensor_size = tensor_weight.shape[0] # prune filter (row) + nparams_toprune = round(sparsity_level * tensor_size) + nparams_toprune = min( + max(nparams_toprune, 0), tensor_size + ) # clamp to [0, tensor_size] + topk = torch.topk(distance, k=nparams_toprune, largest=False) + mask[topk.indices] = False diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__init__.py b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c67893adbc16ecc1d6f47e0e99298dd41dd9a43e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__init__.py @@ -0,0 +1,5 @@ +from .base_structured_sparsifier import BaseStructuredSparsifier +from .FPGM_pruner import FPGMPruner +from .lstm_saliency_pruner import LSTMSaliencyPruner +from .parametrization import BiasHook, FakeStructuredSparsity +from .saliency_pruner import SaliencyPruner diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/FPGM_pruner.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/FPGM_pruner.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da4b18c0895a4fba84c5777b1cc5fa55307056f3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/FPGM_pruner.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f06b6a0e5f7f7abf95fa1fd1c35d5c8ba2117d7d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/base_structured_sparsifier.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/base_structured_sparsifier.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..470c020419f3226d816ce6ed573fab61c235f368 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/base_structured_sparsifier.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/lstm_saliency_pruner.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/lstm_saliency_pruner.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f1ca4b06ef0b93e8c5e440aceb2bda84590c792 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/lstm_saliency_pruner.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/match_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/match_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b9d6ee110dfb1321471f6f3564c3981caa692fc Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/match_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/parametrization.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/parametrization.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d885ba8efb3b747ba01a0b8458e5e01d0656b617 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/parametrization.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/prune_functions.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/prune_functions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a2d178b7cdd6aed213f27a9df750dec6ec410f1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/prune_functions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/saliency_pruner.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/saliency_pruner.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9fcaa67e56d4f3bd947f90cd62412d34f11ff07 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/saliency_pruner.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py new file mode 100644 index 0000000000000000000000000000000000000000..6cfb93e312d582abc1c45c6b5e20dd0c0f0ea94d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py @@ -0,0 +1,312 @@ +# mypy: allow-untyped-defs +from itertools import chain +from operator import getitem +from typing import Callable, Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn +from torch.ao.pruning.sparsifier.base_sparsifier import BaseSparsifier +from torch.fx import symbolic_trace +from torch.nn.utils import parametrize + +from .match_utils import apply_match, MatchAllNode +from .parametrization import BiasHook, FakeStructuredSparsity, module_contains_param +from .prune_functions import ( + prune_conv2d, + prune_conv2d_activation_conv2d, + prune_conv2d_activation_pool_conv2d, + prune_conv2d_conv2d, + prune_conv2d_pool_activation_conv2d, + prune_conv2d_pool_flatten_linear, + prune_linear, + prune_linear_activation_linear, + prune_linear_linear, + prune_lstm_output_layernorm_linear, + prune_lstm_output_linear, +) + + +def _get_supported_structured_pruning_modules(): + SUPPORTED_STRUCTURED_PRUNING_MODULES = { # added to config if None given + nn.Linear, + nn.Conv2d, + nn.LSTM, + } + return SUPPORTED_STRUCTURED_PRUNING_MODULES + + +def _get_supported_activation_functions(): + SUPPORTED_ACTIVATION_FUNCTIONS = { + F.relu, + F.rrelu, + F.hardtanh, + F.relu6, + F.sigmoid, + F.hardsigmoid, + F.tanh, + F.silu, + F.mish, + F.hardswish, + F.elu, + F.celu, + F.selu, + F.hardshrink, + F.leaky_relu, + F.logsigmoid, + F.softplus, + F.prelu, + F.softsign, + F.tanhshrink, + F.gelu, + } + return SUPPORTED_ACTIVATION_FUNCTIONS + + +def _get_supported_activation_modules(): + SUPPORTED_ACTIVATION_MODULES = { + nn.ReLU, + nn.RReLU, + nn.Hardtanh, + nn.ReLU6, + nn.Sigmoid, + nn.Hardsigmoid, + nn.Tanh, + nn.SiLU, + nn.Mish, + nn.Hardswish, + nn.ELU, + nn.CELU, + nn.SELU, + nn.Hardshrink, + nn.LeakyReLU, + nn.LogSigmoid, + nn.Softplus, + nn.PReLU, + nn.Softsign, + nn.Tanhshrink, + nn.GELU, + } + return SUPPORTED_ACTIVATION_MODULES + + +def _get_default_structured_pruning_patterns() -> dict[ + tuple[Union[type[nn.Module], Callable, MatchAllNode, str], ...], + Callable[..., None], +]: + """ + Returns the patterns for conv2d / linear conversion for each element in the activation functions/modules defined above. + """ + patterns: dict[ + tuple[Union[type[nn.Module], Callable, MatchAllNode, str], ...], + Callable[..., None], + ] = { + # linear -> linear + (nn.Linear, "output"): prune_linear, + (nn.Linear, nn.Linear): prune_linear_linear, + # conv2d -> conv2d + (nn.Conv2d, "output"): prune_conv2d, + (nn.Conv2d, nn.Conv2d): prune_conv2d_conv2d, + # TODO LSTM Structured pruning does not support returned state currently. + # Should find a way to explicitly match getitem(0) instead of getitem. + # This will also require changing the pruning function. + # lstm -> getitem(0) -> linear + (nn.LSTM, getitem, nn.Linear): prune_lstm_output_linear, + # lstm -> getitem(0) -> layernorm -> linear + (nn.LSTM, getitem, nn.LayerNorm, nn.Linear): prune_lstm_output_layernorm_linear, + } + + for activation in chain( + _get_supported_activation_functions(), _get_supported_activation_modules() + ): + patterns.update( + { + # linear -> activation -> linear + (nn.Linear, activation, nn.Linear): prune_linear_activation_linear, + # conv2d -> activation -> conv2d + (nn.Conv2d, activation, nn.Conv2d): prune_conv2d_activation_conv2d, + # conv2d -> activation -> pool -> conv2d + ( + nn.Conv2d, + activation, + nn.AvgPool2d, + nn.Conv2d, + ): prune_conv2d_activation_pool_conv2d, + ( + nn.Conv2d, + activation, + F.avg_pool2d, + nn.Conv2d, + ): prune_conv2d_activation_pool_conv2d, + ( + nn.Conv2d, + activation, + nn.MaxPool2d, + nn.Conv2d, + ): prune_conv2d_activation_pool_conv2d, + ( + nn.Conv2d, + activation, + F.max_pool2d, + nn.Conv2d, + ): prune_conv2d_activation_pool_conv2d, + # conv2d -> pool -> activation -> conv2d + ( + nn.Conv2d, + nn.AvgPool2d, + activation, + nn.Conv2d, + ): prune_conv2d_pool_activation_conv2d, + ( + nn.Conv2d, + F.avg_pool2d, + activation, + nn.Conv2d, + ): prune_conv2d_pool_activation_conv2d, + ( + nn.Conv2d, + nn.MaxPool2d, + activation, + nn.Conv2d, + ): prune_conv2d_pool_activation_conv2d, + ( + nn.Conv2d, + F.max_pool2d, + activation, + nn.Conv2d, + ): prune_conv2d_pool_activation_conv2d, + # conv2d -> adaptive pool -> flatten -> linear + ( + nn.Conv2d, + nn.AdaptiveAvgPool2d, + nn.Flatten, + nn.Linear, + ): prune_conv2d_pool_flatten_linear, + ( + nn.Conv2d, + nn.AdaptiveAvgPool2d, + torch.flatten, + nn.Linear, + ): prune_conv2d_pool_flatten_linear, + ( + nn.Conv2d, + nn.AdaptiveMaxPool2d, + nn.Flatten, + nn.Linear, + ): prune_conv2d_pool_flatten_linear, + ( + nn.Conv2d, + nn.AdaptiveMaxPool2d, + torch.flatten, + nn.Linear, + ): prune_conv2d_pool_flatten_linear, + } + ) + return patterns + + +class BaseStructuredSparsifier(BaseSparsifier): + r"""Base class for structured pruning. + + Abstract methods that need to be implemented: + - update_mask: Function to compute a new mask for all keys in the + `groups` attribute. + + Args: + - defaults [dict]: default configurations will be attached to the + configuration. Only the keys that don't exist in the `config` will + be updated. + """ + + def __init__(self, defaults, patterns=None): + super().__init__(defaults) + if patterns is None: + patterns = _get_default_structured_pruning_patterns() + self.patterns = patterns + + def make_config_from_model( + self, + model: nn.Module, + SUPPORTED_MODULES: Optional[set[type]] = None, + ) -> None: + if SUPPORTED_MODULES is None: + SUPPORTED_MODULES = _get_supported_structured_pruning_modules() + super().make_config_from_model(model, SUPPORTED_MODULES=SUPPORTED_MODULES) + + def _prepare(self, *args, **kwargs) -> None: + r"""This function will attach the FakeStructuredSparsity parameterizations + and BiasHooks at the appropriate points in the model. + """ + for config in self.groups: + module = config["module"] + tensor_name = config["tensor_name"] + parametrization = config.get("parametrization", FakeStructuredSparsity) + tensor = getattr(module, tensor_name) + + mask = config.get( + "mask", + torch.ones(tensor.shape[0], dtype=torch.bool, device=tensor.device), + ) + self.state[config["tensor_fqn"]]["mask"] = mask + parametrize.register_parametrization( + module, tensor_name, parametrization(mask) + ) + + # if linear / conv, we add in bias hooks + if isinstance(module, (nn.Linear, nn.Conv2d)): + prune_bias = config.get("prune_bias", True) + if module.bias is not None: + module.register_parameter( + "_bias", nn.Parameter(module.bias.detach()) + ) + module.bias = None + module.prune_bias = prune_bias + + module.register_forward_hook( + BiasHook(module.parametrizations.weight[0], prune_bias) # type: ignore[union-attr, index] + ) + + def prune(self) -> None: + r""" + This function will FX symbolically trace the model and then find instances of the patterns + defined in self.patterns (by default SUPPORTED_STRUCTURED_PRUNING_PATTERNS ). + + For each pattern, it will apply to corresponding conversion function, which will modify the output + and input size expected by the modules within the pattern + """ + + self.traced = symbolic_trace(self.model) + modules = dict(self.traced.named_modules()) + + # Right now we check for matches simply by iterating across all the patterns + # if this is slow we can store patterns in a trie-structure and modify this code for faster lookup + for node in self.traced.graph.nodes: + for pattern, convert_fn in self.patterns.items(): + matched = apply_match(modules, pattern, node, []) + if matched is None: + continue + + first_module = modules.get(node.target) + # check if first module exists and has appropriate parameterization, otherwise skip + if ( + first_module is not None + and parametrize.is_parametrized(first_module) + and module_contains_param(first_module, FakeStructuredSparsity) + ): + convert_block = [] + for node in matched: + if node.op == "call_module": + convert_block.append(modules.get(node.target)) + elif node.op == "call_function": + convert_block.append(node.target) + convert_fn(*convert_block) + + for module in self.traced.modules(): + if module_contains_param(module, FakeStructuredSparsity): + raise Exception( # noqa: TRY002 + f"Error: {module} still contains FakeStructuredSparsity parametrizations!" + ) + + self.traced.graph.lint() + self.traced.recompile() + return self.traced # type: ignore[return-value] diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/lstm_saliency_pruner.py b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/lstm_saliency_pruner.py new file mode 100644 index 0000000000000000000000000000000000000000..60ba386878a2e7930dadf611fbd60c60e2d6212d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/lstm_saliency_pruner.py @@ -0,0 +1,53 @@ +# mypy: allow-untyped-defs +from typing import cast + +import torch + +from .base_structured_sparsifier import BaseStructuredSparsifier, FakeStructuredSparsity + + +class LSTMSaliencyPruner(BaseStructuredSparsifier): + """ + Prune packed LSTM weights based on saliency. + For each layer {k} inside a LSTM, we have two packed weight matrices + - weight_ih_l{k} + - weight_hh_l{k} + + These tensors pack the weights for the 4 linear layers together for efficiency. + + [W_ii | W_if | W_ig | W_io] + + Pruning this tensor directly will lead to weights being misassigned when unpacked. + To ensure that each packed linear layer is pruned the same amount: + 1. We split the packed weight into the 4 constituent linear parts + 2. Update the mask for each individual piece using saliency individually + + This applies to both weight_ih_l{k} and weight_hh_l{k}. + """ + + def update_mask(self, module, tensor_name, **kwargs): + weights = getattr(module, tensor_name) + + for p in getattr(module.parametrizations, tensor_name): + if isinstance(p, FakeStructuredSparsity): + mask = cast(torch.Tensor, p.mask) + + # select weights based on magnitude + if weights.dim() <= 1: + raise Exception( # noqa: TRY002 + "Structured pruning can only be applied to a 2+dim weight tensor!" + ) + # take norm over all but first dim + dims = tuple(range(1, weights.dim())) + saliency = weights.norm(dim=dims, p=1) + + # handle weights in 4 groups + split_size = len(mask) // 4 + masks = torch.split(mask, split_size) + saliencies = torch.split(saliency, split_size) + + for keep_mask, sal in zip(masks, saliencies): + # mask smallest k values to be removed + k = int(len(keep_mask) * kwargs["sparsity_level"]) + prune = sal.topk(k, largest=False, sorted=False).indices + keep_mask.data[prune] = False # modifies underlying p.mask directly diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/match_utils.py b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/match_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..824d0e9f780fb7cb0d3051f8fd99bcc24dfa4987 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/match_utils.py @@ -0,0 +1,65 @@ +""" +Contains utility functions to check if a pattern is in the graph and return the matching nodes +""" + +from typing import Any, Optional, Union + +import torch +from torch import nn +from torch.ao.quantization.utils import MatchAllNode +from torch.fx import Node +from torch.nn.utils import parametrize + + +def _match( + modules: dict[str, nn.ModuleDict], + node: Node, + current: Union[nn.Module, Any], +) -> bool: + r""" + checks to see if a single node of a pattern matches + """ + if isinstance(current, type) and issubclass(current, MatchAllNode): + return True + if not isinstance(node, Node): + return False + if isinstance(current, type) and issubclass(current, torch.nn.Module): + return ( + node.op == "call_module" + and parametrize.type_before_parametrizations(modules[node.target]) # type: ignore[index] + == current + ) + elif callable(current): + return node.op == "call_function" and node.target is current + elif isinstance(current, str): + return node.target == current + return False + + +def apply_match( + modules: dict[str, nn.ModuleDict], + pattern: Union[tuple[Any], Any], + node: Node, + matched_node_pattern: list[Node], +) -> Optional[list[Node]]: + r""" + This function will return the matched nodes if the pattern matches the node given + If there is no match, it will return None + """ + if isinstance(pattern, tuple): + if len(pattern) == 1: + if _match(modules, node, pattern[0]): + return matched_node_pattern + [node] + + first, *rest = pattern + if _match(modules, node, first): + if rest is None: + return matched_node_pattern + [node] + + for user in node.users: + return apply_match( + modules, tuple(rest), user, matched_node_pattern + [node] + ) + elif _match(modules, node, pattern): + return [node] + return None diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/parametrization.py b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/parametrization.py new file mode 100644 index 0000000000000000000000000000000000000000..d1748d4e89613be56ecff4dcec204d26ae62dcef --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/parametrization.py @@ -0,0 +1,59 @@ +# mypy: allow-untyped-defs +import torch +from torch import nn +from torch.nn.utils.parametrize import is_parametrized + + +def module_contains_param(module, parametrization): + if is_parametrized(module): + # see if any of the module tensors have a parametriztion attached that matches the one passed in + return any( + any(isinstance(param, parametrization) for param in param_list) + for key, param_list in module.parametrizations.items() + ) + return False + + +# Structured Pruning Parameterizations +class FakeStructuredSparsity(nn.Module): + r""" + Parametrization for Structured Pruning. Like FakeSparsity, this should be attached to + the 'weight' or any other parameter that requires a mask. + + Instead of an element-wise bool mask, this parameterization uses a row-wise bool mask. + """ + + def __init__(self, mask): + super().__init__() + self.register_buffer("mask", mask) + + def forward(self, x): + assert isinstance(self.mask, torch.Tensor) + assert self.mask.shape[0] == x.shape[0] + shape = [1] * len(x.shape) + shape[0] = -1 + return self.mask.reshape(shape) * x + + def state_dict(self, *args, **kwargs): + # avoid double saving masks + return {} + + +class BiasHook: + def __init__(self, parametrization, prune_bias): + self.param = parametrization + self.prune_bias = prune_bias + + def __call__(self, module, input, output): + if getattr(module, "_bias", None) is not None: + bias = module._bias.data + if self.prune_bias: + bias[~self.param.mask] = 0 + + # reshape bias to broadcast over output dimensions + idx = [1] * len(output.shape) + idx[1] = -1 + bias = bias.reshape(idx) + + output += bias + return output diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/prune_functions.py b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/prune_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..ffdaec3c6a19a3d3f58bf77d5370eda773910861 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/prune_functions.py @@ -0,0 +1,479 @@ +# mypy: allow-untyped-defs +""" +Collection of conversion functions for linear / conv2d structured pruning +Also contains utilities for bias propagation +""" + +from typing import Callable, cast, Optional + +import torch +from torch import nn, Tensor +from torch.nn.utils import parametrize +from torch.nn.utils.parametrize import ParametrizationList + +from .parametrization import BiasHook, FakeStructuredSparsity + + +# BIAS PROPAGATION +def _remove_bias_handles(module: nn.Module) -> None: + if hasattr(module, "_forward_hooks"): + bias_hooks: list[int] = [] + for key, hook in module._forward_hooks.items(): + if isinstance(hook, BiasHook): + bias_hooks.append(key) + + for key in bias_hooks: + del module._forward_hooks[key] + + +def _get_adjusted_next_layer_bias( + next_layer: nn.Module, pruned_biases: Tensor, mask: Tensor +) -> nn.Parameter: + r"""Returns new adjusted bias for the second supported module""" + if parametrize.is_parametrized(next_layer): + # need to access original weight + parametrization_dict = cast(nn.ModuleDict, next_layer.parametrizations) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict.weight + ) + next_weight = weight_parameterizations.original + else: + next_weight = cast(Tensor, next_layer.weight) + + scaling_weight = next_weight[:, ~mask] + if isinstance(next_layer, nn.Conv2d): # checking for Conv2d + # Propagating first layer pruned biases and calculating the new second layer bias + # involves more steps since the Conv2d scaling weight has extra dimensions, + # so adding bias involves broadcasting, logically: + # for each channel k in range(oC): + # scaled_biases = sum(first_bias[pruned_idx] @ next_weight[k, pruned_idx, :, :].T) + # new_next_bias[k] = old_next_bias[k] + scaled_biases + scaling_product = torch.matmul( + pruned_biases.reshape(1, -1), torch.transpose(scaling_weight, 1, 2) + ) + sum_range = list(range(len(scaling_product.shape)))[ + 1: + ] # all but the first dimension + scaled_biases = torch.sum(scaling_product, sum_range) + elif isinstance(next_layer, nn.Linear): # Linear + scaled_biases = torch.matmul( + pruned_biases, torch.transpose(scaling_weight, 0, 1) + ) # recall b2_new = b1 @ w2.T + b2 + else: + raise NotImplementedError(f"Type {type(next_layer)} not supported yet.") + + if ( + parametrize.is_parametrized(next_layer) + and getattr(next_layer, "_bias", None) is not None + ): # next_layer is parametrized & has original bias ._bias + adjusted_bias = nn.Parameter(scaled_biases + next_layer._bias) # type: ignore[operator] + elif ( + not parametrize.is_parametrized(next_layer) and next_layer.bias is not None + ): # next_layer not parametrized & has .bias + adjusted_bias = nn.Parameter(scaled_biases + next_layer.bias) # type: ignore[operator] + else: # next_layer has no bias + adjusted_bias = nn.Parameter(scaled_biases) + return adjusted_bias + + +def _prune_module_bias(module: nn.Module, mask: Tensor) -> None: + r"""Applies mask to given modules bias""" + # prune bias along with weights, discard pruned indices of bias + original_bias = cast(Tensor, getattr(module, "_bias", module.bias)) + if original_bias is not None: + module.bias = nn.Parameter(original_bias[mask]) + + # remove _bias parameter + if hasattr(module, "_bias"): + delattr(module, "_bias") + + +def _propagate_module_bias(module: nn.Module, mask: Tensor) -> Optional[Tensor]: + r""" + In the case that we need to propagate biases, this function will return the biases we need + """ + # set current module bias + if module.bias is not None: + module.bias = nn.Parameter(cast(Tensor, module.bias)[mask]) + elif getattr(module, "_bias", None) is not None: + module.bias = nn.Parameter(cast(Tensor, module._bias)[mask]) + + # get pruned biases to propagate to subsequent layer + if getattr(module, "_bias", None) is not None: + pruned_biases = cast(Tensor, module._bias)[~mask] + else: + pruned_biases = None + + if hasattr(module, "_bias"): + delattr(module, "_bias") + + return pruned_biases + + +# LINEAR +def _prune_linear_helper(linear: nn.Linear) -> Tensor: + # expects linear to be a parameterized linear module + parametrization_dict = cast(nn.ModuleDict, linear.parametrizations) + weight_parameterizations = cast(ParametrizationList, parametrization_dict.weight) + for p in weight_parameterizations: + if isinstance(p, FakeStructuredSparsity): + mask = cast(Tensor, p.mask) + + with torch.no_grad(): + parametrize.remove_parametrizations(linear, "weight", leave_parametrized=True) + linear.weight = nn.Parameter(linear.weight[mask]) # type: ignore[possibly-undefined] + linear.out_features = linear.weight.shape[0] + _remove_bias_handles(linear) + + return mask + + +def prune_linear(linear: nn.Linear) -> None: + mask = _prune_linear_helper(linear) + if getattr(linear, "prune_bias", False): + _prune_module_bias(linear, mask) + + +def prune_linear_linear(linear1: nn.Linear, linear2: nn.Linear) -> None: + prune_linear_activation_linear(linear1, None, linear2) + + +def prune_linear_activation_linear( + linear1: nn.Linear, + activation: Optional[Callable[[Tensor], Tensor]], + linear2: nn.Linear, +): + mask = _prune_linear_helper(linear1) + if getattr(linear1, "prune_bias", False): + _prune_module_bias(linear1, mask) + else: + pruned_biases = _propagate_module_bias(linear1, mask) + if pruned_biases is not None: + if activation: + pruned_biases = activation(pruned_biases) + linear2.bias = _get_adjusted_next_layer_bias(linear2, pruned_biases, mask) + + with torch.no_grad(): + if parametrize.is_parametrized(linear2): + parametrization_dict = cast(nn.ModuleDict, linear2.parametrizations) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict.weight + ) + + weight_parameterizations.original = nn.Parameter( + weight_parameterizations.original[:, mask] + ) + linear2.in_features = weight_parameterizations.original.shape[1] + else: + linear2.weight = nn.Parameter(linear2.weight[:, mask]) + linear2.in_features = linear2.weight.shape[1] + + +# CONV2D +def _prune_conv2d_helper(conv2d: nn.Conv2d) -> Tensor: + parametrization_dict = cast(nn.ModuleDict, conv2d.parametrizations) + weight_parameterizations = cast(ParametrizationList, parametrization_dict.weight) + for p in weight_parameterizations: + if isinstance(p, FakeStructuredSparsity): + mask = cast(Tensor, p.mask) + + with torch.no_grad(): + parametrize.remove_parametrizations(conv2d, "weight", leave_parametrized=True) + conv2d.weight = nn.Parameter(conv2d.weight[mask]) # type: ignore[possibly-undefined] + conv2d.out_channels = conv2d.weight.shape[0] + + _remove_bias_handles(conv2d) + return mask + + +def prune_conv2d_padded(conv2d_1: nn.Conv2d) -> None: + parametrization_dict = cast(nn.ModuleDict, conv2d_1.parametrizations) + weight_parameterizations = cast(ParametrizationList, parametrization_dict.weight) + for p in weight_parameterizations: + if isinstance(p, FakeStructuredSparsity): + mask = cast(Tensor, p.mask) + + with torch.no_grad(): + parametrize.remove_parametrizations(conv2d_1, "weight", leave_parametrized=True) + + if getattr(conv2d_1, "_bias", None) is not None: + if ( + conv2d_1.bias is not None + ): # conv2d_1 has original bias and bias propagated from previous layer + new_bias = torch.zeros(conv2d_1.bias.shape) + new_bias[mask] = conv2d_1.bias[mask] # type: ignore[possibly-undefined] + # adjusted bias that to keep in conv2d_1 + new_bias[~mask] = cast(Tensor, conv2d_1._bias)[~mask] + # pruned biases that are kept instead of propagated + conv2d_1.bias = nn.Parameter(new_bias) + else: # conv2d_1 has only original bias + conv2d_1.bias = nn.Parameter(cast(Tensor, conv2d_1._bias)) + else: + # no original bias, only propagated bias + if ( + conv2d_1.bias is not None + ): # conv2d_1 has bias propagated from previous layer + conv2d_1.bias.data[~mask] = 0 # type: ignore[possibly-undefined] + + if hasattr(conv2d_1, "_bias"): + delattr(conv2d_1, "_bias") + + +def prune_conv2d(conv2d: nn.Conv2d) -> None: + mask = _prune_conv2d_helper(conv2d) + if getattr(conv2d, "prune_bias", False): + _prune_module_bias(conv2d, mask) + + +def prune_conv2d_conv2d(conv2d_1: nn.Conv2d, conv2d_2: nn.Conv2d) -> None: + prune_conv2d_activation_conv2d(conv2d_1, None, conv2d_2) + + +def prune_conv2d_activation_conv2d( + conv2d_1: nn.Conv2d, + activation: Optional[Callable[[Tensor], Tensor]], + conv2d_2: nn.Conv2d, +): + r""" + Fusion Pattern for conv2d -> some activation module / function -> conv2d layers + """ + parametrization_dict = cast(nn.ModuleDict, conv2d_1.parametrizations) + weight_parameterizations = cast(ParametrizationList, parametrization_dict.weight) + for p in weight_parameterizations: + if isinstance(p, FakeStructuredSparsity): + mask = cast(Tensor, p.mask) + + prune_bias = getattr(conv2d_1, "prune_bias", False) + if ( + hasattr(conv2d_2, "padding") + and cast(tuple[int], conv2d_2.padding) > (0, 0) + and (conv2d_1.bias is not None or getattr(conv2d_1, "_bias", None) is not None) + ): + prune_conv2d_padded(conv2d_1) + else: + mask = _prune_conv2d_helper(conv2d_1) + if prune_bias: + _prune_module_bias(conv2d_1, mask) + else: + pruned_biases = _propagate_module_bias(conv2d_1, mask) + if pruned_biases is not None: + if activation: + pruned_biases = activation(pruned_biases) + conv2d_2.bias = _get_adjusted_next_layer_bias( + conv2d_2, pruned_biases, mask + ) + + if ( + not ( + hasattr(conv2d_2, "padding") + and cast(tuple[int], conv2d_2.padding) > (0, 0) + ) + or conv2d_1.bias is None + ): + with torch.no_grad(): + if parametrize.is_parametrized(conv2d_2): + parametrization_dict = cast( + nn.ModuleDict, conv2d_2.parametrizations + ) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict.weight + ) + weight_parameterizations.original = nn.Parameter( + weight_parameterizations.original[:, mask] + ) + conv2d_2.in_channels = weight_parameterizations.original.shape[1] + else: + conv2d_2.weight = nn.Parameter(conv2d_2.weight[:, mask]) + conv2d_2.in_channels = conv2d_2.weight.shape[1] + + +def prune_conv2d_pool_activation_conv2d( + c1: nn.Conv2d, + pool: nn.Module, + activation: Optional[Callable[[Tensor], Tensor]], + c2: nn.Conv2d, +) -> None: + prune_conv2d_activation_conv2d(c1, activation, c2) + + +def prune_conv2d_activation_pool_conv2d( + c1: nn.Conv2d, + activation: Optional[Callable[[Tensor], Tensor]], + pool: nn.Module, + c2: nn.Conv2d, +) -> None: + prune_conv2d_activation_conv2d(c1, activation, c2) + + +def prune_conv2d_pool_flatten_linear( + conv2d: nn.Conv2d, + pool: nn.Module, + flatten: Optional[Callable[[Tensor], Tensor]], + linear: nn.Linear, +) -> None: + mask = _prune_conv2d_helper(conv2d) + + # We map the pruned indices of the Conv2d output to the flattened indices of the Linear following the Flatten layer. + # we determine the flattening scale (h * w), and readjust `first_pruned_indices` + # (each idx maps to range idx * h * w to (idx+1) * h * w), `first_valid_indices`, + # and `pruned_biases` (repeat each bias by h * w). + if parametrize.is_parametrized(linear): + parametrization_dict = cast(nn.ModuleDict, linear.parametrizations) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict.weight + ) + linear_ic = weight_parameterizations.original.shape[1] + else: + linear_ic = linear.weight.shape[1] + + conv2d_oc = len(mask) + assert linear_ic % conv2d_oc == 0, ( + f"Flattening from dimensions {conv2d_oc} to {linear_ic} not supported" + ) + + flatten_scale = linear_ic // conv2d_oc + flattened_mask = torch.tensor( + [[val] * flatten_scale for val in mask], dtype=torch.bool, device=mask.device + ).flatten() + + if getattr(conv2d, "prune_bias", False): + _prune_module_bias(conv2d, mask) + else: + pruned_biases = cast(Tensor, _propagate_module_bias(conv2d, mask)) + flattened_pruned_biases = torch.tensor( + [[bias] * flatten_scale for bias in pruned_biases], device=mask.device + ).flatten() + linear.bias = _get_adjusted_next_layer_bias( + linear, flattened_pruned_biases, flattened_mask + ) + + with torch.no_grad(): + if parametrize.is_parametrized(linear): + parametrization_dict = cast(nn.ModuleDict, linear.parametrizations) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict.weight + ) + weight_parameterizations.original = nn.Parameter( + weight_parameterizations.original[:, flattened_mask] + ) + linear.in_features = weight_parameterizations.original.shape[1] + else: + linear.weight = nn.Parameter(linear.weight[:, flattened_mask]) + linear.in_features = linear.weight.shape[1] + + +def prune_lstm_output_linear( + lstm: nn.LSTM, getitem: Callable, linear: nn.Linear +) -> None: + prune_lstm_output_layernorm_linear(lstm, getitem, None, linear) + + +def prune_lstm_output_layernorm_linear( + lstm: nn.LSTM, + getitem: Callable, + layernorm: Optional[nn.LayerNorm], + linear: nn.Linear, +) -> None: + for i in range(lstm.num_layers): + if parametrize.is_parametrized(lstm, f"weight_ih_l{i}"): + parametrization_dict = cast(nn.ModuleDict, lstm.parametrizations) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict[f"weight_ih_l{i}"] + ) + mask = weight_parameterizations[0].mask + + with torch.no_grad(): + parametrize.remove_parametrizations( + lstm, f"weight_ih_l{i}", leave_parametrized=True + ) + setattr( + lstm, + f"weight_ih_l{i}", + nn.Parameter(getattr(lstm, f"weight_ih_l{i}")[mask]), + ) + setattr( + lstm, + f"bias_ih_l{i}", + nn.Parameter(getattr(lstm, f"bias_ih_l{i}")[mask]), + ) + + if parametrize.is_parametrized(lstm, f"weight_hh_l{i}"): + parametrization_dict = cast(nn.ModuleDict, lstm.parametrizations) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict[f"weight_hh_l{i}"] + ) + mask = weight_parameterizations[0].mask + + with torch.no_grad(): + parametrize.remove_parametrizations( + lstm, f"weight_hh_l{i}", leave_parametrized=True + ) + # splitting out hidden-hidden masks + W_hi, W_hf, W_hg, W_ho = torch.split( + getattr(lstm, f"weight_hh_l{i}"), lstm.hidden_size + ) + M_hi, M_hf, M_hg, M_ho = torch.split(mask, lstm.hidden_size) # type: ignore[arg-type] + + # resize each individual weight separately + W_hi = W_hi[M_hi][:, M_hi] + W_hf = W_hf[M_hf][:, M_hf] + W_hg = W_hg[M_hg][:, M_hg] + W_ho = W_ho[M_ho][:, M_ho] + + # concat, use this as new weight + new_weight = torch.cat((W_hi, W_hf, W_hg, W_ho)) + setattr(lstm, f"weight_hh_l{i}", nn.Parameter(new_weight)) + setattr( + lstm, + f"bias_hh_l{i}", + nn.Parameter(getattr(lstm, f"bias_hh_l{i}")[mask]), + ) + + # If this is the final layer, then we need to prune linear layer columns + if i + 1 == lstm.num_layers: + lstm.hidden_size = int(M_hi.sum()) + with torch.no_grad(): + if parametrize.is_parametrized(linear): + parametrization_dict = cast( + nn.ModuleDict, linear.parametrizations + ) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict.weight + ) + + weight_parameterizations.original = nn.Parameter( + weight_parameterizations.original[:, M_ho] + ) + linear.in_features = weight_parameterizations.original.shape[1] + else: + linear.weight = nn.Parameter(linear.weight[:, M_ho]) + linear.in_features = linear.weight.shape[1] + + # if layernorm module, prune weight and bias + if layernorm is not None: + layernorm.normalized_shape = (linear.in_features,) + layernorm.weight = nn.Parameter(layernorm.weight[M_ho]) + layernorm.bias = nn.Parameter(layernorm.bias[M_ho]) + + # otherwise need to prune the columns of the input of the next LSTM layer + else: + with torch.no_grad(): + if parametrize.is_parametrized(lstm, f"weight_ih_l{i + 1}"): + parametrization_dict = cast( + nn.ModuleDict, lstm.parametrizations + ) + weight_parameterizations = cast( + ParametrizationList, + getattr(parametrization_dict, f"weight_ih_l{i + 1}"), + ) + + weight_parameterizations.original = nn.Parameter( + weight_parameterizations.original[:, M_ho] + ) + else: + next_layer_weight = getattr(lstm, f"weight_ih_l{i + 1}") + setattr( + lstm, + f"weight_ih_l{i + 1}", + nn.Parameter(next_layer_weight[:, M_ho]), + ) diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/saliency_pruner.py b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/saliency_pruner.py new file mode 100644 index 0000000000000000000000000000000000000000..73fe8acdc2cd15418e8c7be2b148bfae1db6f066 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/pruning/_experimental/pruner/saliency_pruner.py @@ -0,0 +1,32 @@ +# mypy: allow-untyped-defs +from .base_structured_sparsifier import BaseStructuredSparsifier + + +class SaliencyPruner(BaseStructuredSparsifier): + """ + Prune rows based on the saliency (L1 norm) of each row. + + This pruner works on N-Dimensional weight tensors. + For each row, we will calculate the saliency, whic is the sum the L1 norm of all weights in that row. + We expect that the resulting saliency vector has the same shape as our mask. + We then pick elements to remove until we reach the target sparsity_level. + """ + + def update_mask(self, module, tensor_name, **kwargs): + # tensor_name will give you the FQN, all other entries in sparse config is present in kwargs + weights = getattr(module, tensor_name) + mask = getattr(module.parametrizations, tensor_name)[0].mask + + # use negative weights so we can use topk (we prune out the smallest) + if weights.dim() <= 1: + raise Exception( # noqa: TRY002 + "Structured pruning can only be applied to a 2+dim weight tensor!" + ) + saliency = -weights.norm(dim=tuple(range(1, weights.dim())), p=1) + assert saliency.shape == mask.shape + + num_to_pick = int(len(mask) * kwargs["sparsity_level"]) + prune = saliency.topk(num_to_pick).indices + + # Set the mask to be false for the rows we want to prune + mask.data[prune] = False diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/_mappings.py b/phivenv/Lib/site-packages/torch/ao/pruning/_mappings.py new file mode 100644 index 0000000000000000000000000000000000000000..0f85ce97ceefd5f10f89b3b1b4392c5fffc48f99 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/pruning/_mappings.py @@ -0,0 +1,23 @@ +# mypy: allow-untyped-defs +__all__ = [ + "get_static_sparse_quantized_mapping", + "get_dynamic_sparse_quantized_mapping", +] + + +def get_static_sparse_quantized_mapping(): + import torch.ao.nn.sparse + + _static_sparse_quantized_mapping = { + torch.nn.Linear: torch.ao.nn.sparse.quantized.Linear, + } + return _static_sparse_quantized_mapping + + +def get_dynamic_sparse_quantized_mapping(): + import torch.ao.nn.sparse + + _dynamic_sparse_quantized_mapping = { + torch.nn.Linear: torch.ao.nn.sparse.quantized.dynamic.Linear, + } + return _dynamic_sparse_quantized_mapping diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/scheduler/__init__.py b/phivenv/Lib/site-packages/torch/ao/pruning/scheduler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/scheduler/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/pruning/scheduler/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a08628858c4fa0da8423b064a1e95ad00eaba06b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/pruning/scheduler/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/scheduler/__pycache__/base_scheduler.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/pruning/scheduler/__pycache__/base_scheduler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2036e2a0b8a121e6829d747d4359fff466569d1a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/pruning/scheduler/__pycache__/base_scheduler.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/scheduler/__pycache__/cubic_scheduler.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/pruning/scheduler/__pycache__/cubic_scheduler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4b46bed276f03c89f5c7c0d25cc2b7288bee3ae Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/pruning/scheduler/__pycache__/cubic_scheduler.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/scheduler/__pycache__/lambda_scheduler.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/pruning/scheduler/__pycache__/lambda_scheduler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d86d2cc22424e6d69e03903571ff194b2f78880e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/pruning/scheduler/__pycache__/lambda_scheduler.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/scheduler/base_scheduler.py b/phivenv/Lib/site-packages/torch/ao/pruning/scheduler/base_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..675fe92063b2997ea35ebc48695712f79921f26a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/pruning/scheduler/base_scheduler.py @@ -0,0 +1,170 @@ +# mypy: allow-untyped-defs + +import warnings +import weakref +from functools import wraps + +from torch.ao.pruning.sparsifier.base_sparsifier import BaseSparsifier + + +__all__ = ["BaseScheduler"] + + +class BaseScheduler: + def __init__(self, sparsifier, last_epoch=-1, verbose=False): + # Attach sparsifier + if not isinstance(sparsifier, BaseSparsifier): + raise TypeError( + f"{type(sparsifier).__name__} is not an instance of torch.ao.pruning.BaseSparsifier" + ) + self.sparsifier = sparsifier + + # Initialize epoch and base sparsity levels + + self.base_sl = [group["sparsity_level"] for group in sparsifier.groups] + self.last_epoch = last_epoch + + # Following https://github.com/pytorch/pytorch/issues/20124 + # We would like to ensure that `scheduler.step()` is called after + # `sparsifier.step()` + def with_counter(method): + if getattr(method, "_with_counter", False): + # `sparsifier.step()` has already been replaced, return. + return method + + # Keep a weak reference to the sparsifier instance to prevent + # cyclic references. + instance_ref = weakref.ref(method.__self__) + # Get the unbound method for the same purpose. + func = method.__func__ + cls = instance_ref().__class__ + del method + + @wraps(func) + def wrapper(*args, **kwargs): + instance = instance_ref() + instance._step_count += 1 # type: ignore[union-attr] + wrapped = func.__get__(instance, cls) + return wrapped(*args, **kwargs) + + # Note that the returned function here is no longer a bound method, + # so attributes like `__func__` and `__self__` no longer exist. + wrapper._with_counter = True # type: ignore[attr-defined] + return wrapper + + self.sparsifier.step = with_counter(self.sparsifier.step) # type: ignore[assignment] + self.sparsifier._step_count = 0 # type: ignore[attr-defined] + self._step_count: int = 0 + self.verbose = verbose + + # Housekeeping + self._get_sl_called_within_step: bool = False + + self.step() + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the sparsifier. + """ + return { + key: value for key, value in self.__dict__.items() if key != "sparsifier" + } + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_last_sl(self): + """Return last computed sparsity level by current scheduler.""" + return self._last_sl + + def get_sl(self): + # Compute sparsity level using chainable form of the scheduler + # Note: This method is not intended to be called directly, and is only + # used by the ".step" method. Use .get_last_sl() instead. + if not self._get_sl_called_within_step: + warnings.warn( + "To get the last sparsity level computed by the scheduler, " + "please use `get_last_sl()`." + ) + raise NotImplementedError + + def print_sl(self, is_verbose, group, sl, epoch=None): + """Display the current sparsity level.""" + if is_verbose: + if epoch is None: + print(f"Adjusting sparsity level of group {group} to {sl:.4e}.") + else: + print( + f"Epoch {epoch:5d}: adjusting sparsity level of group {group} to {sl:.4e}." + ) + + def __repr__(self): + format_string = self.__class__.__name__ + " (" + format_string += "\n" + format_string += f"Sparsifier {self.sparsifier}\n" + format_string += f" base_sl: {self.base_sl}\n" + format_string += ")" + return format_string + + def step(self, epoch=None): + # Raise warning if trying to call scheduler step before the sparsifier. + # https://github.com/pytorch/pytorch/issues/20124 + if self._step_count == 1: + if not hasattr(self.sparsifier.step, "_with_counter"): + warnings.warn( + "Seems like `sparsifier.step()` has been overridden after sparsity scheduler " + "initialization. Please, make sure to call `sparsifier.step()` before " + "`scheduler.step()`.", + UserWarning, + ) + + # Just check if there were two first scheduler.step() calls before sparsifier.step() + elif self.sparsifier._step_count < 1: # type: ignore[attr-defined] + warnings.warn( + "Detected call of `scheduler.step()` before `sparsifier.step()`. " + "You have to make sure you run the sparsifier.step() BEFORE any " + "calls to the scheduler.step().", + UserWarning, + ) + self._step_count += 1 + + class _enable_get_sl_call: + def __init__(self, o): + self.o = o + + def __enter__(self): + self.o._get_sl_called_within_step = True + return self + + def __exit__(self, type, value, traceback): + self.o._get_sl_called_within_step = False + + with _enable_get_sl_call(self): + self.last_epoch += 1 + values = self.get_sl() + + for i, data in enumerate(zip(self.sparsifier.groups, values)): + param_group, sl = data + param_group["sparsity_level"] = sl + self.print_sl(self.verbose, i, sl, epoch) + + self._last_sl = [group["sparsity_level"] for group in self.sparsifier.groups] + self.sparsifier.enable_mask_update = True + + def _make_sure_a_list(self, var): + r"""Utility that extends it to the same length as the .groups, ensuring it is a list""" + n = len(self.sparsifier.groups) + if not isinstance(var, (list, tuple)): + return [var] * n + else: + if len(var) != n: + raise ValueError(f"Expected variable of length {n}, but got {len(var)}") + return list(var) # We want the result to be in a list, not tuple diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/scheduler/cubic_scheduler.py b/phivenv/Lib/site-packages/torch/ao/pruning/scheduler/cubic_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..0574d23ba24c02e9aafed340b05b6af05743fb15 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/pruning/scheduler/cubic_scheduler.py @@ -0,0 +1,113 @@ +# mypy: allow-untyped-defs +import warnings + +from .base_scheduler import BaseScheduler + + +__all__ = ["CubicSL"] + + +def _clamp(x, lo, hi): + return max(lo, min(hi, x)) + + +class CubicSL(BaseScheduler): + r"""Sets the sparsity level of each parameter group to the final sl + plus a given exponential function. + + .. math:: + + s_i = s_f + (s_0 - s_f) \cdot \left( 1 - \frac{t - t_0}{n\Delta t} \right)^3 + + where :math:`s_i` is the sparsity at epoch :math:`t`, :math;`s_f` is the final + sparsity level, :math:`f(i)` is the function to be applied to the current epoch + :math:`t`, initial epoch :math:`t_0`, and final epoch :math:`t_f`. + :math:`\Delta t` is used to control how often the update of the sparsity level + happens. By default, + + Args: + sparsifier (BaseSparsifier): Wrapped sparsifier. + init_sl (int, list): Initial level of sparsity + init_t (int, list): Initial step, when pruning starts + delta_t (int, list): Pruning frequency + total_t (int, list): Total number of pruning steps + initially_zero (bool, list): If True, sets the level of sparsity to 0 + before init_t (:math:`t_0`). Otherwise, the sparsity level before + init_t (:math:`t_0`) is set to init_sl(:math:`s_0`) + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + """ + + def __init__( + self, + sparsifier, + init_sl=0.0, + init_t=0, + delta_t=10, + total_t=100, + initially_zero=False, + last_epoch=-1, + verbose=False, + ): + self.sparsifier = sparsifier + + self.init_sl = self._make_sure_a_list(init_sl) + self.init_t = self._make_sure_a_list(init_t) + self.delta_t = self._make_sure_a_list(delta_t) + self.total_t = self._make_sure_a_list(total_t) + + self.initially_zero = self._make_sure_a_list(initially_zero) + + super().__init__(sparsifier, last_epoch, verbose) + + @staticmethod + def sparsity_compute_fn(s_0, s_f, t, t_0, dt, n, initially_zero=False): + r""" "Computes the current level of sparsity. + + Based on https://arxiv.org/pdf/1710.01878.pdf + + Args: + s_0: Initial level of sparsity, :math:`s_i` + s_f: Target level of sparsity, :math:`s_f` + t: Current step, :math:`t` + t_0: Initial step, :math:`t_0` + dt: Pruning frequency, :math:`\Delta T` + n: Pruning steps, :math:`n` + initially_zero: Sets the level of sparsity to 0 before t_0. + If False, sets to s_0 + + Returns: + The sparsity level :math:`s_t` at the current step :math:`t` + """ + if initially_zero and t < t_0: + return 0 + s_t = s_f + (s_0 - s_f) * (1.0 - (t - t_0) / (dt * n)) ** 3 + s_t = _clamp(s_t, s_0, s_f) + return s_t + + def get_sl(self): + if not self._get_sl_called_within_step: + warnings.warn( + "To get the last sparsity level computed by the scheduler, " + "please use `get_last_sl()`." + ) + return [ + self.sparsity_compute_fn( + s_0=initial_sparsity, + s_f=final_sparsity, + t=self.last_epoch, + t_0=initial_epoch, + dt=delta_epoch, + n=interval_epochs, + initially_zero=initially_zero, + ) + for initial_sparsity, final_sparsity, initial_epoch, delta_epoch, interval_epochs, initially_zero in zip( + self.init_sl, + self.base_sl, + self.init_t, + self.delta_t, + self.total_t, + self.initially_zero, + ) + ] diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/scheduler/lambda_scheduler.py b/phivenv/Lib/site-packages/torch/ao/pruning/scheduler/lambda_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..9f9f75d738a0b1e69ee1f0b2a95e4921ebfdc0b6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/pruning/scheduler/lambda_scheduler.py @@ -0,0 +1,55 @@ +# mypy: allow-untyped-defs +import warnings + +from .base_scheduler import BaseScheduler + + +__all__ = ["LambdaSL"] + + +class LambdaSL(BaseScheduler): + """Sets the sparsity level of each parameter group to the final sl + times a given function. When last_epoch=-1, sets initial sl as zero. + Args: + sparsifier (BaseSparsifier): Wrapped sparsifier. + sl_lambda (function or list): A function which computes a multiplicative + factor given an integer parameter epoch, or a list of such + functions, one for each group in sparsifier.param_groups. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + Example: + >>> # Assuming sparsifier has two groups. + >>> lambda1 = lambda epoch: epoch // 30 + >>> lambda2 = lambda epoch: 0.95**epoch + >>> # xdoctest: +SKIP + >>> scheduler = LambdaSL(sparsifier, sl_lambda=[lambda1, lambda2]) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__(self, sparsifier, sl_lambda, last_epoch=-1, verbose=False): + self.sparsifier = sparsifier + + if not isinstance(sl_lambda, list) and not isinstance(sl_lambda, tuple): + self.sl_lambdas = [sl_lambda] * len(sparsifier.groups) + else: + if len(sl_lambda) != len(sparsifier.groups): + raise ValueError( + f"Expected {len(sparsifier.groups)} lr_lambdas, but got {len(sl_lambda)}" + ) + self.sl_lambdas = list(sl_lambda) + super().__init__(sparsifier, last_epoch, verbose) + + def get_sl(self): + if not self._get_sl_called_within_step: + warnings.warn( + "To get the last sparsity level computed by the scheduler, " + "please use `get_last_sl()`." + ) + return [ + base_sl * lmbda(self.last_epoch) + for lmbda, base_sl in zip(self.sl_lambdas, self.base_sl) + ] diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/sparsifier/__init__.py b/phivenv/Lib/site-packages/torch/ao/pruning/sparsifier/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/sparsifier/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/pruning/sparsifier/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d2341662c063ec3ea6ce8205d76729bf6c43a32 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/pruning/sparsifier/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/sparsifier/__pycache__/base_sparsifier.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/pruning/sparsifier/__pycache__/base_sparsifier.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07f7b5eb63a015c1cf47da95c875188f7de21992 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/pruning/sparsifier/__pycache__/base_sparsifier.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/sparsifier/__pycache__/nearly_diagonal_sparsifier.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/pruning/sparsifier/__pycache__/nearly_diagonal_sparsifier.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b32b8ab50819d016a89c12bf7254d4e1d53f7c5c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/pruning/sparsifier/__pycache__/nearly_diagonal_sparsifier.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/sparsifier/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/pruning/sparsifier/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..446d6ecbe7ef786f18e2e2f44de4cdf5ce74eb5e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/pruning/sparsifier/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/sparsifier/__pycache__/weight_norm_sparsifier.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/pruning/sparsifier/__pycache__/weight_norm_sparsifier.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d1bd1306f79d154de3539007750ad4683e87450 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/pruning/sparsifier/__pycache__/weight_norm_sparsifier.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/sparsifier/base_sparsifier.py b/phivenv/Lib/site-packages/torch/ao/pruning/sparsifier/base_sparsifier.py new file mode 100644 index 0000000000000000000000000000000000000000..f36d0e970e6e6f8e47eaccec4fd4c58abd2b2600 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/pruning/sparsifier/base_sparsifier.py @@ -0,0 +1,353 @@ +# mypy: allow-untyped-defs +import abc +import copy +from collections import defaultdict +from typing import Any, Optional + +import torch +from torch import nn +from torch.nn.utils import parametrize +from torch.nn.utils.parametrize import type_before_parametrizations + +from .utils import ( + FakeSparsity, + get_arg_info_from_tensor_fqn, + module_contains_param, + module_to_fqn, + swap_module, +) + + +__all__ = ["BaseSparsifier"] + +SUPPORTED_MODULES = {nn.Linear} + +KEYS_NOT_IN_STATE_DICT = ["module", "module_fqn", "tensor_name"] + + +# TODO update desc with new config args +class BaseSparsifier(abc.ABC): + r"""Base class for all sparsifiers. + + Abstract methods that need to be implemented: + + - update_mask: Function to compute a new mask for all keys in the + `groups`. + + Args: + - model [nn.Module]: model to configure. The model itself is not saved + but used for the state_dict saving / loading. + - config [list]: configuration elements should be a dict map that includes + `tensor_fqn` of tensors to sparsify + - defaults [dict]: default configurations will be attached to the + configuration. Only the keys that don't exist in the `config` will + be updated. + + Example:: + + >>> # xdoctest: +SKIP("Can't instantiate abstract class BaseSparsifier with abstract method update_mask") + >>> config = [{'tensor_fqn': 'layer1.weight', 'tensor_fqn': 'linear2.weight2', 'sparsity_level': 0.5}] + >>> defaults = {'sparsity_level': 0.7} + >>> # model.layer1.weight will have `sparsity_level` = 0.7 (getting default) + >>> sparsifier = BaseSparsifier(config, defaults) + """ + + def __init__(self, defaults: Optional[dict[str, Any]] = None): + super().__init__() + self.defaults: dict[str, Any] = defaults or {} + + self.state: dict[str, dict] = defaultdict(dict) + self.groups: list[dict[str, Any]] = [] + self.enable_mask_update = True + + def __getstate__(self) -> dict[str, Any]: + return { + "defaults": self.defaults, + "state": self.state, + "groups": self.groups, + } + + def __setstate__(self, state: dict[str, dict[str, Any]]) -> None: + self.__dict__.update(state) + + def __repr__(self): + format_string = self.__class__.__name__ + " (" + for i, sparse_args in enumerate(self.groups): + module = sparse_args["module"] + format_string += "\n" + format_string += f"\tGroup {i}\n" + format_string += f"\t module: {module}\n" + for key in sorted(sparse_args.keys()): + if key == "module": + continue + format_string += f"\t {key}: {sparse_args[key]}\n" + format_string += ")" + return format_string + + def state_dict(self) -> dict[str, Any]: + r"""Returns the state of the optimizer as a :class:`dict`. + + It contains: + * state - current state of the sparsification. + * groups - a list containing all sparsity configuration groups + with the key 'tensor_fqn' specifying the path to the sparsified tensor within a model + + TODO: Need a clean way of loading the state of the "prepared" module + """ + + groups: list[dict[str, Any]] = [ + dict( + filter( + lambda key_value: key_value[0] not in KEYS_NOT_IN_STATE_DICT, + mg.items(), + ) + ) + for mg in self.groups + ] + + return { + "state": self.state, + "groups": groups, + } + + def load_state_dict(self, state_dict: dict[str, Any], strict: bool = True): + groups = copy.deepcopy(state_dict["groups"]) + states = state_dict["state"] + for tensor_fqn, s in states.items(): + arg_info = get_arg_info_from_tensor_fqn(self.model, tensor_fqn) + module = arg_info["module"] + tensor_name = arg_info["tensor_name"] + if strict and module is None: + raise RuntimeError(f"Error loading {tensor_fqn} into the model") + + found = False + for p in module.parametrizations[tensor_name]: + if isinstance(p, FakeSparsity): + found = True + break + if not found: + p = FakeSparsity(torch.ones(getattr(module, tensor_name).shape)) + parametrize.register_parametrization(module, tensor_name, p) + if s.get("mask", None) is not None: + mask = s.pop("mask") + p.mask = mask + + for mg in groups: + if mg["tensor_fqn"] == tensor_fqn: + mg.update(arg_info) + self.__setstate__({"state": states, "groups": groups}) + + def make_config_from_model( + self, + model: nn.Module, + SUPPORTED_MODULES: set[type[nn.Linear]] = SUPPORTED_MODULES, + ) -> None: + self.config = [] + stack = [model] + while stack: + module = stack.pop() + for _name, child in module.named_children(): + if type(child) in SUPPORTED_MODULES: + module_fqn = module_to_fqn(model, child) + assert isinstance(module_fqn, str) # for mypy + self.config.append({"tensor_fqn": module_fqn + ".weight"}) + else: + stack.append(child) + + def prepare(self, model, config): + r"""Prepares a model, by adding the parametrizations. + + Note:: + + The model is modified inplace. If you need to preserve the original + model, use copy.deepcopy. + """ + self.model = model # TODO: Need to figure out how to load without this. + self.config = config + + # If no config -- try getting all the supported layers + if self.config is None: + self.make_config_from_model(model) + + # TODO: Remove the configuration by reference ('module') + for module_config in self.config: + assert isinstance(module_config, dict), ( + "config elements should be dicts not modules i.e.:" + "[{`tensor_fqn`: `foo.bar.weight`}, {`tensor_fqn`: ... }, ...]" + ) + + assert isinstance(self.defaults, dict) # for mypy + local_args = copy.deepcopy(self.defaults) + local_args.update(module_config) + + tensor_fqn = local_args.get("tensor_fqn", None) + assert tensor_fqn is not None, ( + "tensor_fqn is a required argument in the sparsity config which" + "replaces previous `module` and [module]`fqn` arguments" + ) + + # populate all information from tensor_fqn + info_from_tensor_fqn = get_arg_info_from_tensor_fqn(model, tensor_fqn) + + # check that whatever was put into local_args agrees with what was obtained + # from tensor_fqn + for key in info_from_tensor_fqn.keys(): + if key in local_args: + assert ( + info_from_tensor_fqn[key] == local_args[key] + or ( + key == "tensor_fqn" + and "." + info_from_tensor_fqn[key] == local_args[key] + ) + # info_from_tensor_fqn will chop leading '.' from tensor_fqn so ignore that + ), ( + f"Given both `{key}` and `tensor_fqn` in the config, it is expected them to agree!" + ) + local_args.update(info_from_tensor_fqn) + self.groups.append(local_args) + self._prepare() + + def _prepare(self, *args, **kwargs): + r"""Adds mask parametrization to the layer weight""" + for config in self.groups: + module = config["module"] + tensor_name = config["tensor_name"] + parametrization = config.get("parametrization", FakeSparsity) + mask = config.get("mask", torch.ones_like(getattr(module, tensor_name))) + self.state[config["tensor_fqn"]]["mask"] = mask + parametrize.register_parametrization( + module, tensor_name, parametrization(mask) + ) + + def squash_mask( + self, + params_to_keep: Optional[tuple[str, ...]] = None, + params_to_keep_per_layer: Optional[dict[str, tuple[str, ...]]] = None, + *args, + **kwargs, + ): + r"""Squashes the sparse masks into the appropriate tensors. + + If either the `params_to_keep` or `params_to_keep_per_layer` is set, + the module will have a `sparse_params` dict attached to it. + + Args: + params_to_keep: List of keys to save in the module or a dict + representing the modules and keys that will have + sparsity parameters saved + params_to_keep_per_layer: Dict to specify the params that should be + saved for specific layers. The keys in the dict + should be the module fqn, while the values should + be a list of strings with the names of the variables + to save in the `sparse_params` + + Examples: + >>> # xdoctest: +SKIP("locals are undefined") + >>> # Don't save any sparse params + >>> sparsifier.squash_mask() + >>> hasattr(model.submodule1, "sparse_params") + False + + >>> # Keep sparse params per layer + >>> sparsifier.squash_mask( + ... params_to_keep_per_layer={ + ... "submodule1.linear1": ("foo", "bar"), + ... "submodule2.linear42": ("baz",), + ... } + ... ) + >>> print(model.submodule1.linear1.sparse_params) + {'foo': 42, 'bar': 24} + >>> print(model.submodule2.linear42.sparse_params) + {'baz': 0.1} + + >>> # Keep sparse params for all layers + >>> sparsifier.squash_mask(params_to_keep=("foo", "bar")) + >>> print(model.submodule1.linear1.sparse_params) + {'foo': 42, 'bar': 24} + >>> print(model.submodule2.linear42.sparse_params) + {'foo': 42, 'bar': 24} + + >>> # Keep some sparse params for all layers, and specific ones for + >>> # some other layers + >>> sparsifier.squash_mask( + ... params_to_keep=("foo", "bar"), + ... params_to_keep_per_layer={"submodule2.linear42": ("baz",)}, + ... ) + >>> print(model.submodule1.linear1.sparse_params) + {'foo': 42, 'bar': 24} + >>> print(model.submodule2.linear42.sparse_params) + {'foo': 42, 'bar': 24, 'baz': 0.1} + """ + for config in self.groups: + module = config["module"] + tensor_name = config["tensor_name"] + parametrize.remove_parametrizations( + module, tensor_name, leave_parametrized=True + ) + sparse_params = {} + if params_to_keep is not None: + global_params = {k: config[k] for k in params_to_keep} + sparse_params.update(global_params) + if params_to_keep_per_layer is not None: + params = params_to_keep_per_layer.get(config["module_fqn"], None) + if params is not None: + per_layer_params = {k: config[k] for k in params} + sparse_params.update(per_layer_params) + if sparse_params: + # TODO handle multiple tensor being quantized on a single module, where to store sparse_params? + module.sparse_params = sparse_params + + def convert( + self, + module: nn.Module, + mapping: Optional[dict[type[nn.Module], type[nn.Module]]] = None, + inplace: bool = False, + parameterization: type[nn.Module] = FakeSparsity, + ): + r"""Converts submodules in input module to a different module according to `mapping` + by calling `from_dense` method on the target module class + Args: + module: input module + mapping: a dictionary that maps from source module type to target + module type, can be overwritten to allow swapping user defined + Modules + inplace: carry out model transformations in-place, the original module + is mutated + """ + if mapping is None: + raise NotImplementedError("Need to auto generate mapping ") + if not inplace: + module = copy.deepcopy(module) + + reassign = {} + for name, mod in module.named_children(): + # leaf node + if ( + module_contains_param(mod, parameterization) + and type_before_parametrizations(mod) in mapping + ): + reassign[name] = swap_module(mod, mapping) + else: + # recurse + reassign[name] = self.convert( + mod, + mapping=mapping, + inplace=True, + parameterization=parameterization, + ) + + for key, value in reassign.items(): + module._modules[key] = value + + return module + + def step(self, use_path: bool = True) -> None: + if not self.enable_mask_update: + return + with torch.no_grad(): + for config in self.groups: + self.update_mask(**config) + + @abc.abstractmethod + def update_mask(self, module: nn.Module, tensor_name: str, **kwargs): + pass diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py b/phivenv/Lib/site-packages/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py new file mode 100644 index 0000000000000000000000000000000000000000..dacc496937d884c0fdffa358a378d75cd5ce2074 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py @@ -0,0 +1,60 @@ +# mypy: allow-untyped-defs +import torch + +from . import base_sparsifier + + +class NearlyDiagonalSparsifier(base_sparsifier.BaseSparsifier): + r"""Nearly Diagonal Sparsifier + + This sparsifier creates a nearly diagonal mask to be applied to the weight matrix. + Nearly Diagonal Matrix is a matrix that contains non-zero elements near the diagonal and the rest are zero. + An example of a nearly diagonal matrix with degree (or nearliness) 3 and 5 are follows respectively. + 1 1 0 0 1 1 1 0 + 1 1 1 0 1 1 1 1 + 0 1 1 1 1 1 1 1 + 0 0 1 1 0 1 1 1 + Note that a nearly diagonal matrix with degree 1 is just a matrix with main diagonal populated + + This sparsifier is controlled by one variable: + 1. `nearliness` defines the number of non-zero diagonal lines that are closest to the main diagonal. + Currently - supports only odd number + + Note: + This can be accelerated (vectorized) once the Spdiagonal feature (PR: #78439) is landed or the banded matrix + feature is landed: https://stackoverflow.com/questions/52463972/generating-banded-matrices-using-numpy + + Args: + nearliness: The degree of nearliness (default = 1) + + """ + + def __init__(self, nearliness: int = 1): + defaults = {"nearliness": nearliness} + super().__init__(defaults=defaults) + + def update_mask( # type:ignore[override] + self, module, tensor_name, nearliness, **kwargs + ): + mask = getattr(module.parametrizations, tensor_name)[0].mask + mask.data = torch.zeros_like(mask) + if nearliness <= 0: + return + + tensor = getattr(module, tensor_name) + height, width = tensor.shape + + if nearliness % 2 == 0: + raise ValueError("nearliness can only be an odd number") + dist_to_diagonal = nearliness // 2 + # check + if dist_to_diagonal >= min(height, width): + raise ValueError( + "nearliness cannot be larger than the dimensions of tensor." + ) + + for row in range(0, height): + # Bounds of entries that needs to be set to 1 + low = max(0, row - dist_to_diagonal) + high = min(width, row + dist_to_diagonal + 1) + mask[row, low:high].fill_(1) diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/sparsifier/utils.py b/phivenv/Lib/site-packages/torch/ao/pruning/sparsifier/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ed0a9aacc063ac4f9e318a7f1e9b0655185d6758 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/pruning/sparsifier/utils.py @@ -0,0 +1,138 @@ +# mypy: allow-untyped-defs +from itertools import chain +from typing import Any, Optional + +from torch import nn +from torch.nn.utils.parametrize import is_parametrized, type_before_parametrizations + + +__all__ = [ + "module_contains_param", + "swap_module", + "module_to_fqn", + "fqn_to_module", + "get_arg_info_from_tensor_fqn", + "FakeSparsity", +] + + +def module_contains_param(module: nn.Module, parametrization: type[nn.Module]) -> bool: + if is_parametrized(module): + # see if any of the module tensors have a parametriztion attached that matches the one passed in + return any( + any(isinstance(param, parametrization) for param in param_list) + for key, param_list in module.parametrizations.items() # type: ignore[union-attr,operator] + ) + return False + + +def swap_module( + mod: nn.Module, mapping: dict[type[nn.Module], type[nn.Module]] +) -> nn.Module: + r"""Swaps the module using from_dense according to the mapping passed in. + Args: + mod: input module + mapping: a dictionary that maps from nn module to sparse nn module + Return: + The corresponding sparse module of `mod` according to mapping, created using from_dense + """ + if type_before_parametrizations(mod) in mapping: + sparse_mod = mapping[type_before_parametrizations(mod)] + + # TODO Fix this typing, as Type[Module] has no attribute "from_dense" + new_mod = sparse_mod.from_dense(mod) # type: ignore[attr-defined] + + # Preserve module's pre forward hooks. They'll be called on quantized input + for pre_hook_fn in mod._forward_pre_hooks.values(): + new_mod.register_forward_pre_hook(pre_hook_fn) + # Preserve module's post forward hooks except _observer_forward_hook + # After convert they'll work with quantized output + for hook_fn in mod._forward_hooks.values(): + new_mod.register_forward_hook(hook_fn) + + # respect device affinity when swapping modules + devices = {p.device for p in chain(mod.parameters(), mod.buffers())} + assert len(devices) <= 1, ( + f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}" + ) + device = next(iter(devices)) if len(devices) > 0 else None + if device: + new_mod.to(device) + + return new_mod + + else: + return mod + + +def module_to_fqn( + model: nn.Module, module: nn.Module, prefix: str = "" +) -> Optional[str]: + """ + Returns the fqn for a module or None if module not a descendent of model. + """ + if module is model: + return "" + for name, child in model.named_children(): + fqn = module_to_fqn(child, module, ".") + if isinstance(fqn, str): + return prefix + name + fqn + return None + + +def fqn_to_module(model: Optional[nn.Module], path: str) -> Optional[nn.Module]: + """ + Given an fqn, returns the corresponding module or tensor or None if the fqn given by `path` + doesn't correspond to anything. Similar to model.get_submodule(path) but works for tensors. + """ + if path != "": + for name in path.split("."): + model = getattr(model, name, None) + return model + + +def get_arg_info_from_tensor_fqn(model: nn.Module, tensor_fqn: str) -> dict[str, Any]: + """ + Uses tensor_fqn to obtain a dict containing module_fqn, module and tensor_name + """ + # string manip to split tensor_fqn into module_fqn and tensor_name + # if tensor_fqn is 'weight' then module_fqn and tensor_name are '' and 'weight' + # if tensor_fqn is 'linear.weight' then module_fqn and tensor_name are 'linear' and 'weight' + tensor_name = tensor_fqn.split(".")[-1] + module_fqn = tensor_fqn[: -len(tensor_name) - ("." in tensor_fqn)] + + module = fqn_to_module(model, module_fqn) + + return { + "module_fqn": module_fqn, + "module": module, + "tensor_name": tensor_name, + "tensor_fqn": tensor_fqn, + } + + +# Parametrizations +class FakeSparsity(nn.Module): + r"""Parametrization for the weights. Should be attached to the 'weight' or + any other parameter that requires a mask applied to it. + + Note:: + + Once the mask is passed, the variable should not change the id. The + contents of the mask can change, but the mask reference itself should + not. + """ + + def __init__(self, mask): + super().__init__() + self.register_buffer("mask", mask) + + def forward(self, x): + assert self.mask.shape == x.shape + return self.mask * x + + def state_dict(self, *args, **kwargs): + # We don't want to let the parametrizations to save the mask. + # That way we make sure that the linear module doesn't store the masks + # alongside their parametrizations. + return {} diff --git a/phivenv/Lib/site-packages/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py b/phivenv/Lib/site-packages/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py new file mode 100644 index 0000000000000000000000000000000000000000..2057a22c08255982683b2ecc2dadb8776a67d920 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py @@ -0,0 +1,248 @@ +# mypy: allow-untyped-defs +import operator +from functools import reduce +from typing import Callable, Optional, Union + +import torch +import torch.nn.functional as F + +from .base_sparsifier import BaseSparsifier + + +__all__ = ["WeightNormSparsifier"] + + +def _flat_idx_to_2d(idx, shape): + rows = idx // shape[1] + cols = idx % shape[1] + return rows, cols + + +class WeightNormSparsifier(BaseSparsifier): + r"""Weight-Norm Sparsifier + + This sparsifier computes the norm of every sparse block and "zeroes-out" the + ones with the lowest norm. The level of sparsity defines how many of the + blocks is removed. + + This sparsifier is controlled by three variables: + 1. `sparsity_level` defines the number of *sparse blocks* that are zeroed-out + 2. `sparse_block_shape` defines the shape of the sparse blocks. Note that + the sparse blocks originate at the zero-index of the tensor. + 3. `zeros_per_block` is the number of zeros that we are expecting in each + sparse block. By default we assume that all elements within a block are + zeroed-out. However, setting this variable sets the target number of + zeros per block. The zeros within each block are chosen as the *smallest + absolute values*. + + Args: + + sparsity_level: The target level of sparsity + sparse_block_shape: The shape of a sparse block (see note below) + zeros_per_block: Number of zeros in a sparse block + norm: Norm to use. Could be either `int` or a callable. + If `int`, only L1 and L2 are implemented. + + Note:: + The `sparse_block_shape` is tuple representing (block_ROWS, block_COLS), + irrespective of what the rows / cols mean in the data tensor. That means, + if you were to sparsify a weight tensor in the nn.Linear, which has a + weight shape `(Cout, Cin)`, the `block_ROWS` would refer to the output + channels, while the `block_COLS` would refer to the input channels. + + Note:: + All arguments to the WeightNormSparsifier constructor are "default" + arguments and could be overridden by the configuration provided in the + `prepare` step. + """ + + def __init__( + self, + sparsity_level: float = 0.5, + sparse_block_shape: tuple[int, int] = (1, 4), + zeros_per_block: Optional[int] = None, + norm: Optional[Union[Callable, int]] = None, + ): + if zeros_per_block is None: + zeros_per_block = reduce(operator.mul, sparse_block_shape) + defaults = { + "sparsity_level": sparsity_level, + "sparse_block_shape": sparse_block_shape, + "zeros_per_block": zeros_per_block, + } + if norm is None: + norm = 2 + if callable(norm): + self.norm_fn = norm + elif norm == 1: + self.norm_fn = lambda T: T.abs() + elif norm == 2: + self.norm_fn = lambda T: T * T + else: + raise NotImplementedError(f"L-{norm} is not yet implemented.") + super().__init__(defaults=defaults) + + def _scatter_fold_block_mask( + self, + output_shape, + dim, + indices, + block_shape, + mask=None, + input_shape=None, + device=None, + ): + r"""Creates patches of size `block_shape` after scattering the indices.""" + if mask is None: + assert input_shape is not None + mask = torch.ones(input_shape, device=device) + mask.scatter_(dim=dim, index=indices, value=0) + mask.data = F.fold( + mask, output_size=output_shape, kernel_size=block_shape, stride=block_shape + ) + return mask + + def _make_tensor_mask( + self, data, input_shape, sparsity_level, sparse_block_shape, mask=None + ): + r"""Creates a tensor-level mask. + + Tensor-level mask is described as a mask, where the granularity of sparsification of the + smallest patch is the sparse_block_shape. That means, that for a given mask and a + sparse_block_shape, the smallest "patch" of zeros/ones could be the sparse_block_shape. + + In this context, `sparsity_level` describes the fraction of sparse patches. + """ + h, w = data.shape[-2:] + block_h, block_w = sparse_block_shape + dh = (block_h - h % block_h) % block_h + dw = (block_w - w % block_w) % block_w + + if mask is None: + mask = torch.ones(h + dh, w + dw, device=data.device) + + if sparsity_level >= 1.0: + mask.data = torch.zeros_like(mask) + return mask + elif sparsity_level <= 0.0: + mask.data = torch.ones_like(mask) + return mask + + values_per_block = reduce(operator.mul, sparse_block_shape) + if values_per_block > 1: + # Reduce the data + data = F.avg_pool2d( + data[None, None, :], + kernel_size=sparse_block_shape, + stride=sparse_block_shape, + ceil_mode=True, + ) + data = data.flatten() + num_blocks = len(data) + + data = data.repeat(1, values_per_block, 1) + + threshold_idx = int(round(sparsity_level * num_blocks)) + threshold_idx = max(0, min(num_blocks - 1, threshold_idx)) # Sanity check + _, sorted_idx = torch.topk(data, k=threshold_idx, dim=2, largest=False) + + # Temp reshape for mask + mask_reshape = mask.reshape(data.shape) # data might be reshaped + self._scatter_fold_block_mask( + dim=2, + output_shape=(h + dh, w + dw), + indices=sorted_idx, + block_shape=sparse_block_shape, + mask=mask_reshape, + ) + mask.data = mask_reshape.squeeze().reshape(mask.shape)[:h, :w].contiguous() + return mask + + def _make_block_mask(self, data, sparse_block_shape, zeros_per_block, mask=None): + r"""Creates a block-level mask. + + Block-level mask is described as a mask, where the granularity of sparsification of the + largest patch is the sparse_block_shape. That means that for a given mask and a + sparse_block_shape, the sparsity is computed only within a patch of a size sparse_block_shape. + + In this context the `zeros_per_block` describes the number of zeroed-out elements within a patch. + """ + h, w = data.shape[-2:] + block_h, block_w = sparse_block_shape + dh = (block_h - h % block_h) % block_h + dw = (block_w - w % block_w) % block_w + values_per_block = reduce(operator.mul, sparse_block_shape) + + if mask is None: + mask = torch.ones((h + dh, w + dw), device=data.device) + + if values_per_block == zeros_per_block: + # Everything should be sparsified + mask.data = torch.zeros_like(mask) + return mask + + # create a new padded tensor like data (to match the block_shape) + padded_data = torch.ones(h + dh, w + dw, dtype=data.dtype, device=data.device) + padded_data.fill_(torch.nan) + padded_data[:h, :w] = data + unfolded_data = F.unfold( + padded_data[None, None, :], + kernel_size=sparse_block_shape, + stride=sparse_block_shape, + ) + + # Temp reshape for mask + mask_reshape = mask.reshape(unfolded_data.shape) + _, sorted_idx = torch.topk( + unfolded_data, k=zeros_per_block, dim=1, largest=False + ) + + self._scatter_fold_block_mask( + dim=1, + indices=sorted_idx, + output_shape=padded_data.shape, + block_shape=sparse_block_shape, + mask=mask_reshape, + ) + + mask.data = mask_reshape.squeeze().reshape(mask.shape).contiguous() + return mask + + def update_mask( # type: ignore[call-override, override] + self, + module, + tensor_name, + sparsity_level, + sparse_block_shape, + zeros_per_block, + **kwargs, + ): + values_per_block = reduce(operator.mul, sparse_block_shape) + if zeros_per_block > values_per_block: + raise ValueError( + "Number of zeros per block cannot be more than the total number of elements in that block." + ) + if zeros_per_block < 0: + raise ValueError("Number of zeros per block should be positive.") + + mask = getattr(module.parametrizations, tensor_name)[0].mask + if sparsity_level <= 0 or zeros_per_block == 0: + mask.data = torch.ones_like(mask) + elif sparsity_level >= 1.0 and (zeros_per_block == values_per_block): + mask.data = torch.zeros_like(mask) + else: + ww = self.norm_fn(getattr(module, tensor_name)) + tensor_mask = self._make_tensor_mask( + data=ww, + input_shape=ww.shape, + sparsity_level=sparsity_level, + sparse_block_shape=sparse_block_shape, + ) + if values_per_block != zeros_per_block: + block_mask = self._make_block_mask( + data=ww, + sparse_block_shape=sparse_block_shape, + zeros_per_block=zeros_per_block, + ) + tensor_mask = torch.logical_or(tensor_mask, block_mask) + mask.data = tensor_mask diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/__init__.py b/phivenv/Lib/site-packages/torch/ao/quantization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6c5f0bb2831957213d89805ffa6aa9058f71a077 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/__init__.py @@ -0,0 +1,234 @@ +# mypy: allow-untyped-defs + +from typing import Callable, Optional, Union + +import torch +from torch import Tensor + +from .fake_quantize import * # noqa: F403 +from .fuse_modules import fuse_modules, fuse_modules_qat # noqa: F403 +from .fuser_method_mappings import * # noqa: F403 +from .observer import * # noqa: F403 +from .pt2e._numeric_debugger import ( # noqa: F401 + compare_results, + CUSTOM_KEY, + extract_results_from_loggers, + generate_numeric_debug_handle, + NUMERIC_DEBUG_HANDLE_KEY, + prepare_for_propagation_comparison, +) +from .pt2e.export_utils import ( + _allow_exported_model_train_eval as allow_exported_model_train_eval, + _move_exported_model_to_eval as move_exported_model_to_eval, + _move_exported_model_to_train as move_exported_model_to_train, +) +from .qconfig import * # noqa: F403 +from .qconfig_mapping import * # noqa: F403 +from .quant_type import * # noqa: F403 +from .quantization_mappings import * # noqa: F403 # type: ignore[no-redef] +from .quantize import * # noqa: F403 +from .quantize_jit import * # noqa: F403 +from .stubs import * # noqa: F403 + + +# ensure __module__ is set correctly for public APIs +ObserverOrFakeQuantize = Union[ObserverBase, FakeQuantizeBase] +ObserverOrFakeQuantize.__module__ = "torch.ao.quantization" +for _f in [ + compare_results, + extract_results_from_loggers, + generate_numeric_debug_handle, + prepare_for_propagation_comparison, +]: + _f.__module__ = "torch.ao.quantization" + +__all__ = [ + "DeQuantStub", + "FakeQuantize", + "FakeQuantizeBase", + "FixedQParamsFakeQuantize", + "FixedQParamsObserver", + "FusedMovingAvgObsFakeQuantize", + "HistogramObserver", + "MatchAllNode", + "MinMaxObserver", + "MovingAverageMinMaxObserver", + "MovingAveragePerChannelMinMaxObserver", + "NoopObserver", + "ObserverBase", + "ObserverOrFakeQuantize", + "Pattern", + "PerChannelMinMaxObserver", + "PlaceholderObserver", + "QConfig", + "QConfigAny", + "QConfigDynamic", + "QConfigMapping", + "QuantStub", + "QuantType", + "QuantWrapper", + "RecordingObserver", + "ReuseInputObserver", + "UniformQuantizationObserverBase", + "add_quant_dequant", + "convert", + "convert_dynamic_jit", + "convert_jit", + "default_affine_fixed_qparams_fake_quant", + "default_affine_fixed_qparams_observer", + "default_debug_observer", + "default_dynamic_fake_quant", + "default_dynamic_quant_observer", + "default_embedding_fake_quant", + "default_embedding_fake_quant_4bit", + "default_eval_fn", + "default_fake_quant", + "default_fixed_qparams_range_0to1_fake_quant", + "default_fixed_qparams_range_0to1_observer", + "default_fixed_qparams_range_neg1to1_fake_quant", + "default_fixed_qparams_range_neg1to1_observer", + "default_float_qparams_observer", + "default_float_qparams_observer_4bit", + "default_fused_act_fake_quant", + "default_fused_per_channel_wt_fake_quant", + "default_fused_wt_fake_quant", + "default_histogram_fake_quant", + "default_histogram_observer", + "default_observer", + "default_per_channel_weight_fake_quant", + "default_per_channel_weight_observer", + "default_placeholder_observer", + "default_reuse_input_observer", + "default_symmetric_fixed_qparams_fake_quant", + "default_symmetric_fixed_qparams_observer", + "default_weight_fake_quant", + "default_weight_observer", + "disable_fake_quant", + "disable_observer", + "enable_fake_quant", + "enable_observer", + "fuse_conv_bn", + "fuse_conv_bn_jit", + "fuse_conv_bn_relu", + "fuse_convtranspose_bn", + "fuse_linear_bn", + "fuse_modules", + "fuse_modules_qat", + "fused_per_channel_wt_fake_quant_range_neg_127_to_127", + "fused_wt_fake_quant_range_neg_127_to_127", + "get_combined_dict", + "get_default_compare_output_module_list", + "get_default_custom_config_dict", + "get_default_dynamic_quant_module_mappings", + "get_default_dynamic_sparse_quant_module_mappings", + "get_default_float_to_quantized_operator_mappings", + "get_default_qat_module_mappings", + "get_default_qat_qconfig", + "get_default_qat_qconfig_dict", + "get_default_qat_qconfig_mapping", + "get_default_qconfig", + "get_default_qconfig_dict", + "get_default_qconfig_mapping", + "get_default_qconfig_propagation_list", + "get_default_static_quant_module_mappings", + "get_default_static_quant_reference_module_mappings", + "get_default_static_sparse_quant_module_mappings", + "get_dynamic_quant_module_class", + "get_embedding_qat_module_mappings", + "get_embedding_static_quant_module_mappings", + "get_fuser_method", + "get_fuser_method_new", + "get_observer_state_dict", + "get_quantized_operator", + "get_static_quant_module_class", + "load_observer_state_dict", + "move_exported_model_to_eval", + "move_exported_model_to_train", + "allow_exported_model_train_eval", + "no_observer_set", + "per_channel_weight_observer_range_neg_127_to_127", + "prepare", + "prepare_dynamic_jit", + "prepare_jit", + "prepare_qat", + "propagate_qconfig_", + "qconfig_equals", + "quantize", + "quantize_dynamic", + "quantize_dynamic_jit", + "quantize_jit", + "quantize_qat", + "script_qconfig", + "script_qconfig_dict", + "swap_module", + "weight_observer_range_neg_127_to_127", + "generate_numeric_debug_handle", + "CUSTOM_KEY", + "NUMERIC_DEBUG_HANDLE_KEY", + "prepare_for_propagation_comparison", + "extract_results_from_loggers", + "compare_results", + # from torchao, should be merged with torchao + # in the future + "AffineQuantizedObserverBase", + "Granularity", + "MappingType", + "PerAxis", + "PerBlock", + "PerGroup", + "PerRow", + "PerTensor", + "PerToken", + "TorchAODType", + "ZeroPointDomain", + "get_block_size", +] + + +def default_eval_fn(model, calib_data): + r"""Define the default evaluation function. + + Default evaluation function takes a torch.utils.data.Dataset or a list of + input Tensors and run the model on the dataset + """ + for data, _target in calib_data: + model(data) + + +class _DerivedObserverOrFakeQuantize(ObserverBase): + r"""This observer is used to describe an observer whose quantization parameters + are derived from other observers + """ + + def __init__( + self, + dtype: torch.dtype, + obs_or_fqs: list[ObserverOrFakeQuantize], + derive_qparams_fn: Callable[ + [list[ObserverOrFakeQuantize]], tuple[Tensor, Tensor] + ], + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + qscheme: Optional[torch.qscheme] = None, + ch_axis: Optional[int] = None, + ): + super().__init__(dtype) + self.obs_or_fqs = obs_or_fqs + self.derive_qparams_fn = derive_qparams_fn + self.quant_min = quant_min + self.quant_max = quant_max + self.qscheme = qscheme + self.ch_axis = ch_axis + + from .utils import is_per_channel + + if is_per_channel(self.qscheme): + assert self.ch_axis is not None, ( + "Must provide a valid ch_axis if qscheme is per channel" + ) + + def forward(self, x: Tensor) -> Tensor: + return x + + def calculate_qparams(self): # type:ignore[override] + return self.derive_qparams_fn(self.obs_or_fqs) diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a0ce1485941fc6cfed0d20b4949a03eb044c2b9 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/_correct_bias.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/_correct_bias.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87cd8257dd084fe505c3d0c25f98e6bfa1f0e410 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/_correct_bias.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/_equalize.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/_equalize.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f20aa13394854aa063195c218c5754630838325 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/_equalize.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/_learnable_fake_quantize.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/_learnable_fake_quantize.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..241950f65eba6b9b1c0ec59b11a0e3f637121a6c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/_learnable_fake_quantize.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/fake_quantize.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/fake_quantize.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7acfa22af20929737c9b784760a53de2a4c8e79 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/fake_quantize.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/fuse_modules.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/fuse_modules.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ada8e0d0c64db6d70875cfeeb781f9170295caa2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/fuse_modules.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/fuser_method_mappings.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/fuser_method_mappings.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d83e066581954c60d485bba0437786adf556972 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/fuser_method_mappings.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/observer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/observer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ad191fc9f1e81dfdcdfe258a1486cdcd92f23d5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/observer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/qconfig.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/qconfig.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..410442886a38116041f95ee2c7abc70025e6704c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/qconfig.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/qconfig_mapping.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/qconfig_mapping.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7dedb4ee23ede67963e027f27480175921c9de2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/qconfig_mapping.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/quant_type.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/quant_type.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7fb40e9fb0241dbf0d3d7672b2b2c2523645f6a7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/quant_type.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/quantization_mappings.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/quantization_mappings.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bca13b41c577f86faafe69e3adb164d9c18bd592 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/quantization_mappings.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/quantize.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/quantize.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a36f8d4fd50c00ad60aecc2d7a55d30dc63d5ae5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/quantize.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/quantize_fx.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/quantize_fx.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c724c6eeb105a0435b93d4fca368de3f215eaaca Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/quantize_fx.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/quantize_jit.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/quantize_jit.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a1767b4427fd894f573dadb9eddaba3bce90c6d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/quantize_jit.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/quantize_pt2e.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/quantize_pt2e.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b337738371ae0377457bfc05dd8400e93dc0ed7d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/quantize_pt2e.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/stubs.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/stubs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..116403f23401e101025685f41cd8c637031478d5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/stubs.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6a37ec34539ab70168cc88439805ce2c7990548 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/_correct_bias.py b/phivenv/Lib/site-packages/torch/ao/quantization/_correct_bias.py new file mode 100644 index 0000000000000000000000000000000000000000..3512464afc9ba8baa2e36ead7f7eec4cae76a531 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/_correct_bias.py @@ -0,0 +1,156 @@ +# mypy: allow-untyped-defs +import torch +import torch.ao.nn.quantized as nnq +import torch.ao.ns._numeric_suite as ns +import torch.ao.quantization +import torch.nn as nn + + +__all__ = [ + "get_module", + "parent_child_names", + "get_param", + "MeanShadowLogger", + "bias_correction", +] + +_supported_modules = {nn.Linear, nn.Conv2d} +_supported_modules_quantized = {nnq.Linear, nnq.Conv2d} + + +def get_module(model, name): + """Given name of submodule, this function grabs the submodule from given model.""" + return dict(model.named_modules())[name] + + +def parent_child_names(name): + """Split full name of submodule into parent submodule's full name and submodule's name.""" + split_name = name.rsplit(".", 1) + if len(split_name) == 1: + return "", split_name[0] + else: + return split_name[0], split_name[1] + + +def get_param(module, attr): + """Get the parameter given a module and attribute. + + Sometimes the weights/bias attribute gives you the raw tensor, but sometimes + gives a function that will give you the raw tensor, this function takes care of that logic + """ + param = getattr(module, attr, None) + if callable(param): + return param() + else: + return param + + +class MeanShadowLogger(ns.Logger): + """Mean Logger for a Shadow module. + + A logger for a Shadow module whose purpose is to record the rolling mean + of the data passed to the floating point and quantized models + """ + + def __init__(self): + """Set up initial values for float and quantized stats, count, float sum, and quant sum.""" + super().__init__() + self.stats["float"] = None + self.stats["quantized"] = None + self.count = 0 + self.float_sum = None + self.quant_sum = None + + def forward(self, x, y): # type: ignore[override] + """Compute the average of quantized and floating-point data from modules. + + The inputs x,y are output data from the quantized and floating-point modules. + x is for the quantized module, y is for the floating point module + """ + if x.is_quantized: + x = x.dequantize() + + self.count += 1 + if self.stats["quantized"] is None: + self.stats["quantized"] = x + self.quant_sum = x + else: + self.quant_sum += x + self.stats["quantized"] = self.quant_sum / self.count + + if self.stats["float"] is None: + self.stats["float"] = y + self.float_sum = y + else: + self.float_sum += y + self.stats["float"] = self.float_sum / self.count + + def clear(self): + self.stats["float"] = None + self.stats["quantized"] = None + self.count = 0 + self.float_sum = None + self.quant_sum = None + + +def bias_correction( + float_model, + quantized_model, + img_data, + target_modules=_supported_modules_quantized, + neval_batches=None, +): + """Perform bias correction on a module. + + Using numeric suite shadow module, the expected output of the floating point and quantized modules + is recorded. Using that data the bias of supported modules is shifted to compensate for the drift caused + by quantization + Paper reference: https://arxiv.org/pdf/1906.04721.pdf (Section 4.2) + + Args: + float_model: a trained model that serves as a reference to what bias correction should aim for + quantized_model: quantized form of float_model that bias correction is to applied to + img_data: calibration data to estimate the expected output (used to find quantization error) + target_modules: specifies what submodules in quantized_model need bias correction (can be extended to + unquantized submodules) + neval_batches: a cap to the number of batches you want to be used for estimating the expected output + """ + ns.prepare_model_with_stubs( + float_model, quantized_model, _supported_modules, MeanShadowLogger + ) + + uncorrected_modules = { + name: submodule + for name, submodule in quantized_model.named_modules() + if type(submodule) in target_modules + } + + for uncorrected_module in uncorrected_modules: + quantized_submodule = get_module(quantized_model, uncorrected_module) + bias = get_param(quantized_submodule, "bias") + if bias is not None: + for count, data in enumerate(img_data, start=1): + quantized_model(data[0]) + if count == neval_batches: + break + ob_dict = ns.get_logger_dict(quantized_model) + parent_name, _ = parent_child_names(uncorrected_module) + + float_data = ob_dict[parent_name + ".stats"]["float"] + quant_data = ob_dict[parent_name + ".stats"]["quantized"] + + # math for expected_error + quantization_error = quant_data - float_data + dims = list(range(quantization_error.dim())) + # Note: we don't want to take the mean over the output channel dimension + dims.remove(1) + expected_error = torch.mean(quantization_error, dims) + + updated_bias = bias.data - expected_error + + bias.data = updated_bias + + # Resets the data contained in the loggers + for name, submodule in quantized_model.named_modules(): + if isinstance(submodule, MeanShadowLogger): + submodule.clear() diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/_equalize.py b/phivenv/Lib/site-packages/torch/ao/quantization/_equalize.py new file mode 100644 index 0000000000000000000000000000000000000000..2bc91a69bd4476f1427f71075bd2d9ca5c6a2415 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/_equalize.py @@ -0,0 +1,278 @@ +# mypy: allow-untyped-defs +import copy +from itertools import chain +from typing import Any + +import torch + + +__all__ = [ + "set_module_weight", + "set_module_bias", + "has_bias", + "get_module_weight", + "get_module_bias", + "max_over_ndim", + "min_over_ndim", + "channel_range", + "get_name_by_module", + "cross_layer_equalization", + "process_paired_modules_list_to_name", + "expand_groups_in_paired_modules_list", + "equalize", + "converged", +] + +_supported_types = {torch.nn.Conv2d, torch.nn.Linear, torch.nn.Conv1d} +_supported_intrinsic_types = { + torch.ao.nn.intrinsic.ConvReLU2d, + torch.ao.nn.intrinsic.LinearReLU, + torch.ao.nn.intrinsic.ConvReLU1d, +} +_all_supported_types = _supported_types.union(_supported_intrinsic_types) + + +def set_module_weight(module, weight) -> None: + if type(module) in _supported_types: + module.weight = torch.nn.Parameter(weight) + else: + module[0].weight = torch.nn.Parameter(weight) + + +def set_module_bias(module, bias) -> None: + if type(module) in _supported_types: + module.bias = torch.nn.Parameter(bias) + else: + module[0].bias = torch.nn.Parameter(bias) + + +def has_bias(module) -> bool: + if type(module) in _supported_types: + return module.bias is not None + else: + return module[0].bias is not None + + +def get_module_weight(module): + if type(module) in _supported_types: + return module.weight + else: + return module[0].weight + + +def get_module_bias(module): + if type(module) in _supported_types: + return module.bias + else: + return module[0].bias + + +def max_over_ndim(input, axis_list, keepdim=False): + """Apply 'torch.max' over the given axes.""" + axis_list.sort(reverse=True) + for axis in axis_list: + input, _ = input.max(axis, keepdim) + return input + + +def min_over_ndim(input, axis_list, keepdim=False): + """Apply 'torch.min' over the given axes.""" + axis_list.sort(reverse=True) + for axis in axis_list: + input, _ = input.min(axis, keepdim) + return input + + +def channel_range(input, axis=0): + """Find the range of weights associated with a specific channel.""" + size_of_tensor_dim = input.ndim + axis_list = list(range(size_of_tensor_dim)) + axis_list.remove(axis) + + mins = min_over_ndim(input, axis_list) + maxs = max_over_ndim(input, axis_list) + + assert mins.size(0) == input.size(axis), ( + "Dimensions of resultant channel range does not match size of requested axis" + ) + return maxs - mins + + +def get_name_by_module(model, module): + """Get the name of a module within a model. + + Args: + model: a model (nn.module) that equalization is to be applied on + module: a module within the model + + Returns: + name: the name of the module within the model + """ + for name, m in model.named_modules(): + if m is module: + return name + raise ValueError("module is not in the model") + + +def cross_layer_equalization(module1, module2, output_axis=0, input_axis=1): + """Scale the range of Tensor1.output to equal Tensor2.input. + + Given two adjacent tensors', the weights are scaled such that + the ranges of the first tensors' output channel are equal to the + ranges of the second tensors' input channel + """ + if ( + type(module1) not in _all_supported_types + or type(module2) not in _all_supported_types + ): + raise ValueError( + "module type not supported:", type(module1), " ", type(module2) + ) + + bias = get_module_bias(module1) if has_bias(module1) else None + + weight1 = get_module_weight(module1) + weight2 = get_module_weight(module2) + + if weight1.size(output_axis) != weight2.size(input_axis): + raise TypeError( + "Number of output channels of first arg do not match \ + number input channels of second arg" + ) + + weight1_range = channel_range(weight1, output_axis) + weight2_range = channel_range(weight2, input_axis) + + # producing scaling factors to applied + weight2_range += 1e-9 + scaling_factors = torch.sqrt(weight1_range / weight2_range) + inverse_scaling_factors = torch.reciprocal(scaling_factors) + + if bias is not None: + bias = bias * inverse_scaling_factors + + # formatting the scaling (1D) tensors to be applied on the given argument tensors + # pads axis to (1D) tensors to then be broadcasted + size1 = [1] * weight1.ndim + size1[output_axis] = weight1.size(output_axis) + size2 = [1] * weight2.ndim + size2[input_axis] = weight2.size(input_axis) + + scaling_factors = torch.reshape(scaling_factors, size2) + inverse_scaling_factors = torch.reshape(inverse_scaling_factors, size1) + + weight1 = weight1 * inverse_scaling_factors + weight2 = weight2 * scaling_factors + + set_module_weight(module1, weight1) + if bias is not None: + set_module_bias(module1, bias) + set_module_weight(module2, weight2) + + +def process_paired_modules_list_to_name(model, paired_modules_list): + """Processes a list of paired modules to a list of names of paired modules.""" + + for group in paired_modules_list: + for i, item in enumerate(group): + if isinstance(item, torch.nn.Module): + group[i] = get_name_by_module(model, item) + elif not isinstance(item, str): + raise TypeError("item must be a nn.Module or a string") + return paired_modules_list + + +def expand_groups_in_paired_modules_list(paired_modules_list): + """Expands module pair groups larger than two into groups of two modules.""" + new_list = [] + + for group in paired_modules_list: + if len(group) == 1: + raise ValueError("Group must have at least two modules") + elif len(group) == 2: + new_list.append(group) + elif len(group) > 2: + new_list.extend([group[i], group[i + 1]] for i in range(len(group) - 1)) + + return new_list + + +def equalize(model, paired_modules_list, threshold=1e-4, inplace=True): + """Equalize modules until convergence is achieved. + + Given a list of adjacent modules within a model, equalization will + be applied between each pair, this will repeated until convergence is achieved + + Keeps a copy of the changing modules from the previous iteration, if the copies + are not that different than the current modules (determined by converged_test), + then the modules have converged enough that further equalizing is not necessary + + Reference is section 4.1 of this paper https://arxiv.org/pdf/1906.04721.pdf + + Args: + model: a model (nn.Module) that equalization is to be applied on + paired_modules_list (List(List[nn.module || str])): a list of lists + where each sublist is a pair of two submodules found in the model, + for each pair the two modules have to be adjacent in the model, + with only piece-wise-linear functions like a (P)ReLU or LeakyReLU in between + to get expected results. + The list can contain either modules, or names of modules in the model. + If you pass multiple modules in the same list, they will all be equalized together. + threshold (float): a number used by the converged function to determine what degree + of similarity between models is necessary for them to be called equivalent + inplace (bool): determines if function is inplace or not + """ + + paired_modules_list = process_paired_modules_list_to_name( + model, paired_modules_list + ) + + if not inplace: + model = copy.deepcopy(model) + + paired_modules_list = expand_groups_in_paired_modules_list(paired_modules_list) + + name_to_module: dict[str, torch.nn.Module] = {} + previous_name_to_module: dict[str, Any] = {} + name_set = set(chain.from_iterable(paired_modules_list)) + + for name, module in model.named_modules(): + if name in name_set: + name_to_module[name] = module + previous_name_to_module[name] = None + while not converged(name_to_module, previous_name_to_module, threshold): + for pair in paired_modules_list: + previous_name_to_module[pair[0]] = copy.deepcopy(name_to_module[pair[0]]) + previous_name_to_module[pair[1]] = copy.deepcopy(name_to_module[pair[1]]) + + cross_layer_equalization(name_to_module[pair[0]], name_to_module[pair[1]]) + + return model + + +def converged(curr_modules, prev_modules, threshold=1e-4): + """Test whether modules are converged to a specified threshold. + + Tests for the summed norm of the differences between each set of modules + being less than the given threshold + + Takes two dictionaries mapping names to modules, the set of names for each dictionary + should be the same, looping over the set of names, for each name take the difference + between the associated modules in each dictionary + + """ + if curr_modules.keys() != prev_modules.keys(): + raise ValueError( + "The keys to the given mappings must have the same set of names of modules" + ) + + summed_norms = torch.tensor(0.0) + if None in prev_modules.values(): + return False + for name in curr_modules.keys(): + curr_weight = get_module_weight(curr_modules[name]) + prev_weight = get_module_weight(prev_modules[name]) + + difference = curr_weight.sub(prev_weight) + summed_norms += torch.norm(difference) + return bool(summed_norms < threshold) diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/_learnable_fake_quantize.py b/phivenv/Lib/site-packages/torch/ao/quantization/_learnable_fake_quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..91591ddbe9228b3eee9e83857418517a75a19e94 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/_learnable_fake_quantize.py @@ -0,0 +1,201 @@ +# mypy: allow-untyped-defs + +import torch +from torch.nn.parameter import Parameter + + +__all__: list[str] = [] + + +class _LearnableFakeQuantize(torch.ao.quantization.FakeQuantizeBase): + r"""Generalized extension of the FakeQuantize module in fake_quantize.py. + + This is an extension of the FakeQuantize module in fake_quantize.py, which + supports more generalized lower-bit quantization and supports learning of the scale + and zero point parameters through backpropagation. + + In addition to the attributes in the original FakeQuantize module, the _LearnableFakeQuantize + module also includes the following attributes to support quantization parameter learning. + + * :attr:`channel_len` defines the length of the channel when initializing scale and zero point + for the per channel case. + + * :attr:`use_grad_scaling` defines the flag for whether the gradients for scale and zero point are + normalized by the constant, which is proportional to the square root of the number of + elements in the tensor. The related literature justifying the use of this particular constant + can be found here: https://openreview.net/pdf?id=rkgO66VKDS. + + * :attr:`fake_quant_enabled` defines the flag for enabling fake quantization on the output. + + * :attr:`static_enabled` defines the flag for using observer's static estimation for + scale and zero point. + + * :attr:`learning_enabled` defines the flag for enabling backpropagation for scale and zero point. + """ + + def __init__( + self, + observer, + quant_min=0, + quant_max=255, + scale=1.0, + zero_point=0.0, + channel_len=-1, + use_grad_scaling=False, + **observer_kwargs, + ): + super().__init__() + assert quant_min < quant_max, "quant_min must be strictly less than quant_max." + self.quant_min = quant_min + self.quant_max = quant_max + # also pass quant_min and quant_max to observer + observer_kwargs["quant_min"] = quant_min + observer_kwargs["quant_max"] = quant_max + self.use_grad_scaling = use_grad_scaling + if channel_len == -1: + self.scale = Parameter(torch.tensor([scale])) + self.zero_point = Parameter(torch.tensor([zero_point])) + else: + assert isinstance(channel_len, int) and channel_len > 0, ( + "Channel size must be a positive integer." + ) + self.scale = Parameter(torch.tensor([scale] * channel_len)) + self.zero_point = Parameter(torch.tensor([zero_point] * channel_len)) + + self.activation_post_process = observer(**observer_kwargs) + assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, ( + "quant_min out of bound" + ) + assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, ( + "quant_max out of bound" + ) + self.dtype = self.activation_post_process.dtype + self.qscheme = self.activation_post_process.qscheme + self.ch_axis = ( + self.activation_post_process.ch_axis + if hasattr(self.activation_post_process, "ch_axis") + else -1 + ) + self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.uint8)) + self.register_buffer("static_enabled", torch.tensor([1], dtype=torch.uint8)) + self.register_buffer("learning_enabled", torch.tensor([0], dtype=torch.uint8)) + + bitrange = torch.tensor(quant_max - quant_min + 1).double() + self.bitwidth = int(torch.log2(bitrange).item()) + self.register_buffer("eps", torch.tensor([torch.finfo(torch.float32).eps])) + + @torch.jit.export + def enable_param_learning(self): + r"""Enable parameter learning over static observer estimates. + + Enables learning of quantization parameters and + disables static observer estimates. Forward path returns fake quantized X. + """ + self.toggle_qparam_learning(enabled=True).toggle_fake_quant( + enabled=True + ).toggle_observer_update(enabled=False) + return self + + @torch.jit.export + def enable_static_estimate(self): + """Enable static estimates of quantization parameters. + + Enables static observer estimates and disables learning of + quantization parameters. Forward path returns fake quantized X. + """ + self.toggle_qparam_learning(enabled=False).toggle_fake_quant( + enabled=True + ).toggle_observer_update(enabled=True) + + @torch.jit.export + def enable_static_observation(self): + """Enable accumulation of data without updating quantization parameters. + + Enables static observer accumulating data from input but doesn't + update the quantization parameters. Forward path returns the original X. + """ + self.toggle_qparam_learning(enabled=False).toggle_fake_quant( + enabled=False + ).toggle_observer_update(enabled=True) + + @torch.jit.export + def toggle_observer_update(self, enabled=True): + self.static_enabled[0] = int(enabled) # type: ignore[operator] + return self + + @torch.jit.export + def enable_observer(self, enabled=True): + self.toggle_observer_update(enabled) + + @torch.jit.export + def toggle_qparam_learning(self, enabled=True): + self.learning_enabled[0] = int(enabled) # type: ignore[operator] + self.scale.requires_grad = enabled + self.zero_point.requires_grad = enabled + return self + + @torch.jit.export + def toggle_fake_quant(self, enabled=True): + self.fake_quant_enabled[0] = int(enabled) + return self + + @torch.jit.export + def observe_quant_params(self): + print(f"_LearnableFakeQuantize Scale: {self.scale.detach()}") + print(f"_LearnableFakeQuantize Zero Point: {self.zero_point.detach()}") + + @torch.jit.export + def calculate_qparams(self): # type: ignore[override] + self.scale.data.clamp_(min=self.eps.item()) # type: ignore[operator] + scale = self.scale.detach() + zero_point = ( + self.zero_point.detach() + .round() + .clamp(self.quant_min, self.quant_max) + .long() + ) + return scale, zero_point + + def forward(self, X): + if self.static_enabled[0] == 1: # type: ignore[index] + self.activation_post_process(X.detach()) + _scale, _zero_point = self.activation_post_process.calculate_qparams() + _scale = _scale.to(self.scale.device) + _zero_point = _zero_point.to(self.zero_point.device) + self.scale.data.copy_(_scale) + self.zero_point.data.copy_(_zero_point) + else: + self.scale.data.clamp_(min=self.eps.item()) # type: ignore[operator] + + if self.fake_quant_enabled[0] == 1: + if self.qscheme in ( + torch.per_channel_symmetric, + torch.per_tensor_symmetric, + ): + self.zero_point.data.zero_() + + if self.use_grad_scaling: + grad_factor = 1.0 / (X.numel() * self.quant_max) ** 0.5 + else: + grad_factor = 1.0 + if self.qscheme in (torch.per_channel_symmetric, torch.per_channel_affine): + X = torch._fake_quantize_learnable_per_channel_affine( + X, + self.scale, + self.zero_point, + self.ch_axis, + self.quant_min, + self.quant_max, + grad_factor, + ) + else: + X = torch._fake_quantize_learnable_per_tensor_affine( + X, + self.scale, + self.zero_point, + self.quant_min, + self.quant_max, + grad_factor, + ) + + return X diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__init__.py b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b62f0a9ef3d581a6dee7d4a065df598711d67d2e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__init__.py @@ -0,0 +1,30 @@ +from .backend_config import ( + BackendConfig, + BackendPatternConfig, + DTypeConfig, + DTypeWithConstraints, + ObservationType, +) +from .executorch import get_executorch_backend_config +from .fbgemm import get_fbgemm_backend_config +from .native import get_native_backend_config, get_native_backend_config_dict +from .onednn import get_onednn_backend_config +from .qnnpack import get_qnnpack_backend_config +from .tensorrt import get_tensorrt_backend_config, get_tensorrt_backend_config_dict + + +__all__ = [ + "get_fbgemm_backend_config", + "get_native_backend_config", + "get_native_backend_config_dict", + "get_qnnpack_backend_config", + "get_tensorrt_backend_config", + "get_tensorrt_backend_config_dict", + "get_executorch_backend_config", + "BackendConfig", + "BackendPatternConfig", + "DTypeConfig", + "DTypeWithConstraints", + "ObservationType", + "get_onednn_backend_config", +] diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..186921f7d63e017477d3e5d70ce4f5497fc84745 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/_common_operator_config_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/_common_operator_config_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..076f1507b75977f17bee5cfde64292c43ed55e6f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/_common_operator_config_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/_qnnpack_pt2e.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/_qnnpack_pt2e.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94db15ccf1fb75039501b3c93d4ff130f31e6030 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/_qnnpack_pt2e.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/backend_config.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/backend_config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aefc7e5590bde0e88591734712d50351f3935013 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/backend_config.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/executorch.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/executorch.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d7f67b547a5c0ab0b121e3a65868e6605a5d8c0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/executorch.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/fbgemm.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/fbgemm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af6fbb6753a50985b05b791f1d077e3d278391c2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/fbgemm.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/native.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/native.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63597dba507a9a799684b224d4e909e43c99a446 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/native.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/observation_type.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/observation_type.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..457c665abf533ee7865b0580537e7db3ff038b72 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/observation_type.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/onednn.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/onednn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21e009f29fbe55ff37966f0107d598b025ae6efc Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/onednn.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/qnnpack.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/qnnpack.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0e210acc4c9ef2d2e4d8729b7673ca0448b49b0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/qnnpack.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/tensorrt.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/tensorrt.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b683838f78cb04a06634d0206588331f66f7d2db Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/tensorrt.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e94b1e0c843127587da2fab0f5389f025fc0658 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/x86.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/x86.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..539334b9be046864b2518a6d5728c4671c3b4512 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/x86.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/_common_operator_config_utils.py b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/_common_operator_config_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b548ab9788aee5d8a54d82906ef21508d95f07b0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/_common_operator_config_utils.py @@ -0,0 +1,782 @@ +# mypy: allow-untyped-defs +import copy +import operator +from collections import namedtuple +from typing import Callable, Union + +import torch +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.intrinsic.qat as nniqat +import torch.ao.nn.qat as nnqat +import torch.ao.nn.quantized.reference as nnqr +import torch.nn as nn +import torch.nn.functional as F +from torch.ao.quantization.fuser_method_mappings import ( + _sequential_wrapper2, + fuse_conv_bn, + fuse_conv_bn_relu, + fuse_convtranspose_bn, + fuse_linear_bn, +) + +from .backend_config import ( + BackendPatternConfig, + DTypeConfig, + DTypeWithConstraints, + ObservationType, +) + + +__all__: list[str] = [] + +# TODO: rename to be more explicit, e.g. qat_conv_relu +_ConvMetadata = namedtuple( + "_ConvMetadata", + [ + "root", + "transpose", + "bn", + "reference", + "transpose_reference", + "fused_conv_relu", + "fused_conv_bn", + "fused_conv_bn_relu", + "qat", + "relu_qat", + "bn_qat", + "bn_relu_qat", + "func", + "func_transpose", + ], +) +_Conv1dMetadata = _ConvMetadata( + nn.Conv1d, + nn.ConvTranspose1d, + nn.BatchNorm1d, + nnqr.Conv1d, + nnqr.ConvTranspose1d, + nni.ConvReLU1d, + nni.ConvBn1d, + nni.ConvBnReLU1d, + nnqat.Conv1d, + nniqat.ConvReLU1d, + nniqat.ConvBn1d, + nniqat.ConvBnReLU1d, + F.conv1d, + F.conv_transpose1d, +) +_Conv2dMetadata = _ConvMetadata( + nn.Conv2d, + nn.ConvTranspose2d, + nn.BatchNorm2d, + nnqr.Conv2d, + nnqr.ConvTranspose2d, + nni.ConvReLU2d, + nni.ConvBn2d, + nni.ConvBnReLU2d, + nnqat.Conv2d, + nniqat.ConvReLU2d, + nniqat.ConvBn2d, + nniqat.ConvBnReLU2d, + F.conv2d, + F.conv_transpose2d, +) +_Conv3dMetadata = _ConvMetadata( + nn.Conv3d, + nn.ConvTranspose3d, + nn.BatchNorm3d, + nnqr.Conv3d, + nnqr.ConvTranspose3d, + nni.ConvReLU3d, + nni.ConvBn3d, + nni.ConvBnReLU3d, + nnqat.Conv3d, + nniqat.ConvReLU3d, + nniqat.ConvBn3d, + nniqat.ConvBnReLU3d, + F.conv3d, + F.conv_transpose3d, +) + +# Add constraints for fixed qparams ops like sigmoid and tanh to ensure values +# fall within the proper ranges, e.g. [0, 1] for sigmoid, [-1, 1] for tanh +_FIXED_QPARAM_OP_0TO1_CONSTRAINTS = DTypeWithConstraints( + dtype=torch.quint8, + quant_min_lower_bound=0, + quant_max_upper_bound=255, + scale_exact_match=1.0 / 256.0, + zero_point_exact_match=0, +) +_FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS = DTypeWithConstraints( + dtype=torch.quint8, + quant_min_lower_bound=0, + quant_max_upper_bound=255, + scale_exact_match=2.0 / 256.0, + zero_point_exact_match=128, +) +_FIXED_QPARAMS_OP_TO_CONSTRAINTS: dict[Union[Callable, str], DTypeWithConstraints] = { + torch.nn.Hardsigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + torch.nn.functional.hardsigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + "hardsigmoid": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + "hardsigmoid_": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + torch.nn.Sigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + torch.sigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + "sigmoid": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + "sigmoid_": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + torch.nn.Softmax: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + torch.nn.Tanh: _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS, + torch.tanh: _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS, + "tanh": _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS, + "tanh_": _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS, +} + + +def _get_binary_op_configs( + dtype_configs: list[DTypeConfig], +) -> list[BackendPatternConfig]: + binary_op_configs: list[BackendPatternConfig] = [] + num_tensor_args_to_observation_type_mapping = { + # TODO: this is not used right now since we have extra check in prepare + # will need to change this to NO_OBSERVER later after we implemented + # Tensor dtype inference properly + 0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, + 1: ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT, + 2: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, + } + for op_with_quantized_bop_scalar_variant in [ + operator.add, + torch.add, + operator.mul, + torch.mul, + ]: + bop_patterns = [ + (op_with_quantized_bop_scalar_variant, nn.ReLU), + (op_with_quantized_bop_scalar_variant, F.relu), + (op_with_quantized_bop_scalar_variant, torch.relu), + op_with_quantized_bop_scalar_variant, + ] + binary_op_configs.extend( + BackendPatternConfig(bop_pattern) + .set_dtype_configs(dtype_configs) # noqa: E131 + ._set_num_tensor_args_to_observation_type( + num_tensor_args_to_observation_type_mapping + ) + for bop_pattern in bop_patterns + ) + # matmul + binary_op_configs.append( + BackendPatternConfig(torch.matmul).set_dtype_configs(dtype_configs) # noqa: E131 + ) + return binary_op_configs + + +def _get_linear_configs(dtype_configs: list[DTypeConfig]) -> list[BackendPatternConfig]: + """ + Return all configs related to linear modules and ops. + """ + observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + linear_configs: list[BackendPatternConfig] = [] + + # (1) Single linear modules/functions + # ------------------------------------- + # linear module + linear_configs.append( + BackendPatternConfig(torch.nn.Linear) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(torch.nn.Linear) + .set_reference_quantized_module(nnqr.Linear) + .set_qat_module(nnqat.Linear) + ) + # linear qat module + linear_configs.append( + BackendPatternConfig(nnqat.Linear) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(torch.nn.Linear) + .set_reference_quantized_module(nnqr.Linear) + ) + # functional linear + linear_configs.append( + BackendPatternConfig(torch.nn.functional.linear) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1, "bias": 2}) + ) + + # (2) Linear + relu + # ------------------- + # 2.1 linear module + relu fusion config + # linear relu, linear module + relu module + linear_configs.append( + BackendPatternConfig((torch.nn.Linear, torch.nn.ReLU)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(_sequential_wrapper2(nni.LinearReLU)) + .set_fused_module(nni.LinearReLU) + ) + # linear relu, linear module + functional relu + linear_configs.append( + BackendPatternConfig((torch.nn.Linear, torch.nn.functional.relu)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(_sequential_wrapper2(nni.LinearReLU)) + .set_fused_module(nni.LinearReLU) + ) + + # 2.2 linear module + relu, fused module configs + # linear relu, fused module + linear_configs.append( + BackendPatternConfig(nni.LinearReLU) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(torch.nn.Linear) + .set_reference_quantized_module(nnqr.Linear) + .set_qat_module(nniqat.LinearReLU) + ) + # linear relu, qat fused module + linear_configs.append( + BackendPatternConfig(nniqat.LinearReLU) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(torch.nn.Linear) + .set_reference_quantized_module(nnqr.Linear) + ) + # 2.3 functional linear + relu configs + # linear relu, functional linear + relu module + linear_configs.append( + BackendPatternConfig((F.linear, torch.nn.ReLU)) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + # linear relu, functional linear + functional relu + linear_configs.append( + BackendPatternConfig((F.linear, F.relu)) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + + # (3) Linear + batchnorm + # ------------------------ + # 3.1 linear bn fusion + linear_configs.append( + BackendPatternConfig((nn.Linear, nn.BatchNorm1d)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(fuse_linear_bn) + .set_fused_module(nni.LinearBn1d) + ) + + # 3.2 linear bn fused + # linear bn, fused module + linear_configs.append( + BackendPatternConfig(nni.LinearBn1d) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(torch.nn.Linear) + .set_reference_quantized_module(nnqr.Linear) + .set_qat_module(nniqat.LinearBn1d) + ) + # linear bn, qat fused module + linear_configs.append( + BackendPatternConfig(nniqat.LinearBn1d) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(torch.nn.Linear) + .set_reference_quantized_module(nnqr.Linear) + ) + return linear_configs + + +def _get_conv_configs(dtype_configs): + """ + Return all configs related to conv modules and ops. + """ + conv_configs = [] + observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + for convs in [_Conv1dMetadata, _Conv2dMetadata, _Conv3dMetadata]: + # (1) Single conv modules/functions + # ----------------------------------- + # conv module + conv_configs.append( + BackendPatternConfig(convs.root) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + .set_qat_module(convs.qat) + ) + # conv qat module + conv_configs.append( + BackendPatternConfig(convs.qat) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + ) + # functional conv + conv_configs.append( + BackendPatternConfig(convs.func) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1, "bias": 2}) + ) + + # (2) Conv + relu + # ----------------- + # 2.1 conv module + relu fusion configs + # conv relu fusion, conv module + relu module + conv_configs.append( + BackendPatternConfig((convs.root, torch.nn.ReLU)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu)) + .set_fused_module(convs.fused_conv_relu) + ) + # conv relu fusion, conv module + functional relu + conv_configs.append( + BackendPatternConfig((convs.root, F.relu)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu)) + .set_fused_module(convs.fused_conv_relu) + ) + # 2.2 conv module + relu fused module configs + # conv relu, fused module + conv_configs.append( + BackendPatternConfig(convs.fused_conv_relu) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + .set_qat_module(convs.relu_qat) + ) + # conv relu, qat fused module + conv_configs.append( + BackendPatternConfig(convs.relu_qat) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + ) + # 2.3 functional conv + relu configs + # conv relu, functional conv + relu module + conv_configs.append( + BackendPatternConfig((convs.func, torch.nn.ReLU)) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + # conv relu, functional conv + functional relu + conv_configs.append( + BackendPatternConfig((convs.func, F.relu)) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + + # fused conv relu + conv_configs.append( + BackendPatternConfig(convs.fused_conv_relu) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_qat_module(convs.relu_qat) + ) + + conv_configs.append( + BackendPatternConfig(convs.relu_qat) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + ) + + # (3) Conv + batchnorm (+ relu) + # ------------------------------- + # 3.1 conv bn fusion configs + # conv + bn fusion + conv_configs.append( + BackendPatternConfig((convs.root, convs.bn)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(fuse_conv_bn) + .set_fused_module(convs.fused_conv_bn) + ) + # conv + bn + relu module fusion + conv_configs.append( + BackendPatternConfig((convs.root, convs.bn, nn.ReLU)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(fuse_conv_bn_relu) + .set_fused_module(convs.fused_conv_bn_relu) + ) + # conv + bn + relu functional fusion + conv_configs.append( + BackendPatternConfig((convs.root, convs.bn, F.relu)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_root_module(convs.root) + .set_fuser_method(fuse_conv_bn_relu) + .set_fused_module(convs.fused_conv_bn_relu) + ) + # TODO: we can add fusion for torch.relu as well + + # 3.2 conv + bn (+ relu) fused module configs + # fused conv bn + conv_configs.append( + BackendPatternConfig(convs.fused_conv_bn) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_qat_module(convs.bn_qat) + ) + + # fused conv bn relu + conv_configs.append( + BackendPatternConfig(convs.fused_conv_bn_relu) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_qat_module(convs.bn_relu_qat) + ) + + # conv bn, qat fused module + conv_configs.append( + BackendPatternConfig(convs.bn_qat) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + ) + # conv bn relu, qat fused module + conv_configs.append( + BackendPatternConfig(convs.bn_relu_qat) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + ) + + # (4) conv transpose and its fusion + # 4.1 conv transpose config + conv_configs.append( + BackendPatternConfig(convs.transpose) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_root_module(convs.transpose) + .set_reference_quantized_module(convs.transpose_reference) + ) + + # 4.2 conv transpose + bn fusion + conv_configs.append( + BackendPatternConfig((convs.transpose, convs.bn)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(fuse_convtranspose_bn) + .set_root_module(convs.transpose) + .set_reference_quantized_module(convs.transpose_reference) + ) + + # 4.3 functional conv transpose + conv_configs.append( + BackendPatternConfig(convs.func_transpose) + .set_dtype_configs(dtype_configs) # noqa: E131 + ._set_input_type_to_index({"weight": 1, "bias": 2}) + ) + + return conv_configs + + +def _get_cat_config(dtype_configs: list[DTypeConfig]) -> BackendPatternConfig: + return ( + BackendPatternConfig(torch.cat) + .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) + .set_dtype_configs(dtype_configs) + ) + + +def _get_ln_configs(dtype_configs: list[DTypeConfig]) -> list[BackendPatternConfig]: + ln_configs = [] + ln_configs.append( + BackendPatternConfig(torch.nn.LayerNorm) + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + ln_configs.append( + BackendPatternConfig(torch.nn.functional.layer_norm) + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 2, "bias": 3}) + ) + return ln_configs + + +def _get_default_op_configs( + dtype_configs: list[DTypeConfig], +) -> list[BackendPatternConfig]: + default_ops = [ + torch.nn.ELU, + torch.nn.LeakyReLU, + torch.nn.Hardswish, + torch.nn.InstanceNorm1d, + torch.nn.InstanceNorm2d, + torch.nn.InstanceNorm3d, + torch.nn.Dropout, + torch.nn.PReLU, + torch.nn.functional.elu, + torch.nn.functional.hardswish, + torch.nn.functional.leaky_relu, + torch.nn.functional.dropout, + ] + configs = [ + BackendPatternConfig(op) + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131 + .set_dtype_configs(dtype_configs) + for op in default_ops + ] + + configs.append( + BackendPatternConfig(torch.nn.functional.group_norm) + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 2, "bias": 3}) + ) + + configs.append( + BackendPatternConfig(torch.nn.functional.instance_norm) + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 3, "bias": 4}) + ) + return configs + + +def _add_fixed_qparams_to_dtype_configs( + dtype_configs: list[DTypeConfig], + constraints: DTypeWithConstraints, +) -> list[DTypeConfig]: + """ + Return a copy of the list of DTypeConfigs where activations are subject to the specified + constraints required for fixed qparams ops. + + If the data type doesn't match the one in the constraints, simply leave the corresponding + DTypeConfig unchanged. + + If `scale_min_lower_bound` or `scale_max_upper_bound` is specified in the activations, + throw an exception since these settings are incompatible with fixed qparams ops. + """ + new_dtype_configs = [] + for dtype_config in dtype_configs: + dc = copy.deepcopy(dtype_config) + for orig_constraints in [ + dc.input_dtype_with_constraints, + dc.output_dtype_with_constraints, + ]: + if orig_constraints.dtype != constraints.dtype: + continue + if orig_constraints.scale_min_lower_bound is not None: + raise ValueError( + f"scale_min_lower_bound is invalid for fixed qparams ops: {dtype_config}" + ) + if orig_constraints.scale_max_upper_bound is not None: + raise ValueError( + f"scale_max_upper_bound is invalid for fixed qparams ops: {dtype_config}" + ) + orig_constraints.quant_min_lower_bound = constraints.quant_min_lower_bound + orig_constraints.quant_max_upper_bound = constraints.quant_max_upper_bound + orig_constraints.scale_exact_match = constraints.scale_exact_match + orig_constraints.zero_point_exact_match = constraints.zero_point_exact_match + new_dtype_configs.append(dc) + return new_dtype_configs + + +def _get_fixed_qparams_op_configs( + dtype_configs: list[DTypeConfig], +) -> list[BackendPatternConfig]: + fixed_qparams_op_configs = [] + for fixed_qparam_op, constraints in _FIXED_QPARAMS_OP_TO_CONSTRAINTS.items(): + new_dtype_configs = _add_fixed_qparams_to_dtype_configs( + dtype_configs, constraints + ) + fixed_qparams_op_configs.append( + BackendPatternConfig(fixed_qparam_op) + .set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(new_dtype_configs) + ) + return fixed_qparams_op_configs + + +def _get_share_qparams_op_configs(dtype_configs): + """Get the operator config for the operators that works for both float and quantized input + if input is quantized, the output Tensor shares the same quantization parameter + with input. + Example operator: avgpool2d, reshape, transpose, maxpool2d + Example observed operator: + observer_0 - avgpool2d - observer_0 (same observer instance as input) + """ + + def _get_share_qprams_op_backend_config(op): + return ( + BackendPatternConfig(op) + .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) + .set_dtype_configs(dtype_configs) + ) + + share_qparams_ops = [ + torch.nn.AdaptiveAvgPool1d, + torch.nn.AdaptiveAvgPool2d, + torch.nn.AdaptiveAvgPool3d, + torch.nn.AvgPool1d, + torch.nn.AvgPool2d, + torch.nn.AvgPool3d, + torch.nn.Hardtanh, + torch.nn.Identity, + torch.nn.MaxPool1d, + torch.nn.MaxPool2d, + torch.nn.MaxPool3d, + torch.nn.PixelShuffle, + torch.nn.PixelUnshuffle, + torch.nn.ReLU, + torch.nn.ReLU6, + torch.adaptive_avg_pool1d, + torch.nn.functional.adaptive_avg_pool2d, + torch.nn.functional.adaptive_avg_pool3d, + torch.nn.functional.hardtanh, + torch.nn.functional.hardtanh_, + torch.nn.functional.interpolate, + torch.nn.functional.max_pool1d, + torch.nn.functional.max_pool2d, + torch.nn.functional.max_pool3d, + torch.nn.functional.pixel_shuffle, + torch.nn.functional.pixel_unshuffle, + torch.nn.functional.relu, + torch.nn.functional.relu6, + torch.avg_pool1d, + torch._C._nn.avg_pool2d, + torch._C._nn.avg_pool3d, + torch.clamp, + torch.flatten, + torch.mean, + torch.narrow, + torch.repeat_interleave, + torch.transpose, + torch.squeeze, + torch.stack, + torch.unsqueeze, + operator.floordiv, + "contiguous", + "clamp", + "detach", + "detach_", + "mean", + "permute", + "repeat", + "repeat_interleave", + "reshape", + "resize_", + "relu", + "relu_", + "squeeze", + "squeeze_", + "transpose", + "unsqueeze", + "unsqueeze_", + "view", + ] + return [_get_share_qprams_op_backend_config(op) for op in share_qparams_ops] + + +def _get_bn_configs(dtype_configs: list[DTypeConfig]) -> list[BackendPatternConfig]: + """Get configs related to batchnorm.""" + bn_configs = [] + bn_to_fused_bn = { + torch.nn.BatchNorm2d: nni.BNReLU2d, + torch.nn.BatchNorm3d: nni.BNReLU3d, + } + for bn in bn_to_fused_bn.keys(): + fused_bn = bn_to_fused_bn[bn] + # bn module + relu module fusion config + bn_configs.append( + BackendPatternConfig((bn, nn.ReLU)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(_sequential_wrapper2(fused_bn)) + .set_fused_module(fused_bn) + ) + # bn module + F.relu fusion config + bn_configs.append( + BackendPatternConfig((bn, F.relu)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(_sequential_wrapper2(fused_bn)) + .set_fused_module(fused_bn) + ) + bn_configs.append( + BackendPatternConfig(bn) + .set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + + # fused bn configs + for fused_bn in bn_to_fused_bn.values(): + bn_configs.append( + BackendPatternConfig(fused_bn) + .set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + return bn_configs + + +def _get_rnn_op_configs(dtype_configs: list[DTypeConfig]) -> list[BackendPatternConfig]: + rnn_op_configs = [] + for rnn_op, ref_rnn_op in [ + (nn.GRUCell, nnqr.GRUCell), + (nn.LSTMCell, nnqr.LSTMCell), + (nn.RNNCell, nnqr.RNNCell), + (nn.LSTM, nnqr.LSTM), + (nn.GRU, nnqr.GRU), + ]: + rnn_op_configs.append( + BackendPatternConfig(rnn_op) + .set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(rnn_op) + .set_reference_quantized_module(ref_rnn_op) + ) + return rnn_op_configs + + +def _get_embedding_op_configs( + dtype_configs: list[DTypeConfig], +) -> list[BackendPatternConfig]: + embedding_op_configs = [] + for embedding_op, qat_embedding_op, ref_embedding_op in [ + (nn.Embedding, nnqat.Embedding, nnqr.Embedding), + (nn.EmbeddingBag, nnqat.EmbeddingBag, nnqr.EmbeddingBag), + ]: + embedding_op_configs.append( + BackendPatternConfig(embedding_op) + .set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_qat_module(qat_embedding_op) + .set_root_module(embedding_op) + .set_reference_quantized_module(ref_embedding_op) + ) + + # config for qat op + embedding_op_configs.append( + BackendPatternConfig(qat_embedding_op) + .set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(embedding_op) + .set_reference_quantized_module(ref_embedding_op) + ) + return embedding_op_configs + + +def _get_tensor_info_op_configs(dtype_configs): + """ + These ops work on tensors of different dtypes but return non-tensors + containing information about the input tensor. + """ + + def _get_config(op): + return ( + BackendPatternConfig(op) + .set_observation_type(ObservationType.INPUT_OUTPUT_NOT_OBSERVED) + .set_dtype_configs(dtype_configs) + ) + + return [_get_config(op) for op in ("shape", "size")] diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/_qnnpack_pt2e.py b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/_qnnpack_pt2e.py new file mode 100644 index 0000000000000000000000000000000000000000..39d54886023bcb7379f43515778635179cd68331 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/_qnnpack_pt2e.py @@ -0,0 +1,181 @@ +# mypy: allow-untyped-defs +import operator + +import torch +from torch.ao.quantization.backend_config import ( + BackendConfig, + BackendPatternConfig, + DTypeConfig, + ObservationType, +) + + +weighted_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, +) + + +def get_linear_configs(): + linear_configs = [] + observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + dtype_configs = [weighted_op_quint8_dtype_config] + + # TODO: need to fix the way we insert observers for this pattern + # should be solved in the new fusion API + # reason that this doesn't work: the pattern is a bit complicated and we don't + # have a way to specify which input of the pattern we would like to observe + # pattern: + # bias input weight + # \ | / + # \ | t + # \ | / + # addmm + # we want to observe "weight" as weight, but there is not way to convey this + # information with current pattern language + # + # right now: + # original: + # weight - t \ + # input - addmm + # observed (no hack): + # weight - t - observer \ + # input - observer - addmm + # target: + # weight - observer - t \ + # input - observer - addmm + + # def root_node_getter(node_pattern): + # addmm, bias, act, weight = node_pattern + # return addmm + + # linear_configs.append( + # BackendPatternConfig((torch.ops.aten.addmm.default, MatchAllNode, MatchAllNode, torch.ops.aten.t.default)) + # .set_observation_type(observation_type) # noqa: E131 + # .set_dtype_configs(dtype_configs) + # ._set_root_node_getter(root_node_getter)) + + linear_configs.append( + BackendPatternConfig(torch.ops.aten.addmm.default) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 2, "bias": 0}) + ) + # linear is decomposed to `t - mm` if bias is not present + linear_configs.append( + BackendPatternConfig(torch.ops.aten.mm.default) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1}) + ) + return linear_configs + + +def get_conv_configs(): + conv_configs = [] + observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + dtype_configs = [weighted_op_quint8_dtype_config] + conv_configs.append( + BackendPatternConfig(torch.ops.aten.convolution.default) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1, "bias": 2}) + ) + conv_configs.append( + BackendPatternConfig( + (torch.ops.aten.convolution.default, torch.ops.aten.relu.default) + ) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1, "bias": 2}) + ) + # TODO: remove when functionalization is supported in PT2 mode + conv_configs.append( + BackendPatternConfig( + (torch.ops.aten.convolution.default, torch.ops.aten.relu_.default) + ) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1, "bias": 2}) + ) + return conv_configs + + +def get_pooling_configs(): + backend_pattern_configs = [] + observation_type = ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT + dtype_configs = [weighted_op_quint8_dtype_config] + + def root_node_getter(node_pattern): + _getitem, maxpool, _index = node_pattern + return maxpool + + backend_pattern_configs.append( + BackendPatternConfig() + ._set_pattern_complex_format( + (operator.getitem, torch.ops.aten.max_pool2d_with_indices.default, 0) + ) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_root_node_getter(root_node_getter) + ) + + return backend_pattern_configs + + +def get_relu_configs(): + backend_pattern_configs = [] + observation_type = ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT + dtype_configs = [weighted_op_quint8_dtype_config] + backend_pattern_configs.append( + BackendPatternConfig(torch.ops.aten.relu.default) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + return backend_pattern_configs + + +def get_binary_op_configs(): + binary_op_configs: list[BackendPatternConfig] = [] + dtype_configs = [weighted_op_quint8_dtype_config] + num_tensor_args_to_observation_type_mapping = { + # TODO: this is not used right now since we have extra check in prepare + # will need to change this to NO_OBSERVER later after we implemented + # Tensor dtype inference properly + 0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, + 1: ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT, + 2: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, + } + for op_with_quantized_bop_scalar_variant in [ + torch.ops.aten.add.Tensor, + torch.ops.aten.add_.Tensor, + ]: + bop_patterns = [ + (op_with_quantized_bop_scalar_variant, torch.ops.aten.relu.default), + op_with_quantized_bop_scalar_variant, + # TODO: remove when functionalization is supported in pt2_mode + (op_with_quantized_bop_scalar_variant, torch.ops.aten.relu_.default), + ] + binary_op_configs.extend( + BackendPatternConfig(bop_pattern) + .set_dtype_configs(dtype_configs) # noqa: E131 + ._set_num_tensor_args_to_observation_type( + num_tensor_args_to_observation_type_mapping + ) + for bop_pattern in bop_patterns + ) + + return binary_op_configs + + +def get_qnnpack_pt2e_backend_config(): + return ( + BackendConfig("qnnpack_pytorch_2.0_export") + .set_backend_pattern_configs(get_linear_configs()) + .set_backend_pattern_configs(get_binary_op_configs()) + .set_backend_pattern_configs(get_conv_configs()) + .set_backend_pattern_configs(get_pooling_configs()) + .set_backend_pattern_configs(get_relu_configs()) + ) diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/backend_config.py b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/backend_config.py new file mode 100644 index 0000000000000000000000000000000000000000..71cabd08c2a96882695ffa6d12e4d7a4ff7d6201 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/backend_config.py @@ -0,0 +1,749 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import Any, Callable, Optional, TYPE_CHECKING, Union + +import torch + + +if TYPE_CHECKING: + from torch.ao.quantization.utils import Pattern + + +__all__ = [ + "BackendConfig", + "BackendPatternConfig", + "DTypeConfig", + "DTypeWithConstraints", + "ObservationType", +] + + +# DTypeConfig dict keys +INPUT_DTYPE_DICT_KEY = "input_dtype" +OUTPUT_DTYPE_DICT_KEY = "output_dtype" +WEIGHT_DTYPE_DICT_KEY = "weight_dtype" +BIAS_DTYPE_DICT_KEY = "bias_dtype" +IS_DYNAMIC_DICT_KEY = "is_dynamic" + +# BackendConfig dict keys +NAME_DICT_KEY = "name" +CONFIGS_DICT_KEY = "configs" + +# BackendPatternConfig dict keys +PATTERN_DICT_KEY = "pattern" +PATTERN_COMPLEX_FORMAT_DICT_KEY = "pattern_complex_format" +OBSERVATION_TYPE_DICT_KEY = "observation_type" +DTYPE_CONFIGS_DICT_KEY = "dtype_configs" +ROOT_MODULE_DICT_KEY = "root_module" +QAT_MODULE_DICT_KEY = "qat_module" +REFERENCE_QUANTIZED_MODULE_DICT_KEY = "reference_quantized_module_for_root" +FUSED_MODULE_DICT_KEY = "fused_module" +FUSER_METHOD_DICT_KEY = "fuser_method" +ROOT_NODE_GETTER_DICT_KEY = "root_node_getter" +EXTRA_INPUTS_GETTER_DICT_KEY = "extra_inputs_getter" +NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY = "num_tensor_args_to_observation_type" +INPUT_TYPE_TO_INDEX_DICT_KEY = "input_type_to_index" + + +# TODO: maybe rename this to something that's not related to observer +# e.g. QParamsType +class ObservationType(Enum): + """An enum that represents different ways of how an operator/operator pattern + should be observed + """ + + OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT = 0 + """this means input and output are observed with different observers, based + on qconfig.activation + example: conv, linear, softmax + """ + + OUTPUT_SHARE_OBSERVER_WITH_INPUT = 1 + """this means the output will use the same observer instance as input, based + on qconfig.activation + example: torch.cat, maxpool + """ + + INPUT_OUTPUT_NOT_OBSERVED = 2 + """this means the input and output are never observed + example: x.shape, x.size + """ + + +@dataclass +class DTypeWithConstraints: + """ + Config for specifying additional constraints for a given dtype, such as quantization + value ranges, scale value ranges, and fixed quantization params, to be used in + :class:`~torch.ao.quantization.backend_config.DTypeConfig`. + + The constraints currently supported are: + + * `quant_min_lower_bound` and `quant_max_upper_bound`: Lower and upper + bounds for the minimum and maximum quantized values respectively. If + the QConfig's `quant_min` and `quant_max` fall outside this range, + then the QConfig will be ignored. + + * `scale_min_lower_bound` and `scale_max_upper_bound`: Lower and upper + bounds for the minimum and maximum scale values respectively. If the + QConfig's minimum scale value (currently exposed as `eps`) falls below + the lower bound, then the QConfig will be ignored. Note that the upper + bound is currently not enforced. + + * `scale_exact_match` and `zero_point_exact_match`: Exact match requirements + for scale and zero point, to be used for operators with fixed quantization + parameters such as sigmoid and tanh. If the observer specified in the QConfig + is neither `FixedQParamsObserver` nor `FixedQParamsFakeQuantize`, or if + the quantization parameters don't match, then the QConfig will be ignored. + """ + + dtype: Optional[torch.dtype] = None + quant_min_lower_bound: Union[int, float, None] = None + quant_max_upper_bound: Union[int, float, None] = None + scale_min_lower_bound: Union[int, float, None] = None + scale_max_upper_bound: Union[int, float, None] = None + scale_exact_match: Optional[float] = None + zero_point_exact_match: Optional[int] = None + + +@dataclass +class DTypeConfig: + """ + Config object that specifies the supported data types passed as arguments to + quantize ops in the reference model spec, for input and output activations, + weights, and biases. + + For example, consider the following reference model: + + quant1 - [dequant1 - fp32_linear - quant2] - dequant2 + + The pattern in the square brackets refers to the reference pattern of + statically quantized linear. Setting the input dtype as `torch.quint8` + in the DTypeConfig means we pass in `torch.quint8` as the dtype argument + to the first quantize op (quant1). Similarly, setting the output dtype as + `torch.quint8` means we pass in `torch.quint8` as the dtype argument to + the second quantize op (quant2). + + Note that the dtype here does not refer to the interface dtypes of the + op. For example, the "input dtype" here is not the dtype of the input + tensor passed to the quantized linear op. Though it can still be the + same as the interface dtype, this is not always the case, e.g. the + interface dtype is fp32 in dynamic quantization but the "input dtype" + specified in the DTypeConfig would still be quint8. The semantics of + dtypes here are the same as the semantics of the dtypes specified in + the observers. + + These dtypes are matched against the ones specified in the user's + QConfig. If there is a match, and the QConfig satisfies the constraints + specified in the DTypeConfig (if any), then we will quantize the given + pattern using this DTypeConfig. Otherwise, the QConfig is ignored and + the pattern will not be quantized. + + Example usage:: + + >>> # xdoctest: +SKIP(failing) + >>> dtype_config1 = DTypeConfig( + ... input_dtype=torch.quint8, + ... output_dtype=torch.quint8, + ... weight_dtype=torch.qint8, + ... bias_dtype=torch.float) + + >>> dtype_config2 = DTypeConfig( + ... input_dtype=DTypeWithConstraints( + ... dtype=torch.quint8, + ... quant_min_lower_bound=0, + ... quant_max_upper_bound=255, + ... ), + ... output_dtype=DTypeWithConstraints( + ... dtype=torch.quint8, + ... quant_min_lower_bound=0, + ... quant_max_upper_bound=255, + ... ), + ... weight_dtype=DTypeWithConstraints( + ... dtype=torch.qint8, + ... quant_min_lower_bound=-128, + ... quant_max_upper_bound=127, + ... ), + ... bias_dtype=torch.float) + + >>> dtype_config1.input_dtype + torch.quint8 + + >>> dtype_config2.input_dtype + torch.quint8 + + >>> dtype_config2.input_dtype_with_constraints + DTypeWithConstraints(dtype=torch.quint8, quant_min_lower_bound=0, quant_max_upper_bound=255, \ +scale_min_lower_bound=None, scale_max_upper_bound=None) + """ + + input_dtype_with_constraints: DTypeWithConstraints + output_dtype_with_constraints: DTypeWithConstraints + weight_dtype_with_constraints: DTypeWithConstraints + bias_dtype: Optional[torch.dtype] + is_dynamic: Optional[bool] + + def __init__( + self, + input_dtype: Union[torch.dtype, DTypeWithConstraints, None] = None, + output_dtype: Union[torch.dtype, DTypeWithConstraints, None] = None, + weight_dtype: Union[torch.dtype, DTypeWithConstraints, None] = None, + bias_dtype: Optional[torch.dtype] = None, + is_dynamic: Optional[bool] = None, + ): + if isinstance(input_dtype, DTypeWithConstraints): + self.input_dtype_with_constraints = input_dtype + else: + self.input_dtype_with_constraints = DTypeWithConstraints(dtype=input_dtype) + + if isinstance(output_dtype, DTypeWithConstraints): + self.output_dtype_with_constraints = output_dtype + else: + self.output_dtype_with_constraints = DTypeWithConstraints( + dtype=output_dtype + ) + + if isinstance(weight_dtype, DTypeWithConstraints): + self.weight_dtype_with_constraints = weight_dtype + else: + self.weight_dtype_with_constraints = DTypeWithConstraints( + dtype=weight_dtype + ) + + self.bias_dtype = bias_dtype + self.is_dynamic = is_dynamic + + @property + def input_dtype(self) -> Optional[torch.dtype]: + return self.input_dtype_with_constraints.dtype + + @property + def output_dtype(self) -> Optional[torch.dtype]: + return self.output_dtype_with_constraints.dtype + + @property + def weight_dtype(self) -> Optional[torch.dtype]: + return self.weight_dtype_with_constraints.dtype + + @classmethod + def from_dict(cls, dtype_config_dict: dict[str, Any]) -> DTypeConfig: + """ + Create a ``DTypeConfig`` from a dictionary with the following items (all optional): + "input_dtype": torch.dtype or ``DTypeWithConstraints`` + "output_dtype": torch.dtype or ``DTypeWithConstraints`` + "weight_dtype": torch.dtype or ``DTypeWithConstraints`` + "bias_type": torch.dtype + "is_dynamic": bool + """ + input_dtype = dtype_config_dict.get(INPUT_DTYPE_DICT_KEY, None) + if input_dtype is not None and not isinstance( + input_dtype, (torch.dtype, DTypeWithConstraints) + ): + raise ValueError( + "Expected input_dtype to be a torch.dtype or DTypeWithConstraints" + ) + output_dtype = dtype_config_dict.get(OUTPUT_DTYPE_DICT_KEY, None) + if output_dtype is not None and not isinstance( + output_dtype, (torch.dtype, DTypeWithConstraints) + ): + raise ValueError( + "Expected output_dtype to be a torch.dtype or DTypeWithConstraints" + ) + weight_dtype = dtype_config_dict.get(WEIGHT_DTYPE_DICT_KEY, None) + if weight_dtype is not None and not isinstance( + weight_dtype, (torch.dtype, DTypeWithConstraints) + ): + raise ValueError( + "Expected weight_dtype to be a torch.dtype or DTypeWithConstraints" + ) + bias_dtype = dtype_config_dict.get(BIAS_DTYPE_DICT_KEY, None) + is_dynamic = dtype_config_dict.get(IS_DYNAMIC_DICT_KEY, None) + return cls(input_dtype, output_dtype, weight_dtype, bias_dtype, is_dynamic) + + def to_dict(self) -> dict[str, Any]: + """ + Convert this ``DTypeConfig`` to a dictionary with the items described in + :func:`~torch.ao.quantization.backend_config.DTypeConfig.from_dict`. + """ + dtype_config_dict: dict[str, Any] = {} + if self.input_dtype is not None: + dtype_config_dict[INPUT_DTYPE_DICT_KEY] = self.input_dtype_with_constraints + if self.output_dtype is not None: + dtype_config_dict[OUTPUT_DTYPE_DICT_KEY] = ( + self.output_dtype_with_constraints + ) + if self.weight_dtype is not None: + dtype_config_dict[WEIGHT_DTYPE_DICT_KEY] = ( + self.weight_dtype_with_constraints + ) + if self.bias_dtype is not None: + dtype_config_dict[BIAS_DTYPE_DICT_KEY] = self.bias_dtype + if self.is_dynamic is not None: + dtype_config_dict[IS_DYNAMIC_DICT_KEY] = self.is_dynamic + return dtype_config_dict + + +class BackendConfig: + # TODO: refer to NativeBackendConfig once that is implemented + """Config that defines the set of patterns that can be quantized on a given backend, and how reference + quantized models can be produced from these patterns. + + A pattern in this context refers to a module, a functional, an operator, or a directed acyclic graph + of the above. Each pattern supported on the target backend can be individually configured through + :class:`~torch.ao.quantization.backend_config.BackendPatternConfig` in terms of: + + (1) The supported input/output activation, weight, and bias data types + + (2) How observers and quant/dequant ops are inserted in order to construct the reference pattern, and + + (3) (Optionally) Fusion, QAT, and reference module mappings. + + The format of the patterns is described in: + https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/README.md + + Example usage:: + + import torch + from torch.ao.quantization.backend_config import ( + BackendConfig, + BackendPatternConfig, + DTypeConfig, + ObservationType, + ) + + weighted_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float) + + def fuse_conv2d_relu(is_qat, conv, relu): + return torch.ao.nn.intrinsic.ConvReLU2d(conv, relu) + + # For quantizing Linear + linear_config = BackendPatternConfig(torch.nn.Linear) \ + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ + .add_dtype_config(weighted_int8_dtype_config) \ + .set_root_module(torch.nn.Linear) \ + .set_qat_module(torch.ao.nn.qat.Linear) \ + .set_reference_quantized_module(torch.ao.nn.quantized.reference.Linear) + + # For fusing Conv2d + ReLU into ConvReLU2d + conv_relu_config = BackendPatternConfig((torch.nn.Conv2d, torch.nn.ReLU)) \ + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ + .add_dtype_config(weighted_int8_dtype_config) \ + .set_fused_module(torch.ao.nn.intrinsic.ConvReLU2d) \ + .set_fuser_method(fuse_conv2d_relu) + + # For quantizing ConvReLU2d + fused_conv_relu_config = BackendPatternConfig(torch.ao.nn.intrinsic.ConvReLU2d) \ + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ + .add_dtype_config(weighted_int8_dtype_config) \ + .set_root_module(torch.nn.Conv2d) \ + .set_qat_module(torch.ao.nn.intrinsic.qat.ConvReLU2d) \ + .set_reference_quantized_module(torch.ao.nn.quantized.reference.Conv2d) + + backend_config = BackendConfig("my_backend") \ + .set_backend_pattern_config(linear_config) \ + .set_backend_pattern_config(conv_relu_config) \ + .set_backend_pattern_config(fused_conv_relu_config) + + """ + + def __init__(self, name: str = ""): + self.name = name + # Store all BackendPatternConfigs in a map to handle duplicates + # Note: the key in this map uses the complex reversed tuple format. + # This is intended only for internal use; users who wish to access + # the original patterns should go through `self.configs` instead. + self._pattern_complex_format_to_config: dict[Pattern, BackendPatternConfig] = {} + + def __repr__(self): + return f"BackendConfig({self.__dict__})" + + def set_name(self, name: str) -> BackendConfig: + """ + Set the name of the target backend. + """ + self.name = name + return self + + def set_backend_pattern_config(self, config: BackendPatternConfig) -> BackendConfig: + """ + Set the config for an pattern that can be run on the target backend. + This overrides any existing config for the given pattern. + """ + # Avoid circular dependencies + pattern_complex_format = torch.ao.quantization.backend_config.utils._get_pattern_in_reversed_nested_tuple_format( + config + ) # type: ignore[attr-defined] + self._pattern_complex_format_to_config[pattern_complex_format] = config + return self + + def set_backend_pattern_configs( + self, configs: list[BackendPatternConfig] + ) -> BackendConfig: + """ + Set the configs for patterns that can be run on the target backend. + This overrides any existing config for a given pattern if it was previously registered already. + """ + for conf in configs: + self.set_backend_pattern_config(conf) + return self + + @property + def configs(self) -> list[BackendPatternConfig]: + """ + Return a copy of the list of configs set in this `BackendConfig`. + """ + return list(self._pattern_complex_format_to_config.values()) + + @classmethod + def from_dict(cls, backend_config_dict: dict[str, Any]) -> BackendConfig: + """ + Create a ``BackendConfig`` from a dictionary with the following items: + + "name": the name of the target backend + + "configs": a list of dictionaries that each represents a `BackendPatternConfig` + + """ + conf = cls(backend_config_dict.get(NAME_DICT_KEY, "")) + for d in backend_config_dict.get(CONFIGS_DICT_KEY, []): + if isinstance(d, BackendPatternConfig): + conf.set_backend_pattern_config(d) + elif isinstance(d, dict): + conf.set_backend_pattern_config(BackendPatternConfig.from_dict(d)) + else: + raise ValueError( + f"Expected backend_config_dict['{CONFIGS_DICT_KEY}'] to be a dictionary" + ) + return conf + + def to_dict(self) -> dict[str, Any]: + """ + Convert this ``BackendConfig`` to a dictionary with the items described in + :func:`~torch.ao.quantization.backend_config.BackendConfig.from_dict`. + """ + return { + NAME_DICT_KEY: self.name, + CONFIGS_DICT_KEY: [c.to_dict() for c in self.configs], + } + + +class BackendPatternConfig: + """ + Config object that specifies quantization behavior for a given operator pattern. + For a detailed example usage, see :class:`~torch.ao.quantization.backend_config.BackendConfig`. + """ + + def __init__(self, pattern: Optional[Pattern] = None): + self.pattern: Optional[Pattern] = pattern + self.observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + self.dtype_configs: list[DTypeConfig] = [] + self.root_module: Optional[type[torch.nn.Module]] = None + self.qat_module: Optional[type[torch.nn.Module]] = None + self.reference_quantized_module: Optional[type[torch.nn.Module]] = None + self.fused_module: Optional[type[torch.nn.Module]] = None + self.fuser_method: Optional[Callable] = None + + # Temporary/internal configs + self._root_node_getter: Optional[Callable] = None + self._extra_inputs_getter: Optional[Callable] = None + self._num_tensor_args_to_observation_type: dict[int, ObservationType] = {} + self._input_type_to_index: dict[str, int] = {} + self._pattern_complex_format: Optional[Pattern] = None + + def __repr__(self): + dict_nonempty = { + k: v + for k, v in self.__dict__.items() + if ( + (not isinstance(v, (list, dict)) and v is not None) + or (isinstance(v, (list, dict)) and len(v) > 0) + ) + } + return f"BackendPatternConfig({dict_nonempty})" + + def set_pattern(self, pattern: Pattern) -> BackendPatternConfig: + """ + Set the pattern to configure. + + The pattern can be a float module, functional operator, pytorch operator, or a tuple + combination of the above. Tuple patterns are treated as sequential patterns, and + currently only tuples of 2 or 3 elements are supported. + """ + if self._pattern_complex_format is not None: + raise ValueError( + "Only one of 'pattern' or 'pattern_complex_format' can be set" + ) + self.pattern = pattern + return self + + def set_observation_type( + self, observation_type: ObservationType + ) -> BackendPatternConfig: + """ + Set how observers should be inserted in the graph for this pattern. + + Observation type here refers to how observers (or quant-dequant ops) will be placed + in the graph. This is used to produce the desired reference patterns understood by + the backend. Weighted ops such as linear and conv require different observers + (or quantization parameters passed to quantize ops in the reference model) for the + input and the output. + + There are two observation types: + + `OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT` (default): the output observer instance + will be different from the input. This is the most common observation type. + + `OUTPUT_SHARE_OBSERVER_WITH_INPUT`: the output observer instance will be the + same as the input. This is useful for operators like `cat`. + + Note: This will be renamed in the near future, since we will soon insert QuantDeQuantStubs + with observers (and fake quantizes) attached instead of observers themselves. + """ + self.observation_type = observation_type + return self + + def add_dtype_config(self, dtype_config: DTypeConfig) -> BackendPatternConfig: + """ + Add a set of supported data types passed as arguments to quantize ops in the + reference model spec. + """ + self.dtype_configs.append(dtype_config) + return self + + def set_dtype_configs( + self, dtype_configs: list[DTypeConfig] + ) -> BackendPatternConfig: + """ + Set the supported data types passed as arguments to quantize ops in the + reference model spec, overriding all previously registered data types. + """ + self.dtype_configs = dtype_configs + return self + + def set_root_module( + self, root_module: type[torch.nn.Module] + ) -> BackendPatternConfig: + """ + Set the module that represents the root for this pattern. + + When we construct the reference quantized model during the convert phase, + the root modules (e.g. torch.nn.Linear for torch.ao.nn.intrinsic.LinearReLU) + will be swapped to the corresponding reference quantized modules (e.g. + torch.ao.nn.reference.quantized.Linear). This allows custom backends to + specify custom reference quantized module implementations to match the + numerics of their lowered operators. Since this is a one-to-one mapping, + both the root module and the reference quantized module must be specified + in the same BackendPatternConfig in order for the conversion to take place. + """ + self.root_module = root_module + return self + + def set_qat_module(self, qat_module: type[torch.nn.Module]) -> BackendPatternConfig: + """ + Set the module that represents the QAT implementation for this pattern. + """ + self.qat_module = qat_module + return self + + def set_reference_quantized_module( + self, reference_quantized_module: type[torch.nn.Module] + ) -> BackendPatternConfig: + """ + Set the module that represents the reference quantized implementation for + this pattern's root module. + + For more detail, see :func:`~torch.ao.quantization.backend_config.BackendPatternConfig.set_root_module`. + """ + self.reference_quantized_module = reference_quantized_module + return self + + def set_fused_module( + self, fused_module: type[torch.nn.Module] + ) -> BackendPatternConfig: + """ + Set the module that represents the fused implementation for this pattern. + """ + self.fused_module = fused_module + return self + + def set_fuser_method(self, fuser_method: Callable) -> BackendPatternConfig: + """ + Set the function that specifies how to fuse this BackendPatternConfig's pattern. + + The first argument of this function should be `is_qat`, and the rest of the arguments + should be the items in the tuple pattern. The return value of this function should be + the resulting fused module. + + For example, the fuser method for the pattern `(torch.nn.Linear, torch.nn.ReLU)` can be: + + def fuse_linear_relu(is_qat, linear, relu): + return torch.ao.nn.intrinsic.LinearReLU(linear, relu) + + For a more complicated example, see https://gist.github.com/jerryzh168/8bea7180a8ba3c279f2c9b050f2a69a6. + """ + self.fuser_method = fuser_method + return self + + def _set_root_node_getter(self, root_node_getter: Callable) -> BackendPatternConfig: + self._root_node_getter = root_node_getter + return self + + def _set_extra_inputs_getter( + self, extra_inputs_getter: Callable + ) -> BackendPatternConfig: + self._extra_inputs_getter = extra_inputs_getter + return self + + def _set_num_tensor_args_to_observation_type( + self, num_tensor_args_to_observation_type: dict[int, ObservationType] + ) -> BackendPatternConfig: + self._num_tensor_args_to_observation_type = num_tensor_args_to_observation_type + return self + + def _set_input_type_to_index( + self, input_type_to_index: dict[str, int] + ) -> BackendPatternConfig: + self._input_type_to_index = input_type_to_index + return self + + def _set_pattern_complex_format(self, pattern: Pattern) -> BackendPatternConfig: + """ + Set the pattern to configure, using the reversed nested tuple format. + + See the BackendConfig README for more detail: + https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/README.md#advanced-pattern-specification + """ + if self.pattern is not None: + raise ValueError( + "Only one of 'pattern' or 'pattern_complex_format' can be set" + ) + self._pattern_complex_format = pattern + return self + + @classmethod + def from_dict( + cls, backend_pattern_config_dict: dict[str, Any] + ) -> BackendPatternConfig: + """ + Create a ``BackendPatternConfig`` from a dictionary with the following items: + + "pattern": the pattern being configured + "observation_type": the :class:`~torch.ao.quantization.backend_config.ObservationType` that specifies how + observers should be inserted for this pattern + "dtype_configs": a list of dictionaries that represents :class:`~torch.ao.quantization.backend_config.DTypeConfig` s + "root_module": a :class:`torch.nn.Module` that represents the root for this pattern + "qat_module": a :class:`torch.nn.Module` that represents the QAT implementation for this pattern + "reference_quantized_module": a :class:`torch.nn.Module` that represents the reference quantized + implementation for this pattern's root module. + "fused_module": a :class:`torch.nn.Module` that represents the fused implementation for this pattern + "fuser_method": a function that specifies how to fuse the pattern for this pattern + "pattern_complex_format": the pattern specified in the reversed nested tuple format (deprecated) + + """ + + def _get_dtype_config(obj: Any) -> DTypeConfig: + """ + Convert the given object into a ``DTypeConfig`` if possible, else throw an exception. + """ + if isinstance(obj, DTypeConfig): + return obj + if isinstance(obj, dict): + return DTypeConfig.from_dict(obj) + raise ValueError( + f"Expected a list of DTypeConfigs in " + f"backend_pattern_config_dict[\"{DTYPE_CONFIGS_DICT_KEY}\"], got '{type(obj)}'" + ) + + conf = cls() + if PATTERN_DICT_KEY in backend_pattern_config_dict: + conf.set_pattern(backend_pattern_config_dict[PATTERN_DICT_KEY]) + if OBSERVATION_TYPE_DICT_KEY in backend_pattern_config_dict: + conf.set_observation_type( + backend_pattern_config_dict[OBSERVATION_TYPE_DICT_KEY] + ) + for d in backend_pattern_config_dict.get(DTYPE_CONFIGS_DICT_KEY, []): + conf.add_dtype_config(_get_dtype_config(d)) + conf.set_root_module( + backend_pattern_config_dict.get(ROOT_MODULE_DICT_KEY, None) # type: ignore[arg-type] + ) + conf.set_qat_module(backend_pattern_config_dict.get(QAT_MODULE_DICT_KEY, None)) # type: ignore[arg-type] + conf.set_reference_quantized_module( + backend_pattern_config_dict.get(REFERENCE_QUANTIZED_MODULE_DICT_KEY, None) # type: ignore[arg-type] + ) + conf.set_fused_module( + backend_pattern_config_dict.get(FUSED_MODULE_DICT_KEY, None) # type: ignore[arg-type] + ) + conf.set_fuser_method( + backend_pattern_config_dict.get(FUSER_METHOD_DICT_KEY, None) # type: ignore[arg-type] + ) + conf._set_root_node_getter( + backend_pattern_config_dict.get(ROOT_NODE_GETTER_DICT_KEY, None) # type: ignore[arg-type] + ) + conf._set_extra_inputs_getter( + backend_pattern_config_dict.get(EXTRA_INPUTS_GETTER_DICT_KEY, None) # type: ignore[arg-type] + ) + conf._set_num_tensor_args_to_observation_type( + backend_pattern_config_dict.get( + NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY, {} + ) + ) + conf._set_input_type_to_index( + backend_pattern_config_dict.get(INPUT_TYPE_TO_INDEX_DICT_KEY, {}) + ) + if PATTERN_COMPLEX_FORMAT_DICT_KEY in backend_pattern_config_dict: + conf._set_pattern_complex_format( + backend_pattern_config_dict[PATTERN_COMPLEX_FORMAT_DICT_KEY] + ) + return conf + + def to_dict(self) -> dict[str, Any]: + """ + Convert this ``BackendPatternConfig`` to a dictionary with the items described in + :func:`~torch.ao.quantization.backend_config.BackendPatternConfig.from_dict`. + """ + backend_pattern_config_dict: dict[str, Any] = { + OBSERVATION_TYPE_DICT_KEY: self.observation_type, + DTYPE_CONFIGS_DICT_KEY: [c.to_dict() for c in self.dtype_configs], + } + if self.pattern is not None: + backend_pattern_config_dict[PATTERN_DICT_KEY] = self.pattern + if self.root_module is not None: + backend_pattern_config_dict[ROOT_MODULE_DICT_KEY] = self.root_module + if self.qat_module is not None: + backend_pattern_config_dict[QAT_MODULE_DICT_KEY] = self.qat_module + if self.reference_quantized_module is not None: + backend_pattern_config_dict[REFERENCE_QUANTIZED_MODULE_DICT_KEY] = ( + self.reference_quantized_module + ) + if self.fused_module is not None: + backend_pattern_config_dict[FUSED_MODULE_DICT_KEY] = self.fused_module + if self.fuser_method is not None: + backend_pattern_config_dict[FUSER_METHOD_DICT_KEY] = self.fuser_method + if self._root_node_getter is not None: + backend_pattern_config_dict[ROOT_NODE_GETTER_DICT_KEY] = ( + self._root_node_getter + ) + if self._extra_inputs_getter is not None: + backend_pattern_config_dict[EXTRA_INPUTS_GETTER_DICT_KEY] = ( + self._extra_inputs_getter + ) + if len(self._num_tensor_args_to_observation_type) > 0: + backend_pattern_config_dict[ + NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY + ] = self._num_tensor_args_to_observation_type + if len(self._input_type_to_index) > 0: + backend_pattern_config_dict[INPUT_TYPE_TO_INDEX_DICT_KEY] = ( + self._input_type_to_index + ) + if self._pattern_complex_format is not None: + backend_pattern_config_dict[PATTERN_COMPLEX_FORMAT_DICT_KEY] = ( + self._pattern_complex_format + ) + return backend_pattern_config_dict diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/executorch.py b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/executorch.py new file mode 100644 index 0000000000000000000000000000000000000000..8ecf58b489ea2c783f27e2e06a67e8377d464ea7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/executorch.py @@ -0,0 +1,498 @@ +# TODO: rename executorch to qnnpack_executorch since executorch is a general runtime +# not a specific backend + +import operator + +import torch +import torch.ao.nn.qat as nnqat +import torch.ao.nn.quantized.reference as nnqr +import torch.nn as nn +import torch.nn.functional as F +from torch.ao.quantization.fuser_method_mappings import ( + _sequential_wrapper2, + fuse_conv_bn, + fuse_conv_bn_relu, +) + +from ._common_operator_config_utils import _Conv2dMetadata +from .backend_config import ( + BackendConfig, + BackendPatternConfig, + DTypeConfig, + DTypeWithConstraints, + ObservationType, +) +from .qnnpack import ( + qnnpack_default_op_qint8_symmetric_dtype_config, + qnnpack_weighted_op_qint8_symmetric_dtype_config, +) + + +__all__ = [ + "get_executorch_backend_config", +] + + +# =================== +# | DTYPE CONFIGS | +# =================== + +executorch_weighted_op_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, +) + +executorch_default_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, +) + +executorch_default_dynamic_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.float, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + is_dynamic=True, +) + +executorch_act_qint8_scale_min_2_neg_12 = DTypeWithConstraints( + dtype=torch.qint8, + scale_min_lower_bound=2**-12, +) + +executorch_weight_qint8_neg_127_to_127_scale_min_2_neg_12 = DTypeWithConstraints( + dtype=torch.qint8, + quant_min_lower_bound=-127, + quant_max_upper_bound=127, + scale_min_lower_bound=2**-12, +) + +executorch_default_dynamic_qint8_dtype_config = DTypeConfig( + input_dtype=executorch_act_qint8_scale_min_2_neg_12, + output_dtype=torch.float, + weight_dtype=executorch_weight_qint8_neg_127_to_127_scale_min_2_neg_12, + bias_dtype=torch.float, + is_dynamic=True, +) + +executorch_default_dynamic_float16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float, + weight_dtype=torch.float16, + bias_dtype=torch.float, + is_dynamic=True, +) + +executorch_weight_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint8, +) + + +# ============================= +# | BACKEND PATTERN CONFIGS | +# ============================= + + +def _get_linear_configs() -> list[BackendPatternConfig]: + """ + Return all configs related to linear modules and ops. + """ + observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + dtype_configs = [ + qnnpack_weighted_op_qint8_symmetric_dtype_config, + executorch_weighted_op_int8_dtype_config, + executorch_default_dynamic_quint8_dtype_config, + executorch_default_dynamic_qint8_dtype_config, + executorch_default_dynamic_float16_dtype_config, + ] + linear_configs: list[BackendPatternConfig] = [] + # linear module + linear_configs.append( + BackendPatternConfig(torch.nn.Linear) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(torch.nn.Linear) + .set_reference_quantized_module(nnqr.Linear) + .set_qat_module(nnqat.Linear) + ) + # linear qat module + linear_configs.append( + BackendPatternConfig(nnqat.Linear) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(torch.nn.Linear) + .set_reference_quantized_module(nnqr.Linear) + ) + # functional linear + linear_configs.append( + BackendPatternConfig(torch.nn.functional.linear) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1, "bias": 2}) + ) + return linear_configs + + +def _get_conv_configs() -> list[BackendPatternConfig]: + """ + Return all configs related to conv modules and ops. + """ + observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + dtype_configs = [ + qnnpack_weighted_op_qint8_symmetric_dtype_config, + executorch_weighted_op_int8_dtype_config, + ] + conv_configs = [] + for convs in [_Conv2dMetadata]: + # (1) Single conv modules/functions + # ----------------------------------- + # conv module + conv_configs.append( + BackendPatternConfig(convs.root) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + .set_qat_module(convs.qat) + ) + # conv qat module + conv_configs.append( + BackendPatternConfig(convs.qat) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + ) + # functional conv + conv_configs.append( + BackendPatternConfig(convs.func) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1, "bias": 2}) + ) + + # (2) Conv + relu + # ----------------------------------- + # conv module + relu module + conv_configs.append( + BackendPatternConfig((convs.root, nn.ReLU)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu)) + .set_fused_module(convs.fused_conv_relu) + ) + # conv module + functional relu + conv_configs.append( + BackendPatternConfig((convs.root, F.relu)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu)) + .set_fused_module(convs.fused_conv_relu) + ) + # fused conv relu module + conv_configs.append( + BackendPatternConfig(convs.fused_conv_relu) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + .set_qat_module(convs.relu_qat) + ) + # conv relu, qat fused module + conv_configs.append( + BackendPatternConfig(convs.relu_qat) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + ) + # functional conv + relu module + conv_configs.append( + BackendPatternConfig((convs.func, nn.ReLU)) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + # functional conv + functional relu + conv_configs.append( + BackendPatternConfig((convs.func, F.relu)) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + # fused conv relu + conv_configs.append( + BackendPatternConfig(convs.fused_conv_relu) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_qat_module(convs.relu_qat) + ) + + conv_configs.append( + BackendPatternConfig(convs.relu_qat) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + ) + + # (3) Conv + batchnorm (+ relu) + # ------------------------------- + # conv + batchnorm (+ relu) + conv_configs.append( + BackendPatternConfig((convs.root, convs.bn)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(fuse_conv_bn) + .set_fused_module(convs.fused_conv_bn) + ) + # conv + bn + relu module fusion + conv_configs.append( + BackendPatternConfig((convs.root, convs.bn, nn.ReLU)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(fuse_conv_bn_relu) + .set_fused_module(convs.fused_conv_bn_relu) + ) + # conv + bn + relu functional fusion + conv_configs.append( + BackendPatternConfig((convs.root, convs.bn, F.relu)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_root_module(convs.root) + .set_fuser_method(fuse_conv_bn_relu) + .set_fused_module(convs.fused_conv_bn_relu) + ) + # TODO: we can add fusion for torch.relu as well + # 3.2 conv + bn (+ relu) fused module configs + # fused conv bn + conv_configs.append( + BackendPatternConfig(convs.fused_conv_bn) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_qat_module(convs.bn_qat) + ) + + # fused conv bn relu + conv_configs.append( + BackendPatternConfig(convs.fused_conv_bn_relu) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_qat_module(convs.bn_relu_qat) + ) + + # conv bn, qat fused module + conv_configs.append( + BackendPatternConfig(convs.bn_qat) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + ) + # conv bn relu, qat fused module + conv_configs.append( + BackendPatternConfig(convs.bn_relu_qat) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + ) + return conv_configs + + +def _get_binary_ops_configs() -> list[BackendPatternConfig]: + """ + Return all configs related to binary ops. + """ + dtype_configs = [ + qnnpack_default_op_qint8_symmetric_dtype_config, + executorch_weighted_op_int8_dtype_config, + ] + num_tensor_args_to_observation_type_mapping = { + # TODO: this is not used right now since we have extra check in prepare + # will need to change this to NO_OBSERVER later after we implemented + # Tensor dtype inference properly + 0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, + 1: ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT, + 2: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, + } + binary_op_configs: list[BackendPatternConfig] = [] + for op in [ + operator.add, + torch.add, + operator.sub, + torch.sub, + operator.mul, + torch.mul, + ]: + bop_patterns = [ + (op, torch.nn.ReLU), + (op, torch.nn.functional.relu), + (op, torch.relu), + op, + ] + binary_op_configs.extend( + BackendPatternConfig(bop_pattern) + .set_dtype_configs(dtype_configs) # noqa: E131 + ._set_num_tensor_args_to_observation_type( + num_tensor_args_to_observation_type_mapping + ) + for bop_pattern in bop_patterns + ) + return binary_op_configs + + +def _get_share_qparams_ops_configs() -> list[BackendPatternConfig]: + """ + Return the operator configs for the operators that works for both float and quantized + input if input is quantized, the output Tensor shares the same quantization parameter + with input. + + Example operator: avgpool2d, reshape, transpose, maxpool2d + Example observed operator: + observer_0 - avgpool2d - observer_0 (same observer instance as input) + """ + observation_type = ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT + dtype_configs = [ + qnnpack_default_op_qint8_symmetric_dtype_config, + executorch_default_op_quint8_dtype_config, + ] + share_qparams_ops = [ + torch.nn.Flatten, + F.adaptive_avg_pool2d, + F.elu, + F.hardtanh, + F.max_pool2d, + F.pad, + F.relu, + F.relu6, + F.leaky_relu, + F.leaky_relu_, + torch.nn.AdaptiveAvgPool2d, + torch.nn.ConstantPad2d, + torch.nn.ELU, + torch.nn.MaxPool2d, + torch.nn.ReLU6, + torch.nn.Hardtanh, + torch.nn.LeakyReLU, + torch.clamp, + torch.flatten, + torch.mean, + torch.permute, + torch.permute_copy, + torch.squeeze, + "clamp", + "mean", + "permute", + "reshape", + "relu", + "relu_", + "squeeze", + "squeeze_", + "leaky_relu", + ] + share_qparams_op_configs: list[BackendPatternConfig] = [ + BackendPatternConfig(op) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + for op in share_qparams_ops + ] + return share_qparams_op_configs + + +def _get_bn_configs() -> list[BackendPatternConfig]: + """ + Return all configs related to batchnorm. + """ + observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + dtype_configs = [ + qnnpack_default_op_qint8_symmetric_dtype_config, + executorch_default_op_quint8_dtype_config, + ] + bn_configs = [] + bn_configs.append( + BackendPatternConfig(nn.BatchNorm2d) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + return bn_configs + + +def _get_cat_configs() -> list[BackendPatternConfig]: + dtype_configs = [ + qnnpack_default_op_qint8_symmetric_dtype_config, + executorch_default_op_quint8_dtype_config, + ] + cat_configs = [] + cat_configs.append( + BackendPatternConfig(torch.cat) + .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) + .set_dtype_configs(dtype_configs) + ) + cat_configs.append( + BackendPatternConfig(torch.concat) + .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) + .set_dtype_configs(dtype_configs) + ) + cat_configs.append( + BackendPatternConfig(torch.concatenate) + .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) + .set_dtype_configs(dtype_configs) + ) + return cat_configs + + +def _get_embedding_op_configs() -> list[BackendPatternConfig]: + dtype_configs = [ + executorch_weight_only_quint8_dtype_config, + ] + embedding_op_configs = [] + for embedding_op, qat_embedding_op, ref_embedding_op in [ + (nn.Embedding, nnqat.Embedding, nnqr.Embedding), + (nn.EmbeddingBag, nnqat.EmbeddingBag, nnqr.EmbeddingBag), + ]: + embedding_op_configs.append( + BackendPatternConfig(embedding_op) + .set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_qat_module(qat_embedding_op) + .set_root_module(embedding_op) + .set_reference_quantized_module(ref_embedding_op) + ) + # config for qat op + embedding_op_configs.append( + BackendPatternConfig(qat_embedding_op) + .set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(embedding_op) + .set_reference_quantized_module(ref_embedding_op) + ) + + # config for functional embedding + embedding_op_configs.append( + BackendPatternConfig(torch.nn.functional.embedding) + .set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1}) + ) + return embedding_op_configs + + +# ===================== +# | BACKEND CONFIGS | +# ===================== + + +def get_executorch_backend_config() -> BackendConfig: + """ + Return the `BackendConfig` for backends PyTorch lowers to through the Executorch stack. + """ + return ( + BackendConfig("executorch") + .set_backend_pattern_configs(_get_linear_configs()) + .set_backend_pattern_configs(_get_conv_configs()) + .set_backend_pattern_configs(_get_binary_ops_configs()) + .set_backend_pattern_configs(_get_share_qparams_ops_configs()) + .set_backend_pattern_configs(_get_bn_configs()) + .set_backend_pattern_configs(_get_cat_configs()) + .set_backend_pattern_configs(_get_embedding_op_configs()) + ) diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/fbgemm.py b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/fbgemm.py new file mode 100644 index 0000000000000000000000000000000000000000..625313ab945018a69d6c3e059262e00465969c99 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/fbgemm.py @@ -0,0 +1,129 @@ +import torch + +from ._common_operator_config_utils import ( + _get_binary_op_configs, + _get_bn_configs, + _get_cat_config, + _get_conv_configs, + _get_default_op_configs, + _get_embedding_op_configs, + _get_fixed_qparams_op_configs, + _get_linear_configs, + _get_rnn_op_configs, + _get_share_qparams_op_configs, + _get_tensor_info_op_configs, +) +from .backend_config import BackendConfig, DTypeConfig + + +__all__ = [ + "get_fbgemm_backend_config", +] + +# =================== +# | DTYPE CONFIGS | +# =================== + +# TODO: For now, these DTypeConfigs are identical to the ones defined in native.py +# In the future, once we support specifying quant_min/quant_max and scale_min/scale_max, +# these will diverge. In particular, for FBGEMM, we will restrict the activation quantized +# values to within [0, 127]. + +fbgemm_weighted_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, +) + +fbgemm_default_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, +) + +fbgemm_default_op_fp16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float16, + weight_dtype=torch.float16, + bias_dtype=torch.float16, +) + +fbgemm_default_dynamic_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.float, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + is_dynamic=True, +) + +fbgemm_default_dynamic_float16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float, + weight_dtype=torch.float16, + bias_dtype=torch.float, + is_dynamic=True, +) + +fbgemm_weight_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint8, +) + +fbgemm_weight_only_quint4x2_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint4x2, +) + + +# ===================== +# | BACKEND CONFIGS | +# ===================== + + +def get_fbgemm_backend_config() -> BackendConfig: + """ + Return the `BackendConfig` for PyTorch's native FBGEMM backend. + """ + conv_dtype_configs = [fbgemm_weighted_op_quint8_dtype_config] + linear_dtype_configs = [ + fbgemm_weighted_op_quint8_dtype_config, + fbgemm_default_dynamic_int8_dtype_config, + fbgemm_default_dynamic_float16_dtype_config, + ] + binary_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config] + default_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config] + fixed_qparams_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config] + share_qparams_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config] + tensor_info_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config] + rnn_op_dtype_configs = [ + fbgemm_default_dynamic_int8_dtype_config, + fbgemm_default_dynamic_float16_dtype_config, + ] + embedding_op_dtype_configs = [ + fbgemm_weight_only_quint8_dtype_config, + fbgemm_weight_only_quint4x2_dtype_config, + ] + return ( + BackendConfig("fbgemm") + .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) + .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) + .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) + .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) + .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) + .set_backend_pattern_configs( + _get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs( + _get_share_qparams_op_configs(share_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs( + _get_tensor_info_op_configs(tensor_info_op_dtype_configs) + ) + .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) + .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) + .set_backend_pattern_configs( + _get_embedding_op_configs(embedding_op_dtype_configs) + ) + ) diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/native.py b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/native.py new file mode 100644 index 0000000000000000000000000000000000000000..54650ef96db1bc23fb610287341174276941972d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/native.py @@ -0,0 +1,231 @@ +# mypy: allow-untyped-defs +import torch + +from ._common_operator_config_utils import ( + _get_binary_op_configs, + _get_bn_configs, + _get_cat_config, + _get_conv_configs, + _get_default_op_configs, + _get_embedding_op_configs, + _get_fixed_qparams_op_configs, + _get_linear_configs, + _get_ln_configs, + _get_rnn_op_configs, + _get_share_qparams_op_configs, + _get_tensor_info_op_configs, +) +from .backend_config import BackendConfig, DTypeConfig + + +__all__ = [ + "get_test_only_legacy_native_backend_config", + "default_op_quint8_dtype_config", + "default_op_fp16_dtype_config", + "default_dynamic_int8_dtype_config", + "default_dynamic_float16_dtype_config", + "input_output_only_quint8_dtype_config", + "weight_only_quint8_dtype_config", + "weight_only_quint4x2_dtype_config", + "get_native_backend_config", + "get_native_backend_config_dict", + "get_test_only_legacy_native_backend_config_dict", +] + +# =================== +# | DTYPE CONFIGS | +# =================== + +# weighted op int8 dtype config +# this is config for ops that has quantized weights, like linear, conv +weighted_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, +) + +default_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, +) + +default_op_fp16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float16, + weight_dtype=torch.float16, + bias_dtype=torch.float16, +) + +default_dynamic_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.float, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + # currently the dtype check is not yet enabled, so we provided the dtype_configs but + # it is not really used yet, + # we will enable it a bit later after we moved everything to backend_config_dict + is_dynamic=True, +) + +default_dynamic_float16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float, + weight_dtype=torch.float16, + bias_dtype=torch.float, + # currently the dtype check is not yet enabled, so we provided the dtype_configs but + # it is not really used yet, + # we will enable it a bit later after we moved everything to backend_config_dict + is_dynamic=True, +) + +# Needed for LayerNorm and f.layer_norm, since currently the kernel only supports float weights +input_output_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.float, + bias_dtype=torch.float, +) + +weight_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint8, +) + +weight_only_quint4x2_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint4x2, +) + + +# ===================== +# | BACKEND CONFIGS | +# ===================== + + +def get_test_only_legacy_native_backend_config() -> BackendConfig: + """ + Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) with various additional fp16 ops. + """ + conv_dtype_configs = [weighted_op_quint8_dtype_config] + linear_dtype_configs = [ + weighted_op_quint8_dtype_config, + default_dynamic_int8_dtype_config, + default_dynamic_float16_dtype_config, + default_op_fp16_dtype_config, + ] + binary_op_dtype_configs = [ + default_op_quint8_dtype_config, + default_op_fp16_dtype_config, + ] + default_op_dtype_configs = [default_op_quint8_dtype_config] + fixed_qparams_op_dtype_configs = [ + default_op_quint8_dtype_config, + default_op_fp16_dtype_config, + ] + share_qparams_op_dtype_configs = [ + default_op_quint8_dtype_config, + default_op_fp16_dtype_config, + ] + tensor_info_op_dtype_configs = [ + default_op_quint8_dtype_config, + ] + rnn_op_dtype_configs = [ + default_dynamic_int8_dtype_config, + default_dynamic_float16_dtype_config, + ] + embedding_op_dtype_configs = [ + weight_only_quint8_dtype_config, + weight_only_quint4x2_dtype_config, + ] + layer_norm_op_dtype_configs = [input_output_only_quint8_dtype_config] + return ( + BackendConfig("_native_and_fp16") + .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) + .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) + .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) + .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) + .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) + .set_backend_pattern_configs( + _get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs( + _get_share_qparams_op_configs(share_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs( + _get_tensor_info_op_configs(tensor_info_op_dtype_configs) + ) + .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) + .set_backend_pattern_configs(_get_ln_configs(layer_norm_op_dtype_configs)) + .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) + .set_backend_pattern_configs( + _get_embedding_op_configs(embedding_op_dtype_configs) + ) + ) + + +def get_native_backend_config() -> BackendConfig: + """ + Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack). + """ + # TODO: express this BackendConfig as a union of the FBGEMM and QNNPACK BackendConfigs + conv_dtype_configs = [weighted_op_quint8_dtype_config] + linear_dtype_configs = [ + weighted_op_quint8_dtype_config, + default_dynamic_int8_dtype_config, + default_dynamic_float16_dtype_config, + ] + binary_op_dtype_configs = [default_op_quint8_dtype_config] + default_op_dtype_configs = [default_op_quint8_dtype_config] + fixed_qparams_op_dtype_configs = [default_op_quint8_dtype_config] + share_qparams_op_dtype_configs = [default_op_quint8_dtype_config] + tensor_info_op_dtype_configs = [default_op_quint8_dtype_config] + rnn_op_dtype_configs = [ + default_dynamic_int8_dtype_config, + default_dynamic_float16_dtype_config, + ] + embedding_op_dtype_configs = [ + weight_only_quint8_dtype_config, + weight_only_quint4x2_dtype_config, + ] + layer_norm_op_dtype_configs = [input_output_only_quint8_dtype_config] + return ( + BackendConfig("native") + .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) + .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) + .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) + .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) + .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) + .set_backend_pattern_configs( + _get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs( + _get_share_qparams_op_configs(share_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs( + _get_tensor_info_op_configs(tensor_info_op_dtype_configs) + ) + .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) + .set_backend_pattern_configs(_get_ln_configs(layer_norm_op_dtype_configs)) + .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) + .set_backend_pattern_configs( + _get_embedding_op_configs(embedding_op_dtype_configs) + ) + ) + + +def get_native_backend_config_dict(): + """ + Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) in dictionary form. + """ + return get_native_backend_config().to_dict() + + +def get_test_only_legacy_native_backend_config_dict(): + """ + Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) with various additional + fp16 ops in dictionary form. + """ + return get_test_only_legacy_native_backend_config().to_dict() diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/observation_type.py b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/observation_type.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/onednn.py b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/onednn.py new file mode 100644 index 0000000000000000000000000000000000000000..778203bbd0068fef8801f55dbc1e5dce31396b16 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/onednn.py @@ -0,0 +1,640 @@ +# mypy: allow-untyped-defs +import itertools +import operator + +import torch +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.quantized.reference as nnqr +import torch.nn as nn +import torch.nn.functional as F +from torch.ao.quantization.fuser_method_mappings import _sequential_wrapper2 +from torch.ao.quantization.utils import MatchAllNode + +from ._common_operator_config_utils import ( + _get_binary_op_configs, + _get_bn_configs, + _get_cat_config, + _get_conv_configs, + _get_default_op_configs, + _get_embedding_op_configs, + _get_fixed_qparams_op_configs, + _get_linear_configs, + _get_ln_configs, + _get_rnn_op_configs, + _get_share_qparams_op_configs, +) +from .backend_config import ( + BackendConfig, + BackendPatternConfig, + DTypeConfig, + ObservationType, +) + + +# =================== +# | DTYPE CONFIGS | +# =================== + +onednn_weighted_op_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, +) + +onednn_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, +) + +onednn_dynamic_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.float, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + is_dynamic=True, +) + +onednn_weight_only_qint8_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.qint8, +) + +onednn_input_output_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.float, + bias_dtype=torch.float, +) + +# =================== +# | FUSER METHODS | +# =================== + + +def _fuse_linear_bn_leaky_relu(is_qat, linear, bn, leaky_relu): + r"""Given the linear, bn and leaky_relu modules, fuses them and returns the fused module + Args: + is_qat: a flag for whether we are using quantization aware training fusion + or post training quantization fusion + linear: Module instance of type Linear + bn: BatchNorm1d instance that needs to be fused with the linear layer + leaky_relu: LeakyReLU instance that needs to be fused with the linear layer + Examples:: + >>> # xdoctest: +SKIP(failing) + >>> m1 = nn.Linear(20, 10) + >>> b1 = nn.BatchNorm1d(10) + >>> lr = nn.LeakyReLU(0.01) + >>> m2 = _fuse_linear_bn_leaky_relu(m1, b1, lr) + """ + assert linear.training == bn.training and bn.training == leaky_relu.training, ( + "Linear, BN and LeakyReLU all must be in the same mode (train or eval)." + ) + + if is_qat: + raise NotImplementedError( + f"Cannot fuse train modules: {(linear, bn, leaky_relu)}" + ) + else: + map_to_fused_module_eval = { + nn.Linear: nni.LinearLeakyReLU, + } + fused_module = map_to_fused_module_eval.get(type(linear), None) + if fused_module is not None: + fused_linear = nn.utils.fusion.fuse_linear_bn_eval(linear, bn) + fm = fused_module(fused_linear, leaky_relu) + return fm + else: + raise NotImplementedError( + f"Cannot fuse eval modules: {(linear, bn, leaky_relu)}" + ) + + +# ====================== +# | CONFIGS FOR CONV | +# ====================== +observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + +conv_dtype_configs = [onednn_weighted_op_int8_dtype_config] +conv_configs = _get_conv_configs(conv_dtype_configs) + +# (1) Conv2d + Add + +# conv2d Y +# \ / +# add + +# include: +# conv2d conv2d +# \ / +# add + + +def _fuse_conv_add_left(is_qat, add, conv, _): + return nni.ConvAdd2d(conv, add) + + +def _conv_add_root_node_getter_left(pattern): + _, conv, _ = pattern + return conv + + +def _conv_add_extra_inputs_getter_left(pattern): + """get inputs pattern for extra inputs, inputs for root node + are assumed to be copied over from root node to the fused node + """ + _, _conv, extra_input = pattern + return [extra_input] + + +# conv2d +# \ +# bn Y +# \ / +# add + + +def _fuse_conv_bn_add_left(is_qat, add, bn_conv, _): + bn, conv = bn_conv + if is_qat: + raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, add)}") + else: + fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn) + return nni.ConvAdd2d(fused_conv, add) + + +def _conv_bn_add_root_node_getter_left(add_pattern): + _, bn_conv, _ = add_pattern + _bn, conv = bn_conv + return conv + + +def _conv_bn_add_extra_inputs_getter_left(add_pattern): + """get inputs pattern for extra inputs, inputs for root node + are assumed to be copied over from root node to the fused node + """ + _, _bn_conv, extra_input = add_pattern + return [extra_input] + + +conv_add_left_optioins = itertools.product( + [True, False], # with_bn + [torch.add, operator.add], # add_op +) + +for with_bn, add_op in conv_add_left_optioins: + if with_bn: + conv_configs.append( + BackendPatternConfig() + ._set_pattern_complex_format( + (add_op, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode) + ) # noqa: E131 + .set_observation_type(observation_type) + .set_dtype_configs(conv_dtype_configs) + .set_fuser_method(_fuse_conv_bn_add_left) + ._set_root_node_getter(_conv_bn_add_root_node_getter_left) + ._set_extra_inputs_getter(_conv_bn_add_extra_inputs_getter_left) + .set_fused_module(nni.ConvAdd2d) + ) + else: + conv_configs.append( + BackendPatternConfig() + ._set_pattern_complex_format((add_op, nn.Conv2d, MatchAllNode)) # noqa: E131 + .set_observation_type(observation_type) + .set_dtype_configs(conv_dtype_configs) + .set_fuser_method(_fuse_conv_add_left) + ._set_root_node_getter(_conv_add_root_node_getter_left) + ._set_extra_inputs_getter(_conv_add_extra_inputs_getter_left) + .set_fused_module(nni.ConvAdd2d) + ) + +# Y conv2d +# \ / +# add + + +def _fuse_conv_add_right(is_qat, add, _, conv): + return nni.ConvAdd2d(conv, add) + + +def _conv_add_root_node_getter_right(pattern): + _add, _, conv = pattern + return conv + + +def _conv_add_extra_inputs_getter_right(pattern): + """get inputs pattern for extra inputs, inputs for root node + are assumed to be copied over from root node to the fused node + """ + _, extra_input, _conv = pattern + return [extra_input] + + +# conv2d +# / +# Y bn +# \ / +# add + + +def _fuse_conv_bn_add_right(is_qat, add, _, bn_conv): + bn, conv = bn_conv + if is_qat: + raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, add)}") + else: + fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn) + return nni.ConvAdd2d(fused_conv, add) + + +def _conv_bn_add_root_node_getter_right(pattern): + _add, _, bn_conv = pattern + _bn, conv = bn_conv + return conv + + +def _conv_bn_add_extra_inputs_getter_right(pattern): + """get inputs pattern for extra inputs, inputs for root node + are assumed to be copied over from root node to the fused node + """ + _, extra_input, _bn_conv = pattern + return [extra_input] + + +conv_add_optioins = itertools.product( + [True, False], # with_bn + [torch.add, operator.add], # add_op +) + +for with_bn, add_op in conv_add_optioins: + if with_bn: + conv_configs.append( + BackendPatternConfig() + ._set_pattern_complex_format( + (add_op, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d)) + ) # noqa: E131 + .set_observation_type(observation_type) + .set_dtype_configs(conv_dtype_configs) + .set_fuser_method(_fuse_conv_bn_add_right) + ._set_root_node_getter(_conv_bn_add_root_node_getter_right) + ._set_extra_inputs_getter(_conv_bn_add_extra_inputs_getter_right) + .set_fused_module(nni.ConvAdd2d) + ) + else: + conv_configs.append( + BackendPatternConfig() + ._set_pattern_complex_format((add_op, MatchAllNode, nn.Conv2d)) # noqa: E131 + .set_observation_type(observation_type) + .set_dtype_configs(conv_dtype_configs) + .set_fuser_method(_fuse_conv_add_right) + ._set_root_node_getter(_conv_add_root_node_getter_right) + ._set_extra_inputs_getter(_conv_add_extra_inputs_getter_right) + .set_fused_module(nni.ConvAdd2d) + ) + +conv_configs.append( + BackendPatternConfig(nni.ConvAdd2d) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(conv_dtype_configs) + .set_root_module(nn.Conv2d) + .set_reference_quantized_module(nnqr.Conv2d) +) + +# (2) Conv2d + Add + Relu + +# conv2d Y +# \ / +# add +# \ +# relu + + +def _fuse_conv_add_relu_left(is_qat, relu, add_pattern): + add, conv, _ = add_pattern + return nni.ConvAddReLU2d(conv, add, relu) + + +def _conv_add_relu_root_node_getter_left(pattern): + _relu, add_pattern = pattern + _, conv, _ = add_pattern + return conv + + +def _conv_add_relu_extra_inputs_getter_left(pattern): + """get inputs pattern for extra inputs, inputs for root node + are assumed to be copied over from root node to the fused node + """ + _relu, add_pattern = pattern + _, _conv, extra_input = add_pattern + return [extra_input] + + +# conv2d +# \ +# bn Y +# \ / +# add +# \ +# relu + + +def _fuse_conv_bn_add_relu_left(is_qat, relu, add_pattern): + add, bn_conv, _ = add_pattern + bn, conv = bn_conv + if is_qat: + raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, add, relu)}") + else: + fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn) + return nni.ConvAddReLU2d(fused_conv, add, relu) + + +def _conv_bn_add_relu_root_node_getter_left(pattern): + _relu, add_pattern = pattern + _, bn_conv, _ = add_pattern + _bn, conv = bn_conv + return conv + + +def _conv_bn_add_relu_extra_inputs_getter_left(pattern): + """get inputs pattern for extra inputs, inputs for root node + are assumed to be copied over from root node to the fused node + """ + _relu, add_pattern = pattern + _, _bn_conv, extra_input = add_pattern + return [extra_input] + + +conv_add_relu_left_optioins = itertools.product( + [True, False], # with_bn + [torch.add, operator.add], # add_op +) + +for with_bn, add_op in conv_add_relu_left_optioins: + if with_bn: + conv_configs.append( + BackendPatternConfig() + ._set_pattern_complex_format( + (nn.ReLU, (add_op, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode)) + ) # noqa: E131 + .set_observation_type(observation_type) + .set_dtype_configs(conv_dtype_configs) + .set_fuser_method(_fuse_conv_bn_add_relu_left) + ._set_root_node_getter(_conv_bn_add_relu_root_node_getter_left) + ._set_extra_inputs_getter(_conv_bn_add_relu_extra_inputs_getter_left) + .set_fused_module(nni.ConvAddReLU2d) + ) + else: + conv_configs.append( + BackendPatternConfig() + ._set_pattern_complex_format((nn.ReLU, (add_op, nn.Conv2d, MatchAllNode))) # noqa: E131 + .set_observation_type(observation_type) + .set_dtype_configs(conv_dtype_configs) + .set_fuser_method(_fuse_conv_add_relu_left) + ._set_root_node_getter(_conv_add_relu_root_node_getter_left) + ._set_extra_inputs_getter(_conv_add_relu_extra_inputs_getter_left) + .set_fused_module(nni.ConvAddReLU2d) + ) + +# Y conv2d +# \ / +# add +# \ +# relu + + +def _fuse_conv_add_relu_right(is_qat, relu, add_pattern): + add, _, conv = add_pattern + return nni.ConvAddReLU2d(conv, add, relu) + + +def _conv_add_relu_root_node_getter_right(pattern): + _relu, add_pattern = pattern + _, _extra_input, conv = add_pattern + return conv + + +def _conv_add_relu_extra_inputs_getter_right(pattern): + """get inputs pattern for extra inputs, inputs for root node + are assumed to be copied over from root node to the fused node + """ + _relu, add_pattern = pattern + _, extra_input, _conv = add_pattern + return [extra_input] + + +# conv2d +# / +# Y bn +# \ / +# add +# \ +# relu + + +def _fuse_conv_bn_add_relu_right(is_qat, relu, add_pattern): + add, _, bn_conv = add_pattern + bn, conv = bn_conv + if is_qat: + raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, add, relu)}") + else: + fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn) + return nni.ConvAddReLU2d(fused_conv, add, relu) + + +def _conv_bn_add_relu_root_node_getter_right(pattern): + _relu, add_pattern = pattern + _, _, bn_conv = add_pattern + _bn, conv = bn_conv + return conv + + +def _conv_bn_add_relu_extra_inputs_getter_right(pattern): + """get inputs pattern for extra inputs, inputs for root node + are assumed to be copied over from root node to the fused node + """ + _relu, add_pattern = pattern + _, extra_input, _bn_conv = add_pattern + return [extra_input] + + +conv_add_relu_left_optioins = itertools.product( + [True, False], # with_bn + [torch.add, operator.add], # add_op +) + +for with_bn, add_op in conv_add_relu_left_optioins: + if with_bn: + conv_configs.append( + BackendPatternConfig() + ._set_pattern_complex_format( + (nn.ReLU, (add_op, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d))) + ) # noqa: E131 + .set_observation_type(observation_type) + .set_dtype_configs(conv_dtype_configs) + .set_fuser_method(_fuse_conv_bn_add_relu_right) + ._set_root_node_getter(_conv_bn_add_relu_root_node_getter_right) + ._set_extra_inputs_getter(_conv_bn_add_relu_extra_inputs_getter_right) + .set_fused_module(nni.ConvAddReLU2d) + ) + else: + conv_configs.append( + BackendPatternConfig() + ._set_pattern_complex_format((nn.ReLU, (add_op, MatchAllNode, nn.Conv2d))) # noqa: E131 + .set_observation_type(observation_type) + .set_dtype_configs(conv_dtype_configs) + .set_fuser_method(_fuse_conv_add_relu_right) + ._set_root_node_getter(_conv_add_relu_root_node_getter_right) + ._set_extra_inputs_getter(_conv_add_relu_extra_inputs_getter_right) + .set_fused_module(nni.ConvAddReLU2d) + ) + +conv_configs.append( + BackendPatternConfig(nni.ConvAddReLU2d) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(conv_dtype_configs) + .set_root_module(nn.Conv2d) + .set_reference_quantized_module(nnqr.Conv2d) +) + +# ======================== +# | CONFIGS FOR LINEAR | +# ======================== + +linear_dtype_configs = [ + onednn_weighted_op_int8_dtype_config, + onednn_dynamic_int8_dtype_config, +] +linear_configs = _get_linear_configs(linear_dtype_configs) + + +def _add_eltwise_fusion_configs( + configs, + root_module, + root_op, + post_module, + post_op, + dtype_configs, + fuser_method, + fused_module, + observation_type, + ref_quant_module, +): + # 1 base module + op module fusion config + configs.append( + BackendPatternConfig((root_module, post_module)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(fuser_method) + .set_fused_module(fused_module) + ) + # base module + functional post op + configs.append( + BackendPatternConfig((root_module, post_op)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(fuser_method) + .set_fused_module(fused_module) + ) + + # 2 fused module configs + configs.append( + BackendPatternConfig(fused_module) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(root_module) + .set_reference_quantized_module(ref_quant_module) + ) + + # 3 functional base op + post op configs + configs.append( + BackendPatternConfig((root_op, post_module)) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + configs.append( + BackendPatternConfig((root_op, post_op)) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + + +# Configs for linear + leaky_relu fusion +_add_eltwise_fusion_configs( + linear_configs, + nn.Linear, + F.linear, + nn.LeakyReLU, + F.leaky_relu, + linear_dtype_configs, + _sequential_wrapper2(nni.LinearLeakyReLU), + nni.LinearLeakyReLU, + observation_type, + nnqr.Linear, +) + +# Configs for linear module + batchnorm + leaky_relu +linear_configs.append( + BackendPatternConfig((nn.Linear, nn.BatchNorm1d, nn.LeakyReLU)) + .set_dtype_configs(linear_dtype_configs) # noqa: E131 + .set_fuser_method(_fuse_linear_bn_leaky_relu) + .set_fused_module(nni.LinearLeakyReLU) +) + +# Configs for linear + tanh fusion +_add_eltwise_fusion_configs( + linear_configs, + nn.Linear, + F.linear, + nn.Tanh, + torch.tanh, + linear_dtype_configs, + _sequential_wrapper2(nni.LinearTanh), + nni.LinearTanh, + observation_type, + nnqr.Linear, +) + +# =========================== +# | CONFIGS FOR OTHER OPS | +# =========================== + +binary_op_dtype_configs = [onednn_op_quint8_dtype_config] +default_op_dtype_configs = [onednn_op_quint8_dtype_config] +fixed_qparams_op_dtype_configs = [onednn_op_quint8_dtype_config] +share_qparams_op_dtype_configs = [onednn_op_quint8_dtype_config] +rnn_op_dtype_configs = [onednn_dynamic_int8_dtype_config] +embedding_op_dtype_configs = [onednn_weight_only_qint8_dtype_config] +layer_norm_op_dtype_configs = [onednn_input_output_only_quint8_dtype_config] + +# ===================== +# | BACKEND CONFIGS | +# ===================== + + +def get_onednn_backend_config() -> BackendConfig: + """ + Return the `BackendConfig` for PyTorch's native ONEDNN backend. + """ + return ( + BackendConfig("onednn") + .set_backend_pattern_configs(conv_configs) + .set_backend_pattern_configs(linear_configs) + .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) + .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) + .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) + .set_backend_pattern_configs( + _get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs( + _get_share_qparams_op_configs(share_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) + .set_backend_pattern_configs(_get_ln_configs(layer_norm_op_dtype_configs)) + .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) + .set_backend_pattern_configs( + _get_embedding_op_configs(embedding_op_dtype_configs) + ) + ) + + +__all__ = [ + "get_onednn_backend_config", +] diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/qnnpack.py b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/qnnpack.py new file mode 100644 index 0000000000000000000000000000000000000000..04d9c19d5398935220b9bfc0c6d7fd31b9a1b99c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/qnnpack.py @@ -0,0 +1,171 @@ +import torch + +from ._common_operator_config_utils import ( + _get_binary_op_configs, + _get_bn_configs, + _get_cat_config, + _get_conv_configs, + _get_default_op_configs, + _get_embedding_op_configs, + _get_fixed_qparams_op_configs, + _get_linear_configs, + _get_rnn_op_configs, + _get_share_qparams_op_configs, +) +from .backend_config import BackendConfig, DTypeConfig, DTypeWithConstraints + + +__all__ = [ + "get_qnnpack_backend_config", +] + +# =================== +# | DTYPE CONFIGS | +# =================== + +qnnpack_weighted_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, +) + +qnnpack_default_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, +) + +qnnpack_default_op_fp16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float16, + weight_dtype=torch.float16, + bias_dtype=torch.float16, +) + +qnnpack_default_dynamic_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.float, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + is_dynamic=True, +) + +qnnpack_default_dynamic_float16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float, + weight_dtype=torch.float16, + bias_dtype=torch.float, + is_dynamic=True, +) + +qnnpack_weight_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint8, +) + +qnnpack_weight_only_quint4x2_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint4x2, +) + +# xnnpack compatible dtype configs + +# We restrict scale values to be 2 ** -12 to ensure the +# requantization scale never falls below the xnnpack lower +# threshold. Additionally, for qint8 weight, we restrict +# the quantization values to [-127, +127], excluding -128. +# For more detail, refer to the description of +# `default_symmetric_qnnpack_qconfig`. + +# TODO: add additional restriction on qscheme to ensure it +# is either per_tensor_symmetric or per_channel_symmetric + +qnnpack_act_qint8_scale_min_2_neg_12 = DTypeWithConstraints( + dtype=torch.qint8, + scale_min_lower_bound=2**-12, +) + +qnnpack_weight_qint8_neg_127_to_127_scale_min_2_neg_12 = DTypeWithConstraints( + dtype=torch.qint8, + quant_min_lower_bound=-127, + quant_max_upper_bound=127, + scale_min_lower_bound=2**-12, +) + +qnnpack_weighted_op_qint8_symmetric_dtype_config = DTypeConfig( + input_dtype=qnnpack_act_qint8_scale_min_2_neg_12, + output_dtype=qnnpack_act_qint8_scale_min_2_neg_12, + weight_dtype=qnnpack_weight_qint8_neg_127_to_127_scale_min_2_neg_12, + bias_dtype=torch.float, +) + +qnnpack_default_op_qint8_symmetric_dtype_config = DTypeConfig( + input_dtype=qnnpack_act_qint8_scale_min_2_neg_12, + output_dtype=qnnpack_act_qint8_scale_min_2_neg_12, +) + + +# ===================== +# | BACKEND CONFIGS | +# ===================== + + +def get_qnnpack_backend_config() -> BackendConfig: + """ + Return the `BackendConfig` for PyTorch's native QNNPACK backend. + """ + conv_dtype_configs = [ + qnnpack_weighted_op_qint8_symmetric_dtype_config, + qnnpack_weighted_op_quint8_dtype_config, + ] + linear_dtype_configs = [ + qnnpack_weighted_op_qint8_symmetric_dtype_config, + qnnpack_weighted_op_quint8_dtype_config, + qnnpack_default_dynamic_int8_dtype_config, + qnnpack_default_dynamic_float16_dtype_config, + ] + binary_op_dtype_configs = [ + qnnpack_default_op_qint8_symmetric_dtype_config, + qnnpack_default_op_quint8_dtype_config, + ] + default_op_dtype_configs = [ + qnnpack_default_op_qint8_symmetric_dtype_config, + qnnpack_default_op_quint8_dtype_config, + ] + fixed_qparams_op_dtype_configs = [ + qnnpack_default_op_qint8_symmetric_dtype_config, + qnnpack_default_op_quint8_dtype_config, + ] + share_qparams_op_dtype_configs = [ + qnnpack_default_op_qint8_symmetric_dtype_config, + qnnpack_default_op_quint8_dtype_config, + ] + rnn_op_dtype_configs = [ + qnnpack_default_dynamic_int8_dtype_config, + qnnpack_default_dynamic_float16_dtype_config, + ] + embedding_op_dtype_configs = [ + qnnpack_weight_only_quint8_dtype_config, + qnnpack_weight_only_quint4x2_dtype_config, + ] + return ( + BackendConfig("qnnpack") + .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) + .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) + .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) + .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) + .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) + .set_backend_pattern_configs( + _get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs( + _get_share_qparams_op_configs(share_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) + .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) + .set_backend_pattern_configs( + _get_embedding_op_configs(embedding_op_dtype_configs) + ) + ) diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/tensorrt.py b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/tensorrt.py new file mode 100644 index 0000000000000000000000000000000000000000..cdeb3390d9cb712a0cb19192ed212a5e17243ef2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/tensorrt.py @@ -0,0 +1,98 @@ +# mypy: allow-untyped-defs +import torch + +from ._common_operator_config_utils import ( + _get_binary_op_configs, + _get_conv_configs, + _get_linear_configs, + _get_share_qparams_op_configs, + _get_tensor_info_op_configs, +) +from .backend_config import ( + BackendConfig, + BackendPatternConfig, + DTypeConfig, + ObservationType, +) + + +__all__ = [ + "get_tensorrt_backend_config", + "get_tensorrt_backend_config_dict", +] + + +def get_tensorrt_backend_config() -> BackendConfig: + """ + Return the `BackendConfig` for the TensorRT backend. + NOTE: Current api will change in the future, it's just to unblock experimentation for + new backends, please don't use it right now. + TODO: add a README when it's more stable + """ + # dtype configs + weighted_op_qint8_dtype_config = DTypeConfig( + input_dtype=torch.qint8, + output_dtype=torch.qint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + ) + non_weighted_op_qint8_dtype_config = DTypeConfig( + input_dtype=torch.qint8, + output_dtype=torch.qint8, + ) + + addmm_config = ( + BackendPatternConfig(torch.addmm) + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) + .add_dtype_config(weighted_op_qint8_dtype_config) + ._set_input_type_to_index( + { + "bias": 0, + "input": 1, + "weight": 2, + } + ) + ) + cat_config = ( + BackendPatternConfig(torch.cat) + .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) + .add_dtype_config(non_weighted_op_qint8_dtype_config) + ) + conv_dtype_configs = [ + weighted_op_qint8_dtype_config, + ] + linear_dtype_configs = [ + weighted_op_qint8_dtype_config, + ] + binary_op_dtype_configs = [ + weighted_op_qint8_dtype_config, + ] + share_qparams_op_dtype_configs = [ + non_weighted_op_qint8_dtype_config, + ] + tensor_info_op_dtype_configs = [ + non_weighted_op_qint8_dtype_config, + ] + # there might be things not supported in fx2trt, but it will error out + # during fx2trt conversion and can support them after that + return ( + BackendConfig("tensorrt") + .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) + .set_backend_pattern_config(addmm_config) + .set_backend_pattern_config(cat_config) + .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) + .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) + .set_backend_pattern_configs( + _get_share_qparams_op_configs(share_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs( + _get_tensor_info_op_configs(tensor_info_op_dtype_configs) + ) + ) + + +def get_tensorrt_backend_config_dict(): + """ + Return the `BackendConfig` for the TensorRT backend in dictionary form. + """ + return get_tensorrt_backend_config().to_dict() diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/utils.py b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e2d452ebdc962e11addf694685e133ac49e5c23d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/utils.py @@ -0,0 +1,314 @@ +# mypy: allow-untyped-defs +from typing import Any, Callable, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.ao.quantization.fuser_method_mappings import _reverse2, _reverse3 +from torch.ao.quantization.utils import Pattern + +from .backend_config import BackendConfig, BackendPatternConfig, DTypeConfig + + +__all__ = [ + "get_pattern_to_dtype_configs", + "get_qat_module_classes", + "get_fused_module_classes", + "get_pattern_to_input_type_to_index", + "get_root_module_to_quantized_reference_module", + "get_fuser_method_mapping", + "get_module_to_qat_module", + "get_fusion_pattern_to_root_node_getter", + "get_fusion_pattern_to_extra_inputs_getter", + "remove_boolean_dispatch_from_name", + "pattern_to_human_readable", + "entry_to_pretty_str", +] + + +def get_pattern_to_dtype_configs( + backend_config: BackendConfig, +) -> dict[Pattern, list[DTypeConfig]]: + pattern_to_dtype_configs: dict[Pattern, list[DTypeConfig]] = {} + for pattern, config in backend_config._pattern_complex_format_to_config.items(): + pattern_to_dtype_configs[pattern] = config.dtype_configs + return pattern_to_dtype_configs + + +def get_qat_module_classes(backend_config: BackendConfig) -> tuple[type, ...]: + qat_module_classes = [ + config.qat_module + for config in backend_config.configs + if config.qat_module is not None + ] + return tuple(set(qat_module_classes)) + + +def get_fused_module_classes(backend_config: BackendConfig) -> tuple[type, ...]: + fused_module_classes = [ + config.fused_module + for config in backend_config.configs + if config.fused_module is not None + ] + return tuple(set(fused_module_classes)) + + +def get_pattern_to_input_type_to_index( + backend_config: BackendConfig, +) -> dict[Pattern, dict[str, int]]: + pattern_to_input_type_to_index: dict[Pattern, dict[str, int]] = {} + for pattern, config in backend_config._pattern_complex_format_to_config.items(): + pattern_to_input_type_to_index[pattern] = config._input_type_to_index + return pattern_to_input_type_to_index + + +def get_root_module_to_quantized_reference_module( + backend_config: BackendConfig, +) -> dict[type[torch.nn.Module], type[torch.nn.Module]]: + mapping: dict[type[torch.nn.Module], type[torch.nn.Module]] = {} + for config in backend_config.configs: + if ( + config.root_module is not None + and config.reference_quantized_module is not None + ): + mapping[config.root_module] = config.reference_quantized_module + return mapping + + +def get_fuser_method_mapping( + backend_config: BackendConfig, +) -> dict[Pattern, Union[nn.Sequential, Callable]]: + fuser_method_mapping: dict[Pattern, Union[nn.Sequential, Callable]] = {} + for pattern, config in backend_config._pattern_complex_format_to_config.items(): + if config.fuser_method is not None: + # Note: both the fuser method and the pattern are specified in forward order in the + # BackendConfig, but the internal pattern matching code uses the reversed nested tuple + # format, so we need to convert both to the internal format + fuser_method = _get_fuser_method_in_reversed_nested_tuple_format(config) + fuser_method_mapping[pattern] = fuser_method + return fuser_method_mapping + + +def get_module_to_qat_module( + backend_config: BackendConfig, +) -> dict[Pattern, type[torch.nn.Module]]: + module_to_qat_module: dict[Pattern, type[torch.nn.Module]] = {} + for pattern, config in backend_config._pattern_complex_format_to_config.items(): + if config.qat_module is not None: + module_to_qat_module[pattern] = config.qat_module + return module_to_qat_module + + +def get_fusion_pattern_to_root_node_getter( + backend_config: BackendConfig, +) -> dict[Pattern, Callable]: + """Get a map from fusion pattern to a function that returns the root node + from the fusion pattern, e.g. the most common one is: + def get_root_node(node_pattern): + while not isinstance(node_pattern[-1], Node): + node_pattern = node_pattern[-1] + return node_pattern[-1] + This can work for all patterns whose root node is the "last node" in the pattern, + e.g. (torch.add, MatchAllNode, (torch.ReLU, torch.Conv2d)) + """ + root_node_getter_mapping: dict[Pattern, Callable] = {} + for pattern, config in backend_config._pattern_complex_format_to_config.items(): + if config._root_node_getter is not None: + root_node_getter_mapping[pattern] = config._root_node_getter + return root_node_getter_mapping + + +def get_fusion_pattern_to_extra_inputs_getter( + backend_config: BackendConfig, +) -> dict[Pattern, Callable]: + """Get a map from fusion pattern to a function that returns extra input nodes + from the fusion pattern, in the order required by the root node. This is optional, + if not specified, we will not copy over any extra inputs for the root node. + Example: + # Let's say we have the pattern (torch.add, MatchAllNode, (torch.nn.BatchNorm2d, torch.nn.Conv2d)) + # and root node is torch.nn.Conv2d, and the node in MatchAllNode would be an extra + # argument to the fused module, we can unpack the pattern and return the node at + # MatchAllNode here + # we can implement extra_inputs_getter as follows: + def extra_inputs_getter(pattern) -> List[Any]: + add, extra_input, conv_pattern = pattern + return [extra_input] + """ + extra_inputs_getter_mapping: dict[Pattern, Callable] = {} + for pattern, config in backend_config._pattern_complex_format_to_config.items(): + if config._extra_inputs_getter is not None: + extra_inputs_getter_mapping[pattern] = config._extra_inputs_getter + return extra_inputs_getter_mapping + + +def remove_boolean_dispatch_from_name(p) -> Any: + """ + Some ops have a default string representation such as + '.fn at 0x7ff1106bf280>', + this function replaces them with the hardcoded function names. + """ + if p is F.fractional_max_pool2d: + return "torch.nn.functional.fractional_max_pool2d" + elif p is F.fractional_max_pool3d: + return "torch.nn.functional.fractional_max_pool3d" + elif p is F.max_pool1d: + return "torch.nn.functional.max_pool1d" + elif p is F.max_pool2d: + return "torch.nn.functional.max_pool2d" + elif p is F.max_pool3d: + return "torch.nn.functional.max_pool3d" + elif p is F.adaptive_max_pool1d: + return "torch.nn.functional.adaptive_max_pool1d" + elif p is F.adaptive_max_pool2d: + return "torch.nn.functional.adaptive_max_pool2d" + elif p is F.adaptive_max_pool3d: + return "torch.nn.functional.adaptive_max_pool3d" + assert "boolean_dispatch" not in str(p), ( + f"{p} does not have a human readable representation in " + + "quantization documentation" + ) + return p + + +def pattern_to_human_readable(p) -> Any: + if isinstance(p, tuple): + # nested patterns, recurse + return tuple(pattern_to_human_readable(inner_p) for inner_p in p) + elif isinstance(p, str): + # method names are already human readable + return p + else: + p = remove_boolean_dispatch_from_name(p) + return p + + +# TODO(future PR): move backend_config_dict to use dataclass and move this logic to +# the corresponding __str__ function +def entry_to_pretty_str(entry) -> str: + """ + Given a backend_config_dict entry, returns a string with the human readable + representation of it. + """ + s = "{\n" + + # always output the pattern first + if "pattern" in entry: + pattern_str = pattern_to_human_readable(entry["pattern"]) + + s += f" 'pattern': {pattern_str},\n" + + # custom output for dtype_configs to make it look nice + if "dtype_configs" in entry: + s += " 'dtype_configs': [\n" + for dtype_config in entry["dtype_configs"]: + s += " {\n" + for k, v in dtype_config.items(): + s += f" '{k}': {v},\n" + s += " },\n" + s += " ],\n" + + # custom output for num_tensor_args_to_observation_type to make it look nice + if "num_tensor_args_to_observation_type" in entry: + s += " 'num_tensor_args_to_observation_type': {\n" + for k, v in entry["num_tensor_args_to_observation_type"].items(): + s += f" {k}: {v},\n" + s += " },\n" + + # output all the other fields + custom_handled_fields = [ + "pattern", + "dtype_configs", + "num_tensor_args_to_observation_type", + ] + for field_name in entry: + if field_name in custom_handled_fields: + continue + s += f" '{field_name}': {entry[field_name]},\n" + + s += "}" + return s + + +def _get_pattern_in_reversed_nested_tuple_format( + config: BackendPatternConfig, +) -> Pattern: + """ + Return the pattern specified in the given config in the reversed nested tuple format + used internally in the quantization pattern matching code. + + If the pattern is not a tuple, or the pattern is already specified in the reversed + nested tuple format, return the pattern as is. Otherwise: + + For 2-tuples (a, b), return (b, a). + For 3-tuples (a, b, c), return (c, (b, a)). + + For example: + * Given nn.Linear, return nn.Linear + * Given (nn.Linear, nn.ReLU), return (nn.ReLU, nn.Linear) + * Given (nn.Conv2d, nn.BatchNorm2d, nn.ReLU), return + (nn.ReLU, (nn.BatchNorm2d, nn.Conv2d)) + + For context, the reason why this is needed is the user-facing BackendConfig + API accepts the flat 2-or-3-tuple format in forward order. While this simple + format handles the vast majority of use cases, it does not handle the more + complex ones, and so the internal pattern matching code for quantization uses + the following, more general reversed nested tuple format instead: + + operator = module_type | functional | torch op | native op | MatchAllNode + Pattern = (operator, Pattern, Pattern, ...) | operator + + In the future, we expect to replace the above complex format with the one used + by the subgraph rewriter in torch.fx, so we don't have to maintain our own + complex pattern matching code. Then we won't need this helper function anymore. + """ + if config._pattern_complex_format is not None: + return config._pattern_complex_format + if config.pattern is None: + raise ValueError( + "Either 'pattern' or 'pattern_complex_format' must be specified" + ) + if not isinstance(config.pattern, tuple): + return config.pattern + + # Pattern is specified in the simple tuple format, need to convert + if len(config.pattern) == 2: + (a, b) = config.pattern + return (b, a) + elif len(config.pattern) == 3: + (a, b, c) = config.pattern + return (c, (b, a)) + else: + raise ValueError("Expected a tuple with 2 or 3 elements, got: ", config.pattern) + + +def _get_fuser_method_in_reversed_nested_tuple_format( + config: BackendPatternConfig, +) -> Callable: + """ + Return the fuser method specified in the given config in the reversed nested + tuple format used internally in the quantization pattern matching code. + + If pattern is specified in the reversed nested tuple format, we assume the + fuser method is also specified in this format and simply return it as is. + Otherwise, we convert the fuser method as follows: + + * Given f(is_qat, conv, relu), return f'(is_qat, relu, conv) + * Given f(is_qat, conv, bn, relu), return f'(is_qat, relu, bn_conv), + where bn_conv is a 2-tuple (bn, conv) + + The first argument of a fuser method is always `is_qat` and is not affected + in the conversion. We currently only support functions with 3 or 4 arguments. + """ + assert config.fuser_method is not None + if config._pattern_complex_format is not None: + return config.fuser_method + if not isinstance(config.pattern, tuple): + raise ValueError("Expected pattern to be a tuple, got: ", config.pattern) + + # Pattern is specified in the simple tuple format, need to convert + if len(config.pattern) == 2: + return _reverse2(config.fuser_method) + elif len(config.pattern) == 3: + return _reverse3(config.fuser_method) + else: + raise ValueError("Expected a tuple with 2 or 3 elements, got: ", config.pattern) diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/x86.py b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/x86.py new file mode 100644 index 0000000000000000000000000000000000000000..3f59e346ede4bd568210eee469488e1ba88dd19e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/backend_config/x86.py @@ -0,0 +1,126 @@ +import torch + +from ._common_operator_config_utils import ( + _get_binary_op_configs, + _get_bn_configs, + _get_cat_config, + _get_conv_configs, + _get_default_op_configs, + _get_embedding_op_configs, + _get_fixed_qparams_op_configs, + _get_linear_configs, + _get_rnn_op_configs, + _get_share_qparams_op_configs, + _get_tensor_info_op_configs, +) +from .backend_config import BackendConfig, DTypeConfig + + +__all__ = [ + "get_x86_backend_config", +] + +# =================== +# | DTYPE CONFIGS | +# =================== + +# X86 aligns with FBGEMM for now + +x86_weighted_op_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, +) + +x86_default_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, +) + +x86_default_op_fp16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float16, + weight_dtype=torch.float16, + bias_dtype=torch.float16, +) + +x86_default_dynamic_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.float, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + is_dynamic=True, +) + +x86_default_dynamic_float16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float, + weight_dtype=torch.float16, + bias_dtype=torch.float, + is_dynamic=True, +) + +x86_weight_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint8, +) + +x86_weight_only_quint4x2_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint4x2, +) + + +# ===================== +# | BACKEND CONFIGS | +# ===================== + + +def get_x86_backend_config() -> BackendConfig: + """ + Return the `BackendConfig` for PyTorch's native x86 backend. + """ + conv_dtype_configs = [x86_weighted_op_int8_dtype_config] + linear_dtype_configs = [ + x86_weighted_op_int8_dtype_config, + x86_default_dynamic_int8_dtype_config, + x86_default_dynamic_float16_dtype_config, + ] + binary_op_dtype_configs = [x86_weighted_op_int8_dtype_config] + default_op_dtype_configs = [x86_default_op_quint8_dtype_config] + fixed_qparams_op_dtype_configs = [x86_weighted_op_int8_dtype_config] + share_qparams_op_dtype_configs = [x86_default_op_quint8_dtype_config] + tensor_info_op_dtype_configs = [x86_default_op_quint8_dtype_config] + rnn_op_dtype_configs = [ + x86_default_dynamic_int8_dtype_config, + x86_default_dynamic_float16_dtype_config, + ] + embedding_op_dtype_configs = [ + x86_weight_only_quint8_dtype_config, + x86_weight_only_quint4x2_dtype_config, + ] + return ( + BackendConfig("x86") + .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) + .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) + .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) + .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) + .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) + .set_backend_pattern_configs( + _get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs( + _get_share_qparams_op_configs(share_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs( + _get_tensor_info_op_configs(tensor_info_op_dtype_configs) + ) + .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) + .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) + .set_backend_pattern_configs( + _get_embedding_op_configs(embedding_op_dtype_configs) + ) + ) diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fake_quantize.py b/phivenv/Lib/site-packages/torch/ao/quantization/fake_quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..c5c3e0fc0aa399dfae251ce6607787e318931bbd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/fake_quantize.py @@ -0,0 +1,650 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +"""Implements modules used to perform fake quantization.""" + +import re +from abc import ABC, abstractmethod +from typing import Any + +import torch +from torch.ao.quantization.observer import ( + _with_args, + default_fixed_qparams_range_0to1_observer, + default_fixed_qparams_range_neg1to1_observer, + FixedQParamsObserver, + HistogramObserver, + MovingAverageMinMaxObserver, + MovingAveragePerChannelMinMaxObserver, +) +from torch.nn import Module + + +__all__ = [ + "FakeQuantizeBase", + "FakeQuantize", + "FixedQParamsFakeQuantize", + "FusedMovingAvgObsFakeQuantize", + "disable_fake_quant", + "disable_observer", + "enable_fake_quant", + "enable_observer", + "default_fake_quant", + "default_weight_fake_quant", + "default_dynamic_fake_quant", + "default_fixed_qparams_range_neg1to1_fake_quant", + "default_fixed_qparams_range_0to1_fake_quant", + "default_symmetric_fixed_qparams_fake_quant", + "default_affine_fixed_qparams_fake_quant", + "default_per_channel_weight_fake_quant", + "default_embedding_fake_quant", + "default_embedding_fake_quant_4bit", + "default_histogram_fake_quant", + "default_fused_act_fake_quant", + "default_fused_wt_fake_quant", + "default_fused_per_channel_wt_fake_quant", + "fused_wt_fake_quant_range_neg_127_to_127", + "fused_per_channel_wt_fake_quant_range_neg_127_to_127", +] + + +def _is_per_channel(qscheme: "torch.qscheme") -> bool: + return qscheme in [ + torch.per_channel_symmetric, + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + ] + + +def _is_per_tensor(qscheme: "torch.qscheme") -> bool: + return qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine] + + +def _is_symmetric_quant(qscheme: "torch.qscheme") -> bool: + return qscheme in [torch.per_tensor_symmetric, torch.per_channel_symmetric] + + +def _is_float_qparams(qscheme: "torch.qscheme") -> bool: + return qscheme in [ + torch.per_channel_affine_float_qparams, + ] + + +class FakeQuantizeBase(ABC, Module): + r"""Base fake quantize module. + + Base fake quantize module + Any fake quantize implementation should derive from this class. + + Concrete fake quantize module should follow the same API. In forward, they will update + the statistics of the observed Tensor and fake quantize the input. They should also provide a + `calculate_qparams` function that computes the quantization parameters given + the collected statistics. + + """ + + fake_quant_enabled: torch.Tensor + observer_enabled: torch.Tensor + + def __init__(self) -> None: + """Set fake_quant_enabled and observer_enabled.""" + super().__init__() + # fake_quant_enabled and observer_enabled are buffers to support their + # replication in DDP. Data type is uint8 because NCCL does not support + # bool tensors. + self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.uint8)) + self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.uint8)) + + @abstractmethod + def forward(self, x): + pass + + @abstractmethod + def calculate_qparams(self, **kwargs): + pass + + @torch.jit.export + def enable_fake_quant(self, enabled: bool = True) -> None: + self.fake_quant_enabled[0] = 1 if enabled else 0 + + @torch.jit.export + def disable_fake_quant(self): + self.enable_fake_quant(False) + + @torch.jit.export + def enable_observer(self, enabled: bool = True) -> None: + self.observer_enabled[0] = 1 if enabled else 0 + + @torch.jit.export + def disable_observer(self): + self.enable_observer(False) + + @classmethod + def with_args(cls, **kwargs): + fake_quant_constructor = _with_args(cls, **kwargs) + # need to assign the correct module to fake_quantize + # constructors to satisfy public v private requirements + fake_quant_constructor.__module__ = "torch.ao.quantization.fake_quantize" + return fake_quant_constructor + + +class FakeQuantize(FakeQuantizeBase): + r"""Simulate the quantize and dequantize operations in training time. + + The output of this module is given by:: + + x_out = ( + clamp(round(x / scale + zero_point), quant_min, quant_max) - zero_point + ) * scale + + * :attr:`is_dynamic` indicates whether the fake quantie is a placeholder for dynamic quantization + operators (choose_qparams -> q -> dq) or static quantization operators (q -> dq) + + * :attr:`scale` defines the scale factor used for quantization. + + * :attr:`zero_point` specifies the quantized value to which 0 in floating point maps to + + * :attr:`fake_quant_enabled` controls the application of fake quantization on tensors, note that + statistics can still be updated. + + * :attr:`observer_enabled` controls statistics collection on tensors + + * :attr:`dtype` specifies the quantized dtype that is being emulated with fake-quantization, + allowable values are torch.qint8 and torch.quint8. + + Args: + + observer (module): Module for observing statistics on input tensors and calculating scale + and zero-point. + observer_kwargs (optional): Arguments for the observer module + + Attributes: + activation_post_process (Module): User provided module that collects statistics on the input tensor and + provides a method to calculate scale and zero-point. + + """ + + scale: torch.Tensor + zero_point: torch.Tensor + + def __init__( + self, + observer=MovingAverageMinMaxObserver, + quant_min=None, + quant_max=None, + is_dynamic=False, + **observer_kwargs, + ): + super().__init__() + # Populate quant_min/quant_max to observer_kwargs if valid + if quant_min is not None and quant_max is not None: + assert quant_min <= quant_max, ( + "quant_min must be less than or equal to quant_max" + ) + dtype = observer_kwargs.get("dtype", torch.quint8) + if hasattr(observer, "p"): + # In case observer is _PartialWrapper, dtype can be stored in + # observer.p.keywords["dtype"] + dtype = getattr(getattr(observer, "p", {}), "keywords", {}).get( + "dtype", dtype + ) + assert torch.iinfo(dtype).min <= quant_min, "quant_min out of bound" + assert quant_max <= torch.iinfo(dtype).max, "quant_max out of bound" + observer_kwargs.update({"quant_min": quant_min, "quant_max": quant_max}) + observer_kwargs["is_dynamic"] = is_dynamic + self.activation_post_process = observer(**observer_kwargs) + # TODO: keeping self.quant_min/max for BC; remove after a couple releases + # Users should use self.activation_post_process.quant_min + self.quant_min = self.activation_post_process.quant_min + self.quant_max = self.activation_post_process.quant_max + self.is_dynamic = self.activation_post_process.is_dynamic + if _is_float_qparams(self.activation_post_process.qscheme): + zero_point_dtype = torch.float + else: + zero_point_dtype = torch.int + self.register_buffer("scale", torch.tensor([1.0], dtype=torch.float)) + self.register_buffer("zero_point", torch.tensor([0], dtype=zero_point_dtype)) + self.dtype = self.activation_post_process.dtype + self.qscheme = self.activation_post_process.qscheme + self.ch_axis = ( + self.activation_post_process.ch_axis + if hasattr(self.activation_post_process, "ch_axis") + else -1 + ) + assert _is_per_channel(self.qscheme) or _is_per_tensor(self.qscheme), ( + "Only per channel and per tensor quantization are supported in fake quantize" + + " got qscheme: " + + str(self.qscheme) + ) + self.is_per_channel = _is_per_channel(self.qscheme) + + @torch.jit.export + def calculate_qparams(self): # type: ignore[override] + return self.activation_post_process.calculate_qparams() + + def forward(self, X): + if self.observer_enabled[0] == 1: + self.activation_post_process(X.detach()) + _scale, _zero_point = self.calculate_qparams() + _scale, _zero_point = ( + _scale.to(self.scale.device), + _zero_point.to(self.zero_point.device), + ) + if self.scale.shape != _scale.shape: + self.scale.resize_(_scale.shape) + self.zero_point.resize_(_zero_point.shape) + self.scale.copy_(_scale) + self.zero_point.copy_(_zero_point) + + if self.fake_quant_enabled[0] == 1: + if self.is_per_channel: + X = torch.fake_quantize_per_channel_affine( + X, + self.scale, + self.zero_point, + self.ch_axis, + self.activation_post_process.quant_min, + self.activation_post_process.quant_max, + ) + else: + X = torch.fake_quantize_per_tensor_affine( + X, + self.scale, + self.zero_point, + self.activation_post_process.quant_min, + self.activation_post_process.quant_max, + ) + return X + + @torch.jit.export + def extra_repr(self): + return ( + f"fake_quant_enabled={self.fake_quant_enabled}, observer_enabled={self.observer_enabled}, " + f"quant_min={self.activation_post_process.quant_min}, quant_max={self.activation_post_process.quant_max}, " + f"dtype={self.dtype}, qscheme={self.qscheme}, ch_axis={self.ch_axis}, " + f"scale={self.scale}, zero_point={self.zero_point}" + ) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + # We cannot currently register scalar values as buffers, so need to manually + # specify serialization here. + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + "scale"] = self.scale + destination[prefix + "zero_point"] = self.zero_point + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + # Removing this function throws an error that the size of the loaded tensor does not match the original size + # i.e., These buffers start out with numel 0 and become numel 1 once they have their first forward pass. + local_state = ["scale", "zero_point"] + for name in local_state: + key = prefix + name + if key in state_dict: + val = state_dict[key] + # Custom handling to allow loading scale and zero_point + # of size N into uninitialized buffers of size 0. The + # buffers are resized here, and the values are copied in + # the default state_dict loading code of the parent. + if name == "scale": + self.scale.resize_(val.shape) + else: + assert name == "zero_point" + self.zero_point.resize_(val.shape) + # For torchscript module we need to update the attributes here since we do not + # call the `_load_from_state_dict` function defined module.py + if torch.jit.is_scripting(): + if name == "scale": + self.scale.copy_(val) + else: + assert name == "zero_point" + self.zero_point.copy_(val) + elif strict: + missing_keys.append(key) + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + +class FixedQParamsFakeQuantize(FakeQuantize): + """Simulate quantize and dequantize in training time. + + Simulate quantize and dequantize with fixed quantization + parameters in training time. Only per tensor quantization + is supported. + """ + + # TODO: rename observer to observer_ctr + def __init__(self, observer): + super().__init__(observer=observer) + assert type(self.activation_post_process) == FixedQParamsObserver, ( + f"{self.__class__.__name__}'s observer must be a {FixedQParamsObserver.__name__}" + ) + self._observer_ctr = observer + self.scale = self.activation_post_process.scale + self.zero_point = self.activation_post_process.zero_point + assert _is_per_tensor(self.qscheme), ( + "Only per tensor quantization is supported" + + " FixedQParamsFakeQuantize module, got qscheme:" + + str(self.qscheme) + ) + + @torch.jit.export + def calculate_qparams(self): # type: ignore[override] + return self.scale, self.zero_point + + @torch.jit.export + def extra_repr(self): + """Define a string representation of the object's attributes.""" + return ( + f"fake_quant_enabled={self.fake_quant_enabled}, observer_enabled={self.observer_enabled}, " + f"scale={self.scale}, zero_point={self.zero_point}, " + f"dtype={self.dtype}, quant_min={self.activation_post_process.quant_min}, " + f"quant_max={self.activation_post_process.quant_max}, qscheme={self.qscheme}" + ) + + +class FusedMovingAvgObsFakeQuantize(FakeQuantize): + r"""Define a fused module to observe the tensor. + + Fused module that is used to observe the input tensor (compute min/max), compute + scale/zero_point and fake_quantize the tensor. + This module uses calculation similar MovingAverageMinMaxObserver for the inputs, + to compute the min/max values in order to compute the scale/zero_point. + The qscheme input in the observer is used to differentiate between symmetric/affine + quantization scheme. + + The output of this module is given by + x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale + + Similar to :class:`~torch.ao.quantization.FakeQuantize`, and accepts the same attributes as the + base class. + + """ + + def __init__( + self, + observer: Any = MovingAverageMinMaxObserver, + quant_min: int = 0, + quant_max: int = 255, + **observer_kwargs: Any, + ) -> None: + super().__init__(observer, quant_min, quant_max, **observer_kwargs) + assert isinstance( + self.activation_post_process, + (MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver), + ), ( + "Fused observer+fake_quant module only works with MovingAverageMinMaxObserver" + ) + self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.long)) + self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.long)) + self.is_symmetric_quant = _is_symmetric_quant( + self.activation_post_process.qscheme + ) + + @torch.jit.export + def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]: # type: ignore[override] + return self.activation_post_process.calculate_qparams() + + @torch.jit.export + def extra_repr(self) -> str: + return ( + f"fake_quant_enabled={self.fake_quant_enabled}, observer_enabled={self.observer_enabled}, " + f"scale={self.scale}, zero_point={self.zero_point}, dtype={self.dtype}, " + f"quant_min={self.activation_post_process.quant_min}, quant_max={self.activation_post_process.quant_max}, " + f"qscheme={self.qscheme}, reduce_range={self.activation_post_process.reduce_range}" + ) + + def forward(self, X: torch.Tensor) -> torch.Tensor: + return torch.fused_moving_avg_obs_fake_quant( + X, + self.observer_enabled, + self.fake_quant_enabled, + self.activation_post_process.min_val, + self.activation_post_process.max_val, + self.scale, + self.zero_point, + self.activation_post_process.averaging_constant, + self.activation_post_process.quant_min, + self.activation_post_process.quant_max, + self.ch_axis, + self.is_per_channel, + self.is_symmetric_quant, + ) + + +default_fake_quant = FakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=True, +) +""" +Default fake_quant for activations. +""" + +default_weight_fake_quant = FakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + qscheme=torch.per_tensor_symmetric, + reduce_range=False, +) +""" +Default fake_quant for weights. +Observer is memoryless since averaging_constant is 1. +""" + +default_dynamic_fake_quant = FakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + is_dynamic=True, + dtype=torch.quint8, + averaging_constant=1, +) +""" +Default dynamic fake_quant for activations. +""" + +default_fixed_qparams_range_neg1to1_fake_quant = FixedQParamsFakeQuantize.with_args( + observer=default_fixed_qparams_range_neg1to1_observer +) +default_fixed_qparams_range_0to1_fake_quant = FixedQParamsFakeQuantize.with_args( + observer=default_fixed_qparams_range_0to1_observer +) +# TODO: the following 2 variables are kept for backwards compatibility; remove after a few releases +default_symmetric_fixed_qparams_fake_quant = ( + default_fixed_qparams_range_neg1to1_fake_quant +) +default_affine_fixed_qparams_fake_quant = default_fixed_qparams_range_0to1_fake_quant + +default_per_channel_weight_fake_quant = FakeQuantize.with_args( + observer=MovingAveragePerChannelMinMaxObserver, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + qscheme=torch.per_channel_symmetric, + reduce_range=False, + ch_axis=0, +) +""" +Default fake_quant for per-channel weights. +Observer is memoryless since averaging_constant is 1. +""" +default_embedding_fake_quant = FakeQuantize.with_args( + observer=MovingAveragePerChannelMinMaxObserver, + qscheme=torch.per_channel_affine_float_qparams, + dtype=torch.quint8, + quant_min=0, + quant_max=255, + ch_axis=0, + averaging_constant=1, +) +""" +Default fake_quant for embeddings. +Observer is memoryless since averaging_constant is 1. +""" + +default_embedding_fake_quant_4bit = FakeQuantize.with_args( + observer=MovingAveragePerChannelMinMaxObserver, + qscheme=torch.per_channel_affine_float_qparams, + ch_axis=0, + dtype=torch.quint4x2, + averaging_constant=1, +) + +default_histogram_fake_quant = FakeQuantize.with_args( + observer=HistogramObserver, + quant_min=0, + quant_max=255, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=True, +) +""" +Fake_quant for activations using a histogram.. +""" + + +default_fused_act_fake_quant = FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + dtype=torch.quint8, +) + +""" +Fused version of `default_fake_quant`, with improved performance. +""" + + +default_fused_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + qscheme=torch.per_tensor_symmetric, +) +""" +Fused version of `default_weight_fake_quant`, with improved performance. +""" + +default_fused_per_channel_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAveragePerChannelMinMaxObserver, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + qscheme=torch.per_channel_symmetric, +) +""" +Fused version of `default_per_channel_weight_fake_quant`, with improved performance. +""" + +fused_wt_fake_quant_range_neg_127_to_127 = FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=-127, + quant_max=127, + dtype=torch.qint8, + qscheme=torch.per_tensor_symmetric, + eps=2**-12, +) +""" +Fused version of `default_weight_fake_quant`, with the 8-bit values restricted to [-127, +127], excluding -128. +""" + +fused_per_channel_wt_fake_quant_range_neg_127_to_127 = ( + FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAveragePerChannelMinMaxObserver, + quant_min=-127, + quant_max=127, + dtype=torch.qint8, + qscheme=torch.per_channel_symmetric, + eps=2**-12, + ) +) + +""" +Fused version of `default_per_channel_weight_fake_quant`, with the 8-bit values restricted to [-127, +127], excluding -128. +""" + + +def _is_fake_quant_script_module(mod): + """Return true if given mod is an instance of FakeQuantize script module.""" + if isinstance(mod, torch.jit.RecursiveScriptModule): + # qualified name looks like '__torch__.torch.ao.quantization.fake_quantize.___torch_mangle_2.FakeQuantize' + suffix = mod._c.qualified_name.split(".", 1)[1] + name = re.sub(r"\.___torch_mangle_\d+", "", suffix) + return ( + name == "torch.ao.quantization.fake_quantize.FakeQuantize" + or name + == "torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize" + ) + return False + + +def disable_fake_quant(mod): + """Disable fake quantization for the module. + + Disable fake quantization for this module, if applicable. Example usage:: + + # model is any PyTorch model + model.apply(torch.ao.quantization.disable_fake_quant) + + """ + if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod): + mod.disable_fake_quant() + + +def enable_fake_quant(mod): + """Enable fake quantization for the module. + + Enable fake quantization for this module, if applicable. Example usage:: + + # model is any PyTorch model + model.apply(torch.ao.quantization.enable_fake_quant) + + """ + if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod): + mod.enable_fake_quant() + + +def disable_observer(mod): + """Disable observation for this module. + + Disable observation for this module, if applicable. Example usage:: + + # model is any PyTorch model + model.apply(torch.ao.quantization.disable_observer) + + """ + if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod): + mod.disable_observer() + + +def enable_observer(mod): + """Enable observation for this module. + + Enable observation for this module, if applicable. Example usage:: + + # model is any PyTorch model + model.apply(torch.ao.quantization.enable_observer) + + """ + if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod): + mod.enable_observer() diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fuse_modules.py b/phivenv/Lib/site-packages/torch/ao/quantization/fuse_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..46662e330aad5483c5b39ae0427f6f88188218b8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/fuse_modules.py @@ -0,0 +1,216 @@ +# mypy: allow-untyped-defs +import copy +from typing import Optional + +import torch.nn as nn + +# for backward compatibility +from torch.ao.quantization.fuser_method_mappings import ( # noqa: F401 # noqa: F401 + fuse_conv_bn, + fuse_conv_bn_relu, + get_fuser_method, +) +from torch.nn.utils.parametrize import type_before_parametrizations + + +__all__ = [ + "fuse_known_modules", + "fuse_modules", + "fuse_modules_qat", +] + + +# Generalization of getattr +def _get_module(model, submodule_key): + tokens = submodule_key.split(".") + cur_mod = model + for s in tokens: + cur_mod = getattr(cur_mod, s) + return cur_mod + + +# Generalization of setattr +def _set_module(model, submodule_key, module): + tokens = submodule_key.split(".") + sub_tokens = tokens[:-1] + cur_mod = model + for s in sub_tokens: + cur_mod = getattr(cur_mod, s) + + setattr(cur_mod, tokens[-1], module) + + +def fuse_known_modules(mod_list, is_qat, additional_fuser_method_mapping=None): + r"""Return a list of known fuse modules. + + Returns a list of modules that fuses the operations specified + in the input module list. + + Fuses only the following sequence of modules: + conv, bn + conv, bn, relu + conv, relu + linear, bn + linear, relu + For these sequences, the first element in the output module list performs + the fused operation. The rest of the elements are set to nn.Identity() + """ + types = tuple(type_before_parametrizations(m) for m in mod_list) + fuser_method = get_fuser_method(types, additional_fuser_method_mapping) + if fuser_method is None: + raise NotImplementedError(f"Cannot fuse modules: {types}") + new_mod: list[Optional[nn.Module]] = [None] * len(mod_list) + fused = fuser_method(is_qat, *mod_list) + # NOTE: forward hooks not processed in the two following for loops will be lost after the fusion + # Move pre forward hooks of the base module to resulting fused module + for pre_hook_fn in mod_list[0]._forward_pre_hooks.values(): + fused.register_forward_pre_hook(pre_hook_fn) + mod_list[0]._forward_pre_hooks.clear() + # Move post forward hooks of the last module to resulting fused module + for hook_fn in mod_list[-1]._forward_hooks.values(): + fused.register_forward_hook(hook_fn) + mod_list[-1]._forward_hooks.clear() + new_mod[0] = fused + + for i in range(1, len(mod_list)): + identity = nn.Identity() + identity.training = mod_list[0].training + new_mod[i] = identity + + return new_mod + + +def _fuse_modules_helper( + model, + modules_to_fuse, + is_qat, + fuser_func=fuse_known_modules, + fuse_custom_config_dict=None, +): + if fuse_custom_config_dict is None: + fuse_custom_config_dict = {} + additional_fuser_method_mapping = fuse_custom_config_dict.get( + "additional_fuser_method_mapping", {} + ) + mod_list = [_get_module(model, item) for item in modules_to_fuse] + + # Fuse list of modules + new_mod_list = fuser_func(mod_list, is_qat, additional_fuser_method_mapping) + + # Replace original module list with fused module list + for i, item in enumerate(modules_to_fuse): + _set_module(model, item, new_mod_list[i]) + + +def _fuse_modules( + model, + modules_to_fuse, + is_qat, + inplace=False, + fuser_func=fuse_known_modules, + fuse_custom_config_dict=None, +): + if not inplace: + model = copy.deepcopy(model) + + if all(isinstance(module_element, str) for module_element in modules_to_fuse): + # Handle case of modules_to_fuse being a list + _fuse_modules_helper( + model, modules_to_fuse, is_qat, fuser_func, fuse_custom_config_dict + ) + else: + # Handle case of modules_to_fuse being a list of lists + for module_list in modules_to_fuse: + _fuse_modules_helper( + model, module_list, is_qat, fuser_func, fuse_custom_config_dict + ) + return model + + +def fuse_modules( + model, + modules_to_fuse, + inplace=False, + fuser_func=fuse_known_modules, + fuse_custom_config_dict=None, +): + r"""Fuse a list of modules into a single module. + + Fuses only the following sequence of modules: + conv, bn + conv, bn, relu + conv, relu + linear, relu + bn, relu + All other sequences are left unchanged. + For these sequences, replaces the first item in the list + with the fused module, replacing the rest of the modules + with identity. + + Args: + model: Model containing the modules to be fused + modules_to_fuse: list of list of module names to fuse. Can also be a list + of strings if there is only a single list of modules to fuse. + inplace: bool specifying if fusion happens in place on the model, by default + a new model is returned + fuser_func: Function that takes in a list of modules and outputs a list of fused modules + of the same length. For example, + fuser_func([convModule, BNModule]) returns the list [ConvBNModule, nn.Identity()] + Defaults to torch.ao.quantization.fuse_known_modules + `fuse_custom_config_dict`: custom configuration for fusion + + .. code-block:: python + + # Example of fuse_custom_config_dict + fuse_custom_config_dict = { + # Additional fuser_method mapping + "additional_fuser_method_mapping": { + (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn + }, + } + + Returns: + model with fused modules. A new copy is created if inplace=True. + + Examples:: + + >>> # xdoctest: +SKIP + >>> m = M().eval() + >>> # m is a module containing the sub-modules below + >>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']] + >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse) + >>> output = fused_m(input) + + >>> m = M().eval() + >>> # Alternately provide a single list of modules to fuse + >>> modules_to_fuse = ['conv1', 'bn1', 'relu1'] + >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse) + >>> output = fused_m(input) + + """ + return _fuse_modules( + model, + modules_to_fuse, + is_qat=False, + inplace=inplace, + fuser_func=fuser_func, + fuse_custom_config_dict=fuse_custom_config_dict, + ) + + +def fuse_modules_qat( + model, + modules_to_fuse, + inplace=False, + fuser_func=fuse_known_modules, + fuse_custom_config_dict=None, +): + """QAT version for `fuse_modules`.""" + return _fuse_modules( + model, + modules_to_fuse, + is_qat=True, + inplace=inplace, + fuser_func=fuser_func, + fuse_custom_config_dict=fuse_custom_config_dict, + ) diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fuser_method_mappings.py b/phivenv/Lib/site-packages/torch/ao/quantization/fuser_method_mappings.py new file mode 100644 index 0000000000000000000000000000000000000000..3eb4d4a61bf61b309722c433666ca0d5372a0226 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/fuser_method_mappings.py @@ -0,0 +1,292 @@ +# mypy: allow-untyped-defs +import itertools +from typing import Any, Callable, Optional, Union + +import torch.ao.nn.intrinsic as nni +import torch.nn as nn +from torch.ao.quantization.utils import get_combined_dict, MatchAllNode, Pattern + + +__all__ = [ + "fuse_conv_bn", + "fuse_conv_bn_relu", + "fuse_linear_bn", + "fuse_convtranspose_bn", + "get_fuser_method", + "get_fuser_method_new", +] + + +def fuse_conv_bn(is_qat, conv, bn): + r"""Return the fused the conv and bn modules. + Given the conv and bn modules, fuses them and returns the fused module + + Args: + is_qat: a flag for whether we are using quantization aware training fusion + or post training quantization fusion + conv: Module instance of type conv2d/conv3d + bn: Spatial BN instance that needs to be fused with the conv + + Examples:: + + >>> m1 = nn.Conv2d(10, 20, 3) + >>> b1 = nn.BatchNorm2d(20) + >>> # xdoctest: +SKIP + >>> m2 = fuse_conv_bn(m1, b1) + """ + assert conv.training == bn.training, ( + "Conv and BN both must be in the same mode (train or eval)." + ) + + fused_module_class_map = { + nn.Conv1d: nni.ConvBn1d, + nn.Conv2d: nni.ConvBn2d, + nn.Conv3d: nni.ConvBn3d, + } + + if is_qat: + assert bn.num_features == conv.out_channels, ( + "Output channel of Conv2d must match num_features of BatchNorm2d" + ) + assert bn.affine, "Only support fusing BatchNorm2d with affine set to True" + assert bn.track_running_stats, ( + "Only support fusing BatchNorm2d with tracking_running_stats set to True" + ) + fused_module_class = fused_module_class_map.get((type(conv)), None) + if fused_module_class is not None: + return fused_module_class(conv, bn) + else: + raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn)}") + else: + return nn.utils.fuse_conv_bn_eval(conv, bn) + + +def fuse_conv_bn_relu(is_qat, conv, bn, relu): + r"""Return the fused conv and bv modules. + + Given the conv and bn modules, fuses them and returns the fused module + + Args: + is_qat: a flag for whether we are using quantization aware training fusion + or post training quantization fusion + conv: Module instance of type conv2d/conv3d + bn: Spatial BN instance that needs to be fused with the conv + + Examples:: + + >>> m1 = nn.Conv2d(10, 20, 3) + >>> b1 = nn.BatchNorm2d(20) + >>> r1 = nn.ReLU(inplace=False) + >>> # xdoctest: +SKIP + >>> m2 = fuse_conv_bn_relu(m1, b1, r1) + """ + assert conv.training == bn.training == relu.training, ( + "Conv and BN both must be in the same mode (train or eval)." + ) + fused_module: Optional[type[nn.Sequential]] = None + if is_qat: + map_to_fused_module_train = { + nn.Conv1d: nni.ConvBnReLU1d, + nn.Conv2d: nni.ConvBnReLU2d, + nn.Conv3d: nni.ConvBnReLU3d, + } + assert bn.num_features == conv.out_channels, ( + "Output channel of Conv must match num_features of BatchNorm" + ) + assert bn.affine, "Only support fusing BatchNorm with affine set to True" + assert bn.track_running_stats, ( + "Only support fusing BatchNorm with tracking_running_stats set to True" + ) + fused_module = map_to_fused_module_train.get(type(conv), None) + if fused_module is not None: + return fused_module(conv, bn, relu) + else: + raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, relu)}") + else: + map_to_fused_module_eval = { + nn.Conv1d: nni.ConvReLU1d, + nn.Conv2d: nni.ConvReLU2d, + nn.Conv3d: nni.ConvReLU3d, + } + fused_module = map_to_fused_module_eval.get(type(conv), None) + if fused_module is not None: + fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn) + return fused_module(fused_conv, relu) + else: + raise NotImplementedError(f"Cannot fuse eval modules: {(conv, bn, relu)}") + + +def fuse_linear_bn(is_qat, linear, bn): + r"""Return the fused linear and bn modules. + Given the linear and bn modules, fuses them and returns the fused module + + Args: + is_qat: a flag for whether we are using quantization aware training fusion + or post training quantization fusion + linear: Module instance of type Linear + bn: BatchNorm1d instance that needs to be fused with the linear layer + + Examples:: + + >>> m1 = nn.Linear(20, 10) + >>> b1 = nn.BatchNorm1d(10) + >>> # xdoctest: +SKIP + >>> m2 = fuse_linear_bn(m1, b1) + """ + assert linear.training == bn.training, ( + "Linear and BN both must be in the same mode (train or eval)." + ) + + if is_qat: + assert bn.num_features == linear.out_features, ( + "Output features of Linear must match num_features of BatchNorm1d" + ) + assert bn.affine, "Only support fusing BatchNorm1d with affine set to True" + assert bn.track_running_stats, ( + "Only support fusing BatchNorm1d with tracking_running_stats set to True" + ) + return nni.LinearBn1d(linear, bn) + else: + return nn.utils.fusion.fuse_linear_bn_eval(linear, bn) + + +def fuse_convtranspose_bn(is_qat, convt, bn): + r"""Return the fused ConvTranspose and bn modules. + Given ConvTranspose and bn modules, fuses them and returns the fused module + + Args: + convt: Module instance of type ConvTransposeNd + bn: BatchNormNd instance that needs to be fused with the linear layer. + batch norm N should match the ConvTranspose N + + Examples:: + + >>> m1 = nn.ConvTranspose2d(10, 20, 3) + >>> b1 = nn.BatchNorm2d(20) + >>> # xdoctest: +SKIP + >>> m2 = fuse_convtranspose_bn(m1, b1) + """ + assert convt.training == bn.training, ( + "ConvTranspose and BN both must be in the same mode (train or eval)." + ) + + if is_qat: + raise Exception( # noqa: TRY002 + "Fusing ConvTranspose+BatchNorm not yet supported in QAT." + ) + else: + return nn.utils.fusion.fuse_conv_bn_eval(convt, bn, transpose=True) + + +def _sequential_wrapper2(sequential): + """Return a sequential wrapped that for is_qat and two modules. + Given a sequential class for two modules, return a function that takes + is_qat, and then two modules as argument, that ignores the is_qat flag + and always returns the sequential that combines the two input modules + """ + + def fuser_method(is_qat, m1, m2): + return sequential(m1, m2) + + return fuser_method + + +_DEFAULT_OP_LIST_TO_FUSER_METHOD: dict[tuple, Union[nn.Sequential, Callable]] = { + (nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn, + (nn.Conv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu, + (nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn, + (nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu, + (nn.Conv3d, nn.BatchNorm3d): fuse_conv_bn, + (nn.Conv3d, nn.BatchNorm3d, nn.ReLU): fuse_conv_bn_relu, + (nn.Conv1d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU1d), + (nn.Conv2d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU2d), + (nn.Conv3d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU3d), + (nn.Linear, nn.BatchNorm1d): fuse_linear_bn, + (nn.Linear, nn.ReLU): _sequential_wrapper2(nni.LinearReLU), + (nn.BatchNorm2d, nn.ReLU): _sequential_wrapper2(nni.BNReLU2d), + (nn.BatchNorm3d, nn.ReLU): _sequential_wrapper2(nni.BNReLU3d), + (nn.ConvTranspose1d, nn.BatchNorm1d): fuse_convtranspose_bn, + (nn.ConvTranspose2d, nn.BatchNorm2d): fuse_convtranspose_bn, + (nn.ConvTranspose3d, nn.BatchNorm3d): fuse_convtranspose_bn, +} + + +def get_fuser_method(op_list, additional_fuser_method_mapping=None): + """Get fuser method for the given list of module types. + + Get fuser method for the given list of module types, + return None if fuser method does not exist + """ + if additional_fuser_method_mapping is None: + additional_fuser_method_mapping = {} + all_mappings = get_combined_dict( + _DEFAULT_OP_LIST_TO_FUSER_METHOD, additional_fuser_method_mapping + ) + fuser_method = all_mappings.get(op_list, None) + assert fuser_method is not None, f"did not find fuser method for: {op_list} " + return fuser_method + + +def _reverse2(f): + def reversed(is_qat, x, y): + return f(is_qat, y, x) + + return reversed + + +def _reverse3(f): + def reversed(is_qat, x, w): + y, z = w + return f(is_qat, z, y, x) + + return reversed + + +def _get_valid_patterns(op_pattern): + """Return a list of valid patterns generated from the op_pattern. + + Returns a list of valid patterns generated from the op_pattern, + since MatchAllNode can match all types of nodes, + e.g. pattern (torch.nn.Conv2d, torch.add) should also be able to match keys like + (MatchAllNode, torch.add) and (torch.nn.Conv2d, MatchAllNode) + + Example Input: + (torch.add, (torch.nn.ReLU, torch.nn.Conv2d)) + + Example Output: + [(torch.add, (torch.nn.ReLU, torch.nn.Conv2d)), + (torch.add, (torch.nn.ReLU, MatchAllNode)), + (torch.add, (MatchAllNode, torch.nn.Conv2d)), + (torch.add, (MatchAllNode, MatchAllNode)), + (MatchAllNode, (torch.nn.ReLU, torch.nn.Conv2d)), + (MatchAllNode, (torch.nn.ReLU, MatchAllNode)), + (MatchAllNode, (MatchAllNode, torch.nn.Conv2d)), + (MatchAllNode, (MatchAllNode, MatchAllNode)), + ] + """ + result: list[Any] + if isinstance(op_pattern, (tuple, list)): + sub_combs = [_get_valid_patterns(sub_pattern) for sub_pattern in op_pattern] + result = list(itertools.product(*sub_combs)) + else: + result = [op_pattern, MatchAllNode] + return result + + +def get_fuser_method_new( + op_pattern: Pattern, + fuser_method_mapping: dict[Pattern, Union[nn.Sequential, Callable]], +): + """Get fuser method. + + This will be made default after we deprecate the get_fuser_method + Would like to implement this first and have a separate PR for deprecation + """ + op_patterns = _get_valid_patterns(op_pattern) + fuser_method = None + for op_pattern in op_patterns: + fuser_method = fuser_method_mapping.get(op_pattern, None) + if fuser_method is not None: + break + assert fuser_method is not None, f"did not find fuser method for: {op_pattern} " + return fuser_method diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/__init__.py b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7bab98f39d0a4692ac020831e543d14f6ad3f16a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__init__.py @@ -0,0 +1,3 @@ +from .convert import convert +from .fuse import fuse +from .prepare import prepare diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bec64711473f990a2992e62e0eab70d421e734d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/_decomposed.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/_decomposed.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d81bdafa37a0fc3dae5fb441aaa441f7ec8ccdbf Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/_decomposed.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/_equalize.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/_equalize.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f51dc1dfe7fe83e791baab6d4f073f91d681d627 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/_equalize.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/_lower_to_native_backend.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/_lower_to_native_backend.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7d48eaf30619c6e122e643a0233327f203146e8 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/_lower_to_native_backend.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/convert.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/convert.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4c780e6dd4217606c463dacbfd0e7a7bc548916 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/convert.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/custom_config.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/custom_config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61721665ae36aab47ecec423d56a997251a35d1d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/custom_config.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/fuse.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/fuse.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bfc701f4d5d50219234e281b5dbe476dd829128b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/fuse.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/fuse_handler.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/fuse_handler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0c2b5c82757b9a7f699ad4e85f1726f39879be4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/fuse_handler.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/graph_module.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/graph_module.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ddf52c0700b11e4e009f19b8db2b608238e7fdfb Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/graph_module.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/lower_to_fbgemm.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/lower_to_fbgemm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8de51e69c0876e480f8bf89250e10eca9577cd55 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/lower_to_fbgemm.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/lower_to_qnnpack.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/lower_to_qnnpack.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21da465dbd5625c2fbfb88e62b08bb67b7a8be33 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/lower_to_qnnpack.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/lstm_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/lstm_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11cfb1c1d733cb2b6bff6caf6ad92a89e3f3c8af Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/lstm_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/match_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/match_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb237c807a8491472d0a866417aef4ea3d5f1c17 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/match_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/pattern_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/pattern_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd7bc57c07374355fdacd2d9dfae8fdad7b68e12 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/pattern_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/prepare.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/prepare.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed80743f18ee9ee0026fbe87ec18eb49e10347c2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/prepare.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/qconfig_mapping_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/qconfig_mapping_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..130071377a07bde4bb7c7c8e7a0b8f22ea92d448 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/qconfig_mapping_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/quantize_handler.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/quantize_handler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..836ecf4d06088a80cb88b3010968c8d218a9551c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/quantize_handler.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/tracer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/tracer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ab33c16d5f620b162781ff648bb68ba8aa48822 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/tracer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd8dc5d2efe2c54cd7090e0a469eafce175df9d1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/fx/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/_decomposed.py b/phivenv/Lib/site-packages/torch/ao/quantization/fx/_decomposed.py new file mode 100644 index 0000000000000000000000000000000000000000..1f0260f4dc280e879c2cf266cf15df2455b3625f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/fx/_decomposed.py @@ -0,0 +1,1221 @@ +# mypy: allow-untyped-defs +import math +from typing import Optional + +import torch +from torch._refs import _unsqueeze_multiple +from torch.ao.quantization.utils import determine_qparams, validate_qmin_qmax +from torch.library import impl, Library + + +# Note: decomposed means decomposed quantized tensor, using decomposed so that the +# name is not too long +quantized_decomposed_lib = Library("quantized_decomposed", "DEF") + +_INTEGER_DTYPES = [torch.uint8, torch.int8, torch.uint16, torch.int16, torch.int32] +_FLOAT_DTYPES = [torch.float8_e5m2, torch.float8_e4m3fn] + +_DTYPE_TO_QVALUE_BOUNDS = { + k: (torch.iinfo(k).min, torch.iinfo(k).max) for k in _INTEGER_DTYPES +} +_DTYPE_TO_QVALUE_BOUNDS.update( + {k: (int(torch.finfo(k).min), int(torch.finfo(k).max)) for k in _FLOAT_DTYPES} +) + + +# Helper to check the passed in quant min and max are valid for the dtype +def _quant_min_max_bounds_check(quant_min, quant_max, dtype): + if dtype not in _DTYPE_TO_QVALUE_BOUNDS: + raise ValueError(f"Unsupported dtype: {dtype}") + quant_min_lower_bound, quant_max_upper_bound = _DTYPE_TO_QVALUE_BOUNDS[dtype] + + assert quant_min >= quant_min_lower_bound, ( + "quant_min out of bound for dtype, " + f"quant_min_lower_bound: {quant_min_lower_bound} quant_min: {quant_min}" + ) + + assert quant_max <= quant_max_upper_bound, ( + "quant_max out of bound for dtype, " + f"quant_max_upper_bound: {quant_max_upper_bound} quant_max: {quant_max}" + ) + + +quantized_decomposed_lib.define( + "quantize_per_tensor(Tensor input, float scale, int zero_point, " + "int quant_min, int quant_max, ScalarType dtype) -> Tensor" +) + + +@impl(quantized_decomposed_lib, "quantize_per_tensor", "CompositeExplicitAutograd") +def quantize_per_tensor( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + """Affine quantization for the Tensor using the same quantization parameters to map + from floating point to quantized values + + Args: + input (torch.Tensor): original float32 or bfloat16 Tensor + scale (float): quantization parameter for affine quantization + zero_point (int): quantization parameter for affine quantization + quant_min (int): minimum quantized value for output Tensor + quant_max (int): maximum quantized value for output Tensor + dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor + + Returns: + Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters + are not stored in the Tensor, we are storing them in function arguments instead + """ + if input.dtype in [torch.float16, torch.bfloat16]: + input = input.to(torch.float32) + assert input.dtype == torch.float32, ( + f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + ) + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + + inv_scale = 1.0 / scale + return torch.clamp( + torch.round(input * inv_scale) + zero_point, quant_min, quant_max + ).to(dtype) + + +@impl(quantized_decomposed_lib, "quantize_per_tensor", "Meta") +def quantize_per_tensor_meta( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + if input.dtype in [torch.float16, torch.bfloat16]: + input = input.to(torch.float32) + assert input.dtype == torch.float32, ( + f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + ) + return torch.empty_like(input, dtype=dtype) + + +quantized_decomposed_lib.define( + "quantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, " + "int quant_min, int quant_max, ScalarType dtype) -> Tensor" +) + + +@impl( + quantized_decomposed_lib, "quantize_per_tensor.tensor", "CompositeExplicitAutograd" +) +def quantize_per_tensor_tensor( + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + """Affine quantization for the Tensor using the same quantization parameters to map + from floating point to quantized values + Same as `quantize_per_tensor` but scale and zero_point are Scalar Tensor instead of + scalar values + """ + assert zero_point.numel() == 1, ( + f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" + ) + assert scale.numel() == 1, ( + f"Expecting scale tensor to be one element, but received : {scale.numel()}" + ) + return quantize_per_tensor( + input, + scale.item(), + zero_point.item(), # type: ignore[arg-type] + quant_min, # type: ignore[arg-type] + quant_max, # type: ignore[arg-type] + dtype, + ) + + +@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor", "Meta") +def quantize_per_tensor_tensor_meta( + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + if input.dtype in [torch.float16, torch.bfloat16]: + input = input.to(torch.float32) + assert zero_point.numel() == 1, ( + f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" + ) + assert scale.numel() == 1, ( + f"Expecting scale tensor to be one element, but received : {scale.numel()}" + ) + assert input.dtype == torch.float32, ( + f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + ) + return torch.empty_like(input, dtype=dtype) + + +# TODO: remove other variants and keep this one +quantized_decomposed_lib.define( + "quantize_per_tensor.tensor2(Tensor input, Tensor scale, Tensor zero_point, " + "Tensor quant_min, Tensor quant_max, ScalarType dtype) -> Tensor" +) + + +@impl( + quantized_decomposed_lib, "quantize_per_tensor.tensor2", "CompositeExplicitAutograd" +) +def quantize_per_tensor_tensor2( + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + quant_min: torch.Tensor, + quant_max: torch.Tensor, + dtype: torch.dtype, +) -> torch.Tensor: + """Affine quantization for the Tensor using the same quantization parameters to map + from floating point to quantized values + Same as `quantize_per_tensor` but scale and zero_point are Scalar Tensor instead of + scalar values + """ + assert zero_point.numel() == 1, ( + f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" + ) + assert scale.numel() == 1, ( + f"Expecting scale tensor to be one element, but received : {scale.numel()}" + ) + return quantize_per_tensor( + input, + scale.item(), + zero_point.item(), # type: ignore[arg-type] + quant_min.item(), # type: ignore[arg-type] + quant_max.item(), # type: ignore[arg-type] + dtype, + ) + + +@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor2", "Meta") +def quantize_per_tensor_tensor2_meta( + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + quant_min: torch.Tensor, + quant_max: torch.Tensor, + dtype: torch.dtype, +) -> torch.Tensor: + return quantize_per_tensor_tensor_meta( + input, + scale, + zero_point, # type: ignore[arg-type] + quant_min, # type: ignore[arg-type] + quant_max, # type: ignore[arg-type] + dtype, + ) + + +# Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in +# the signature as metadata for the input Tensor, this might be useful for pattern +# matching in the future +# We will revisit this later if we found there are no use cases for it +quantized_decomposed_lib.define( + "dequantize_per_tensor(Tensor input, float scale, int zero_point, " + "int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor" +) + + +@impl(quantized_decomposed_lib, "dequantize_per_tensor", "CompositeExplicitAutograd") +def dequantize_per_tensor( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + *, + out_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + """Affine dequantization for the Tensor using the same quantization parameters to map + from quantized values to floating point values + + Args: + input (torch.Tensor): Tensor with dtype matching `dtype` argument, + e.g. (`torch.uint8`), it is a per tensor quantized Tensor if combined with + quantization parameters in the argument of this function (scale/zero_point) + + scale (float): quantization parameter for affine quantization + + zero_point (int): quantization parameter for affine quantization + + quant_min (int): minimum quantized value for input Tensor (not used in computation, + reserved for pattern matching) + + quant_max (int): maximum quantized value for input Tensor (not used in computation, + reserved for pattern matching) + + dtype (torch.dtype): dtype for input Tensor (not used in computation, + reserved for pattern matching) + + out_dtype (torch.dtype?): optional dtype for output Tensor + + Returns: + dequantized float32 Tensor + """ + assert input.dtype == dtype, ( + f"Expecting input to have dtype: {dtype}, but got {input.dtype}" + ) + if out_dtype is None: + out_dtype = torch.float32 + if dtype in _DTYPE_TO_QVALUE_BOUNDS: + # TODO: investigate why + # (input - zero_point).to(torch.float32) * scale + # failed the test + return (input.to(out_dtype) - zero_point) * scale + else: + raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}") + + +@impl(quantized_decomposed_lib, "dequantize_per_tensor", "Meta") +def dequantize_per_tensor_meta( + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + *, + out_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + if out_dtype is None: + out_dtype = torch.float32 + return torch.empty_like(input, dtype=out_dtype) + + +quantized_decomposed_lib.define( + "dequantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, " + "int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor" +) + + +@impl( + quantized_decomposed_lib, + "dequantize_per_tensor.tensor", + "CompositeExplicitAutograd", +) +def dequantize_per_tensor_tensor( + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + *, + out_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + """Affine dequantization for the Tensor using the same quantization parameters to map + from quantized values to floating point values + Same as `dequantize_per_tensor` but scale and zero_point are Scalar Tensor instead of + scalar values + """ + assert zero_point.numel() == 1, ( + f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" + ) + assert scale.numel() == 1, ( + f"Expecting scale tensor to be one element, but received : {scale.numel()}" + ) + return dequantize_per_tensor( + input, + scale.item(), + zero_point.item(), # type: ignore[arg-type] + quant_min, + quant_max, + dtype, + out_dtype=out_dtype, + ) + + +@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor", "Meta") +def dequantize_per_tensor_tensor_meta( + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + *, + out_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + if out_dtype is None: + out_dtype = torch.float32 + assert zero_point.numel() == 1, ( + f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" + ) + assert scale.numel() == 1, ( + f"Expecting scale tensor to be one element, but received : {scale.numel()}" + ) + assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}" + if dtype in _DTYPE_TO_QVALUE_BOUNDS: + return torch.empty_like(input, dtype=out_dtype) + else: + raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}") + + +# TODO: remove other variants and keep this one +quantized_decomposed_lib.define( + "dequantize_per_tensor.tensor2(Tensor input, Tensor scale, Tensor zero_point, " + "Tensor quant_min, Tensor quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor" +) + + +@impl( + quantized_decomposed_lib, + "dequantize_per_tensor.tensor2", + "CompositeExplicitAutograd", +) +def dequantize_per_tensor_tensor2( + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + quant_min: torch.Tensor, + quant_max: torch.Tensor, + dtype: torch.dtype, + *, + out_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + """Affine dequantization for the Tensor using the same quantization parameters to map + from quantized values to floating point values + Same as `dequantize_per_tensor` but scale and zero_point are Scalar Tensor instead of + scalar values + """ + assert zero_point.numel() == 1, ( + f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" + ) + assert scale.numel() == 1, ( + f"Expecting scale tensor to be one element, but received : {scale.numel()}" + ) + return dequantize_per_tensor( + input, + scale.item(), + zero_point.item(), # type: ignore[arg-type] + quant_min.item(), # type: ignore[arg-type] + quant_max.item(), # type: ignore[arg-type] + dtype, + out_dtype=out_dtype, + ) + + +@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor2", "Meta") +def dequantize_per_tensor_tensor2_meta( + input, + scale, + zero_point, + quant_min, + quant_max, + dtype, + *, + out_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + return dequantize_per_tensor_tensor_meta( + input, scale, zero_point, quant_min, quant_max, dtype, out_dtype=out_dtype + ) + + +quantized_decomposed_lib.define( + "choose_qparams.tensor(Tensor input, int quant_min, int quant_max, " + "float eps, ScalarType dtype) -> (Tensor, Tensor)" +) + + +@impl(quantized_decomposed_lib, "choose_qparams.tensor", "CompositeExplicitAutograd") +def choose_qparams_tensor( + input: torch.Tensor, qmin: int, qmax: int, eps: float, dtype: torch.dtype +) -> tuple[torch.Tensor, torch.Tensor]: + """Given an input Tensor, derive the per tensor affine quantization parameter + (scale and zero_point) for target quantized Tensor from the Tensor + + Args: + input (torch.Tensor): floating point input Tensor + quant_min (int): minimum quantized value for target quantized Tensor + quant_max (int): maximum quantized value for target quantized Tensor + dtype (torch.dtype): dtype for target quantized Tensor + + Returns: + scale (float): quantization parameter for the target quantized Tensor + zero_point (int): quantization parameter for the target quantized Tensor + """ + assert input.dtype in [ + torch.float32, + torch.float16, + torch.bfloat16, + ], ( + f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}" + ) + assert dtype in _DTYPE_TO_QVALUE_BOUNDS, ( + f"Expecting target dtype to be one of {_DTYPE_TO_QVALUE_BOUNDS.keys()}, but got: {dtype}" + ) + validate_qmin_qmax(qmin, qmax) + + min_val, max_val = torch.aminmax(input) + + return determine_qparams( + min_val, + max_val, + qmin, + qmax, + dtype, + torch.Tensor([eps]), + has_customized_qrange=False, + ) + + +quantized_decomposed_lib.define( + "choose_qparams_symmetric.tensor(Tensor input, int quant_min, int quant_max, " + "float eps, ScalarType dtype) -> (Tensor, Tensor)" +) + + +@impl( + quantized_decomposed_lib, + "choose_qparams_symmetric.tensor", + "CompositeExplicitAutograd", +) +def choose_qparams_symmetric_tensor( + input: torch.Tensor, qmin: int, qmax: int, eps: float, dtype: torch.dtype +) -> tuple[torch.Tensor, torch.Tensor]: + """Given an input Tensor, derive the per tensor affine quantization parameter + (scale and zero_point) for target quantized Tensor from the Tensor + + Args: + input (torch.Tensor): floating point input Tensor + quant_min (int): minimum quantized value for target quantized Tensor + quant_max (int): maximum quantized value for target quantized Tensor + dtype (torch.dtype): dtype for target quantized Tensor + + Returns: + scale (float): quantization parameter for the target quantized Tensor + zero_point (int): quantization parameter for the target quantized Tensor + """ + assert input.dtype in [ + torch.float32, + torch.float16, + torch.bfloat16, + ], ( + f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}" + ) + assert dtype in _DTYPE_TO_QVALUE_BOUNDS, ( + f"Expecting target dtype to be one of {_DTYPE_TO_QVALUE_BOUNDS.keys()}, but got: {dtype}" + ) + validate_qmin_qmax(qmin, qmax) + + min_val, max_val = torch.aminmax(input) + return determine_qparams( + min_val, + max_val, + qmin, + qmax, + dtype, + torch.Tensor([eps]), + has_customized_qrange=False, + qscheme=torch.per_tensor_symmetric, + ) + + +@impl(quantized_decomposed_lib, "choose_qparams.tensor", "Meta") +def choose_qparams_tensor_meta( + input: torch.Tensor, quant_min: int, quant_max: int, eps: float, dtype: torch.dtype +) -> tuple[torch.Tensor, torch.Tensor]: + assert input.dtype in [ + torch.float32, + torch.float16, + torch.bfloat16, + ], ( + f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}" + ) + assert quant_min < quant_max, ( + f"Expecting quant_min to be smaller than quant_max but received min: \ + {quant_min} max: {quant_max}" + ) + return torch.empty(1, dtype=torch.double, device=input.device), torch.empty( + 1, dtype=torch.int64, device=input.device + ) + + +@impl(quantized_decomposed_lib, "choose_qparams_symmetric.tensor", "Meta") +def choose_qparams_symmetric_tensor_meta( + input: torch.Tensor, quant_min: int, quant_max: int, eps: float, dtype: torch.dtype +) -> tuple[torch.Tensor, torch.Tensor]: + return torch.empty(1, dtype=torch.double, device=input.device), torch.empty( + 1, dtype=torch.int64, device=input.device + ) + + +# Helper function used to implement per-channel quantization against any axis +def _permute_to_axis_zero(x, axis): + new_axis_list = list(range(x.dim())) + new_axis_list[axis] = 0 + new_axis_list[0] = axis + y = x.permute(tuple(new_axis_list)) + return y, new_axis_list + + +quantized_decomposed_lib.define( + "quantize_per_channel(Tensor input, Tensor scales, Tensor zero_points, int axis, " + "int quant_min, int quant_max, ScalarType dtype) -> Tensor" +) + + +@impl(quantized_decomposed_lib, "quantize_per_channel", "CompositeExplicitAutograd") +def quantize_per_channel( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + axis: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + """Affine per channel quantization for the Tensor using the same quantization + parameters for each channel/axis to map from floating point to quantized values + + Args: + input (torch.Tensor): original float32 or bfloat16 Tensor + scales (torch.Tensor): a list of scale quantization parameter for + affine quantization, one per channel + zero_point (torch.Tensor): a list of zero_point quantization parameter for + affine quantization, one per channel + quant_min (int): minimum quantized value for output Tensor + quant_max (int): maximum quantized value for output Tensor + dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor + + Returns: + Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters + are not stored in the Tensor, we are storing them in function arguments instead + """ + if input.dtype in [torch.float16, torch.bfloat16]: + input = input.to(torch.float32) + assert input.dtype == torch.float32, ( + f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + ) + assert axis < input.dim(), f"Expecting axis to be < {input.dim()}" + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + input, permute_axis_list = _permute_to_axis_zero(input, axis) + + new_shape = [1] * input.dim() + new_shape[0] = scales.shape[0] + scales = scales.view(new_shape) + zero_points = zero_points.view(new_shape) + + res = torch.clamp( + torch.round(input * (1.0 / scales)) + zero_points, quant_min, quant_max + ) + out = res.permute(tuple(permute_axis_list)) + return out.to(dtype) + + +@impl(quantized_decomposed_lib, "quantize_per_channel", "Meta") +def quantize_per_channel_meta( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + axis: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + if input.dtype in [torch.float16, torch.bfloat16]: + input = input.to(torch.float32) + assert input.dtype == torch.float32, ( + f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + ) + assert axis < input.dim(), f"Expecting axis to be < {input.dim()}" + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + return torch.empty_like(input, dtype=dtype) + + +# Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in +# the signature as metadata for the input Tensor, this might be useful for pattern +# matching in the future +# We will revisit this later if we found there are no use cases for it +quantized_decomposed_lib.define( + "dequantize_per_channel(Tensor input, Tensor scales, Tensor? zero_points, int axis, " + "int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor" +) + + +@impl(quantized_decomposed_lib, "dequantize_per_channel", "CompositeExplicitAutograd") +def dequantize_per_channel( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: Optional[torch.Tensor], + axis: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + *, + out_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + """Affine per channel dequantization for the Tensor using the same quantization + parameters for each channel/axis to map from quantized values to floating point values + + Args: + input (torch.Tensor): Tensor with dtype matching `dtype` argument, + e.g. (`torch.uint8`), it is a per channel quantized Tensor if combined with + quantization parameter in the argument of this function (scales/zero_points/axis) + + scales (torch.Tensor): a list of scale quantization parameter for + affine quantization, one per channel + + zero_points (torch.Tensor): a list of zero_point quantization parameter for + affine quantization, one per channel + + quant_min (int): minimum quantized value for output Tensor (not used in computation, + reserved for pattern matching) + + quant_max (int): maximum quantized value for output Tensor (not used in computation, + reserved for pattern matching) + + dtype (torch.dtype): requested dtype for output Tensor (not used in computation, + reserved for pattern matching) + + out_dtype (torch.dtype?): optional dtype for output Tensor + + Returns: + dequantized float32 Tensor + """ + assert input.dtype == dtype, ( + f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}" + ) + if out_dtype is None: + out_dtype = torch.float32 + assert axis < input.dim(), f"Expecting axis to be < {input.dim()}" + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + input, permute_axis_list = _permute_to_axis_zero(input, axis) + + new_shape = [1] * input.dim() + new_shape[0] = scales.shape[0] + scales = scales.view(new_shape) + if zero_points is not None: + res = (input - zero_points.view(new_shape)) * scales + else: + res = input * scales + + res = res.to(out_dtype) + + out = res.permute(tuple(permute_axis_list)) + return out + + +@impl(quantized_decomposed_lib, "dequantize_per_channel", "Meta") +def dequantize_per_channel_meta( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: Optional[torch.Tensor], + axis: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + *, + out_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + assert input.dtype == dtype, ( + f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}" + ) + if out_dtype is None: + out_dtype = torch.float32 + assert axis < input.dim(), f"Expecting axis to be < {input.dim()}" + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + return torch.empty_like(input, dtype=out_dtype) + + +quantized_decomposed_lib.define( + "choose_qparams_per_token(Tensor input, ScalarType dtype) -> (Tensor, Tensor)" +) + + +@impl( + quantized_decomposed_lib, + "choose_qparams_per_token", + "CompositeExplicitAutograd", +) +def choose_qparams_per_token( + input: torch.Tensor, + dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + """Choose quantization parameters for per token quantization. This means for a N dimension Tensor + (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize + every N elements with the same quantization parameter. The dimension for scales/zero_points + will be (M1 * M2 ... * Mn) + + Args: + input (torch.Tensor): original float32/float16 Tensor + dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor + + Returns: + scales and zero_points, both float32 Tensors + """ + + scales = input.abs().amax(dim=-1, keepdim=True) + if scales.dtype == torch.float16: + scales = ( + scales.float() + ) # want float scales to avoid overflows for fp16, (bf16 has wide enough range) + if dtype == torch.int8: + n_bits = 8 + quant_max = 2 ** (n_bits - 1) - 1 + else: + raise Exception( # noqa: TRY002 + f"unsupported dtype in choose_qparams_per_token: {dtype}" + ) + + scales = scales.clamp(min=1e-5).div(quant_max) + zero_points = torch.zeros_like(scales) + return scales, zero_points + + +@impl( + quantized_decomposed_lib, + "choose_qparams_per_token", + "Meta", +) +def choose_qparams_per_token_meta( + input: torch.Tensor, + dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + size = list(input.shape[:-1]) + [1] + return torch.empty(size, dtype=torch.double, device=input.device), torch.empty( + size, dtype=torch.int64, device=input.device + ) + + +quantized_decomposed_lib.define( + "_choose_qparams_per_token_asymmetric_impl(Tensor input, ScalarType dtype) -> (Tensor, Tensor)" +) + + +@impl( + quantized_decomposed_lib, + "_choose_qparams_per_token_asymmetric_impl", + "CompositeImplicitAutograd", +) +def _choose_qparams_per_token_asymmetric_impl( + input: torch.Tensor, + dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + """Choose quantization parameters for per token quantization. This means for a N dimension Tensor + (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize + every N elements with the same quantization parameter. The dimension for scales/zero_points + will be (M1 * M2 ... * Mn) + + Args: + input (torch.Tensor): original float32/float16 Tensor + dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor + + Returns: + scales and zero_points, both float32 Tensors + """ + # Based on https://github.com/google/XNNPACK/blob/df156f0cf3db5a4576cc711123eeb54915f82ffc/src/xnnpack/quantization.h#L18 + qmin, qmax = -128, 127 + min_val = torch.amin(input, dim=-1, keepdim=True) + max_val = torch.amax(input, dim=-1, keepdim=True) + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + eps = torch.finfo(torch.float32).eps # use xnnpack eps? + + # scale + scale = (max_val_pos - min_val_neg) / float(qmax - qmin) + scale = scale.clamp(min=eps) + + # zero point + descaled_min = min_val_neg / scale + descaled_max = max_val_pos / scale + zero_point_from_min_error = qmin + descaled_min + zero_point_from_max_error = qmax + descaled_max + zero_point = torch.where( + zero_point_from_min_error + zero_point_from_max_error > 0, + qmin - descaled_min, + qmax - descaled_max, + ) + zero_point = torch.clamp(zero_point, qmin, qmax).round() + + return scale.to(torch.float64), zero_point.to(torch.int64) + + +quantized_decomposed_lib.define( + "choose_qparams_per_token_asymmetric(Tensor input, ScalarType dtype) -> (Tensor, Tensor)" +) + + +@impl( + quantized_decomposed_lib, + "choose_qparams_per_token_asymmetric", + "CompositeExplicitAutograd", +) +def choose_qparams_per_token_asymmetric( + input: torch.Tensor, + dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + return _choose_qparams_per_token_asymmetric_impl(input, dtype) + + +@impl( + quantized_decomposed_lib, + "choose_qparams_per_token_asymmetric", + "Meta", +) +def choose_qparams_per_token_asymmetric_meta( + input: torch.Tensor, + dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + size = list(input.shape[:-1]) + [1] + return torch.empty(size, dtype=torch.double, device=input.device), torch.empty( + size, dtype=torch.int64, device=input.device + ) + + +def _per_token_quant_qparam_dim_check(input, scales, zero_points): + num_tokens = math.prod(list(input.size())[:-1]) + assert num_tokens == scales.numel(), ( + f"num_tokens: {num_tokens} scales: {scales.size()}" + ) + assert num_tokens == zero_points.numel(), ( + f"num_tokens: {num_tokens} zero_points: {zero_points.size()}" + ) + + +quantized_decomposed_lib.define( + "quantize_per_token(Tensor input, Tensor scales, Tensor zero_points, " + "int quant_min, int quant_max, ScalarType dtype) -> Tensor" +) + + +@impl(quantized_decomposed_lib, "quantize_per_token", "CompositeExplicitAutograd") +def quantize_per_token( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +): + """Per token quantization for the Tensor using the quantization parameters to map + from floating point to quantized values. This means for a N dimension Tensor + (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize + every N elements with the same quantization parameter. The dimension for scales/zero_points + will be (M1 * M2 ... * Mn) + + Args: + input (torch.Tensor): original float32 or bfloat16 Tensor + scales (float32 torch.Tensor): quantization parameter for per token affine quantization + zero_points (int32 torch.Tensor): quantization parameter for per token affine quantization + quant_min (int): minimum quantized value for output Tensor + quant_max (int): maximum quantized value for output Tensor + dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor + + Returns: + Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters + are not stored in the Tensor, we are storing them in function arguments instead + """ + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + _per_token_quant_qparam_dim_check(input, scales, zero_points) + input = ( + input.mul(1.0 / scales) + .add(zero_points) + .round() + .clamp(quant_min, quant_max) + .to(dtype) + ) + return input + + +@impl(quantized_decomposed_lib, "quantize_per_token", "Meta") +def quantize_per_token_meta( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +): + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + return torch.empty_like(input, dtype=dtype) + + +quantized_decomposed_lib.define( + "dequantize_per_token(Tensor input, Tensor scales, Tensor zero_points, " + "int quant_min, int quant_max, ScalarType dtype, ScalarType output_dtype) -> Tensor" +) + + +@impl(quantized_decomposed_lib, "dequantize_per_token", "CompositeExplicitAutograd") +def dequantize_per_token( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + output_dtype: torch.dtype = torch.float32, +): + """Per token dequantization for the Tensor using the quantization parameters to map + from floating point to quantized values. This means for a N dimension Tensor + (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize + every N elements with the same quantization parameter. The dimension for scales/zero_points + will be (M1 * M2 ... * Mn) + + Args: + input (torch.Tensor): quantized Tensor (uint8, int8 etc.) + scales (float64 torch.Tensor): quantization parameter for per token affine quantization + zero_points (int64 torch.Tensor): quantization parameter for per token affine quantization + quant_min (int): minimum quantized value for input Tensor + quant_max (int): maximum quantized value for input Tensor + dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor + output_dtype (torch.dtype): dtype (e.g. torch.float32) for output Tensor + + Returns: + dequantized Tensor with dtype `output_dtype` + """ + input = input - zero_points + input = input * scales + # Since scales are of float64 type, we need to cast it to output dtype requested + return input.to(output_dtype) + + +@impl(quantized_decomposed_lib, "dequantize_per_token", "Meta") +def dequantize_per_token_meta( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + output_dtype: torch.dtype = torch.float32, +): + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + # TODO: support fp16 + return torch.empty_like(input, dtype=output_dtype) + + +quantized_decomposed_lib.define( + "quantize_per_channel_group(Tensor input, Tensor scales, Tensor zero_points, int quant_min, " + "int quant_max, ScalarType dtype, int group_size) -> Tensor" +) + + +# TODO: dtype is ignored for now +@impl( + quantized_decomposed_lib, "quantize_per_channel_group", "CompositeExplicitAutograd" +) +def quantize_per_channel_group( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + group_size=128, +): + assert group_size > 1 + # needed for GPTQ single column quantize + if group_size > input.shape[-1] and scales.shape[-1] == 1: + group_size = input.shape[-1] + + assert input.shape[-1] % group_size == 0 + assert input.dim() == 2 + + # TODO: check for dtype, currently we can't express torch.int4 so it's omitted + to_quant = input.reshape(-1, group_size) + assert torch.isnan(to_quant).sum() == 0 + + scales = scales.reshape(-1, 1) + zero_points = zero_points.reshape(-1, 1) + + input_int8 = ( + to_quant.mul(1.0 / scales) + .add(zero_points) + .round() + .clamp_(quant_min, quant_max) + .to(dtype) + .reshape_as(input) + ) + + return input_int8 + + +@impl(quantized_decomposed_lib, "quantize_per_channel_group", "Meta") +def quantize_per_channel_group_meta( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + group_size=128, +): + """Groupwise quantization within each channel for an 2-d Tensor using the quantization parameters + to map from floating point to quantized values. This means for each row of a 2-d Tensor + (M, N), we calculate scales/zero_points for each `group_size` elements + and quantize every `group_size` elements with the same quantization parameter. + The dimension for scales/zero_points will be (M * ceil(N, group_size),) + + Args: + input (torch.Tensor): original float32 or bfloat16 Tensor + scales (float32 torch.Tensor): quantization parameter for per channel group affine quantization + zero_points (int32 torch.Tensor): quantization parameter for per channel group affine quantization + quant_min (int): minimum quantized value for output Tensor + quant_max (int): maximum quantized value for output Tensor + dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor + + Returns: + Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters + are not stored in the Tensor, we are storing them in function arguments instead + """ + assert group_size > 1 + # needed for GPTQ single column quantize + if group_size > input.shape[-1] and scales.shape[-1] == 1: + group_size = input.shape[-1] + + assert input.shape[-1] % group_size == 0 + assert input.dim() == 2 + return torch.empty_like(input, dtype=dtype) + + +quantized_decomposed_lib.define( + "dequantize_per_channel_group(Tensor input, Tensor scales, Tensor? zero_points, int quant_min, " + "int quant_max, ScalarType dtype, int group_size, ScalarType output_dtype) -> Tensor" +) + + +@impl( + quantized_decomposed_lib, + "dequantize_per_channel_group", + "CompositeExplicitAutograd", +) +def dequantize_per_channel_group( + w_int8: torch.Tensor, + scales: torch.Tensor, + zero_points: Optional[torch.Tensor], + quant_min: int, + quant_max: int, + dtype: torch.dtype, + group_size: int = 128, + output_dtype: torch.dtype = torch.float32, +): + """Groupwise dequantization within each channel for an 2-d Tensor using the quantization parameters + to map from floating point to quantized values. This means for each row of a 2-d Tensor + (M, N), we calculate scales/zero_points for each `group_size` elements + and quantize every `group_size` elements with the same quantization parameter. + The dimension for scales/zero_points will be (M * ceil(N, group_size),) + + Args: + input (torch.Tensor): quantized Tensor (uint8/int8 etc.) + scales (float32 torch.Tensor): quantization parameter for per channel group affine quantization + zero_points (int32 torch.Tensor): quantization parameter for per channel group affine quantization + quant_min (int): minimum quantized value for input Tensor + quant_max (int): maximum quantized value for input Tensor + dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor + output_dtype (torch.dtype): dtype (e.g. torch.float32) for output Tensor + + Returns: + dequantized Tensor with dtype `output_dtype` + """ + + assert group_size > 1 + # needed for GPTQ single column dequantize + if group_size > w_int8.shape[-1] and scales.shape[-1] == 1: + group_size = w_int8.shape[-1] + assert w_int8.shape[-1] % group_size == 0 + assert w_int8.dim() == 2 + + w_int8_grouped = w_int8.reshape(-1, group_size) + scales = scales.reshape(-1, 1) + if zero_points is not None: + zp = zero_points.reshape(-1, 1) + else: + zp = torch.zeros([], dtype=torch.int32, device=scales.device) + w_dq = w_int8_grouped.sub(zp).mul(scales).reshape_as(w_int8).to(output_dtype) + return w_dq + + +quantized_decomposed_lib.define( + "fake_quant_per_channel(Tensor input, Tensor scales, Tensor zero_points, int axis, " + "int quant_min, int quant_max) -> Tensor" +) + + +class FakeQuantPerChannel(torch.autograd.Function): + @staticmethod + def forward(ctx, input, scales, zero_points, axis, quant_min, quant_max): + if scales.dtype != torch.float32: + scales = scales.to(torch.float32) + if zero_points.dtype != torch.int32: + zero_points = zero_points.to(torch.int32) + assert input.dtype == torch.float32, ( + f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + ) + assert axis < input.dim(), f"Expecting axis to be < {input.dim()}" + broadcast_dims = list(range(0, axis)) + list(range(axis + 1, input.ndim)) + unsqueeze_scales = _unsqueeze_multiple(scales, broadcast_dims) + unsqueeze_zero_points = _unsqueeze_multiple(zero_points, broadcast_dims) + temp = torch.round(input * (1.0 / unsqueeze_scales)) + unsqueeze_zero_points + out = ( + torch.clamp(temp, quant_min, quant_max) - unsqueeze_zero_points + ) * unsqueeze_scales + mask = torch.logical_and((temp >= quant_min), (temp <= quant_max)) + + ctx.save_for_backward(mask) + return out + + @staticmethod + def backward(ctx, gy): + (mask,) = ctx.saved_tensors + return gy * mask, None, None, None, None, None + + +@impl(quantized_decomposed_lib, "fake_quant_per_channel", "Autograd") +def fake_quant_per_channel( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + axis: int, + quant_min: int, + quant_max: int, +) -> torch.Tensor: + return FakeQuantPerChannel.apply( + input, scales, zero_points, axis, quant_min, quant_max + ) + + +@impl(quantized_decomposed_lib, "fake_quant_per_channel", "Meta") +def fake_quant_per_channel_meta( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + axis: int, + quant_min: int, + quant_max: int, +) -> torch.Tensor: + return torch.empty_like(input) + + +quantized_decomposed_lib.define( + "convert_element_type.no_fuse(Tensor input, ScalarType dtype) -> Tensor" +) + + +@impl( + quantized_decomposed_lib, + "convert_element_type.no_fuse", + "CompositeExplicitAutograd", +) +def convert_element_type(input: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + return torch.ops.prims.convert_element_type.default(input, dtype) + + +@impl(quantized_decomposed_lib, "convert_element_type.no_fuse", "Meta") +def convert_element_type_meta(input: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + return torch.empty_like(input, dtype=dtype) diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/_equalize.py b/phivenv/Lib/site-packages/torch/ao/quantization/fx/_equalize.py new file mode 100644 index 0000000000000000000000000000000000000000..c464c4cbae8a15dc6fa58e85ccfd47c57b8f904e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/fx/_equalize.py @@ -0,0 +1,955 @@ +# mypy: allow-untyped-defs +import operator +import warnings +from collections import namedtuple +from typing import Any, Optional + +import torch +import torch.ao.nn.intrinsic as nni +import torch.nn as nn +import torch.nn.functional as F +from torch.ao.quantization.fx.graph_module import _get_observed_graph_module_attr +from torch.ao.quantization.observer import ( + _with_args, + ObserverBase, + PerChannelMinMaxObserver, +) +from torch.ao.quantization.utils import _parent_name, check_min_max_valid +from torch.fx import GraphModule +from torch.fx.graph import Node + +from .utils import ( + get_new_attr_name_with_prefix, + maybe_get_next_module, + node_arg_is_weight, +) + + +CUSTOM_MODULE_SUPP_LIST: list[Any] = [] + + +def reshape_scale(scale: torch.Tensor, axis: int, input: torch.Tensor) -> torch.Tensor: + """Reshapes the scale so that we can multiply it to the input by the given axis.""" + new_shape = [1] * input.ndim + new_shape[axis] = input.size(axis) + return scale.view(new_shape) + + +qsheme_mapping_per_tensor_to_per_channel = { + torch.per_tensor_affine: torch.per_channel_affine, + torch.per_tensor_symmetric: torch.per_channel_symmetric, +} + + +class _InputEqualizationObserver(nn.Module): + r"""Observer for tracking the running min/max values of input columns, and + computing the quantization parameters for the overall min/max input values. + + Args: + dtype: Quantized data type + qscheme: Quantization scheme + quant_min: Minimum quantization value. If unspecified, it will + follow the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, it will + follow the 8-bit setup. + + The running minimum/maximum :math:`x_\text{min/max}` are computed in the + same way as :class:`~torch.ao.quantization.observer.PerChannelMinMaxObserver`, + with the difference that the running min/max values are stored per column. + This observer is intended to be used along with a WeightEqualizationObserver + to calculate the equalization scale. + """ + + def __init__( + self, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + quant_min=None, + quant_max=None, + factory_kwargs=None, + ) -> None: + super().__init__() + + if qscheme not in {torch.per_tensor_affine, torch.per_tensor_symmetric}: + raise TypeError("Input qscheme must be per-tensor") + + self.dtype = dtype + self.qscheme = qscheme + + per_channel_qscheme = qsheme_mapping_per_tensor_to_per_channel[qscheme] + self.input_obs = PerChannelMinMaxObserver( + ch_axis=1, + dtype=dtype, + qscheme=per_channel_qscheme, + quant_min=quant_min, + quant_max=quant_max, + factory_kwargs=factory_kwargs, + ) + + self.equalization_scale = torch.tensor(1) + self.equalization_shape: list[int] = [] + + def forward(self, x_orig): + if not (x_orig.ndim >= 2 and x_orig.ndim <= 5): + raise ValueError( + "InputEqualizationObserver only supports Linear and Conv layers" + ) + + # Calculate the shape needed to reshape the equalization scale later (needed for Conv layers) + self.equalization_shape = [1] * x_orig.ndim + self.equalization_shape[1] = x_orig.size(1) + + return self.input_obs(x_orig) + + def get_input_minmax(self): + return (self.input_obs.min_val, self.input_obs.max_val) + + def set_equalization_scale(self, equalization_scale): + # Reshape the equalization scale along axis=1 so that it can be + # multiplied with the input along axis=1 + if equalization_scale.nelement() == 1 and equalization_scale == torch.tensor(1): + return + self.equalization_scale = torch.reshape( + equalization_scale, self.equalization_shape + ) + + def calculate_scaled_minmax(self): + r"""Returns the scaled min/max inputs""" + if ( + self.equalization_scale.nelement() == 1 + and self.equalization_scale == torch.tensor(1) + ): + warnings.warn( + "Must call calculate_equalization_scale before calling calculate_scaled_minmax. " + + "Will not scale the next quantization observer." + ) + return None, None + + # Calculate qparams for the scaled min/max inputs + # Scale the input by the equalization scale located at the same column + # index + (min_inputs, max_inputs) = self.get_input_minmax() + equalization_scale_reshaped = reshape_scale( + self.equalization_scale, 0, min_inputs + ) + min_input_scaled = torch.min(torch.mul(min_inputs, equalization_scale_reshaped)) + max_input_scaled = torch.max(torch.mul(max_inputs, equalization_scale_reshaped)) + + return min_input_scaled, max_input_scaled + + with_args = classmethod(_with_args) + + +class _WeightEqualizationObserver(nn.Module): + r"""Observer for tracking the running min/max values of weight columns and + rows, and computing the quantization parameters for the weight rows. + + Args: + dtype: Quantized data type + qscheme: Quantization scheme + quant_min: Minimum quantization value. If unspecified, it will + follow the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, it will + follow the 8-bit setup. + + This observer is made up of 1 PerChannelMinMaxObserver `weight_col_obs` used + to record the running minimum and maximum of columns of incoming weight + tensors. This observer is intended to be used along with an + InputEqualizationObserver to calculate the equalization scale. + + The running minimum/maximum :math:`w_\text{min/max}` are computed in the + same way as :class:`~torch.ao.quantization.observer.PerChannelMinMaxObserver`. + """ + + def __init__( + self, + dtype=torch.qint8, + qscheme=torch.per_tensor_affine, + quant_min=None, + quant_max=None, + factory_kwargs=None, + ) -> None: + super().__init__() + + self.dtype = dtype + self.qscheme = qscheme + self.ch_axis = 1 + + per_channel_qscheme = qscheme + if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}: + per_channel_qscheme = qsheme_mapping_per_tensor_to_per_channel[qscheme] + self.weight_col_obs = PerChannelMinMaxObserver( + ch_axis=1, + dtype=dtype, + qscheme=per_channel_qscheme, + quant_min=quant_min, + quant_max=quant_max, + factory_kwargs=factory_kwargs, + ) + + self.equalization_scale = torch.tensor(1) + + def forward(self, w_orig): + if not (w_orig.ndim >= 2 and w_orig.ndim <= 5): + raise ValueError( + "InputEqualizationObserver only supports Linear and Conv layers" + ) + + return self.weight_col_obs(w_orig) + + def get_weight_col_minmax(self): + return (self.weight_col_obs.min_val, self.weight_col_obs.max_val) + + def set_equalization_scale(self, equalization_scale): + self.equalization_scale = equalization_scale + + with_args = classmethod(_with_args) + + +def calculate_equalization_scale( + input_obs: _InputEqualizationObserver, weight_obs: _WeightEqualizationObserver +) -> torch.Tensor: + r"""Calculates the equalization scale and sets the equalization_scale value + in the observers. + + Args: + input_obs: Observer that tracks the ranges for the input columns + weight_obs: Observer that tracks the ranges for the weight columns + """ + + (min_inputs, max_inputs) = input_obs.get_input_minmax() + (min_weights, max_weights) = weight_obs.get_weight_col_minmax() + + if not ( + check_min_max_valid(min_inputs, max_inputs) + and check_min_max_valid(min_weights, max_weights) + ): + warnings.warn( + "Must run observer before calling calculate_equalization_scale. " + + "Returning default equalization scale torch.tensor(1)." + ) + return torch.tensor(1) + + if not (min_inputs.shape == min_weights.shape): + raise ValueError( + "Input and Weight must have the same column dimension. " + + f"Found {min_inputs.shape} and {min_weights.shape} shapes instead." + ) + + equalization_scale = torch.sqrt( + (max_weights - min_weights) / (max_inputs - min_inputs) + ) + # Replace all 'inf', 'nan', 0's with 1s to prevent errors + equalization_scale[equalization_scale == 0.0] = 1 + equalization_scale = torch.nan_to_num(equalization_scale, nan=1, posinf=1, neginf=1) + return equalization_scale + + +class EqualizationQConfig( + namedtuple("EqualizationQConfig", ["input_activation", "weight"]) +): + """ + Describes how to quantize a layer or a part of the network specifically for + input-weight equalization by providing settings (observer classes) for + inputs, outputs, and weights. + + Note that EqualizationQConfig needs to contain observer **classes** (like + MinMaxObserver) or a callable that returns instances on invocation, not the + concrete observer instances themselves. + Quantization function will instantiate observers multiple times for each of + the layers. + + Observer classes have usually reasonable default arguments, but they can be + overwritten with `with_args` method (that behaves like functools.partial): + + my_qconfig = EqualizationQConfig(input_activation=_InputEqualizationObserver.with_args(dtype=torch.qint8), + weight=_WeightEqualizationObserver.with_args(dtype=torch.qint8)) + """ + + __slots__ = () + + def __new__(cls, input_activation=torch.nn.Identity, weight=torch.nn.Identity): + if isinstance(input_activation, nn.Module) or isinstance(weight, nn.Module): + raise ValueError( + "EqualizationQConfig received observer instance, please pass observer class instead. " + + "Use MyObserver.with_args(x=1) to override arguments to constructor if needed" + ) + self = super().__new__(cls, input_activation, weight) + return self + + +input_equalization_observer = _InputEqualizationObserver.with_args( + dtype=torch.quint8, qscheme=torch.per_tensor_symmetric +) +weight_equalization_observer = _WeightEqualizationObserver.with_args( + dtype=torch.qint8, qscheme=torch.per_channel_symmetric +) +default_equalization_qconfig = EqualizationQConfig( + input_activation=input_equalization_observer, weight=weight_equalization_observer +) + + +def fused_module_supports_equalization(module) -> bool: + """Checks if the fused node supports equalization.""" + return type(module) in [ + nni.LinearReLU, + nni.ConvReLU1d, + nni.ConvReLU2d, + nni.ConvReLU3d, + ] + + +def nn_module_supports_equalization(module) -> bool: + """Checks if the torch.nn node supports equalization.""" + return type(module) in [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d] + + +def custom_module_supports_equalization(module) -> bool: + """Checks if the custom node supports equalization.""" + return type(module) in CUSTOM_MODULE_SUPP_LIST + + +def node_supports_equalization(node: Node, modules) -> bool: + """Checks if the current node supports equalization + Currently we only support nn.Linear/F.Linear and nn.Conv/F.conv layers + """ + if node.op == "call_module": + return ( + nn_module_supports_equalization(modules[str(node.target)]) + or fused_module_supports_equalization(modules[str(node.target)]) + or custom_module_supports_equalization(modules[str(node.target)]) + ) + elif node.op == "call_function": + return node.target in [F.linear, F.conv1d, F.conv2d, F.conv3d] + return False + + +def is_equalization_observer(observer: nn.Module) -> bool: + return isinstance( + observer, (_InputEqualizationObserver, _WeightEqualizationObserver) + ) + + +############################################################################### +# Functions for equalization during convert # +############################################################################### + + +def get_op_node_and_weight_eq_obs( + input_eq_obs_node: Node, model: GraphModule, modules: dict[str, nn.Module] +) -> tuple[Optional[Node], Optional[_WeightEqualizationObserver]]: + """Gets the following weight equalization observer. There should always + exist a weight equalization observer after an input equalization observer. + + Returns the operation node that follows the input equalization observer node + and the weight equalization observer + """ + + # Find the op node that comes directly after the input equalization observer + op_node = None + for user in input_eq_obs_node.users.keys(): + if node_supports_equalization(user, modules): + op_node = user + break + + assert op_node is not None + if op_node.op == "call_module": + # If the op_node is a nn.Linear layer, then it must have a + # WeightEqualizationObserver configuration + maybe_equalization_node_name_to_config = _get_observed_graph_module_attr( + model, "equalization_node_name_to_qconfig" + ) + assert maybe_equalization_node_name_to_config is not None + equalization_node_name_to_qconfig: dict[str, Any] = ( + maybe_equalization_node_name_to_config # type: ignore[assignment] + ) + assert equalization_node_name_to_qconfig.get(op_node.name, None) is not None + weight_eq_obs = equalization_node_name_to_qconfig.get( # type: ignore[union-attr] + op_node.name, None + ).weight() + + assert isinstance(weight_eq_obs, _WeightEqualizationObserver) + return op_node, weight_eq_obs + + elif op_node.op == "call_function": + weight_node = maybe_get_weight_eq_obs_node(op_node, modules) + if weight_node is not None: + weight_eq_obs = modules[str(weight_node.target)] + assert isinstance(weight_eq_obs, _WeightEqualizationObserver) + return op_node, weight_eq_obs + + return None, None + + +def maybe_get_weight_eq_obs_node( + op_node: Node, modules: dict[str, nn.Module] +) -> Optional[Node]: + """Gets the weight equalization observer node if it exists.""" + assert op_node.op == "call_function" + for node_arg in op_node.args: + if node_arg_is_weight(op_node, node_arg): + assert ( + isinstance(node_arg, Node) + and node_arg.op == "call_module" + and isinstance( + modules[str(node_arg.target)], _WeightEqualizationObserver + ) + ) + return node_arg + return None + + +def maybe_get_next_input_eq_obs( + node: Node, modules: dict[str, nn.Module] +) -> Optional[_InputEqualizationObserver]: + """Gets the following input equalization observer if it exists. + + For example, in the case of connecting linear layers: + x -> inp_obs1 -> eq_obs1 -> linear1 -> out_obs1 -> eq_obs2 -> linear2 -> out_obs2 + If the node being passed in is the linear1 node, then we want to return eq_obs2, + the following equalization observer for linear2. + + However, if there are no connecting layers: + x -> inp_obs1 -> eq_obs1 -> linear1 -> out_obs1 -> add + Then we want to return None. + + In the case of an unfused linear-relu layer with a connecting linear layer: + linear1 -> relu -> out_obs1 -> eq_obs2 -> linear2 -> out_obs2 + Since it is unfused, we want to skip over the relu layer and return eq_obs2, + the following equalization observer for linear2. + """ + + assert node_supports_equalization(node, modules) + + # Locate the following nn.ReLU or F.relu node if it exists + maybe_relu_node = maybe_get_next_module(node, modules, nn.ReLU) + if maybe_relu_node is None: + maybe_relu_node = maybe_get_next_module( + node, modules, target_functional_type=F.relu + ) + + # Locate the following output observer if it exists. + # We will skip the relu node if it exists. + maybe_obs_node = ( + maybe_get_next_module(node, modules, ObserverBase) + if maybe_relu_node is None + else maybe_get_next_module(maybe_relu_node, modules, ObserverBase) + ) + if maybe_obs_node is None: + return None + + maybe_eq_obs_node = maybe_get_next_module( + maybe_obs_node, modules, _InputEqualizationObserver + ) + if maybe_eq_obs_node is None: + return None + + maybe_eq_obs = modules[str(maybe_eq_obs_node)] + assert isinstance(maybe_eq_obs, _InputEqualizationObserver) + return maybe_eq_obs + + +def maybe_get_next_equalization_scale( + node: Node, modules: dict[str, nn.Module] +) -> Optional[torch.Tensor]: + """If the next next node is an InputEqualizationObserver then we want to + return its equalization scale, else we return 1 + + This is used in the case where there are two connecting linear layers: + linear1 -> LinearOutObs -> InputEqObs -> linear2 + In this case, the node given is linear1 and we want to locate the InputEqObs. + """ + next_inp_eq_obs = maybe_get_next_input_eq_obs(node, modules) + if next_inp_eq_obs: + if ( + next_inp_eq_obs.equalization_scale.nelement() == 1 + and next_inp_eq_obs.equalization_scale == torch.tensor(1) + ): + return None + return next_inp_eq_obs.equalization_scale + return None + + +def scale_input_observer(node: Node, modules: dict[str, nn.Module]) -> None: + """Scales the following input quantization observer's min/max values by + updating the values with the scaled min/max values calculated by the input + equalization observer + """ + input_eq_obs = modules[str(node.target)] + assert isinstance(input_eq_obs, _InputEqualizationObserver) + + input_quant_obs_node = node.args[0] + assert isinstance(input_quant_obs_node, Node) + + input_quant_obs = modules[str(input_quant_obs_node.target)] + if not isinstance(input_quant_obs, ObserverBase): + return + + min_input_scaled, max_input_scaled = input_eq_obs.calculate_scaled_minmax() + if min_input_scaled is None and max_input_scaled is None: + return + input_quant_obs.min_val = min_input_scaled + input_quant_obs.max_val = max_input_scaled + + +def scale_weight_node( + node: Node, + modules: dict[str, nn.Module], + equalization_scale: torch.Tensor, + next_equalization_scale: Optional[torch.Tensor], +) -> None: + """Scale the weights for input-weight equalization by multiplying the + weight by 1/equalization_scale and next_equalization_scale + + Args: + node: Current node whose weights we want to scale + equalization_scale: Current node's calculated equalization scale + next_equalization_scale: Next node's calculated equalization scale if + the following node needs to be equalized, 1 otherwise + """ + if equalization_scale is None: + return + + if fused_module_supports_equalization(modules[str(node.target)]): + op_module = modules[str(node.target)][0] # type: ignore[index] + else: + op_module = modules[str(node.target)] + assert nn_module_supports_equalization( + op_module + ) or custom_module_supports_equalization(op_module) + + # Scale the weights for input-weight equalization + # If the following layer needs to be equalized then we will multiply its scale + weight = op_module.weight + assert isinstance(weight, torch.Tensor) + + # Scale the weights by the reciprocal of the equalization scale + # Reshape the equalization scale so that we can multiply it to the weight along axis=1 + equalization_scale_reshaped = reshape_scale(equalization_scale, 1, weight) + scaled_weight = torch.mul(weight, torch.reciprocal(equalization_scale_reshaped)) + + if next_equalization_scale is None: + op_module.weight = nn.Parameter(scaled_weight) + return + + # Multiply the weights row wise by the next equalization scale + # Reshape the equalization scale so that we can multiply it to the weight along axis=0 + next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, weight) + scaled_weight = torch.mul(scaled_weight, next_equalization_scale_reshaped) + + op_module.weight = nn.Parameter(scaled_weight) + + # Multiply the bias element wise by the next equalization scale + bias = op_module.bias + if bias is None: + return + assert isinstance(bias, torch.Tensor) + + # Reshape the equalization scale so that we can multiply it element-wise to the bias + next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, bias) + scaled_bias = torch.mul(bias, next_equalization_scale_reshaped) + op_module.bias = nn.Parameter(scaled_bias) + + +def scale_weight_functional( + op_node: Node, + model: GraphModule, + modules: dict[str, nn.Module], + equalization_scale: torch.Tensor, + next_equalization_scale: Optional[torch.Tensor], +) -> None: + """Scales the weight value for functional layers""" + if equalization_scale is None: + return + + # From the given op_node, the path looks like: + # get_attr(weight) -> weight_quant_obs -> weight_eq_obs -> op_node + # So we want to trace back from the op_node to get the equalization observer + # node, then the quantization observer node, and then finally the weight + # node which contains the weight values. + + # Get the equalization observer node + weight_eq_obs_node = maybe_get_weight_eq_obs_node(op_node, modules) + if weight_eq_obs_node is None: + return + + # Get the quantization observer node + weight_quant_obs_node = weight_eq_obs_node.args[0] + if weight_quant_obs_node is None: + return + assert isinstance(weight_quant_obs_node, Node) and isinstance( + modules[str(weight_quant_obs_node.target)], ObserverBase + ) + + # Get the get_attr(weight) node + weight_node = weight_quant_obs_node.args[0] + if weight_node is None: + return + assert isinstance(weight_node, Node) and weight_node.op == "get_attr" + + weight_parent_name, weight_name = _parent_name(weight_node.target) + weight = getattr(modules[weight_parent_name], weight_name) + + # Scale the weights for input-weight equalization + # If the following layer needs to be equalized then we will multiply its scale + # Reshape the equalization scale so that we can multiply it to the weight along axis=1 + equalization_scale_reshaped = reshape_scale(equalization_scale, 1, weight) + scaled_weight = torch.mul(weight, torch.reciprocal(equalization_scale_reshaped)) + + if next_equalization_scale is None: + setattr(modules[weight_parent_name], weight_name, scaled_weight) + return + + # Multiply the weights row wise by the next equalization scale + # Reshape the equalization scale so that we can multiply it to the weight along axis=1 + next_equalization_scale_reshaped = reshape_scale( + next_equalization_scale, 0, scaled_weight + ) + scaled_weight = torch.mul(scaled_weight, next_equalization_scale_reshaped) + + setattr(modules[weight_parent_name], weight_name, scaled_weight) + assert torch.allclose(model.get_buffer(str(weight_node.target)), scaled_weight) + + # Multiply the bias element wise by the next equalization scale + bias_node = None + for node in op_node.args: + # Find the node containing the weight values + if isinstance(node, Node) and node.op == "get_attr" and "bias" in node.name: + bias_node = node + break + if bias_node is None: + return + + bias_parent_name, bias_name = _parent_name(bias_node.target) + bias = getattr(modules[bias_parent_name], bias_name) + + # Reshape the equalization scale so that we can multiply it element-wise to the bias + next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, bias) + scaled_bias = torch.mul(bias, next_equalization_scale_reshaped) + setattr(modules[bias_parent_name], bias_name, scaled_bias) + + +def clear_weight_quant_obs_node(op_node: Node, modules: dict[str, nn.Module]) -> None: + """Given the operation node, we want find the corresponding quantization + observer and reset its min/max values + """ + weight_eq_obs_node = maybe_get_weight_eq_obs_node(op_node, modules) + if weight_eq_obs_node is None: + return + + weight_quant_obs_node = weight_eq_obs_node.args[0] + if weight_quant_obs_node is None: + return + assert isinstance(weight_quant_obs_node, Node) + + weight_quant_obs = modules[str(weight_quant_obs_node.target)] + assert isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase) + weight_quant_obs.reset_min_max_vals() # type: ignore[operator] + + +def remove_node(model: GraphModule, node: Node, prev_node: Node): + """Removes the given node from the model by replacing all of its users with + the given previous node + """ + # For all of the current node's users, replace the current node with + # the input quantization observer node + orig_users = list(node.users.keys()) + for user_node in orig_users: + user_node.replace_input_with(node, prev_node) + + # Erase the InputEqualizationObserver node + model.graph.erase_node(node) + + +def update_obs_for_equalization( + model: GraphModule, modules: dict[str, nn.Module] +) -> dict[str, _WeightEqualizationObserver]: + """Update all of the observer's equalization scale. For each + InputEqualizationObserver, we will find the location of the next + WeightEqualizationObserver, create it, and calculate the equalization scale + based on the two observers. + + We will then return a dictionary mapping operation node names to + the corresponding WeightEqualizationObservers for that operation. + """ + weight_eq_obs_dict = {} + for node in model.graph.nodes: + if node.op == "call_module" and isinstance( + modules[node.target], _InputEqualizationObserver + ): + input_eq_obs = modules[node.target] + assert isinstance(input_eq_obs, _InputEqualizationObserver) + op_node, weight_eq_obs = get_op_node_and_weight_eq_obs(node, model, modules) + + if op_node is None or weight_eq_obs is None: + continue + + if op_node.op == "call_module": + # Calibrate the weight equalization observer since it has just + # been created + if fused_module_supports_equalization(modules[str(op_node.target)]): + module = modules[str(op_node.target)][0] # type: ignore[index] + assert nn_module_supports_equalization(module) + weight_eq_obs(module.weight) + else: + weight_eq_obs(modules[str(op_node.target)].weight) + + # Calculate and set the equalization scale values + equalization_scale = calculate_equalization_scale( + input_eq_obs, weight_eq_obs + ) + input_eq_obs.set_equalization_scale(equalization_scale) + weight_eq_obs.set_equalization_scale(equalization_scale) + + weight_eq_obs_dict[op_node.name] = weight_eq_obs + + return weight_eq_obs_dict + + +def convert_eq_obs( + model: GraphModule, + modules: dict[str, nn.Module], + weight_eq_obs_dict: dict[str, _WeightEqualizationObserver], +) -> None: + """Converts the equalization operations and updates the other nodes in the + following way: + - Removes the input equalization observers and inserts a mul operator + along with an equalization scale node wherever applicable (we do not + want to insert a mul operator between connecting linear layers). + - Updates the input quantization observers with the scaled input min/max + values. + - Scales the weights by the current and next equalization scales. + - Removes the weight equalization observer node if it exists. + + Before (after prepare): + weight values + | + WeightQuantObs + | + WeightEqObs + | + x -> InpQuantObs -> InpEqObs -> linear -> OutQuantObs + + After this function: + scaled weight values + | + equalization scale WeightQuantObs + | | + x -> mul -> InpQuantObs (scaled min/max) -> linear -> OutQuantObs + + After convert: + equalization scale scaled weight values + | | + x -> mul -> quantize_per_tensor -> quantized::linear + + Note that although the equalization observer appeared after the quantization + observer after prepare_fx, the mul node appears before the quantization node + after convert_fx. This is because placing the equalization observer after + the quantization observer in prepare_fx would allow us to keep the invariant + that the graph before the current node inserts its observers is not + modified. + + Having the equalization observer before the quantization observer would also + cause some inconsistences between the ordering of the quantization and + equalization observers. + For example, a single linear layer would look like: + x -> InpEqObs1 -> InpQuantObs1 -> linear1 -> OutQuantObs1 + But between two connected linear layers, it would look like: + linear1 -> OutQuantObs1 -> InpEqObs2 -> linear2 -> OutQuantObs2 + """ + for node in model.graph.nodes: + if node.op == "call_module" and isinstance( + modules[node.target], _InputEqualizationObserver + ): + inp_quant_obs_node = node.args[0] + prev_node = inp_quant_obs_node.args[0] + + # If the previous node is a layer that needs to be equalized, then + # we will remove the current node because we do not need to add any + # equalization nodes between two layers that need to be equalized + + # Before: linear1/relu (prev_node) -> output_quant_obs1 (inp_quant_obs_node) -> input_eq_obs2 (node) -> linear2 + # After: linear1/relu (prev_node) -> output_quant_obs1 (inp_quant_obs_node) -> linear2 + if ( + node_supports_equalization(prev_node, modules) + or "relu" in prev_node.name + ): + remove_node(model, node, inp_quant_obs_node) + continue + + # Update the following input quantization observer's min/max values + scale_input_observer(node, modules) + + # Remove the InputEqualization node and add a mul operator before + # the quantization observer node that appears before the equalization node + # Before: x -> input_quant_obs -> input_eq_obs -> linear + # After: x -> mul -> input_quant_obs -> linear + + # Create a node containing the equalization scale + with model.graph.inserting_before(inp_quant_obs_node): + get_new_eq_scale_name = get_new_attr_name_with_prefix( + prev_node.name + "_equalization_scale" + ) + name = get_new_eq_scale_name(modules) + setattr(model, name, modules[node.target].equalization_scale) + eq_scale_node = model.graph.create_node("get_attr", name) + + # Create a node multiplying the input with the equalization scale + with model.graph.inserting_after(eq_scale_node): + inputs = (prev_node, eq_scale_node) + mul_node = model.graph.create_node("call_function", torch.mul, inputs) + + # Set the mul nod to be the input_quant_obs_node's input instead of + # the previous node + inp_quant_obs_node.replace_input_with(prev_node, mul_node) + remove_node(model, node, inp_quant_obs_node) + + elif weight_eq_obs_dict.get(node.name, None) is not None: + weight_eq_obs = weight_eq_obs_dict.get(node.name) + assert isinstance(weight_eq_obs, _WeightEqualizationObserver) + equalization_scale = weight_eq_obs.equalization_scale + + if ( + equalization_scale.nelement() == 1 + and equalization_scale == torch.tensor(1) + ): + equalization_scale = None # type: ignore[assignment] + maybe_next_equalization_scale = maybe_get_next_equalization_scale( + node, modules + ) + + # Scale the weight nodes + if node.op == "call_module": + scale_weight_node( + node, modules, equalization_scale, maybe_next_equalization_scale + ) + elif node.op == "call_function": + scale_weight_functional( + node, + model, + modules, + equalization_scale, + maybe_next_equalization_scale, + ) + + weight_eq_obs_node = maybe_get_weight_eq_obs_node(node, modules) + if weight_eq_obs_node is None: + return + assert isinstance( + modules[str(weight_eq_obs_node.target)], _WeightEqualizationObserver + ) + + # Clear the quantization observer's min/max values so that they + # can get updated later based on the new scale values + clear_weight_quant_obs_node(node, modules) + + # Erase the weight equalization observer node + prev_node = weight_eq_obs_node.args[0] + remove_node(model, weight_eq_obs_node, prev_node) # type: ignore[arg-type] + else: + raise ValueError( + "Expected operation node to be 'call_module' or 'call_function" + + f"Instead got node {node.name} as '{node.op}'." + ) + + +def _convert_equalization_ref(model: GraphModule): + """Reference function which applies changes needed for equalization, but + does not quantize the nodes + """ + modules = dict(model.named_modules(remove_duplicate=False)) + + # Calculate the equalization scale, update the observers with the scaled + # inputs, and scale the weight + weight_eq_obs_dict = update_obs_for_equalization(model, modules) + convert_eq_obs(model, modules, weight_eq_obs_dict) + + return GraphModule(model, model.graph) + + +############################################################################### +# Functions for running the equalized model on the Numeric Suite # +############################################################################### + + +def get_layer_sqnr_dict( + model_a: nn.Module, model_b: nn.Module, x: torch.Tensor +) -> dict[str, float]: + """Runs the Numeric Suite on model_a and model_b and returns a dictionary + containing the SQNR between layers in model_a and model_b. + + Note: In order to support equalized models, this function has a hacky fix in + which we do not match any torch.mul operators. This is because equalized + models contain extra mul operators to scale the input by the equalization + scale, but this edge case has not been resolved yet within the numeric suite code. + + Args: + model_a: A float model + model_b: A quantized model + x: Inputs to use during calibration + """ + import torch.ao.ns._numeric_suite_fx as ns + from torch.ao.ns.fx.mappings import get_unmatchable_types_map + + unmatchable_types_map = get_unmatchable_types_map() + unmatchable_types_map["funs_unmatchable"].add(torch.mul) + + model_a_ns, model_b_ns = ns.add_loggers( + "fp32", + model_a, + "int8", + model_b, + ns.OutputLogger, + unmatchable_types_map=unmatchable_types_map, + ) + + model_a_ns(x) + model_b_ns(x) + + activation_comparison_dict = ns.extract_logger_info( + model_a_ns, model_b_ns, ns.OutputLogger, "int8" + ) + ns.extend_logger_results_with_comparison( + activation_comparison_dict, + "fp32", + "int8", + torch.ao.ns.fx.utils.compute_sqnr, + "sqnr", + ) + + # Construct a dictionary mapping layer names to the SQNR values + layer_sqnr_dict = {} + for key in activation_comparison_dict: + layer = activation_comparison_dict[key]["node_output"]["int8"][0]["fqn"] + sqnr = activation_comparison_dict[key]["node_output"]["int8"][0]["sqnr"][0] + layer_sqnr_dict[layer] = sqnr + + return layer_sqnr_dict + + +def get_equalization_qconfig_dict( + layer_sqnr_dict: dict[str, float], num_layers_to_equalize: int +) -> Any: + """Given the layer to SQNR dictionary, find the layers with the highest + quantization errors, and return an equalization_qconfig_dict + specifying to only equalize those top layers. + + Args: + layer_sqnr_dict: Dictionary mapping layer names to SQNR values (found + when comparing an equalized model against a float model) + num_layers_to_equalize: Number of layers with the highest quantization + errors to equalize + """ + + # Sort the layer_sqnr_dictionary values and get the layers with the lowest + # SQNR values (aka highest quantization errors) + layer_sqnr_sorted = sorted(layer_sqnr_dict.items(), key=operator.itemgetter(1)) + layers_to_equalize = layer_sqnr_sorted[:num_layers_to_equalize] + + # Constructs an equalization_qconfig_dict that specifies to only equalize + # the layers with the highest quantization errors + module_to_qconfig_list = [ + (item[0], default_equalization_qconfig) for item in layers_to_equalize + ] + equalization_qconfig_dict = {"module_name": module_to_qconfig_list} + return equalization_qconfig_dict diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/_lower_to_native_backend.py b/phivenv/Lib/site-packages/torch/ao/quantization/fx/_lower_to_native_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..8bf5b99aecde651947cc0d10c1f5adf85ad7a1c5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/fx/_lower_to_native_backend.py @@ -0,0 +1,1364 @@ +# mypy: allow-untyped-defs +import operator +from typing import Any, Callable, Optional, Union + +import torch +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.intrinsic.quantized as nniq +import torch.ao.nn.intrinsic.quantized.dynamic as nniqd +import torch.ao.nn.quantized as nnq +import torch.ao.nn.quantized.dynamic as nnqd +import torch.ao.nn.quantized.reference as nnqr +import torch.nn as nn +import torch.nn.functional as F +from torch.ao.nn.quantized.modules.utils import WeightedQuantizedModule +from torch.ao.quantization.qconfig import QConfigAny +from torch.ao.quantization.quantization_mappings import get_quantized_operator +from torch.ao.quantization.utils import _parent_name +from torch.fx import GraphModule, map_arg, Node +from torch.fx.graph import Graph + +from .utils import ( + collect_producer_nodes, + create_node_from_old_node_preserve_meta, + get_linear_prepack_op_for_dtype, + get_new_attr_name_with_prefix, + get_qconv_prepack_op, + graph_module_from_producer_nodes, +) + + +QOP_TO_ARG_NAMES_TO_SKIP: dict[Callable[..., Any], list[str]] = { + torch._ops.ops.quantized.hardswish: ["inplace"], + torch._ops.ops.quantized.elu: ["inplace"], + torch._ops.ops.quantized.dropout: ["inplace"], + torch._ops.ops.quantized.instance_norm: [ + "running_mean", + "running_var", + "use_input_stats", + "momentum", + ], +} + + +def _is_node_in_list(node, modules, func_list, method_list, module_type_list): + is_call_function = node.op == "call_function" and node.target in func_list + is_call_method = node.op == "call_method" and node.target in method_list + is_call_module = ( + node.op == "call_module" and type(modules[str(node.target)]) in module_type_list + ) + return is_call_function, is_call_method, is_call_module + + +def is_fixed_qparams_node(node, modules): + func_list = [ + torch.nn.functional.hardsigmoid, + torch.nn.functional.sigmoid, + torch.sigmoid, + torch.tanh, + ] + method_list = [ + "hardsigmoid", + "hardsigmoid_", + "sigmoid", + "sigmoid_", + "tanh", + "tanh_", + ] + module_type_list = [ + torch.nn.Hardsigmoid, + torch.nn.Sigmoid, + torch.nn.Tanh, + torch.nn.Softmax, + ] + return _is_node_in_list(node, modules, func_list, method_list, module_type_list) + + +def is_default_node(node, modules): + func_list = [ + torch.nn.functional.elu, + torch.nn.functional.hardswish, + torch.nn.functional.instance_norm, + torch.nn.functional.layer_norm, + torch.nn.functional.leaky_relu, + torch.nn.functional.dropout, + ] + method_list: list[Any] = [] + module_type_list = [ + nnqr.ConvTranspose1d, + nnqr.ConvTranspose2d, + nnqr.ConvTranspose3d, + torch.nn.ELU, + torch.nn.LeakyReLU, + torch.nn.Hardswish, + torch.nn.InstanceNorm1d, + torch.nn.InstanceNorm2d, + torch.nn.InstanceNorm3d, + torch.nn.LayerNorm, + torch.nn.Dropout, + torch.nn.PReLU, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + torch.ao.nn.intrinsic.BNReLU2d, + torch.ao.nn.intrinsic.BNReLU3d, + ] + return _is_node_in_list(node, modules, func_list, method_list, module_type_list) + + +def is_copy_node(node, modules): + func_list = [ + torch.adaptive_avg_pool1d, + torch.nn.functional.adaptive_avg_pool2d, + torch.nn.functional.adaptive_avg_pool3d, + torch.nn.functional.hardtanh, + torch.nn.functional.hardtanh_, + torch.nn.functional.interpolate, + torch.nn.functional.max_pool1d, + torch.nn.functional.max_pool2d, + torch.nn.functional.max_pool3d, + torch.nn.functional.relu, + torch.nn.functional.relu6, + torch.avg_pool1d, + torch._C._nn.avg_pool2d, + torch._C._nn.avg_pool3d, + torch.clamp, + torch.flatten, + torch.mean, + operator.floordiv, + # F.channel_shuffle and torch.channel_shuffle are essentially the same thing + # so we only need to put one of them here + torch.channel_shuffle, + ] + method_list = [ + "clamp", + "mean", + "relu", + "relu_", + ] + module_type_list = [ + torch.nn.AdaptiveAvgPool1d, + torch.nn.AdaptiveAvgPool2d, + torch.nn.AdaptiveAvgPool3d, + torch.nn.AvgPool1d, + torch.nn.AvgPool2d, + torch.nn.AvgPool3d, + torch.nn.Hardtanh, + torch.nn.MaxPool1d, + torch.nn.MaxPool2d, + torch.nn.MaxPool3d, + torch.nn.ReLU, + torch.nn.ReLU6, + torch.nn.ChannelShuffle, + ] + return _is_node_in_list(node, modules, func_list, method_list, module_type_list) + + +def is_general_tensor_shape_node(node, modules): + func_list = [ + torch.narrow, + torch.transpose, + torch.repeat_interleave, + torch.squeeze, + torch.stack, + torch.unsqueeze, + torch.nn.functional.pixel_shuffle, + torch.nn.functional.pixel_unshuffle, + ] + method_list = [ + "contiguous", + "detach", + "detach_", + "permute", + "repeat", + "repeat_interleave", + "reshape", + "resize_", + "shape", + "size", + "squeeze", + "squeeze_", + "transpose", + "unsqueeze", + "unsqueeze_", + "view", + ] + module_type_list = [ + torch.nn.Identity, + torch.nn.PixelShuffle, + torch.nn.PixelUnshuffle, + ] + return _is_node_in_list(node, modules, func_list, method_list, module_type_list) + + +def is_other_node(node, modules): + func_list = [ + torch.cat, + ] + method_list: list[Any] = [] + module_type_list: list[Any] = [] + return _is_node_in_list(node, modules, func_list, method_list, module_type_list) + + +def is_special_pattern_node(node, modules): + res_function, res_method, res_module = False, False, False + for checker in [ + is_fixed_qparams_node, + is_default_node, + is_copy_node, + is_general_tensor_shape_node, + is_other_node, + ]: + is_call_function, is_call_method, is_call_module = checker(node, modules) + res_function = res_function or is_call_function + res_method = res_method or is_call_method + res_module = res_module or is_call_module + return res_function, res_method, res_module + + +def is_dequantize_node(node): + return ( + isinstance(node, Node) + and node.op == "call_method" + and node.target == "dequantize" + ) + + +def is_getattr_tensor_metadata_node(node): + return ( + node.op == "call_function" + and node.target == getattr + and node.args[1] in ["shape"] + ) + + +def is_get_tensor_info_node(node): + return node.op == "call_method" and node.target in ["shape", "size"] + + +def should_skip_lowering(op: torch.fx.node.Node, qconfig_map: dict[str, QConfigAny]): + """ + Return True if the op is configured with a None qconfig, False otherwise. + Note: maybe need to generalize this to also check for the dtype, and we + only lower when dtype matches, but right now fbgemm/qnnpack only support + a single dtype, so it is OK for now. + """ + return op.name in qconfig_map and qconfig_map[op.name] is None + + +# Mapping from reference module class to the replacement static quantized module class for lowering +STATIC_LOWER_MODULE_MAP: dict[type[nn.Module], type[WeightedQuantizedModule]] = { + nnqr.Linear: nnq.Linear, + nnqr.Conv1d: nnq.Conv1d, + nnqr.Conv2d: nnq.Conv2d, + nnqr.Conv3d: nnq.Conv3d, +} + +# Mapping from reference module class to the replacement dynamic quantized module class for lowering +DYNAMIC_LOWER_MODULE_MAP: dict[type[nn.Module], type[nn.Module]] = { + nnqr.Linear: nnqd.Linear, + nnqr.GRUCell: nnqd.GRUCell, + nnqr.LSTMCell: nnqd.LSTMCell, + nnqr.RNNCell: nnqd.RNNCell, + nnqr.LSTM: nnqd.LSTM, + nnqr.GRU: nnqd.GRU, +} + +# Mapping from reference module class to the replacement weight only quantized module class for lowering +# TODO: correct the namespace for these modules +WEIGHT_ONLY_LOWER_MODULE_MAP: dict[type[nn.Module], type[nn.Module]] = { + nnqr.Embedding: nnq.Embedding, + nnqr.EmbeddingBag: nnq.EmbeddingBag, +} + +# TODO: merge with STATIC_LOWER_MODULE_MAP after we merge +# _lower_static_weighted_ref_module and special_pattern_replacement +SPECIAL_PATTERN_LOWER_MODULE_MAP = { + nn.BatchNorm2d: nnq.BatchNorm2d, + nn.BatchNorm3d: nnq.BatchNorm3d, + nnqr.ConvTranspose1d: nnq.ConvTranspose1d, + nnqr.ConvTranspose2d: nnq.ConvTranspose2d, + nnqr.ConvTranspose3d: nnq.ConvTranspose3d, + nn.ELU: nnq.ELU, + nn.LeakyReLU: nnq.LeakyReLU, + nn.Hardswish: nnq.Hardswish, + nn.InstanceNorm1d: nnq.InstanceNorm1d, + nn.InstanceNorm2d: nnq.InstanceNorm2d, + nn.InstanceNorm3d: nnq.InstanceNorm3d, + nn.LayerNorm: nnq.LayerNorm, + nn.Dropout: nnq.Dropout, + nn.Softmax: nnq.Softmax, + nn.PReLU: nnq.PReLU, + nni.BNReLU2d: nniq.BNReLU2d, + nni.BNReLU3d: nniq.BNReLU3d, +} + +# Mapping from fused module class to a 2-tuple of: +# 1) The inner reference module class +# 2) The replacement static quantized module class for lowering +STATIC_LOWER_FUSED_MODULE_MAP: dict[ + type[nn.Module], tuple[type[nn.Module], type[WeightedQuantizedModule]] +] = { + nni.LinearReLU: (nnqr.Linear, nniq.LinearReLU), + # TODO: LinearLeakyReLU is registered as global but it is only fused and + # lowered when ondnn's backend config is used. Maybe need to separate + # registration and lowering functions for different backends in the future. + nni.LinearLeakyReLU: (nnqr.Linear, nniq.LinearLeakyReLU), + nni.LinearTanh: (nnqr.Linear, nniq.LinearTanh), + nni.ConvReLU1d: (nnqr.Conv1d, nniq.ConvReLU1d), + nni.ConvReLU2d: (nnqr.Conv2d, nniq.ConvReLU2d), + nni.ConvReLU3d: (nnqr.Conv3d, nniq.ConvReLU3d), +} + +# The difference between STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP and STATIC_LOWER_FUSED_MODULE_MAP: +# The refer node inside STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP has 2 inputs. +# Mapping from fused module class to a 2-tuple of: +# 1) The inner reference module class +# 2) The replacement static quantized module class for lowering +STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP: dict[ + type[nn.Module], tuple[type[nn.Module], type[WeightedQuantizedModule]] +] = { + nni.ConvAdd2d: (nnqr.Conv2d, nniq.ConvAdd2d), + nni.ConvAddReLU2d: (nnqr.Conv2d, nniq.ConvAddReLU2d), +} + +# Mapping from fused module class to a 2-tuple of: +# 1) The inner reference module class +# 2) The replacement dynamic quantized module class for lowering +DYNAMIC_LOWER_FUSED_MODULE_MAP: dict[ + type[nn.Module], tuple[type[nn.Module], type[nn.Module]] +] = { + nni.LinearReLU: (nnqr.Linear, nniqd.LinearReLU), +} + +# Mapping from a functional to lower to a 2-tuple of +# 1) The quantized version of the op +# 2) The quantized version of the op fused with relu, if it exists, else None +STATIC_LOWER_FUNCTIONAL_MAP: dict[Callable, tuple[Callable, Optional[Callable]]] = { + F.linear: (torch.ops.quantized.linear, torch.ops.quantized.linear_relu), + F.conv1d: (torch.ops.quantized.conv1d, torch.ops.quantized.conv1d_relu), + F.conv2d: (torch.ops.quantized.conv2d, torch.ops.quantized.conv2d_relu), + F.conv3d: (torch.ops.quantized.conv3d, torch.ops.quantized.conv3d_relu), + F.conv_transpose1d: (torch.ops.quantized.conv_transpose1d, None), + F.conv_transpose2d: (torch.ops.quantized.conv_transpose2d, None), + F.conv_transpose3d: (torch.ops.quantized.conv_transpose3d, None), +} + +WEIGHT_PREPACK_OPS: set[Callable] = { + torch._ops.ops.quantized.linear_prepack, + torch._ops.ops.quantized.linear_prepack_fp16, + torch._ops.ops.quantized.conv1d_prepack, + torch._ops.ops.quantized.conv2d_prepack, + torch._ops.ops.quantized.conv3d_prepack, + torch.ops.quantized.conv_transpose1d_prepack, + torch.ops.quantized.conv_transpose2d_prepack, + torch.ops.quantized.conv_transpose3d_prepack, +} + +# Mapping from a functional to a dictionary, where the key is a 2-tuple of +# (input_activation_dtype, weight_dtype) and the value is a 2-tuple of +# 1) The dynamically quantized version of the op +# 2) The dynamically quantized version of the op fused with relu, if it exists, else None +DYNAMIC_LOWER_FUNCTIONAL_MAP: dict[ + Callable, dict[tuple[torch.dtype, torch.dtype], tuple[Callable, Optional[Callable]]] +] = { + F.linear: { + (torch.quint8, torch.qint8): ( + torch.ops.quantized.linear_dynamic, + torch.ops.quantized.linear_relu_dynamic, + ), + (torch.float16, torch.float16): ( + torch.ops.quantized.linear_dynamic_fp16, + torch.ops.quantized.linear_relu_dynamic_fp16, + ), + }, + # dynamic conv + relu is not available yet + F.conv1d: { + (torch.quint8, torch.qint8): (torch.ops.quantized.conv1d_dynamic, None), + }, + F.conv2d: { + (torch.quint8, torch.qint8): (torch.ops.quantized.conv2d_dynamic, None), + }, + F.conv3d: { + (torch.quint8, torch.qint8): (torch.ops.quantized.conv3d_dynamic, None), + }, +} + +CONV_FUNCTIONAL_OPS: set[Callable] = { + F.conv1d, + F.conv2d, + F.conv3d, +} + +CONV_TRANSPOSE_FUNCTIONAL_OPS: set[Callable] = { + F.conv_transpose1d, + F.conv_transpose2d, + F.conv_transpose3d, +} + +# TODO: add tests for lowering these ops +QBIN_OP_MAPPING: dict[Union[Callable, str], Callable] = { + operator.add: torch.ops.quantized.add, + torch.add: torch.ops.quantized.add, + operator.mul: torch.ops.quantized.mul, + operator.matmul: torch.ops.quantized.matmul, + torch.mul: torch.ops.quantized.mul, + torch.matmul: torch.ops.quantized.matmul, +} +QBIN_RELU_OP_MAPPING: dict[Union[Callable, str], Callable] = { + operator.add: torch.ops.quantized.add_relu, + torch.add: torch.ops.quantized.add_relu, + operator.mul: torch.ops.quantized.mul_relu, + torch.mul: torch.ops.quantized.mul_relu, +} + +ORIGINAL_WEIGHTS_LOOKUP = "original_weights_lookup" + + +def _save_packed_weight(self, destination, prefix, keep_vars): + for attr_name in dir(self): + if "_packed_weight" in attr_name and isinstance( + getattr(self, attr_name), torch._C.ScriptObject + ): # type: ignore[attr-defined] + packed_weight = getattr(self, attr_name) + destination[prefix + attr_name] = packed_weight + + +def _load_packed_weight( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, +): + attrs_to_pop = [] + for attr_name in state_dict: + if attr_name.startswith("_packed_weight") and isinstance( + state_dict[attr_name], torch._C.ScriptObject + ): # type: ignore[attr-defined] # noqa: B950 + setattr(self, attr_name, state_dict[attr_name]) + attrs_to_pop.append(attr_name) + + # pop the packed param attributesn + for attr_name in attrs_to_pop: + state_dict.pop(attr_name) + + +def fold_weight( + quantized_model: GraphModule, + node_name_to_scope: dict[str, tuple[str, type]], + keep_original_weights: bool = False, +) -> GraphModule: + """ + Trace back from the weight node util we hit getattr, reconstruct the + graph module with the traced nodes and run the graph module to pack the + weight. then replace the original chain of ops with the packed weight. + """ + packed_weights = {} + # map from folded node name to the prepacked weight name + folded_nodes = {} + original_weights_lookup: dict[str, list] = {} + lookup_counter = 0 + # get packed weights + for node in quantized_model.graph.nodes: + if node.op == "call_function" and node.target in WEIGHT_PREPACK_OPS: + nodes_to_fold = collect_producer_nodes(node) + if nodes_to_fold is not None: + for node_to_fold in nodes_to_fold: + folded_nodes[node_to_fold.name] = node + + prepacking_module = graph_module_from_producer_nodes( + quantized_model, nodes_to_fold + ) + packed_weight = prepacking_module() + packed_weights[node.name] = packed_weight + if keep_original_weights: + original_weights = list(prepacking_module.state_dict().values()) + original_weights_lookup[str(lookup_counter)] = sorted( + original_weights, key=lambda x: x.numel(), reverse=True + ) + if len(original_weights_lookup[str(lookup_counter)]) == 1: + # bias is None + original_weights_lookup[str(lookup_counter)].append(None) + lookup_counter += 1 + lookup_counter = 0 + + # remove folded nodes and replace the prepacking node with getattr + folded_graph = Graph() + env: dict[Any, Any] = {} + + def load_arg(a): + return map_arg(a, lambda node: env[node.name]) + + for node in quantized_model.graph.nodes: + prepack_node = folded_nodes.get(node.name, None) + if prepack_node is node: + packed_weight = packed_weights[node.name] + # add a prepacked attribute to root + op_node = next(iter(prepack_node.users)) + module_path, _ = node_name_to_scope[op_node.name] + get_new_packed_weight_name = get_new_attr_name_with_prefix( + module_path + "_packed_weight_" + ) + packed_weight_name = get_new_packed_weight_name(quantized_model) + setattr(quantized_model, packed_weight_name, packed_weight) + # replace prepack node with a getattr node + env[node.name] = folded_graph.create_node( + "get_attr", packed_weight_name, (), {} + ) + if keep_original_weights: + key_name = ( + packed_weight_name.replace(":", "_") + .replace("/", "_") + .replace("|", "_") + .replace(" ", "") + .lower() + ) + original_weights_lookup[key_name] = original_weights_lookup[ + str(lookup_counter) + ] + del original_weights_lookup[str(lookup_counter)] + lookup_counter += 1 + elif prepack_node is not None: + # remove the foled node + continue + else: + # copy other nodes + env[node.name] = folded_graph.node_copy(node, load_arg) + + quantized_model = GraphModule(quantized_model, folded_graph) + quantized_model._register_state_dict_hook(_save_packed_weight) + quantized_model.register_load_state_dict_pre_hook(_load_packed_weight) + + if keep_original_weights: + setattr( # noqa: B010 + quantized_model, ORIGINAL_WEIGHTS_LOOKUP, original_weights_lookup + ) + + return quantized_model + + +def _get_module(node: Node, modules: dict[str, nn.Module]) -> Optional[nn.Module]: + """ + Return the `torch.nn.Module` that corresponds to the specified node's target. + If no such node exists, return None. + """ + if node.op == "call_module" and str(node.target) in modules: + return modules[str(node.target)] + else: + return None + + +def _match_static_pattern( + node: Node, + modules: dict[str, nn.Module], + qconfig_map: dict[str, QConfigAny], + matching_modules_or_ops: list[Callable], + dequantize_node_arg_indices: list[int], +) -> Union[tuple[Node, Node, Node], tuple[None, None, None]]: + """ + Match the pattern (dequantize - ref node - quantize) against the node provided. + + If there is a match, return a 3-tuple of: + 1) q_node: the quantize node, + 2) relu_node: a relu node wrapping the ref_node, and + 3) ref_node: a reference module or functional node to replace with its quantized counterpart + Otherwise, if there is no match, return a 3-tuple of (None, None, None). + + Parameters: + node: The `torch.fx.Node` to match against. + modules: A mapping from node names to modules in the model graph, used for module lookup. + qconfig_map: A mapping from node names to the qconfigs associated with the nodes. + If the corresponding qconfig for the reference node is None, then return no match. + matching_modules_or_ops: Either a list of functions or a list of `torch.nn.Module`s. + If the reference node is not in this list, then return no match. + dequantize_node_arg_indices: A list of indices in the reference node args where dequantize + nodes may be present. An empty list means skipping the check for dequantize nodes. + """ + SKIP_LOWERING_VALUE = (None, None, None) + + # Match quantize node + if node.op != "call_function" or node.target != torch.quantize_per_tensor: + return SKIP_LOWERING_VALUE + q_node = node + ref_node = q_node.args[0] + assert isinstance(ref_node, Node) + + # Handle cases where the node is wrapped in a ReLU + if (ref_node.op == "call_function" and ref_node.target in (F.relu, torch.relu)) or ( + ref_node.op == "call_module" and type(_get_module(ref_node, modules)) == nn.ReLU + ): + relu_node = ref_node + ref_node = relu_node.args[0] + assert isinstance(ref_node, Node) + else: + relu_node = None + if should_skip_lowering(ref_node, qconfig_map): + return SKIP_LOWERING_VALUE + + # Match reference module or functional + if isinstance(matching_modules_or_ops[0], type) and issubclass( + matching_modules_or_ops[0], nn.Module + ): + expected_op = "call_module" + match_key = type(_get_module(ref_node, modules)) + else: + expected_op = "call_function" + match_key = ref_node.target # type: ignore[assignment] + if ref_node.op != expected_op or match_key not in matching_modules_or_ops: + return SKIP_LOWERING_VALUE + + # Match dequantize node(s). Both of the following conditions must pass: + # (1) All `torch.fx.Node`s at the matching indices must be a dequantize node + # (2) There must be at least one dequantize node + matched_dequantize = False + for i in dequantize_node_arg_indices: + assert i < len(ref_node.args), ( + f"Dequantize index {i} exceeded reference node's arg length {len(ref_node.args)}" + ) + arg = ref_node.args[i] + if is_dequantize_node(arg): + matched_dequantize = True + elif isinstance(arg, Node): + return SKIP_LOWERING_VALUE + if not matched_dequantize: + return SKIP_LOWERING_VALUE + + return (q_node, relu_node, ref_node) # type: ignore[return-value] + + +def _match_static_pattern_with_two_inputs( + node: Node, + modules: dict[str, nn.Module], + qconfig_map: dict[str, QConfigAny], + matching_modules_or_ops: list[Callable], +) -> Union[tuple[Node, Node], tuple[None, None]]: + """ + (dequantize \ + Match the pattern (dequantize - ref node - quantize) against the node provided. + + If there is a match, return a 2-tuple of: + 1) q_node: the quantize node, + 2) ref_node: a reference module or functional node to replace with its quantized counterpart + Otherwise, if there is no match, return a 2-tuple of (None, None). + + Parameters: + node: The `torch.fx.Node` to match against. + modules: A mapping from node names to modules in the model graph, used for module lookup. + qconfig_map: A mapping from node names to the qconfigs associated with the nodes. + If the corresponding qconfig for the reference node is None, then return no match. + matching_modules_or_ops: Either a list of functions or a list of `torch.nn.Module`s. + If the reference node is not in this list, then return no match. + """ + SKIP_LOWERING_VALUE = (None, None) + + # Match quantize node + if node.op != "call_function" or node.target != torch.quantize_per_tensor: + return SKIP_LOWERING_VALUE + q_node = node + ref_node = q_node.args[0] + assert isinstance(ref_node, Node) + + if should_skip_lowering(ref_node, qconfig_map): + return SKIP_LOWERING_VALUE + + # Match reference module or functional + if isinstance(matching_modules_or_ops[0], type) and issubclass( + matching_modules_or_ops[0], nn.Module + ): + expected_op = "call_module" + match_key = type(_get_module(ref_node, modules)) + else: + # This pass only support op of "call_module" + return SKIP_LOWERING_VALUE + + if ref_node.op != expected_op or match_key not in matching_modules_or_ops: + return SKIP_LOWERING_VALUE + + # Check ref_node has 2 input nodes, both are dq node. + if len(ref_node.args) != 2: + return SKIP_LOWERING_VALUE + for i in range(len(ref_node.args)): + arg = ref_node.args[i] + if not is_dequantize_node(arg): + return SKIP_LOWERING_VALUE + + return (q_node, ref_node) + + +def _lower_static_weighted_ref_module( + model: GraphModule, qconfig_map: dict[str, QConfigAny] +): + """ + Traverse the graph and find dequantize - ref module - quantize patterns + and replace them with the quantized version of the ref module. + """ + modules = dict(model.named_modules(remove_duplicate=False)) + for n in model.graph.nodes: + # Step 0: Find nodes that match this pattern (dequantize - ref module - quantize) + matching_modules = list(STATIC_LOWER_MODULE_MAP.keys()) + list( + STATIC_LOWER_FUSED_MODULE_MAP.keys() + ) + q_node, _relu_node, ref_node = _match_static_pattern( + n, + modules, + qconfig_map, + matching_modules, # type: ignore[arg-type] + dequantize_node_arg_indices=[0], + ) + if q_node is None: + continue + assert ref_node is not None + (_, scale_node, zero_point_node, _) = q_node.args + ref_module = _get_module(ref_node, modules) + ref_class = type(ref_module) + assert isinstance(scale_node, Node) + assert isinstance(zero_point_node, Node) + assert issubclass(ref_class, nn.Module) + + # Step 1: Change this pattern to use the corresponding quantized module + # For fused modules, we also check whether the inner module is a reference module + # If so, we replace the entire fused module with the corresponding quantized module + if ref_class in STATIC_LOWER_FUSED_MODULE_MAP: + inner_ref_class, q_class = STATIC_LOWER_FUSED_MODULE_MAP[ref_class] + if type(ref_module[0]) != inner_ref_class: # type: ignore[index] + continue + else: + q_class = STATIC_LOWER_MODULE_MAP[ref_class] + output_scale = getattr(model, scale_node.target) # type: ignore[arg-type] + output_zero_point = getattr(model, zero_point_node.target) # type: ignore[arg-type] + q_module = q_class.from_reference(ref_module, output_scale, output_zero_point) + # replace reference module with quantized module + parent_name, module_name = _parent_name(ref_node.target) + setattr(modules[parent_name], module_name, q_module) + + # Step 2: Reroute around dq_node, and remove q_node and its args + assert len(ref_node.args) == 1 + dq_node = ref_node.args[0] + assert isinstance(dq_node, Node) + ref_node.replace_input_with(dq_node, dq_node.args[0]) # type: ignore[arg-type] + q_node.replace_all_uses_with(ref_node) + model.graph.erase_node(q_node) + model.graph.erase_node(scale_node) + model.graph.erase_node(zero_point_node) + + +def _lower_static_weighted_ref_module_with_two_inputs( + model: GraphModule, qconfig_map: dict[str, QConfigAny] +): + """ + Traverse the graph and find patterns + dequantize dequantize + \\ // + ref module + \\ + quantize + and replace them with the quantized version of the ref module. + """ + modules = dict(model.named_modules(remove_duplicate=False)) + for n in model.graph.nodes: + # (dequantize \ + # Step 0: Find nodes that match this pattern (dequantize - ref module - quantize) + matching_modules = list(STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP.keys()) + (q_node, ref_node) = _match_static_pattern_with_two_inputs( + n, + modules, + qconfig_map, + matching_modules, # type: ignore[arg-type] + ) + if q_node is None: + continue + assert ref_node is not None + (_, scale_node, zero_point_node, _) = q_node.args + ref_module = _get_module(ref_node, modules) + ref_class = type(ref_module) + assert isinstance(scale_node, Node) + assert isinstance(zero_point_node, Node) + assert issubclass(ref_class, nn.Module) + + # Step 1: Change this pattern to use the corresponding quantized module + # For fused modules, we also check whether the inner module is a reference module + # If so, we replace the entire fused module with the corresponding quantized module + if ref_class in STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP: + inner_ref_class, q_class = STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP[ + ref_class + ] + if type(ref_module[0]) != inner_ref_class: # type: ignore[index] + continue + else: + continue + output_scale = getattr(model, scale_node.target) # type: ignore[arg-type] + output_zero_point = getattr(model, zero_point_node.target) # type: ignore[arg-type] + q_module = q_class.from_reference(ref_module, output_scale, output_zero_point) + # replace reference module with quantized module + parent_name, module_name = _parent_name(ref_node.target) + setattr(modules[parent_name], module_name, q_module) + + # Step 2: Reroute around dq_node, and remove q_node and its args + assert len(ref_node.args) == 2 + for arg in ref_node.args: + if not is_dequantize_node(arg): + continue + dq_node = arg + assert isinstance(dq_node, Node) + ref_node.replace_input_with(dq_node, dq_node.args[0]) # type: ignore[arg-type] + + q_node.replace_all_uses_with(ref_node) + model.graph.erase_node(q_node) + model.graph.erase_node(scale_node) + model.graph.erase_node(zero_point_node) + + +def _lower_dynamic_weighted_ref_module(model: GraphModule): + """ + Traverse the graph and find quantize_per_tensor_dynamic - dequantize - ref_module patterns + and replace them with the dynamically quantized version of the ref module. + """ + named_modules = dict(model.named_modules(remove_duplicate=False)) + for n in model.graph.nodes: + if n.op != "call_module" or type(named_modules[str(n.target)]) not in set( + DYNAMIC_LOWER_MODULE_MAP.keys() + ).union(set(DYNAMIC_LOWER_FUSED_MODULE_MAP.keys())): + continue + ref_node = n + dq_node = ref_node.args[0] + if dq_node.op != "call_method" or dq_node.target != "dequantize": + continue + + input_dynamic_q_node = dq_node.args[0] + + if ( + input_dynamic_q_node.op != "call_function" + or input_dynamic_q_node.target != torch.quantize_per_tensor_dynamic + ): + continue + + activation_dtype = input_dynamic_q_node.args[1] + is_fp16 = activation_dtype == torch.float16 + is_int8 = activation_dtype in [torch.quint8, torch.qint8] + if not is_int8 and not is_fp16: + continue + + ref_module = named_modules[str(ref_node.target)] + ref_class = type(ref_module) + if ref_class in DYNAMIC_LOWER_FUSED_MODULE_MAP: + inner_ref_class, q_class = DYNAMIC_LOWER_FUSED_MODULE_MAP[ref_class] + if type(ref_module[0]) != inner_ref_class: + continue + else: + q_class = DYNAMIC_LOWER_MODULE_MAP.get(ref_class) # type: ignore[assignment] + # TODO: maybe define a WeightedDynamicallyQuantizedModule + q_module = q_class.from_reference(ref_module) # type: ignore[attr-defined] + + # replace reference module with dynamically quantized module + parent_name, module_name = _parent_name(ref_node.target) + setattr(named_modules[parent_name], module_name, q_module) + ref_node.replace_input_with(dq_node, input_dynamic_q_node.args[0]) + + +def _lower_weight_only_weighted_ref_module(model: GraphModule): + """ + Traverse the graph and find ref_module patterns + and replace them with the weight only quantized version of the ref module. + """ + named_modules = dict(model.named_modules(remove_duplicate=False)) + for n in model.graph.nodes: + if n.op != "call_module" or type(named_modules[str(n.target)]) not in set( + WEIGHT_ONLY_LOWER_MODULE_MAP.keys() + ): + continue + ref_node = n + ref_module = named_modules[str(ref_node.target)] + ref_class = type(ref_module) + q_class = WEIGHT_ONLY_LOWER_MODULE_MAP.get(ref_class) + # TODO: WeightedQuantizedModule is currently assuming static quant apis + # with output_scale, output_zero_point in from_reference, we may want to + # relax that, or rename this + # TODO: maybe define a WeightedWeightOnlyQuantizedModule + q_module = q_class.from_reference(ref_module) # type: ignore[union-attr] + + # replace reference module with dynamically quantized module + parent_name, module_name = _parent_name(ref_node.target) + setattr(named_modules[parent_name], module_name, q_module) + + +def _lower_static_weighted_ref_functional( + model: GraphModule, qconfig_map: dict[str, QConfigAny] +): + """ + Traverse the graph and replace functional reference patterns with their quantized versions. + """ + modules = dict(model.named_modules(remove_duplicate=False)) + for n in model.graph.nodes: + # Step 0: Find nodes that match this pattern (dequantize - functional op - quantize) + matching_ops = list(STATIC_LOWER_FUNCTIONAL_MAP.keys()) + (q_node, relu_node, func_node) = _match_static_pattern( + n, modules, qconfig_map, matching_ops, dequantize_node_arg_indices=[0, 1] + ) + if q_node is None: + continue + assert func_node is not None + (_, output_scale_node, output_zp_node, _) = q_node.args + (input_dq_node, weight_dq_node, *remaining_func_args) = func_node.args + assert isinstance(output_zp_node, Node) + assert isinstance(input_dq_node, Node) + assert isinstance(weight_dq_node, Node) + quantized_weight = weight_dq_node.args[0] + assert isinstance(quantized_weight, Node) + if quantized_weight.op != "call_function" or quantized_weight.target not in ( + torch.quantize_per_tensor, + torch.quantize_per_channel, + ): + continue + + # Step 1: Replace quantized weights with packed weights, which will be folded later + # Use the right prepack op and prepare the corresponding args + # Linear prepack args: (quantized weights[, bias]) + # Conv prepack args: (quantized weights[, bias, stride, padding, dilation, groups]) + prepack_args = [quantized_weight] + remaining_func_args + if func_node.target == F.linear: + weight_dtype = quantized_weight.args[-1] + prepack_op = get_linear_prepack_op_for_dtype(weight_dtype) + elif func_node.target in CONV_FUNCTIONAL_OPS: + prepack_op = get_qconv_prepack_op(func_node.target) # type: ignore[arg-type] + # For conv1d, the stride, padding, and dilation args may be ints, + # in which case we need to convert them to tuples + if func_node.target == F.conv1d: + for i in [2, 3, 4]: + if len(prepack_args) > i and isinstance(prepack_args[i], int): + prepack_args[i] = (prepack_args[i],) + elif func_node.target in CONV_TRANSPOSE_FUNCTIONAL_OPS: + prepack_op = get_qconv_prepack_op(func_node.target) # type: ignore[arg-type] + # For conv_transpose1d, the stride, padding, and dilation args may be ints, + # in which case we need to convert them to tuples + if func_node.target == F.conv_transpose1d: + # Note prepack_args[5] is groups. + for i in [2, 3, 4, 6]: + if len(prepack_args) > i and isinstance(prepack_args[i], int): + prepack_args[i] = (prepack_args[i],) + # swap dilation and groups + # prepack op has arguments: {w, b, stride, padding, output_padding, dilation, groups} + # transposed conv op has arguments: {x, w, b, stride, padding, output_padding, groups, dilation} + if len(prepack_args) > 6: + prepack_args[5], prepack_args[6] = prepack_args[6], prepack_args[5] + else: + raise ValueError(f"Lowering is not supported for op '{func_node.target}'") + with model.graph.inserting_before(output_scale_node): # type: ignore[arg-type] + # kwargs of the func node are needed for prepack op (i.e., quantized::linear_prepack) + # They are not needed for compute op (i.e., quantized::linear) + kwargs = func_node.kwargs + # F.linear uses 'bias' key for bias while qlinear_prepack uses 'B' for bias + if func_node.target == F.linear and "bias" in kwargs: + kwargs = kwargs.copy() + kwargs["B"] = kwargs["bias"] + del kwargs["bias"] + packed_weight = model.graph.create_node( + "call_function", prepack_op, tuple(prepack_args), kwargs + ) + + # Step 2: Replace reference pattern with the corresponding quantized op + (q_func, q_relu_func) = STATIC_LOWER_FUNCTIONAL_MAP[func_node.target] # type: ignore[index] + # conv_transpose does not support fusion with relu yet. q_relu_func is None in such cases + if q_relu_func is not None: + func_node.target = q_relu_func if relu_node is not None else q_func + else: + func_node.target = q_func + func_node.args = ( + input_dq_node.args[0], + packed_weight, + output_scale_node, + output_zp_node, + ) + # kwargs for func_node has been moved to kwargs for prepack op + func_node.kwargs = {} + q_node.replace_all_uses_with(func_node) + # Move func_node after output_zp_node in the graph + output_zp_node.append(func_node) + + # Clean up: Remove quantize node, and the relu node if it exists + model.graph.erase_node(q_node) + if relu_node is not None and q_relu_func is not None: + model.graph.erase_node(relu_node) + + +def _lower_dynamic_weighted_ref_functional( + model: GraphModule, qconfig_map: dict[str, QConfigAny] +): + """ + Traverse the graph and replace functional reference patterns with their dynamically + quantized versions. + Examples: + quantize_per_tensor_dynamic - dequantize - functional linear --> linear_dynamic + to(torch.float16) - dequantize - functional linear --> linear_dynamic_fp16 + """ + modules = dict(model.named_modules(remove_duplicate=False)) + # we want to search in reserved order so that we can match the larger patterns first + # e.g. we want to match linear - relu before linear. + for n in reversed(model.graph.nodes): + # Step 0: Find nodes that match this pattern + # (quantize_per_tensor_dynamic - dequantize - dynamically quantized op) + # We search for the pattern backwards, starting with the quantize node + # Quantize node args: (func, scale, zp, dtype) + func_node = n + # Handle cases where the functional op is wrapped in a ReLU + if ( + func_node.op == "call_function" + and func_node.target == F.relu + or func_node.op == "call_module" + and type(modules[str(func_node.target)]) == torch.nn.ReLU + ): + relu_node = func_node + func_node = relu_node.args[0] + else: + relu_node = None + if should_skip_lowering(func_node, qconfig_map): + continue + # Linear args: (dequantized inputs, dequantized weights[, bias]) + # Conv args: (dequantized inputs, dequantized weights[, bias, stride, padding, dilation, groups]) + if ( + func_node.op != "call_function" + or func_node.target not in DYNAMIC_LOWER_FUNCTIONAL_MAP + ): + continue + (input_dq_node, weight_dq_node, *remaining_func_args) = func_node.args + if ( + input_dq_node.op != "call_method" + or input_dq_node.target != "dequantize" + or weight_dq_node.op != "call_method" + or weight_dq_node.target != "dequantize" + ): + continue + + input_dynamic_q_node = input_dq_node.args[0] + + if ( + input_dynamic_q_node.op != "call_function" + or input_dynamic_q_node.target != torch.quantize_per_tensor_dynamic + ): + continue + + reduce_range_node = None + (pattern_input, activation_dtype, reduce_range_node) = input_dynamic_q_node.args + is_fp16 = activation_dtype == torch.float16 + is_int8 = activation_dtype in [torch.quint8, torch.qint8] + if not is_int8 and not is_fp16: + continue + + quantized_weight = weight_dq_node.args[0] + weight_dtype = quantized_weight.args[-1] + + # Step 1: Try to select reference pattern with the corresponding quantized op + dynamic_quant_dtype_key = (activation_dtype, weight_dtype) + if ( + dynamic_quant_dtype_key + not in DYNAMIC_LOWER_FUNCTIONAL_MAP[func_node.target] + ): + print( + f"Didn't find dtype combination {dynamic_quant_dtype_key} during " + f"dynamic quantized op lowering for {func_node.target}" + ) + continue + (q_func, q_relu_func) = DYNAMIC_LOWER_FUNCTIONAL_MAP[func_node.target][ + dynamic_quant_dtype_key + ] + + if q_func is None or q_relu_func is None: + print( + "Didn't find corresponding quantized function or quantized relu function " + f"for {func_node.target}, {dynamic_quant_dtype_key}" + ) + continue + + # Step 2: Replace quantized weights with packed weights, which will be folded later + # Use the right prepack op and prepare the corresponding args + # Linear prepack args: (quantized weights[, bias]) + # Conv prepack args: (quantized weights[, bias, stride, padding, dilation, groups]) + prepack_args = [quantized_weight] + remaining_func_args + prepack_kwargs = {} + if func_node.target == F.linear: + prepack_op = get_linear_prepack_op_for_dtype(weight_dtype) + kwargs = func_node.kwargs.copy() + if "bias" in kwargs: + prepack_kwargs["B"] = kwargs["bias"] + del kwargs["bias"] + func_node.kwargs = kwargs + elif func_node.target in CONV_FUNCTIONAL_OPS: + prepack_op = get_qconv_prepack_op(func_node.target) + # For conv1d, the stride, padding, and dilation args may be ints, + # in which case we need to convert them to tuples + if func_node.target == F.conv1d: + for i in [2, 3, 4]: + if len(prepack_args) > i and isinstance(prepack_args[i], int): + prepack_args[i] = (prepack_args[i],) + else: + raise ValueError(f"Lowering is not supported for op '{func_node.target}'") + with model.graph.inserting_before(func_node): + packed_weight = model.graph.create_node( + "call_function", prepack_op, tuple(prepack_args), prepack_kwargs + ) + + # Step 3: Replace reference pattern with the corresponding quantized op + func_node.target = q_relu_func if relu_node is not None else q_func + if is_int8: + func_node.args = (pattern_input, packed_weight, reduce_range_node) + else: + func_node.args = (pattern_input, packed_weight) + + if relu_node is not None: + relu_node.replace_all_uses_with(func_node) + + # Step 4: Remove the relu node if it exists + if relu_node is not None: + model.graph.erase_node(relu_node) + + +def _lower_quantized_binary_op(model: GraphModule, qconfig_map: dict[str, QConfigAny]): + binary_ops_to_lower: list[Callable] = [ + operator.add, + torch.add, + operator.mul, + torch.mul, + torch.matmul, + ] + modules = dict(model.named_modules(remove_duplicate=False)) + for n in model.graph.nodes: + # Step 0: Find nodes that match this pattern (dequantize - ref module - quantize) + (q_node, relu_node, bop_node) = _match_static_pattern( + n, + modules, + qconfig_map, + binary_ops_to_lower, + dequantize_node_arg_indices=[0, 1], + ) + if q_node is None: + continue + assert bop_node is not None + (_, scale_node, zero_point_node, _) = q_node.args + + # Step 1: Remove dequant nodes + num_dq_nodes = 0 + for arg in bop_node.args: + if not is_dequantize_node(arg): + continue + dq_node = arg + assert isinstance(dq_node, Node) + dn_input = dq_node.args[0] + bop_node.replace_input_with(dq_node, dn_input) # type: ignore[arg-type] + num_dq_nodes += 1 + assert num_dq_nodes > 0 + + # Step 2: Swap binary op to quantized binary op + assert bop_node.target in QBIN_OP_MAPPING + binop_to_qbinop = QBIN_OP_MAPPING if relu_node is None else QBIN_RELU_OP_MAPPING + qbin_op = binop_to_qbinop[bop_node.target] + # prepare the args for quantized binary op + # (x, y) + qop_node_args = list(bop_node.args) + # (x, y, scale, zero_point) + # add scale and zero_point arguments for Tensor - Tensor operation + if num_dq_nodes == 2: + qop_node_args.extend([scale_node, zero_point_node]) + # insert a call to quantized binary op and remove the original binary op + with model.graph.inserting_after(q_node): + qop_node = create_node_from_old_node_preserve_meta( + model.graph, + ("call_function", qbin_op, tuple(qop_node_args), {}), + bop_node, + ) + q_node.replace_all_uses_with(qop_node) + + # Step 3: Remove quantize node, binary op node, and relu node if any + model.graph.erase_node(q_node) + if relu_node is not None: + model.graph.erase_node(relu_node) + model.graph.erase_node(bop_node) + + +def special_pattern_replacement(model: GraphModule): + modules = dict(model.named_modules(remove_duplicate=False)) + for n in model.graph.nodes: + q_node = n + is_quantize = q_node.target == torch.quantize_per_tensor + is_to_fp16 = ( + q_node.op == "call_method" + and q_node.target == "to" + and len(q_node.args) == 2 + and q_node.args[1] == torch.float16 + ) + if not (is_quantize or is_to_fp16): + continue + ref_node = q_node.args[0] + # get output scale/zero_point/dtype from the quantize node + # ref_node, scale_node, zero_point_node, dtype = q_node.args + # TODO: add safety checks that users for the ref_node and dq_node needs to be one + is_call_function, is_call_method, is_call_module = is_fixed_qparams_node( + ref_node, modules + ) + if is_to_fp16 and (is_call_function or is_call_method or is_call_module): + # TODO: add a warning or error out here? (bc-breaking if error out) + # warnings.warn( + # "Only reference patterns are currently supported for {dtype} dtype with {op} op" + # "".format(dtype=dtypes, op=ref_node)) + continue + + is_call_function, is_call_method, is_call_module = is_default_node( + ref_node, modules + ) + if is_to_fp16 and (is_call_function or is_call_method or is_call_module): + # TODO: add a warning or error out here? (bc-breaking if error out) + continue + + # This check includes all supported ops + is_call_function, is_call_method, is_call_module = is_special_pattern_node( + ref_node, modules + ) + if not (is_call_module or is_call_function or is_call_method): + continue + assert len(ref_node.args) > 0 or len(ref_node.kwargs) > 0 + dq_node_or_nodes = ( + ref_node.args[0] + if len(ref_node.args) > 0 + else next(iter(ref_node.kwargs.values())) + ) + assert isinstance(dq_node_or_nodes, (Node, tuple, list)) + is_dequantize = False + if isinstance(dq_node_or_nodes, Node): + is_dequantize = ( + dq_node_or_nodes.op == "call_method" + and dq_node_or_nodes.target == "dequantize" + ) + elif isinstance(dq_node_or_nodes, (tuple, list)): + is_dequantize = all( + x.op == "call_method" and x.target == "dequantize" + for x in dq_node_or_nodes + ) + + if not is_dequantize: + continue + + # TODO: enable we have patterns that needs to swap the modules + if is_call_module: + ref_module = modules[ref_node.target] + if type(ref_module) in SPECIAL_PATTERN_LOWER_MODULE_MAP and is_quantize: + qmodule_cls = SPECIAL_PATTERN_LOWER_MODULE_MAP.get(type(ref_module)) + scale_node = q_node.args[1] + zero_point_node = q_node.args[2] + output_scale = getattr(model, scale_node.target) + output_zero_point = getattr(model, zero_point_node.target) + + qmodule = qmodule_cls.from_reference( # type:ignore[union-attr] + ref_module, output_scale, output_zero_point + ) + # replace reference module with quantized module + parent_name, module_name = _parent_name(ref_node.target) + setattr(modules[parent_name], module_name, qmodule) + + # reroute around dq node: + dq_nodes: list[Node] = [] + if isinstance(dq_node_or_nodes, Node): + dq_nodes = [dq_node_or_nodes] + elif isinstance(dq_node_or_nodes, (tuple, list)): + dq_nodes = list(dq_node_or_nodes) + + for dq_node in dq_nodes: + dn_input = dq_node.args[0] + ref_node.replace_input_with(dq_node, dn_input) + + # store q node args + qnode_qparams = list(q_node.args)[1:] + # replace uses of q node with input and remove q node + q_node_input = q_node.args[0] + q_node.replace_all_uses_with(q_node_input) + model.graph.erase_node(q_node) + + is_call_function, is_call_method, is_call_module = is_default_node( + ref_node, modules + ) + if is_call_function: + # pass scale/zer_point arguments from quantize_per_tensor to the default node operator + # insert an op after the zero_point node so that the scale/zero_point + # nodes are is available + qop = get_quantized_operator(ref_node.target) + args = list(ref_node.args) + kwargs = dict(ref_node.kwargs) + if qop in QOP_TO_ARG_NAMES_TO_SKIP: + args_to_skip = QOP_TO_ARG_NAMES_TO_SKIP[qop] + for arg in args_to_skip: + if arg in kwargs: + kwargs.pop(arg) + kwargs["output_scale"] = qnode_qparams[0] + kwargs["output_zero_point"] = qnode_qparams[1] + with model.graph.inserting_after(qnode_qparams[1]): + qop_node = create_node_from_old_node_preserve_meta( + model.graph, ("call_function", qop, tuple(args), kwargs), ref_node + ) + ref_node.replace_all_uses_with(qop_node) + model.graph.erase_node(ref_node) + else: + # remove scale/zero_point node for quantize node + for n in qnode_qparams: + if isinstance(n, Node): + model.graph.erase_node(n) + + return model + + +def _lower_getattr_tensor_metadta_op(model: GraphModule): + """Modified the graph of the model inplace, to skip extra dequantize op before + the general tensor shape ops when possible + """ + for n in model.graph.nodes: + if is_getattr_tensor_metadata_node(n): + maybe_dq = n.args[0] + if maybe_dq.op != "call_method" or maybe_dq.target != "dequantize": + continue + # skip the dequantize node + args = list(n.args) + args[0] = n.args[0].args[0] + n.args = tuple(args) + + +def _lower_get_tensor_info_op(model: GraphModule): + """Modified the graph of the model inplace, to skip extra dequantize op before + the general tensor shape ops when possible + """ + for n in model.graph.nodes: + if not is_get_tensor_info_node(n): + continue + maybe_dq = n.args[0] + if maybe_dq.op != "call_method" or maybe_dq.target != "dequantize": + continue + # skip the dequantize node + args = list(n.args) + args[0] = n.args[0].args[0] + n.args = tuple(args) + + +def _lower_to_native_backend( + model: GraphModule, + qconfig_map: dict[str, QConfigAny], + node_name_to_scope: dict[str, tuple[str, type]], + keep_original_weights: bool = False, +) -> GraphModule: + """Lower a quantized reference model (with reference quantized operator patterns) + to the native backend in PyTorch (fbgemm/qnnpack), both backends shares the same + operator signature so they can be lowered with the same function + """ + _lower_static_weighted_ref_module(model, qconfig_map) + _lower_static_weighted_ref_module_with_two_inputs(model, qconfig_map) + _lower_dynamic_weighted_ref_module(model) + _lower_weight_only_weighted_ref_module(model) + _lower_static_weighted_ref_functional(model, qconfig_map) + _lower_dynamic_weighted_ref_functional(model, qconfig_map) + _lower_quantized_binary_op(model, qconfig_map) + _lower_getattr_tensor_metadta_op(model) + _lower_get_tensor_info_op(model) + special_pattern_replacement(model) + model.graph.eliminate_dead_code() + model = fold_weight(model, node_name_to_scope, keep_original_weights) + model.graph.eliminate_dead_code() + model.recompile() + model.graph.lint() + return model diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/_model_report/__init__.py b/phivenv/Lib/site-packages/torch/ao/quantization/fx/_model_report/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49db78739811cd60c0d5a68f6e3189f74033f780 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/detector.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/detector.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2477c4d80b73ebe96d2834fd10ac55c9d16d5d4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/detector.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f309373bd81f831df91270179d18189e7578e7f8 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report_observer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report_observer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cec22bbdcc467f9bca0f76197cc4d4504501afbb Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report_observer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report_visualizer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report_visualizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55bf8c72c3c513cc19f2907ebcbe3c991497cae6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report_visualizer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/_model_report/detector.py b/phivenv/Lib/site-packages/torch/ao/quantization/fx/_model_report/detector.py new file mode 100644 index 0000000000000000000000000000000000000000..776ea0553b6da568c8ac49a6dcdfb65b81cb9968 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/fx/_model_report/detector.py @@ -0,0 +1,1733 @@ +# mypy: allow-untyped-defs +from abc import ABC, abstractmethod +from typing import Any, Callable + +import torch +import torch.ao.nn.qat as nnqat +import torch.nn as nn +from torch.ao.quantization.fake_quantize import FakeQuantize +from torch.ao.quantization.fx._equalize import ( + default_equalization_qconfig, + EqualizationQConfig, +) +from torch.ao.quantization.fx._model_report.model_report_observer import ( + ModelReportObserver, +) +from torch.ao.quantization.fx.graph_module import GraphModule +from torch.ao.quantization.observer import ( + _is_activation_post_process, + default_dynamic_quant_observer, + default_observer, + default_per_channel_weight_observer, + default_weight_observer, + ObserverBase, +) +from torch.ao.quantization.qconfig import ( + _assert_valid_qconfig, + default_qconfig, + QConfig, +) + + +# Names for observer insert keys +DETECTOR_TARGET_NODE_KEY = "target_node" +DETECTOR_OBS_TO_INSERT_KEY = "observer_to_insert" +DETECTOR_IS_POST_OBS_KEY = "is_post_observer" +DETECTOR_OBS_ARGS_KEY = "observer_args" + + +# Mapping related code +class DetectorQConfigInfo: + r""" + This class contains the QConfig information for a single module. + The list of variables / values this contains can grow depending on the + extensibility of the qconfig mapping feature set but this currently includes: + - if activation observer is dynamic + - if weight observer is per channel + + + Args: + module_fqn (str): The fully qualified name (fqn) of the module that this + information contains info relevant to qconfig for + """ + + def __init__(self, module_fqn: str): + super().__init__() + self.module_fqn = module_fqn + + # populate this section with all the variables we might find important + # change from none if your detector is actually using this + self.is_activation_dynamic = False + self.is_weight_per_channel = False + + # equalization related options + self.is_equalization_recommended = False + + def generate_quantization_qconfig(self, module: torch.nn.Module) -> QConfig: + r""" + Args: + module (torch.nn.Module) The module we are generating + the qconfig for + + Returns the generated quantization QConfig according to what a valid configuration is + """ + # Apply suggestions to new qconfig + module_qconfig = default_qconfig + + # keep track of dynamic and per_channel recommendations + recommendations_list = [] + # append as if a list of combinations + recommendations_list.append( + (self.is_activation_dynamic, self.is_weight_per_channel) + ) + recommendations_list.append( + (self.is_activation_dynamic, False) + ) # only trying dynamic rec + recommendations_list.append( + (False, self.is_weight_per_channel) + ) # only trying dynamic + + # now we try each of the combinations + for rec in recommendations_list: + # rec[0] -> dynamic recommended + # rec[1] -> per channel recommended + activation = default_dynamic_quant_observer if rec[0] else default_observer + weight = ( + default_per_channel_weight_observer + if rec[1] + else default_weight_observer + ) + test_config = QConfig(activation, weight) + try: + _assert_valid_qconfig(test_config, module) + module_qconfig = test_config + break + except AssertionError: + # if not a valid configuration, we move on to the next one in priority + continue + + # return the QConfig chosen + return module_qconfig + + def generate_equalization_qconfig(self) -> EqualizationQConfig: + r""" + This returns the equalization configuration for a module. + + For now, it just returns the default, but as more equalization options become + possible, this method can get more fleshed out with more nuanced granularity. + + + Returns the generated equalization QConfig according to what a valid configuration is + """ + # in this case, we just return default equalization config + # we know this is valid because only valid modules would even + # have this option + return default_equalization_qconfig + + +# Adding base class for detectors +class DetectorBase(ABC): + r"""Base Detector Module + Any detector class should derive from this class. + + Concrete detectors should follow the same general API, which includes: + - A method to calculate and return observer insertion points + - Should return both the fqns and the Observer class to insert + - A method to return a report based on the detector + - Should return a str-based report and dict info in Tuple[str,Dict] format + """ + + def __init__(self) -> None: + super().__init__() + self.detector_config_info = None + + @abstractmethod + def determine_observer_insert_points(self, model) -> dict: + r""" + Args + model (nn.Module or subclass): model to find observer insertion points + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict. + This dict maps string keys to detector specific information + """ + + @abstractmethod + def get_detector_name(self) -> str: + r"""Returns the name of the current detector""" + + @abstractmethod + def get_qconfig_info(self, model) -> dict[str, DetectorQConfigInfo]: + r"""Returns the DetectorQConfigInfo for each module_fqn relevant + Args + model (nn.Module or subclass): model to find observer insertion points + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to: + A DetectorQConfigInfo with the information to generate a QConfig for a specific module + """ + + def _get_targeting_node( + self, prepared_fx_model: GraphModule, target_fqn: str + ) -> torch.fx.node.Node: + r""" + Takes in a GraphModule and the target_fqn and finds the node whose target is this fqn. + + If it's not found, it means it is most likely inside a fused layer + We just go one layer up in terms of the fqn we are searching for until we find parent node + If we get to empty string, then we know that it doesn't exist + + The reason for the recursion is that if the model that we are looking for got fused, + we will have module fqn as e.g. x.linear.0 but the graph will only have a node for the fused module, + which would have fqn as x.linear so they will not match. + To handle this, if we don't match, we then take off the last bit of the fqn e.g. x.linear.0 -> x.linear, + or more generally foo.bar.baz -> foo.bar and search again, this will allow us to locate the correct module + even in cases with fusion + + Args: + prepared_fx_model (GraphModule): The prepared Fx GraphModule + target_fqn (str): The fqn of the layer we are trying to target + + Returns the node object we are trying to add observers around + """ + for node in prepared_fx_model.graph.nodes: + # if the node's target is our target, return it + if node.target == target_fqn: + return node + + # getting here means node not found + # if no "." we are already at base and failed + parent_fqn_sep_index = target_fqn.rfind(".") + if parent_fqn_sep_index == -1: + raise ValueError("passed in target_fqn not found in graph's targets.") + else: + # recursively call it with parent fqn + return self._get_targeting_node( + prepared_fx_model, target_fqn[:parent_fqn_sep_index] + ) + + @abstractmethod + def generate_detector_report(self, model) -> tuple[str, dict[str, Any]]: + r""" + Args + model (nn.Module or subclass): model to find observer insertion points + + Returns a Tuple of two elements: + Str: string report of the suggested improvements + Dict: contains useful data collected by the observer pertinent to this report + """ + + +class PerChannelDetector(DetectorBase): + r"""This class is used to detect if any Linear or Conv layers in a model utilize per_channel quantization. + Only Linear and Conv layers can use per_channel as of now so only these two are currently checked. + + per_channel quantization can lead to major benefits in the form of accuracy. + Therefore, if the backend used by the user supports it, it is recommended to use + + Args: + backend (str, optional): the backend the user wishes to use in production + Default value is current torch.backends.quantized.engine + """ + + # Keys for return dictionary + BACKEND_KEY = "backend" + PER_CHAN_SUPPORTED_KEY = "per_channel_quantization_supported" + PER_CHAN_USED_KEY = "per_channel_quantization_used" + + # Default map for representing supported per channel quantization modules for different backends + DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES: dict[str, set[Any]] = { + "fbgemm": { + nn.Linear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nnqat.Linear, + nnqat.Conv1d, + nnqat.Conv2d, + nnqat.Conv3d, + }, + "qnnpack": { + nn.Linear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nnqat.Linear, + nnqat.Conv1d, + nnqat.Conv2d, + nnqat.Conv3d, + }, + "onednn": { + nn.Linear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nnqat.Linear, + nnqat.Conv1d, + nnqat.Conv2d, + nnqat.Conv3d, + }, + "x86": { + nn.Linear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nnqat.Linear, + nnqat.Conv1d, + nnqat.Conv2d, + nnqat.Conv3d, + }, + } + + def __init__(self, backend: str = torch.backends.quantized.engine): + super().__init__() + + # store the backend information + self.backend_chosen = backend + self.supported_modules = set() + if self.backend_chosen in self.DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES: + self.supported_modules = self.DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES[ + self.backend_chosen + ] + else: + raise ValueError( + f"Not configured to work with {self.backend_chosen}. Try a different default backend" + ) + + def get_detector_name(self) -> str: + r"""returns the string name of this detector""" + return "per_channel_detector" + + def get_qconfig_info(self, model) -> dict[str, DetectorQConfigInfo]: + r"""Returns the DetectorQConfigInfo for each module_fqn relevant + Args + model (nn.Module or subclass): model to find observer insertion points + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to: + A DetectorQConfigInfo with the information to generate a QConfig for a specific module + """ + # run the helper function to populate the dictionary + per_channel_info = self._detect_per_channel_helper(model) + + # we actually have a qconfig info object we are populating + module_fqn_to_detector_qconfig_info = {} + + for module_fqn in per_channel_info: + # create a detector info instance + detector_qconfig_info = DetectorQConfigInfo(module_fqn) + + # see if per channel quantization is supported + per_chan_supported: bool = per_channel_info[module_fqn][ + self.PER_CHAN_SUPPORTED_KEY + ] + detector_qconfig_info.is_weight_per_channel = per_chan_supported + module_fqn_to_detector_qconfig_info[module_fqn] = detector_qconfig_info + + return module_fqn_to_detector_qconfig_info + + def determine_observer_insert_points(self, model: nn.Module) -> dict: + r""" + There is no observers inserted for the PerChannelDetector. + + Returns an empty dictionary since no observers are added or needed + """ + return {} + + def _detect_per_channel_helper(self, model: nn.Module): + r""" + determines if per_channel quantization is supported in modules and submodules. + + Returns a dictionary in the higher level _detect_per_channel function. + Each entry maps the fully-qualified-name to information on whether per_channel quantization. + + Args: + model: The current module that is being checked to see if it is per_channel quantizable + + Returns dictionary mapping fqns to if per_channel quantization is possible + """ + # create dict we will return + per_channel_info: dict = {} + + # get the fully qualified name and check if in list of modules to include and list of modules to ignore + for fqn, module in model.named_modules(): + is_in_include_list = any( + isinstance(module, x) for x in self.supported_modules + ) + + # check if the module per_channel is supported + # based on backend + per_channel_supported = False + + if is_in_include_list: + per_channel_supported = True + + # assert statement for MyPy + q_config_file = module.qconfig + assert isinstance(q_config_file, QConfig) + + # this object should either be fake quant or observer + q_or_s_obj = module.qconfig.weight.p.func() + assert isinstance(q_or_s_obj, (FakeQuantize, ObserverBase)) + + per_channel_used = False # will be true if found in qconfig + + if hasattr( + q_or_s_obj, "ch_axis" + ): # then we know that per_channel quantization used + # all fake quants have channel axis so need to check is_per_channel + if isinstance(q_or_s_obj, FakeQuantize): + if ( + hasattr(q_or_s_obj, "is_per_channel") + and q_or_s_obj.is_per_channel + ): + per_channel_used = True + elif isinstance(q_or_s_obj, ObserverBase): + # should be an observer otherwise + per_channel_used = True + else: + raise ValueError("Should be either observer or fake quant") + + per_channel_info[fqn] = { + self.PER_CHAN_SUPPORTED_KEY: per_channel_supported, + self.PER_CHAN_USED_KEY: per_channel_used, + self.BACKEND_KEY: self.backend_chosen, + } + + return per_channel_info + + def generate_detector_report(self, model: nn.Module) -> tuple[str, dict[str, Any]]: + r"""Checks if any Linear or Conv layers in the model utilize per_channel quantization. + Only Linear and Conv layers can use per_channel as of now so only these two are currently checked. + + Looks at q_config format and backend to determine if per_channel can be utilized. + Uses the DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES structure to determine support + + Args: + model: The prepared and calibrated model we want to check if using per_channel + + Returns a tuple with two elements: + String report of potential actions to improve model (if per_channel quantization is available in backend) + Dictionary mapping per_channel quantizable elements to: + whether per_channel quantization is supported by the backend + if it is being utilized in the current model + """ + + # run the helper function to populate the dictionary + per_channel_info = self._detect_per_channel_helper(model) + + # String to let the user know of further optimizations + further_optims_str = ( + f"Further Optimizations for backend {self.backend_chosen}: \n" + ) + + optimizations_possible = False + for fqn in per_channel_info: + fqn_dict = per_channel_info[fqn] + if ( + fqn_dict[self.PER_CHAN_SUPPORTED_KEY] + and not fqn_dict[self.PER_CHAN_USED_KEY] + ): + optimizations_possible = True + further_optims_str += ( + f"Module {fqn} can be configured to use per_channel quantization.\n" + ) + + if optimizations_possible: + further_optims_str += "To use per_channel quantization, make sure the qconfig has a per_channel weight observer." + else: + further_optims_str += "No further per_channel optimizations possible." + + # return the string and the dictionary form of same information + return (further_optims_str, per_channel_info) + + +class DynamicStaticDetector(DetectorBase): + r""" + Determines whether dynamic or static quantization is more appropriate for a given module. + + Takes advantage of the ModelReportObserver that records range information. + Stationary distribution of data are strictly above tolerance level for the comparison statistic: + + S = average_batch_activation_range/epoch_activation_range + + Nonstationary distributions are below or at the tolerance level for this metric. + + If the distribution of data right after the module is non-stationary, recommend dynamic quantization + Otherwise recommend static quantization + + Args: + tolerance (float, optional): The threshold where S metric is stationary above and non-stationary otherwise. Default: 0.5 + """ + + # names for the pre and post observers that are inserted + DEFAULT_PRE_OBSERVER_NAME = "model_report_pre_observer" + DEFAULT_POST_OBSERVER_NAME = "model_report_post_observer" + + # naming conventions for stationary vs non-stationary data + STATIONARY_STR = "stationary" + NON_STATIONARY_STR = "non-stationary" + + # naming for activation + INPUT_ACTIVATION_PREFIX = "input_activation_" + OUTPUT_ACTIVATION_PREFIX = "output_activation_" + + # naming conventions for the keys of the return module info + TOLERANCE_KEY = "dynamic_static_tolerance" + DEFAULT_DYNAMIC_REC_KEY = "dynamic_recommended" + PRE_OBS_COMP_STAT_KEY = INPUT_ACTIVATION_PREFIX + "dynamic_static_comp_stat" + POST_OBS_COMP_STAT_KEY = OUTPUT_ACTIVATION_PREFIX + "dynamic_static_comp_stat" + PRE_OBS_DATA_DIST_KEY = ( + INPUT_ACTIVATION_PREFIX + "dynamic_static_data_classification" + ) + POST_OBS_DATA_DIST_KEY = ( + OUTPUT_ACTIVATION_PREFIX + "dynamic_static_data_classification" + ) + IS_CURRENTLY_SUPPORTED_KEY = "is_dynamic_supported" + + # modules that are supported both dynamic and static for this report function + DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED = {nn.Linear} + + # modules that will be supported soon for both + DEFAULT_DYNAMIC_STATIC_FUTURE_SUPPORTED = {nn.Conv1d, nn.Conv2d, nn.Conv3d} + + def __init__(self, tolerance=0.5): + super().__init__() + + # set tolerance level and initialize a set to keep track of useful fqn locations + self.tolerance = tolerance + self.useful_observer_fqns: set[str] = set() + + def determine_observer_insert_points( + self, prepared_fx_model: GraphModule + ) -> dict[str, dict[str, Any]]: + r""" + Determines where observers need to be inserted for the Dynamic vs Static detector. + For this detector, we want to place observers on either side of linear layers in the model. + + Currently inserts observers for: + linear layers + + Args: + prepared_fx_model (GraphModule): The prepared Fx GraphModule + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with: + key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node) + key "observer_to_insert" -> the observer we wish to insert (ObserverBase) + key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer + key "observer_args" -> The arguments that are meant to be passed into the observer + """ + + # observer for this detector is ModelReportObserver + obs_ctr = ModelReportObserver + + # return dict + obs_fqn_to_info: dict[str, dict[str, Any]] = {} + + for fqn, module in prepared_fx_model.named_modules(): + # make sure module is supported + if self._is_supported(module, insert=True): + # if it's a supported type, we want to get node and add observer insert locations + targeted_node = self._get_targeting_node(prepared_fx_model, fqn) + + # add entry for pre-observer + pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME + + obs_fqn_to_info[pre_obs_fqn] = { + DETECTOR_TARGET_NODE_KEY: targeted_node, + DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(), + DETECTOR_IS_POST_OBS_KEY: False, + DETECTOR_OBS_ARGS_KEY: targeted_node.args, + } + + # add entry for post-observer + post_obs_fqn = fqn + "." + self.DEFAULT_POST_OBSERVER_NAME + + obs_fqn_to_info[post_obs_fqn] = { + DETECTOR_TARGET_NODE_KEY: targeted_node, + DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(), + DETECTOR_IS_POST_OBS_KEY: True, + DETECTOR_OBS_ARGS_KEY: (targeted_node,), + } + + return obs_fqn_to_info + + def get_detector_name(self) -> str: + r"""returns the string name of this detector""" + return "dynamic_vs_static_detector" + + def get_qconfig_info(self, model) -> dict[str, DetectorQConfigInfo]: + r"""Returns the DetectorQConfigInfo for each module_fqn relevant + Args + model (nn.Module or subclass): model to find observer insertion points + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to: + A DetectorQConfigInfo with the information to generate a QConfig for a specific module + """ + # run the helper function to populate the dictionary + dynamic_static_info = self._generate_dict_info(model) + + # we actually have a qconfig info object we are populating + module_fqn_to_detector_qconfig_info = {} + + for module_fqn in dynamic_static_info: + # create a detector info instance + detector_qconfig_info = DetectorQConfigInfo(module_fqn) + + # see if per channel quantization is supported + dynamic_static_recommended: bool = dynamic_static_info[module_fqn][ + self.DEFAULT_DYNAMIC_REC_KEY + ] + detector_qconfig_info.is_activation_dynamic = dynamic_static_recommended + module_fqn_to_detector_qconfig_info[module_fqn] = detector_qconfig_info + + return module_fqn_to_detector_qconfig_info + + def _is_supported(self, module: nn.Module, insert: bool = False) -> bool: + r"""Returns whether the given module is supported for observers + + Args + module: The module to check and ensure is supported + insert: True if this is check for observer insertion, false if for report gen + + Returns True if the module is supported by observer, False otherwise + """ + # check to see if module is of a supported type + is_supported_type = any( + isinstance(module, x) for x in self.DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED + ) + + # check if it will be supported + future_supported_type = any( + isinstance(module, x) for x in self.DEFAULT_DYNAMIC_STATIC_FUTURE_SUPPORTED + ) + + # supported + supported = is_supported_type or future_supported_type + + # this is check for observer insertion + if insert: + return supported + else: + # this is for report gen and we also need to check if it contains observers + has_obs = hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME) and hasattr( + module, self.DEFAULT_POST_OBSERVER_NAME + ) + return supported and has_obs + + def _generate_dict_info(self, model: GraphModule) -> dict[str, Any]: + r""" + Helper function for generate_detector_report that does the generation of the dictionary. + This process is done as specified in generate_detector_report documentation + + Args: + model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers + + Returns a Dictionary mapping modules with ModelReportObservers around them to: + whether dynamic quantization is recommended + their S metric of input to module + whether input to module is stationary or non-stationary + their S metric of output of module + whether output of module is stationary or non-stationary + the tolerance level to decided whether input/output is stationary or non-stationary + whether it is currently supported or planned for the future + """ + # store modules dynamic vs static information + module_dynamic_static_info = {} + + # This for loop goes through the modules, and extracts all relevant information into module_dynamic_static_info + # This information primary includes whether the data distributions around a supported module is stationary or not + # Based on this, it is recorded whether dynamic or static quantization is recommended + + # loop through all submodules included nested ones + for fqn, module in model.named_modules(): + # if module is Linear has the ModelReportObserver attached to it + if self._is_supported(module): + # get pre and post observers for the module + pre_obs = getattr(module, self.DEFAULT_PRE_OBSERVER_NAME) + post_obs = getattr(module, self.DEFAULT_POST_OBSERVER_NAME) + + # get the statistics for each module + pre_stat = pre_obs.get_batch_to_epoch_ratio() + post_stat = post_obs.get_batch_to_epoch_ratio() + + # record module, pre and post stat, and whether to do dynamic or static based off it + # true if post observer data distribution is non-stationary, false if it's stationary + dynamic_recommended = post_stat <= self.tolerance + + # specify the classifications for whether data distributions considered stationary or non-stationary + pre_obs_dist_classif = ( + self.STATIONARY_STR + if pre_stat > self.tolerance + else self.NON_STATIONARY_STR + ) + post_obs_dist_classif = ( + self.STATIONARY_STR + if post_stat > self.tolerance + else self.NON_STATIONARY_STR + ) + + # check if current support or future support + is_supported_type = any( + isinstance(module, x) + for x in self.DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED + ) + + # store the set of important information for this module + module_info = { + self.TOLERANCE_KEY: self.tolerance, + self.DEFAULT_DYNAMIC_REC_KEY: dynamic_recommended, + self.PRE_OBS_COMP_STAT_KEY: pre_stat, + self.PRE_OBS_DATA_DIST_KEY: pre_obs_dist_classif, + self.POST_OBS_COMP_STAT_KEY: post_stat, + self.POST_OBS_DATA_DIST_KEY: post_obs_dist_classif, + self.IS_CURRENTLY_SUPPORTED_KEY: is_supported_type, + } + + module_dynamic_static_info[fqn] = module_info + + return module_dynamic_static_info + + def generate_detector_report( + self, model: GraphModule + ) -> tuple[str, dict[str, Any]]: + r""" + Determines whether dynamic or static quantization is more appropriate for a given module. + + Takes advantage of the ModelReportObserver that records range information. + Stationary distribution of data are strictly above tolerance level for the comparison statistic: + + S = average_batch_activation_range/epoch_activation_range + + Nonstationary distributions are below or at the tolerance level for this metric. + + If the distribution of data right after the module is non-stationary, recommend dynamic quantization + Otherwise recommend static quantization + + This will then generate suggestions for dynamic vs static quantization focused around Linear. + + Args: + model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers + + Returns a tuple with two elements: + String report of of whether dynamic or static quantization is recommended for certain modules + Dictionary mapping modules with ModelReportObservers around them to: + whether dynamic quantization is recommended + their S metric of input to module + whether input to module is stationary or non-stationary + their S metric of output of module + whether output of module is stationary or non-stationary + the tolerance level to decided whether input/output is stationary or non-stationary + whether it is currently supported or planned for the future + """ + + # get the dictionary of the information to format the string report + module_dynamic_static_info = self._generate_dict_info(model) + + dynamic_vs_static_string = "Dynamic vs. Static Quantization suggestions: \n" + + modules_added: bool = False # check to make sure at least 1 module added. + + dynamic_benefit = ( + " You will get more accurate results if you use dynamic quantization" + ) + static_benefit = ( + " You can increase model efficiency if you use static quantization" + ) + future_support_str = ( + ". This layer is not yet supported for dynamic quantization" + ) + # This for loop goes through the information collected in module_dynamic_static_info and: + # Populates the string based report with the information from module_dynamic_static_info + # Compiles the complete report by appending relevant formatted strings + + for module_fqn in module_dynamic_static_info.keys(): + # there is at least 1 module for suggestion + modules_added = True + module_info = module_dynamic_static_info[module_fqn] + suggestion_string_template = ( + "For module {} it is suggested to use {} quantization because {}.\n" + ) + + # decide what string formatting values will be + quantization_type = "" + quantization_reasoning = "the distribution of data before {} is {} and the distribution after is {}." + + benefit_str = "" + + # strings for if dynamic quantized per tensor is needed + recommend_per_tensor = ( + ". We recommend to add a {} before this module if it is static." + ) + rec_lay_to_add = "dynamic quantize per tensor layer" + dynamic_per_tensor_string = recommend_per_tensor.format(rec_lay_to_add) + dynamic_per_tensor_reasoning_string = " This is because the input to this module has a non-stationary distribution" + + # start composing explanation + if module_info[self.DEFAULT_DYNAMIC_REC_KEY]: + quantization_type = "dynamic" + # check if currently supported or future supported + benefit_str = dynamic_benefit + if not module_info[self.IS_CURRENTLY_SUPPORTED_KEY]: + benefit_str += future_support_str + else: + quantization_type = "static" + benefit_str = static_benefit + + # now set the quantization explanation string + quantization_reasoning = ( + quantization_reasoning.format( + module_fqn, + module_info[self.PRE_OBS_DATA_DIST_KEY], + module_info[self.POST_OBS_DATA_DIST_KEY], + ) + + benefit_str + ) + + # if we have a non-stationary input -> linear -> stationary we suggested static + # however, we want to also recommend they add a dynamic quantize per tensor right if this change is made + if ( + module_info[self.PRE_OBS_DATA_DIST_KEY] == self.NON_STATIONARY_STR + and module_info[self.POST_OBS_DATA_DIST_KEY] == self.STATIONARY_STR + ): + quantization_reasoning = ( + quantization_reasoning + + dynamic_per_tensor_string + + dynamic_per_tensor_reasoning_string + ) + + # format the overall suggestion string with the specific inputs + module_suggestion_string = suggestion_string_template.format( + module_fqn, quantization_type, quantization_reasoning + ) + + # append to overall suggestion + dynamic_vs_static_string += module_suggestion_string + + if not modules_added: + dynamic_vs_static_string += "No applicable layers for suggestions. Only linear and conv are valid.\n" + + # return the string as well as the dictionary of information + return (dynamic_vs_static_string, module_dynamic_static_info) + + +class InputWeightEqualizationDetector(DetectorBase): + r""" + Determines whether input-weight equalization can help improve quantization for certain modules. + + Specifically, this list of modules includes: + linear + conv + + Determines whether input-weight equalization is recommended based on the comp stat: + s_c = sqrt(w_c/W)/sqrt(i_c/I) + where: + w_c is range of weight for channel c, W is range of weight over all channels + i_c is range of input for channel c, I is range of input over all channels + + if s_c >= threshold or <= 1 / threshold, recommends input-weight equalization + + Args: + ratio_threshold (float): The threshold for s_c to determine if input-weight equalization is suggested + Should be between 0 and 1 (both non-inclusive) + ch_axis (int, optional): The channel axis being observed to determine input weight equalization + Default: 1 + + * :attr:`ratio_threshold`: The threshold for s_c to determine if input-weight equalization is suggested + Should be between 0 and 1 + + * :attr:`ch_axis`: The channel axis being observed to determine input weight equalization + + * :attr:`SUPPORTED_MODULES`: This specifies the modules that are supported for input-weight equalization + + * :attr:`DEFAULT_PRE_OBSERVER_NAME`: The name of the pre-observer to be inserted for this detector + """ + + SUPPORTED_MODULES: set[Callable] = { + nn.Linear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nnqat.Linear, + nnqat.Conv1d, + nnqat.Conv2d, + nnqat.Conv3d, + } + + # names for the pre and post observers that are inserted + DEFAULT_PRE_OBSERVER_NAME: str = "model_report_pre_observer" + + # weight / activation prefix for each of the below info + WEIGHT_PREFIX = "weight_" + ACTIVATION_PREFIX = "input_activation_" + + # string names for keys of info dictionaries + PER_CHANNEL_MAX_KEY = "per_channel_max" + PER_CHANNEL_MIN_KEY = "per_channel_min" + GLOBAL_MAX_KEY = "global_max" + GLOBAL_MIN_KEY = "global_min" + + # keys for return dict of recommendations + RECOMMENDED_KEY = "input_weight_equalization_recommended" + COMP_METRIC_KEY = "input_weight_channel_comparison_metrics" + THRESHOLD_KEY = "input_weight_threshold" + CHANNEL_KEY = "input_weight_channel_axis" + + # default weight and info strings + WEIGHT_STR = "weight" + INPUT_STR = "input" + + # default for what ratio we recommend input weight + DEFAULT_RECOMMEND_INPUT_WEIGHT_CHANNEL_RATIO = 0.4 + + def __init__(self, ratio_threshold: float, ch_axis: int = 1): + # ensure passed in inputs are valid + if ratio_threshold <= 0 or ratio_threshold >= 1: + raise ValueError("Make sure threshold is > 0 and < 1") + + # initialize attributes based on args + self.ratio_threshold: float = ratio_threshold + self.ch_axis: int = ch_axis + + def _is_supported(self, module: nn.Module, insert: bool = False) -> bool: + r"""Returns whether the given module is supported for observers + + Args + module: The module to check and ensure is supported + insert: True if this is check for observer insertion, false if for report gen + + Returns True if the module is supported by observer, False otherwise + """ + # check to see if module is of a supported type + is_supported_type = any(type(module) is x for x in self.SUPPORTED_MODULES) + + # this is check for observer insertion + if insert: + return is_supported_type + else: + # this is for report gen and we also need to check if it contains observers + has_obs = hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME) + return is_supported_type and has_obs + + def get_qconfig_info(self, model) -> dict[str, DetectorQConfigInfo]: + r"""Returns the DetectorQConfigInfo for each module_fqn relevant + Args + model (nn.Module or subclass): model to find observer insertion points + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to: + A DetectorQConfigInfo with the information to generate a QConfig for a specific module + """ + # run the helper function to populate the dictionary + # find the range of inputs + input_values: dict[str, dict] = self._extract_input_info(model) + + # find the range of weights + weight_values: dict[str, dict] = self._extract_weight_info(model) + + # calculate per_channel comparison statistic s_c + comp_stats: dict[str, torch.Tensor] = self._generate_comparison_values( + input_values, weight_values + ) + + # generate the return dictionary + input_weight_equalization_info: dict[str, dict] = self._generate_dict_info( + input_values, weight_values, comp_stats + ) + + # we actually have a qconfig info object we are populating + module_fqn_to_detector_qconfig_info = {} + + for module_fqn in input_weight_equalization_info: + # create a detector info instance + detector_qconfig_info = DetectorQConfigInfo(module_fqn) + + # see if per channel quantization is supported + input_weight_recommended: bool = input_weight_equalization_info[module_fqn][ + self.RECOMMENDED_KEY + ] + detector_qconfig_info.is_equalization_recommended = input_weight_recommended + module_fqn_to_detector_qconfig_info[module_fqn] = detector_qconfig_info + + return module_fqn_to_detector_qconfig_info + + def determine_observer_insert_points( + self, prepared_fx_model: GraphModule + ) -> dict[str, dict[str, Any]]: + r"""Determines where observers need to be inserted for the Input Weight Equalization Detector. + For this detector, we want to place observers in front of supported layers. + + Currently inserts observers for: + linear layers + conv layers + + Args: + prepared_fx_model (GraphModule): The prepared Fx GraphModule + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with: + key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node) + key "observer_to_insert" -> the observer we wish to insert (ObserverBase) + key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer + key "observer_args" -> The arguments that are meant to be passed into the observer + """ + + # observer for this detector is ModelReportObserver + obs_ctr = ModelReportObserver + + # return dict + obs_fqn_to_info: dict[str, dict[str, Any]] = {} + + for fqn, module in prepared_fx_model.named_modules(): + # check to see if module is of a supported type + if self._is_supported(module, insert=True): + # if it's a supported type, we want to get node and add observer insert locations + targeted_node = self._get_targeting_node(prepared_fx_model, fqn) + + # add entry for pre-observer + pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME + + obs_fqn_to_info[pre_obs_fqn] = { + DETECTOR_TARGET_NODE_KEY: targeted_node, + DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(ch_axis=self.ch_axis), + DETECTOR_IS_POST_OBS_KEY: False, + DETECTOR_OBS_ARGS_KEY: targeted_node.args, + } + + return obs_fqn_to_info + + def get_detector_name(self) -> str: + r"""Returns the name of this detector""" + return "input_weight_equalization_detector" + + def _extract_input_info(self, model: GraphModule) -> dict[str, dict]: + r""" + Takes in a calibrated GraphModule and then finds the relevant observers. + It then extracts the input information for each observer returns it + + Args + model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers + + Returns a dict mapping relevant module fqns (str) to a dict with keys: + "input_activation_per_channel_max" : maps to the per_channel max values + "input_activation_per_channel_min" : maps to the per_channel min values + "input_activation_global_max" : maps to the global max recorded + "input_activation_global_min" : maps to the global min recorded + """ + + # return dictionary mapping observer fqns to desired info + input_info: dict[str, dict] = {} + + for fqn, module in model.named_modules(): + # if module is supported and it has a pre-observer + if self._is_supported(module): + # get pre observer for the module + pre_obs = getattr(module, self.DEFAULT_PRE_OBSERVER_NAME) + + input_info[fqn] = { + self.ACTIVATION_PREFIX + self.PER_CHANNEL_MAX_KEY: pre_obs.max_val, + self.ACTIVATION_PREFIX + self.PER_CHANNEL_MIN_KEY: pre_obs.min_val, + self.ACTIVATION_PREFIX + self.GLOBAL_MAX_KEY: max(pre_obs.max_val), + self.ACTIVATION_PREFIX + self.GLOBAL_MIN_KEY: min(pre_obs.min_val), + } + + return input_info + + def _extract_weight_info(self, model: GraphModule) -> dict[str, dict]: + r""" + Takes in a calibrated GraphModule and then finds the relevant observers. + It then extracts the weight information for each layer an observer is attached to. + + Args + model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers + + Returns a dict mapping module fqns (str) to a dict with keys: + "per_channel_max" : maps to the per_channel max values + "per_channel_min" : maps to the per_channel min values + "global_max" : maps to the global max recorded + "global_min" : maps to the global min recorded + """ + # return dictionary mapping observer fqns to desired info + weight_info: dict[str, dict] = {} + + for fqn, module in model.named_modules(): + # if module is supported and it has a pre-observer + if self._is_supported(module): + # we don't need actual observer, just the module weights + # calculate min and max vals + device = module.weight.device + min_val: torch.Tensor = torch.tensor([float("inf")], device=device) + max_val: torch.Tensor = torch.tensor([float("-inf")], device=device) + x_copy = module.weight + x_dim = x_copy.size() + + new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x_copy.permute(new_axis_list) + + # Need to match dtype of min/max because the updates to buffers + # are done in place and types need to match for comparisons + y = y.to(min_val.dtype) + y = torch.flatten(y, start_dim=1) + if min_val.numel() == 0 or max_val.numel() == 0: + min_val, max_val = torch.aminmax(y, dim=1) + else: + min_val_cur, max_val_cur = torch.aminmax(y, dim=1) + min_val = torch.min(min_val_cur, min_val) + max_val = torch.max(max_val_cur, max_val) + + weight_info[fqn] = { + self.WEIGHT_PREFIX + self.PER_CHANNEL_MAX_KEY: max_val, + self.WEIGHT_PREFIX + self.PER_CHANNEL_MIN_KEY: min_val, + self.WEIGHT_PREFIX + self.GLOBAL_MAX_KEY: max(max_val), + self.WEIGHT_PREFIX + self.GLOBAL_MIN_KEY: min(min_val), + } + + return weight_info + + def _calculate_range_ratio( + self, info_dict: dict, info_str: str, module_fqn: str + ) -> torch.Tensor: + r""" + Takes in an info dict and calculates the s_c matrix. + + Args: + info_dict (dict): A dictionary of either input or weight range info + info_str (str): A str describing whether currently looking at weight or input info + Either "weight" or "input" + module_fqn (str): The fqn of the module we are looking at + + Returns a tensor of values, where each value is the s_c stat for a different channel + """ + # calculate the ratios of the info + # get the prefix str + prefix_str = ( + self.ACTIVATION_PREFIX if info_str == self.INPUT_STR else self.WEIGHT_PREFIX + ) + + per_channel_range = ( + info_dict[prefix_str + self.PER_CHANNEL_MAX_KEY] + - info_dict[prefix_str + self.PER_CHANNEL_MIN_KEY] + ) + global_range = ( + info_dict[prefix_str + self.GLOBAL_MAX_KEY] + - info_dict[prefix_str + self.GLOBAL_MIN_KEY] + ) + + if global_range == 0: + range_zero_explanation = "We recommend removing this channel as it doesn't provide any useful information." + raise ValueError( + f"The range of the {info_str} data for module {module_fqn} is 0, " + f"which means you have a constant value channel. {range_zero_explanation}" + ) + + ratio = per_channel_range / global_range + + return ratio + + def _generate_comparison_values( + self, input_info: dict, weight_info: dict + ) -> dict[str, torch.Tensor]: + r""" + Takes in the information on the min and max values of the inputs and weights and: + Calculates the comp stat for each channel: s_c = sqrt(w_c/W)/sqrt(i_c/I) + + Args: + input_info (dict): A dict mapping each observer to input range information + weight_info (dict): A dict mapping each observer to weight range information + + Returns a dict mapping relevant observer fqns (str) to a 1-D tensor. + Each value is a different s_c value for a different channel + """ + # create return dictionary for each observer + module_fqn_to_channel: dict[str, torch.Tensor] = {} + + # for each module (both passed in dicts should have same keys) + for module_fqn in input_info: + # raise error if not in weight info + if module_fqn not in weight_info: + raise KeyError( + f"Unable to find weight range stats for module {module_fqn}" + ) + + # calculate the ratios of the weight info and input info + weight_ratio = self._calculate_range_ratio( + weight_info[module_fqn], self.WEIGHT_STR, module_fqn + ) + input_ratio = self._calculate_range_ratio( + input_info[module_fqn], self.INPUT_STR, module_fqn + ) + + # if mismatched size, because of grouping, we want to replicate weight enough times + weight_channels = len(weight_ratio) + input_channels = len(input_ratio) + if weight_channels != input_channels: + # we try to replicate + assert input_channels % weight_channels == 0, ( + "input channels should be divisible by weight channels." + ) + # get replication factor + rep_factor: int = input_channels // weight_channels + + # weight ratio is (n,), input ratio is (k,), we just repeat weight ratio k // n + weight_ratio = weight_ratio.repeat(rep_factor) + + # calculate the s metric per channel + s = torch.sqrt(weight_ratio) / torch.sqrt(input_ratio) + module_fqn_to_channel[module_fqn] = s + + # return compiled observer ratios + return module_fqn_to_channel + + def _generate_dict_info( + self, input_info: dict, weight_info: dict, comp_stats: dict + ) -> dict[str, dict]: + r""" + Helper function for generate_detector_report that does the generation of the dictionary. + This process is done as specified in generate_detector_report documentation + + Args: + input_info (dict): A dict mapping each module to input range information + weight_info (dict): A dict mapping each module to weight range information + comp_stats (dict): A dict mapping each module to its corresponding comp stat + + Returns a dictionary mapping each module with relevant ModelReportObservers around them to: + whether input weight equalization is recommended + their s_c metric compared to the threshold + the threshold used to make the recommendation + the channel used for recording data + the input channel range info + the weight channel range info + """ + # store modules input weight equalization info + input_weight_equalization_info: dict[str, dict] = {} + + # for each module we add separate set of suggestions + for module_fqn in input_info: + # get relevant info for this module + mod_input_info: dict = input_info[module_fqn] + mod_weight_info: dict = weight_info[module_fqn] + mod_comp_stat: dict = comp_stats[module_fqn] + + # decide if each channel should have input weight equalization or not + channel_rec_vals: list = [] + + for val in mod_comp_stat: + float_rep: float = val.item() + + # decide if recommending input weight equalization + recommended: bool = ( + float_rep >= self.ratio_threshold + and float_rep <= 1 / self.ratio_threshold + ) + channel_rec_vals.append(recommended) + + # build the return dict input + # also unpack input and weight dicts into it + input_weight_equalization_info[module_fqn] = { + self.RECOMMENDED_KEY: channel_rec_vals, + self.COMP_METRIC_KEY: mod_comp_stat, + self.THRESHOLD_KEY: self.ratio_threshold, + self.CHANNEL_KEY: self.ch_axis, + **mod_input_info, + **mod_weight_info, + } + + # return our compiled info for each module + return input_weight_equalization_info + + def generate_detector_report( + self, model: GraphModule + ) -> tuple[str, dict[str, Any]]: + r""" + Determines whether input weight equalization is appropriate for a given module. + + Takes advantage of the ModelReport Observer which records per channel information of input range + It then uses the passed in weight info inconjunction to compute the desired ratio + Finally, it gives suggestions based on this information for each module of interest + + Args: + model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers + + Returns a tuple with two elements: + String report of of whether input weight equalization is recommended for certain modules + Dictionary mapping modules of interest to: + whether input weight equalization is recommended + their s_c metric compared to the threshold + the threshold used to make the recommendation + the channel used for recording data + the input channel range info + the weight channel range info + """ + + # find the range of inputs + input_values: dict[str, dict] = self._extract_input_info(model) + + # find the range of weights + weight_values: dict[str, dict] = self._extract_weight_info(model) + + # calculate per_channel comparison statistic s_c + comp_stats: dict[str, torch.Tensor] = self._generate_comparison_values( + input_values, weight_values + ) + + # generate the return dictionary + input_weight_equalization_info: dict[str, dict] = self._generate_dict_info( + input_values, weight_values, comp_stats + ) + + # now we can generate report based on this information + input_weight_string = "Input-Weight Equalization suggestions: \n" + + # some strings to be formatted depending on module we are adding + module_suggestion_str = "For Module {} looked at with axis {}: \n" + channel_suggestion_str = ( + "\tWe suggest {} input weight equalization because {}\n" + ) + use_str = "to use" + no_use_str = "to not use" + input_weight_benefit_str = "{}/{} channels would benefit and we expect significant reduction in quantization error." + input_weight_non_benefit_reasoning = ( + "{}/{} channels benefitting from input-weight equalization being applied." + ) + input_weight_non_benefit_str = "we don't expect much improvement from input-weight equalization based on {}" + + # added module check + added_module: bool = False + + # compile the suggestion string + for module_fqn in input_weight_equalization_info: + # we added at least 1 module + added_module = True + # add the module level description + input_weight_string += module_suggestion_str.format( + module_fqn, self.ch_axis + ) + + mod_info: dict[str, Any] = input_weight_equalization_info[module_fqn] + + # gather info on how many channels would benefit from input weight and + recommendation_per_channel: torch.Tensor = mod_info[self.RECOMMENDED_KEY] + num_recs = sum(recommendation_per_channel) + + if ( + num_recs / len(recommendation_per_channel) + >= self.DEFAULT_RECOMMEND_INPUT_WEIGHT_CHANNEL_RATIO + ): + input_benefit_formatted = input_weight_benefit_str.format( + num_recs, len(recommendation_per_channel) + ) + channel_str = channel_suggestion_str.format( + use_str, input_benefit_formatted + ) + input_weight_string += channel_str + else: + non_benefit_reason_formatted = ( + input_weight_non_benefit_reasoning.format( + num_recs, len(recommendation_per_channel) + ) + ) + non_benefit_str = input_weight_non_benefit_str.format( + non_benefit_reason_formatted + ) + channel_str = channel_suggestion_str.format(no_use_str, non_benefit_str) + input_weight_string += channel_str + + # if no modules looked at, amend return string + if not added_module: + input_weight_string += ( + "No applicable layers for suggestions. Only linear and conv valid.\n" + ) + + # return a tuple with the string explanation and the compiled dict info + return (input_weight_string, input_weight_equalization_info) + + +class OutlierDetector(DetectorBase): + r""" + Determines whether there are significant outliers in activation data around a certain layer. + + This is ideally used in conjunction with information on stationary vs. non-stationary distribution: + If the data is stationary, and there are significant outliers, then we want to flag them + We want to do this on a per channel basis for detecting outliers + + Determines whether activation data is flagged as outlier based on if data is stationary and: + p_r = avg(100th percentile / "reference_percentile"th percentile) + where: + p_r is average percentile ratio across all batches in the epoch + reference_percentile is a percentile values between 0 and 100 exclusive + + if p_r is above some threshold, then we consider the activations to have significant outliers + + Args: + ratio_threshold (float, optional): The threshold for p_r to determine if there are outliers in activations + Should be >= 1 + Default: 3.5 + reference_percentile (float, optional): The denominator to find the relative scale of the 100th percentile + Should be between 0 and 1 + Default: 0.975 + fraction_batches_used_threshold (float, optional): Threshold of fraction of batches per channel to determine outlier + If fraction is below this, we deem number of samples used to calculate outliers as insignificant and alert user + regardless of whether we detected outliers or not in channel to take a closer look at channel results + Should be between 0 and 1 + Default: 0.95 + ch_axis (int, optional): The channel axis being observed to determine input weight equalization + Default: 1 + + * :attr:`ratio_threshold`: The threshold for p_r to determine if there are outliers in activations + The p_r value (average ratio of 100th percentile/reference_percentile) is compared to ratio_threshold + If it is significantly greater, then we consider it an outlier + This threshold was calculated based on the ratio of the percentiles in a normal distribution + The calculations behind value choice: https://drive.google.com/file/d/1N2wdtXWI-kOH8S7HH4-PYB_NmqzZil4p/view?usp=sharing + + * :attr:`reference_percentile`: The denominator of the top fraction to find the relative scale of the 100th percentile + Should be between 0 and 1 + The calculations behind value choice: https://drive.google.com/file/d/1N2wdtXWI-kOH8S7HH4-PYB_NmqzZil4p/view?usp=sharing + + * :attr:`fraction_batches_used_threshold`: The fraction of batches to determine outliers for each channel should be above this + Some batches may not be used because of 0-based errors, so this is to ensure a good amount of the total batches are used + Should be between 0 and 1 + + * :attr:`ch_axis`: The channel axis being observed to determine outliers + + * :attr:`DEFAULT_PRE_OBSERVER_NAME`: The name of the pre-observer to be inserted for this detector + """ + + # names for the pre observers that are inserted + DEFAULT_PRE_OBSERVER_NAME: str = "model_report_pre_observer" + + # pre activation prefix + INPUT_ACTIVATION_PREFIX = "input_activation_" + + # names for dict keys + OUTLIER_KEY = "outliers_detected" + NUM_BATCHES_KEY = "outlier_detection_batches_used" + IS_SUFFICIENT_BATCHES_KEY = "outlier_detection_is_sufficient_batches" + COMP_METRIC_KEY = "outlier_detection_percentile_ratios" + RATIO_THRES_KEY = "outlier_detection_ratio_threshold" + REF_PERCENTILE_KEY = "outlier_detection_reference_percentile" + CHANNEL_AXIS_KEY = "outlier_detection_channel_axis" + MAX_VALS_KEY = INPUT_ACTIVATION_PREFIX + "per_channel_max" + CONSTANT_COUNTS_KEY = "constant_batch_counts" + + def __init__( + self, + ratio_threshold: float = 3.5, + reference_percentile: float = 0.975, + fraction_batches_used_threshold: float = 0.95, + ch_axis: int = 1, + ): + # initialize the variables of interest + self.ratio_threshold = ratio_threshold + + # make sure passed in percentile is valid + assert reference_percentile >= 0 and reference_percentile <= 1 + assert ( + fraction_batches_used_threshold >= 0 + and fraction_batches_used_threshold <= 1 + ) + self.reference_percentile = reference_percentile + self.fraction_batches_used_threshold = fraction_batches_used_threshold + self.ch_axis = ch_axis + + def get_detector_name(self) -> str: + r"""Returns the name of this detector""" + return "outlier_detector" + + def _supports_insertion(self, module: nn.Module) -> bool: + r"""Returns whether the given module is supported for observers insertion + + Any module that doesn't have children and isn't an observer itself is supported + + Args + module: The module to check and ensure is supported + + Returns True if the module is supported by observer, False otherwise + """ + # case for insertion of module + # check if the module has any children and isn't observer + num_children = len(list(module.children())) + return num_children == 0 and not _is_activation_post_process(module) + + def get_qconfig_info(self, model) -> dict[str, DetectorQConfigInfo]: + r"""Returns the DetectorQConfigInfo for each module_fqn relevant + Args + model (nn.Module or subclass): model to find observer insertion points + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to: + A DetectorQConfigInfo with the information to generate a QConfig for a specific module + """ + # currently doesn't do anything for outlier detector + return {} + + def _supports_report_gen(self, module: nn.Module) -> bool: + r"""Returns whether the given module is supported for report generation + + Any module that has a model report pre-observer is supported + + Args + module: The module to check and ensure is supported + + Returns True if the module is supported by observer, False otherwise + """ + return hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME) + + def determine_observer_insert_points( + self, prepared_fx_model: GraphModule + ) -> dict[str, dict[str, Any]]: + r"""Determines where observers need to be inserted for the Outlier Detector. + + For this detector, we want to place observers in front of supported layers. + + Currently inserts observers for: + all layers that do not have children (leaf level layers) + + Args: + prepared_fx_model (GraphModule): The prepared Fx GraphModule + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with: + key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node) + key "observer_to_insert" -> the observer we wish to insert (ObserverBase) + key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer + key "observer_args" -> The arguments that are meant to be passed into the observer + """ + # observer for this detector is ModelReportObserver + obs_ctr = ModelReportObserver + + # return dict + obs_fqn_to_info: dict[str, dict[str, Any]] = {} + + for fqn, module in prepared_fx_model.named_modules(): + # check to see if module is of a supported type + if self._supports_insertion(module): + # if it's a supported type, we want to get node and add observer insert locations + targeted_node = self._get_targeting_node(prepared_fx_model, fqn) + + # add entry for pre-observer + pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME + + obs_fqn_to_info[pre_obs_fqn] = { + DETECTOR_TARGET_NODE_KEY: targeted_node, + DETECTOR_OBS_TO_INSERT_KEY: obs_ctr( + ch_axis=self.ch_axis, comp_percentile=self.reference_percentile + ), + DETECTOR_IS_POST_OBS_KEY: False, + DETECTOR_OBS_ARGS_KEY: targeted_node.args, + } + + return obs_fqn_to_info + + def _calculate_outlier_info( + self, + percentile_ratios: torch.Tensor, + counted_batches: torch.Tensor, + total_batches: int, + ) -> dict[str, list[bool]]: + r""" + Gives info on whether the percentile ratios calculated would be considered outliers + Also gives information on whether the collected data is statistically significant to make this claim + + Args: + percentile_ratios (torch.Tensor): The average percentile_ratios per channel calculated by the observer + counted_batches (torch.Tensor): The number of batches used for average calculation per tensor + total_batches (int): The total number of batches that passed through observer in this epoch + + Returns a dictionary mapping: + "outliers_detected" : list of bools per channel that are true if it is considered an outlier + "is_sufficient_batches": if o_r was >= fraction_batches_used_threshold: + where o_r = counted_batches / total_batches + """ + outlier_dict: dict[str, list[bool]] = { + self.OUTLIER_KEY: [], + self.IS_SUFFICIENT_BATCHES_KEY: [], + } + + # get both as flattened lists for easy mapping + ratios_list: list = percentile_ratios.tolist() + num_batches_list: list = counted_batches.tolist() + + # calculate whether channels were statistically significant + significant_size = [ + batch_size / total_batches >= self.fraction_batches_used_threshold + for batch_size in num_batches_list + ] + outlier_dict[self.IS_SUFFICIENT_BATCHES_KEY] = significant_size + + # calculate for each channel whether it's an outlier or not based on ratio + outlier_detected = [ratio > self.ratio_threshold for ratio in ratios_list] + outlier_dict[self.OUTLIER_KEY] = outlier_detected + + # return the dictionary with the two lists + return outlier_dict + + def _generate_info_dict(self, model: GraphModule) -> dict[str, dict]: + r""" + Helper function for generate_detector_report that does the generation of the dictionary. + This process is done as specified in generate_detector_report documentation + + Args: + model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers + + Returns a dict mapping relevant module fqns to: + whether there were outliers found in activation before + the number of batches used for each channel + whether fraction of applicable batches used is above fraction_batches_used_threshold + their p_r metric compared to the threshold + the threshold used to make the recommendation + the reference_percentile used to make the recommendation + the channel axis used to determine individual channels + the constant batch counts per channel + the per channel max values + """ + # return dictionary mapping observer fqns to desired info + info_dict: dict[str, dict] = {} + + for fqn, module in model.named_modules(): + # if module is supported and it has a pre-observer + if self._supports_report_gen(module): + # get pre observer for the module + pre_obs: ModelReportObserver = getattr( + module, self.DEFAULT_PRE_OBSERVER_NAME + ) + + # get the number of batches and calculated ratio thresholds + num_batches: torch.Tensor = pre_obs.percentile_batches_tracked + average_ratios: torch.Tensor = pre_obs.average_percentile_ratio + channel_batch_cnts: torch.Tensor = pre_obs.constant_channels + total_batches: int = pre_obs.num_batches_tracked + + # also get the max values + max_vals: torch.Tensor = pre_obs.max_val + + # we have to specifically modify how we are recording negative ratio for pre-relu layers + for index, ratio_val in enumerate(average_ratios): + # check if we have a negative ratio + # a ratio might be negative if we have a situation where the 100th percentile is + # > 0 while the nth percentile is < 0, in which case this would not be detected + # as an outlier. Since we care more about magnitude, we make it positive. + if ratio_val.item() < 0: + # first make it positive + average_ratios[index] = -ratio_val + + if ratio_val.item() < 1: + # if it's less than 1 we have the flip it as well + average_ratios[index] = 1 / ratio_val + + outlier_calcs = self._calculate_outlier_info( + average_ratios, num_batches, total_batches + ) + + # calculate whether ratios were outliers + info_dict[fqn] = { + self.CHANNEL_AXIS_KEY: self.ch_axis, + self.REF_PERCENTILE_KEY: self.reference_percentile, + self.RATIO_THRES_KEY: self.ratio_threshold, + self.COMP_METRIC_KEY: average_ratios, + self.NUM_BATCHES_KEY: num_batches, + self.OUTLIER_KEY: outlier_calcs[self.OUTLIER_KEY], + self.IS_SUFFICIENT_BATCHES_KEY: outlier_calcs[ + self.IS_SUFFICIENT_BATCHES_KEY + ], + self.CONSTANT_COUNTS_KEY: channel_batch_cnts, + self.MAX_VALS_KEY: max_vals, + } + + return info_dict + + def generate_detector_report( + self, model: GraphModule + ) -> tuple[str, dict[str, Any]]: + r""" + Determines whether input weight equalization is appropriate for a given module. + + Takes advantage of the ModelReport Observer which records the relevant percentile information + + Args: + model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers + + Returns a tuple with two elements: + String report of of whether there are outliers in the activations around certain modules + Dictionary mapping modules of interest to: + whether there were outliers found in activation before + the number of batches used for each channel + whether fraction of applicable batches used is above fraction_batches_used_threshold + their p_r metric compared to the threshold + the threshold used to make the recommendation + the reference_percentile used to make the recommendation + the channel axis used to determine individual channels + the constant batch counts per channel + the per channel max values + """ + # generate the information dictionary of outlier information + info_dict = self._generate_info_dict(model) + + # now we can generate report based on this information + outlier_string = "Outlier detection report: \n" + + # added module check + added_module: bool = False + + # some strings to be formatted depending on module we are adding + module_suggestion_str = "For Module {} looked at with axis {}: \n" + channel_suggestion_str = "\tFor channel {}, we found outliers in the preceding activation data with {}.\n" + channel_max_value_str = "a max value across all batches of {}" + note_string = "Note: outlier detection is only reliable for {}. We recommend {} to ensure the most accurate results." + note_distribution = "stationary distributions" + note_rec = "running the static vs. dynamic detector to ensure activation data before modules above is stationary" + + # suggestion for constant batch check since that can make it no outliers + constant_str = "\tFor channel {}, we found {} constant value batches. {}\n" + constant_suggestion = "We recommend taking a look at the dict and data to see how frequent this occurred and why." + + # compile the suggestion string + for module_fqn in info_dict: + # get module specific info + mod_info: dict[str, Any] = info_dict[module_fqn] + # check to see if we already added high level model desc + added_model_desc = False + # look at each individual channel and add a suggestion + for index, outlier_detected in enumerate(mod_info[self.OUTLIER_KEY]): + if outlier_detected: + # we found at least 1 outlier + if not added_model_desc: + # add the module level description + outlier_string += module_suggestion_str.format( + module_fqn, self.ch_axis + ) + added_model_desc = True + + # we mark that we found at least one outlier + added_module = True + max_value_found_str = channel_max_value_str.format( + mod_info[self.MAX_VALS_KEY][index] + ) + channel_str = channel_suggestion_str.format( + index, max_value_found_str + ) + outlier_string += channel_str + + # also check if we found constant batch + if mod_info[self.CONSTANT_COUNTS_KEY][index] != 0: + # make sure we add a module level highlight. + if not added_model_desc: + # add the module level description + outlier_string += module_suggestion_str.format( + module_fqn, self.ch_axis + ) + added_model_desc = True + + constant_values_for_channel = mod_info[self.CONSTANT_COUNTS_KEY][ + index + ] + formatted_str = constant_str.format( + index, constant_values_for_channel, constant_suggestion + ) + outlier_string += formatted_str + # we also added at least one thing to description + added_module = True + + # if found outlier, give suggestion, else give default response + if added_module: + # compose the note string + note_composed = note_string.format(note_distribution, note_rec) + outlier_string += note_composed + else: + outlier_string += "There were no outliers found in the activations.\n" + + return (outlier_string, info_dict) diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/_model_report/model_report.py b/phivenv/Lib/site-packages/torch/ao/quantization/fx/_model_report/model_report.py new file mode 100644 index 0000000000000000000000000000000000000000..598abb063f5772d9020da3fd69d655cdd3597c1a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/fx/_model_report/model_report.py @@ -0,0 +1,664 @@ +# mypy: allow-untyped-defs +from collections import OrderedDict +from typing import Any, Callable + +import torch +from torch.ao.quantization.fx._equalize import EqualizationQConfig +from torch.ao.quantization.fx._model_report.detector import ( + DETECTOR_IS_POST_OBS_KEY, + DETECTOR_OBS_ARGS_KEY, + DETECTOR_OBS_TO_INSERT_KEY, + DETECTOR_TARGET_NODE_KEY, + DetectorBase, + DetectorQConfigInfo, +) +from torch.ao.quantization.fx._model_report.model_report_visualizer import ( + ModelReportVisualizer, +) +from torch.ao.quantization.fx.graph_module import GraphModule +from torch.ao.quantization.observer import ObserverBase +from torch.ao.quantization.qconfig_mapping import QConfig, QConfigMapping + + +class ModelReport: + r""" + The ModelReport class aims to provide users an easy way to diagnose issues that they run into + with their models. The class works with all traceable GraphModules to help diagnose issues, + though the requirements on the type of model more-so depends on the specific report the user + is trying to generate. With respect to the reports, the ModelReport class is initialized with + a set of Detector classes, each of which generate reports on quantization configuration + issues a use might have. + + Currently supports generating reports on: + - Suggestions for per-channel vs. per-tensor quantization (nn.Module) + - Suggestions for dynamic vs static quantization for linear layers (Graph Modules) + - Suggestions for input-weight equalization for linear and conv layers (Graph Modules) + - Suggestions for outlier detection for all layers (Graph Modules) + + The ModelReport class has the primary functionality of inserting observers (primarily the ModelReportObserver) + where needed for each detector to gather the information it needs, and then after callibration, the ModelReport + class compiles the report generated by each Detector class into a single report to return to the user. It also + has the capability to remove all the observers it inserted as well. + + * :attr:`_model` The model we wish to generate the report for. Must be a traceable GraphModule + + * :attr:`_desired_report_detectors` The set of Detectors representing desired reports from the ModelReport class + Make sure that these are all unique types of detectors [do not have more than 1 of the same class] + + * :attr:`_desired_detector_names` The set of detector names of the _desired_report_detectors. + This set is generated by calling the get_detector_name() of each detector + + * :attr:`_detector_name_to_observer_fqns` The mapping from each detector to fqns of observers of interest + The purpose of this is to keep track of what observers were inserted for each detector, so that they + can be removed at the end if desired + + * :attr:`_prepared_flag` A boolean flag that keeps track of whether we have prepared the model or not + This is to ensure we only insert observers once with the ModelReport instance + + * :attr:`_removed_observers` A boolean to track if we have removed observers already + The purpose is to ensure we don't attempt to remove observers twice with the same ModelReport + instance. This also allows the functionality where we can generate the report multiple times + as long as we haven't removed the observers yet. + + Note: + This class was initially designed to work with the Fx Graph Mode workflow in mind. However, + full functionality is available as long as there is a traceable GraphModule that is being used. + One method to get a traceable GraphModule without going through the Fx workflow is to use + the QuantizationTracer class. + + General Flow for Fx workflow: + 1.) Initialize ModelReport object with reports of interest by passing in initialized detector objects and model + 2.) Prepare your model with prepare_fx + 3.) Call model_report.prepare_detailed_calibration to add relevant observers + 4.) Callibrate your model with data + 5.) Call model_report.generate_report on your model to generate report and optionally remove added observers + Optional + 6.) Call model_report.generate_visualizer to get a ModelReportVisualizer instance + 7.) To help in parsing report information and debugging, view report info as a: + - Table + - Histogram + - Line plot + 8.) Call model_report.generate_qconfigs to generate the qconfigs based on the report suggestions + + Example (with QuantizationTracer): + >>> # xdoctest: +SKIP + >>> # get the necessary qconfig + >>> config = PrepareCustomConfig() + >>> skipped_module_names, skipped_module_classes = ( + ... get_skipped_module_name_and_classes(config, False) + ... ) + + >>> # initialize our model and get GraphModule + >>> model = SomeModel() + >>> tracer = QuantizationTracer(skipped_module_names, skipped_module_classes) + >>> graph_module = GraphModule(model, tracer.trace(model)) + + >>> # get our set of detectors and ModelReport instance + >>> detector_set = set( + ... [ + ... DynamicStaticDetector(tolerance=0.5), + ... InputWeightEqualizationDetector(ratio_threshold=0.7), + ... ] + ... ) + >>> tracer_reporter = ModelReport(graph_module, tracer_detector_set) + + >>> # now we insert the observers and callibrate the model + >>> tracer_model_with_observers = tracer_reporter.prepare_detailed_calibration() + >>> for i in range(num_callibration_batches): + >>> example_input = get_callibration_input() + >>> tracer_model_with_observers(example_input) + + >>> # finally we generate the reports and optionally remove the observers we inserted + >>> reports = tracer_reporter.generate_model_report( + ... remove_inserted_observers=True + ... ) + + >>> # Optional: we can generate the qconfig mapping based on the suggestions + >>> qconfigs = model_report.generate_qconfig_mapping() + + >>> # Optional: we can generate the equalization mapping based on the suggestions + >>> qconfigs = model_report.generate_equalization_mapping() + + >>> # Optional: we get a ModelReportVisualizer instance to do any visualizations desired + >>> model_report_visualizer = tracer_reporter.generate_visualizer() + + """ + + def __init__(self, model: GraphModule, desired_report_detectors: set[DetectorBase]): + if len(desired_report_detectors) == 0: + raise ValueError("Should include at least 1 desired report") + + # keep track of the model we wish to generate report for + self._model: GraphModule = model + + # keep the reports private so they can't be modified + self._desired_report_detectors = desired_report_detectors + self._desired_detector_names = { + detector.get_detector_name() for detector in desired_report_detectors + } + + # keep a mapping of desired reports to observers of interest + # this is to get the readings, and to remove them, can create a large set + # this set can then be used to traverse the graph and remove added observers + self._detector_name_to_observer_fqns: dict[str, set[str]] = {} + + # initialize each report to have empty set of observers of interest + for desired_report in self._desired_detector_names: + self._detector_name_to_observer_fqns[desired_report] = set() + + # flags to ensure that we can only prepare and remove observers once + self._prepared_flag = False + self._removed_observers = False + + # store the reports that we generated for visualization purposes + # initially empty since no reports generated + self._generated_reports: dict[str, dict] = {} + + def get_desired_reports_names(self) -> set[str]: + """Returns a copy of the desired reports for viewing""" + return self._desired_detector_names.copy() + + def get_observers_of_interest(self) -> dict[str, set[str]]: + """Returns a copy of the observers of interest for viewing""" + return self._detector_name_to_observer_fqns.copy() + + def prepare_detailed_calibration(self) -> GraphModule: + r""" + Takes in a graph model and inserts the following observers: + - ModelReportObserver + + Each observer is inserted based on the desired_reports into the relevant locations + + Right now, each report in self._desired_detector_names has independent insertions + However, if a module already has a Observer of the same type, the insertion will not occur + This is because all of the same type of Observer collect same information, so redundant + + Returns the same GraphModule with the observers inserted + """ + + # if already prepared once, cannot prepare again + if self._prepared_flag: + raise ValueError( + "Already ran preparing detailed callibration. Run the report generation next after callibration." + ) + + # loop through each detector, find where placements should be, and keep track + insert_observers_fqns: dict[str, Any] = {} + + for detector in self._desired_report_detectors: + # determine observer points for each detector + obs_fqn_to_info = detector.determine_observer_insert_points(self._model) + # map each insert point to the observer to use + insert_observers_fqns.update(obs_fqn_to_info) + # update the set of observers this report cares about + self._detector_name_to_observer_fqns[detector.get_detector_name()] = set( + obs_fqn_to_info.keys() + ) + + # now insert all the observers at their desired locations + for observer_fqn in insert_observers_fqns: + target_node = insert_observers_fqns[observer_fqn][DETECTOR_TARGET_NODE_KEY] + insert_obs = insert_observers_fqns[observer_fqn][DETECTOR_OBS_TO_INSERT_KEY] + insert_post = insert_observers_fqns[observer_fqn][DETECTOR_IS_POST_OBS_KEY] + observer_args = insert_observers_fqns[observer_fqn][DETECTOR_OBS_ARGS_KEY] + self._insert_observer_around_module( + observer_fqn, target_node, insert_obs, observer_args, insert_post + ) + + self._prepared_flag = True + + return self._model + + def _insert_observer_around_module( + self, + obs_fqn: str, + target_node: torch.fx.node.Node, + obs_to_insert: ObserverBase, + observer_args: tuple, + insert_post: bool, + ): + r""" + Helper function that inserts the observer into both the graph structure and the module of the model + + Args + node_fqn (str): The fully qualified name of the observer we want to insert + target_node (torch.fx.node.Node): The node in model we are inserting observers around + obs_to_insert (ObserverBase): The observer we are inserting around target_node + observer_args (Tuple): The arguments we want to pass into the observer + insert_post (bool): whether this is meant to be a post observer for this node + """ + # if we are inserting post, then our target node is the next node + if insert_post: + target_node = target_node.next + + with self._model.graph.inserting_before(target_node): + self._model.add_submodule(obs_fqn, obs_to_insert) + self._model.graph.create_node( + op="call_module", target=obs_fqn, args=observer_args + ) + + # recompile model after inserts are made + self._model.recompile() + + def _get_node_from_fqn(self, node_fqn: str) -> torch.fx.node.Node: + r""" + Takes in a node fqn and returns the node based on the fqn + + Args + node_fqn (str): The fully qualified name of the node we want to find in model + + Returns the Node object of the given node_fqn otherwise returns None + """ + node_to_return = None + for node in self._model.graph.nodes: + # if the target matches the fqn, it's the node we are looking for + if node.target == node_fqn: + node_to_return = node + break + + if node_to_return is None: + raise ValueError("The node_fqn is was not found within the module.") + + # assert for MyPy + assert isinstance(node_to_return, torch.fx.node.Node) + + return node_to_return + + def generate_model_report( + self, remove_inserted_observers: bool + ) -> dict[str, tuple[str, dict]]: + r""" + Generates all the requested reports. + + Note: + You should have callibrated the model with relevant data before calling this + + The reports generated are specified by the desired_reports specified in desired_reports + + Can optionally remove all the observers inserted by the ModelReport instance + + Args: + remove_inserted_observers (bool): True to remove the observers inserted by this ModelReport instance + + Returns a mapping of each desired report name to a tuple with: + The textual summary of that report information + A dictionary containing relevant statistics or information for that report + + Note: + Throws exception if we try to generate report on model we already removed observers from + Throws exception if we try to generate report without preparing for callibration + """ + # if we haven't prepped model for callibration, then we shouldn't generate report yet + if not self._prepared_flag: + raise Exception( # noqa: TRY002 + "Cannot generate report without preparing model for callibration" + ) + + # if we already removed the observers, we cannot generate report + if self._removed_observers: + raise Exception( # noqa: TRY002 + "Cannot generate report on model you already removed observers from" + ) + + # keep track of all the reports of interest and their outputs + reports_of_interest = {} + + for detector in self._desired_report_detectors: + # generate the individual report for the detector + report_output = detector.generate_detector_report(self._model) + reports_of_interest[detector.get_detector_name()] = report_output + + # if user wishes to remove inserted observers, go ahead and remove + if remove_inserted_observers: + self._removed_observers = True + # get the set of all Observers inserted by this instance of ModelReport + all_observers_of_interest: set[str] = set() + for desired_report in self._detector_name_to_observer_fqns: + observers_of_interest = self._detector_name_to_observer_fqns[ + desired_report + ] + all_observers_of_interest.update(observers_of_interest) + + # go through all_observers_of_interest and remove them from the graph and model + for observer_fqn in all_observers_of_interest: + # remove the observer from the model + self._model.delete_submodule(observer_fqn) + + # remove the observer from the graph structure + node_obj = self._get_node_from_fqn(observer_fqn) + + if node_obj: + self._model.graph.erase_node(node_obj) + else: + raise ValueError("Node no longer exists in GraphModule structure") + + # remember to recompile the model + self._model.recompile() + + # save the generated reports for visualization purposes + saved_reports: dict[str, dict] = { + report_name: report_tuple[1] + for report_name, report_tuple in reports_of_interest.items() + } + + self._generated_reports = saved_reports + + # return the reports of interest + return reports_of_interest + + def _is_same_info_for_same_key(self, info_dict_a: dict, info_dict_b: dict) -> bool: + r""" + Takes in two dictionaries and ensures that any common keys between the two have the same + values. + + Args: + info_dict_a (Dict): First dictionary we wish to compare + info_dict_b (Dict): Second dictionary we wish to compare + + Returns True if all shared keys have same values, false otherwise + """ + # get the set of keys for both + dict_a_keys: set = set(info_dict_a.keys()) + dict_b_keys: set = set(info_dict_b.keys()) + + # get the insersection keys and check if same value for both dicts + intersecting_keys: set = dict_a_keys.intersection(dict_b_keys) + + for key in intersecting_keys: + dict_a_val = info_dict_a[key] + dict_b_val = info_dict_b[key] + + # if it's a tensor we have to handle separately + if type(dict_a_val) == torch.Tensor: + # if dict_b_val not tensor, automatically false + if ( + type(dict_b_val) != torch.Tensor + or sum(dict_a_val != dict_b_val) != 0 + ): + return False + else: + # for non-tensor vals + if dict_a_val != dict_b_val: + return False + + # if no non matching shared keys found, return true + return True + + def _reformat_reports_for_visualizer(self) -> OrderedDict: + r""" + Takes the generated reports and reformats them into the format that is desired by the + ModelReportVisualizer + + Returns an OrderedDict mapping module_fqns to their features + """ + # we want to reorder and reformat the information so it is ordered in terms of order + # found in the model + + # first create new dict with all modules as keys and features under respective module + module_fqns_to_features: dict[str, dict] = {} + + for report_name in self._generated_reports: + # get mod -> feature dict and go through + module_info = self._generated_reports[report_name] + + for module_fqn in module_info: + # check if already in our accumulation dict + if module_fqn in module_fqns_to_features: + # we merge all the features together + new_info: dict = module_info[module_fqn] + present_info: dict = module_fqns_to_features[module_fqn] + + # merge them together into the new unioned dict + # same features keys -> same info, so okay if override + + # do safety check to make sure shared keys have same info + if self._is_same_info_for_same_key(new_info, present_info): + module_fqns_to_features[module_fqn] = { + **new_info, + **present_info, + } + else: + error_str = "You have the same key with different values across detectors. " + error_str += "Someone incorrectly implemented a detector with conflicting keys to existing detectors." + raise ValueError(error_str) + else: + # we just set it + module_fqns_to_features[module_fqn] = module_info[module_fqn] + + # our ordered dict so that modules can be ordered in order of how they appear in model + features_by_module: OrderedDict[str, dict] = OrderedDict() + + # we loop through modules in graph in order + for fqn, _module in self._model.named_modules(): + # find that fqn in fqns_to_features + if fqn in module_fqns_to_features: + # add it to our ordered dict + features_by_module[fqn] = module_fqns_to_features[fqn] + + # return the ordered dict of info we created + return features_by_module + + def generate_visualizer(self) -> ModelReportVisualizer: + r""" + Generates a ModelReportVisualizer instance using the reports generated + by the generate_model_report() method. + + Returns the generated ModelReportVisualizer instance initialized + + Note: + Throws exception if attempt to get visualizers without generating report + """ + # check if user has generated reports at least once + if len(self._generated_reports) == 0: + raise Exception( # noqa: TRY002 + "Unable to generate visualizers without first generating reports" + ) + + # get the ordered dict mapping modules to their full set of collected features / stats + module_fqns_to_features: OrderedDict = self._reformat_reports_for_visualizer() + + # create and return ModelReportVisualizer instance + visualizer: ModelReportVisualizer = ModelReportVisualizer( + module_fqns_to_features + ) + + return visualizer + + def _generate_qconfig_mapping_helper( + self, + detector_qconfig_info_combined: dict[str, DetectorQConfigInfo], + generation_function: Callable, + ) -> QConfigMapping: + r""" + This helper takes in the compiled detector qconfig info that + has been compiled together and merges it into a QConfigMapping + """ + # keep track of the qconfigmapping + qconfig_mapping = QConfigMapping() + + # loop through each module / fqn and attempt to create QConfigMapping + for fqn, module in self._model.named_modules(): + # if we have a qconfig info for this module + if fqn in detector_qconfig_info_combined: + qconfig_info_compiled = detector_qconfig_info_combined[fqn] + + # now generate the qconfig and add it to the mapping + generated_qconfig = generation_function(qconfig_info_compiled, module) + + # add to our config + qconfig_mapping.set_module_name(fqn, generated_qconfig) + + # return compiled mapping + return qconfig_mapping + + def _update_detector_quantizaiton_qconfig_info( + self, combined_info: DetectorQConfigInfo, new_info: DetectorQConfigInfo + ): + r""" + Takes in the old and new information and updates the combined information. + + Args: + combined_info (DetectorQConfigInfo): The DetectorQConfigInfo we are compiling all of the information in + new_info (DetectorQConfigInfo): The DetectorQConfigInfo with the information we are trying to merge the new info + into it + """ + combined_info.is_activation_dynamic = ( + combined_info.is_activation_dynamic or new_info.is_activation_dynamic + ) + combined_info.is_weight_per_channel = ( + combined_info.is_weight_per_channel or new_info.is_weight_per_channel + ) + + def _update_detector_equalization_qconfig_info( + self, combined_info: DetectorQConfigInfo, new_info: DetectorQConfigInfo + ): + r""" + Takes in the old and new information and updates the combined information. + + Args: + combined_info (DetectorQConfigInfo): The DetectorQConfigInfo we are compiling all of the information in + new_info (DetectorQConfigInfo): The DetectorQConfigInfo with the information we are trying to merge the new info + into it + """ + is_equalization_recommended = ( + combined_info.is_equalization_recommended + or new_info.is_equalization_recommended + ) + combined_info.is_equalization_recommended = is_equalization_recommended + + def _generate_module_fqn_to_detector_info_mapping( + self, update_qconfig_info_function: Callable + ) -> dict[str, DetectorQConfigInfo]: + r""" + Generates a QConfigMapping based on the suggestions of the + ModelReport API. The generated mapping encompasses all the + different types of feedback from the different detectors + all into one place. + + These configs are based on the suggestions provided by the ModelReport API + and can only be generated once the reports have been generated. + + Args: + update_qconfig_info_function (Callable) takes in a function that takes in two DetectorQConfigInfo + and updates the one that is being compiled + + Returns a Dict mapping module_fqns to DetectorQConfigInfo objects + + Note: + Throws exception if we try to generate mapping on model we already removed observers from + Throws exception if we try to generate mapping without preparing for callibration + """ + # if we haven't prepped model for callibration, then we shouldn't generate mapping yet + if not self._prepared_flag: + raise Exception( # noqa: TRY002 + "Cannot generate report without preparing model for callibration" + ) + + # if we already removed the observers, we cannot mapping + if self._removed_observers: + raise Exception( # noqa: TRY002 + "Cannot generate report on model you already removed observers from" + ) + + # keep track of qconfig info for each module across detectors + detector_qconfig_info_combined: dict[str, DetectorQConfigInfo] = {} + + for detector in self._desired_report_detectors: + # get the info from the detector + detector_info: dict[str, DetectorQConfigInfo] = detector.get_qconfig_info( + self._model + ) + + # we go through the modules + for module_fqn in detector_info: + # see if we already have info on it + if module_fqn in detector_qconfig_info_combined: + # we combine the current options with what is there + current_options = detector_qconfig_info_combined[module_fqn] + detector_options = detector_info[module_fqn] + + update_qconfig_info_function(current_options, detector_options) + else: + # we just use this for now + detector_qconfig_info_combined[module_fqn] = detector_info[ + module_fqn + ] + + return detector_qconfig_info_combined + + def generate_qconfig_mapping(self) -> QConfigMapping: + r""" + Generates a QConfigMapping based on the suggestions of the + ModelReport API. The generated mapping encompasses all the + different types of feedback from the different detectors + all into one place. + + These configs are based on the suggestions provided by the ModelReport API + and can only be generated once the reports have been generated. + + Returns a QConfigMapping for the quantization configuration + + Note: + Throws exception if we try to generate mapping on model we already removed observers from + Throws exception if we try to generate mapping without preparing for callibration + """ + # get the mapping info + detector_qconfig_info_combined = ( + self._generate_module_fqn_to_detector_info_mapping( + self._update_detector_quantizaiton_qconfig_info + ) + ) + + # we will do a bit of processing and remove fqns that don't have input weight recommended + + # now we generate the QConfig for each of the options + mapping: QConfigMapping = self._generate_qconfig_mapping_helper( + detector_qconfig_info_combined, self._quantization_config_generator + ) + + # return the generated mapping + return mapping + + def _quantization_config_generator( + self, detector_qconfig_info: DetectorQConfigInfo, module: torch.nn.Module + ) -> QConfig: + r""" + Returns the quantization configuration generated by the DetectorQConfigInfo object + """ + return detector_qconfig_info.generate_quantization_qconfig(module) + + def _equalization_config_generator( + self, detector_qconfig_info: DetectorQConfigInfo, module: torch.nn.Module + ) -> EqualizationQConfig: + r""" + We ignore the module argument here, and only focus on thedetector_qconfig_info + + Returns the equalization configuration generated by the DetectorQConfigInfo object + """ + return detector_qconfig_info.generate_equalization_qconfig() + + def generate_equalization_mapping(self) -> QConfigMapping: + r""" + Generates a QConfigMapping based on the suggestions of the + ModelReport API for equalization. The generated mapping encompasses all the + different types of feedback from the input-weight equalization detector. + + These configs are based on the suggestions provided by the ModelReport API + and can only be generated once the reports have been generated. + + Returns a QConfigMapping for the equalization configuration + """ + # get the mapping info + detector_qconfig_info_combined = ( + self._generate_module_fqn_to_detector_info_mapping( + self._update_detector_equalization_qconfig_info + ) + ) + + # now we generate the QConfig for each of the options + mapping: QConfigMapping = self._generate_qconfig_mapping_helper( + detector_qconfig_info_combined, self._equalization_config_generator + ) + + # return the generated mapping + return mapping diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/_model_report/model_report_observer.py b/phivenv/Lib/site-packages/torch/ao/quantization/fx/_model_report/model_report_observer.py new file mode 100644 index 0000000000000000000000000000000000000000..8232f70c9d460a04df7ea7e38bfca906e2386325 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/fx/_model_report/model_report_observer.py @@ -0,0 +1,285 @@ +# mypy: allow-untyped-defs +import torch +from torch.ao.quantization.observer import ObserverBase + + +class ModelReportObserver(ObserverBase): + r"""This observer is used to record additional information regarding keeping track + of S = average_batch_activation_range/epoch_activation_range. + + The purpose of this information is to prepare a report to present to users on whether + Dynamic or Static Quantization is more appropriate for their model given the general + distributions of their data. + + Args: + ch_axis (int, optional): The channel axis for which the range and outlier stats are computed + Default: 1 + comp_percentile (float, optional): The percentile to compare against 100 percentile to find outliers + Should be between 0 and 1 exclusive + Default: 0.9 + + * :attr:`num_batches_tracked` specifies number of batches passed through the observer + + * :attr:`average_batch_activation_range` defines average across the ranges of each batch passed through + + * :attr:`epoch_activation_min` defines the minimum value passed through the observer + + * :attr:`epoch_activation_max` defines the maximum value passed through the observer + + * :attr:`ch_axis` defines the channel being used to compute per channel min max stats + + * :attr:`min_val` defines the per channel minimum values passed through + + * :attr:`max_val` defines the per channel maximum values passed through + + * :attr:`comp_percentile` defines comparison percentile to find outliers + + * :attr:`average_percentile_ratio` defines the per channel average percentile ratios + + * :attr:`percentile_batches_tracked` defines the number of percentile batches tracked for each channel + + * :attr:`constant_channels` defines the number of batches that aren't constant channels per channel + + Note: this tool is meant for FX Graph Mode Quantization + """ + + epoch_activation_min: torch.Tensor + epoch_activation_max: torch.Tensor + min_val: torch.Tensor + max_val: torch.Tensor + comp_percentile: torch.Tensor + average_percentile_ratio: torch.Tensor + percentile_batches_tracked: torch.Tensor + constant_channels: torch.Tensor + + def __init__(self, ch_axis: int = 1, comp_percentile: float = 0.9): + super().__init__(torch.qint8) + self.num_batches_tracked = 0 + + # keep track of the min and mix of the range for average batch and epoch as a whole + self.average_batch_activation_range: torch.Tensor = torch.tensor(float(0)) + self.register_buffer("epoch_activation_min", torch.tensor(float("inf"))) + self.register_buffer("epoch_activation_max", torch.tensor(float("-inf"))) + + # keep track of per channel min max information using the given channel + self.ch_axis: int = ch_axis + self.register_buffer("min_val", torch.tensor([])) + self.register_buffer("max_val", torch.tensor([])) + + # keep track of percentile ratio information per channel + self.register_buffer("comp_percentile", torch.tensor([comp_percentile])) + self.register_buffer("average_percentile_ratio", torch.tensor([])) + self.register_buffer("percentile_batches_tracked", torch.tensor([])) + self.register_buffer("constant_channels", torch.tensor([])) + + def forward(self, x): + x_copy = x.detach() # avoid keeping autograd tape + x_copy = x_copy.to(self.epoch_activation_min.dtype) + + x_copy = self._calculate_range_stats(x_copy) + x_copy = self._calculate_min_max_stats(x_copy) + x_copy = self._calculate_percentile_stats(x_copy) + + # return the passed in the value + return x + + def _calculate_range_stats(self, x_copy): + r"""Calculates and stores range stats with forward values. + + Args + x_copy: A copy of the forward data + + Returns the passed in x_copy + """ + # get the min, max values of the data + min_val_cur, max_val_cur = torch.aminmax(x_copy) + + # calculate new epoch range values + epoch_min_val = torch.min(self.epoch_activation_min, min_val_cur) + epoch_max_val = torch.max(self.epoch_activation_max, max_val_cur) + + self.epoch_activation_min.copy_(epoch_min_val) + self.epoch_activation_max.copy_(epoch_max_val) + + # calculate the average batch activation range + current_batch_range = max_val_cur - min_val_cur + new_range = ( + self.average_batch_activation_range * self.num_batches_tracked + + current_batch_range + ) / (self.num_batches_tracked + 1) + + self.average_batch_activation_range = new_range + self.num_batches_tracked += 1 # new batch was processed + + return x_copy + + def _calculate_min_max_stats(self, x_copy): + r"""Calculates and stores the per_channel min, max stats with forward values. + Does calculation based on channel axis: self.ch_axis + + Args + x_copy: A copy of the forward data + + Returns the passed in x_copy + """ + # get the current min and max vals + min_val = self.min_val + max_val = self.max_val + x_dim = x_copy.size() + + new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x_copy.permute(new_axis_list) + # Need to match dtype of min/max because the updates to buffers + # are done in place and types need to match for comparisons + y = y.to(self.min_val.dtype) + y = torch.flatten(y, start_dim=1) + if min_val.numel() == 0 or max_val.numel() == 0: + min_val, max_val = torch.aminmax(y, dim=1) + else: + min_val_cur, max_val_cur = torch.aminmax(y, dim=1) + min_val = torch.min(min_val_cur, min_val) + max_val = torch.max(max_val_cur, max_val) + + self.min_val.resize_(min_val.shape) + self.max_val.resize_(max_val.shape) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + + return x_copy + + def _calculate_percentile_stats(self, x_copy): + r"""Calculates and stores the per_channel percentile stats with forward values. + Does calculation based on channel axis: self.ch_axis + + Args + x_copy: A copy of the forward data + + Returns the passed in x_copy + """ + # get the dimension of the copy + x_dim = x_copy.size() + + new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x_copy.permute(new_axis_list) + # Need to match dtype of min/max because the updates to buffers + # are done in place and types need to match for comparisons + y = y.to(self.min_val.dtype) + y = torch.flatten(y, start_dim=1) + y = y.to(dtype=self.min_val.dtype, device="cpu") + + # find the percentile values along the axis + # we want both 100th percentile and comp_percentile + # we also want to find 0th quartile to see if we have constant channel + quantiles_list = [0, self.comp_percentile, 1.00] + quantiles_to_find = torch.tensor(quantiles_list, dtype=self.min_val.dtype) + + # find the quantiles + desired_quantiles = torch.quantile( + y, quantiles_to_find, dim=self.ch_axis, interpolation="lower" + ) + zero_quantile = desired_quantiles[0] + comp_quantile = desired_quantiles[1] + hundreth_quartile = desired_quantiles[2] + + # if any of the channels have 0s, we ignore that channel for this calculation + any_non_zero_quantile_value: torch.Tensor = ( + comp_quantile != torch.tensor([0]) + ) | (hundreth_quartile != torch.tensor([0])) + any_non_zero_quantile_value = ( + any_non_zero_quantile_value.int() + ) # transform boolean values to int values + + # we also check if we have a constant channel + any_constant_channels: torch.Tensor = ( + hundreth_quartile - zero_quantile + ) == torch.tensor([0]) + any_constant_channels = ( + any_constant_channels.int() + ) # transform boolean values to int values + + # possibilities to get nan as an answer + # will ignore any of these three cases with 0s and just not deal with them for now + # case (1) 0 in numerator: issue if 0 is largest, all negative, and rest are really negative + # case (2) 0 in denominator: is possible unless case 3, we just ignore + # case (3) 0 in both: not outlier, channel just kinda useless, ignore + + # get the ratio and get rid of nan values + quantile_ratios = hundreth_quartile / comp_quantile + quantile_ratios = torch.nan_to_num(quantile_ratios) + # update averages, remembering to only update if didn't have zeros + ratio_if_not_zero = any_non_zero_quantile_value * quantile_ratios + + # if num_batches and average_ratio are not initialized, we want to initialize them + if ( + self.percentile_batches_tracked.shape[0] == 0 + or self.average_percentile_ratio.shape[0] == 0 + ): + self.percentile_batches_tracked = torch.zeros_like( + any_non_zero_quantile_value + ) + self.average_percentile_ratio = torch.zeros_like(ratio_if_not_zero) + + # also initialize the constant channel var if that is not initialized separately + if self.constant_channels.shape[0] == 0: + self.constant_channels = torch.zeros_like(any_constant_channels) + + # get current num batches and average ratio + num_batches = self.percentile_batches_tracked + average_ratio = self.average_percentile_ratio + + # calculate new_number of batches, new_ratios, and get rid of nans because of 0 size batches + new_number_of_batches: torch.Tensor = num_batches + any_non_zero_quantile_value + new_ratios: torch.Tensor = ( + (average_ratio * num_batches) + ratio_if_not_zero + ) / new_number_of_batches + new_ratios = torch.nan_to_num(new_ratios) + + # update the number of non-constant channels + new_constant_count: torch.Tensor = ( + self.constant_channels + any_constant_channels + ) + + # update the values locally + self.percentile_batches_tracked.copy_(new_number_of_batches) + self.average_percentile_ratio.copy_(new_ratios) + self.constant_channels.copy_(new_constant_count) + + return x_copy + + @torch.jit.export + def get_batch_to_epoch_ratio(self): + epoch_activation_range = self.epoch_activation_max - self.epoch_activation_min + + if epoch_activation_range == torch.tensor(float(0)): + raise ValueError("Range for Epoch is 0") + elif epoch_activation_range == torch.tensor(float("inf")): + raise ValueError( + "No data has been run through observer or infinity value present" + ) + else: + return self.average_batch_activation_range / epoch_activation_range + + @torch.jit.export + def reset_batch_and_epoch_values(self): + # set all the values back to their original defaults for a new epoch + # keep device + device = self.max_val.device + self.num_batches_tracked = 0 + self.average_batch_activation_range = torch.tensor(float(0), device=device) + self.epoch_activation_min = torch.tensor(float("inf"), device=device) + self.epoch_activation_max = torch.tensor(float("-inf"), device=device) + self.min_val = torch.tensor([], device=device) + self.max_val = torch.tensor([], device=device) + self.average_percentile_ratio = torch.tensor([], device=device) + self.percentile_batches_tracked = torch.tensor([], device=device) + self.constant_channels = torch.tensor([], device=device) + + @torch.jit.export + def calculate_qparams(self): # type: ignore[override] + raise Exception( # noqa: TRY002 + "calculate_qparams should not be called for ModelReportObserver" + ) diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/_model_report/model_report_visualizer.py b/phivenv/Lib/site-packages/torch/ao/quantization/fx/_model_report/model_report_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..37c0a491883b70232839650b833c9def941a854b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/fx/_model_report/model_report_visualizer.py @@ -0,0 +1,710 @@ +# mypy: allow-untyped-defs +from collections import OrderedDict, OrderedDict as OrdDict +from typing import Any + +import torch + + +# try to import tablate +got_tabulate = True +try: + from tabulate import tabulate +except ImportError: + got_tabulate = False + + +# var to see if we could import matplotlib +got_matplotlib = True +try: + import matplotlib.pyplot as plt +except ImportError: + got_matplotlib = False + + +class ModelReportVisualizer: + r""" + The ModelReportVisualizer class aims to provide users a way to visualize some of the statistics + that were generated by the ModelReport API. However, at a higher level, the class aims to provide + some level of visualization of statistics to PyTorch in order to make it easier to parse data and + diagnose any potential issues with data or a specific model. With respect to the visualizations, + the ModelReportVisualizer class currently supports several methods of visualizing data. + + Supported Visualization Methods Include: + - Table format + - Plot format (line graph) + - Histogram format + + For all of the existing visualization methods, there is the option to filter data based on: + - A module fqn prefix + - Feature [required for the plot and histogram] + + * :attr:`generated_reports` The reports generated by the ModelReport class in the structure below + Ensure sure that features that are the same across different report contain the same name + Ensure that objects representing the same features are the same type / dimension (where applicable) + + Note: + Currently, the ModelReportVisualizer class supports visualization of data generated by the + ModelReport class. However, this structure is extensible and should allow the visualization of + other information as long as the information is structured in the following general format: + + Report Structure + -- module_fqn [module with attached detectors] + | + -- feature keys [not every detector extracts same information] + [same collected info has same keys, unless can be specific to detector] + + + The goal behind the class is that the generated visualizations can be used in conjunction with the generated + report for people to get a better understanding of issues and what the fix might be. It is also just to provide + a good visualization platform, since it might be hard to parse through the ModelReport returned dictionary as + that grows in size. + + General Use Flow Expected + 1.) Initialize ModelReport object with reports of interest by passing in initialized detector objects + 2.) Prepare your model with prepare_fx + 3.) Call model_report.prepare_detailed_calibration on your model to add relevant observers + 4.) Callibrate your model with data + 5.) Call model_report.generate_report on your model to generate report and optionally remove added observers + 6.) Use output of model_report.generate_report to initialize ModelReportVisualizer instance + 7.) Use instance to view different views of data as desired, applying filters as needed + 8.) Either see the super detailed information or just the actual printed or shown table / plot / histogram + + """ + + # keys for table dict + TABLE_TENSOR_KEY = "tensor_level_info" + TABLE_CHANNEL_KEY = "channel_level_info" + + # Constants for header vals + NUM_NON_FEATURE_TENSOR_HEADERS = 2 + NUM_NON_FEATURE_CHANNEL_HEADERS = 3 + + # Constants for row index in header + CHANNEL_NUM_INDEX = 2 + + def __init__(self, generated_reports: OrderedDict[str, Any]): + r""" + Initializes the ModelReportVisualizer instance with the necessary reports. + + Args: + generated_reports (Dict[str, Any]): The reports generated by the ModelReport class + can also be a dictionary generated in another manner, as long as format is same + """ + self.generated_reports = generated_reports + + def get_all_unique_module_fqns(self) -> set[str]: + r""" + The purpose of this method is to provide a user the set of all module_fqns so that if + they wish to use some of the filtering capabilities of the ModelReportVisualizer class, + they don't need to manually parse the generated_reports dictionary to get this information. + + Returns all the unique module fqns present in the reports the ModelReportVisualizer + instance was initialized with. + """ + # returns the keys of the ordered dict + return set(self.generated_reports.keys()) + + def get_all_unique_feature_names( + self, plottable_features_only: bool = True + ) -> set[str]: + r""" + The purpose of this method is to provide a user the set of all feature names so that if + they wish to use the filtering capabilities of the generate_table_view(), or use either of + the generate_plot_view() or generate_histogram_view(), they don't need to manually parse + the generated_reports dictionary to get this information. + + Args: + plottable_features_only (bool): True if the user is only looking for plottable features, + False otherwise + plottable features are those that are tensor values + Default: True (only return those feature names that are plottable) + + Returns all the unique module fqns present in the reports the ModelReportVisualizer + instance was initialized with. + """ + unique_feature_names = set() + for module_fqn in self.generated_reports: + # get dict of the features + feature_dict: dict[str, Any] = self.generated_reports[module_fqn] + + # loop through features + for feature_name in feature_dict: + # if we need plottable, ensure type of val is tensor + if ( + not plottable_features_only + or type(feature_dict[feature_name]) == torch.Tensor + ): + unique_feature_names.add(feature_name) + + # return our compiled set of unique feature names + return unique_feature_names + + def _get_filtered_data( + self, feature_filter: str, module_fqn_filter: str + ) -> OrderedDict[str, Any]: + r""" + Filters the data and returns it in the same ordered dictionary format so the relevant views can be displayed. + + Args: + feature_filter (str): The feature filter, if we want to filter the set of data to only include + a certain set of features that include feature_filter + If feature = "", then we do not filter based on any features + module_fqn_filter (str): The filter on prefix for the module fqn. All modules that have fqn with + this prefix will be included + If module_fqn_filter = "" we do not filter based on module fqn, and include all modules + + First, the data is filtered based on module_fqn, and then filtered based on feature + Returns an OrderedDict (sorted in order of model) mapping: + module_fqns -> feature_names -> values + """ + # create return dict + filtered_dict: OrderedDict[str, Any] = OrdDict() + + for module_fqn in self.generated_reports: + # first filter based on module + if module_fqn_filter == "" or module_fqn_filter in module_fqn: + # create entry for module and loop through features + filtered_dict[module_fqn] = {} + module_reports = self.generated_reports[module_fqn] + for feature_name in module_reports: + # check if filtering on features and do so if desired + if feature_filter == "" or feature_filter in feature_name: + filtered_dict[module_fqn][feature_name] = module_reports[ + feature_name + ] + + # we have populated the filtered dict, and must return it + + return filtered_dict + + def _generate_tensor_table( + self, + filtered_data: OrderedDict[str, dict[str, Any]], + tensor_features: list[str], + ) -> tuple[list, list]: + r""" + Takes in the filtered data and features list and generates the tensor headers and table + + Currently meant to generate the headers and table for both the tensor information. + + Args: + filtered_data (OrderedDict[str, Dict[str, Any]]): An OrderedDict (sorted in order of model) mapping: + module_fqns -> feature_names -> values + tensor_features (List[str]): A list of the tensor level features + + Returns a tuple with: + A list of the headers of the tensor table + A list of lists containing the table information row by row + The 0th index row will contain the headers of the columns + The rest of the rows will contain data + """ + # now we compose the tensor information table + tensor_table: list[list[Any]] = [] + tensor_headers: list[str] = [] + + # append the table row to the table only if we have features + if len(tensor_features) > 0: + # now we add all the data + for index, module_fqn in enumerate(filtered_data): + # we make a new row for the tensor table + tensor_table_row = [index, module_fqn] + for feature in tensor_features: + # we iterate in same order of added features + + if feature in filtered_data[module_fqn]: + # add value if applicable to module + feature_val = filtered_data[module_fqn][feature] + else: + # add that it is not applicable + feature_val = "Not Applicable" + + # if it's a tensor we want to extract val + if isinstance(feature_val, torch.Tensor): + feature_val = feature_val.item() + + # we add to our list of values + tensor_table_row.append(feature_val) + + tensor_table.append(tensor_table_row) + + # add row of headers of we actually have something, otherwise just empty + if len(tensor_table) != 0: + tensor_headers = ["idx", "layer_fqn"] + tensor_features + + return (tensor_headers, tensor_table) + + def _generate_channels_table( + self, + filtered_data: OrderedDict[str, Any], + channel_features: list[str], + num_channels: int, + ) -> tuple[list, list]: + r""" + Takes in the filtered data and features list and generates the channels headers and table + + Currently meant to generate the headers and table for both the channels information. + + Args: + filtered_data (OrderedDict[str, Any]): An OrderedDict (sorted in order of model) mapping: + module_fqns -> feature_names -> values + channel_features (List[str]): A list of the channel level features + num_channels (int): Number of channels in the channel data + + Returns a tuple with: + A list of the headers of the channel table + A list of lists containing the table information row by row + The 0th index row will contain the headers of the columns + The rest of the rows will contain data + """ + # now we compose the table for the channel information table + channel_table: list[list[Any]] = [] + channel_headers: list[str] = [] + + # counter to keep track of number of entries in + channel_table_entry_counter: int = 0 + + if len(channel_features) > 0: + # now we add all channel data + for module_fqn in filtered_data: + # we iterate over all channels + for channel in range(num_channels): + # we make a new row for the channel + new_channel_row = [channel_table_entry_counter, module_fqn, channel] + for feature in channel_features: + if feature in filtered_data[module_fqn]: + # add value if applicable to module + feature_val = filtered_data[module_fqn][feature][channel] + else: + # add that it is not applicable + feature_val = "Not Applicable" + + # if it's a tensor we want to extract val + if type(feature_val) is torch.Tensor: + feature_val = feature_val.item() + + # add value to channel specific row + new_channel_row.append(feature_val) + + # add to table and increment row index counter + channel_table.append(new_channel_row) + channel_table_entry_counter += 1 + + # add row of headers of we actually have something, otherwise just empty + if len(channel_table) != 0: + channel_headers = ["idx", "layer_fqn", "channel"] + channel_features + + return (channel_headers, channel_table) + + def generate_filtered_tables( + self, feature_filter: str = "", module_fqn_filter: str = "" + ) -> dict[str, tuple[list, list]]: + r""" + Takes in optional filter values and generates two tables with desired information. + + The generated tables are presented in both a list-of-lists format + + The reason for the two tables are that they handle different things: + 1.) the first table handles all tensor level information + 2.) the second table handles and displays all channel based information + + The reasoning for this is that having all the info in one table can make it ambiguous which collected + statistics are global, and which are actually per-channel, so it's better to split it up into two + tables. This also makes the information much easier to digest given the plethora of statistics collected + + Tensor table columns: + idx layer_fqn feature_1 feature_2 feature_3 .... feature_n + ---- --------- --------- --------- --------- --------- + + Per-Channel table columns: + idx layer_fqn channel feature_1 feature_2 feature_3 .... feature_n + ---- --------- ------- --------- --------- --------- --------- + + Args: + feature_filter (str, optional): Filters the features presented to only those that + contain this filter substring + Default = "", results in all the features being printed + module_fqn_filter (str, optional): Only includes modules that contains this string + Default = "", results in all the modules in the reports to be visible in the table + + Returns a dictionary with two keys: + (Dict[str, Tuple[List, List]]) A dict containing two keys: + "tensor_level_info", "channel_level_info" + Each key maps to a tuple with: + A list of the headers of each table + A list of lists containing the table information row by row + The 0th index row will contain the headers of the columns + The rest of the rows will contain data + + Example Use: + >>> # xdoctest: +SKIP("undefined variables") + >>> mod_report_visualizer.generate_filtered_tables( + ... feature_filter="per_channel_min", module_fqn_filter="block1" + ... ) # generates table with per_channel_min info for all modules in block 1 of the model + """ + # first get the filtered data + filtered_data: OrderedDict[str, Any] = self._get_filtered_data( + feature_filter, module_fqn_filter + ) + + # now we split into tensor and per-channel data + tensor_features: set[str] = set() + channel_features: set[str] = set() + + # keep track of the number of channels we have + num_channels: int = 0 + + for module_fqn in filtered_data: + for feature_name in filtered_data[module_fqn]: + # get the data for that specific feature + feature_data = filtered_data[module_fqn][feature_name] + + # check if not zero dim tensor + is_tensor: bool = isinstance(feature_data, torch.Tensor) + is_not_zero_dim: bool = is_tensor and len(feature_data.shape) != 0 + + if is_not_zero_dim or isinstance(feature_data, list): + # works means per channel + channel_features.add(feature_name) + num_channels = len(feature_data) + else: + # means is per-tensor + tensor_features.add(feature_name) + + # we make them lists for iteration purposes + tensor_features_list: list[str] = sorted(tensor_features) + channel_features_list: list[str] = sorted(channel_features) + + # get the tensor info + tensor_headers, tensor_table = self._generate_tensor_table( + filtered_data, tensor_features_list + ) + + # get the channel info + channel_headers, channel_table = self._generate_channels_table( + filtered_data, channel_features_list, num_channels + ) + + # let's now create the dictionary to return + table_dict = { + self.TABLE_TENSOR_KEY: (tensor_headers, tensor_table), + self.TABLE_CHANNEL_KEY: (channel_headers, channel_table), + } + + # return the two tables + return table_dict + + def generate_table_visualization( + self, feature_filter: str = "", module_fqn_filter: str = "" + ): + r""" + Takes in optional filter values and prints out formatted tables of the information. + + The reason for the two tables printed out instead of one large one are that they handle different things: + 1.) the first table handles all tensor level information + 2.) the second table handles and displays all channel based information + + The reasoning for this is that having all the info in one table can make it ambiguous which collected + statistics are global, and which are actually per-channel, so it's better to split it up into two + tables. This also makes the information much easier to digest given the plethora of statistics collected + + Tensor table columns: + idx layer_fqn feature_1 feature_2 feature_3 .... feature_n + ---- --------- --------- --------- --------- --------- + + Per-Channel table columns: + + idx layer_fqn channel feature_1 feature_2 feature_3 .... feature_n + ---- --------- ------- --------- --------- --------- --------- + + Args: + feature_filter (str, optional): Filters the features presented to only those that + contain this filter substring + Default = "", results in all the features being printed + module_fqn_filter (str, optional): Only includes modules that contains this string + Default = "", results in all the modules in the reports to be visible in the table + + Example Use: + >>> # xdoctest: +SKIP("undefined variables") + >>> mod_report_visualizer.generate_table_visualization( + ... feature_filter="per_channel_min", module_fqn_filter="block1" + ... ) + >>> # prints out neatly formatted table with per_channel_min info + >>> # for all modules in block 1 of the model + """ + # see if we got tabulate + if not got_tabulate: + print("Make sure to install tabulate and try again.") + return None + + # get the table dict and the specific tables of interest + table_dict = self.generate_filtered_tables(feature_filter, module_fqn_filter) + tensor_headers, tensor_table = table_dict[self.TABLE_TENSOR_KEY] + channel_headers, channel_table = table_dict[self.TABLE_CHANNEL_KEY] + + # get the table string and print it out + # now we have populated the tables for each one + # let's create the strings to be returned + table_str = "" + # the tables will have some headers columns that are non-feature + # ex. table index, module name, channel index, etc. + # we want to look at header columns for features, that come after those headers + if len(tensor_headers) > self.NUM_NON_FEATURE_TENSOR_HEADERS: + # if we have at least one tensor level feature to be added we add tensor table + table_str += "Tensor Level Information \n" + table_str += tabulate(tensor_table, headers=tensor_headers) + if len(channel_headers) > self.NUM_NON_FEATURE_CHANNEL_HEADERS: + # if we have at least one channel level feature to be added we add tensor table + table_str += "\n\n Channel Level Information \n" + table_str += tabulate(channel_table, headers=channel_headers) + + # if no features at all, let user know + if table_str == "": + table_str = "No data points to generate table with." + + print(table_str) + + def _get_plottable_data( + self, feature_filter: str, module_fqn_filter: str + ) -> tuple[list, list[list], bool]: + r""" + Takes in the feature filters and module filters and outputs the x and y data for plotting + + Args: + feature_filter (str): Filters the features presented to only those that + contain this filter substring + module_fqn_filter (str): Only includes modules that contains this string + + Returns a tuple of three elements + The first is a list containing relevant x-axis data + The second is a list containing the corresponding y-axis data + If the data is per channel + """ + # get the table dict and the specific tables of interest + table_dict = self.generate_filtered_tables(feature_filter, module_fqn_filter) + tensor_headers, tensor_table = table_dict[self.TABLE_TENSOR_KEY] + channel_headers, channel_table = table_dict[self.TABLE_CHANNEL_KEY] + + # make sure it is only 1 feature that is being plotted + # get the number of features in each of these + tensor_info_features_count = ( + len(tensor_headers) - ModelReportVisualizer.NUM_NON_FEATURE_TENSOR_HEADERS + ) + channel_info_features_count = ( + len(channel_headers) - ModelReportVisualizer.NUM_NON_FEATURE_CHANNEL_HEADERS + ) + + # see if valid tensor or channel plot + is_valid_per_tensor_plot: bool = tensor_info_features_count == 1 + is_valid_per_channel_plot: bool = channel_info_features_count == 1 + + # offset should either be one of tensor or channel table or neither + feature_column_offset = ModelReportVisualizer.NUM_NON_FEATURE_TENSOR_HEADERS + table = tensor_table + + # if a per_channel plot, we have different offset and table + if is_valid_per_channel_plot: + feature_column_offset = ( + ModelReportVisualizer.NUM_NON_FEATURE_CHANNEL_HEADERS + ) + table = channel_table + + x_data: list = [] + y_data: list[list] = [] + # the feature will either be a tensor feature or channel feature + if is_valid_per_tensor_plot: + for table_row_num, row in enumerate(table): + # get x_value to append + x_val_to_append = table_row_num + # the index of the feature will the 0 + num non feature columns + tensor_feature_index = feature_column_offset + row_value = row[tensor_feature_index] + if not type(row_value) == str: + x_data.append(x_val_to_append) + y_data.append(row_value) + elif is_valid_per_channel_plot: + # gather the x_data and multiple y_data + # calculate the number of channels + num_channels: int = max(row[self.CHANNEL_NUM_INDEX] for row in table) + 1 + + # separate data list per channel + y_data.extend([] for _ in range(num_channels)) + + for table_row_num, row in enumerate(table): + # get x_value to append + x_val_to_append = table_row_num + current_channel = row[ + self.CHANNEL_NUM_INDEX + ] # initially chose current channel + new_module_index: int = table_row_num // num_channels + x_val_to_append = new_module_index + + # the index of the feature will the 0 + num non feature columns + tensor_feature_index = feature_column_offset + row_value = row[tensor_feature_index] + if not type(row_value) == str: + # only append if new index we are appending + if len(x_data) == 0 or x_data[-1] != x_val_to_append: + x_data.append(x_val_to_append) + + # append value for that channel + y_data[current_channel].append(row_value) + else: + # more than one feature was chosen + error_str = "Make sure to pick only a single feature with your filter to plot a graph." + error_str += " We recommend calling get_all_unique_feature_names() to find unique feature names." + error_str += " Pick one of those features to plot." + raise ValueError(error_str) + + # return x, y values, and if data is per-channel + return (x_data, y_data, is_valid_per_channel_plot) + + def generate_plot_visualization( + self, feature_filter: str, module_fqn_filter: str = "" + ): + r""" + Takes in a feature and optional module_filter and plots of the desired data. + + For per channel features, it averages the value across the channels and plots a point + per module. The reason for this is that for models with hundreds of channels, it can + be hard to differentiate one channel line from another, and so the point of generating + a single average point per module is to give a sense of general trends that encourage + further deep dives. + + Note: + Only features in the report that have tensor value data are plottable by this class + When the tensor information is plotted, it will plot: + idx as the x val, feature value as the y_val + When the channel information is plotted, it will plot: + the first idx of each module as the x val, feature value as the y_val [for each channel] + The reason for this is that we want to be able to compare values across the + channels for same layer, and it will be hard if values are staggered by idx + This means each module is represented by only 1 x value + Args: + feature_filter (str): Filters the features presented to only those that + contain this filter substring + module_fqn_filter (str, optional): Only includes modules that contains this string + Default = "", results in all the modules in the reports to be visible in the table + + Example Use: + >>> # xdoctest: +SKIP("undefined variables") + >>> mod_report_visualizer.generate_plot_visualization( + ... feature_filter="per_channel_min", module_fqn_filter="block1" + ... ) + >>> # outputs line plot of per_channel_min information for all + >>> # modules in block1 of model each channel gets it's own line, + >>> # and it's plotted across the in-order modules on the x-axis + """ + # checks if we have matplotlib and let's user know to install it if don't + if not got_matplotlib: + print("make sure to install matplotlib and try again.") + return None + + # get the x and y data and if per channel + x_data, y_data, data_per_channel = self._get_plottable_data( + feature_filter, module_fqn_filter + ) + + # plot based on whether data is per channel or not + ax = plt.subplot() + ax.set_ylabel(feature_filter) + ax.set_title(feature_filter + " Plot") + plt.xticks(x_data) # only show ticks for actual points + + if data_per_channel: + ax.set_xlabel("First idx of module") + # set the legend as well + # plot a single line that is average of the channel values + num_modules = len( + y_data[0] + ) # all y_data have same length, so get num modules + num_channels = len( + y_data + ) # we want num channels to be able to calculate average later + + avg_vals = [ + sum(y_data[:][index]) / num_channels for index in range(num_modules) + ] + + # plot the three things we measured + ax.plot( + x_data, avg_vals, label=f"Average Value Across {num_channels} Channels" + ) + ax.legend(loc="upper right") + else: + ax.set_xlabel("idx") + ax.plot(x_data, y_data) + + # actually show the plot + plt.show() + + def generate_histogram_visualization( + self, feature_filter: str, module_fqn_filter: str = "", num_bins: int = 10 + ): + r""" + Takes in a feature and optional module_filter and plots the histogram of desired data. + + Note: + Only features in the report that have tensor value data can be viewed as a histogram + If you want to plot a histogram from all the channel values of a specific feature for + a specific model, make sure to specify both the model and the feature properly + in the filters and you should be able to see a distribution of the channel data + + Args: + feature_filter (str, optional): Filters the features presented to only those that + contain this filter substring + Default = "", results in all the features being printed + module_fqn_filter (str, optional): Only includes modules that contains this string + Default = "", results in all the modules in the reports to be visible in the table + num_bins (int, optional): The number of bins to create the histogram with + Default = 10, the values will be split into 10 equal sized bins + + Example Use: + >>> # xdoctest: +SKIP + >>> mod_report_visualizer.generategenerate_histogram_visualization_plot_visualization( + ... feature_filter="per_channel_min", module_fqn_filter="block1" + ... ) + # outputs histogram of per_channel_min information for all modules in block1 of model + information is gathered across all channels for all modules in block 1 for the + per_channel_min and is displayed in a histogram of equally sized bins + """ + # checks if we have matplotlib and let's user know to install it if don't + if not got_matplotlib: + print("make sure to install matplotlib and try again.") + return None + + # get the x and y data and if per channel + _x_data, y_data, data_per_channel = self._get_plottable_data( + feature_filter, module_fqn_filter + ) + + # for histogram, we just care about plotting the y data + # plot based on whether data is per channel or not + ax = plt.subplot() + ax.set_xlabel(feature_filter) + ax.set_ylabel("Frequency") + ax.set_title(feature_filter + " Histogram") + + if data_per_channel: + # set the legend as well + # combine all the data + all_data = [] + for channel_info in y_data: + all_data.extend(channel_info) + + _val, bins, _ = plt.hist( + all_data, + bins=num_bins, + stacked=True, + rwidth=0.8, + ) + plt.xticks(bins) + else: + _val, bins, _ = plt.hist( + y_data, + bins=num_bins, + stacked=False, + rwidth=0.8, + ) + plt.xticks(bins) + + plt.show() diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/convert.py b/phivenv/Lib/site-packages/torch/ao/quantization/fx/convert.py new file mode 100644 index 0000000000000000000000000000000000000000..8b31d3fd57c104dea35292440027973e6acab9aa --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/fx/convert.py @@ -0,0 +1,1264 @@ +# mypy: ignore-errors + +import copy +import operator +import warnings +from typing import Any, Callable, Optional, Union + +import torch +from torch.ao.quantization import CUSTOM_KEY, NUMERIC_DEBUG_HANDLE_KEY +from torch.ao.quantization.backend_config import ( + BackendConfig, + get_native_backend_config, +) +from torch.ao.quantization.backend_config.utils import ( + get_fused_module_classes, + get_pattern_to_dtype_configs, + get_qat_module_classes, + get_root_module_to_quantized_reference_module, +) +from torch.ao.quantization.observer import _is_activation_post_process +from torch.ao.quantization.qconfig import qconfig_equals, QConfigAny +from torch.ao.quantization.qconfig_mapping import QConfigMapping +from torch.ao.quantization.quant_type import QuantType +from torch.ao.quantization.quantize import _remove_qconfig +from torch.ao.quantization.stubs import DeQuantStub +from torch.ao.quantization.utils import ( + _parent_name, + activation_is_statically_quantized, + get_qparam_dict, + get_swapped_custom_module_class, + is_per_channel, + to_underlying_dtype, + weight_is_quantized, +) +from torch.fx import GraphModule +from torch.fx.graph import Argument, Graph, Node +from torch.nn.utils.parametrize import type_before_parametrizations + +# importing the lib so that the quantized_decomposed ops are registered +from ._decomposed import quantized_decomposed_lib # noqa: F401 +from ._equalize import convert_eq_obs, update_obs_for_equalization +from .custom_config import ConvertCustomConfig, PrepareCustomConfig +from .graph_module import _is_observed_module, _is_observed_standalone_module +from .lower_to_fbgemm import lower_to_fbgemm +from .qconfig_mapping_utils import ( + _compare_prepare_convert_qconfig_mappings, + _generate_node_name_to_qconfig, + _is_qconfig_supported_by_dtype_configs, + _update_qconfig_for_fusion, + _update_qconfig_for_qat, +) +from .utils import ( + _get_module, + _is_custom_module_lstm, + _is_custom_module_mha, + assert_and_get_unique_device, + collect_producer_nodes, + create_getattr_from_value, + get_custom_module_class_keys, + graph_module_from_producer_nodes, + node_arg_is_weight, +) + + +__all__ = [ + "convert", + "convert_custom_module", + "convert_standalone_module", + "convert_weighted_module", +] + +SUPPORTED_QDTYPES = [ + torch.quint8, + torch.qint8, + torch.qint32, + torch.uint8, + torch.int8, + torch.uint16, + torch.int16, + torch.int32, + torch.float8_e5m2, + torch.float8_e4m3fn, +] + +_QSCHEME_TO_CHOOSE_QPARAMS_OP = { + torch.per_tensor_affine: torch.ops.quantized_decomposed.choose_qparams.tensor, + torch.per_tensor_symmetric: torch.ops.quantized_decomposed.choose_qparams_symmetric.tensor, +} + + +def _replace_observer_with_quantize_dequantize_node_decomposed( + model: torch.fx.GraphModule, + node: Node, + modules: dict[str, torch.nn.Module], + node_name_to_scope: dict[str, tuple[str, type]], + node_name_to_qconfig: dict[str, QConfigAny], +) -> None: + """Replace activation_post_process module call node with quantize and + dequantize node working with decomposed Tensor + + Before: + ... -> observer_0(x) -> ... + After: + ... -> torch.ops.quantized_decomposed.quantize_per_tensor(x, ...) -> + torch.ops.quantized_decomposed.dequantize_per_tensor() -> ... + + or quantize_per_channel and dequantize_per_channel + """ + graph = model.graph + assert modules is not None + assert isinstance(node.target, str) + module_path, prefix = _get_module_path_and_prefix( + node, node_name_to_scope, node_name_to_qconfig + ) + activation_post_process = modules[node.target] + if hasattr(activation_post_process, "convert"): + activation_post_process.convert(model, node) + return + # skip replacing observers to quant/dequant nodes if the qconfigs of all + # consumers and producers of this observer are None + skip_replacement = all( + _has_none_qconfig(n, node_name_to_qconfig) + for n in list(node.args) + list(node.users.keys()) + ) + if skip_replacement or not _is_conversion_supported(activation_post_process): + # didn't find corresponding quantize op and info for the activation_post_process + # so we just remove the observer + with graph.inserting_before(node): + node.replace_all_uses_with(node.args[0]) + graph.erase_node(node) + return + + # otherwise, we can convert the activation_post_process module call to quantize/dequantize node + + # 1. extract the information from activation_post_process module for generating + # the quantize and dequantize operator + dtype = activation_post_process.dtype # type: ignore[attr-defined] + + is_dynamic = False + if hasattr(activation_post_process, "is_dynamic"): + is_dynamic = activation_post_process.is_dynamic # type: ignore[assignment] + + def add_dequantize_op_kwargs(dequantize_op, input_node): + dequantize_op_kwargs = {} + if "val" in input_node.meta: + dq_out_dtype = input_node.meta["val"].dtype + if dq_out_dtype != torch.float32: + dequantize_op_kwargs = {"out_dtype": dq_out_dtype} + return dequantize_op_kwargs + + if dtype in SUPPORTED_QDTYPES and (not is_dynamic): + # TODO: probably should cleanup this condition check, it's hard + # to reason about this if and the following elif + + # uint8/int8/int32 static quantization branch + + # 1. extract information for inserting q/dq node from activation_post_process + node_type = "call_function" + quantize_op: Optional[Callable] = None + scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined, operator] + if is_per_channel(activation_post_process.qscheme): # type: ignore[attr-defined] + ch_axis = int(activation_post_process.ch_axis) # type: ignore[attr-defined, arg-type] + quantize_op = torch.ops.quantized_decomposed.quantize_per_channel.default + dequantize_op = ( + torch.ops.quantized_decomposed.dequantize_per_channel.default + ) + quant_min = activation_post_process.quant_min + quant_max = activation_post_process.quant_max + dtype_ = to_underlying_dtype(dtype) + qparams = { + "_scale_": scale, + "_zero_point_": zero_point, + "_axis_": ch_axis, + "_quant_min_": quant_min, + "_quant_max_": quant_max, + "_dtype_": dtype_, + } + else: + quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.default + dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default + scale = float(scale) + zero_point = int(zero_point) + quant_min = activation_post_process.quant_min # type: ignore[attr-defined] + quant_max = activation_post_process.quant_max # type: ignore[attr-defined] + dtype_ = to_underlying_dtype(dtype) + qparams = { + "_scale_": scale, + "_zero_point_": zero_point, + "_quant_min_": quant_min, + "_quant_max_": quant_max, + "_dtype_": dtype_, + } + + # 2. replace activation_post_process node with quantize and dequantize + with graph.inserting_before(node): + input_node = node.args[0] + quantize_op_inputs = [input_node] + for key, value_or_node in qparams.items(): + # TODO: we can add the information of whether a value needs to + # be registered as an attribute in qparams dict itself + if key in ["_scale_", "_zero_point_"] and ( + not isinstance(value_or_node, (float, int)) + ): + # For scale and zero_point values we register them as buffers in the root module. + # However, note that when the values are not tensors, as in the case of + # per_tensor quantization, they will be treated as literals. + # However, registering them as a node seems to cause issue with dynamo + # tracing where it may consider tensor overload as opposed to default. + # With extra check of scale and zero_point being scalar, it makes + # sure that the default overload can be used. + # TODO: maybe need more complex attr name here + qparam_node = create_getattr_from_value( + model, graph, module_path + prefix + key, value_or_node + ) + quantize_op_inputs.append(qparam_node) + else: + # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph. + quantize_op_inputs.append(value_or_node) + + quantized_node = graph.create_node( + node_type, quantize_op, tuple(quantize_op_inputs), {} + ) + # use the same qparams from quantize op + dq_inputs = [quantized_node] + quantize_op_inputs[1:] + dequantized_node = graph.call_function( + dequantize_op, + tuple(dq_inputs), + add_dequantize_op_kwargs(dequantize_op, input_node), + ) + + node.replace_all_uses_with(dequantized_node) + # propagate numeric debug handle from observer/fake_quant node to dequantize node + if ( + CUSTOM_KEY in node.meta + and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY] + ): + if CUSTOM_KEY not in dequantized_node.meta: + dequantized_node.meta[CUSTOM_KEY] = {} + dequantized_node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = node.meta[ + CUSTOM_KEY + ][NUMERIC_DEBUG_HANDLE_KEY] + graph.erase_node(node) + elif is_dynamic: + # uint8/int8/fp16 dynamic quantization + + # 1. extract information for inserting q/dq node from activation_post_process + node_type = "call_function" + quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.tensor + # we only use choose_qparams for is_decomposed now, + # but we should probably align the non-decomposed path with this as well, + # and that can be done after we remove reduce_range flag + # 1. extract qparams from activation_post_process module + dtype_ = to_underlying_dtype(dtype) + assert dtype_ in [torch.uint8, torch.int8], ( + "only uint8 and int8 are supported in reference flow for " + "dynamic quantization right now" + ) + quant_min = activation_post_process.quant_min # type: ignore[attr-defined] + quant_max = activation_post_process.quant_max # type: ignore[attr-defined] + qscheme = getattr(activation_post_process, "qscheme", torch.per_tensor_affine) # type: ignore[attr-defined] + eps = getattr(activation_post_process, "eps", torch.finfo(torch.float32).eps) # type: ignore[attr-defined] + # note: scale and zero_point are missing for quantize_per_tensor op + # we'll need to get this from choose_qparams op, which we'll add after + # this step + qparams = { + "_quant_min_": quant_min, + "_quant_max_": quant_max, + "_eps_": eps, + "_dtype_": dtype_, + } + + choose_qparams_op = _QSCHEME_TO_CHOOSE_QPARAMS_OP[qscheme] + # 2. insert choose_qparams op and update the qparams list + with graph.inserting_before(node): + input_node = node.args[0] + choose_qparams_op_inputs = [node.args[0]] + for key, value in qparams.items(): + # we have quant_min, quant_max and dtype, all should be stored + # as literals + choose_qparams_op_inputs.append(value) + choose_qparams_node = graph.create_node( + "call_function", choose_qparams_op, tuple(choose_qparams_op_inputs), {} + ) + # choose_qparms returns (scale, zero_point) + scale_node = graph.create_node( + "call_function", operator.getitem, (choose_qparams_node, 0), {} + ) + zero_point_node = graph.create_node( + "call_function", operator.getitem, (choose_qparams_node, 1), {} + ) + quant_min = qparams["_quant_min_"] + quant_max = qparams["_quant_max_"] + dtype = qparams["_dtype_"] + qparams = { + "_scale_": scale_node, + "_zero_point_": zero_point_node, + "_quant_min_": quant_min, + "_quant_max_": quant_max, + "_dtype_": dtype, + } + + # 3. replace activation_post_process node to quantize and dequantize node + with graph.inserting_before(node): + input_node = node.args[0] + quantize_op_inputs = [input_node] + for key, value_or_node in qparams.items(): + # TODO: we can add the information of whether a value needs to + # be registered as an attribute in qparams dict itself + if key in ["_scale_", "_zero_point_"]: + # in this case we have a node in the graph since it's dynamically + # computed from the input, with choose_qparams op + qparam_node = value_or_node + quantize_op_inputs.append(qparam_node) + else: + # for qparams that are not scale/zero_point (like axis, dtype) we + # store them as literals in the graph. + quantize_op_inputs.append(value_or_node) + + quantized_node = graph.create_node( + node_type, quantize_op, tuple(quantize_op_inputs), {} + ) + # use the same qparams from quantize op + dq_inputs = [quantized_node] + quantize_op_inputs[1:] + # need to use the tensor variant of this op, since scale and zero_point + # from choose_qparam are Tensors, instead of float/int, this is to + # prevent these nodes being traced away by downstream systems + dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.tensor + dequantized_node = graph.call_function( + dequantize_op, + tuple(dq_inputs), + add_dequantize_op_kwargs(dequantize_op, input_node), + ) + + node.replace_all_uses_with(dequantized_node) + # propagate numeric debug handle from observer/fake_quant node to dequantize node + if NUMERIC_DEBUG_HANDLE_KEY in node.meta: + dequantized_node.meta[NUMERIC_DEBUG_HANDLE_KEY] = node.meta[ + NUMERIC_DEBUG_HANDLE_KEY + ] + graph.erase_node(node) + elif dtype == torch.float16: + # Insert to_fp16 -> to_fp32 node + dtype_convert_op = torch.ops.quantized_decomposed.convert_element_type.no_fuse + with graph.inserting_before(node): + input_node = node.args[0] + convert_fp16_node = graph.create_node( + "call_function", dtype_convert_op, (input_node, torch.float16), {} + ) + convert_fp32_node = graph.create_node( + "call_function", dtype_convert_op, (convert_fp16_node, torch.float), {} + ) + node.replace_all_uses_with(convert_fp32_node) + graph.erase_node(node) + + # should not reach since we have checks in the beginning to make sure the + # activation_post_process is supported + + +def _replace_observer_with_quantize_dequantize_node( + model: torch.fx.GraphModule, + node: Node, + modules: dict[str, torch.nn.Module], + node_name_to_scope: dict[str, tuple[str, type]], + node_name_to_qconfig: dict[str, QConfigAny], +) -> None: + """Replace activation_post_process module call node with quantize and + dequantize node + + Before: + ... -> observer_0(x) -> ... + After: + ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ... + """ + assert modules is not None + assert isinstance(node.target, str) + graph = model.graph + module_path, prefix = _get_module_path_and_prefix( + node, node_name_to_scope, node_name_to_qconfig + ) + activation_post_process = modules[node.target] + # skip replacing observers to quant/dequant nodes if the qconfigs of all + # consumers and producers of this observer are None + skip_replacement = all( + _has_none_qconfig(n, node_name_to_qconfig) + for n in list(node.args) + list(node.users.keys()) + ) + if skip_replacement or not _is_conversion_supported(activation_post_process): + # didn't find corresponding quantize op and info for the activation_post_process + # so we just remove the observer + with graph.inserting_before(node): + node.replace_all_uses_with(node.args[0]) + graph.erase_node(node) + return + + # otherwise, we can convert the activation_post_process module call to quantize/dequantize node + dtype = activation_post_process.dtype # type: ignore[attr-defined] + + is_dynamic = False + if hasattr(activation_post_process, "is_dynamic"): + is_dynamic = activation_post_process.is_dynamic # type: ignore[attr-defined, assignment] + + if dtype in [ + torch.quint8, + torch.qint8, + torch.qint32, + torch.float8_e5m2, + torch.float8_e4m3fn, + ] and (not is_dynamic): + # TODO: probably should cleanup this condition check, it's hard + # to reason about this if and the following elif + + # uint8/int8/int32 static quantization branch + + # 1. extract the information from activation_post_process module for generating + # the quantize and dequantize operator + node_type = "call_function" + quantize_op: Optional[Callable] = None + scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined, operator] + if is_per_channel(activation_post_process.qscheme): # type: ignore[attr-defined] + ch_axis = int(activation_post_process.ch_axis) # type: ignore[attr-defined, arg-type] + qparams = { + "_scale_": scale, + "_zero_point_": zero_point, + "_axis_": ch_axis, + "_dtype_": dtype, + } + quantize_op = torch.quantize_per_channel + else: + scale = float(scale) + zero_point = int(zero_point) + qparams = {"_scale_": scale, "_zero_point_": zero_point, "_dtype_": dtype} + quantize_op = torch.quantize_per_tensor + + # 2. replace activation_post_process node with quantize and dequantize + with graph.inserting_before(node): + input_node = node.args[0] + quantize_op_inputs = [input_node] + for key, value_or_node in qparams.items(): + # TODO: we can add the information of whether a value needs to + # be registered as an attribute in qparams dict itself + if key in ["_scale_", "_zero_point_"]: + # For scale and zero_point values we register them as buffers in the root module. + # TODO: maybe need more complex attr name here + qparam_node = create_getattr_from_value( + model, graph, module_path + prefix + key, value_or_node + ) + quantize_op_inputs.append(qparam_node) + else: + # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph. + quantize_op_inputs.append(value_or_node) + + quantized_node = graph.create_node( + node_type, quantize_op, tuple(quantize_op_inputs), {} + ) + dequantized_node = graph.call_method("dequantize", args=(quantized_node,)) + node.replace_all_uses_with(dequantized_node) + graph.erase_node(node) + elif is_dynamic: + # uint8/int8/fp16 dynamic quantization branch + + node_type = "call_function" + quantize_op = torch.quantize_per_tensor_dynamic + # TODO: get reduce range from observer + # reduce_range = activation_post_process.reduce_range + reduce_range = torch.backends.quantized.engine in ("fbgemm", "x86") + qparams = {"_dtype_": dtype, "_reduce_range_": reduce_range} + + with graph.inserting_before(node): + input_node = node.args[0] + quantize_op_inputs = [input_node] + for key, value in qparams.items(): + quantize_op_inputs.append(value) + + quantized_node = graph.create_node( + node_type, quantize_op, tuple(quantize_op_inputs), {} + ) + dequantized_node = graph.call_method("dequantize", args=(quantized_node,)) + node.replace_all_uses_with(dequantized_node) + graph.erase_node(node) + elif dtype == torch.float16: + node_type = "call_method" + quantize_op = "to" # type: ignore[assignment] + qparams = {"_dtype_": dtype} + with graph.inserting_before(node): + input_node = node.args[0] + quantize_op_inputs = [input_node] + for key, value in qparams.items(): + # TODO: we can add the information of whether a value needs to + # be registered as an attribute in qparams dict itself + quantize_op_inputs.append(value) + + quantized_node = graph.create_node( + node_type, quantize_op, tuple(quantize_op_inputs), {} + ) + dequantized_node = graph.call_method("dequantize", args=(quantized_node,)) + node.replace_all_uses_with(dequantized_node) + graph.erase_node(node) + + # should not reach since we have checks in the beginning to make sure the + # activation_post_process is supported + + +# this is a temporary hack for custom module, we may want to implement +# this properly after the custom module class design is finalized +# TODO: DeQuantStubs are currently inserted only after custom module LSTM, while observers are inserted +# after all other custom modules. In the future, we should simply insert QuantStubs before and DeQuantStubs +# after custom modules in general, and replace these with "quantize" and "dequantize" nodes respectively. +def _replace_observer_or_dequant_stub_with_dequantize_node( + node: Node, graph: Graph +) -> None: + call_custom_module_node = node.args[0] + assert isinstance(call_custom_module_node, Node), ( + f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}" + ) + node.replace_all_uses_with(call_custom_module_node) + graph.erase_node(node) + _insert_dequantize_node(call_custom_module_node, graph) + + +def _is_conversion_supported(activation_post_process: torch.nn.Module) -> bool: + dtype = activation_post_process.dtype # type: ignore[attr-defined] + + is_dynamic = False + if hasattr(activation_post_process, "is_dynamic"): + is_dynamic = activation_post_process.is_dynamic # type: ignore[attr-defined, assignment] + + return ( + (dtype in SUPPORTED_QDTYPES and (not is_dynamic)) + or is_dynamic # type: ignore[return-value] + or dtype == torch.float16 + ) + + +def _has_none_qconfig( + node: Argument, node_name_to_qconfig: dict[str, QConfigAny] +) -> bool: + """Check if a node has a qconfig of None, i.e. user requested to not quantize + the node + """ + return ( + isinstance(node, Node) + and node.name in node_name_to_qconfig + and node_name_to_qconfig[node.name] is None + ) + + +def _run_weight_observers(observed: GraphModule, backend_config: BackendConfig) -> None: + """Extract the subgraph that produces the weight for dynamic quant + or weight only quant node and run the subgraph to observe the weight. + Note that the observers of dynamic quant or weight only quant ops are + run during the convert step. + """ + for node in observed.graph.nodes: + if node.op != "call_function": + continue + for node_arg in node.args: + # node_arg is weight + if node_arg and node_arg_is_weight(node, node_arg): + weight_observer_nodes = collect_producer_nodes(node_arg) + if weight_observer_nodes is None: + continue + weight_observer_module = graph_module_from_producer_nodes( + observed, weight_observer_nodes + ) + # run the weight observer + weight_observer_module() + + +def _maybe_recursive_remove_dequantize(arg: Any, node: Node, graph: Graph) -> None: + """If the arg is a dequantize Node, or a list/tuple/dict of dequantize Node, + we'll recursively remove the dequantize Node + """ + if isinstance(arg, Node) and arg.op == "call_method" and arg.target == "dequantize": + quantize_node = arg.args[0] + # we only replace the specific use since dequantize could be used by other nodes + # as well + node.replace_input_with(arg, quantize_node) + elif isinstance(arg, (list, tuple)): + for arg_element in arg: + _maybe_recursive_remove_dequantize(arg_element, node, graph) + elif isinstance(arg, dict): + for arg_element in arg.values(): + _maybe_recursive_remove_dequantize(arg_element, node, graph) + else: + warnings.warn( + f"Unsupported node type in recursive remove dequantize: {type(arg)}" + ) + + +def _get_module_path_and_prefix( + obs_node: Node, + node_name_to_scope: dict[str, tuple[str, type]], + node_name_to_qconfig: dict[str, QConfigAny], +) -> tuple[str, str]: + """Given and observer node, get the `Scope` or the fully qualified name for + the submodule containing the observed node, also return a prefix of "_input" + when the observed node is an input of a F.linear op, and not the output of another + quantized op. + TODO: this logic is hacky, we should think about how to remove it or make it more + general + """ + observed_node = obs_node.args[0] + # an observer can be inserted for both input of the next operator or output of the previous + # operator (they can be the same) + # this flag identifies if the observer is inserted only because the observed node is + # the input of the next operator + assert isinstance(observed_node, Node), ( + f"Expecting observed node to be a Node, but got {observed_node}" + ) + is_input_observer_only = ( + node_name_to_qconfig[observed_node.name] is None + if observed_node.name in node_name_to_qconfig + else None + ) + if is_input_observer_only: + # if the quantize function is at the input of op, then we find the first user of the observer_node + # to get the path. If a linear call_function is in the user list, we return the first instance + # of linear node to get the FQN. + users = list(obs_node.users) + first_linear_use_or_first_use = users[0] if users else None + linear_node = None + for n in users: + if n.op == "call_function" and n.target == torch.nn.functional.linear: + linear_node = n + break + if linear_node: + first_linear_use_or_first_use = linear_node + prefix = "_input" + else: + # if the quantize function is at the output of the op, we use the observer input node to get the path + first_linear_use_or_first_use = observed_node + prefix = "" + + if ( + first_linear_use_or_first_use + and first_linear_use_or_first_use.name in node_name_to_scope + ): + module_path, _ = node_name_to_scope[first_linear_use_or_first_use.name] + else: + # TODO: it's not used, so actually we can skip quantization + # but this requires changing return type of quantize_node + # we can fix it later if needed + module_path = "" + return module_path, prefix + + +def _insert_dequantize_node(node: Node, graph: Graph) -> None: + """Inserts dequantize node for `node` in `graph`""" + with graph.inserting_after(node): + dequantize_node = graph.call_method("dequantize", (node,)) + for user_node in dict(node.users): + if user_node is not dequantize_node: + user_node.replace_input_with(node, dequantize_node) + + +def _maybe_get_observer_for_node( + node: Node, modules: dict[str, torch.nn.Module] +) -> Optional[torch.nn.Module]: + """ + If the node is observed, return the observer + instance. Otherwise, return None. + """ + for maybe_obs_node in node.users.keys(): + if maybe_obs_node.op == "call_module": + maybe_obs = modules[str(maybe_obs_node.target)] + if _is_activation_post_process(maybe_obs): + return maybe_obs + return None + + +def convert_standalone_module( + node: Node, + modules: dict[str, torch.nn.Module], + model: torch.fx.GraphModule, + is_reference: bool, + backend_config: Optional[BackendConfig], +) -> None: + """Converts a observed standalone module to a quantized standalone module by calling + the fx convert api, currently using the same `is_reference` flag as parent, but we may + changing this behavior in the future (e.g. separating quantization and lowering for + standalone module as well) + + Args: + - node: The call_module node of the observed standalone module + - modules: named_module of original model + - model: original model + - is_reference: a flag from parent provided by user to decide if we want to + produce a reference model or a fbgemm/qnnpack model + - backend_config: backend configuration of the target backend of quantization + """ + # TODO: remove is_reference flag + if is_reference: + convert_fn = torch.ao.quantization.quantize_fx.convert_to_reference_fx + else: + convert_fn = torch.ao.quantization.quantize_fx.convert_fx # type: ignore[attr-defined] + # We know that observed standalone module is a GraphModule since + # it's produced by us + observed_standalone_module: GraphModule = modules[str(node.target)] # type: ignore[assignment] + sm_input_quantized_idxs = observed_standalone_module.meta[ + "_observed_graph_module_attrs" + ].standalone_module_input_quantized_idxs + # remove the dequantize nodes for inputs + args = list(node.args) + for idx in range(len(args)): + if idx in sm_input_quantized_idxs: + arg = args[idx] + if arg.op == "call_method" and arg.target == "dequantize": # type: ignore[union-attr] + quantize_node = arg.args[0] # type: ignore[union-attr] + node.replace_input_with(arg, quantize_node) + if len(arg.users) == 0: # type: ignore[union-attr] + model.graph.erase_node(arg) + # add dequantize node for output + sm_output_quantized_idxs = observed_standalone_module.meta[ + "_observed_graph_module_attrs" + ].standalone_module_output_quantized_idxs + if len(sm_output_quantized_idxs) > 0: + assert sm_output_quantized_idxs[0] == 0, "Currently only quantized" + "output idxs = [0] is supported" + + # if it's non-empty, then it means the output is kept in quantized form + # we'll just add a dequantize node after this node + _insert_dequantize_node(node, model.graph) + + # TODO: allow convert_custom_config to override backend_config + # for standalone module + quantized_standalone_module = convert_fn( + observed_standalone_module, backend_config=backend_config + ) + parent_name, name = _parent_name(node.target) + # update the modules dict + setattr(modules[parent_name], name, quantized_standalone_module) + modules[str(node.target)] = quantized_standalone_module + + +def convert_weighted_module( + node: Node, + modules: dict[str, torch.nn.Module], + observed_node_names: set[str], + node_name_to_qconfig: dict[str, QConfigAny], + backend_config: BackendConfig, + is_decomposed: bool = False, + is_reference: bool = False, +) -> None: + """Convert a weighted module to reference quantized module in the model + If the QConfig of a QAT module is not set, the module will still be converted to + a float module. + + Args: + - node: The call_module node of the observed standalone module + - modules: named_module of original model + - observed_node_names: names for the set of observed fx node, we can skip + this conversion if the node is not observed + """ + original_module = modules[str(node.target)] + qconfig: QConfigAny = original_module.qconfig # type: ignore[assignment] + weight_post_process = None + qat_module_classes = get_qat_module_classes(backend_config) + + if isinstance(original_module, qat_module_classes): + # Converting qat module to a float module, we need to attach + # weight fake_quant to the module, weight fake_quant is assumed to be run during + # QAT so we don't need to run it again here + weight_post_process = original_module.weight_fake_quant + original_module = original_module.to_float() # type: ignore[operator] + # change qat module to float module + parent_name, name = _parent_name(node.target) + setattr(modules[parent_name], name, original_module) + + is_observed = node.name in observed_node_names + # If a qconfig is not defined for this node, then skip converting to a reference module + if ( + qconfig is None + or _has_none_qconfig(node, node_name_to_qconfig) + or not is_observed + ): + return + + # skip converting to reference quantized module if the qconfig is not supported + pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config) + dtype_configs = pattern_to_dtype_configs.get(type(original_module), []) + if not _is_qconfig_supported_by_dtype_configs(qconfig, dtype_configs): + return + + # TODO: rename weight_is_statically_quantized to weight_is_int8_quantized + is_weight_quantized = weight_is_quantized(qconfig) + + # the condition for swapping the module to reference quantized module is: + # weights need to be quantized + if not is_weight_quantized: + return + + fused_module = None + float_module = original_module + # extract the individual float_module and fused module + if isinstance(original_module, torch.ao.nn.intrinsic._FusedModule): + fused_module = float_module + float_module = fused_module[0] # type: ignore[index] + + # TODO: move this to the reference quantized module + # weight_qparams or weight_qparams dict + wq_or_wq_dict = {"is_decomposed": is_decomposed} + if isinstance(float_module, torch.nn.RNNCellBase): + weight_post_process_ih = qconfig.weight() # type: ignore[union-attr, operator] + weight_post_process_hh = qconfig.weight() # type: ignore[union-attr, operator] + weight_post_process_ih(float_module.weight_ih) + weight_post_process_hh(float_module.weight_hh) + weight_qparams_ih = get_qparam_dict(weight_post_process_ih) + weight_qparams_hh = get_qparam_dict(weight_post_process_hh) + wq_or_wq_dict.update( + { + "weight_ih": weight_qparams_ih, + "weight_hh": weight_qparams_hh, + } + ) + elif isinstance(float_module, (torch.nn.LSTM, torch.nn.GRU)): + # format for wq_or_wq_dict (flattened attributes): + # {"weight_ih_l0_scale": ..., "weight_ih_l0_qscheme": ..., ...} + for wn in float_module._flat_weights_names: + if hasattr(float_module, wn) and wn.startswith("weight"): + weight = getattr(float_module, wn) + weight_post_process = qconfig.weight() # type: ignore[union-attr, operator] + if weight_post_process.dtype == torch.qint8: # type: ignore[union-attr] + weight_post_process(weight) # type: ignore[operator, misc] + wq_or_wq_dict[wn] = get_qparam_dict(weight_post_process) + else: + # weight_post_process is None means the original module is not a QAT module + # we need to get weight_post_process from qconfig in this case + is_ptq = weight_post_process is None + if is_ptq: + weight_post_process = qconfig.weight() # type: ignore[union-attr, operator] + device = assert_and_get_unique_device(float_module) + if device: + weight_post_process.to(device) + + # Call weight observer/fake_quant at least once to ensure the scales and zero points + # have the right shapes. Note: there are two cases where we don't have to do this: + # + # (1) QAT: The model's forward method already calls the weight observer/fake_quant, + # and this typically happens during training, so we don't need to do it here. + # + # (2) Non-reference (lowered) case: The quantized module's from_float method already + # calls the weight observer/fake_quant, so we don't have to do it here. + # + # Currently we ignore both cases and call the weight observer/fake_quant here + # regardless, which is technically incorrect. For (1), this is mainly to preserve BC + # in test code, which may not always train before convert. In the future, we should + # break BC for these two cases. See https://github.com/pytorch/pytorch/issues/73941. + # + # For PT2, however, we don't need to preserve BC here, so we can skip this hack + # for QAT. We identify this case as (is_decomposed + is_reference + is_qat). + # Note that we still need it for PTQ in the PT2 flow since the model's forward + # method doesn't call the weight observer. + is_qat = not is_ptq + if not (is_decomposed and is_reference and is_qat): + weight_post_process(float_module.weight) # type: ignore[operator] + + wq_or_wq_dict.update(get_qparam_dict(weight_post_process)) + + # We use the same reference module for all modes of quantization: static, dynamic, weight_only + # root_module_to_quantized_reference_module: module mapping from root (floating point) module class + # to quantized reference module class, e.g. nn.Conv2d to nn.quantized._reference.Conv2d + root_module_to_quantized_reference_module = ( + get_root_module_to_quantized_reference_module(backend_config) + ) + ref_qmodule_cls = root_module_to_quantized_reference_module.get( + type_before_parametrizations(float_module), None + ) + assert ref_qmodule_cls is not None, ( + f"No reference quantized module class configured for {type_before_parametrizations(float_module)}" + ) + ref_qmodule = ref_qmodule_cls.from_float(float_module, wq_or_wq_dict) # type: ignore[attr-defined] + if fused_module is not None: + fused_module[0] = ref_qmodule # type: ignore[operator] + else: + parent_name, name = _parent_name(node.target) + setattr(modules[parent_name], name, ref_qmodule) + + +def _remove_previous_dequantize_in_custom_module( + node: Node, prev_node: Node, graph: Graph +) -> None: + """ + Given a custom module `node`, if the previous node is a dequantize, reroute the custom as follows: + + Before: quantize - dequantize - custom_module + After: quantize - custom_module + \\ - dequantize + """ + # expecting the input node for a custom module node to be a Node + assert isinstance(prev_node, Node), ( + f"Expecting the argument for custom module node to be a Node, but got {prev_node}" + ) + if prev_node.op == "call_method" and prev_node.target == "dequantize": + node.replace_input_with(prev_node, prev_node.args[0]) + # Remove the dequantize node if it doesn't have other users + if len(prev_node.users) == 0: + graph.erase_node(prev_node) + + +def convert_custom_module( + node: Node, + graph: Graph, + modules: dict[str, torch.nn.Module], + custom_module_class_mapping: dict[QuantType, dict[type, type]], + statically_quantized_custom_module_nodes: set[Node], +) -> None: + """Converts an observed custom module to a quantized custom module based on + `custom_module_class_mapping` + For static quantization, we'll also remove the previous `dequantize` node and + attach the observer node for output to the module, the observer for the node + will be converted to a dequantize node instead of quantize-dequantize pairs + later in the graph. In the end we would have a quantized custom module that + has the same interface as a default quantized module in nn.quantized namespace, + i.e. quantized input and quantized output. + + Args: + - node: The call_module node of the observed standalone module + - graph: The graph containing the node + - modules: named_module of original model + - custom_module_class_mapping: mapping from observed custom module class to + quantized custom module class, used to swap custom modules + - statically_quantized_custom_module_nodes: we'll add the custom module node + if we find it is statically quantized, this will be used later when converting + observers to quant/dequant node pairs, if the observed node is a statically + quantized custom module nodes, we'll convert the observer to a dequantize node, + this is to keep the interface the same as the default quantized module. + TODO: maybe we want to redesign this part to align with reference model design + as well, but there has been some discussions around the interface, so we can do + it later. + """ + observed_custom_module = modules[str(node.target)] + qconfig = observed_custom_module.qconfig + if activation_is_statically_quantized(qconfig): + statically_quantized_custom_module_nodes.add(node) + if _is_custom_module_lstm(node, modules): + # The inputs are tuples in the form (input, (hidden0, hidden1)) + # Ensure all three input nodes are quantized + assert ( + len(node.args) == 2 + and isinstance(node.args[1], tuple) + and len(node.args[1]) == 2 + ) + (inputs, (hidden0, hidden1)) = node.args # type: ignore[misc] + assert isinstance(inputs, Node) + assert isinstance(hidden0, Node) + assert isinstance(hidden1, Node) + _remove_previous_dequantize_in_custom_module(node, inputs, graph) + _remove_previous_dequantize_in_custom_module(node, hidden0, graph) + _remove_previous_dequantize_in_custom_module(node, hidden1, graph) + elif _is_custom_module_mha(node, modules): + # Inputs are in the form (query, key, value) + # TODO: This is the first step in enabling the full fx custom module + # quantization path for MultiheadAttention, and only covers the inputs + # to the module. + # Additional handling is yet to be implemented for the outputs, similar + # to LSTM custom module + assert len(node.args) == 3 + query, key, value = node.args + assert isinstance(query, Node) + assert isinstance(key, Node) + assert isinstance(value, Node) + _remove_previous_dequantize_in_custom_module(node, query, graph) + _remove_previous_dequantize_in_custom_module(node, key, graph) + _remove_previous_dequantize_in_custom_module(node, value, graph) + else: + # remove the previous dequant node to ensure the inputs are quantized + arg = node.args[0] + assert isinstance(arg, Node) + _remove_previous_dequantize_in_custom_module(node, arg, graph) + # absorb the following observer into the module conversion + activation_post_process = _maybe_get_observer_for_node(node, modules) + assert activation_post_process is not None + observed_custom_module.activation_post_process = activation_post_process + + # swap the observed custom module to quantized custom module + quantized_custom_module_class = get_swapped_custom_module_class( + observed_custom_module, custom_module_class_mapping, qconfig + ) + quantized_custom_module = quantized_custom_module_class.from_observed( + observed_custom_module + ) + parent_name, name = _parent_name(node.target) + setattr(modules[parent_name], name, quantized_custom_module) + + +def convert( + model: GraphModule, + is_reference: bool = False, + convert_custom_config: Union[ConvertCustomConfig, dict[str, Any], None] = None, + is_standalone_module: bool = False, + _remove_qconfig_flag: bool = True, + qconfig_mapping: Union[QConfigMapping, dict[str, Any], None] = None, + backend_config: Union[BackendConfig, dict[str, Any], None] = None, + is_decomposed: bool = False, + keep_original_weights: bool = False, +) -> GraphModule: + """ + We will convert an observed model (a module with observer calls) to a reference + quantized model, the rule is simple: + 1. for each observer module call in the graph, we'll convert it to calls to + quantize and dequantize functions based on the observer instance + 2. for weighted operations like linear/conv, we need to convert them to reference + quantized module, this requires us to know whether the dtype configured for the + weight is supported in the backend, this is done in prepare step and the result + is stored in observed_node_names, we can decide whether we need to swap the + module based on this set + + Args: + * `is_standalone_module`: when this flag is True, it means we are quantizing + a submodule that is not inlined in parent module, and will be quantized + separately as one unit. + + * `is_decomposed`: a boolean flag to indicate whether we want to use the + quantize operator for decomposed quantized tensor + (torch.ops.quantized_decomposed.quantize_per_tensor) or default/standalone + quantized tensor (torch.quantize_per_tensor) + + Returns: + a quantized standalone module, whether input/output is quantized is + specified by prepare_custom_config, with + input_quantized_idxs, output_quantized_idxs, please + see docs for :func:`~torch.ao.quantization.prepare_fx` for details + """ + if convert_custom_config is None: + convert_custom_config = ConvertCustomConfig() + + if isinstance(convert_custom_config, dict): + warnings.warn( + "Passing a convert_custom_config_dict to convert is deprecated and will not be supported " + "in a future version. Please pass in a ConvertCustomConfig instead.", + FutureWarning, + stacklevel=2, + ) + convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config) + + if isinstance(qconfig_mapping, dict): + warnings.warn( + "Passing a QConfig dictionary to convert is deprecated and will not be supported " + "in a future version. Please pass in a QConfigMapping instead.", + FutureWarning, + stacklevel=2, + ) + qconfig_mapping = ( + QConfigMapping.from_dict(qconfig_mapping) if qconfig_mapping else None + ) + qconfig_mapping = copy.deepcopy(qconfig_mapping) + assert qconfig_mapping is None or isinstance(qconfig_mapping, QConfigMapping) + + if isinstance(backend_config, dict): + warnings.warn( + "Passing a backend_config_dict to prepare is deprecated and will not be supported " + "in a future version. Please pass in a BackendConfig instead.", + FutureWarning, + stacklevel=2, + ) + backend_config = BackendConfig.from_dict(backend_config) + + if backend_config is None: + backend_config = get_native_backend_config() + + assert _is_observed_module(model), "incoming model must be produced by prepare_fx" + observed_graph_module_attrs = model.meta["_observed_graph_module_attrs"] + node_name_to_scope: dict[str, tuple[str, type]] = ( + observed_graph_module_attrs.node_name_to_scope + ) + prepare_custom_config: PrepareCustomConfig = ( + observed_graph_module_attrs.prepare_custom_config + ) + observed_node_names: set[str] = observed_graph_module_attrs.observed_node_names + node_name_to_qconfig: dict[str, QConfigAny] = ( + observed_graph_module_attrs.node_name_to_qconfig + ) # type: ignore[assignment] + + # mapping from fully qualified module name to module instance + # for example, + # { + # '': Model(...), + # 'linear': Linear(...), + # 'linear.weight_fake_quant': PerChannelMinMaxObserver(...), + # } + # We use remove_duplicate=False here because torch.cat uses + # the same activation_post_process module instance but different names + modules = dict(model.named_modules(remove_duplicate=False)) + + # TODO refactor this code once we update the prepare logic to have additional information on + # which graph nodes have been observed and share that with convert to decide which observers to ignore. + if qconfig_mapping: + prepare_qconfig_mapping: QConfigMapping = ( + observed_graph_module_attrs.qconfig_mapping + ) # type: ignore[assignment] + modules_copy = copy.deepcopy(modules) + + if observed_graph_module_attrs.is_qat: + _update_qconfig_for_qat(qconfig_mapping, backend_config) + _update_qconfig_for_fusion(model, qconfig_mapping) + + _compare_prepare_convert_qconfig_mappings( + prepare_qconfig_mapping, qconfig_mapping + ) # type: ignore[arg-type] + convert_node_name_to_qconfig = _generate_node_name_to_qconfig( + model, modules_copy, model.graph, qconfig_mapping, node_name_to_scope + ) + # check the convert_node_name_to_qconfig generated and ensure that + # all the values either match what was set in prepare node_name_to_qconfig + # or are set to None in the convert_node_name_to_qconfig. + for k, v in node_name_to_qconfig.items(): + assert k in convert_node_name_to_qconfig, ( + f"Expected key {k} in convert node_name_to_qconfig" + ) + if convert_node_name_to_qconfig[k] is not None: + assert qconfig_equals(v, convert_node_name_to_qconfig[k]), ( + f"Expected k {k} to have the same value in prepare and convert QConfigMappings, " + f"but {v} was updated to {convert_node_name_to_qconfig[k]}" + ) + node_name_to_qconfig = convert_node_name_to_qconfig + + custom_module_classes = get_custom_module_class_keys( + convert_custom_config.observed_to_quantized_mapping + ) + custom_module_class_mapping = convert_custom_config.observed_to_quantized_mapping + + if observed_graph_module_attrs.equalization_node_name_to_qconfig is not None: + # If we want to do equalization then do the following: + # Calculate the equalization scale, update the observers with the scaled + # inputs, and scale the weight + weight_eq_obs_dict = update_obs_for_equalization(model, modules) + convert_eq_obs(model, modules, weight_eq_obs_dict) + + # always run weight observers in the top level forward method + # for dynamic quant ops or weight only quant ops + _run_weight_observers(model, backend_config) + + # additional state to override inputs to be quantized, if specified + # by the user + placeholder_node_seen_cnt = 0 + input_quantized_idxs: list[int] = prepare_custom_config.input_quantized_indexes + output_quantized_idxs: list[int] = prepare_custom_config.output_quantized_indexes + + root_module_to_quantized_reference_module = ( + get_root_module_to_quantized_reference_module(backend_config) + ) + # convert tuples so that it can work with isinstance(module, tuple_of_classes) + root_module_classes = tuple(root_module_to_quantized_reference_module.keys()) + qat_module_classes = get_qat_module_classes(backend_config) + fused_module_classes = get_fused_module_classes(backend_config) + statically_quantized_custom_module_nodes: set[Node] = set() + + for node in list(model.graph.nodes): + if node.op == "placeholder": + cur_placeholder_node_idx = placeholder_node_seen_cnt + placeholder_node_seen_cnt += 1 + if cur_placeholder_node_idx in input_quantized_idxs: + # Inputs are assumed to be quantized if the user specified the + # input_quantized_idxs override. + # we need to dequantize the inputs since all operators took + # floating point inputs in reference quantized models + _insert_dequantize_node(node, model.graph) + elif node.op == "output": + # If the argument is empty we don't need to do anything + if len(output_quantized_idxs) == 0: + continue + # Result are kept quantized if the user specified the + # output_quantized_idxs override. + # Remove the dequantize operator for the node in the end if any + return_node = node + output = node.args[0] + # outputs can be Node, list, tuple, dict, other cases are not supported yet + if isinstance(output, (list, tuple)): + for idx in output_quantized_idxs: + _maybe_recursive_remove_dequantize( + output[idx], return_node, model.graph + ) + elif isinstance(output, (Node, dict)): + # we treat dict as a single argument currently, but it can be extended + # to support {"key": dtype} after we change output_quantized_idxs to + # dict + if 0 in output_quantized_idxs: + _maybe_recursive_remove_dequantize(output, return_node, model.graph) + else: + warnings.warn( + f"Unsupported node type for output_quantized_idxs: {type(output)}" + ) + elif node.op == "call_module": + mod = _get_module(node, modules) + assert mod is not None + if _is_activation_post_process(mod): + observed_node = node.args[0] + if observed_node in statically_quantized_custom_module_nodes: + _replace_observer_or_dequant_stub_with_dequantize_node( + node, model.graph + ) + else: + if is_decomposed: + _replace_observer_with_quantize_dequantize_node_decomposed( + model, + node, + modules, + node_name_to_scope, + node_name_to_qconfig, + ) + else: + _replace_observer_with_quantize_dequantize_node( + model, + node, + modules, + node_name_to_scope, + node_name_to_qconfig, + ) + elif isinstance(mod, DeQuantStub): + _replace_observer_or_dequant_stub_with_dequantize_node( + node, model.graph + ) + elif _is_observed_standalone_module(mod): + convert_standalone_module( + node, modules, model, is_reference, backend_config + ) + # below this point `type_before_parametrizations` is used + # instead of `type` to handle situations with fx quant + sparsity + elif type_before_parametrizations(mod) in set(root_module_classes).union( + qat_module_classes + ).union(fused_module_classes): + # extra check for fused module classes to make sure they are fused module classes + # of target modules + if ( + type_before_parametrizations(mod) in fused_module_classes + and type_before_parametrizations(mod[0]) not in root_module_classes + ): # type: ignore[index] + continue + convert_weighted_module( + node, + modules, + observed_node_names, + node_name_to_qconfig, + backend_config, + is_decomposed, + is_reference, + ) + elif type_before_parametrizations(mod) in custom_module_classes: + convert_custom_module( + node, + model.graph, + modules, + custom_module_class_mapping, + statically_quantized_custom_module_nodes, + ) + + # remove deadcode after converting observers to quant/dequant ops + model.graph.eliminate_dead_code() + model = GraphModule(model, model.graph) + + # TODO: maybe move this to quantize_fx.py + if not is_reference: + model = lower_to_fbgemm( + model, node_name_to_qconfig, node_name_to_scope, keep_original_weights + ) + + # TODO: this looks hacky, we want to check why we need this and see if we can + # remove this + # removes qconfig and activation_post_process modules + if _remove_qconfig_flag: + _remove_qconfig(model) + model.delete_all_unused_submodules() + model.meta.pop("_observed_graph_module_attrs", None) + return model diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/custom_config.py b/phivenv/Lib/site-packages/torch/ao/quantization/fx/custom_config.py new file mode 100644 index 0000000000000000000000000000000000000000..4b988ce16f5941fdfacaf4df01841fdbaeab27a3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/fx/custom_config.py @@ -0,0 +1,521 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Optional + +from torch.ao.quantization import QConfigMapping +from torch.ao.quantization.backend_config import BackendConfig +from torch.ao.quantization.quant_type import ( + _get_quant_type_to_str, + _quant_type_from_str, + QuantType, +) + + +__all__ = [ + "ConvertCustomConfig", + "FuseCustomConfig", + "PrepareCustomConfig", + "StandaloneModuleConfigEntry", +] + + +# TODO: replace all usages with these constants +STANDALONE_MODULE_NAME_DICT_KEY = "standalone_module_name" +STANDALONE_MODULE_CLASS_DICT_KEY = "standalone_module_class" +FLOAT_TO_OBSERVED_DICT_KEY = "float_to_observed_custom_module_class" +OBSERVED_TO_QUANTIZED_DICT_KEY = "observed_to_quantized_custom_module_class" +NON_TRACEABLE_MODULE_NAME_DICT_KEY = "non_traceable_module_name" +NON_TRACEABLE_MODULE_CLASS_DICT_KEY = "non_traceable_module_class" +INPUT_QUANTIZED_INDEXES_DICT_KEY = "input_quantized_idxs" +OUTPUT_QUANTIZED_INDEXES_DICT_KEY = "output_quantized_idxs" +PRESERVED_ATTRIBUTES_DICT_KEY = "preserved_attributes" + + +@dataclass +class StandaloneModuleConfigEntry: + # qconfig_mapping for the prepare function called in the submodule, + # None means use qconfig from parent qconfig_mapping + qconfig_mapping: Optional[QConfigMapping] + example_inputs: tuple[Any, ...] + prepare_custom_config: Optional[PrepareCustomConfig] + backend_config: Optional[BackendConfig] + + +class PrepareCustomConfig: + """ + Custom configuration for :func:`~torch.ao.quantization.quantize_fx.prepare_fx` and + :func:`~torch.ao.quantization.quantize_fx.prepare_qat_fx`. + + Example usage:: + + prepare_custom_config = PrepareCustomConfig() \ + .set_standalone_module_name("module1", qconfig_mapping, example_inputs, \ + child_prepare_custom_config, backend_config) \ + .set_standalone_module_class(MyStandaloneModule, qconfig_mapping, example_inputs, \ + child_prepare_custom_config, backend_config) \ + .set_float_to_observed_mapping(FloatCustomModule, ObservedCustomModule) \ + .set_non_traceable_module_names(["module2", "module3"]) \ + .set_non_traceable_module_classes([NonTraceableModule1, NonTraceableModule2]) \ + .set_input_quantized_indexes([0]) \ + .set_output_quantized_indexes([0]) \ + .set_preserved_attributes(["attr1", "attr2"]) + """ + + def __init__(self) -> None: + self.standalone_module_names: dict[str, StandaloneModuleConfigEntry] = {} + self.standalone_module_classes: dict[type, StandaloneModuleConfigEntry] = {} + self.float_to_observed_mapping: dict[QuantType, dict[type, type]] = {} + self.non_traceable_module_names: list[str] = [] + self.non_traceable_module_classes: list[type] = [] + self.input_quantized_indexes: list[int] = [] + self.output_quantized_indexes: list[int] = [] + self.preserved_attributes: list[str] = [] + + def __repr__(self): + dict_nonempty = {k: v for k, v in self.__dict__.items() if len(v) > 0} + return f"PrepareCustomConfig({dict_nonempty})" + + def set_standalone_module_name( + self, + module_name: str, + qconfig_mapping: Optional[QConfigMapping], + example_inputs: tuple[Any, ...], + prepare_custom_config: Optional[PrepareCustomConfig], + backend_config: Optional[BackendConfig], + ) -> PrepareCustomConfig: + """ + Set the configuration for running a standalone module identified by ``module_name``. + + If ``qconfig_mapping`` is None, the parent ``qconfig_mapping`` will be used instead. + If ``prepare_custom_config`` is None, an empty ``PrepareCustomConfig`` will be used. + If ``backend_config`` is None, the parent ``backend_config`` will be used instead. + """ + self.standalone_module_names[module_name] = StandaloneModuleConfigEntry( + qconfig_mapping, example_inputs, prepare_custom_config, backend_config + ) + return self + + def set_standalone_module_class( + self, + module_class: type, + qconfig_mapping: Optional[QConfigMapping], + example_inputs: tuple[Any, ...], + prepare_custom_config: Optional[PrepareCustomConfig], + backend_config: Optional[BackendConfig], + ) -> PrepareCustomConfig: + """ + Set the configuration for running a standalone module identified by ``module_class``. + + If ``qconfig_mapping`` is None, the parent ``qconfig_mapping`` will be used instead. + If ``prepare_custom_config`` is None, an empty ``PrepareCustomConfig`` will be used. + If ``backend_config`` is None, the parent ``backend_config`` will be used instead. + """ + self.standalone_module_classes[module_class] = StandaloneModuleConfigEntry( + qconfig_mapping, example_inputs, prepare_custom_config, backend_config + ) + return self + + def set_float_to_observed_mapping( + self, + float_class: type, + observed_class: type, + quant_type: QuantType = QuantType.STATIC, + ) -> PrepareCustomConfig: + """ + Set the mapping from a custom float module class to a custom observed module class. + + The observed module class must have a ``from_float`` class method that converts the float module class + to the observed module class. This is currently only supported for static quantization. + """ + if quant_type != QuantType.STATIC: + raise ValueError( + "set_float_to_observed_mapping is currently only supported for static quantization" + ) + if quant_type not in self.float_to_observed_mapping: + self.float_to_observed_mapping[quant_type] = {} + self.float_to_observed_mapping[quant_type][float_class] = observed_class + return self + + def set_non_traceable_module_names( + self, module_names: list[str] + ) -> PrepareCustomConfig: + """ + Set the modules that are not symbolically traceable, identified by name. + """ + self.non_traceable_module_names = module_names + return self + + def set_non_traceable_module_classes( + self, module_classes: list[type] + ) -> PrepareCustomConfig: + """ + Set the modules that are not symbolically traceable, identified by class. + """ + self.non_traceable_module_classes = module_classes + return self + + def set_input_quantized_indexes(self, indexes: list[int]) -> PrepareCustomConfig: + """ + Set the indexes of the inputs of the graph that should be quantized. + Inputs are otherwise assumed to be in fp32 by default instead. + """ + self.input_quantized_indexes = indexes + return self + + def set_output_quantized_indexes(self, indexes: list[int]) -> PrepareCustomConfig: + """ + Set the indexes of the outputs of the graph that should be quantized. + Outputs are otherwise assumed to be in fp32 by default instead. + """ + self.output_quantized_indexes = indexes + return self + + def set_preserved_attributes(self, attributes: list[str]) -> PrepareCustomConfig: + """ + Set the names of the attributes that will persist in the graph module even if they are not used in + the model's ``forward`` method. + """ + self.preserved_attributes = attributes + return self + + # TODO: remove this + @classmethod + def from_dict( + cls, prepare_custom_config_dict: dict[str, Any] + ) -> PrepareCustomConfig: + """ + Create a ``PrepareCustomConfig`` from a dictionary with the following items: + + "standalone_module_name": a list of (module_name, qconfig_mapping, example_inputs, + child_prepare_custom_config, backend_config) tuples + + "standalone_module_class" a list of (module_class, qconfig_mapping, example_inputs, + child_prepare_custom_config, backend_config) tuples + + "float_to_observed_custom_module_class": a nested dictionary mapping from quantization + mode to an inner mapping from float module classes to observed module classes, e.g. + {"static": {FloatCustomModule: ObservedCustomModule}} + + "non_traceable_module_name": a list of modules names that are not symbolically traceable + "non_traceable_module_class": a list of module classes that are not symbolically traceable + "input_quantized_idxs": a list of indexes of graph inputs that should be quantized + "output_quantized_idxs": a list of indexes of graph outputs that should be quantized + "preserved_attributes": a list of attributes that persist even if they are not used in ``forward`` + + This function is primarily for backward compatibility and may be removed in the future. + """ + + def _get_qconfig_mapping(obj: Any, dict_key: str) -> Optional[QConfigMapping]: + """ + Convert the given object into a QConfigMapping if possible, else throw an exception. + """ + if isinstance(obj, QConfigMapping) or obj is None: + return obj + if isinstance(obj, dict): + return QConfigMapping.from_dict(obj) + raise ValueError( + f"Expected QConfigMapping in prepare_custom_config_dict[\"{dict_key}\"], got '{type(obj)}'" + ) + + def _get_prepare_custom_config( + obj: Any, dict_key: str + ) -> Optional[PrepareCustomConfig]: + """ + Convert the given object into a PrepareCustomConfig if possible, else throw an exception. + """ + if isinstance(obj, PrepareCustomConfig) or obj is None: + return obj + if isinstance(obj, dict): + return PrepareCustomConfig.from_dict(obj) + raise ValueError( + f"Expected PrepareCustomConfig in prepare_custom_config_dict[\"{dict_key}\"], got '{type(obj)}'" + ) + + def _get_backend_config(obj: Any, dict_key: str) -> Optional[BackendConfig]: + """ + Convert the given object into a BackendConfig if possible, else throw an exception. + """ + if isinstance(obj, BackendConfig) or obj is None: + return obj + if isinstance(obj, dict): + return BackendConfig.from_dict(obj) + raise ValueError( + f"Expected BackendConfig in prepare_custom_config_dict[\"{dict_key}\"], got '{type(obj)}'" + ) + + conf = cls() + for ( + module_name, + qconfig_dict, + example_inputs, + _prepare_custom_config_dict, + backend_config_dict, + ) in prepare_custom_config_dict.get(STANDALONE_MODULE_NAME_DICT_KEY, []): + qconfig_mapping = _get_qconfig_mapping( + qconfig_dict, STANDALONE_MODULE_NAME_DICT_KEY + ) + prepare_custom_config = _get_prepare_custom_config( + _prepare_custom_config_dict, STANDALONE_MODULE_NAME_DICT_KEY + ) + backend_config = _get_backend_config( + backend_config_dict, STANDALONE_MODULE_NAME_DICT_KEY + ) + conf.set_standalone_module_name( + module_name, + qconfig_mapping, + example_inputs, + prepare_custom_config, + backend_config, + ) + for ( + module_class, + qconfig_dict, + example_inputs, + _prepare_custom_config_dict, + backend_config_dict, + ) in prepare_custom_config_dict.get(STANDALONE_MODULE_CLASS_DICT_KEY, []): + qconfig_mapping = _get_qconfig_mapping( + qconfig_dict, STANDALONE_MODULE_CLASS_DICT_KEY + ) + prepare_custom_config = _get_prepare_custom_config( + _prepare_custom_config_dict, STANDALONE_MODULE_CLASS_DICT_KEY + ) + backend_config = _get_backend_config( + backend_config_dict, STANDALONE_MODULE_CLASS_DICT_KEY + ) + conf.set_standalone_module_class( + module_class, + qconfig_mapping, + example_inputs, + prepare_custom_config, + backend_config, + ) + for quant_type_name, custom_module_mapping in prepare_custom_config_dict.get( + FLOAT_TO_OBSERVED_DICT_KEY, {} + ).items(): + quant_type = _quant_type_from_str(quant_type_name) + for float_class, observed_class in custom_module_mapping.items(): + conf.set_float_to_observed_mapping( + float_class, observed_class, quant_type + ) + conf.set_non_traceable_module_names( + prepare_custom_config_dict.get(NON_TRACEABLE_MODULE_NAME_DICT_KEY, []) + ) + conf.set_non_traceable_module_classes( + prepare_custom_config_dict.get(NON_TRACEABLE_MODULE_CLASS_DICT_KEY, []) + ) + conf.set_input_quantized_indexes( + prepare_custom_config_dict.get(INPUT_QUANTIZED_INDEXES_DICT_KEY, []) + ) + conf.set_output_quantized_indexes( + prepare_custom_config_dict.get(OUTPUT_QUANTIZED_INDEXES_DICT_KEY, []) + ) + conf.set_preserved_attributes( + prepare_custom_config_dict.get(PRESERVED_ATTRIBUTES_DICT_KEY, []) + ) + return conf + + def to_dict(self) -> dict[str, Any]: + """ + Convert this ``PrepareCustomConfig`` to a dictionary with the items described in + :func:`~torch.ao.quantization.fx.custom_config.PrepareCustomConfig.from_dict`. + """ + + def _make_tuple(key: Any, e: StandaloneModuleConfigEntry): + qconfig_dict = e.qconfig_mapping.to_dict() if e.qconfig_mapping else None + prepare_custom_config_dict = ( + e.prepare_custom_config.to_dict() if e.prepare_custom_config else None + ) + return ( + key, + qconfig_dict, + e.example_inputs, + prepare_custom_config_dict, + e.backend_config, + ) + + d: dict[str, Any] = {} + for module_name, sm_config_entry in self.standalone_module_names.items(): + if STANDALONE_MODULE_NAME_DICT_KEY not in d: + d[STANDALONE_MODULE_NAME_DICT_KEY] = [] + d[STANDALONE_MODULE_NAME_DICT_KEY].append( + _make_tuple(module_name, sm_config_entry) + ) + for module_class, sm_config_entry in self.standalone_module_classes.items(): + if STANDALONE_MODULE_CLASS_DICT_KEY not in d: + d[STANDALONE_MODULE_CLASS_DICT_KEY] = [] + d[STANDALONE_MODULE_CLASS_DICT_KEY].append( + _make_tuple(module_class, sm_config_entry) + ) + for ( + quant_type, + float_to_observed_mapping, + ) in self.float_to_observed_mapping.items(): + if FLOAT_TO_OBSERVED_DICT_KEY not in d: + d[FLOAT_TO_OBSERVED_DICT_KEY] = {} + d[FLOAT_TO_OBSERVED_DICT_KEY][_get_quant_type_to_str(quant_type)] = ( + float_to_observed_mapping + ) + if len(self.non_traceable_module_names) > 0: + d[NON_TRACEABLE_MODULE_NAME_DICT_KEY] = self.non_traceable_module_names + if len(self.non_traceable_module_classes) > 0: + d[NON_TRACEABLE_MODULE_CLASS_DICT_KEY] = self.non_traceable_module_classes + if len(self.input_quantized_indexes) > 0: + d[INPUT_QUANTIZED_INDEXES_DICT_KEY] = self.input_quantized_indexes + if len(self.output_quantized_indexes) > 0: + d[OUTPUT_QUANTIZED_INDEXES_DICT_KEY] = self.output_quantized_indexes + if len(self.preserved_attributes) > 0: + d[PRESERVED_ATTRIBUTES_DICT_KEY] = self.preserved_attributes + return d + + +class ConvertCustomConfig: + """ + Custom configuration for :func:`~torch.ao.quantization.quantize_fx.convert_fx`. + + Example usage:: + + convert_custom_config = ConvertCustomConfig() \ + .set_observed_to_quantized_mapping(ObservedCustomModule, QuantizedCustomModule) \ + .set_preserved_attributes(["attr1", "attr2"]) + """ + + def __init__(self) -> None: + self.observed_to_quantized_mapping: dict[QuantType, dict[type, type]] = {} + self.preserved_attributes: list[str] = [] + + def __repr__(self): + dict_nonempty = {k: v for k, v in self.__dict__.items() if len(v) > 0} + return f"ConvertCustomConfig({dict_nonempty})" + + def set_observed_to_quantized_mapping( + self, + observed_class: type, + quantized_class: type, + quant_type: QuantType = QuantType.STATIC, + ) -> ConvertCustomConfig: + """ + Set the mapping from a custom observed module class to a custom quantized module class. + + The quantized module class must have a ``from_observed`` class method that converts the observed module class + to the quantized module class. + """ + if quant_type not in self.observed_to_quantized_mapping: + self.observed_to_quantized_mapping[quant_type] = {} + self.observed_to_quantized_mapping[quant_type][observed_class] = quantized_class + return self + + def set_preserved_attributes(self, attributes: list[str]) -> ConvertCustomConfig: + """ + Set the names of the attributes that will persist in the graph module even if they are not used in + the model's ``forward`` method. + """ + self.preserved_attributes = attributes + return self + + # TODO: remove this + @classmethod + def from_dict( + cls, convert_custom_config_dict: dict[str, Any] + ) -> ConvertCustomConfig: + """ + Create a ``ConvertCustomConfig`` from a dictionary with the following items: + + "observed_to_quantized_custom_module_class": a nested dictionary mapping from quantization + mode to an inner mapping from observed module classes to quantized module classes, e.g.:: + { + "static": {FloatCustomModule: ObservedCustomModule}, + "dynamic": {FloatCustomModule: ObservedCustomModule}, + "weight_only": {FloatCustomModule: ObservedCustomModule} + } + "preserved_attributes": a list of attributes that persist even if they are not used in ``forward`` + + This function is primarily for backward compatibility and may be removed in the future. + """ + conf = cls() + for quant_type_name, custom_module_mapping in convert_custom_config_dict.get( + OBSERVED_TO_QUANTIZED_DICT_KEY, {} + ).items(): + quant_type = _quant_type_from_str(quant_type_name) + for observed_class, quantized_class in custom_module_mapping.items(): + conf.set_observed_to_quantized_mapping( + observed_class, quantized_class, quant_type + ) + conf.set_preserved_attributes( + convert_custom_config_dict.get(PRESERVED_ATTRIBUTES_DICT_KEY, []) + ) + return conf + + def to_dict(self) -> dict[str, Any]: + """ + Convert this ``ConvertCustomConfig`` to a dictionary with the items described in + :func:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig.from_dict`. + """ + d: dict[str, Any] = {} + for ( + quant_type, + observed_to_quantized_mapping, + ) in self.observed_to_quantized_mapping.items(): + if OBSERVED_TO_QUANTIZED_DICT_KEY not in d: + d[OBSERVED_TO_QUANTIZED_DICT_KEY] = {} + d[OBSERVED_TO_QUANTIZED_DICT_KEY][_get_quant_type_to_str(quant_type)] = ( + observed_to_quantized_mapping + ) + if len(self.preserved_attributes) > 0: + d[PRESERVED_ATTRIBUTES_DICT_KEY] = self.preserved_attributes + return d + + +class FuseCustomConfig: + """ + Custom configuration for :func:`~torch.ao.quantization.quantize_fx.fuse_fx`. + + Example usage:: + + fuse_custom_config = FuseCustomConfig().set_preserved_attributes( + ["attr1", "attr2"] + ) + """ + + def __init__(self) -> None: + self.preserved_attributes: list[str] = [] + + def __repr__(self): + dict_nonempty = {k: v for k, v in self.__dict__.items() if len(v) > 0} + return f"FuseCustomConfig({dict_nonempty})" + + def set_preserved_attributes(self, attributes: list[str]) -> FuseCustomConfig: + """ + Set the names of the attributes that will persist in the graph module even if they are not used in + the model's ``forward`` method. + """ + self.preserved_attributes = attributes + return self + + # TODO: remove this + @classmethod + def from_dict(cls, fuse_custom_config_dict: dict[str, Any]) -> FuseCustomConfig: + """ + Create a ``ConvertCustomConfig`` from a dictionary with the following items: + + "preserved_attributes": a list of attributes that persist even if they are not used in ``forward`` + + This function is primarily for backward compatibility and may be removed in the future. + """ + conf = cls() + conf.set_preserved_attributes( + fuse_custom_config_dict.get(PRESERVED_ATTRIBUTES_DICT_KEY, []) + ) + return conf + + def to_dict(self) -> dict[str, Any]: + """ + Convert this ``FuseCustomConfig`` to a dictionary with the items described in + :func:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig.from_dict`. + """ + d: dict[str, Any] = {} + if len(self.preserved_attributes) > 0: + d[PRESERVED_ATTRIBUTES_DICT_KEY] = self.preserved_attributes + return d diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/fuse.py b/phivenv/Lib/site-packages/torch/ao/quantization/fx/fuse.py new file mode 100644 index 0000000000000000000000000000000000000000..77641a340b5a4cae2b0c86a8617a8f5ac2fd3634 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/fx/fuse.py @@ -0,0 +1,191 @@ +# mypy: allow-untyped-defs +import warnings +from typing import Any, Callable, Union + +from torch.ao.quantization.backend_config import ( + BackendConfig, + get_native_backend_config, +) +from torch.ao.quantization.backend_config.utils import ( + get_fuser_method_mapping, + get_fusion_pattern_to_extra_inputs_getter, + get_fusion_pattern_to_root_node_getter, +) +from torch.ao.quantization.utils import NodePattern, Pattern +from torch.fx import GraphModule, map_arg, Node +from torch.fx.graph import Graph + +from .custom_config import FuseCustomConfig +from .fuse_handler import _get_fusion_pattern_to_fuse_handler_cls, FuseHandler +from .match_utils import _is_match, MatchAllNode +from .pattern_utils import _sorted_patterns_dict + + +__all__ = [ + "fuse", + # TODO: We should make this private in the future + # This is currently needed for test_public_bindings for some reason + "FuseHandler", +] + + +def fuse( + model: GraphModule, + is_qat: bool, + fuse_custom_config: Union[FuseCustomConfig, dict[str, Any], None] = None, + backend_config: Union[BackendConfig, dict[str, Any], None] = None, +) -> GraphModule: + if fuse_custom_config is None: + fuse_custom_config = FuseCustomConfig() + + if isinstance(fuse_custom_config, dict): + warnings.warn( + "Passing a fuse_custom_config_dict to fuse is deprecated and will not be supported " + "in a future version. Please pass in a FuseCustomConfig instead.", + FutureWarning, + stacklevel=2, + ) + fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config) + + if isinstance(backend_config, dict): + warnings.warn( + "Passing a backend_config_dict to prepare is deprecated and will not be supported " + "in a future version. Please pass in a BackendConfig instead.", + FutureWarning, + stacklevel=2, + ) + backend_config = BackendConfig.from_dict(backend_config) + + named_modules = dict(model.named_modules()) + + if backend_config is None: + backend_config = get_native_backend_config() + + fusion_pattern_to_fuse_handler_cls = _sorted_patterns_dict( + _get_fusion_pattern_to_fuse_handler_cls(backend_config) + ) + fuser_method_mapping = get_fuser_method_mapping(backend_config) + fusion_pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter( + backend_config + ) + fusion_pattern_to_extra_inputs_getter = get_fusion_pattern_to_extra_inputs_getter( + backend_config + ) + + # find fusion + fusion_pairs = _find_matches(model, model.graph, fusion_pattern_to_fuse_handler_cls) + # TODO: change this to inplace changes to graph, since we no longer construct + # new GraphModule anymore + fused_graph = Graph() + env: dict[Any, Any] = {} + + def load_arg(a): + return map_arg(a, lambda node: env[node.name]) + + def default_root_node_getter(node_pattern): + while not isinstance(node_pattern[-1], Node): + node_pattern = node_pattern[-1] + return node_pattern[-1] + + for node in model.graph.nodes: + ( + maybe_last_node, + pattern, + matched_node_pattern, + obj, + node_to_subpattern, + ) = fusion_pairs.get(node.name, (None, None, None, None, None)) + # get the corresponding subpattern for the current node + if node_to_subpattern is not None: + node_subpattern = node_to_subpattern.get(node, None) + else: + node_subpattern = None + if maybe_last_node is node: + assert obj is not None + root_node_getter = fusion_pattern_to_root_node_getter.get( + pattern, default_root_node_getter + ) + root_node = root_node_getter(matched_node_pattern) # type: ignore[index] + extra_inputs_getter = fusion_pattern_to_extra_inputs_getter.get( + pattern, None + ) + extra_inputs = [] + if extra_inputs_getter is not None: + extra_inputs = extra_inputs_getter(matched_node_pattern) + # TODO: add validation that root_node is a module and has the same type + # as the root_module in the configuration + env[node.name] = obj.fuse( + load_arg, + named_modules, + fused_graph, + root_node, + extra_inputs, + matched_node_pattern, # type: ignore[arg-type] + fuse_custom_config, + fuser_method_mapping, + is_qat, + ) + elif maybe_last_node is None or node_subpattern is MatchAllNode: + env[node.name] = fused_graph.node_copy(node, load_arg) + # node matched in patterns and is not root is removed here + + model = GraphModule(model, fused_graph) + return model + + +def _find_matches( + root: GraphModule, + graph: Graph, + pattern_to_fuse_handler_cls: dict[Pattern, Callable], +) -> dict[str, tuple[Node, Pattern, NodePattern, FuseHandler, dict[Node, Any]]]: + modules = dict(root.named_modules()) + # node name -> (root_node, match_value) + match_map: dict[ + str, tuple[Node, Pattern, NodePattern, FuseHandler, dict[Node, Any]] + ] = {} + # a map from node to the matched subpattern + node_to_subpattern: dict[Node, Any] = {} + + # TODO: dedup with quantization matching function in match_utils.py + def apply_match(pattern, node, match, matched_node_pattern, node_to_subpattern): + if isinstance(pattern, tuple): + s, *args = pattern + current_node_pattern: list[Node] = [] + apply_match(s, node, match, current_node_pattern, node_to_subpattern) + for subpattern, arg in zip(args, node.args): + apply_match( + subpattern, arg, match, current_node_pattern, node_to_subpattern + ) + matched_node_pattern.append(tuple(current_node_pattern)) + else: + # the first pattern matches will take precedence + if node.name not in match_map: + matched_node_pattern.append(node) + # MatchAllNode here is actually MatchAllInputNode which should not + # be added to match_map + if pattern is not MatchAllNode: + node_to_subpattern[node] = pattern + root_node, pattern, handler = match + match_map[node.name] = ( + root_node, + pattern, + matched_node_pattern, + handler, + node_to_subpattern, + ) + + for node in reversed(graph.nodes): + if node.name not in match_map: + for pattern, fuse_handler_cls in pattern_to_fuse_handler_cls.items(): + matched_node_pattern: list[Node] = [] + if _is_match(modules, node, pattern): + apply_match( + pattern, + node, + (node, pattern, fuse_handler_cls(node)), + matched_node_pattern, + node_to_subpattern, + ) + break + + return match_map diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/fuse_handler.py b/phivenv/Lib/site-packages/torch/ao/quantization/fx/fuse_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..3c890d8945fba271b88e6829331ebecdb387241d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/fx/fuse_handler.py @@ -0,0 +1,129 @@ +# mypy: allow-untyped-defs +from abc import ABC, abstractmethod +from typing import Any, Callable, Union + +import torch +from torch.ao.quantization.backend_config import BackendConfig +from torch.ao.quantization.fuser_method_mappings import get_fuser_method_new +from torch.ao.quantization.utils import _parent_name, NodePattern, Pattern +from torch.fx.graph import Graph, Node +from torch.nn.utils.parametrize import type_before_parametrizations + +from .custom_config import FuseCustomConfig +from .match_utils import MatchAllNode + + +__all__ = [ + "DefaultFuseHandler", + "FuseHandler", +] + + +# ---------------------------- +# Fusion Pattern Registrations +# ---------------------------- + + +# Base Pattern Handler +class FuseHandler(ABC): + """Base handler class for the fusion patterns""" + + @abstractmethod + def __init__(self, node: Node): + pass + + @abstractmethod + def fuse( + self, + load_arg: Callable, + named_modules: dict[str, torch.nn.Module], + fused_graph: Graph, + root_node: Node, + extra_inputs: list[Any], + matched_node_pattern: NodePattern, + fuse_custom_config: FuseCustomConfig, + fuser_method_mapping: dict[Pattern, Union[torch.nn.Sequential, Callable]], + is_qat: bool, + ) -> Node: + pass + + +class DefaultFuseHandler(FuseHandler): + def __init__(self, node: Node): + super().__init__(node) # type:ignore[safe-super] + + def fuse( + self, + load_arg: Callable, + named_modules: dict[str, torch.nn.Module], + fused_graph: Graph, + root_node: Node, + extra_inputs: list[Any], + matched_node_pattern: NodePattern, + fuse_custom_config: FuseCustomConfig, + fuser_method_mapping: dict[Pattern, Union[torch.nn.Sequential, Callable]], + is_qat: bool, + ) -> Node: + assert root_node.op == "call_module", ( + "Expecting module node to be a call_module Node" + ) + root_module = named_modules[str(root_node.target)] + + def get_modules(pattern): + """Given a node pattern, extract the corresponding modules + e.g. input: (relu_node, (bn_node, conv_node)) + output: (relu_module, (bn_module, conv_module)) + """ + if isinstance(pattern, (tuple, list)): + n, *args = pattern + modules: list[torch.nn.Module] = [] + modules.append(get_modules(n)) + modules.extend(get_modules(a) for a in args) + return tuple(modules) + else: + n = pattern + if n.op == "call_module": + return named_modules[n.target] + elif n.op == "call_function" and n.target == torch.nn.functional.relu: + relu = torch.nn.ReLU() + relu.training = root_module.training + return relu + elif n.op == "call_function" or n.op == "call_method": + return n.target + else: + return MatchAllNode + + # since relu can be used multiple times, we'll need to create a relu module for each match + matched_modules = get_modules(matched_node_pattern) + + def get_matched_types(m): + if isinstance(m, tuple): + return tuple(map(get_matched_types, m)) + if isinstance(m, torch.nn.Module): + return type_before_parametrizations(m) + return m + + matched_module_types = get_matched_types(matched_modules) + module_parent_name, module_name = _parent_name(root_node.target) + fuser_method = get_fuser_method_new(matched_module_types, fuser_method_mapping) + # TODO: change the signature for fuser_method to take matched module patterns + # as input + fused_module = fuser_method(is_qat, *matched_modules) + setattr(named_modules[module_parent_name], module_name, fused_module) + extra_args = [load_arg(input) for input in extra_inputs] + node = fused_graph.node_copy(root_node, load_arg) + args = list(node.args) + args.extend(extra_args) + node.args = tuple(args) + return node + + +def _get_fusion_pattern_to_fuse_handler_cls( + backend_config: BackendConfig, +) -> dict[Pattern, Callable]: + fusion_pattern_to_fuse_handlers: dict[Pattern, Callable] = {} + for pattern, config in backend_config._pattern_complex_format_to_config.items(): + if config.fuser_method is not None: + # TODO: is this logic right? + fusion_pattern_to_fuse_handlers[pattern] = DefaultFuseHandler + return fusion_pattern_to_fuse_handlers diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/graph_module.py b/phivenv/Lib/site-packages/torch/ao/quantization/fx/graph_module.py new file mode 100644 index 0000000000000000000000000000000000000000..91eecaa797cb133b997693649444ff383d082946 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/fx/graph_module.py @@ -0,0 +1,205 @@ +# mypy: allow-untyped-defs +import copy +from typing import Any, Union + +import torch +from torch.fx import GraphModule +from torch.fx.graph import Graph + + +__all__ = [ + "FusedGraphModule", + "ObservedGraphModule", + "ObservedStandaloneGraphModule", + "QuantizedGraphModule", +] + + +class FusedGraphModule(GraphModule): + def __init__( + self, + root: Union[torch.nn.Module, dict[str, Any]], + graph: Graph, + preserved_attr_names: set[str], + ): + self.preserved_attr_names = preserved_attr_names + preserved_attrs = { + attr: getattr(root, attr) + for attr in self.preserved_attr_names + if hasattr(root, attr) + } + super().__init__(root, graph) + for attr in preserved_attrs: + setattr(self, attr, preserved_attrs[attr]) + + # GraphModule does not copy attributes which are not in the __dict__ + # of vanilla nn.Module. So, we override __deepcopy__ in order + # to copy the quantization specific attributes correctly. + def __deepcopy__(self, memo): + fake_mod = torch.nn.Module() + fake_mod.__dict__ = copy.deepcopy(self.__dict__) + return FusedGraphModule( + fake_mod, + copy.deepcopy(self.graph), + copy.deepcopy(self.preserved_attr_names), + ) + + +class ObservedGraphModule(GraphModule): + def __init__( + self, + root: Union[torch.nn.Module, dict[str, Any]], + graph: Graph, + preserved_attr_names: set[str], + ): + self.preserved_attr_names = { + "_activation_post_process_map", + "_activation_post_process_indexes", + "_patterns", + "_node_name_to_qconfig", + "_prepare_custom_config", + "_equalization_node_name_to_qconfig", + "_node_name_to_scope", + "_qconfig_mapping", + "_is_qat", + "_observed_node_names", + }.union(preserved_attr_names) + preserved_attrs = { + attr: getattr(root, attr) + for attr in self.preserved_attr_names + if hasattr(root, attr) + } + super().__init__(root, graph) + for attr in preserved_attrs: + setattr(self, attr, preserved_attrs[attr]) + + # GraphModule does not copy attributes which are not in the __dict__ + # of vanilla nn.Module. So, we override __deepcopy__ in order + # to copy the quantization specific attributes correctly. + def __deepcopy__(self, memo): + fake_mod = torch.nn.Module() + fake_mod.__dict__ = copy.deepcopy(self.__dict__) + return ObservedGraphModule( + fake_mod, + copy.deepcopy(self.graph), + copy.deepcopy(self.preserved_attr_names), + ) + + +def _is_observed_module(module: Any) -> bool: + return hasattr(module, "meta") and "_observed_graph_module_attrs" in module.meta + + +def _get_observed_graph_module_attr( + model: Union[torch.nn.Module, GraphModule], attr_name: str +) -> Any: + if hasattr(model, "meta") and "_observed_graph_module_attrs" in model.meta: # type: ignore[operator, index] + return getattr(model.meta["_observed_graph_module_attrs"], attr_name) # type: ignore[index] + return None + + +class ObservedStandaloneGraphModule(ObservedGraphModule): + def __init__( + self, + root: Union[torch.nn.Module, dict[str, Any]], + graph: Graph, + preserved_attr_names: set[str], + ): + preserved_attr_names = preserved_attr_names.union( + { + "_standalone_module_input_quantized_idxs", + "_standalone_module_output_quantized_idxs", + } + ) + super().__init__(root, graph, preserved_attr_names) + + def __deepcopy__(self, memo): + fake_mod = torch.nn.Module() + fake_mod.__dict__ = copy.deepcopy(self.__dict__) + return ObservedStandaloneGraphModule( + fake_mod, + copy.deepcopy(self.graph), + copy.deepcopy(self.preserved_attr_names), + ) + + +def _is_observed_standalone_module(module: Any) -> bool: + return ( + _is_observed_module(module) + and module.meta["_observed_graph_module_attrs"].is_observed_standalone_module + ) + + +def _save_packed_weight(self, destination, prefix, keep_vars): + for attr_name in dir(self): + if "_packed_weight" in attr_name and isinstance( + getattr(self, attr_name), torch._C.ScriptObject + ): # type: ignore[attr-defined] + packed_weight = getattr(self, attr_name) + destination[prefix + attr_name] = packed_weight + + +class QuantizedGraphModule(GraphModule): + """This class is created to make sure PackedParams + (e.g. LinearPackedParams, Conv2dPackedParams) to appear in state_dict + so that we can serialize and deserialize quantized graph module with + torch.save(m.state_dict()) and m.load_state_dict(state_dict) + """ + + def __init__( + self, + root: Union[torch.nn.Module, dict[str, Any]], + graph: Graph, + preserved_attr_names: set[str], + ): + self.preserved_attr_names = preserved_attr_names + preserved_attrs = { + attr: getattr(root, attr) + for attr in self.preserved_attr_names + if hasattr(root, attr) + } + super().__init__(root, graph) + for attr in preserved_attrs: + setattr(self, attr, preserved_attrs[attr]) + self._register_state_dict_hook(_save_packed_weight) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + attrs_to_pop = [] + for attr_name in state_dict: + if attr_name.startswith("_packed_weight") and isinstance( + state_dict[attr_name], torch._C.ScriptObject + ): # type: ignore[attr-defined] # noqa: B950 + setattr(self, attr_name, state_dict[attr_name]) + attrs_to_pop.append(attr_name) + + # pop the packed param attributesn + for attr_name in attrs_to_pop: + state_dict.pop(attr_name) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def __deepcopy__(self, memo): + fake_mod = torch.nn.Module() + fake_mod.__dict__ = copy.deepcopy(self.__dict__) + return QuantizedGraphModule( + fake_mod, + copy.deepcopy(self.graph), + copy.deepcopy(self.preserved_attr_names), + ) diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/lower_to_fbgemm.py b/phivenv/Lib/site-packages/torch/ao/quantization/fx/lower_to_fbgemm.py new file mode 100644 index 0000000000000000000000000000000000000000..77832f0e7292855679990d8c11d2fa9ebe7454a1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/fx/lower_to_fbgemm.py @@ -0,0 +1,21 @@ +from torch.ao.quantization.qconfig import QConfigAny +from torch.fx import GraphModule + +from ._lower_to_native_backend import _lower_to_native_backend + + +__all__ = ["lower_to_fbgemm"] + + +def lower_to_fbgemm( + model: GraphModule, + qconfig_map: dict[str, QConfigAny], + node_name_to_scope: dict[str, tuple[str, type]], + keep_original_weights: bool = False, +) -> GraphModule: + """Lower a quantized reference model (with reference quantized operator patterns) + to fbgemm + """ + return _lower_to_native_backend( + model, qconfig_map, node_name_to_scope, keep_original_weights + ) diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/lower_to_qnnpack.py b/phivenv/Lib/site-packages/torch/ao/quantization/fx/lower_to_qnnpack.py new file mode 100644 index 0000000000000000000000000000000000000000..4d19ecdc1f232d5d34319198aaa325420f476cfb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/fx/lower_to_qnnpack.py @@ -0,0 +1,18 @@ +from torch.ao.quantization.qconfig import QConfigAny +from torch.fx import GraphModule + +from ._lower_to_native_backend import _lower_to_native_backend + + +__all__ = ["lower_to_qnnpack"] + + +def lower_to_qnnpack( + model: GraphModule, + qconfig_map: dict[str, QConfigAny], + node_name_to_scope: dict[str, tuple[str, type]], +) -> GraphModule: + """Lower a quantized reference model (with reference quantized operator patterns) + to qnnpack + """ + return _lower_to_native_backend(model, qconfig_map, node_name_to_scope) diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/lstm_utils.py b/phivenv/Lib/site-packages/torch/ao/quantization/fx/lstm_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2871814a8526d44afe6d1ff63990c1e79a922089 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/fx/lstm_utils.py @@ -0,0 +1,221 @@ +import copy +import operator +from typing import Any, Callable, Optional + +import torch +from torch.ao.quantization import ( + default_weight_fake_quant, + default_weight_observer, + FakeQuantizeBase, + QConfig, + QConfigMapping, +) +from torch.ao.quantization.backend_config import BackendConfig +from torch.ao.quantization.observer import _PartialWrapper +from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx + + +# TODO: move all LSTM util functions from fx/utils.py to this file +def _get_lstm_with_individually_observed_parts( + float_lstm: torch.nn.LSTM, + example_inputs: tuple[Any, ...], + backend_config: Optional[BackendConfig] = None, + linear_output_obs_ctr: Optional[_PartialWrapper] = None, + sigmoid_obs_ctr: Optional[_PartialWrapper] = None, + tanh_obs_ctr: Optional[_PartialWrapper] = None, + cell_state_obs_ctr: Optional[_PartialWrapper] = None, + hidden_state_obs_ctr: Optional[_PartialWrapper] = None, + split_gates: bool = False, +) -> torch.ao.nn.quantizable.LSTM: + """ + Return an observed `torch.ao.nn.quantizable.LSTM` created from a `torch.nn.LSTM` + with specific observers or fake quantizes assigned to the inner ops or submodules. + + In both eager and FX graph mode quantization, `torch.ao.nn.quantizable.LSTM` is + used as an observed custom module, which is responsible for inserting its own + observers. By default, all inner ops inherit the parent custom module's QConfig. + Users who wish to override this behavior may extend `torch.ao.nn.quantizable.LSTM` + and use this helper function to customize the observer insertion logic. + + This is meant to be used to convert a float module to an observed module in the + custom module flow. + + Args: + `float_lstm`: The float LSTM module + `example_inputs`: example inputs for the forward function of the LSTM module + `backend_config`: BackendConfig to use to observe the LSTM module + `linear_output_obs_ctr`: observer or fake quantize for linear outputs Wx + b, + where W is the weight matrix, b is the bias, and x is either the inputs + or the hidden state from the previous layer (if any) + `sigmoid_obs_ctr`: observer or fake quantize for sigmoid activations + `tanh_obs_ctr`: observer or fake quantize for tanh activations + `cell_state_obs_ctr`: observer or fake quantize for the cell state + `hidden_state_obs_ctr`: observer or fake quantize for the hidden state and + the output + + Return: + A `torch.ao.nn.quantizable.LSTM` with the specified observers or fake quantizes + assigned to the inner ops. + """ + + def make_qconfig(obs_ctr: _PartialWrapper) -> QConfig: + """ + Make a QConfig with fixed qparams observers or fake quantizes. + """ + if isinstance(obs_ctr(), FakeQuantizeBase): + weight = default_weight_fake_quant + else: + weight = default_weight_observer + return QConfig(activation=obs_ctr, weight=weight) + + quantizable_lstm = torch.ao.nn.quantizable.LSTM( + float_lstm.input_size, + float_lstm.hidden_size, + float_lstm.num_layers, + float_lstm.bias, + float_lstm.batch_first, + float_lstm.dropout, + float_lstm.bidirectional, + split_gates=split_gates, + ) + quantizable_lstm.qconfig = float_lstm.qconfig + + for idx in range(float_lstm.num_layers): + quantizable_lstm.layers[idx] = ( + torch.ao.nn.quantizable.modules.rnn._LSTMLayer.from_float( + float_lstm, + idx, + float_lstm.qconfig, + batch_first=False, + split_gates=split_gates, + ) + ) + + # Build QConfigMapping for the LSTM cell + # Note: FloatFunctional qconfigs will be configured separately below + cell_qm = QConfigMapping().set_global(float_lstm.qconfig) # type: ignore[arg-type] + if sigmoid_obs_ctr is not None: + cell_qm.set_module_name("input_gate", make_qconfig(sigmoid_obs_ctr)) + cell_qm.set_module_name("forget_gate", make_qconfig(sigmoid_obs_ctr)) + cell_qm.set_module_name("output_gate", make_qconfig(sigmoid_obs_ctr)) + if tanh_obs_ctr is not None: + cell_qm.set_module_name("cell_gate", make_qconfig(tanh_obs_ctr)) + + # Insert observers into each LSTM cell + # TODO: maybe make this work for layer_bw as well + for layer in quantizable_lstm.layers: + cell = layer.layer_fw.cell # type: ignore[union-attr] + assert isinstance(cell, torch.nn.Module), "cell should be a nn.Module" + cell = prepare_fx(cell, cell_qm, example_inputs, backend_config=backend_config) + # HACK: Manually replace the activation_post_process following these ops. + # This is needed for FloatFunctional ops because there is currently no way + # to configure these ops in FX graph mode quantization today. This is because + # the FloatFunctional modules simply disappear from the graph after tracing. + # In the future, we should rewrite quantizable LSTM without FloatFunctionals. + if not split_gates: + op_index_to_activation_post_process_ctr = { + (torch.add, 0): linear_output_obs_ctr, # gates.add + (torch.mul, 0): cell_state_obs_ctr, # fgate_cx.mul + (torch.mul, 1): cell_state_obs_ctr, # igate_cgate.mul + (torch.add, 1): cell_state_obs_ctr, # fgate_cx_igate_cgate.add + (torch.mul, 2): hidden_state_obs_ctr, # ogate_cy.mul + } + else: + op_index_to_activation_post_process_ctr = { + (torch.add, 0): linear_output_obs_ctr, # gates.add (input) + (torch.add, 1): linear_output_obs_ctr, # gates.add (forget) + (torch.add, 2): linear_output_obs_ctr, # gates.add (cell) + (torch.add, 3): linear_output_obs_ctr, # gates.add (output) + (torch.mul, 0): cell_state_obs_ctr, # fgate_cx.mul + (torch.mul, 1): cell_state_obs_ctr, # igate_cgate.mul + (torch.add, 4): cell_state_obs_ctr, # fgate_cx_igate_cgate.add + (torch.mul, 2): hidden_state_obs_ctr, # ogate_cy.mul + } + add_count = 0 + mul_count = 0 + for node in cell.graph.nodes: + op_index: Optional[tuple[Callable, int]] = None # e.g. (torch.add, 1) + if node.target == torch.add: + op_index = (torch.add, add_count) + add_count += 1 + elif node.target == torch.mul: + op_index = (torch.mul, mul_count) + mul_count += 1 + else: + # Neither torch.add nor torch.mul + continue + if op_index not in op_index_to_activation_post_process_ctr: + continue + assert len(node.users) == 1 + activation_post_process_name = next(iter(node.users.keys())).name + activation_post_process_ctr = op_index_to_activation_post_process_ctr[ + op_index + ] + if activation_post_process_ctr is not None: + setattr( + cell, activation_post_process_name, activation_post_process_ctr() + ) + layer.layer_fw.cell = cell # type: ignore[union-attr] + return quantizable_lstm + + +def _get_reference_quantized_lstm_module( + observed_lstm: torch.ao.nn.quantizable.LSTM, + backend_config: Optional[BackendConfig] = None, +) -> torch.ao.nn.quantized.LSTM: + """ + Return a `torch.ao.nn.quantized.LSTM` created from a `torch.ao.nn.quantizable.LSTM` + with observers or fake quantizes inserted through `prepare_fx`, e.g. from + `_get_lstm_with_individually_observed_parts`. + + This is meant to be used to convert an observed module to a quantized module in the + custom module flow. + + Args: + `observed_lstm`: a `torch.ao.nn.quantizable.LSTM` observed through `prepare_fx` + `backend_config`: BackendConfig to use to produce the reference quantized model + + Return: + A reference `torch.ao.nn.quantized.LSTM` module. + """ + quantized_lstm = torch.ao.nn.quantized.LSTM( + observed_lstm.input_size, + observed_lstm.hidden_size, + observed_lstm.num_layers, + observed_lstm.bias, + observed_lstm.batch_first, + observed_lstm.dropout, + observed_lstm.bidirectional, + ) + + for i, layer in enumerate(quantized_lstm.layers): + cell = copy.deepcopy(observed_lstm.layers.get_submodule(str(i)).layer_fw.cell) # type: ignore[union-attr] + cell = convert_to_reference_fx(cell, backend_config=backend_config) # type: ignore[arg-type] + assert isinstance(cell, torch.fx.GraphModule) + # HACK: Manually remove input quantize nodes and output dequantize nodes, + # since custom modules expect quint8 inputs and outputs for now. Note that + # this functionality is supposedly handled through PrepareCustomConfig's + # `set_input_quantized_indexes` and `set_output_quantized_indexes`, but that + # API doesn't currently handle tuple inputs and outputs, so we have to do + # this manually for now. In the future we should (1) relax the restriction + # on custom module input/output dtypes, and (2) expand support for complex + # input/output structures. + for node in cell.graph.nodes: + if node.target == torch.quantize_per_tensor: + arg = node.args[0] + # Remove quantize(x), quantize(hidden[0]), and quantize(hidden[1]) + if arg.target == "x" or ( + arg.target == operator.getitem and arg.args[0].target == "hidden" + ): + with cell.graph.inserting_before(node): + node.replace_all_uses_with(arg) + cell.graph.erase_node(node) + if node.target == "output": + # Remove all dequantize nodes in the output tuple + for arg in node.args[0]: + with cell.graph.inserting_before(node): + node.replace_input_with(arg, arg.args[0]) + cell.graph.eliminate_dead_code() + cell.recompile() + layer.layer_fw.cell = cell # type: ignore[union-attr] + return quantized_lstm diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/match_utils.py b/phivenv/Lib/site-packages/torch/ao/quantization/fx/match_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..86d112e53e9ba378ee4dd050efd0b427081308ea --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/fx/match_utils.py @@ -0,0 +1,228 @@ +# mypy: allow-untyped-defs +import sys +from collections.abc import Iterable +from typing import Any, Callable, Optional + +import torch +from torch.ao.quantization.qconfig import QConfigAny +from torch.ao.quantization.utils import MatchAllNode, Pattern +from torch.fx.graph import Graph, Node +from torch.nn.utils.parametrize import type_before_parametrizations + +from .graph_module import _is_observed_standalone_module +from .quantize_handler import QuantizeHandler + + +__all__: list[str] = [] + +# TODO(future PR): the 1st argument is typed as `List[Node]`, but a better type +# would be a recursive `List[Union[Node, Tuple[Union[Node, ...]]]]` +_MatchResult = tuple[Node, list[Node], Optional[Pattern], QuantizeHandler] + +_MatchResultWithQConfig = tuple[ + Node, list[Node], Optional[Pattern], QuantizeHandler, QConfigAny +] + + +# Note: The order of patterns is important! match function will take whatever is matched first, so we'll +# need to put the fusion patterns before single patterns. For example, add_relu should be registered come before relu. +# decorators are applied in the reverse order we see. Also when we match the nodes in the graph with these patterns, +# we'll start from the last node of the graph and traverse back. +def _is_match(modules, node, pattern, max_uses=sys.maxsize): + """Matches a node in fx against a pattern""" + if isinstance(pattern, tuple): + self_match, *arg_matches = pattern + if self_match is getattr: + assert len(pattern) == 2, "Expecting getattr pattern to have two elements" + arg_matches = [] + else: + self_match = pattern + arg_matches = [] + + if isinstance(self_match, type) and issubclass(self_match, MatchAllNode): + return True + + if node == pattern: + return True + + if not isinstance(node, Node) or len(node.users) > max_uses: + return False + + if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module): + if node.op != "call_module": + return False + if not type_before_parametrizations(modules[node.target]) == self_match: + return False + elif callable(self_match): + if node.op != "call_function" or node.target is not self_match: + return False + elif node.target is getattr: + if node.args[1] != pattern[1]: + return False + elif isinstance(self_match, str): + if node.op != "call_method" or node.target != self_match: + return False + elif node.target != self_match: + return False + + if not arg_matches: + return True + + if len(arg_matches) != len(node.args): + return False + + return all( + _is_match(modules, node, arg_match, max_uses=1) + for node, arg_match in zip(node.args, arg_matches) + ) + + +def _find_matches( + graph: Graph, + modules: dict[str, torch.nn.Module], + patterns: dict[Pattern, QuantizeHandler], + root_node_getter_mapping: dict[Pattern, Callable], + standalone_module_names: Optional[list[str]] = None, + standalone_module_classes: Optional[list[type]] = None, + custom_module_classes: Optional[list[Any]] = None, +) -> dict[str, _MatchResult]: + """ + Matches the nodes in the input graph to quantization patterns, and + outputs the information needed to quantize them in future steps. + + Inputs: + - graph: an fx.Graph object + - modules: a mapping of fully qualified module name to instance, + for example, {'foo': ModuleFoo, ...} + - patterns: a mapping from a tuple of nodes in reverse order to + uninitialized QuantizeHandler subclass. + + Outputs a map of + node_name -> + (node, matched_values, matched_pattern, QuantizeHandler instance, + qconfig) + + For example, { + 'relu_1': (relu_1, [relu_1], torch.nn.functional.relu, + , QConfig(...)), + ... + } + """ + if custom_module_classes is None: + custom_module_classes = [] + + if standalone_module_classes is None: + standalone_module_classes = [] + + if standalone_module_names is None: + standalone_module_names = [] + + match_map: dict[str, _MatchResult] = {} + all_matched: set[str] = set() + + def _recursive_record_node_in_match_map( + last_node, match_map, node_pattern, matched_node_pattern, pattern, match_value + ): + if isinstance(node_pattern, Node): + match_map[node_pattern.name] = ( + last_node, + matched_node_pattern, + pattern, + match_value, + ) + elif not isinstance(node_pattern, Iterable): + return + else: + for n in node_pattern: + _recursive_record_node_in_match_map( + last_node, match_map, n, matched_node_pattern, pattern, match_value + ) + + # TODO: 1. merge with fuse matcher 2. document the code + def record_match(pattern, node, last_node, matched_node_pattern, match_map): + if isinstance(pattern, tuple): + s, *args = pattern + is_single_arg = len(args) == 1 + current_node_pattern: list[Node] = [] + record_match(s, node, last_node, matched_node_pattern, match_map) + if pattern[0] is not getattr: + for subpattern, arg in zip(args, node.args): + record_match(subpattern, arg, node, current_node_pattern, match_map) + if len(current_node_pattern) > 1: + # current_node_pattern is the node pattern we get from matching + # the subpattern with arguments of the node + # we use is_single_arg to recover the original structure of the pattern + # if the original pattern has a single argument, we will have + # (original_op, (original_arg, ...)) + # otherwise, we'll have a list of arguments + # (original_op, arg0, arg1, arg2, ...) + if is_single_arg: + matched_node_pattern.append(tuple(current_node_pattern)) + else: + matched_node_pattern.extend(list(current_node_pattern)) + else: + matched_node_pattern.append(current_node_pattern[0]) + else: + matched_node_pattern.append(node) + + for node in reversed(graph.nodes): + if node.name not in match_map and node.name not in all_matched: + for pattern, quantize_handler_cls in patterns.items(): + root_node_getter = root_node_getter_mapping.get(pattern, None) + if _is_match(modules, node, pattern) and node.name not in match_map: + matched_node_pattern: list[Node] = [] + record_match(pattern, node, node, matched_node_pattern, match_map) + quantize_handler = quantize_handler_cls( # type: ignore[operator] + matched_node_pattern, modules, root_node_getter + ) + last_node = node + # record the match for all nodes in the pattern + _recursive_record_node_in_match_map( + last_node, + match_map, + # we need to record all nodes in the matched pattern in the match_map + matched_node_pattern, + # this is a part of the value corresponding to the node + matched_node_pattern, + pattern, + quantize_handler, + ) + break + + # add custom module instances to the match result + assert modules is not None + for node in graph.nodes: + if ( + node.op == "call_module" + and type(modules[node.target]) in custom_module_classes + ): + match_map[node.name] = ( + node, + node, + None, + QuantizeHandler(node, modules, is_custom_module=True), + ) + + def is_standalone_module(node_target: str, modules: dict[str, torch.nn.Module]): + assert modules is not None + return ( + node_target in standalone_module_names + or type(modules[node_target]) # type: ignore[operator] + in standalone_module_classes # type: ignore[operator] + ) + + # add standalone modules to the match + for node in graph.nodes: + if node.op == "call_module" and ( + is_standalone_module(node.target, modules) + or _is_observed_standalone_module(modules[node.target]) + ): + # add node to matched nodes + match_map[node.name] = ( + node, + node, + None, + QuantizeHandler(node, modules, is_standalone_module=True), + ) + + return match_map diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/pattern_utils.py b/phivenv/Lib/site-packages/torch/ao/quantization/fx/pattern_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6377dcf445dbce504b9aef9c66847e37e06f49ac --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/fx/pattern_utils.py @@ -0,0 +1,112 @@ +# mypy: allow-untyped-defs +import copy +from collections import OrderedDict +from typing import Any + +from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize +from torch.ao.quantization.observer import ObserverBase +from torch.ao.quantization.utils import Pattern + + +__all__ = [ + "get_default_fusion_patterns", + "get_default_quant_patterns", + "get_default_output_activation_post_process_map", +] + +# TODO(future PR): fix the typing on QuantizeHandler (currently a circular dependency) +QuantizeHandler = Any + +# pattern for conv bn fusion +_DEFAULT_FUSION_PATTERNS: dict[Pattern, QuantizeHandler] = OrderedDict() + + +def _register_fusion_pattern(pattern): + def insert(fn): + _DEFAULT_FUSION_PATTERNS[pattern] = fn + return fn + + return insert + + +def get_default_fusion_patterns() -> dict[Pattern, QuantizeHandler]: + return copy.copy(_DEFAULT_FUSION_PATTERNS) + + +_DEFAULT_QUANTIZATION_PATTERNS: dict[Pattern, QuantizeHandler] = OrderedDict() + +# Mapping from pattern to activation_post_process(observer/fake_quant) constructor for output activation +# e.g. pattern: torch.sigmoid, +# output_activation_post_process: default_fixed_qparams_range_0to1_fake_quant +_DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP: dict[Pattern, QuantizeHandler] = {} +_DEFAULT_OUTPUT_OBSERVER_MAP: dict[Pattern, QuantizeHandler] = {} + + +# Register pattern for both static quantization and qat +def _register_quant_pattern(pattern, fixed_qparams_observer=None): + def insert(fn): + _DEFAULT_QUANTIZATION_PATTERNS[pattern] = fn + if fixed_qparams_observer is not None: + _DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP[pattern] = ( + FixedQParamsFakeQuantize.with_args(observer=fixed_qparams_observer) + ) + _DEFAULT_OUTPUT_OBSERVER_MAP[pattern] = fixed_qparams_observer + return fn + + return insert + + +# Get patterns for both static quantization and qat +def get_default_quant_patterns() -> dict[Pattern, QuantizeHandler]: + return copy.copy(_DEFAULT_QUANTIZATION_PATTERNS) + + +# a map from pattern to output activation post process constructor +# e.g. torch.sigmoid -> default_affine_fixed_qparam_fake_quant +def get_default_output_activation_post_process_map( + is_training, +) -> dict[Pattern, ObserverBase]: + if is_training: + return copy.copy(_DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP) + else: + return copy.copy(_DEFAULT_OUTPUT_OBSERVER_MAP) + + +# Example use of register pattern function: +# @_register_fusion_pattern(torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d))) +# class ConvOrLinearBNReLUFusion(): +# def __init__(...): +# ... +# + + +def _sorted_patterns_dict( + patterns_dict: dict[Pattern, QuantizeHandler], +) -> dict[Pattern, QuantizeHandler]: + """ + Return a sorted version of the patterns dictionary such that longer patterns are matched first, + e.g. match (F.relu, F.linear) before F.relu. + This works for current use cases, but we may need to have a more clever way to sort + things to address more complex patterns + """ + + def get_len(pattern): + """this will calculate the length of the pattern by counting all the entries + in the pattern. + this will make sure (nn.ReLU, (nn.BatchNorm, nn.Conv2d)) comes before + (nn.BatchNorm, nn.Conv2d) so that we can match the former first + """ + len = 0 + if isinstance(pattern, tuple): + for item in pattern: + len += get_len(item) + else: + len += 1 + return len + + return OrderedDict( + sorted( + patterns_dict.items(), + key=lambda kv: -get_len(kv[0]) if isinstance(kv[0], tuple) else 1, + ) + ) diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/prepare.py b/phivenv/Lib/site-packages/torch/ao/quantization/fx/prepare.py new file mode 100644 index 0000000000000000000000000000000000000000..5bc14dbc2becfab252eb8e916a6a63906e5f084f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/fx/prepare.py @@ -0,0 +1,2186 @@ +# mypy: allow-untyped-defs +import copy +import warnings +from dataclasses import asdict +from typing import Any, Optional, Union + +import torch +from torch._subclasses import FakeTensor +from torch.ao.quantization import ( + _DerivedObserverOrFakeQuantize, + FixedQParamsFakeQuantize, + FixedQParamsObserver, + ObserverBase, + ObserverOrFakeQuantize, + PlaceholderObserver, +) +from torch.ao.quantization.backend_config import ( + BackendConfig, + DTypeConfig, + get_native_backend_config, +) +from torch.ao.quantization.backend_config.utils import ( + get_fusion_pattern_to_root_node_getter, + get_module_to_qat_module, + get_pattern_to_dtype_configs, +) +from torch.ao.quantization.observer import _is_activation_post_process, _PartialWrapper +from torch.ao.quantization.qconfig import _is_reuse_input_qconfig, QConfigAny +from torch.ao.quantization.qconfig_mapping import QConfigMapping +from torch.ao.quantization.quantize import convert, propagate_qconfig_ +from torch.ao.quantization.quantizer import ( + DerivedQuantizationSpec, + EdgeOrNode, + FixedQParamsQuantizationSpec, + QuantizationSpec, + QuantizationSpecBase, + SharedQuantizationSpec, +) +from torch.ao.quantization.utils import ( + _parent_name, + get_qconfig_dtypes, + get_swapped_custom_module_class, + NodePattern, + Pattern, +) +from torch.fx import GraphModule +from torch.fx.graph import Graph, Node +from torch.fx.node import Argument + +from ._equalize import is_equalization_observer, node_supports_equalization +from .custom_config import PrepareCustomConfig, StandaloneModuleConfigEntry +from .match_utils import _find_matches, _MatchResultWithQConfig +from .pattern_utils import _sorted_patterns_dict +from .qconfig_mapping_utils import ( + _generate_node_name_to_qconfig, + _get_flattened_qconfig_dict, + _update_qconfig_for_fusion, + _update_qconfig_for_qat, +) +from .quantize_handler import ( + _default_root_node_getter, + _get_pattern_to_quantize_handlers, + QuantizeHandler, +) +from .utils import ( + _insert_dequant_stubs_for_custom_module_lstm_output, + _is_custom_module_lstm, + _maybe_get_custom_module_lstm_from_node_arg, + _qconfig_satisfies_dtype_config_constraints, + all_node_args_have_no_tensors, + assert_and_get_unique_device, + get_custom_module_class_keys, + get_new_attr_name_with_prefix, + get_non_observable_arg_indexes_and_types, + node_arg_is_bias, + node_arg_is_weight, + NON_QUANTIZABLE_WEIGHT_OPS, + ObservedGraphModuleAttrs, +) + + +__all__ = [ + "insert_observers_for_model", + "prepare", + "propagate_dtypes_for_known_nodes", +] + + +# list of dtypes to not add observers to +_DO_NOT_OBS_DTYPE_LIST = [int, float, torch.bool, None] +_OBS_DTYPE_LIST = [ + torch.quint8, + torch.qint8, + torch.qint32, + torch.float16, + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.float8_e5m2, + torch.float8_e4m3fn, +] + +_DEFAULT_FP32_OBS_OR_FQ_CTR = PlaceholderObserver.with_args(dtype=torch.float) + +# note: the following default target dtype info dicts are temporary, +# should be moved to the new programmable API class soon +_DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO = { + "input_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig.activation, + "output_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig.activation, +} + +_DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO = { + "input_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_quint8_placeholder_qconfig.activation, + "output_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_quint8_placeholder_qconfig.activation, +} + + +def _get_observer_kwargs( + quant_spec: Union[QuantizationSpec, FixedQParamsQuantizationSpec], +): + kwargs_dict = asdict(quant_spec) + return copy.deepcopy(kwargs_dict) + + +def _get_qspec_for_arg( + arg: Node, + input_qspec_map: dict[Node, QuantizationSpecBase], + named_modules: dict[str, torch.nn.Module], +) -> Optional[QuantizationSpecBase]: + while _is_activation_post_process_node(arg, named_modules): + arg = arg.args[0] # type: ignore[assignment] + return input_qspec_map.get(arg, None) + + +def _create_obs_or_fq_from_qspec( + quantization_spec: Optional[QuantizationSpecBase], + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, +): + """Create observer or fake quantize objects based on quantization spec + + Args: + quantization_spec: used to store parameters to create the observer or fake quantizer + obs_or_fq_map: this is a map from edge/output to the corresponding observer/fake_quant + instance, it may be reused for different edge/output depending on configuration + """ + if quantization_spec is None: + return None + if isinstance(quantization_spec, SharedQuantizationSpec): + edge_or_node = quantization_spec.edge_or_node + assert edge_or_node in obs_or_fq_map, ( + "please make sure only refer to edge or node that has " + f"observer/fake_quant inserted: '{edge_or_node}' not in\n{obs_or_fq_map.keys()}" + ) + return obs_or_fq_map[edge_or_node] + elif isinstance(quantization_spec, DerivedQuantizationSpec): + # can't use asdict, so not calling get_observer_kwargs here + kwargs = { + "dtype": quantization_spec.dtype, + "derive_qparams_fn": quantization_spec.derive_qparams_fn, + "quant_min": quantization_spec.quant_min, + "quant_max": quantization_spec.quant_max, + "qscheme": quantization_spec.qscheme, + "ch_axis": quantization_spec.ch_axis, + } + edge_or_nodes = quantization_spec.derived_from + obs_or_fqs = [obs_or_fq_map[k] for k in edge_or_nodes] + kwargs["obs_or_fqs"] = obs_or_fqs + return _DerivedObserverOrFakeQuantize.with_args(**kwargs)() + elif isinstance(quantization_spec, FixedQParamsQuantizationSpec): + kwargs = _get_observer_kwargs(quantization_spec) + observer_ctr = FixedQParamsObserver.with_args(**kwargs) + if is_qat: + return FixedQParamsFakeQuantize.with_args(observer=observer_ctr)() + else: + return observer_ctr() + + assert isinstance(quantization_spec, QuantizationSpec) + observer_or_fake_quant_ctr = quantization_spec.observer_or_fake_quant_ctr + kwargs = _get_observer_kwargs(quantization_spec) + kwargs.pop("observer_or_fake_quant_ctr") + # we will remove is_dynamic from QuantizationSpec because + # it seems that dynamic range quantization + obs_or_fq_class = observer_or_fake_quant_ctr + if isinstance(observer_or_fake_quant_ctr, _PartialWrapper): + obs_or_fq_class = observer_or_fake_quant_ctr.p.func # type: ignore[union-attr, assignment] + if "PerChannel" not in obs_or_fq_class.__name__: # type: ignore[operator, union-attr] + kwargs.pop("ch_axis") + return observer_or_fake_quant_ctr.with_args(**kwargs)() + + +def _needs_obs_or_fq( + prev_output_dtype: Any, + prev_output_is_dynamic: bool, + cur_target_dtype: Any, + cur_target_is_dynamic: bool, + reuse_input_obs_or_fq: bool, + is_zeroth_arg: bool = False, +) -> bool: + """ + note: we will treat "not specified" as torch.float for now + utility function that checks if we should insert an observer or fake quant node + base on the requested dtype for the nodes from user + + is_zeroth_arg: we only dynamically quantize the first arg of the node right now + this should be removed when we enable configuring dynamic quantization + for a specific argument, this can be removed if we deprecate fx graph mode + quantization + + """ + + # need to insert placeholder observer for dynamic quantization so that it can + # be converted to choose_qparams -> q -> dq in convert step + if cur_target_is_dynamic: + assert cur_target_dtype in _OBS_DTYPE_LIST, ( + f"Expected cur_target_dtype to be torch.float, but got: {cur_target_dtype}" + ) + assert prev_output_dtype not in _DO_NOT_OBS_DTYPE_LIST + return is_zeroth_arg + if reuse_input_obs_or_fq: + return False + # non dynamic quantization + if cur_target_dtype in _OBS_DTYPE_LIST: + return ( + prev_output_dtype in _OBS_DTYPE_LIST + [torch.float] + and cur_target_dtype != prev_output_dtype + ) + + # lots of error checking are skipped here for now + return False + + +def _is_activation_post_process_node( + node: Node, named_modules: dict[str, torch.nn.Module] +) -> bool: + return ( + isinstance(node, torch.fx.Node) + and node.op == "call_module" + and _is_activation_post_process(named_modules[str(node.target)]) + ) + + +def _get_dtype_and_is_dynamic( + obs_or_fq: Optional[ObserverOrFakeQuantize], +) -> tuple[Optional[torch.dtype], bool]: + """Given a constructor for observer or fake quant module, returns + a Tuple of dtype and is_dynamic + """ + # TODO: instead of instantiating the instance, we can use inspect to get the default args + if obs_or_fq is None: + return None, False + else: + return obs_or_fq.dtype, getattr(obs_or_fq, "is_dynamic", False) # type: ignore[return-value] + + +def _is_input_arg_dtype_supported_by_backend( + arg: Argument, + node: Node, + qconfig: QConfigAny, + dtype_config: DTypeConfig, + backend_config: BackendConfig, +) -> bool: + """Check if the configured qconfig for the argument + is supported by the backend or not + """ + if isinstance(arg, (list, tuple)): + return all( + _is_input_arg_dtype_supported_by_backend( + a, node, qconfig, dtype_config, backend_config + ) + for a in arg + ) + if not isinstance(arg, Node): + return True + # TODO: support check for standalone module + is_weight = node_arg_is_weight(node, arg) + is_bias = node_arg_is_bias(node, arg) + is_activation = not is_weight and not is_bias + if is_activation: + input_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get( + "input_act_obs_or_fq_ctr" + ) + input_act_obs_or_fq = ( + input_act_obs_or_fq_ctr() if input_act_obs_or_fq_ctr else None + ) + qconfig_dtype, qconfig_is_dynamic = _get_dtype_and_is_dynamic( + input_act_obs_or_fq + ) + # TODO(future PR): remove the cast to bool below after figuring + # out why backend_config has is_dynamic set to None in some cases. + return (dtype_config.input_dtype is None) or ( + dtype_config.input_dtype == qconfig_dtype + and bool(dtype_config.is_dynamic) == bool(qconfig_is_dynamic) + and _qconfig_satisfies_dtype_config_constraints( + qconfig, dtype_config.input_dtype_with_constraints + ) + ) + elif is_weight: + # TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well + weight_obs_or_fq_ctr = node.meta["target_dtype_info"].get( + "weight_obs_or_fq_ctr", None + ) + weight_obs_or_fq = weight_obs_or_fq_ctr() if weight_obs_or_fq_ctr else None + qconfig_weight_dtype, _ = _get_dtype_and_is_dynamic(weight_obs_or_fq) + backend_config_weight_dtype = dtype_config.weight_dtype + dtype_matches = qconfig_weight_dtype == backend_config_weight_dtype + qconfig_satisfies_constraints = _qconfig_satisfies_dtype_config_constraints( + qconfig, dtype_config.weight_dtype_with_constraints, is_activation=False + ) + return backend_config_weight_dtype is None or ( + dtype_matches and qconfig_satisfies_constraints + ) + else: # bias + # TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well + bias_obs_or_fq_ctr = node.meta["target_dtype_info"].get( + "bias_obs_or_fq_ctr", None + ) + bias_obs_or_fq = bias_obs_or_fq_ctr() if bias_obs_or_fq_ctr else None + qconfig_bias_dtype, _ = _get_dtype_and_is_dynamic(bias_obs_or_fq) + backend_config_bias_dtype = dtype_config.bias_dtype + return ( + backend_config_bias_dtype is None + or qconfig_bias_dtype == backend_config_bias_dtype + ) + + +def _is_output_dtype_supported_by_backend( + node: Node, + qconfig: QConfigAny, + dtype_config: DTypeConfig, +) -> bool: + """Check if the configured qconfig for the output + is supported by the backend or not + """ + # TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well + backend_config_output_dtype = dtype_config.output_dtype + # TODO: we should check is_dynamic here as well, the code from _is_input_arg_dtype_supported_by_backend + # from input activation check can be reused here + qconfig_output_dtype = None + output_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get( + "output_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR + ) + output_act_obs_or_fq = ( + output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None + ) + qconfig_output_dtype, qconfig_output_is_dynamic = _get_dtype_and_is_dynamic( + output_act_obs_or_fq + ) + # TODO: this is a hack because we can only specify one activation_obs_or_fq for + # qconfig (qconfig.activation), and we are only supporting dynamically quantized + # linear op which has fp32 output dtype, this should be removed if we generalize + # the structure of qconfig in the future + if qconfig_output_is_dynamic: + qconfig_output_dtype = torch.float32 + dtype_matches = qconfig_output_dtype == backend_config_output_dtype + qconfig_satisfies_constraints = _qconfig_satisfies_dtype_config_constraints( + qconfig, dtype_config.output_dtype_with_constraints + ) + return backend_config_output_dtype is None or ( + dtype_matches and qconfig_satisfies_constraints + ) + + +def _is_observer_in_same_graph( + node: Node, + named_modules: dict[str, torch.nn.Module], + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat, +): + """Check if observer in same graph + when the node output is not fp32 and input is 'placeholder' + the input is assumed to be quantized, so it is observed + in a different place rather than not observed. + """ + node_output_dtype = _get_arg_target_dtype_as_output( + node, named_modules, obs_or_fq_map, is_qat + ) + if len(node.args) > 0 and isinstance(node.args[0], Node): + if ( + node_output_dtype in [torch.quint8, torch.uint8] + and node.args[0].op == "placeholder" + ): + return False + return True + + +def _is_pattern_dtype_config_and_qconfig_supported_by_backend( + pattern: Optional[Pattern], + matched_node_pattern: Optional[list[Node]], + qconfig: QConfigAny, + backend_config: BackendConfig, +) -> bool: + """Check if the dtype configuration of a pattern is supported by + the backend or not, and whether the qconfig satisfies constraints + specified in the corresponding dtype config. + """ + if backend_config is None or pattern is None: + return True + assert matched_node_pattern is not None and len(matched_node_pattern) >= 1 + pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config) + dtype_configs: list[DTypeConfig] = pattern_to_dtype_configs.get(pattern, []) + pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config) + + root_node_getter = pattern_to_root_node_getter.get( + pattern, _default_root_node_getter + ) + root_node = root_node_getter(matched_node_pattern) + input_node = root_node + output_node = matched_node_pattern[0] + for dtype_config in dtype_configs: + # check if arg dtype are supported + supported = True + for arg in list(input_node.args) + list(input_node.kwargs.values()): + supported = supported and _is_input_arg_dtype_supported_by_backend( + arg, input_node, qconfig, dtype_config, backend_config + ) + # check if output dtype is supported + supported = supported and _is_output_dtype_supported_by_backend( + output_node, qconfig, dtype_config + ) + if supported: + return True + return False + + +def _get_standalone_module_configs( + node: Node, + named_modules: dict[str, torch.nn.Module], + prepare_custom_config: PrepareCustomConfig, + parent_qconfig: QConfigAny, + parent_backend_config: Optional[BackendConfig], +) -> tuple[ + QConfigMapping, tuple[Any, ...], PrepareCustomConfig, Optional[BackendConfig] +]: + """ + Returns the standalone module QConfigMapping and PrepareCustomConfig + for `node`, assuming that the module pointed to by `node` is + a standalone modules. + """ + module_name = str(node.target) + module_type = type(named_modules[module_name]) # type: ignore[index] + # name config has precedence over type config + config_entry = StandaloneModuleConfigEntry(None, (), None, None) + config_entry = prepare_custom_config.standalone_module_classes.get( + module_type, config_entry + ) + config_entry = prepare_custom_config.standalone_module_names.get( + module_name, config_entry + ) + # fallback to use parent module's qconfig if user didn't specify qconfig dict + qconfig_mapping = config_entry.qconfig_mapping or QConfigMapping().set_global( + parent_qconfig + ) + example_inputs = config_entry.example_inputs + prepare_custom_config = config_entry.prepare_custom_config or PrepareCustomConfig() + backend_config = config_entry.backend_config or parent_backend_config + return (qconfig_mapping, example_inputs, prepare_custom_config, backend_config) + + +def _qat_swap_modules( + root: torch.nn.Module, module_to_qat_module: dict[Pattern, type[torch.nn.Module]] +) -> None: + convert(root, mapping=module_to_qat_module, inplace=True, remove_qconfig=False) + + +def _add_matched_node_name_to_set(matched_node_pattern: NodePattern, s: set[str]): + if isinstance(matched_node_pattern, Node): + s.add(matched_node_pattern.name) + elif isinstance(matched_node_pattern, (list, tuple)): + for maybe_node in matched_node_pattern: + _add_matched_node_name_to_set(maybe_node, s) + + +def _insert_obs_or_fq( + node: Node, + obs_or_fq: ObserverOrFakeQuantize, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + graph: Graph, +) -> Node: + """ + Attaches `obs_or_fq` to `model`, and creates a node which calls + `obs_or_fq` on the output of `node`. + + obs_or_fq: an instance of Observer or FakeQuantize module + """ + model_device = assert_and_get_unique_device(model) + if model_device: + obs_or_fq.to(model_device) + # add obs_or_fq module as attribute + if is_equalization_observer(obs_or_fq): + prefix = node.name + "_equalization_process_" + else: + prefix = "activation_post_process_" + get_new_obs_or_fq_name = get_new_attr_name_with_prefix(prefix) + obs_or_fq_name = get_new_obs_or_fq_name(model) + setattr(model, obs_or_fq_name, obs_or_fq) + named_modules[obs_or_fq_name] = obs_or_fq + with graph.inserting_after(node): + new_obs = graph.create_node("call_module", obs_or_fq_name, (node,), {}) + return new_obs + + +def _set_target_dtype_info_for_matched_node_pattern( + matched_node_pattern: NodePattern, + last_node: Node, + qconfig: QConfigAny, + qhandler: Optional[QuantizeHandler], + backend_config: BackendConfig, + named_modules: dict[str, torch.nn.Module], + cache_for_no_tensor_check: dict[Node, bool], + processed_nodes: set[Node], +) -> None: + """Sets the target_dtype_info for each node in matched_node_pattern + Note: processed_nodes is used to ensure we only process each node once + """ + if isinstance(matched_node_pattern, (list, tuple)): + for node_pattern in matched_node_pattern: + _set_target_dtype_info_for_matched_node_pattern( + node_pattern, + last_node, + qconfig, + qhandler, + backend_config, + named_modules, + cache_for_no_tensor_check, + processed_nodes, + ) + + # set target_dtype_info if matched_node_pattern is a Node + # other types of matched object, e.g. int, float literals, are ignored + elif isinstance(matched_node_pattern, Node): + # for pyre + assert isinstance(matched_node_pattern, Node) + node = matched_node_pattern + if node in processed_nodes: + return + processed_nodes.add(node) + + if qconfig is None: + return + # TODO: refactor the following code in terms of apply a qconfig to a pattern + # e.g. for a pattern with op1 -> op2 -> op3, and qconfig = QConfig(input_act=obs0, output_act=obs1) + # we set the input_obs_or_fq_ctr for the arguments of op1 to based on qconfig.input_act, + # and set output_obs_or_fq_ctr based on qconfig.output_act + # this also requires we extend the structure of QConfig to support more fine + # grained configurations + target_dtype_info: dict[str, Any] = _get_target_activation_dtype_for_node( + node, + qconfig, + qhandler, + named_modules, + backend_config, + cache_for_no_tensor_check, + ) + node.meta["target_dtype_info"] = target_dtype_info + + +def _get_target_activation_dtype_for_node( + node: Node, + qconfig: QConfigAny, + qhandler: Optional[QuantizeHandler], + named_modules: dict[str, torch.nn.Module], + backend_config: BackendConfig, + cache_for_no_tensor_check: dict[Node, bool], +) -> dict[str, Any]: + """ + For each op attribute in the op's input activation, output activation, + weight, bias - returns the settings of dtype and is_dynamic we expect + for the `quantize` call in the reference model representation, or None + if there is no `quantize` call needed. + + For example, if we have a node corresponding to `op0` in + + x0 -> op0 -> x1 + + And we want a reference quantized representation to be + + x0 -> quant_static -> dequant -> op0 -> quant_dynamic -> dequant -> x1 + + Then this function will return + + { + "input_act_obs_or_fq_ctr": MinMaxObserver.with_args(dtype=torch.quint8, is_dynamic=False), + "output_act_obs_or_fq_ctr": MinMaxObserver.with_args(dtype=torch.quint8, is_dynamic=False), + } + + TODO(future PR, if needed): explicitly spell out the non-Tensor + dtypes. + """ + args_have_no_tensors = all_node_args_have_no_tensors( + node, named_modules, cache_for_no_tensor_check + ) + if args_have_no_tensors: + return { + "input_act_obs_or_fq_ctr": None, + "output_act_obs_or_fq_ctr": None, + } + # get qconfig to determine the eventual dtype of this node + if qconfig is not None: + act_dtype, weight_dtype, input_act_is_dynamic = get_qconfig_dtypes(qconfig) + + # Currently `QConfig` only has one `activation` field. + # For static quantization, it is reused for both input + # and output activation. For dynamic quantization, this + # field is currently only used for the input activation, + # with the output activation being in fp32. + # In the future this may change as we add more fields + # to the `QConfig` object. + bias_dtype = ( + torch.float16 + if ( + act_dtype == torch.float16 + and weight_dtype == torch.float16 + and (not input_act_is_dynamic) + ) + else torch.float + ) + + is_general_tensor_value_op = ( + qhandler is not None and qhandler.is_general_tensor_value_op() + ) + + _is_standalone_module = qhandler is not None and qhandler.is_standalone_module() + + weight_index = None + if ( + isinstance(node, Node) + and node.op == "call_function" + and node.target in backend_config._pattern_complex_format_to_config + ): + weight_index = backend_config._pattern_complex_format_to_config[ + node.target + ]._input_type_to_index.get("weight") + + bias_index = None + if ( + isinstance(node, Node) + and node.op == "call_function" + and node.target in backend_config._pattern_complex_format_to_config + ): + bias_index = backend_config._pattern_complex_format_to_config[ + node.target + ]._input_type_to_index.get("bias") + + return { + "input_act_obs_or_fq_ctr": qconfig.activation, + "weight_obs_or_fq_ctr": qconfig.weight, + "bias_obs_or_fq_ctr": PlaceholderObserver.with_args(dtype=bias_dtype), + "weight_index": weight_index, + "bias_index": bias_index, + "output_act_obs_or_fq_ctr": qconfig.activation, + "reuse_input_obs_or_fq": _is_reuse_input_qconfig(qconfig), + "input_output_share_observers": is_general_tensor_value_op, + "_is_standalone_module": _is_standalone_module, + } + return copy.copy(_DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO) + + +def _get_output_act_obs_or_fq( + arg: Node, + named_modules: dict[str, torch.nn.Module], + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, +) -> Optional[ObserverOrFakeQuantize]: + """Get the constructor for observer or fake quant object for + the argument in the original graph as the output of previous node, + skipping inserted observers + + We are assuming that the observers are inserted correctly, and the dtype for + argument in quantized graph will match what is specified by the qconfig + """ + assert isinstance(arg, Node) + if "quantization_annotation" in arg.meta: + return _create_obs_or_fq_from_qspec( + arg.meta["quantization_annotation"].output_qspec, obs_or_fq_map, is_qat + ) + + # Custom module LSTM output is a tuple that we broke down into the internal nodes in order + # to insert DeQuantStubs (see `_insert_dequant_stubs_for_custom_module_lstm_output`). + # Since we modified the graph in this case, we must trace back from the args through + # the specific nodes we added in order to reach the original LSTM node. Otherwise, we would + # not be able to accurately detect whether this node is a consumer of custom module LSTM. + custom_module_lstm_node = _maybe_get_custom_module_lstm_from_node_arg( + arg, named_modules + ) + output_act_obs_or_fq_ctr = None + if custom_module_lstm_node is not None: + output_act_obs_or_fq_ctr = custom_module_lstm_node.meta["target_dtype_info"][ + "output_act_obs_or_fq_ctr" + ] + output_act_obs_or_fq = ( + output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None + ) + elif _is_activation_post_process_node(arg, named_modules): + observed_arg = arg.args[0] + assert isinstance(observed_arg, Node), ( + "Currently we only support observing Node" + ) + if "quantization_annotation" in observed_arg.meta: + output_act_obs_or_fq = _create_obs_or_fq_from_qspec( + observed_arg.meta["quantization_annotation"].output_qspec, + obs_or_fq_map, + is_qat, + ) + else: + assert "target_dtype_info" in observed_arg.meta + output_act_obs_or_fq_ctr = observed_arg.meta["target_dtype_info"][ + "output_act_obs_or_fq_ctr" + ] + output_act_obs_or_fq = ( + output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None + ) + else: + if "target_dtype_info" in arg.meta: + output_act_obs_or_fq_ctr = arg.meta["target_dtype_info"].get( + "output_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR + ) + else: + output_act_obs_or_fq_ctr = _DEFAULT_FP32_OBS_OR_FQ_CTR + output_act_obs_or_fq = ( + output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None + ) + + return output_act_obs_or_fq + + +def _get_arg_target_dtype_as_output( + arg: Node, + named_modules: dict[str, torch.nn.Module], + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, +) -> Optional[torch.dtype]: + arg_as_output_act_obs_or_fq = _get_output_act_obs_or_fq( + arg, named_modules, obs_or_fq_map, is_qat + ) + arg_as_output_target_dtype, _ = _get_dtype_and_is_dynamic( + arg_as_output_act_obs_or_fq + ) + return arg_as_output_target_dtype + + +def _get_arg_as_input_act_obs_or_fq( + arg: Node, + node: Node, + named_modules: dict[str, torch.nn.Module], + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, +) -> Optional[ObserverOrFakeQuantize]: + """Get the observer or fake quant constructor for the Argument `arg`, as input + to Node `node` + """ + assert isinstance(arg, Node) + # "input_qspec_map" is the more general design we'll use for pt2e path + # it is a map from input argument node to observer or fake quant constructor, for example + # for the following graph: + # x -> conv -> output + # + # we may annotate conv node like the following: + # conv.meta[...] = QuantizationAnnotation("input_qspec_map": {x: MinMaxObserver.with_args(dtype=torch.qint8)}, ...) + # + if "quantization_annotation" in node.meta: + input_qspec_map = node.meta["quantization_annotation"].input_qspec_map + input_arg_qspec = _get_qspec_for_arg(arg, input_qspec_map, named_modules) + if input_arg_qspec is None: + input_arg_obs_or_fq = _DEFAULT_FP32_OBS_OR_FQ_CTR() + else: + input_arg_obs_or_fq = _create_obs_or_fq_from_qspec( + input_arg_qspec, obs_or_fq_map, is_qat + ) + return input_arg_obs_or_fq + + # we can remove the following path in the future if fx graph mode quantization is + # no longer used + is_weight = node_arg_is_weight(node, arg) + is_bias = node_arg_is_bias(node, arg) + is_activation = not is_weight and not is_bias + obs_or_fq_ctr = None + if is_activation: + obs_or_fq_ctr = node.meta["target_dtype_info"].get( + "input_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR + ) + elif is_weight: + if node.target not in NON_QUANTIZABLE_WEIGHT_OPS: + obs_or_fq_ctr = node.meta["target_dtype_info"].get( + "weight_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR + ) + else: + obs_or_fq_ctr = node.meta["target_dtype_info"].get( + "bias_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR + ) + return obs_or_fq_ctr() if obs_or_fq_ctr else None + + +def _maybe_insert_input_observer_for_arg_or_kwarg( + node: Union[Node, Any], + arg: Argument, + qconfig: QConfigAny, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + graph: Graph, + qhandler: Optional[QuantizeHandler], + prepare_custom_config: PrepareCustomConfig, + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, + backend_config: Optional[BackendConfig] = None, +) -> Argument: + """ + Given a `node` and an `arg`, inserts an input observer between + `node` and `arg` if necessary. + """ + # for ops such as torch.cat([x0, x1]), + # traverse through the list + if isinstance(arg, (list, tuple)): + new_arg_to_return = [] + for inner_arg in arg: + new_inner_arg = _maybe_insert_input_observer_for_arg_or_kwarg( + node, + inner_arg, + qconfig, + model, + named_modules, + graph, + qhandler, + prepare_custom_config, + obs_or_fq_map, + is_qat, + backend_config, + ) + new_arg_to_return.append(new_inner_arg) + return type(arg)(new_arg_to_return) + + if not isinstance(arg, Node): + return arg + assert isinstance(arg, Node) + # default (no observer) + new_arg = arg + + is_standalone_module = qhandler is not None and qhandler.is_standalone_module() + # TODO: move this to a separate function + if not is_standalone_module: + # Note: qconfig can be None in this branch this we are getting act/fq from + # node.meta now + # regular flow for most nodes, except standalone modules + + if "quantization_annotation" in node.meta: + reuse_input_obs_or_fq = node.meta[ + "quantization_annotation" + ]._reuse_input_obs_or_fq + else: + assert "target_dtype_info" in node.meta + # TODO: we are assuming "target_dtype_info" exists here, maybe + # a default value also need to be provided here + target_dtype_info = node.meta["target_dtype_info"] + # for nodes that doesn't have `reuse_input_obs_or_fq` configured, + # we'll default to False, this makes configuring this field optional for users + reuse_input_obs_or_fq = target_dtype_info.get( + "reuse_input_obs_or_fq", False + ) + arg_as_input_act_obs_or_fq = _get_arg_as_input_act_obs_or_fq( + arg, node, named_modules, obs_or_fq_map, is_qat + ) + ( + arg_as_input_target_dtype, + arg_as_input_target_is_dynamic, + ) = _get_dtype_and_is_dynamic(arg_as_input_act_obs_or_fq) + + arg_as_output_act_obs_or_fq = _get_output_act_obs_or_fq( + arg, named_modules, obs_or_fq_map, is_qat + ) + ( + arg_as_output_target_dtype, + arg_as_output_target_is_dynamic, + ) = _get_dtype_and_is_dynamic(arg_as_output_act_obs_or_fq) + + needs_obs_or_fq = _needs_obs_or_fq( + arg_as_output_target_dtype, + arg_as_output_target_is_dynamic, + arg_as_input_target_dtype, + arg_as_input_target_is_dynamic, + reuse_input_obs_or_fq, + is_zeroth_arg=len(node.args) > 0 and arg is node.args[0], + ) + + else: + assert qconfig is not None + # custom flow for standalone modules + _, _, sm_prepare_custom_config, _ = _get_standalone_module_configs( + node, named_modules, prepare_custom_config, qconfig, backend_config + ) + sm_input_quantized_idxs = sm_prepare_custom_config.input_quantized_indexes + + # for args, this is set to the index of the current arg + # for kwargs, this is left at None + cur_input_idx = None + for arg_idx, arg_to_check in enumerate(node.args): + if arg_to_check is arg: + cur_input_idx = arg_idx + break + + if cur_input_idx is None: + needs_obs_or_fq = False + else: + arg_as_output_target_dtype = _get_arg_target_dtype_as_output( + arg, named_modules, obs_or_fq_map, is_qat + ) + arg_as_input_target_dtype = ( + torch.quint8 + if cur_input_idx in sm_input_quantized_idxs + else torch.float + ) + needs_obs_or_fq = ( + arg_as_output_target_dtype != arg_as_input_target_dtype + ) and (arg_as_input_target_dtype != torch.float) + + act_post_process_ctr = qconfig.activation + arg_as_input_act_obs_or_fq = ( + act_post_process_ctr() if act_post_process_ctr else None + ) + + if needs_obs_or_fq: + existing_obs_node = None + + # Before using the new observer, check if an observer + # of the correct type already exists. If it does, use it. + # This prevents duplicate observer insertions if a node is + # used by multiple nodes. + # TODO: this is looking into how the value is used in the future + # we should remove this + # removing this means we insert one observer for each use, even if they + # have the same dtype, we can have an extra pass that removes the extra observers + for maybe_obs_node in arg.users.keys(): + if maybe_obs_node.op == "call_module": + maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index] + if ( + type(maybe_obs_mod) == type(arg_as_input_act_obs_or_fq) + and maybe_obs_mod.dtype == arg_as_input_target_dtype # type: ignore[possibly-undefined] + ): + arg_as_input_act_obs_or_fq = maybe_obs_mod # type: ignore[assignment] + existing_obs_node = maybe_obs_node + break + + assert arg_as_input_act_obs_or_fq is not None + obs_or_fq_map[(arg, node)] = arg_as_input_act_obs_or_fq + if existing_obs_node is None: + new_obs_node = _insert_obs_or_fq( + arg, arg_as_input_act_obs_or_fq, model, named_modules, graph + ) + # override this arg to be the observed arg + new_arg = new_obs_node + else: + new_arg = existing_obs_node + + return new_arg + + +def _maybe_insert_input_observers_for_node( + node: Node, + qconfig: QConfigAny, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + graph: Graph, + qhandler: Optional[QuantizeHandler], + prepare_custom_config: PrepareCustomConfig, + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, + backend_config: Optional[BackendConfig] = None, +) -> None: + """ + If needed, inserts observers to the input args and kwargs of `node`. + Note: modifies `node` inplace. + + For example, if cur_node needs an observer after prev_node, we change from + + prev_node -> cur_node + + To + + prev_node -> obs -> cur_node + + Note: backend_config only needed for standalone_module node + """ + # Look through every input arg. If that arg's target dtype does not + # match the current node's target dtype, insert an observer. + new_args = [] + for arg in node.args: + new_arg = _maybe_insert_input_observer_for_arg_or_kwarg( + node, + arg, + qconfig, + model, + named_modules, + graph, + qhandler, + prepare_custom_config, + obs_or_fq_map, + is_qat, + backend_config, + ) + new_args.append(new_arg) + + new_kwargs = {} + for k, kwarg in node.kwargs.items(): + new_kwarg = _maybe_insert_input_observer_for_arg_or_kwarg( + node, + kwarg, + qconfig, + model, + named_modules, + graph, + qhandler, + prepare_custom_config, + obs_or_fq_map, + is_qat, + backend_config, + ) + new_kwargs[k] = new_kwarg + + # assign the new args and kwargs to the node, inplace + node.args = tuple(new_args) + node.kwargs = new_kwargs + + +def _maybe_insert_input_equalization_observers_for_node( + node: Node, + equalization_qconfig: Any, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + graph: Graph, + is_branch: bool, +) -> None: + """ + If `node` needs to be equalized, find the input/weight observers it needs in + `equalization_qconfig`, creates them, and inserts it into `graph`. + + If `node` does not need an equalization observer, returns None. + """ + if equalization_qconfig is None or not node_supports_equalization( + node, named_modules + ): + return + + if is_branch: + warnings.warn(f"Cannot equalize {node} because it is part of a branch.") + return + + new_args = [] + for arg in node.args: + if not isinstance(arg, Node) or node_arg_is_bias(node, arg): + new_args.append(arg) + continue + + is_weight = node_arg_is_weight(node, arg) + + act_eq_process_ctr = ( + equalization_qconfig.weight + if is_weight + else equalization_qconfig.input_activation + ) + + new_eq_obs_mod = act_eq_process_ctr() + new_eq_obs_node = _insert_obs_or_fq( + arg, new_eq_obs_mod, model, named_modules, graph + ) + + new_args.append(new_eq_obs_node) + + # assign the new args and kwargs to the node, inplace + node.args = tuple(new_args) + + +def _maybe_insert_output_observer_for_node( + node: Node, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + graph: Graph, + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, +) -> Optional[Node]: + """ + If `node` needs an output observer, creates it, inserts it into `graph` + and returns it. + + If `node` does not need an output observer, returns None. + + Note: inserting dynamic quantization ops for output is not supported in fx graph mode + quantization code path right now + """ + assert node.op != "output", "observer insertion for outputs is handled elsewhere" + + is_standalone_module = False + if "quantization_annotation" in node.meta: + output_act_obs_or_fq = _create_obs_or_fq_from_qspec( + node.meta["quantization_annotation"].output_qspec, obs_or_fq_map, is_qat + ) + else: + assert "target_dtype_info" in node.meta + is_standalone_module = node.meta["target_dtype_info"].get( + "_is_standalone_module", False + ) + output_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get( + "output_act_obs_or_fq_ctr" + ) + output_act_obs_or_fq = ( + output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None + ) + target_dtype, target_is_dynamic = _get_dtype_and_is_dynamic(output_act_obs_or_fq) + # uncomment after we support reuse_input_obs_or_fq properly by having separate + # implemntations for this key instead of reusing the input_output_share_observers + # code + # reuse_input_obs_or_fq = node.meta["target_dtype_info"].get("reuse_input_obs_or_fq", False) + # for now we set this to False since reuse_input_obs_or_fq for + # the output of a node is implementation in the same code path as observer sharing, + # we should refactor this part to make it clearer in the future + # and we would be able to read this from config directly + reuse_input_obs_or_fq = False + + # Note: prev_output_dtype = torch.float and prev_output_is_dynamic=False + # because the prev_output is the output of an fp32 op, althought technically + # we should get the dtype of the output from node.meta["val"] in the future + # if we deprecate fx graph mode quantization + needs_obs_or_fq = _needs_obs_or_fq( + torch.float, False, target_dtype, target_is_dynamic, reuse_input_obs_or_fq + ) + # currently the activation in QConfig(activation=...,) is for both input + # and output, and when the activation is configured to be dynamic quantization + # e.g. PlaceholderObserver(dtype=torch.quint8, is_dynamic=True, ...), it means + # the input should by dynamically quantized, but output should not be quantized + # + # there is no way we can specify different observer/fq for input and output + # activation through QConfig today, this limitation is lifted in the + # quantizer/annotation API in pytorch 2.0 export quantization code path, + # but since this code is reused, annotating output to be dynamically quantized + # would not work either for that. + # we can change QConfig to support input/output activation if we want + # to remove the following check, or if we can deprecate fx graph mode quantization + if target_is_dynamic: + needs_obs_or_fq = False + + # we never insert observers to output of standalone module, we assume + # if needed, they are inserted inside the standalone module + needs_obs_or_fq = needs_obs_or_fq and (not is_standalone_module) + + if needs_obs_or_fq: + obs_or_fq_map[node] = output_act_obs_or_fq + return _insert_obs_or_fq( + node, output_act_obs_or_fq, model, named_modules, graph + ) + else: + return None + + +def _maybe_insert_observers_before_graph_output( + graph_output_node: Node, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + graph: Graph, + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, +) -> None: + """ + If the output needs to be quantized and there are any nodes + in the output which are not already observed, inserts observers + for those nodes. + """ + + def _recursive_maybe_replace_node_with_obs( + maybe_node: Argument, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + graph: Graph, + ) -> Argument: + """ + Navigate an arbitrary data structure of lists, tuples, dicts. + For each container type, recurse on all inputs. Once any Node + is found, insert an observer if needed and do not recurse further. + + For example, given a structure of + + {'foo1': [[bar1]], 'foo2': {'foo3': [[[bar3]]]}} + + we recurse down to bar1 and bar3, observe them if necessary, + and if we inserted an observer then replace the original node + with its observer. + + Returns the data structure with all nodes needing observation being + replaced by their observers. + """ + if isinstance(maybe_node, Node): + # check dtype of this node + arg_as_output_target_dtype = _get_arg_target_dtype_as_output( + maybe_node, named_modules, obs_or_fq_map, is_qat + ) + observer_mod = None + arg_as_input_target_dtype = torch.float + if "target_dtype_info" in maybe_node.meta: + observer_cls = maybe_node.meta["target_dtype_info"].get( + "input_act_obs_or_fq_ctr", None + ) + if observer_cls is not None: + observer_mod = observer_cls() + arg_as_input_target_dtype = observer_mod.dtype + # TODO: this does not handle dynamic quantization yet + need_obs = ( + arg_as_output_target_dtype != arg_as_input_target_dtype + and arg_as_input_target_dtype != torch.float + ) + if need_obs: + assert observer_mod is not None + # insert observer + observer_node = _insert_obs_or_fq( + maybe_node, observer_mod, model, named_modules, graph + ) + return observer_node + else: + return maybe_node + elif isinstance(maybe_node, (list, tuple)): + results = [ + _recursive_maybe_replace_node_with_obs( + inner_node, model, named_modules, graph + ) + for inner_node in maybe_node + ] + if isinstance(maybe_node, list): + return results + else: + return tuple(results) + elif isinstance(maybe_node, dict): + results_dict = {} + for k, inner_v in maybe_node.items(): + results_dict[k] = _recursive_maybe_replace_node_with_obs( + inner_v, model, named_modules, graph + ) + return results_dict + elif maybe_node is None: + return None + else: + raise Exception( # noqa: TRY002 + "Unhandled type for returned node:", maybe_node + ) + + new_args = [ + _recursive_maybe_replace_node_with_obs(old_arg, model, named_modules, graph) + for old_arg in graph_output_node.args + ] + + graph_output_node.args = tuple(new_args) # type: ignore[assignment] + + +def _maybe_propagate_dtype_for_node( + node: Node, + target_dtype: Union[torch.dtype, type], + node_name_to_match_result_with_qconfig: dict[str, _MatchResultWithQConfig], +) -> None: + """ + Assigns `target_dtype` to `node`, setting `is_dynamic` to False. If `node` + is a general tensor shape op, also call this function recursively on + the first argument, to propagate the dtype to the caller. + """ + node.meta["target_dtype_info"]["input_act_obs_or_fq_ctr"] = None + node.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"] = None + # if this is a copy node, propagate to first arg + ( + _root_node, + _, + _pattern, + qhandler, + _qconfig, + ) = node_name_to_match_result_with_qconfig.get( + node.name, (None, None, None, None, None) + ) + # TODO: probably need to remove `is_general_tensor_value_op` + if qhandler is not None and qhandler.is_general_tensor_value_op(): + prev_node = node.args[0] + if isinstance(prev_node, Node): + _maybe_propagate_dtype_for_node( + prev_node, target_dtype, node_name_to_match_result_with_qconfig + ) + + +def propagate_dtypes_for_known_nodes( + graph: Graph, + node_name_to_match_result_with_qconfig: dict[str, _MatchResultWithQConfig], +) -> None: + """ + Currently we assume that inputs to the graph are either `torch.float` or + `torch.quint8`, which is not always correct. For ops such as + `x.masked_fill(mask, value)`, we know that the dtype of `mask` is a + `BoolTensor`. Propagate this information throughout the graph. + + Note: not all dtypes in the graph will be correct after this pass, but a + higher percentage of them will be correct. Hopefully in the future we can + replace this with a better way to reason about dtypes of tensors. + """ + for node in graph.nodes: + non_observable_arg_dict = get_non_observable_arg_indexes_and_types(node) + + for arg_type in non_observable_arg_dict: + non_observable_indices = non_observable_arg_dict[arg_type](node) + + for index in non_observable_indices: + arg = node.args[index] + + # when an argument is a tuple, it does not show up as another node so we need to go through + # all elements of the tuple manually + if isinstance(arg, (tuple, list)): + arg_list = list(arg) + else: + arg_list = [arg] + + for cur_arg in arg_list: + # hard coded arguments show up but aren't `Node` typed and do not need dtype propagated + if isinstance(cur_arg, torch.fx.node.Node): + _maybe_propagate_dtype_for_node( + cur_arg, arg_type, node_name_to_match_result_with_qconfig + ) + + +def _maybe_make_input_output_share_observers( + node: Node, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], +) -> bool: + """ + Ensures that we share an observer + for all input arguments as well as the output argument. In detail, given + a graph of + + x0 -> obs0 -> op -> x2 + / + x1 -> obs1 / + + where node obs0 points to observer instance observer0, + obs1 points to observer1 and obs2 points to observer2, we make nodes obs1 + and ob2 point to observer0. + Returns: whether the operation succeeded or not + """ + first_arg = None + # find the first non-Tensor arg + for i in range(len(node.args)): + if isinstance(node.args[i], (Node, list, tuple)): + first_arg = node.args[i] + break + + # if there is no non-Tensor arg, return directly + if first_arg is None: + return False + + if isinstance(first_arg, (list, tuple)): + first_arg_arg = first_arg[0] + elif isinstance(first_arg, Node): + first_arg_arg = first_arg + else: + return False + + # if we have a graph such as + # observed_node -> non_observed_node -> cat + # we need to navigate up to the first observer + iteration_guard = 0 + while not _is_activation_post_process_node(first_arg_arg, named_modules): + if not isinstance(first_arg_arg, Node): + return False + # did not find an activation_post_process for the op + if first_arg_arg.op == "placeholder": + return False + # trace back the args until we found the first Tensor/Node + trace_back_node = None + for i in range(len(first_arg_arg.args)): + trace_back_node = first_arg_arg.args[i] + if isinstance(trace_back_node, Node): + break + if trace_back_node is None: + return False + first_arg_arg = trace_back_node + + iteration_guard += 1 + if iteration_guard > 10000: + raise AssertionError("Unable to find observer of previous node") + + assert isinstance(first_arg_arg, Node) + target_to_use = first_arg_arg.target + assert isinstance(target_to_use, str) + obs_mod_to_use = named_modules[target_to_use] + + if isinstance(first_arg, (list, tuple)): + # set all other input observer nodes to use that module + for input_idx, input_arg in enumerate(first_arg): + if input_idx == 0: + continue + iteration_guard = 0 + while not _is_activation_post_process_node(input_arg, named_modules): + # failed to trace back since no input arg for the current node + if len(input_arg.args) < 1: + return False + input_arg = input_arg.args[0] + iteration_guard += 1 + if iteration_guard > 10000: + raise AssertionError("Unable to find observer of previous node") + + parent_name, name = _parent_name(input_arg.target) + setattr(named_modules[parent_name], name, obs_mod_to_use) + + # set the output observer node to use that module + for output_obs_node in node.users.keys(): + assert _is_activation_post_process_node(output_obs_node, named_modules) + parent_name, name = _parent_name(output_obs_node.target) + setattr(named_modules[parent_name], name, obs_mod_to_use) + + # TODO(future PR): delete the orphaned observer modules + return True + + +def _remove_output_observer( + node: Node, model: torch.nn.Module, named_modules: dict[str, torch.nn.Module] +): + items = list(node.users.items()) + for output_obs_node, _ in items: + assert _is_activation_post_process_node(output_obs_node, named_modules) + output_obs_node.replace_all_uses_with(node) + model.graph.erase_node(output_obs_node) # type: ignore[union-attr, operator] + + +def _swap_custom_module_to_observed( + node: Node, + qconfig: QConfigAny, + named_modules: dict[str, torch.nn.Module], + prepare_custom_config: PrepareCustomConfig, +): + custom_module = named_modules[node.target] # type: ignore[index] + custom_module_class_mapping = prepare_custom_config.float_to_observed_mapping + observed_custom_module_class = get_swapped_custom_module_class( + custom_module, custom_module_class_mapping, qconfig + ) + observed_custom_module = observed_custom_module_class.from_float(custom_module) + parent_name, name = _parent_name(node.target) + setattr(named_modules[parent_name], name, observed_custom_module) + + +def insert_observers_for_model( + model: GraphModule, + node_name_to_match_result_with_qconfig: dict[str, _MatchResultWithQConfig], + node_name_to_qconfig: dict[str, QConfigAny], + prepare_custom_config: PrepareCustomConfig, + equalization_config_map: dict[str, Any], + backend_config: BackendConfig, + observed_node_names: set[str], + is_qat: bool, +) -> Optional[Node]: + """ + Inserts observers, using the following high level algorithm: + + For each node in the graph: + 1. determine the target dtype of this node in the quantized graph, and save + it for future steps + 2. determine the target dtype or all args and kwargs of this node + 3. if any arg or kwarg's target dtype does not match the current node's + dtype, insert an observer + 4. if the current node needs an output observer, insert it + + For example: + + - starting graph: + x0 -> linear -> x1 + + - observed graph after processing x0: + x0(fp32) + + - observed graph after processing linear: + x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8) + + - observed graph after processing x1: + x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8) -> x1 + + After a node is processed, the naive observer placement is guaranteed to be + complete for that node and all of its predecessors. There can be future + passes which optimize the graph by deduplicating observers, etc. + """ + + # node.meta["target_dtype_info"] stores the target dtype information + # that's derived from qconfig for the Node, for example, if we have + # a conv2d node that has a qconfig + # qconfig = QConfig(activation=..., weight=...) + # # information for input and bias node omitted + # # for getattr node + # # weight = getattr(self, 'weight') + # weight.meta["target_dtype_info"] = { + # 'output_act_obs_or_fq_ctr': qconfig.weight, + # } + # # for conv2d node + # # conv2d = call_function[target=torch.nn.functional.conv2d]( + # # args=(input, weight, bias)) + # conv2d.meta["target_dtype_info"] = { + # 'input_act_obs_or_fq_ctr': qconfig.activation + # 'weight_obs_or_fq_ctr': qconfig.weight, + # 'bias_obs_or_fq_ctr': PlaceholderObserver.with_args(dtype=torch.float32), + # 'output_act_obs_or_fq_ctr': qconfig.activation, + # } + # + cache_for_no_tensor_check: dict[Node, bool] = {} + + # first, populate the dtype map based only on qconfig and qhandler + # this assumes: + # graph inputs are fp32 by default, and int8 where overridden + # other nodes output dtype is specified by the qconfig + named_modules = dict(model.named_modules(remove_duplicate=False)) + + input_quantized_idxs: list[int] = prepare_custom_config.input_quantized_indexes + output_quantized_idxs: list[int] = prepare_custom_config.output_quantized_indexes + processed_nodes: set[Node] = set() + # initialize target_dtype_info + for node in model.graph.nodes: + node.meta["target_dtype_info"] = copy.copy( + _DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO + ) + + inputs_seen_counter = 0 + outputs_seen_counter = 0 + placeholder_node_to_input_index: dict[Node, int] = {} + # TODO: we probably don't need this counter since each graph will only have + # one output node? + output_node_to_output_index: dict[Node, int] = {} + for node in model.graph.nodes: + if node.op == "placeholder": + placeholder_node_to_input_index[node] = inputs_seen_counter + inputs_seen_counter += 1 + if node.op == "output": + output_node_to_output_index[node] = outputs_seen_counter + outputs_seen_counter += 1 + + # Step 1, set the observer or fake quantize module constructor for each node in the + # matched_node_pattern + + for match_res_with_qconfig in node_name_to_match_result_with_qconfig.values(): + ( + last_node, + matched_node_pattern, + pattern, + qhandler, + qconfig, + ) = match_res_with_qconfig + assert qhandler is not None + _set_target_dtype_info_for_matched_node_pattern( + matched_node_pattern, + last_node, + qconfig, + qhandler, + backend_config, + named_modules, + cache_for_no_tensor_check, + processed_nodes, + ) + + # Step 2. Special cases for some operators, we might be able to remove them + # in the future if we know dtype information of each node better + + # Step 2.1. some settings are not based on patterns, we need to process each node + # instead + for node in model.graph.nodes: + if ( + node.op == "placeholder" + and placeholder_node_to_input_index[node] in input_quantized_idxs + ): + # users are not supposed to call calculate_qparams on PlaceholderObserver, and + # this is OK because we are using this as a way to encode the dtypes of input + # tensor, we won't actually insert these observers in the graph and won't + # actually call calculate_qparams + node.meta["target_dtype_info"] = copy.copy( + _DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO + ) + elif node.op in ("call_module", "call_method", "call_function"): + args_have_no_tensors = all_node_args_have_no_tensors( + node, named_modules, cache_for_no_tensor_check + ) + if args_have_no_tensors: + node.meta["target_dtype_info"] = { + "input_act_obs_or_fq_ctr": None, + "output_act_obs_or_fq_ctr": None, + } + elif ( + node.op == "output" + and output_node_to_output_index[node] in output_quantized_idxs + ): + # TODO(future PR): update the output_quantized_idxs API to match + # arbitrary data structures. There is always a single output, and + # that output can have arbitrary nesting of values. List[int] is + # not the right data type for this. + + # TODO(future PR): support more dtypes in model outputs, if necessary + node.meta["target_dtype_info"] = copy.copy( + _DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO + ) + + # Step 2.2, for nodes with known input dtypes, propagate them throughout the + # graph. For example, if there is a call such as + # x1 = x0.masked_fill(mask, 1) + # we propagate the type of mask to be torch.bool + propagate_dtypes_for_known_nodes( + model.graph, node_name_to_match_result_with_qconfig + ) + + # Step 3, check if the requested target_dtype_info is supported by backend or not + # if not, we'll reset the target_dtye_info to use the default (float Tensor) + + # reset the counters and set of processed_nodes + processed_nodes: set[Node] = set() + for match_res_with_qconfig in node_name_to_match_result_with_qconfig.values(): + ( + last_node, + matched_node_pattern, + pattern, + qhandler, + qconfig, + ) = match_res_with_qconfig + is_supported_by_backend = ( + _is_pattern_dtype_config_and_qconfig_supported_by_backend( + pattern, matched_node_pattern, qconfig, backend_config + ) + ) + assert qhandler is not None + + # get output_act_dtype so that we don't also reset the special typed nodes + # TODO: we might want to handle these more uniformly with the default path + # this can be improved if we can use node.meta["val"] + output_act_or_fq_ctr = node.meta["target_dtype_info"][ + "output_act_obs_or_fq_ctr" + ] + output_act_or_fq = output_act_or_fq_ctr() if output_act_or_fq_ctr else None + output_act_dtype, _ = _get_dtype_and_is_dynamic(output_act_or_fq) + if not is_supported_by_backend and output_act_dtype not in [ + None, + int, + float, + torch.bool, + ]: + # restore target_dtype_info to default if it is not supported by backend + _set_target_dtype_info_for_matched_node_pattern( + matched_node_pattern, + last_node, + torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig, + None, + backend_config, + named_modules, + cache_for_no_tensor_check, + processed_nodes, + ) + + # After this point, the current node and all of its arguments + # have a target_dtype_info assigned. Now, we insert observers for inputs + # of this node (if needed for this node), and the output of this node + # (if needed for this node). + + # Since we are mutating the graph as we go, we iterate over the original + # nodes before observer insertion, instead of model.graph.nodes. + nodes_before_observation = list(model.graph.nodes) + + # Avoid duplicates custom module swaps for multiple nodes with same target. + custom_module_names_already_swapped: set[str] = set() + + # TODO: reuse placeholder_node_to_input_index and output_node_to_output_index + # reset inputs/outputs counters + inputs_seen_counter = 0 + outputs_seen_counter = 0 + results_node = None + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize] = {} + + # TODO: change this to insert obs/fq by pattern instead of by node + for node in nodes_before_observation: + if node.op == "placeholder": + # if a graph input is in fp32, it does not need observation + # if a graph input is in int8, we assume the observation happens + # outside of the graph, and no additional observation is needed + pass + + elif node.op in ("call_module", "call_method", "call_function", "output"): + # check for matches + ( + last_node, + matched_node_pattern, + pattern, + qhandler, + qconfig, + ) = node_name_to_match_result_with_qconfig.get( # type: ignore[assignment] + node.name, (None, None, None, None, None) + ) + equalization_qconfig = equalization_config_map.get(node.name, None) + + this_node_dtype_info = node.meta["target_dtype_info"] + if "val" in node.meta: + output_is_a_tensor = this_node_dtype_info is not None and isinstance( + node.meta["val"], FakeTensor + ) + else: + output_is_a_tensor = this_node_dtype_info is not None + + skip_inserting_observers = ( + (qconfig is None) or not output_is_a_tensor + ) and (not node.op == "output") + + # TODO: take a closer look to see if we can remove this check + # right now it is here because of `observed_node_names`, we are using + # it as an indicator for swapping the modules to reference modules in + # convert + is_supported_by_backend = ( + _is_pattern_dtype_config_and_qconfig_supported_by_backend( + pattern, matched_node_pattern, qconfig, backend_config + ) + ) + + if not skip_inserting_observers and is_supported_by_backend: + named_modules = dict(model.named_modules(remove_duplicate=False)) + if node.op != "output": + assert matched_node_pattern is not None + # add matched nodes to the observed node name set + _add_matched_node_name_to_set( + matched_node_pattern, observed_node_names + ) + + # This is currently only used for equalization. + # Checks if the current node is in a branch in which the two + # first layers are both being quantized. + # + # ex. conv2 + # / + # x -> conv1 + # + # If this is the case, we will not apply equalization to the + # initial two layers. + is_quantized_branch = False + if ( + len(node.args) > 0 + and isinstance(node.args[0], Node) + and len(node.args[0].users) > 1 + ): + for user in node.args[0].users: + # Checks if there exists another user being quantized + is_user_quantized = node_name_to_qconfig.get( + user.name, None + ) is not None or ( + user.op == "call_module" + and isinstance( + named_modules[str(user.target)], ObserverBase + ) + ) + if user != node and is_user_quantized: + is_quantized_branch = True + + pattern_to_root_node_getter = ( + get_fusion_pattern_to_root_node_getter(backend_config) + ) + root_node_getter = pattern_to_root_node_getter.get( + pattern, _default_root_node_getter + ) + root_node = root_node_getter(matched_node_pattern) + is_input_node_of_the_pattern = node is root_node + if is_input_node_of_the_pattern: + # this modifies node inplace + _maybe_insert_input_observers_for_node( + node, + qconfig, + model, + named_modules, + model.graph, + qhandler, + prepare_custom_config, + obs_or_fq_map, + is_qat, + backend_config, + ) + + # insert equalization input observers if needed + _maybe_insert_input_equalization_observers_for_node( + node, + equalization_qconfig, + model, + named_modules, + model.graph, + is_quantized_branch, + ) + + is_last_node_of_pattern = node is last_node + input_output_share_observers = node.meta["target_dtype_info"].get( + "input_output_share_observers", False + ) + reuse_input_obs_or_fq = node.meta["target_dtype_info"].get( + "reuse_input_obs_or_fq", False + ) + + if is_last_node_of_pattern: + if _is_custom_module_lstm( + node, named_modules, qconfig, qhandler + ): + # Currently custom module outputs are assumed to be already quantized, + # so we need to insert a DeQuantStub after the output. For custom module + # LSTM specifically, the outputs are also a nested tuple, so we must first + # break down the tuple to insert DeQuantStubs after the internal nodes. + + # TODO: This currently diverges from how custom modules are handled today, + # where we insert observers after the output instead of DeQuantStubs, and + # replace these observers with "dequantize" nodes during convert. Conceptually, + # these output observers are the same as DeQuantStubs. In the future, we + # should resolve this inconsistency by inserting DeQuantStubs for all custom + # modules, not just for LSTM. + _insert_dequant_stubs_for_custom_module_lstm_output( + node, model, named_modules, model.graph + ) + if node.target not in custom_module_names_already_swapped: + custom_module_names_already_swapped.add(node.target) + _swap_custom_module_to_observed( + node, qconfig, named_modules, prepare_custom_config + ) + else: + # this returns the new observer node if it was needed + maybe_output_obs_node = ( + _maybe_insert_output_observer_for_node( + node, + model, + named_modules, + model.graph, + obs_or_fq_map, + is_qat, + ) + ) + + if maybe_output_obs_node is not None: + # Update users of original node to use the output observer + # instead. For example, change + # + # next_node + # / + # cur_node -> obs + # + # to + # + # next_node + # / + # cur_node -> obs + # + # We need to save orig users before updating uses because + # the list of users will change as we update uses + orig_users = list(node.users.keys()) + for user_node in orig_users: + if user_node is maybe_output_obs_node: + continue + user_node.replace_input_with( + node, maybe_output_obs_node + ) + + _is_observer_in_same_graph_ = ( + _is_observer_in_same_graph( + node, named_modules, obs_or_fq_map, is_qat + ) + ) + + # for ops whose inputs and outputs share observer/fqs, we modify the graph + # to make all inputs and outputs use the first input's + # observer/fq + if ( + input_output_share_observers + and _is_observer_in_same_graph_ + ) or reuse_input_obs_or_fq: + if not _maybe_make_input_output_share_observers( + node, model, named_modules + ): + _remove_output_observer( + node, model, named_modules + ) + + if qhandler is not None and qhandler.is_custom_module(): + if ( + node.target + not in custom_module_names_already_swapped + ): + custom_module_names_already_swapped.add( + node.target + ) + _swap_custom_module_to_observed( + node, + qconfig, + named_modules, + prepare_custom_config, + ) + + else: # output + _maybe_insert_observers_before_graph_output( + node, model, named_modules, model.graph, obs_or_fq_map, is_qat + ) + + # + # After this point, the current node has input and output observers + # that it needs for itself inserted. + # + + # increment the counters, so future inputs and outputs are assigned + # correct dtypes + if node.op == "placeholder": + inputs_seen_counter += 1 + elif node.op == "output": + outputs_seen_counter += 1 + results_node = node + + return results_node + + +def _run_prepare_fx_on_standalone_modules( + model: torch.nn.Module, + is_qat: bool, + named_modules: dict[str, torch.nn.Module], + node_name_to_match_result_with_qconfig: Any, + prepare_custom_config: PrepareCustomConfig, + backend_config: BackendConfig, +) -> None: + """ + Runs prepare_fx on each standalone module. Note: this does + not modify the graph, it just replaces the unobserved modules with + their observed versions. + """ + for ( + root_node, + _, + _pattern, + qhandler, + qconfig, + ) in node_name_to_match_result_with_qconfig.values(): + if qhandler is None: + continue + elif not qhandler.is_standalone_module(): + continue + + ( + sm_qconfig_mapping, + sm_example_inputs, + sm_prepare_custom_config, + sm_backend_config, + ) = _get_standalone_module_configs( + root_node, named_modules, prepare_custom_config, qconfig, backend_config + ) + + standalone_module = named_modules[root_node.target] + prepare = torch.ao.quantization.quantize_fx._prepare_standalone_module_fx # type: ignore[attr-defined] + observed_standalone_module = prepare( + standalone_module, + sm_qconfig_mapping, + is_qat, + example_inputs=sm_example_inputs, + prepare_custom_config=sm_prepare_custom_config, + backend_config=sm_backend_config, + ) + parent_name, name = _parent_name(root_node.target) + setattr(named_modules[parent_name], name, observed_standalone_module) + named_modules[root_node.target] = observed_standalone_module + + +def _save_state( + observed: GraphModule, + node_name_to_qconfig: dict[str, QConfigAny], + node_name_to_scope: dict[str, tuple[str, type]], + prepare_custom_config: PrepareCustomConfig, + equalization_node_name_to_qconfig: dict[str, Any], + qconfig_mapping: QConfigMapping, + is_qat: bool, + observed_node_names: set[str], +) -> None: + observed.meta["_observed_graph_module_attrs"] = ObservedGraphModuleAttrs( + node_name_to_qconfig=node_name_to_qconfig, + node_name_to_scope=node_name_to_scope, + prepare_custom_config=prepare_custom_config, + equalization_node_name_to_qconfig=equalization_node_name_to_qconfig, + qconfig_mapping=qconfig_mapping, + is_qat=is_qat, + observed_node_names=observed_node_names, + ) + + +def prepare( + model: GraphModule, + qconfig_mapping: Union[QConfigMapping, dict[str, Any]], + is_qat: bool, + node_name_to_scope: dict[str, tuple[str, type]], + example_inputs: tuple[Any, ...], + prepare_custom_config: Union[PrepareCustomConfig, dict[str, Any], None] = None, + _equalization_config: Union[QConfigMapping, dict[str, Any], None] = None, + backend_config: Union[BackendConfig, dict[str, Any], None] = None, + is_standalone_module: bool = False, +) -> GraphModule: + """standalone_module means it a submodule that is not inlined in + parent module, and will be quantized separately as one unit. + + How the standalone module is observed is specified by `input_quantized_idxs` and + `output_quantized_idxs` in the prepare_custom_config for the standalone module + Args: + node_name_to_scope: mapping from node name to the scope of the module which contains the node. + The scope is a tuple of fully qualified path of the module and the type of the module + Returns: + model(GraphModule): prepared standalone module + attributes related to standalone module + in model.meta["_observed_graph_module_attrs"]: + is_observed_standalone_module (bool): boolean value that shows whether the + current model is a observed standalone module or not + standalone_module_input_quantized_idxs(List[Int]): a list of + indexes for the graph input that is expected to be quantized, + same as input_quantized_idxs configuration provided + for the standalone module + standalone_module_output_quantized_idxs(List[Int]): a list of + indexs for the graph output that is quantized + same as input_quantized_idxs configuration provided + for the standalone module + """ + if prepare_custom_config is None: + prepare_custom_config = PrepareCustomConfig() + if _equalization_config is None: + _equalization_config = QConfigMapping() + + if isinstance(qconfig_mapping, dict): + warnings.warn( + "Passing a QConfig dictionary to prepare is deprecated and will not be supported " + "in a future version. Please pass in a QConfigMapping instead.", + FutureWarning, + stacklevel=2, + ) + qconfig_mapping = QConfigMapping.from_dict(qconfig_mapping) + + if isinstance(_equalization_config, dict): + warnings.warn( + "Passing a QConfig dictionary to prepare for equalization is deprecated and will not " + "be supported in a future version. Please pass in a QConfigMapping instead.", + FutureWarning, + stacklevel=2, + ) + _equalization_config = QConfigMapping.from_dict(_equalization_config) + + if isinstance(prepare_custom_config, dict): + warnings.warn( + "Passing a prepare_custom_config_dict to prepare is deprecated and will not be supported " + "in a future version. Please pass in a PrepareCustomConfig instead.", + FutureWarning, + stacklevel=2, + ) + prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config) + + if isinstance(backend_config, dict): + warnings.warn( + "Passing a backend_config_dict to prepare is deprecated and will not be supported " + "in a future version. Please pass in a BackendConfig instead.", + FutureWarning, + stacklevel=2, + ) + backend_config = BackendConfig.from_dict(backend_config) + + assert isinstance(qconfig_mapping, QConfigMapping) + assert isinstance(_equalization_config, QConfigMapping) + qconfig_mapping = copy.deepcopy(qconfig_mapping) + _equalization_config = copy.deepcopy(_equalization_config) + + # mapping from a tuple of nodes in reverse order to uninitialized + # QuantizeHandler subclass. For example, + # { + # # match a single node + # (: + # ), + # # match multiple nodes in reverse order + # ((, ): + # ), + # } + + pattern_to_quantize_handler: dict[Pattern, QuantizeHandler] = {} + if backend_config is None: + backend_config = get_native_backend_config() + pattern_to_quantize_handler = _get_pattern_to_quantize_handlers(backend_config) + pattern_to_quantize_handler = _sorted_patterns_dict(pattern_to_quantize_handler) + + root_node_getter_mapping = get_fusion_pattern_to_root_node_getter(backend_config) + + _update_qconfig_for_fusion(model, qconfig_mapping) + _update_qconfig_for_fusion(model, _equalization_config) + flattened_qconfig_dict = _get_flattened_qconfig_dict(qconfig_mapping) + # TODO: support regex as well + propagate_qconfig_(model, flattened_qconfig_dict, prepare_custom_config.to_dict()) + + if is_qat: + module_to_qat_module = get_module_to_qat_module(backend_config) + _qat_swap_modules(model, module_to_qat_module) + _update_qconfig_for_qat(qconfig_mapping, backend_config) + + # mapping from fully qualified module name to module instance + # for example, + # { + # '': Model(...), + # 'linear': Linear(...), + # 'linear.weight_fake_quant': PerChannelMinMaxObserver(...), + # } + named_modules = dict(model.named_modules(remove_duplicate=False)) + + # fill node_name_to_qconfig, a map from node name to qconfig, used in _find_matches + equalization_node_name_to_qconfig = _generate_node_name_to_qconfig( + model, named_modules, model.graph, _equalization_config, node_name_to_scope + ) + node_name_to_qconfig = _generate_node_name_to_qconfig( + model, named_modules, model.graph, qconfig_mapping, node_name_to_scope + ) + + # match the patterns that will get quantized + standalone_module_names = list(prepare_custom_config.standalone_module_names.keys()) + standalone_module_classes = list( + prepare_custom_config.standalone_module_classes.keys() + ) + + custom_module_classes = get_custom_module_class_keys( + prepare_custom_config.float_to_observed_mapping + ) + matches_without_qconfig = _find_matches( + model.graph, + named_modules, + pattern_to_quantize_handler, + root_node_getter_mapping, + standalone_module_names, + standalone_module_classes, + custom_module_classes, + ) + + # map qconfig instances to matches + node_name_to_match_result_with_qconfig = {} + for node_name, match_without_qconfig in matches_without_qconfig.items(): + match_with_qconfig = (*match_without_qconfig, node_name_to_qconfig[node_name]) + node_name_to_match_result_with_qconfig[node_name] = match_with_qconfig + + _run_prepare_fx_on_standalone_modules( + model, + is_qat, + named_modules, + node_name_to_match_result_with_qconfig, + prepare_custom_config, + backend_config, + ) + + # record names for the set of observed node, so that in convert step + # we know whether we need to convert a floating point module to reference + # quantized module or not + observed_node_names: set[str] = set() + + result_node = insert_observers_for_model( + model, + node_name_to_match_result_with_qconfig, + node_name_to_qconfig, + prepare_custom_config, + equalization_node_name_to_qconfig, + backend_config, + observed_node_names, + is_qat, + ) + model = GraphModule(model, model.graph) + + _save_state( + model, + node_name_to_qconfig, + node_name_to_scope, + prepare_custom_config, + equalization_node_name_to_qconfig, + qconfig_mapping, + is_qat, + observed_node_names, + ) + + if is_standalone_module: + assert result_node is not None + assert isinstance(result_node.args[0], Node), ( + "standalone module only supports returning simple value currently" + "(not tuple, dict etc.)" + ) + # these inputs are observed in parent + # converting List[int] to Tensor since module attribute is + # Union[Tensor, Module] + input_quantized_idxs: list[int] = prepare_custom_config.input_quantized_indexes + output_quantized_idxs: list[int] = ( + prepare_custom_config.output_quantized_indexes + ) + observed_graph_module_attrs = model.meta["_observed_graph_module_attrs"] + # inplace modification + observed_graph_module_attrs.is_observed_standalone_module = True + observed_graph_module_attrs.standalone_module_input_quantized_idxs = ( + input_quantized_idxs + ) + observed_graph_module_attrs.standalone_module_output_quantized_idxs = ( + output_quantized_idxs + ) + return model diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/qconfig_mapping_utils.py b/phivenv/Lib/site-packages/torch/ao/quantization/fx/qconfig_mapping_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7dbba1dd33c52532f0e1421b8b4b20a5fa796afd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/fx/qconfig_mapping_utils.py @@ -0,0 +1,400 @@ +# mypy: allow-untyped-defs +import re +from collections import defaultdict, OrderedDict +from typing import Any, Callable, Union + +import torch +from torch.ao.nn.intrinsic import _FusedModule +from torch.ao.quantization import QConfig +from torch.ao.quantization.backend_config import BackendConfig, DTypeConfig +from torch.ao.quantization.backend_config.utils import get_module_to_qat_module +from torch.ao.quantization.observer import _is_activation_post_process +from torch.ao.quantization.qconfig import ( + _add_module_to_qconfig_obs_ctr, + qconfig_equals, + QConfigAny, +) +from torch.ao.quantization.qconfig_mapping import ( + _MODULE_NAME_DICT_KEY, + _MODULE_NAME_REGEX_DICT_KEY, + _OBJECT_TYPE_DICT_KEY, + QConfigMapping, +) +from torch.ao.quantization.utils import _parent_name, get_qconfig_dtypes +from torch.fx import GraphModule +from torch.fx.graph import Graph + + +__all__: list[str] = [] + + +def _maybe_adjust_qconfig_for_module_name_object_type_order( + qconfig_mapping: QConfigMapping, + cur_module_path: str, + cur_object_type: Callable, + cur_object_type_idx: int, + fallback_qconfig: QConfigAny, +) -> QConfigAny: + for ( + module_name, + object_type, + index, + ), qconfig in qconfig_mapping.module_name_object_type_order_qconfigs.items(): + if ( + (module_name == cur_module_path) + and (object_type == cur_object_type) + and (index == cur_object_type_idx) + ): + return qconfig + return fallback_qconfig + + +def _update_qconfig_for_fusion(model: GraphModule, qconfig_mapping: QConfigMapping): + """ + Update the QConfigMapping to account for fused modules such as LinearReLU. + This assumes the QConfigMapping's attributes have already been converted to OrderedDicts. + """ + object_type_dict = qconfig_mapping.object_type_qconfigs + if len(object_type_dict) == 0: + return qconfig_mapping + + modules = dict(model.named_modules()) + + for node in model.graph.nodes: + if node.op == "call_module" and node.target in modules: + maybe_fused_module = modules[str(node.target)] + if not isinstance(maybe_fused_module, _FusedModule): + continue + + ops = list(maybe_fused_module._modules.values()) + fused_qconfig = object_type_dict.get(type(ops[0]), None) + + # Raise an error if the modules in the fused module have + # different qconfigs specified in the qconfig_dict + # TODO: currently it only works for modules, + # need to make this work for torch.nn.functional.relu + # TODO: currently it only works for object_type configurations, + # ideally it should work for different types of configurations, + # maybe we want to redesign this part + for op in ops[1:]: + if not qconfig_equals( + object_type_dict.get(type(op), None), fused_qconfig + ): + raise LookupError( + "During fusion, we need to specify the same " + + f"qconfigs for all module types in {type(maybe_fused_module)} " + + f"offending type: {type(op)}" + ) + + if fused_qconfig is not None: + object_type_dict[type(maybe_fused_module)] = fused_qconfig + + +def _generate_node_name_to_qconfig( + root: torch.nn.Module, + modules: dict[str, torch.nn.Module], + input_graph: Graph, + qconfig_mapping: QConfigMapping, + node_name_to_scope: dict[str, tuple[str, type]], +) -> dict[str, QConfigAny]: + global_qconfig = qconfig_mapping.global_qconfig + node_name_to_qconfig = {} + + # example: + # + # {'foo.bar': {F.linear: 0, F.conv2d: 1, ...}, ...} + # + # meaning in submodule 'foo.bar', we have seen 0 F.linear and + # 1 F.conv2d invocations so far. + submodule_to_object_type_to_cur_idx: dict[str, dict[Callable, int]] = defaultdict( + lambda: defaultdict(int) + ) + for node in input_graph.nodes: + qconfig = None + if node.op == "get_attr": + module_name, _ = _parent_name(node.target) + qconfig = _maybe_adjust_qconfig_for_module_type_or_name( + qconfig_mapping, type(modules[module_name]), module_name, global_qconfig + ) + qconfig_with_device_check = _add_module_to_qconfig_obs_ctr( + qconfig, modules.get(node.target, None) + ) + elif node.op == "call_function": + # precedence: module_name_qconfig + # > function_qconfig > global_qconfig + # module_name takes precedence over function qconfig + function_qconfig = _get_object_type_qconfig( + qconfig_mapping, node.target, global_qconfig + ) + module_path, module_type = node_name_to_scope[node.name] + qconfig = _maybe_adjust_qconfig_for_module_type_or_name( + qconfig_mapping, module_type, module_path, function_qconfig + ) + + cur_object_type_idx = submodule_to_object_type_to_cur_idx[module_path][ + node.target + ] + submodule_to_object_type_to_cur_idx[module_path][node.target] += 1 + qconfig = _maybe_adjust_qconfig_for_module_name_object_type_order( + qconfig_mapping, module_path, node.target, cur_object_type_idx, qconfig + ) + qconfig_with_device_check = _add_module_to_qconfig_obs_ctr( + qconfig, modules.get(node.target, None) + ) + + elif node.op == "call_method": + module_path, module_type = node_name_to_scope[node.name] + # first use node.target (string) to get the qconfig + # this is to support configs like + # "object_type": [("reshape", qconfig)] + qconfig = _maybe_adjust_qconfig_for_module_type_or_name( + qconfig_mapping, node.target, module_path, global_qconfig + ) + # if there is no special config for the method, we'll fall back to the + # config for the module that contains the call_method node + qconfig = _maybe_adjust_qconfig_for_module_type_or_name( + qconfig_mapping, module_type, module_path, qconfig + ) + # currently call_method does not support modifying qconfig + # by order, we can add this later if it is needed. + qconfig_with_device_check = _add_module_to_qconfig_obs_ctr( + qconfig, modules.get(node.target, None) + ) + + elif node.op == "call_module": + # if the node is an observer, just continue - don't add it to the qconfig_map + if _is_activation_post_process(modules[node.target]): + continue + qconfig = _maybe_adjust_qconfig_for_module_type_or_name( + qconfig_mapping, type(modules[node.target]), node.target, global_qconfig + ) + + module_path, module_type = node_name_to_scope[node.name] + # Note: for call_module, the module_path is the current module's name. + # to meaningfully count invocations, we need to count them in the parent + # module. + parent_name, _ = _parent_name(module_path) + cur_object_type_idx = submodule_to_object_type_to_cur_idx[parent_name][ + module_type + ] + submodule_to_object_type_to_cur_idx[parent_name][module_type] += 1 + qconfig = _maybe_adjust_qconfig_for_module_name_object_type_order( + qconfig_mapping, parent_name, module_type, cur_object_type_idx, qconfig + ) + qconfig_with_device_check = _add_module_to_qconfig_obs_ctr( + qconfig, modules.get(node.target, None) + ) + + # regex is not supported eager mode propagate_qconfig_, we'll + # need to set the qconfig explicitly here in case regex + # is used + modules[node.target].qconfig = qconfig_with_device_check + else: + qconfig_with_device_check = None + + node_name_to_qconfig[node.name] = qconfig_with_device_check + return node_name_to_qconfig + + +def _check_is_valid_config_dict( + config_dict: Any, allowed_keys: set[str], dict_name: str +) -> None: + r"""Checks if the given config_dict has the correct keys + + Args: + `config_dict`: dictionary whose keys we want to check + """ + + for k in config_dict.keys(): + if k not in allowed_keys: + raise ValueError( + "Expected " + + dict_name + + " to have the following keys: " + + str(allowed_keys) + + ". But found '" + + k + + "' instead." + ) + + +def _compare_prepare_convert_qconfig_mappings( + prepare_qconfig_mapping: QConfigMapping, convert_qconfig_mapping: QConfigMapping +): + r"""Compare the qconfig_mapping passed in convert to the one from prepare and check the values + + Args: + `prepare_qconfig_mapping`: configuration for prepare quantization step + `convert_qconfig_mapping`: configuration for convert quantization step + """ + assert qconfig_equals( + prepare_qconfig_mapping.global_qconfig, convert_qconfig_mapping.global_qconfig + ), ( + "Expected global qconfigs to be the same in the prepare and convert quantization configs" + ) + prepare_dicts: list[OrderedDict] = [ + prepare_qconfig_mapping.object_type_qconfigs, + prepare_qconfig_mapping.module_name_qconfigs, + prepare_qconfig_mapping.module_name_regex_qconfigs, + ] + convert_dicts: list[OrderedDict] = [ + convert_qconfig_mapping.object_type_qconfigs, + convert_qconfig_mapping.module_name_qconfigs, + convert_qconfig_mapping.module_name_regex_qconfigs, + ] + dict_names = [ + _OBJECT_TYPE_DICT_KEY, + _MODULE_NAME_DICT_KEY, + _MODULE_NAME_REGEX_DICT_KEY, + ] + for i in range(len(prepare_dicts)): + for name in prepare_dicts[i].keys(): + assert name in convert_dicts[i], ( + f"Missing key {dict_names[i]} {name} in convert QConfigMapping \ + when it was present in prepare" + ) + assert convert_dicts[i][name] is None or qconfig_equals( + prepare_dicts[i][name], convert_dicts[i][name] + ), ( + f"Expected convert QConfigMapping to have the same qconfig as prepare for key {dict_names[i]} {name}; \ + prepare: {prepare_dicts[i][name]}; convert: {convert_dicts[i][name]}" + ) + + +def _is_qconfig_supported_by_dtype_configs( + qconfig: QConfig, dtype_configs: list[DTypeConfig] +): + for dtype_config in dtype_configs: + is_dynamic = dtype_config.is_dynamic + if is_dynamic is None: + is_dynamic = False + input_dtype = dtype_config.input_dtype or torch.float + weight_dtype = dtype_config.weight_dtype or torch.float + bias_dtype = dtype_config.bias_dtype or torch.float + output_dtype = dtype_config.output_dtype or torch.float + ( + qconfig_activation_dtype, + qconfig_weight_dtype, + qconfig_input_act_is_dynamic, + ) = get_qconfig_dtypes(qconfig) + qconfig_bias_dtype = ( + torch.float16 + if ( + qconfig_activation_dtype == torch.float16 + and qconfig_weight_dtype == torch.float16 + and not is_dynamic + ) + else torch.float + ) + + if is_dynamic: + is_match = ( + qconfig_input_act_is_dynamic + and input_dtype == qconfig_activation_dtype + and output_dtype == torch.float + and weight_dtype == qconfig_weight_dtype + ) + else: + is_match = ( + input_dtype == qconfig_activation_dtype + and output_dtype == qconfig_activation_dtype + and weight_dtype == qconfig_weight_dtype + and bias_dtype == qconfig_bias_dtype + ) + if is_match: + return True + return False + + +def _get_object_type_qconfig( + qconfig_mapping: QConfigMapping, + object_type: Union[Callable, str], + fallback_qconfig: QConfigAny, +) -> QConfigAny: + return qconfig_mapping.object_type_qconfigs.get(object_type, fallback_qconfig) + + +def _get_module_name_regex_qconfig(qconfig_mapping, module_name, fallback_qconfig): + for regex_pattern, qconfig in qconfig_mapping.module_name_regex_qconfigs.items(): + if re.match(regex_pattern, module_name): + # first match wins + return qconfig + return fallback_qconfig + + +def _get_module_name_qconfig(qconfig_mapping, module_name, fallback_qconfig): + if module_name == "": + # module name qconfig not found + return fallback_qconfig + if module_name in qconfig_mapping.module_name_qconfigs: + return qconfig_mapping.module_name_qconfigs[module_name] + else: + parent, _ = _parent_name(module_name) + return _get_module_name_qconfig(qconfig_mapping, parent, fallback_qconfig) + + +def _maybe_adjust_qconfig_for_module_type_or_name( + qconfig_mapping, module_type, module_name, global_qconfig +): + # get qconfig for module_name, + # fallback to module_name_regex_qconfig, module_type_qconfig, + # global_qconfig if necessary + module_type_qconfig = _get_object_type_qconfig( + qconfig_mapping, module_type, global_qconfig + ) + module_name_regex_qconfig = _get_module_name_regex_qconfig( + qconfig_mapping, module_name, module_type_qconfig + ) + module_name_qconfig = _get_module_name_qconfig( + qconfig_mapping, module_name, module_name_regex_qconfig + ) + return module_name_qconfig + + +def _get_flattened_qconfig_dict( + qconfig_mapping: QConfigMapping, +) -> dict[Union[Callable, str], QConfigAny]: + """flatten the global, object_type and module_name qconfig + to the same qconfig_dict so that it can be used by + propagate_qconfig_ function. + "module_name_regex" is ignored for now since it's not supported + in propagate_qconfig_, but it can be fixed later. + + For example: + Input: { + "": qconfig, + "object_type": [ + (torch.add, qconfig) + ], + "module_name": [ + ("conv", qconfig) + ] + } + + Output: { + "": qconfig, + torch.add: qconfig, + "conv": qconfig + } + """ + flattened: dict[Union[Callable, str], QConfigAny] = { + "": qconfig_mapping.global_qconfig + } + flattened.update(qconfig_mapping.object_type_qconfigs) + flattened.update(qconfig_mapping.module_name_qconfigs) # type: ignore[arg-type] + return flattened + + +def _update_qconfig_for_qat( + qconfig_mapping: QConfigMapping, backend_config: BackendConfig +): + """ + Update the qconfig_mapping to account for module swaps during QAT. + During QAT we perform a module swap on the nn.Module types to the corresponding nn.qat.modules types. + """ + module_to_qat_module_class = get_module_to_qat_module(backend_config) + object_type_dict = qconfig_mapping.object_type_qconfigs + new_object_type_dict = object_type_dict.copy() + for k, v in new_object_type_dict.items(): + if k in module_to_qat_module_class: + object_type_dict[module_to_qat_module_class[k]] = v diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/quantize_handler.py b/phivenv/Lib/site-packages/torch/ao/quantization/fx/quantize_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..101aca8895ec4fb4edc4c8433d5325649b54dbda --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/fx/quantize_handler.py @@ -0,0 +1,225 @@ +# mypy: allow-untyped-defs +from abc import ABC +from typing import Callable, Optional + +import torch +from torch.ao.quantization.backend_config import ( + BackendConfig, + DTypeConfig, + ObservationType, +) +from torch.ao.quantization.utils import NodePattern, Pattern, QuantizerCls +from torch.fx.graph import Node + +from .utils import all_node_args_have_no_tensors + + +__all__ = [ + "QuantizeHandler", + "BinaryOpQuantizeHandler", + "CatQuantizeHandler", + "ConvReluQuantizeHandler", + "LinearReLUQuantizeHandler", + "BatchNormQuantizeHandler", + "EmbeddingQuantizeHandler", + "RNNDynamicQuantizeHandler", + "DefaultNodeQuantizeHandler", + "FixedQParamsOpQuantizeHandler", + "CopyNodeQuantizeHandler", + "GeneralTensorShapeOpQuantizeHandler", + "CustomModuleQuantizeHandler", + "StandaloneModuleQuantizeHandler", +] + + +def _default_root_node_getter(node_pattern): + if node_pattern is None: + return node_pattern + while not isinstance(node_pattern, Node): + node_pattern = node_pattern[-1] + return node_pattern + + +# Base Pattern Handler +class QuantizeHandler(ABC): # noqa: B024 + """Base handler class for the quantizer patterns""" + + def __init__( + self, + node_pattern: NodePattern, + modules: dict[str, torch.nn.Module], + root_node_getter: Optional[Callable] = None, + is_custom_module=False, + is_standalone_module=False, + ): + """Records pattern information in __init__, which will be used + in convert + """ + self.node_pattern = node_pattern + self.modules = modules + if root_node_getter is None: + root_node_getter = _default_root_node_getter + self.root_node = root_node_getter(node_pattern) + self.is_custom_module_ = is_custom_module + self.is_standalone_module_ = is_standalone_module + self.num_tensor_args = 0 + # determine how many of the first two args are Tensors (versus scalars) + # this distinguishes things like "x + y" from "x + 2" or "2 + x" + if isinstance(self.root_node, Node): + cache_for_no_tensor_check: dict[Node, bool] = {} + for arg_idx in range(len(self.root_node.args)): + arg = self.root_node.args[arg_idx] + if isinstance(arg, Node) and ( + not all_node_args_have_no_tensors( + arg, self.modules, cache_for_no_tensor_check + ) + ): + self.num_tensor_args += 1 + + def is_general_tensor_value_op(self) -> bool: + """ + Returns True if the operator works for both floating point and + quantized input, and does some computation based on the input Tensor, + or the ops that only re-arranges the Tensor values or query some metadata + about the Tensor + so we need to insert observer/fake_quant for the output of the + operator (same observer instance as input) + since the distribution of values is different for input and output + Tensors (for HistogramObserver) while they share the same quantization + parameters + Example operator: avgpool2d, reshape, transpose, maxpool2d + Example observed operator: + observer_0 - avgpool2d - observer_0 (same observer instance as input) + """ + return False + + def is_custom_module(self): + return self.is_custom_module_ + + def is_standalone_module(self): + return self.is_standalone_module_ + + +def _get_quantize_handler_cls( + observation_type: ObservationType, + dtype_configs: list[DTypeConfig], + num_tensor_args_to_observation_type: dict[int, ObservationType], +) -> type[QuantizeHandler]: + """ + Return a configurable QuantizeHandler that matches the given specifications from the backend. + """ + + class ConfigurableQuantizeHandler(QuantizeHandler): + def __init__( + self, + node_pattern: NodePattern, + modules: dict[str, torch.nn.Module], + root_node_getter: Optional[Callable] = None, + ): + super().__init__(node_pattern, modules, root_node_getter) + if num_tensor_args_to_observation_type: + assert self.num_tensor_args in num_tensor_args_to_observation_type, ( + f"Must provide observation_type config for tensor number {self.num_tensor_args}" + f" in num_tensor_args_to_observation_type for {node_pattern}" + ) + self.observation_type = num_tensor_args_to_observation_type[ + self.num_tensor_args + ] + else: + self.observation_type = observation_type + self.dtype_configs = dtype_configs + + def is_general_tensor_value_op(self) -> bool: + return ( + self.observation_type + == ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT + ) + + return ConfigurableQuantizeHandler + + +def _get_pattern_to_quantize_handlers( + backend_config: BackendConfig, +) -> dict[Pattern, QuantizerCls]: + """ + Note: Quantize handler is just a holder for some check methods like + (should_insert_observer_for_output), maybe this can be a enum as well, + we can refactor this after we convert the path for fbgemm/qnnpack fully to the + new path, this is not exposed to backend developers + """ + pattern_to_quantize_handlers = {} + for pattern, config in backend_config._pattern_complex_format_to_config.items(): + observation_type = config.observation_type + dtype_configs = config.dtype_configs + num_tensor_args_to_observation_type = ( + config._num_tensor_args_to_observation_type + ) + pattern_to_quantize_handlers[pattern] = _get_quantize_handler_cls( + observation_type, dtype_configs, num_tensor_args_to_observation_type + ) + return pattern_to_quantize_handlers + + +# TODO: remove this class, this is still exposed in torch.ao.quantization +# but we should be able to break bc +class BinaryOpQuantizeHandler(QuantizeHandler): + pass + + +class CatQuantizeHandler(QuantizeHandler): + pass + + +# TODO: remove this class +class ConvReluQuantizeHandler(QuantizeHandler): + pass + + +# TODO: remove this class +class LinearReLUQuantizeHandler(QuantizeHandler): + pass + + +# TODO: remove this class +class BatchNormQuantizeHandler(QuantizeHandler): + pass + + +# TODO: remove this class +class EmbeddingQuantizeHandler(QuantizeHandler): + pass + + +# TODO: remove this class +class RNNDynamicQuantizeHandler(QuantizeHandler): + pass + + +# TODO: remove this class +class DefaultNodeQuantizeHandler(QuantizeHandler): + """Common quantized op, first input and first output will be quantized""" + + +# TODO: remove this class +class FixedQParamsOpQuantizeHandler(QuantizeHandler): + pass + + +# TODO: remove +class CopyNodeQuantizeHandler(QuantizeHandler): + pass + + +# TODO: remove +class GeneralTensorShapeOpQuantizeHandler(QuantizeHandler): + pass + + +# TODO: not used, can be removed after torch.ao.quantization namespace is deprecated +class CustomModuleQuantizeHandler(QuantizeHandler): + pass + + +# TODO: not used, can be removed after torch.ao.quantization namespace is deprecated +class StandaloneModuleQuantizeHandler(QuantizeHandler): + pass diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/tracer.py b/phivenv/Lib/site-packages/torch/ao/quantization/fx/tracer.py new file mode 100644 index 0000000000000000000000000000000000000000..7aad1ef19400b44925bf778aa9cc8037939ce011 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/fx/tracer.py @@ -0,0 +1,48 @@ +from typing import Callable + +import torch +from torch.ao.nn.intrinsic import _FusedModule +from torch.fx._symbolic_trace import Tracer +from torch.fx.proxy import Scope + + +__all__ = [ + "QuantizationTracer", +] + + +class ScopeContextManager(torch.fx.proxy.ScopeContextManager): + def __init__( + self, scope: Scope, current_module: torch.nn.Module, current_module_path: str + ): + super().__init__(scope, Scope(current_module_path, type(current_module))) + + +class QuantizationTracer(Tracer): + def __init__( + self, skipped_module_names: list[str], skipped_module_classes: list[Callable] + ): + super().__init__() + self.skipped_module_names = skipped_module_names + self.skipped_module_classes = skipped_module_classes + # NB: initialized the module_type of top level module to None + # we are assuming people won't configure the model with the type of top level + # module here, since people can use "" for global config + # We can change this if there is a use case that configures + # qconfig using top level module type + self.scope = Scope("", None) + self.record_stack_traces = True + + def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: + return ( + ( + ( + m.__module__.startswith("torch.nn") + or m.__module__.startswith("torch.ao.nn") + ) + and not isinstance(m, torch.nn.Sequential) + ) + or module_qualified_name in self.skipped_module_names + or type(m) in self.skipped_module_classes + or isinstance(m, _FusedModule) + ) diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/fx/utils.py b/phivenv/Lib/site-packages/torch/ao/quantization/fx/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..57d6832a83b3175ddd6c619093b21edb6a6acb4e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/fx/utils.py @@ -0,0 +1,963 @@ +# mypy: allow-untyped-defs +import copy +import operator +import warnings +from collections import namedtuple +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union + +import torch +import torch.nn as nn +from torch.ao.quantization import QConfigAny, QuantType +from torch.ao.quantization.backend_config import DTypeWithConstraints +from torch.ao.quantization.fake_quantize import ( + FakeQuantizeBase, + FixedQParamsFakeQuantize, +) +from torch.ao.quantization.observer import ( + _is_activation_post_process, + FixedQParamsObserver, + ObserverBase, +) +from torch.ao.quantization.qconfig import ( + float16_dynamic_qconfig, + float16_static_qconfig, + qconfig_equals, +) +from torch.ao.quantization.qconfig_mapping import QConfigMapping +from torch.ao.quantization.stubs import DeQuantStub +from torch.ao.quantization.utils import ( + _assert_and_get_unique_device, + activation_is_statically_quantized, +) +from torch.fx import GraphModule, map_arg +from torch.fx.graph import Graph, Node + +# importing the lib so that the quantized_decomposed ops are registered +from ._decomposed import quantized_decomposed_lib # noqa: F401 +from .custom_config import PrepareCustomConfig + + +# TODO: revisit this list. Many helper methods shouldn't be public +__all__ = [ + "all_node_args_except_first", + "all_node_args_have_no_tensors", + "assert_and_get_unique_device", + "collect_producer_nodes", + "create_getattr_from_value", + "create_node_from_old_node_preserve_meta", + "EMPTY_ARG_DICT", + "get_custom_module_class_keys", + "get_linear_prepack_op_for_dtype", + "get_new_attr_name_with_prefix", + "get_non_observable_arg_indexes_and_types", + "get_qconv_prepack_op", + "get_skipped_module_name_and_classes", + "graph_module_from_producer_nodes", + "maybe_get_next_module", + "NodeInfo", + "node_arg_is_bias", + "node_arg_is_weight", + "NON_OBSERVABLE_ARG_DICT", + "NON_QUANTIZABLE_WEIGHT_OPS", + "return_arg_list", + "ObservedGraphModuleAttrs", +] + +NON_QUANTIZABLE_WEIGHT_OPS = { + torch.nn.functional.layer_norm, + torch.nn.functional.group_norm, + torch.nn.functional.instance_norm, +} + + +@dataclass +class ObservedGraphModuleAttrs: + node_name_to_qconfig: dict[str, QConfigAny] + node_name_to_scope: dict[str, tuple[str, type]] + prepare_custom_config: PrepareCustomConfig + equalization_node_name_to_qconfig: dict[str, Any] + qconfig_mapping: QConfigMapping + is_qat: bool + observed_node_names: set[str] + is_observed_standalone_module: bool = False + standalone_module_input_quantized_idxs: Optional[list[int]] = None + standalone_module_output_quantized_idxs: Optional[list[int]] = None + + +def node_arg_is_weight(node: Node, arg: Any) -> bool: + """Returns if node arg is weight""" + weight_index = None + if "target_dtype_info" in node.meta: + weight_index = node.meta["target_dtype_info"].get("weight_index", None) + if ( + weight_index is not None + and weight_index < len(node.args) + and node.args[weight_index] is arg + ): + return True + return node.kwargs.get("weight") is arg + + +def node_arg_is_bias(node: Node, arg: Any) -> bool: + """Returns if node arg is bias""" + bias_index = None + if "target_dtype_info" in node.meta: + bias_index = node.meta["target_dtype_info"].get("bias_index", None) + if ( + bias_index is not None + and bias_index < len(node.args) + and node.args[bias_index] is arg + ): + return True + return node.kwargs.get("bias") is arg + + +def get_custom_module_class_keys( + custom_module_mapping: dict[QuantType, dict[type, type]], +) -> list[Any]: + r"""Get all the unique custom module keys in the custom config dict + e.g. + Input: + { + QuantType.STATIC: { + CustomModule1: ObservedCustomModule + }, + QuantType.DYNAMIC: { + CustomModule2: DynamicObservedCustomModule + }, + QuantType.WEIGHT_ONLY: { + CustomModule3: WeightOnlyObservedCustomModule + }, + } + + Output: + # extract the keys across all inner STATIC, DYNAMIC, and WEIGHT_ONLY dicts + [CustomModule1, CustomModule2, CustomModule3] + """ + # using set to dedup + float_custom_module_classes: set[Any] = set() + for quant_mode in [QuantType.STATIC, QuantType.DYNAMIC, QuantType.WEIGHT_ONLY]: + quant_mode_custom_module_config = custom_module_mapping.get(quant_mode, {}) + quant_mode_custom_module_classes = set(quant_mode_custom_module_config.keys()) + float_custom_module_classes |= quant_mode_custom_module_classes + return list(float_custom_module_classes) + + +def get_linear_prepack_op_for_dtype(dtype): + if dtype == torch.float16: + return torch.ops.quantized.linear_prepack_fp16 + elif dtype == torch.qint8: + return torch.ops.quantized.linear_prepack + else: + raise Exception("can't get linear prepack op for dtype:", dtype) # noqa: TRY002 + + +def get_qconv_prepack_op(conv_op: Callable) -> Callable: + prepack_ops = { + torch.nn.functional.conv1d: torch.ops.quantized.conv1d_prepack, + torch.nn.functional.conv2d: torch.ops.quantized.conv2d_prepack, + torch.nn.functional.conv3d: torch.ops.quantized.conv3d_prepack, + torch.nn.functional.conv_transpose1d: torch.ops.quantized.conv_transpose1d_prepack, + torch.nn.functional.conv_transpose2d: torch.ops.quantized.conv_transpose2d_prepack, + torch.nn.functional.conv_transpose3d: torch.ops.quantized.conv_transpose3d_prepack, + } + prepack_op = prepack_ops.get(conv_op, None) + assert prepack_op, f"Didn't find prepack op for {conv_op}" + return prepack_op + + +# Returns a function that can get a new attribute name for module with given +# prefix, for example, +# >> get_new_observer_name = get_new_attr_name_with_prefix('_observer') +# >> new_name = get_new_observer_name(module) +# new_name will be an unused attribute name on module, e.g. `_observer_1` +def get_new_attr_name_with_prefix(prefix: str) -> Callable: + prefix = prefix.replace(".", "_") + + def get_new_attr_name(module: torch.nn.Module): + def get_attr_name(i: int): + return prefix + str(i) + + i = 0 + attr_name = get_attr_name(i) + while hasattr(module, attr_name): + i += 1 + attr_name = get_attr_name(i) + return attr_name + + return get_new_attr_name + + +def collect_producer_nodes(node: Node) -> Optional[list[Node]]: + r"""Starting from a target node, trace back until we hit inpu or + getattr node. This is used to extract the chain of operators + starting from getattr to the target node, for example + def forward(self, x): + observed = self.observer(self.weight) + return F.linear(x, observed) + collect_producer_nodes(observed) will either return a list of nodes that + produces the observed node or None if we can't extract a self contained + graph without free variables(inputs of the forward function). + """ + nodes = [node] + frontier = [node] + while frontier: + node = frontier.pop() + all_args = list(node.args) + list(node.kwargs.values()) + for arg in all_args: + if not isinstance(arg, Node): + continue + if arg.op == "placeholder": + # hit input, can't fold in this case + return None + nodes.append(arg) + if not (arg.op == "call_function" and arg.target == getattr): + frontier.append(arg) + return nodes + + +def graph_module_from_producer_nodes( + root: GraphModule, producer_nodes: list[Node] +) -> GraphModule: + r"""Construct a graph module from extracted producer nodes + from `collect_producer_nodes` function + Args: + root: the root module for the original graph + producer_nodes: a list of nodes we use to construct the graph + Return: + A graph module constructed from the producer nodes + """ + assert len(producer_nodes) > 0, "list of producer nodes can not be empty" + # since we traced back from node to getattr + producer_nodes.reverse() + graph = Graph() + env: dict[Any, Any] = {} + + def load_arg(a): + return map_arg(a, lambda node: env[node]) + + for producer_node in producer_nodes: + env[producer_node] = graph.node_copy(producer_node, load_arg) + graph.output(load_arg(producer_nodes[-1])) + graph_module = GraphModule(root, graph) + return graph_module + + +# TODO: delete +def assert_and_get_unique_device(module: torch.nn.Module) -> Any: + """ + Returns the unique device for a module, or None if no device is found. + Throws an error if multiple devices are detected. + """ + return _assert_and_get_unique_device(module) + + +def create_getattr_from_value( + module: torch.nn.Module, graph: Graph, prefix: str, value: Any +) -> Node: + """ + Given a value of any type, creates a getattr node corresponding to the value and + registers the value as a buffer to the module. + """ + get_new_attr_name = get_new_attr_name_with_prefix(prefix) + attr_name = get_new_attr_name(module) + device = assert_and_get_unique_device(module) + new_value = ( + value.detach().clone() + if isinstance(value, torch.Tensor) + else torch.tensor(value, device=device) + ) + module.register_buffer(attr_name, new_value) + # Create get_attr with value + attr_node = graph.create_node("get_attr", attr_name) + return attr_node + + +def all_node_args_have_no_tensors( + node: Node, modules: dict[str, torch.nn.Module], cache: dict[Node, bool] +) -> bool: + """ + If we know for sure that all of this node's args have no + tensors (are primitives), return True. If we either + find a tensor or are not sure, return False. Note: this + function is not exact. + """ + if cache and node in cache: + return cache[node] + + result = False # will be overwritten + if not isinstance(node, Node): + result = True + elif node.op == "placeholder": + result = False + elif node.op == "call_module": + assert isinstance(node.target, str) + if _is_activation_post_process(modules[node.target]): + result = all_node_args_have_no_tensors(node.args[0], modules, cache) # type: ignore[arg-type] + elif node.op == "call_module": + result = False + elif node.op == "call_function" and node.target is operator.getitem: + result = all_node_args_have_no_tensors(node.args[0], modules, cache) # type: ignore[arg-type] + elif node.op == "get_attr": + result = False + elif node.target is getattr and node.args[1] in ["ndim", "shape"]: + # x1 = x0.ndim + result = True + elif node.op == "call_method" and node.target == "size": + # x1 = x0.size(0) + result = True + else: + found_one_tensor = False + for arg in node.args: + if isinstance(arg, list): + for list_el in arg: + if isinstance(list_el, Node): + this_list_el_args_have_no_tensors = ( + all_node_args_have_no_tensors(list_el, modules, cache) + ) + found_one_tensor = found_one_tensor or ( + not this_list_el_args_have_no_tensors + ) + # If found_one_tensor is True, there is no point in + # recursing further as the end result will always + # be True. + # TODO(future PR): remove this entire function and + # change to dtype inference without recursion. + if found_one_tensor: + result = not found_one_tensor + if cache: + cache[node] = result + return result + elif isinstance(arg, int): + pass + else: + if isinstance(arg, Node): + this_arg_args_have_no_tensors = all_node_args_have_no_tensors( + arg, modules, cache + ) + found_one_tensor = found_one_tensor or ( + not this_arg_args_have_no_tensors + ) + # If found_one_tensor is True, there is no point in + # recursing further as the end result will always + # be True. + # TODO(future PR): remove this entire function and + # change to dtype inference without recursion. + if found_one_tensor: + result = not found_one_tensor + if cache: + cache[node] = result + return result + else: + found_one_tensor = True + result = not found_one_tensor + if cache: + cache[node] = result + return result + + +def all_node_args_except_first(node: Node) -> list[int]: + """ + Returns all node arg indices after first + """ + return list(range(1, len(node.args))) + + +def return_arg_list(arg_indices: list[int]) -> Callable[[Node], list[int]]: + """ + Constructs a function that takes a node as arg and returns the arg_indices + that are valid for node.args + """ + + def arg_indices_func(node: Node) -> list[int]: + return [i for i in arg_indices if i < len(node.args)] + + return arg_indices_func + + +NodeInfo = namedtuple("NodeInfo", "op target") + +# this dict identifies which indices of a node are non tensors +# so that they can be propagated correctly since inserting observers +# for them would cause errors + +NON_OBSERVABLE_ARG_DICT: dict[ + NodeInfo, dict[Union[type, torch.dtype], Callable[[Node], list[int]]] +] = { + NodeInfo("call_method", "masked_fill"): { + torch.bool: return_arg_list([1]), + float: return_arg_list([2]), + }, + NodeInfo("call_method", "permute"): {int: all_node_args_except_first}, + NodeInfo("call_method", "repeat"): {int: all_node_args_except_first}, + NodeInfo("call_method", "reshape"): {int: all_node_args_except_first}, + NodeInfo("call_method", "size"): {int: return_arg_list([1])}, + NodeInfo("call_method", "transpose"): {int: all_node_args_except_first}, + NodeInfo("call_method", torch.transpose): {int: all_node_args_except_first}, + NodeInfo("call_method", "unsqueeze"): {int: return_arg_list([1])}, + NodeInfo("call_method", "unsqueeze_"): {int: return_arg_list([1])}, + NodeInfo("call_method", torch.unsqueeze): {int: return_arg_list([1])}, + NodeInfo("call_method", "view"): {int: all_node_args_except_first}, +} + +EMPTY_ARG_DICT: dict[Union[type, torch.dtype], Callable[[Node], list[int]]] = {} + + +def get_non_observable_arg_indexes_and_types( + node: Node, +) -> dict[Union[type, torch.dtype], Callable[[Node], list[int]]]: + """ + Returns a dict with of non float tensor types as keys and values which correspond to a + function to retrieve the list (which takes the node as an argument) + """ + info = NodeInfo(node.op, node.target) + + return NON_OBSERVABLE_ARG_DICT.get(info, EMPTY_ARG_DICT) + + +def maybe_get_next_module( + node: Node, + modules: dict[str, nn.Module], + target_module_type: Optional[type[nn.Module]] = None, + target_functional_type: Any = None, +) -> Optional[Node]: + """Gets the next module that matches what is needed in + is_target_module_type if it exists + + Args: + node: The node whose users we want to look at + target_module_type: Module type that we want to check + target_functional_type: Functional type that we want to check + """ + + for user in node.users.keys(): + if ( + user.op == "call_module" + and target_module_type is not None + and isinstance(modules[str(user.target)], target_module_type) + ): + return user + elif ( + user.op == "call_function" + and target_functional_type is not None + and user.target == target_functional_type + ): + return user + + return None + + +def create_node_from_old_node_preserve_meta( + quantized_graph: Graph, + create_node_args: tuple[Any, ...], + old_node: Node, +) -> Node: + """ + Creates `new_node` and copies the necessary metadata to it from `old_node`. + """ + new_node = quantized_graph.create_node(*create_node_args) + new_node.stack_trace = old_node.stack_trace + return new_node + + +def get_skipped_module_name_and_classes( + prepare_custom_config: PrepareCustomConfig, is_standalone_module: bool +) -> tuple[list[str], list[type[Any]]]: + skipped_module_names = copy.copy(prepare_custom_config.non_traceable_module_names) + skipped_module_classes = copy.copy( + prepare_custom_config.non_traceable_module_classes + ) + if not is_standalone_module: + # standalone module and custom module config are applied in top level module + skipped_module_names += list( + prepare_custom_config.standalone_module_names.keys() + ) + skipped_module_classes += list( + prepare_custom_config.standalone_module_classes.keys() + ) + skipped_module_classes += get_custom_module_class_keys( + prepare_custom_config.float_to_observed_mapping + ) + + return skipped_module_names, skipped_module_classes + + +def _is_custom_module_lstm( + node: Node, + named_modules: dict[str, torch.nn.Module], + qconfig: QConfigAny = None, + # QuantizeHandler, but we cannot include the type here due to circular imports + qhandler: Optional[Any] = None, +) -> bool: + """ + Return whether this refers to the custom module LSTM flow. + """ + mod = _get_module(node, named_modules) + if qconfig is not None and qhandler is not None: + assert isinstance( + qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler + ) # type: ignore[attr-defined] + return ( + isinstance(mod, torch.nn.LSTM) + and activation_is_statically_quantized(qconfig) + and qhandler.is_custom_module() + ) + else: + return isinstance(mod, torch.ao.nn.quantizable.LSTM) + + +def _is_custom_module_mha( + node: Node, + named_modules: dict[str, torch.nn.Module], + qconfig: QConfigAny = None, + # QuantizeHandler, but we cannot include the type here due to circular imports + qhandler: Optional[Any] = None, +) -> bool: + """ + Return whether this refers to the custom module MultiheadAttention flow. + """ + mod = _get_module(node, named_modules) + if qconfig is not None and qhandler is not None: + assert isinstance( + qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler + ) # type: ignore[attr-defined] + return ( + isinstance(mod, torch.nn.MultiheadAttention) + and activation_is_statically_quantized(qconfig) + and qhandler.is_custom_module() + ) + else: + return isinstance(mod, torch.ao.nn.quantizable.MultiheadAttention) + + +def _get_module( + node: Node, named_modules: dict[str, torch.nn.Module] +) -> Optional[torch.nn.Module]: + """ + If `node` refers to a call_module node, return the module, else None. + """ + if node.op == "call_module" and str(node.target) in named_modules: + return named_modules[str(node.target)] + else: + return None + + +def _insert_dequant_stub( + node: Node, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + graph: Graph, +) -> Node: + """ + Attach a `DeQuantStub` to the model and create a node that calls this + `DeQuantStub` on the output of `node`, similar to how observers are inserted. + """ + prefix = "dequant_stub_" + get_new_dequant_stub_name = get_new_attr_name_with_prefix(prefix) + dequant_stub_name = get_new_dequant_stub_name(model) + dequant_stub = DeQuantStub() + setattr(model, dequant_stub_name, dequant_stub) + named_modules[dequant_stub_name] = dequant_stub + with graph.inserting_after(node): + return graph.call_module(dequant_stub_name, (node,)) + + +def _insert_dequant_stubs_for_custom_module_lstm_output( + node: Node, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + graph: Graph, +) -> Node: + """ + Insert DeQuantStubs after each internal output node of custom module LSTM. + + Custom module LSTM outputs are nested tuples of the structure (output, (hidden0, hidden1)), + Since we cannot dequantize a tuple as a whole, we must first break down the tuple into its + components through `getitem`. This function transforms the graph as follows: + + (1) Split the LSTM node into (output, (hidden0, hidden1)) + (2) Insert a DeQuantStub after each internal node + (3) Recombine the DeQuantStubs into the same structure as before + (4) Reroute all consumers of the original LSTM node and its sub-nodes + (e.g. lstm[0]) + + Before: + lstm_output + | + v + original_user(s) + After: + lstm_output + / \\ + / (getitem) \\ + / \\ + v v + output hidden + | / \\ + (DeQuantStub) (getitem) + | / \\ + v v v + output_dq hidden0 hidden1 + | | | + | (DeQuantStub) (DeQuantStub) + | | | + | v v + | hidden0_dq hidden1_dq + | \\ / + | (tuple) + | \\ / + | v v + | hidden_dq + \\ / + \\ (tuple) / + v v + lstm_output_dq + | + v + original_user(s) + + For step (4), reroute all users of the original LSTM node(s) as follows: + lstm_output -> lstm_output_dq + lstm_output[0] -> output_dq + lstm_output[1] -> hidden_dq + lstm_output[1][0] -> hidden0_dq + lstm_output[1][1] -> hidden1_dq + + Return the node `lstm_output_dq`. + """ + # (1) Split the LSTM node into (output, (hidden0, hidden1)) + # (2) Insert a DeQuantStub after each internal node + with graph.inserting_after(node): + output = graph.call_function(operator.getitem, (node, 0)) + output_dq = _insert_dequant_stub(output, model, named_modules, graph) + with graph.inserting_after(output_dq): + hidden = graph.call_function(operator.getitem, (node, 1)) + with graph.inserting_after(hidden): + hidden0 = graph.call_function(operator.getitem, (hidden, 0)) + hidden0_dq = _insert_dequant_stub(hidden0, model, named_modules, graph) + with graph.inserting_after(hidden0_dq): + hidden1 = graph.call_function(operator.getitem, (hidden, 1)) + hidden1_dq = _insert_dequant_stub(hidden1, model, named_modules, graph) + + # (3) Recombine the DeQuantStubs into the same structure as before + with graph.inserting_after(hidden1_dq): + hidden_dq = graph.call_function(tuple, ([hidden0_dq, hidden1_dq],)) + with graph.inserting_after(hidden_dq): + lstm_output_dq = graph.call_function(tuple, ([output_dq, hidden_dq],)) + + # (4) Reroute all consumers of the original LSTM node and its sub-nodes + for user in list(node.users.keys()): + if user != output and user != hidden: + user.replace_input_with(node, lstm_output_dq) + # The getitem and tuple nodes we added here may interfere with reference quantized + # pattern matching, so we need to redirect the consumers of internal nodes to the + # corresponding nodes with DeQuantStubs (e.g. lstm_output_dq[0] -> output_dq) attached, + # in order to preserve reference patterns like "dequantize - consumer - quantize". + _reroute_tuple_getitem_pattern(graph) + return lstm_output_dq + + +def _maybe_get_custom_module_lstm_from_node_arg( + arg: Node, + named_modules: dict[str, torch.nn.Module], +) -> Optional[Node]: + """ + Given an argument of a node, if the argument refers to the path through which the node + is a consumer of custom module LSTM, return the custom module LSTM node, or None otherwise. + + This is used to determine whether a node is a consumer of custom module LSTM, and, if so, + skip inserting input observers for this node. This is because custom module LSTM produces + quantized outputs, so inserting an input observer for the consumer of custom module LSTM + would unnecessarily quantize the outputs again. + + lstm -> consumer + + In practice, however, custom module LSTM outputs a tuple (output, (hidden0, hidden1)) with + DeQuantStubs attached to each internal node (see `_insert_dequant_stubs_for_custom_module_lstm_output`). + This tuple can be consumed in one of four ways: + + lstm -> getitem -> DeQuantStub -> consumer # consume lstm[0] + lstm -> getitem -> getitem -> DeQuantStub -> tuple -> consumer # consume lstm[1] + lstm -> getitem -> getitem -> DeQuantStub -> consumer # consume lstm[1][0] or lstm[1][1] + lstm -> getitem -> DeQuantStub -> tuple -> consumer # consume lstm + + Thus, we must match against the above patterns instead of simply checking the parent node + to determine whether this node is a consumer of a custom module LSTM. + """ + + def match_dq(a): + return isinstance(_get_module(a, named_modules), DeQuantStub) + + def match_lstm(a): + return _is_custom_module_lstm(a, named_modules) + + def match_getitem(a): + return a.op == "call_function" and a.target == operator.getitem + + def match_tuple(a): + return a.op == "call_function" and a.target == tuple + + def _match_pattern(match_pattern: list[Callable]) -> Optional[Node]: + """ + Traverse up the graph and match the args one by one. + If there is a match, return the last matched node, or None otherwise. + """ + a = arg + for i, match in enumerate(match_pattern): + if not match(a): + return None + # Match next arg, for tuple the arg is a tuple of a list, e.g. ([dq_1, other_node],) + if i < len(match_pattern) - 1: + if match == match_tuple: + a = a.args[0][0] # type: ignore[assignment,index] + else: + a = a.args[0] # type: ignore[assignment] + return a + + all_match_patterns = [ + [match_dq, match_getitem, match_lstm], + [match_tuple, match_dq, match_getitem, match_getitem, match_lstm], + [match_dq, match_getitem, match_getitem, match_lstm], + [match_tuple, match_dq, match_getitem, match_lstm], + ] + + for p in all_match_patterns: + matched_node = _match_pattern(p) + if matched_node is not None: + return matched_node + return None + + +def _reroute_tuple_getitem_pattern(graph: Graph): + """ + Search for patterns where N consecutive `tuple` call_function nodes are followed by + N consecutive `getitem` call_function nodes that are "reverses" of the `tuple` nodes. + If we find this pattern, reroute the consumers of the last `getitem` to skip these + N `tuple` and `getitem` nodes. + + Before: + + a b c + | \\ / + \\ tuple + \\ / + tuple + | + getitem(1) + | + getitem(0) + | + d + + After: + + b + | + d + """ + + def find_patterns( + node: Node, + index_stack: list[int], + current_pattern: list[Node], + matched_patterns: list[list[Node]], + seen: set[tuple[Node, tuple[int, ...]]], + ): + """ + Traverse the graph recursively to match for the N-tuple - N-getitem patterns, + starting at the given node. + + We use a stack to keep track of the expected `getitem` indices, since these are + reversed from the `tuple` indices. In the above example, the stack after + (b -> tuple -> tuple) will be [0, 1], which will be popped by getitem(1) first + and then by getitem(0). + + TODO: traverse upwards from the output and handle the case when tuple is not a + separate node, e.g. graph.call_function(operator.getitem, args=(a, (b, c))) + """ + if len(index_stack) == 0 and len(current_pattern) > 0: + matched_patterns.append(copy.copy(current_pattern)) + current_pattern.clear() + + # Avoid duplicating work + state = (node, tuple(index_stack)) + if state in seen: + return + seen.add(state) + + # Iterate through users of this node to find tuple/getitem nodes to match + for user in node.users: + if user.op == "call_function" and user.target == tuple: + for i, user_arg in enumerate(user.args[0]): # type: ignore[arg-type] + if user_arg == node: + index_stack.append(i) + current_pattern.append(user) + find_patterns( + user, index_stack, current_pattern, matched_patterns, seen + ) + elif user.op == "call_function" and user.target == operator.getitem: + if len(index_stack) > 0: + if user.args[1] == index_stack[-1]: + index_stack.pop() + current_pattern.append(user) + find_patterns( + user, index_stack, current_pattern, matched_patterns, seen + ) + return matched_patterns + + # Collect all matched patterns + matched_patterns: list[list[Node]] = [] + seen: set[tuple[Node, tuple[int, ...]]] = set() # (node, index_stack) + for node in graph.nodes: + find_patterns(node, [], [], matched_patterns, seen) + + # For each pattern, redirect all consumers of the last getitem node to the correct input + # of the first tuple node + for pattern in matched_patterns: + first_tuple = pattern[0] + last_getitem = pattern[-1] + assert first_tuple.op == "call_function" and first_tuple.target == tuple + assert ( + last_getitem.op == "call_function" + and last_getitem.target == operator.getitem + ) + last_getitem_index = last_getitem.args[1] + new_input = first_tuple.args[0][last_getitem_index] # type: ignore[index] + for user in list(last_getitem.users.keys()): + user.replace_input_with(last_getitem, new_input) # type: ignore[arg-type] + + +def _get_observer_from_activation_post_process( + activation_post_process: Union[ObserverBase, FakeQuantizeBase], +) -> ObserverBase: + """ + If `activation_post_process` is an observer, return the observer. + If `activation_post_process` is a fake quantize, return the internal observer. + """ + if isinstance(activation_post_process, ObserverBase): + return activation_post_process + else: + assert isinstance(activation_post_process, FakeQuantizeBase) + return activation_post_process.activation_post_process # type: ignore[return-value] + + +def _qconfig_satisfies_dtype_config_constraints( + qconfig: QConfigAny, + dtype_with_constraints: DTypeWithConstraints, + is_activation: bool = True, +) -> bool: + """ + Return whether `qconfig` satisfies the following constraints from the backend, + specified through the activation and weight DTypeWithConstraints. + + 1. QConfig specified a quantization range that falls within the backend's, if any + 2. QConfig specified a min scale value that is >= the backend's, if any + 3. QConfig specified a FixedQParamsObserver or FixedQParamsFakeQuantize that has + scale and zero point that match the backend's, if any + + If `is_activation` is True, we check `qconfig.activation`, else we check `qconfig.weight`. + If `qconfig` or `dtype_with_constraints.dtype` is None, or the dtypes do not match, return True. + """ + + # TODO: log warnings only when the user enabled a debug flag + def _activation_post_process_satisfies_dtype_config_constraints( + activation_post_process: Union[ObserverBase, FakeQuantizeBase], + dtype_with_constraints: DTypeWithConstraints, + debug_string: str, + ) -> bool: + observer = _get_observer_from_activation_post_process(activation_post_process) + app_quant_min = getattr(observer, "quant_min", None) + app_quant_max = getattr(observer, "quant_max", None) + # TODO: for now, just use the existing eps value as scale_min. In the future, we should + # resolve the differences between the two, either by renaming eps or some other way + app_scale_min = getattr(observer, "eps", None) + backend_quant_min = dtype_with_constraints.quant_min_lower_bound + backend_quant_max = dtype_with_constraints.quant_max_upper_bound + backend_scale_min = dtype_with_constraints.scale_min_lower_bound + backend_scale_exact_match = dtype_with_constraints.scale_exact_match + backend_zero_point_exact_match = dtype_with_constraints.zero_point_exact_match + # check quantization ranges + if backend_quant_min is not None and backend_quant_max is not None: + if app_quant_min is None or app_quant_max is None: + warnings.warn( + f"QConfig {debug_string} must specify 'quant_min' and 'quant_max', ignoring {qconfig}" + ) + return False + elif app_quant_min < backend_quant_min or app_quant_max > backend_quant_max: + warnings.warn( + f"QConfig {debug_string} quantization range must fall within the backend's:\n" + f"QConfig range = ({app_quant_min}, {app_quant_max}), " + f"BackendConfig range = ({backend_quant_min}, {backend_quant_max}), " + f"ignoring {qconfig}" + ) + return False + # check scale min + if backend_scale_min is not None: + if app_scale_min is None: + warnings.warn( + f"QConfig {debug_string} must specify 'eps', ignoring {qconfig}" + ) + return False + if app_scale_min < backend_scale_min: + warnings.warn( + f"QConfig {debug_string} eps ({app_scale_min}) must be greater than or equal to " + f"the backend's min scale value ({backend_scale_min}), ignoring {qconfig}" + ) + return False + # check fixed scale and zero point + if ( + backend_scale_exact_match is not None + and backend_zero_point_exact_match is not None + ): + # For tests only, accept the following qconfigs for now + # TODO: handle fp16 qconfigs properly + for accepted_qconfig in [float16_static_qconfig, float16_dynamic_qconfig]: + if qconfig_equals(qconfig, accepted_qconfig): + return True + suggestion_str = ( + "Please use torch.ao.quantization.get_default_qconfig_mapping or " + "torch.ao.quantization.get_default_qat_qconfig_mapping. Example:\n" + ' qconfig_mapping = get_default_qconfig_mapping("fbgemm")\n' + " model = prepare_fx(model, qconfig_mapping, example_inputs)" + ) + if not isinstance( + activation_post_process, FixedQParamsObserver + ) and not isinstance(activation_post_process, FixedQParamsFakeQuantize): + warnings.warn( + f"QConfig must specify a FixedQParamsObserver or a FixedQParamsFakeQuantize " + f"for fixed qparams ops, ignoring {qconfig}.\n{suggestion_str}" + ) + return False + if ( + observer.scale != backend_scale_exact_match + or observer.zero_point != backend_zero_point_exact_match + ): + warnings.warn( + f"QConfig fixed scale ({observer.scale}) and zero point ({observer.zero_point}) " + f"do not match the backend's ({backend_scale_exact_match} and {backend_zero_point_exact_match}), " + f"ignoring {qconfig}.\n{suggestion_str}" + ) + return False + return True + + if qconfig is None or dtype_with_constraints.dtype is None: + return True + + activation_post_process_ctr = ( + qconfig.activation if is_activation else qconfig.weight + ) + debug_string = "activation" if is_activation else "weight" + satisfies_constraints = True + if activation_post_process_ctr is not None: + activation_post_process = activation_post_process_ctr() + assert _is_activation_post_process(activation_post_process) + # If dtypes don't match, don't check the activation_post_process and return True early + if activation_post_process.dtype != dtype_with_constraints.dtype: + return True + satisfies_constraints = ( + _activation_post_process_satisfies_dtype_config_constraints( + activation_post_process, dtype_with_constraints, debug_string + ) + ) + return satisfies_constraints diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/observer.py b/phivenv/Lib/site-packages/torch/ao/quantization/observer.py new file mode 100644 index 0000000000000000000000000000000000000000..133df5ac2843d91792ed8d8c369631a24847d291 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/observer.py @@ -0,0 +1,2131 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +# temporarily skip RUF for this file for now, we can re-enable +# after move the affine quantization related things to torchao +# noqa: RUF +""" +This module implements observers which are used to collect statistics about +the values observed during calibration (PTQ) or training (QAT). +""" + +import operator +import re +import warnings +from abc import ABCMeta, abstractmethod +from collections import OrderedDict +from functools import partial +from typing import Any, Optional + +import torch +import torch.nn as nn +from torch.ao.quantization.utils import ( + calculate_qmin_qmax, + check_min_max_valid, + is_per_channel, + is_per_tensor, + validate_qmin_qmax, +) +from torch.fx import Node + + +__all__ = [ + "default_affine_fixed_qparams_observer", + "default_debug_observer", + "default_dynamic_quant_observer", + "default_fixed_qparams_range_0to1_observer", + "default_fixed_qparams_range_neg1to1_observer", + "default_float_qparams_observer", + "default_float_qparams_observer_4bit", + "default_histogram_observer", + "default_observer", + "default_per_channel_weight_observer", + "default_placeholder_observer", + "default_reuse_input_observer", + "default_symmetric_fixed_qparams_observer", + "default_weight_observer", + "get_observer_state_dict", + "load_observer_state_dict", + "per_channel_weight_observer_range_neg_127_to_127", + "weight_observer_range_neg_127_to_127", + "FixedQParamsObserver", + "HistogramObserver", + "MinMaxObserver", + "MovingAverageMinMaxObserver", + "MovingAveragePerChannelMinMaxObserver", + "NoopObserver", + "ObserverBase", + "PerChannelMinMaxObserver", + "PlaceholderObserver", + "RecordingObserver", + "ReuseInputObserver", + "UniformQuantizationObserverBase", + "AffineQuantizedObserverBase", + "Granularity", + "MappingType", + "PerAxis", + "PerBlock", + "PerGroup", + "PerRow", + "PerTensor", + "PerToken", + "TorchAODType", + "ZeroPointDomain", + "get_block_size", +] + + +class _PartialWrapper: + def __init__(self, p): + self.p = p + self.callable_args = {} + + def __call__(self, *args, **keywords): + # call each arg in callable_args and add them partial, then run with keywords + # skip if arg_name in keywords so its possible to overwrite + for arg_name in self.callable_args: + if arg_name not in keywords: + keywords = {**keywords, arg_name: self.callable_args[arg_name]()} + return self.p(*args, **keywords) + + def __repr__(self): + return self.p.__repr__() + self.callable_args.__repr__() + + def with_args(self, **kwargs): + return _with_args(self, **kwargs) + + def with_callable_args(self, **kwargs): + result = _PartialWrapper(p=self.p) + result.callable_args = {**self.callable_args, **kwargs} + return result + + +def _with_args(cls_or_self, **kwargs): + r"""Wrapper that allows creation of class factories. + + This can be useful when there is a need to create classes with the same + constructor arguments, but different instances. Can be used in conjunction with + _callable_args + + Example:: + + >>> # xdoctest: +SKIP("Undefined vars") + >>> Foo.with_args = classmethod(_with_args) + >>> foo_builder = Foo.with_args(a=3, b=4).with_args(answer=42) + >>> foo_instance1 = foo_builder() + >>> foo_instance2 = foo_builder() + >>> id(foo_instance1) == id(foo_instance2) + False + """ + r = _PartialWrapper(partial(cls_or_self, **kwargs)) + return r + + +def _with_callable_args(cls_or_self, **kwargs): + r"""Wrapper that allows creation of class factories args that need to be + called at construction time. + + This can be useful when there is a need to create classes with the same + constructor arguments, but different instances and those arguments should only + be calculated at construction time. Can be used in conjunction with _with_args + + Example:: + + >>> # xdoctest: +SKIP("Undefined vars") + >>> Foo.with_callable_args = classmethod(_with_callable_args) + >>> Foo.with_args = classmethod(_with_args) + >>> foo_builder = Foo.with_callable_args(cur_time=get_time_func).with_args(name="dan") + >>> foo_instance1 = foo_builder() + >>> # wait 50 + >>> foo_instance2 = foo_builder() + >>> id(foo_instance1.creation_time) == id(foo_instance2.creation_time) + False + """ + r = _PartialWrapper(partial(cls_or_self)) + return r.with_callable_args(**kwargs) + + +ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3: + + +class ObserverBase(ABC, nn.Module): + r"""Base observer Module. + Any observer implementation should derive from this class. + + Concrete observers should follow the same API. In forward, they will update + the statistics of the observed Tensor. And they should provide a + `calculate_qparams` function that computes the quantization parameters given + the collected statistics. + + Args: + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. + is_dynamic: indicator for whether the observer is a placeholder for dynamic quantization + or static quantization + """ + + def __init__(self, dtype, is_dynamic: bool = False): + super().__init__() + self.dtype = dtype + self.is_dynamic = is_dynamic + + @abstractmethod + def forward(self, x): + pass + + @abstractmethod + def calculate_qparams(self, **kwargs): + pass + + with_args = classmethod(_with_args) + with_callable_args = classmethod(_with_callable_args) + + +class UniformQuantizationObserverBase(ObserverBase): + r"""Common base for all observers using uniform quantization to calculate + scale and zero_point. + + Args: + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. + qscheme: Quantization scheme to be used. + reduce_range: Reduces the range of the quantized data type by 1 bit. + This is sometimes required to avoid instruction overflow. + quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. + eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. + + .. warning:: + + :attr:`dtype` can only take ``torch.qint8`` or ``torch.quint8``. + or `torch.int8` or `torch.uint8` + + .. warning:: + + :attr:`qscheme` can only take one of the following options: + + - ``torch.per_tensor_affine`` + - ``torch.per_tensor_symmetric`` + - ``torch.per_channel_affine`` + - ``torch.per_channel_symmetric`` + """ + + # Note: the version is shared by all observer types + # + # Version 1/None + # self + # + # Version 2 (base class only, does not include child class buffers) + # self + # |--- eps : Tensor + # + # Version 3 + # for HistogramObserver only, changed the shape of uninitialized + # min_val and max_val buffers from torch.Size([0]) to torch.Size([]) + # for PerChannelObservers, changed the name of the buffers from min_vals + # to min_val and from max_vals to max_val. + _version = 3 + + eps: torch.Tensor + + def __init__( + self, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps, + is_dynamic=False, + **kwargs, + ) -> None: + factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) + super().__init__(dtype=dtype, is_dynamic=is_dynamic, **kwargs) + self.qscheme = qscheme + if reduce_range: + warnings.warn( + "Please use quant_min and quant_max to specify the range for observers. \ + reduce_range will be deprecated in a future release of PyTorch." + ) + self.reduce_range = reduce_range + self.register_buffer("eps", torch.tensor([eps], **factory_kwargs)) + assert self.qscheme in ( + torch.per_tensor_affine, + torch.per_tensor_symmetric, + torch.per_channel_affine, + torch.per_channel_symmetric, + torch.per_channel_affine_float_qparams, + ), ( + "Default Observer only works for per_tensor_affine, \ + per_tensor_symmetric, per_channel_affine, \ + per_channel_symmetric and per_channel_float_qparams quantization scheme" + ) + + _ALLOWED_DTYPES = ( + torch.qint8, + torch.quint8, + torch.quint4x2, + torch.qint32, + torch.int8, + torch.uint8, + torch.int16, + torch.int32, + torch.float8_e5m2, + torch.float8_e4m3fn, + torch.uint16, + ) + + assert self.dtype in _ALLOWED_DTYPES, ( + f"Default Observer only works for {_ALLOWED_DTYPES} data type" + ) + self.has_customized_qrange = (quant_min is not None) and (quant_max is not None) + if self.has_customized_qrange: + validate_qmin_qmax(quant_min, quant_max) + self.quant_min, self.quant_max = calculate_qmin_qmax( + quant_min, + quant_max, + self.has_customized_qrange, + self.dtype, + self.reduce_range, + ) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + + if version is None or version == 1: + # eps was moved to a buffer in version 2 + eps = torch.tensor([torch.finfo(torch.float32).eps]) + state_dict[prefix + "eps"] = eps + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + @torch.jit.export + def _validate_qmin_qmax(self, quant_min: int, quant_max: int) -> None: + r"""Validates that the user-specified quantization range is properly initialized + and within the given bound supported by the observer dtype. + + To accommodate lower-bit quantization with respect to the existing torch.qint8 and + torch.quint8 datatypes, the user can choose to use dynamic quantization range by passing + in a tuple of initial qmin and qmax values. One use case is these customized qmin and qmax + values are used to calculate static estimates of the scale and zero point for aggressive lower-bit + fake quantization. These estimates are compared against parameters learned through backpropagation. + The related literatures for scale and zero point via backpropagation are as follows: + + Learned Step Size Quantization: https://openreview.net/pdf?id=rkgO66VKDS + Trained Quantization Thresholds: https://arxiv.org/pdf/1903.08066.pdf + """ + # The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted + # based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer. + assert quant_min <= 0 <= quant_max, ( + "Used-specified quantization range must include 0." + ) + assert quant_min < quant_max, ( + "qmin must be strictly less than qmax for user-specified quantization range." + ) + + @torch.jit.export + def _calculate_qparams( + self, min_val: torch.Tensor, max_val: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + r"""Calculates the quantization parameters, given min and max + value tensors. Works for both per tensor and per channel cases + + Args: + min_val: Minimum values per channel + max_val: Maximum values per channel + + Returns: + scales: Scales tensor of shape (#channels,) + zero_points: Zero points tensor of shape (#channels,) + """ + # Functionally equivalent to 'determine_qparams' in utils.py. Observers must be torchscriptable however and qscheme + # as far as I can tell is not allowed to passed as a parameter in torchscript functions. This makes refactoring observer + # to use this utility a massive pain and very gross. For now Im opting just to duplicate as this code + # seems unlikey to change (last update over 1 year ago) and when torchscript is fully deprecated we can refactor. + # TODO(jakeszwe, jerryzh168) + if not check_min_max_valid(min_val, max_val): + return torch.tensor([1.0], device=min_val.device.type), torch.tensor( + [0], device=min_val.device.type + ) + + quant_min, quant_max = self.quant_min, self.quant_max + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + + device = min_val_neg.device + scale = torch.ones(min_val_neg.size(), dtype=torch.float32, device=device) + zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) + + if ( + self.qscheme == torch.per_tensor_symmetric + or self.qscheme == torch.per_channel_symmetric + ): + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scale = max_val_pos / (float(quant_max - quant_min) / 2) + scale = torch.max(scale, self.eps) + if self.dtype in [torch.quint8, torch.uint8]: + if self.has_customized_qrange: + # When customized quantization range is used, down-rounded midpoint of the range is chosen. + zero_point = zero_point.new_full( + zero_point.size(), (quant_min + quant_max) // 2 + ) + else: + zero_point = zero_point.new_full(zero_point.size(), 128) + elif self.dtype in [torch.uint16]: + zero_point = zero_point.new_full(zero_point.size(), 2**15) + elif self.qscheme == torch.per_channel_affine_float_qparams: + scale = (max_val - min_val) / float(quant_max - quant_min) + scale = torch.where(scale > self.eps, scale, torch.ones_like(scale)) + # We use the quantize function + # xq = Round(Xf * inv_scale + zero_point), + # setting zero_point to (-1 * min *inv_scale) we get + # Xq = Round((Xf - min) * inv_scale) + zero_point = -1 * min_val / scale + else: + scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) + scale = torch.max(scale, self.eps) + zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int) + zero_point = torch.clamp(zero_point, quant_min, quant_max) + + # For scalar values, cast them to Tensors of size 1 to keep the shape + # consistent with default values in FakeQuantize. + if len(scale.shape) == 0: + # TODO: switch to scale.item() after adding JIT support + scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device) + if len(zero_point.shape) == 0: + # TODO: switch to zero_point.item() after adding JIT support + zero_point = torch.tensor( + [int(zero_point)], dtype=zero_point.dtype, device=device + ) + if self.qscheme == torch.per_channel_affine_float_qparams: + zero_point = torch.tensor( + [float(zero_point)], dtype=zero_point.dtype, device=device + ) + + return scale, zero_point + + @torch.jit.export + def reset_min_max_vals(self): + raise NotImplementedError("Cannot reset min/max values in the given observer.") + + +# Originally, this class was called `_ObserverBase`. Keeping the old name around +# for backwards compatibility. +# TODO(after v1.13): delete this +_ObserverBase = UniformQuantizationObserverBase + + +class MinMaxObserver(UniformQuantizationObserverBase): + r"""Observer module for computing the quantization parameters based on the + running min and max values. + + This observer uses the tensor min/max statistics to compute the quantization + parameters. The module records the running minimum and maximum of incoming + tensors, and uses this statistic to compute the quantization parameters. + + Args: + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. + qscheme: Quantization scheme to be used + reduce_range: Reduces the range of the quantized data type by 1 bit + quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. + eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. + + Given running min/max as :math:`x_\text{min}` and :math:`x_\text{max}`, + scale :math:`s` and zero point :math:`z` are computed as: + + The running minimum/maximum :math:`x_\text{min/max}` is computed as: + + .. math:: + + \begin{array}{ll} + x_\text{min} &= \begin{cases} + \min(X) & \text{if~}x_\text{min} = \text{None} \\ + \min\left(x_\text{min}, \min(X)\right) & \text{otherwise} + \end{cases}\\ + x_\text{max} &= \begin{cases} + \max(X) & \text{if~}x_\text{max} = \text{None} \\ + \max\left(x_\text{max}, \max(X)\right) & \text{otherwise} + \end{cases}\\ + \end{array} + + where :math:`X` is the observed tensor. + + The scale :math:`s` and zero point :math:`z` are then computed as: + + .. math:: + + \begin{aligned} + \text{if Symmetric:}&\\ + &s = 2 \max(|x_\text{min}|, x_\text{max}) / + \left( Q_\text{max} - Q_\text{min} \right) \\ + &z = \begin{cases} + 0 & \text{if dtype is qint8} \\ + 128 & \text{otherwise} + \end{cases}\\ + \text{Otherwise:}&\\ + &s = \left( x_\text{max} - x_\text{min} \right ) / + \left( Q_\text{max} - Q_\text{min} \right ) \\ + &z = Q_\text{min} - \text{round}(x_\text{min} / s) + \end{aligned} + + where :math:`Q_\text{min}` and :math:`Q_\text{max}` are the minimum and + maximum of the quantized data type. + + .. warning:: :attr:`dtype` can only take ``torch.qint8`` or ``torch.quint8``. + + .. note:: If the running minimum equals to the running maximum, the scale + and zero_point are set to 1.0 and 0. + """ + + min_val: torch.Tensor + max_val: torch.Tensor + + def __init__( + self, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps, + is_dynamic=False, + **kwargs, + ) -> None: + if not is_per_tensor(qscheme): + raise NotImplementedError( + "MinMaxObserver's qscheme only support torch.per_tensor_symmetric \ + and torch.per_tensor_affine." + ) + # TODO: MinMaxObserver by itself doesn't support dynamic quantization, but + # if it's inherited by MovingAverageObserver, and averaging_constant is 1, it + # supports dynamic quantization, we may need to better error checking here + + # For x86 quantized kernels, we need to ensure that the vpmaddubsw + # instruction does not overflow. We allow for a reduce_range argument to + # observers that reduces the quantized range to (0,127) or (-64, 63). + # For more details see aten/src/ATen/native/quantized/cpu/qconv.cpp + # This is not an optimal choice for non x86 backends as it loses a bit + # of precision for activations. + super().__init__( + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + quant_min=quant_min, + quant_max=quant_max, + factory_kwargs=factory_kwargs, + eps=eps, + is_dynamic=is_dynamic, + **kwargs, + ) + factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) + self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs)) + self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs)) + if ( + self.qscheme == torch.per_tensor_symmetric + and self.reduce_range + and self.dtype == torch.quint8 + ): + raise NotImplementedError( + "Cannot reduce range for symmetric \ + quantization for quint8" + ) + + def forward(self, x_orig): + r"""Records the running minimum and maximum of ``x``.""" + if x_orig.numel() == 0: + return x_orig + x = x_orig.detach() # avoid keeping autograd tape + x = x.to(self.min_val.dtype) + min_val_cur, max_val_cur = torch.aminmax(x) + min_val = torch.min(min_val_cur, self.min_val) + max_val = torch.max(max_val_cur, self.max_val) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + return x_orig + + @torch.jit.export + def calculate_qparams(self): # type: ignore[override] + r"""Calculates the quantization parameters.""" + return self._calculate_qparams(self.min_val, self.max_val) + + @torch.jit.export + def extra_repr(self): + return f"min_val={self.min_val}, max_val={self.max_val}" + + @torch.jit.export + def reset_min_max_vals(self): + """Resets the min/max values.""" + self.min_val.copy_(torch.tensor(float("inf"))) + self.max_val.copy_(torch.tensor(float("-inf"))) + + +class MovingAverageMinMaxObserver(MinMaxObserver): + r"""Observer module for computing the quantization parameters based on the + moving average of the min and max values. + + This observer computes the quantization parameters based on the moving + averages of minimums and maximums of the incoming tensors. The module + records the average minimum and maximum of incoming tensors, and uses this + statistic to compute the quantization parameters. + + Args: + averaging_constant: Averaging constant for min/max. + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. + qscheme: Quantization scheme to be used + reduce_range: Reduces the range of the quantized data type by 1 bit + quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. + eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. + + The moving average min/max is computed as follows + + .. math:: + + \begin{array}{ll} + x_\text{min} = \begin{cases} + \min(X) & \text{if~}x_\text{min} = \text{None} \\ + (1 - c) x_\text{min} + c \min(X) & \text{otherwise} + \end{cases}\\ + x_\text{max} = \begin{cases} + \max(X) & \text{if~}x_\text{max} = \text{None} \\ + (1 - c) x_\text{max} + c \max(X) & \text{otherwise} + \end{cases}\\ + \end{array} + + where :math:`x_\text{min/max}` is the running average min/max, :math:`X` is + is the incoming tensor, and :math:`c` is the ``averaging_constant``. + + The scale and zero point are then computed as in + :class:`~torch.ao.quantization.observer.MinMaxObserver`. + + .. note:: Only works with ``torch.per_tensor_affine`` quantization scheme. + + .. note:: If the running minimum equals to the running maximum, the scale + and zero_point are set to 1.0 and 0. + """ + + def __init__( + self, + averaging_constant=0.01, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + eps=torch.finfo(torch.float32).eps, + is_dynamic=False, + **kwargs, + ) -> None: + if not is_per_tensor(qscheme): + raise NotImplementedError( + f"MovingAverageMinMaxObserver's qscheme only support \ + torch.per_tensor_symmetric and torch.per_tensor_affine. \ + but got: {qscheme}" + ) + self.averaging_constant = averaging_constant + if is_dynamic and self.averaging_constant != 1: + raise NotImplementedError( + "MovingAverageMinMaxObserver doesn't support dynamic quantization for " + f"averaging constant of {self.averaging_constant}" + ) + super().__init__( + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + is_dynamic=is_dynamic, + **kwargs, + ) + + def forward(self, x_orig): + if x_orig.numel() == 0: + return x_orig + x = x_orig.detach() # avoid keeping autograd tape + x = x.to(self.min_val.dtype) + min_val = self.min_val + max_val = self.max_val + if min_val == float("inf") and max_val == float("-inf"): + min_val, max_val = torch.aminmax(x) + else: + min_val_cur, max_val_cur = torch.aminmax(x) + min_val = min_val + self.averaging_constant * (min_val_cur - min_val) + max_val = max_val + self.averaging_constant * (max_val_cur - max_val) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + return x_orig + + +class PerChannelMinMaxObserver(UniformQuantizationObserverBase): + r"""Observer module for computing the quantization parameters based on the + running per channel min and max values. + + This observer uses the tensor min/max statistics to compute the per channel + quantization parameters. The module records the running minimum and maximum + of incoming tensors, and uses this statistic to compute the quantization + parameters. + + Args: + ch_axis: Channel axis + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. + qscheme: Quantization scheme to be used + reduce_range: Reduces the range of the quantized data type by 1 bit + quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. + eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. + + The quantization parameters are computed the same way as in + :class:`~torch.ao.quantization.observer.MinMaxObserver`, with the difference + that the running min/max values are stored per channel. + Scales and zero points are thus computed per channel as well. + + .. note:: If the running minimum equals to the running maximum, the scales + and zero_points are set to 1.0 and 0. + """ + + min_val: torch.Tensor + max_val: torch.Tensor + + def __init__( + self, + ch_axis=0, + dtype=torch.quint8, + qscheme=torch.per_channel_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps, + is_dynamic=False, + **kwargs, + ) -> None: + if not is_per_channel(qscheme): + raise NotImplementedError( + "PerChannelMinMaxObserver's qscheme only support \ + torch.per_channel_symmetric, torch.per_channel_affine and torch.per_channel_affine_float_qparams." + ) + if is_dynamic: + raise NotImplementedError( + "PerChannelMinMaxObserver doesn't support dynamic quantization" + ) + super().__init__( + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + quant_min=quant_min, + quant_max=quant_max, + factory_kwargs=factory_kwargs, + eps=eps, + is_dynamic=is_dynamic, + **kwargs, + ) + factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) + self.ch_axis = ch_axis + self.register_buffer("min_val", torch.tensor([], **factory_kwargs)) + self.register_buffer("max_val", torch.tensor([], **factory_kwargs)) + if ( + self.qscheme == torch.per_channel_symmetric + and self.reduce_range + and self.dtype == torch.quint8 + ): + raise NotImplementedError( + "Cannot reduce range for symmetric quantization for quint8" + ) + + def forward(self, x_orig): + return self._forward(x_orig) + + def _forward(self, x_orig): + if x_orig.numel() == 0: + return x_orig + x = x_orig.detach() # avoid keeping autograd tape + min_val = self.min_val + max_val = self.max_val + x_dim = x.size() + + new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x.permute(new_axis_list) + # Need to match dtype of min/max because the updates to buffers + # are done in place and types need to match for comparisons + y = y.to(self.min_val.dtype) + y = torch.flatten(y, start_dim=1) + if min_val.numel() == 0 or max_val.numel() == 0: + min_val, max_val = torch.aminmax(y, dim=1) + else: + min_val_cur, max_val_cur = torch.aminmax(y, dim=1) + min_val = torch.min(min_val_cur, min_val) + max_val = torch.max(max_val_cur, max_val) + self.min_val.resize_(min_val.shape) + self.max_val.resize_(max_val.shape) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + return x_orig + + @torch.jit.export + def calculate_qparams(self): # type: ignore[override] + return self._calculate_qparams(self.min_val, self.max_val) + + def extra_repr(self): + return f"min_val={self.min_val}, max_val={self.max_val}" + + def _load_from_state_dict( + self, + state_dict: dict[str, Any], + prefix: str, + local_metadata: dict[str, torch.Tensor], + strict: bool, + missing_keys: list[str], + unexpected_keys: list[str], + error_msgs: list[str], + ): + version = local_metadata.get("version", None) + if version is not None and version < 3: + local_state = ["min_vals", "max_vals"] + expected_min_name = "min_vals" + expected_max_name = "max_vals" + else: + local_state = ["min_val", "max_val"] + expected_min_name = "min_val" + expected_max_name = "max_val" + for name in local_state: + key = prefix + name + if key in state_dict: + val = state_dict[key] + # Custom handling to allow loading min_val or max_val + # of size N into uninitialized buffers of size 0. The + # buffers are resized here, and the values are copied in + # the default state_dict loading code of the parent. + if name == expected_min_name: + self.min_val.resize_(val.shape) + elif name == expected_max_name: + self.max_val.resize_(val.shape) + else: + warnings.warn( + f"Observer load_from_state_dict got unexpected name {name}" + ) + # For torchscript module we need to update the attributes here since we do not + # call the `_load_from_state_dict` function defined module.py + if torch.jit.is_scripting(): + if name == expected_min_name: + self.min_val.copy_(val) + elif name == expected_max_name: + self.max_val.copy_(val) + else: + warnings.warn( + f"Observer load_from_state_dict got unexpected name {name}" + ) + elif strict: + missing_keys.append(key) + + if not torch.jit.is_scripting(): + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def _load_from_state_dict_script( + self, + state_dict: dict[str, Any], + prefix: str, + local_metadata: dict[str, torch.Tensor], + strict: bool, + missing_keys: list[str], + unexpected_keys: list[str], + error_msgs: list[str], + ): + self._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + @torch.jit.export + def reset_min_max_vals(self): + """Resets the min/max values.""" + # This used to be torch.ones but that does not work because + # JIT compiler can optimize it via common subexpression elimination + # in which case both min_val and max_val point to the same tensor. + self.min_val = torch.rand( + 0, + ) + self.max_val = torch.rand( + 0, + ) + + +class MovingAveragePerChannelMinMaxObserver(PerChannelMinMaxObserver): + r"""Observer module for computing the quantization parameters based on the + running per channel min and max values. + + This observer uses the tensor min/max statistics to compute the per channel + quantization parameters. The module records the running minimum and maximum + of incoming tensors, and uses this statistic to compute the quantization + parameters. + + Args: + averaging_constant: Averaging constant for min/max. + ch_axis: Channel axis + dtype: Quantized data type + qscheme: Quantization scheme to be used + reduce_range: Reduces the range of the quantized data type by 1 bit + quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. + eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. + + The quantization parameters are computed the same way as in + :class:`~torch.ao.quantization.observer.MovingAverageMinMaxObserver`, with the + difference that the running min/max values are stored per channel. + Scales and zero points are thus computed per channel as well. + + .. note:: If the running minimum equals to the running maximum, the scales + and zero_points are set to 1.0 and 0. + """ + + def __init__( + self, + averaging_constant=0.01, + ch_axis=0, + dtype=torch.quint8, + qscheme=torch.per_channel_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + eps=torch.finfo(torch.float32).eps, + is_dynamic=False, + **kwargs, + ) -> None: + if not is_per_channel(qscheme): + raise NotImplementedError( + "MovingAveragePerChannelMinMaxObserver's qscheme only support \ + torch.per_channel_symmetric, torch.per_channel_affine and torch.per_channel_affine_float_qparams." + ) + if is_dynamic: + raise NotImplementedError( + "MovingAveragePerChannelMinMaxObserver doesn't support dynamic quantization" + ) + super().__init__( + ch_axis=ch_axis, + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + is_dynamic=is_dynamic, + **kwargs, + ) + self.averaging_constant = averaging_constant + + def forward(self, x_orig): + if x_orig.numel() == 0: + return x_orig + x = x_orig.detach() # avoid keeping autograd tape + x = x.to(self.min_val.dtype) + min_val = self.min_val + max_val = self.max_val + x_dim = x.size() + + new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x.permute(new_axis_list) + y = torch.flatten(y, start_dim=1) + if min_val.numel() == 0 or max_val.numel() == 0: + min_val, max_val = torch.aminmax(y, dim=1) + else: + min_val_cur, max_val_cur = torch.aminmax(y, dim=1) + min_val = min_val + self.averaging_constant * (min_val_cur - min_val) + max_val = max_val + self.averaging_constant * (max_val_cur - max_val) + self.min_val.resize_(min_val.shape) + self.max_val.resize_(max_val.shape) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + return x_orig + + +class HistogramObserver(UniformQuantizationObserverBase): + r""" + The module records the running histogram of tensor values along with + min/max values. ``calculate_qparams`` will calculate scale and zero_point. + + Args: + bins: Number of bins to use for the histogram + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec + qscheme: Quantization scheme to be used + reduce_range: Reduces the range of the quantized data type by 1 bit + eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. + + The scale and zero point are computed as follows: + + 1. Create the histogram of the incoming inputs. + The histogram is computed continuously, and the ranges per bin change + with every new tensor observed. + 2. Search the distribution in the histogram for optimal min/max values. + The search for the min/max values ensures the minimization of the + quantization error with respect to the floating point model. + 3. Compute the scale and zero point the same way as in the + :class:`~torch.ao.quantization.MinMaxObserver` + """ + + histogram: torch.Tensor + min_val: torch.Tensor + max_val: torch.Tensor + + def __init__( + self, + bins: int = 2048, + dtype: torch.dtype = torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps, + is_dynamic=False, + **kwargs, + ) -> None: + if not is_per_tensor(qscheme): + raise NotImplementedError( + "HistogramObserver's qscheme only support torch.per_tensor_symmetric \ + and torch.per_tensor_affine." + ) + if is_dynamic: + raise NotImplementedError( + "HistogramObserver doesn't support dynamic quantization" + ) + # bins: The number of bins used for histogram calculation. + super().__init__( + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + quant_min=quant_min, + quant_max=quant_max, + factory_kwargs=factory_kwargs, + eps=eps, + is_dynamic=is_dynamic, + **kwargs, + ) + factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) + self.bins = bins + self.register_buffer("histogram", torch.zeros(self.bins, **factory_kwargs)) + self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs)) + self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs)) + self.dst_nbins = 2 ** torch.iinfo(self.dtype).bits + self.upsample_rate = ( + 16 # used to reduce quantization errors when upscaling histogram + ) + + def _get_norm( + self, delta_begin: torch.Tensor, delta_end: torch.Tensor, density: torch.Tensor + ) -> torch.Tensor: + r""" + Compute the norm of the values uniformaly distributed between + delta_begin and delta_end. + Currently only L2 norm is supported. + + norm = density * (integral_{begin, end} x^2) + = density * (end^3 - begin^3) / 3 + """ + norm = ( + delta_end * delta_end * delta_end - delta_begin * delta_begin * delta_begin + ) / 3 + return density * norm + + def _compute_quantization_error(self, next_start_bin: int, next_end_bin: int): + r""" + Compute the quantization error if we use start_bin to end_bin as the + min and max to do the quantization. + """ + bin_width = (self.max_val.item() - self.min_val.item()) / self.bins + + dst_bin_width = bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins + if dst_bin_width == 0.0: + return 0.0 + + src_bin = torch.arange(self.bins, device=self.histogram.device) + # distances from the beginning of first dst_bin to the beginning and + # end of src_bin + src_bin_begin = (src_bin - next_start_bin) * bin_width + src_bin_end = src_bin_begin + bin_width + + # which dst_bins the beginning and end of src_bin belong to? + dst_bin_of_begin = torch.clamp( + torch.div(src_bin_begin, dst_bin_width, rounding_mode="floor"), + 0, + self.dst_nbins - 1, + ) + dst_bin_of_begin_center = (dst_bin_of_begin + 0.5) * dst_bin_width + + dst_bin_of_end = torch.clamp( + torch.div(src_bin_end, dst_bin_width, rounding_mode="floor"), + 0, + self.dst_nbins - 1, + ) + density = self.histogram / bin_width + + norm = torch.zeros(self.bins, device=self.histogram.device) + + delta_begin = src_bin_begin - dst_bin_of_begin_center + delta_end = dst_bin_width / 2 + norm += self._get_norm( + delta_begin, + torch.ones(self.bins, device=self.histogram.device) * delta_end, + density, + ) + + norm += (dst_bin_of_end - dst_bin_of_begin - 1) * self._get_norm( + torch.tensor(-dst_bin_width / 2), torch.tensor(dst_bin_width / 2), density + ) + + dst_bin_of_end_center = dst_bin_of_end * dst_bin_width + dst_bin_width / 2 + + delta_begin = -dst_bin_width / 2 + delta_end = src_bin_end - dst_bin_of_end_center + norm += self._get_norm(torch.tensor(delta_begin), delta_end, density) + + return norm.sum().item() + + def _non_linear_param_search(self) -> tuple[torch.Tensor, torch.Tensor]: + r"""Non-linear parameter search. + + An approximation for L2 error minimization for selecting min/max. + By selecting new min/max, we filter out outliers in input distribution. + This follows the implementation of NormMinimization::NonlinearQuantizationParamsSearch in + caffe2/quantization/server/norm_minimization.cc + """ + assert self.histogram.size()[0] == self.bins, "bins mismatch" + bin_width = (self.max_val - self.min_val) / self.bins + + # cumulative sum + total = torch.sum(self.histogram).item() + cSum = torch.cumsum(self.histogram, dim=0) + + stepsize = 1e-5 # granularity + alpha = 0.0 # lower bound + beta = 1.0 # upper bound + start_bin = 0 + end_bin = self.bins - 1 + norm_min = float("inf") + + while alpha < beta: + # Find the next step + next_alpha = alpha + stepsize + next_beta = beta - stepsize + + # find the left and right bins between the quantile bounds + l = start_bin + r = end_bin + while l < end_bin and cSum[l] < next_alpha * total: + l = l + 1 + while r > start_bin and cSum[r] > next_beta * total: + r = r - 1 + + # decide the next move + next_start_bin = start_bin + next_end_bin = end_bin + if (l - start_bin) > (end_bin - r): + # move the start bin + next_start_bin = l + alpha = next_alpha + else: + # move the end bin + next_end_bin = r + beta = next_beta + + if next_start_bin == start_bin and next_end_bin == end_bin: + continue + + # calculate the quantization error using next_start_bin and next_end_bin + norm = self._compute_quantization_error(next_start_bin, next_end_bin) + + if norm > norm_min: + break + norm_min = norm + start_bin = next_start_bin + end_bin = next_end_bin + + new_min = self.min_val + bin_width * start_bin + new_max = self.min_val + bin_width * (end_bin + 1) + return new_min, new_max + + def _upscale_histogram( + self, + histogram: torch.Tensor, + orig_min: torch.Tensor, + orig_max: torch.Tensor, + update_min: torch.Tensor, + update_max: torch.Tensor, + ): + # this turns the histogram into a more fine-coarsed histogram to reduce + # bin quantization errors + histogram = histogram.repeat_interleave(self.upsample_rate) / self.upsample_rate + bin_size = (orig_max - orig_min) / (self.bins * self.upsample_rate) + mid_points_histogram = ( + torch.linspace( + orig_min, + orig_max, + self.bins * self.upsample_rate + 1, + device=orig_min.device, + )[:-1].to(histogram.device) + + 0.5 * bin_size + ) + boundaries_new_histogram = torch.linspace( + update_min, update_max, self.bins + 1, device=update_min.device + ).to(histogram.device) + # this maps the mid-poits of the histogram to the new histogram's space + bucket_assignments = ( + torch.bucketize(mid_points_histogram, boundaries_new_histogram, right=True) + - 1 + ) + # this then maps the histogram mid-points in the new space, weighted by the original histogram's values + # this is just the old histogram in the new histogram's space + + # In case due to numerical issues the values land higher/lower than the maximum/minimum + bucket_assignments[bucket_assignments >= self.bins] = self.bins - 1 + bucket_assignments[bucket_assignments < 0] = 0 + + update_histogram = torch.bincount( + bucket_assignments, weights=histogram, minlength=self.bins + ) + return update_histogram + + def _combine_histograms( + self, + orig_hist: torch.Tensor, + orig_min: torch.Tensor, + orig_max: torch.Tensor, + update_hist: torch.Tensor, + update_min: torch.Tensor, + update_max: torch.Tensor, + ) -> torch.Tensor: + # If the new min and max are the same as the current min and max, + # we can just add the new histogram to the original histogram + if update_min == orig_min and update_max == orig_max: + return orig_hist + update_hist + + # If the orig hist only has one value (i.e., the min and max are the same) + # we can just add it into new histogram + if orig_min == orig_max: + bin_value = torch.sum(update_hist) + transformed_orig_hist = ( + torch.histc(orig_min, bins=self.bins, min=update_min, max=update_max) # type: ignore[arg-type] + * bin_value + ) + return transformed_orig_hist + update_hist + + # We assume the update_hist is already in the target range, we will map the orig_max to it + assert update_min <= orig_min + assert update_max >= orig_max + + # Now we need to turn the old_histogram, into the range of the new histogram + transformed_orig_hist = self._upscale_histogram( + orig_hist, + orig_min, + orig_max, + update_min, + update_max, + ) + + return update_hist + transformed_orig_hist + + def reset_histogram( + self, x: torch.Tensor, min_val: torch.Tensor, max_val: torch.Tensor + ) -> None: + self.min_val.resize_(min_val.shape) + self.min_val.copy_(min_val) + self.max_val.resize_(max_val.shape) + self.max_val.copy_(max_val) + assert min_val.numel() == 1 and max_val.numel() == 1, ( + "histogram min/max values must be scalar." + ) + new_histogram = torch.histc(x, self.bins, min=min_val, max=max_val) # type: ignore[arg-type] + self.histogram.detach_().resize_(new_histogram.shape) + self.histogram.copy_(new_histogram) + + def forward(self, x_orig: torch.Tensor) -> torch.Tensor: # pyre-ignore[14] + if x_orig.numel() == 0: + return x_orig + x = x_orig.detach() + x_min, x_max = torch.aminmax(x) + # want to ignore torch.inf since we don't actually + # want to make our quantization range infinite + # and in practice those values will be clamped + if x_min == -torch.inf or x_max == torch.inf: + warnings.warn("torch.inf detected in input tensor, ignoring input") + x = x[x.abs() != torch.inf] + if x.numel() == 0: + return x_orig + x_min, x_max = torch.aminmax(x) + + current_min = self.min_val + current_max = self.max_val + + is_uninitialized = self.min_val == float("inf") or self.max_val == float("-inf") + if is_uninitialized: + self.reset_histogram(x, x_min, x_max) + else: + update_min, update_max = x_min, x_max + new_min = torch.min(current_min, update_min) + new_max = torch.max(current_max, update_max) + + # TODO: For some reason, this is required for it to pass torchscript test + # new_min and new_max should already have requires_grad set to False + new_min, new_max = new_min.detach(), new_max.detach() + update_histogram = torch.histc( + x, + self.bins, + min=new_min, # type: ignore[arg-type] + max=new_max, # type: ignore[arg-type] + ).to(self.histogram.device) + if new_min == current_min and new_max == current_max: + combined_histogram = self.histogram + update_histogram + self.histogram.detach_().resize_(combined_histogram.shape) + self.histogram.copy_(combined_histogram) + else: + combined_histogram = self._combine_histograms( + self.histogram, + current_min, + current_max, + update_histogram, + new_min, + new_max, + ) + self.histogram.detach_().resize_(combined_histogram.shape) + self.histogram.copy_(combined_histogram) + self.min_val.detach_().resize_(new_min.shape) + self.min_val.copy_(new_min) + self.max_val.detach_().resize_(new_max.shape) + self.max_val.copy_(new_max) + + return x_orig + + @torch.jit.export + def calculate_qparams(self): # type: ignore[override] + is_uninitialized = self.min_val == float("inf") and self.max_val == float( + "-inf" + ) + if is_uninitialized: + warnings.warn( + "must run observer before calling calculate_qparams.\ + Returning default scale and zero point " + ) + return torch.tensor([1.0], device=self.min_val.device.type), torch.tensor( + [0], device=self.min_val.device.type + ) + assert self.bins == len(self.histogram), ( + "The number of bins in histogram should be equal to the number of bins " + "supplied while making this observer" + ) + + new_min, new_max = self._non_linear_param_search() + + return self._calculate_qparams(new_min, new_max) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + "min_val"] = self.min_val + destination[prefix + "max_val"] = self.max_val + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + + if version is None or version < 3: + # if min_val and max_val are not initialized, update their shape + # to account for the differences between v2 and v3 + min_val_name, max_val_name = prefix + "min_val", prefix + "max_val" + if min_val_name in state_dict: + if state_dict[min_val_name].shape == torch.Size([0]): + state_dict[min_val_name] = torch.tensor(float("inf")) + if max_val_name in state_dict: + if state_dict[max_val_name].shape == torch.Size([0]): + state_dict[max_val_name] = torch.tensor(float("-inf")) + + local_state = ["min_val", "max_val"] + for name in local_state: + key = prefix + name + if key in state_dict: + val = state_dict[key] + setattr(self, name, val) + elif strict: + missing_keys.append(key) + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def extra_repr(self): + return f"min_val={self.min_val}, max_val={self.max_val}" + + +class FixedQParamsObserver(ObserverBase): + r""" + Observer that simulates quantize and dequantize with fixed + quantization parameters in training time. Only per tensor + quantization is supported. + + Args: + `scale` (float): fixed scale for the observer + `zero_point` (int): fixed zero point for the observer + `dtype`, `qscheme`, `quant_min`, `quant_max` + """ + + scale: torch.Tensor + zero_point: torch.Tensor + + def __init__( + self, + scale, + zero_point, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + quant_min=0, + quant_max=255, + is_dynamic=False, + **kwargs, + ): + if is_dynamic: + raise NotImplementedError( + "FixedQParamsObserver doesn't support dynamic quantization" + ) + super().__init__(dtype=dtype, is_dynamic=is_dynamic, **kwargs) + self.quant_min = quant_min + self.quant_max = quant_max + self.register_buffer("scale", torch.tensor([scale], dtype=torch.float)) + self.register_buffer("zero_point", torch.tensor([zero_point], dtype=torch.int)) + self.dtype = dtype + self.qscheme = qscheme + + def forward(self, X): + return X + + @torch.jit.export + def calculate_qparams(self): # type: ignore[override] + return self.scale, self.zero_point + + +class PlaceholderObserver(ObserverBase): + r""" + Observer that doesn't do anything and just passes its configuration to the + quantized module's ``.from_float()``. + + Can be used for quantization to float16 which doesn't require determining + ranges. + + Args: + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. + quant_min: minimum value in quantized domain (TODO: align behavior with other observers) + quant_max: maximum value in quantized domain + custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation + (Can be used in Graph Mode Passes for special case ops). + compute_dtype (deprecated): if set, marks the future quantize function to use + dynamic quantization instead of static quantization. + This field is deprecated, use `is_dynamic=True` instead. + is_dynamic: if True, the `quantize` function in the reference model + representation taking stats from this observer instance will + use dynamic quantization. + """ + + def __init__( + self, + dtype=torch.float32, + custom_op_name="", + compute_dtype=None, + quant_min=None, + quant_max=None, + qscheme=None, + eps=None, + is_dynamic=False, + ) -> None: + super().__init__(dtype=dtype, is_dynamic=is_dynamic) + if qscheme is None: + qscheme = torch.per_tensor_affine + if eps is None: + eps = torch.finfo(torch.float32).eps + + # dtype of input of the target operator, e.g. for dynamic quantization + # ops, the dtype will be float32 + self.dtype = dtype + self.qscheme = qscheme + self.quant_min = quant_min + self.quant_max = quant_max + self.eps = eps + self.custom_op = custom_op_name + # used for configuration of computation type for dynamic quantization + if compute_dtype: + is_dynamic = True + warnings.warn( + "Please use `is_dynamic` instead of `compute_dtype`. \ + `compute_dtype` will be deprecated in a future release \ + of PyTorch." + ) + + def forward(self, x): + return x + + @torch.jit.export + def extra_repr(self): + return f"dtype={self.dtype}, is_dynamic={self.is_dynamic}" + + @torch.jit.export + def calculate_qparams(self): # type: ignore[override] + raise Exception( # noqa: TRY002 + "calculate_qparams should not be called for PlaceholderObserver" + ) + + +class RecordingObserver(ObserverBase): + r""" + The module is mainly for debug and records the tensor values during runtime. + + Args: + dtype: Quantized data type + qscheme: Quantization scheme to be used + reduce_range: Reduces the range of the quantized data type by 1 bit + """ + + __annotations__ = {"tensor_val": list[Optional[torch.Tensor]]} + + def __init__(self, dtype=torch.quint8): + super().__init__(dtype=dtype, is_dynamic=False) + self.tensor_val = [] + + def forward(self, x): + self.tensor_val.append(x.clone()) + return x + + @torch.jit.export + def calculate_qparams(self): # type: ignore[override] + raise Exception( # noqa: TRY002 + "calculate_qparams should not be called for RecordingObserver" + ) + + @torch.jit.export + def get_tensor_value(self): + return self.tensor_val + + +class NoopObserver(ObserverBase): + r""" + Observer that doesn't do anything and just passes its configuration to the + quantized module's ``.from_float()``. + + Primarily used for quantization to float16 which doesn't require determining + ranges. + + Args: + dtype: Quantized data type + custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation + (Can be used in Graph Mode Passes for special case ops). + """ + + def __init__(self, dtype=torch.float16, custom_op_name="") -> None: + super().__init__(dtype=dtype, is_dynamic=False) + self.dtype = dtype + self.custom_op = custom_op_name + + def forward(self, x): + return x + + @torch.jit.export + def calculate_qparams(self): # type: ignore[override] + raise Exception( # noqa: TRY002 + "calculate_qparams should not be called for NoopObserver" + ) + + +class ReuseInputObserver(ObserverBase): + r"""This observer is used when we want to reuse the observer from the operator + that produces the input Tensor, typically used for operators like reshape, e.g. + ``` + x0 = ... + x1 = x0.reshape() + ``` + if we configure x0 to be observed by some observer, let's say MinMaxObserver, + and reshape is configured with ReuseInputObserver, we'll reuse the observer instance + for x0 for x1 (output of reshape). If x0 is not observed, we also won't observe x1. + + Note: this is only enabled in FX Graph Mode Quantization + """ + + def __init__(self) -> None: + super().__init__(torch.quint8, is_dynamic=False) + + def forward(self, x): + return x + + @torch.jit.export + def calculate_qparams(self): # type: ignore[override] + raise Exception( # noqa: TRY002 + "calculate_qparams should not be called for ReuseInputObserver" + ) + + +""" +# Experimental Affine Quantization Feature START +We plan to merge the following with torchao repo after we move pt2e flow to torchao +copied from https://github.com/pytorch/ao/blob/main/torchao/quantization/observer.py +""" +from dataclasses import dataclass +from enum import auto, Enum + + +class MappingType(Enum): + """How floating point number is mapped to integer number + + symmetric mapping means floating point range is symmetrically mapped to integer range + let's say we have floating point range (-3.5, 10.2) and integer range (-8, 7) (int4) + we'll use (-10.2, 10.2) as the range for floating point and map that to (-8, 7) + e.g. scale = (10.2 - (-10.2)) / (7 - (-8)) + + SYMMETRIC_NO_CLIPPING_ERR is a variant of symmetric mapping, where the scale is the max of smin + and smax, where smin = min_val_neg / quant_min, and smax = max_val_pos / quant_max. By calculating + smin and smax individually, there can be less round error on negative values, and no out-of-range + of all floating point values. + + asymmetric mapping means we just directly map the floating point range to integer range, + for the above example, we will map (-3.5, 10.2) to (-8, 7) and calculate quantization parameter + based on this mapping + e.g. scale = (10.2 - (-3.5)) / (7 - (-8)) + """ + + SYMMETRIC = auto() + SYMMETRIC_NO_CLIPPING_ERR = auto() + ASYMMETRIC = auto() + + +class ZeroPointDomain(Enum): + """Enum that indicate whether zero_point is in integer domain or floating point domain + + integer domain: quantized_val = (float_val / scale) (integer) + zero_point (integer) + float domain: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale + none domain: quantized_val = (float_val / scale) + """ + + INT = auto() + FLOAT = auto() + NONE = auto() + + +class TorchAODType(Enum): + """ + Placeholder for dtypes that do not exist in PyTorch core yet. + """ + + # torch.int1 to torch.int7 will be added to PyTorch 2.6 + # These will remain here for BC with older PyTorch versions + INT1 = auto() + INT2 = auto() + INT3 = auto() + INT4 = auto() + INT5 = auto() + INT6 = auto() + INT7 = auto() + + +@dataclass(frozen=True) +class Granularity: + """ + Base class for representing the granularity of quantization. + + This class serves as a parent for specific granularity types used in + quantization operations, such as per-tensor or per-axis quantization. + """ + + +@dataclass(frozen=True) +class PerBlock(Granularity): + """ + Represents per-block granularity in quantization. See + :func:`~torchao.quantization.quant_primitives.quantize_affine` for docs for + `block_size` + + Attributes: + block_size (Tuple[int, ...]): The size of each quantization group + """ + + block_size: tuple[int, ...] + + +@dataclass(frozen=True) +class PerTensor(Granularity): + """ + Represents per-tensor granularity in quantization. + + This granularity type calculates the quantization parameters + based off the entire tensor. + + """ + + +@dataclass(frozen=True) +class PerAxis(Granularity): + """ + Represents per-axis granularity in quantization. + + This granularity type calculates different quantization parameters + along a specified axis of the tensor. + + For example if the input tensor is shape [8, 16] and axis=0, then + the quantization parameters are calculated for each row of the tensor. + Giving a total of 8 quantization parameters. + + Attributes: + axis (int): The axis along which reduction is performed. + """ + + axis: int + + +@dataclass(frozen=True) +class PerGroup(Granularity): + """ + Represents per-channel group granularity in quantization. + + This granularity type calculates different quantization parameters + for each group of elements. + + For example if the input tensor is shape [8, 16], and the group size is 4, then + the input tensor is reshaped to [64, 4] + quantization parameters are calculated for each group of 4 elements, + giving a total of 64 quantization parameters. + + Attributes: + group_size (int): The size of each quantization group + + """ + + group_size: int + + +class PerRow(Granularity): + """ + Represents row-wise granularity in quantization. + + This is a special case of per-axis quantization and is unique to Float8 matmuls + where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight + is quantized with a block_size of (1, weight.shape[1]). + """ + + +class PerToken(Granularity): + """ + Represents per-token granularity in quantization. + + This granularity type calculates a different set of quantization parameters + for each token, which is represented as the last dimension of the tensor. + + For example, if the input tensor has shape [2, 3, 4], then there are 6 tokens + with 4 elements each, and we will calculate 6 sets of quantization parameters, + one for each token. + + If the input tensor has only two dimensions, e.g. [8, 16], then this is + equivalent to `PerAxis(axis=0)`, which yields 8 sets of quantization parameters. + """ + + +def get_block_size( + input_shape: tuple[int, ...], granularity: Granularity +) -> tuple[int, ...]: + """Get the block size based on the input shape and granularity type. + + Args: + input_shape: The input tensor shape possibly more than 2 dimensions + granularity: The granularity type of the quantization + """ + assert isinstance(granularity, Granularity), ( + "Please provide an instance of Granularity, not subclass of it" + ) + if isinstance(granularity, PerTensor): + return input_shape + elif isinstance(granularity, PerAxis): + block_size = list(input_shape) + block_size[granularity.axis] = 1 + return tuple(block_size) + elif isinstance(granularity, PerRow): + return (1,) * (len(input_shape) - 1) + (input_shape[-1],) + elif isinstance(granularity, PerGroup): + assert len(input_shape) == 2, ( + f"Expecting input shape dim to be 2 for per group quantization, gotinput shape: {input_shape}" + ) + return (1, granularity.group_size) + elif isinstance(granularity, PerToken): + block_size = [1] * len(input_shape) + block_size[-1] = input_shape[-1] + return tuple(block_size) + raise ValueError(f"Unsupported Granularity: {granularity}") + + +class AffineQuantizedObserverBase(ABC, torch.nn.Module): + """Observer module for affine quantization (https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization) + + Args: + `granularity` and `block_size`: The granularity of the quantization, + must specify at least one, if both are specified `block_size` takes precedence + Current supported granularity type are `PerTensor` and `PerAxis` + other args: please see `:class:torchao.dtypes.AffineQuantizedTensor` + """ + + with_args = classmethod(_with_args) + + def __init__( + self, + mapping_type: MappingType, + target_dtype: torch.dtype, + granularity: Granularity, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero: bool = True, + zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + # there could be some extra args that's ignored + **kwargs, + ): + super().__init__() + assert granularity is not None, "granularity is None" + + self.mapping_type = mapping_type + self.target_dtype = target_dtype + self.granularity = granularity + self.quant_min = quant_min + self.quant_max = quant_max + self.eps = eps + self.scale_dtype = scale_dtype + self.zero_point_dtype = zero_point_dtype + self.preserve_zero = preserve_zero + self.zero_point_domain = zero_point_domain + # populatd during forward + self.block_size = None + self.original_dtype = None + + @abstractmethod + def forward(self, input: torch.Tensor) -> torch.Tensor: + """forward function should take the input tensor + and updates internal stats and return the original input Tensor + """ + + @abstractmethod + def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]: + """Calculate quantization parameter based on the stats attached to the observer module + and returns a tuple of scale and zero_point Tensor + """ + + def convert(self, model: torch.fx.GraphModule, observer_node: Node): + """ + Converts the observer node in the graph into its quantized representation + + Args: + model: graph module to conver the observer node in + observer_node: the observer node to convert + """ + from torch.ao.quantization.fx.utils import create_getattr_from_value + + with model.graph.inserting_before(observer_node): + assert self.block_size is not None, "Expecting block_size to be populated" + assert self.original_dtype is not None, ( + "Expecting original_dtype to be populated" + ) + if hasattr(self, "is_dynamic") and self.is_dynamic: + choose_qparams_affine = model.graph.call_function( + torch.ops.pt2e_quant.choose_qparams_affine, + ( + observer_node.args[0], + self.mapping_type.name, + self.block_size, + self.target_dtype, + self.quant_min, + self.quant_max, + self.eps, + self.scale_dtype, + self.zero_point_dtype, + self.preserve_zero, + self.zero_point_domain.name, + ), + ) + scale_node = model.graph.call_function( + operator.getitem, (choose_qparams_affine, 0) + ) + zero_point_node = model.graph.call_function( + operator.getitem, (choose_qparams_affine, 1) + ) + else: + scale, zero_point = self.calculate_qparams() + scale_node = create_getattr_from_value( + model, model.graph, "_scale", scale + ) + zero_point_node = create_getattr_from_value( + model, model.graph, "_zero_point", zero_point + ) + + q_node = model.graph.call_function( + torch.ops.pt2e_quant.quantize_affine, + ( + observer_node.args[0], + self.block_size, + scale_node, + zero_point_node, + self.target_dtype, + self.quant_min, + self.quant_max, + self.zero_point_domain.name, + ), + {}, + ) + dq_node = model.graph.call_function( + torch.ops.pt2e_quant.dequantize_affine, + ( + q_node, + self.block_size, + scale_node, + zero_point_node, + self.target_dtype, + self.quant_min, + self.quant_max, + self.zero_point_domain.name, + ), + {"output_dtype": self.original_dtype}, + ) + observer_node.replace_all_uses_with(dq_node) + model.graph.erase_node(observer_node) + + +def _is_observer_script_module(mod, obs_type_name): + """Returns true if given mod is an instance of Observer script module.""" + if isinstance(mod, torch.jit.RecursiveScriptModule): + # qualified name looks like '__torch__.torch.ao.quantization.observer.___torch_mangle_2.MinMaxObserver' + suffix = mod._c.qualified_name.split(".", 1)[1] + name = re.sub(r"\.___torch_mangle_\d+", "", suffix) + return obs_type_name in name + return False + + +# Experimental Affine Quantization Feature END + + +def _is_activation_post_process(module): + return isinstance( + module, + ( + torch.ao.quantization.ObserverBase, + torch.ao.quantization.FakeQuantizeBase, + AffineQuantizedObserverBase, + ), + ) or _is_observer_script_module(module, "quantization.observer") + + +def _is_per_channel_script_obs_instance(module): + if isinstance(module, torch.jit.RecursiveScriptModule): + return _is_observer_script_module( + module, "quantization.observer.PerChannelMinMaxObserver" + ) or _is_observer_script_module( + module, "quantization.observer.MovingAveragePerChannelMinMaxObserver" + ) + return False + + +def get_observer_state_dict(mod): + r""" + Returns the state dict corresponding to the observer stats. + Traverse the model state_dict and extract out the stats. + """ + od = OrderedDict() + if isinstance(mod, torch.jit.RecursiveScriptModule): + for k, v in mod.state_dict().items(): + if "observer" in k: + od[k] = v + else: + # path for GraphModule and nn.Module (eager mode) + for k, v in mod.state_dict().items(): + if "activation_post_process" in k: + od[k] = v + od._metadata = mod.state_dict()._metadata # type: ignore[attr-defined] + return od + + +def load_observer_state_dict(mod, obs_dict): + r""" + Given input model and a state_dict containing model observer stats, + load the stats back into the model. The observer state_dict can be saved + using torch.ao.quantization.get_observer_state_dict + """ + missing_keys: list[str] = [] + unexpected_keys: list[str] = [] + for name, module in mod.named_modules(): + prefix = name + "." + if _is_activation_post_process(module): + if _is_per_channel_script_obs_instance(module): + # For per-channel observers we need to call a custom load_from_state_dict to resize the tensor. + # However this is not called when the module is scripted and we end up calling the default one in module.py + module._load_from_state_dict_script( + obs_dict, prefix, {}, True, missing_keys, unexpected_keys, [] + ) + else: + module._load_from_state_dict( + obs_dict, prefix, {}, False, missing_keys, unexpected_keys, [] + ) + for k in missing_keys: + if "observer" in k or "activation_post_process" in k: + raise Exception( # noqa: TRY002 + f"Missing keys for observer {k} in state_dict" + ) + for k in unexpected_keys: + if "observer" in k or "activation_post_process" in k: + raise Exception( # noqa: TRY002 + f"Unexpected keys for observer {k} in state_dict" + ) + + +# Restrict activations to be in the range (0,127) +default_observer = MinMaxObserver.with_args(quant_min=0, quant_max=127) +""" +Default observer for static quantization, usually used for debugging. +""" + +default_placeholder_observer = PlaceholderObserver +""" +Default placeholder observer, usually used for quantization to torch.float16. +""" + +default_debug_observer = RecordingObserver +""" +Default debug-only observer. +""" + +default_weight_observer = MinMaxObserver.with_args( + dtype=torch.qint8, qscheme=torch.per_tensor_symmetric +) +""" +Default weight observer. +""" + +weight_observer_range_neg_127_to_127 = MinMaxObserver.with_args( + dtype=torch.qint8, + qscheme=torch.per_tensor_symmetric, + quant_min=-127, + quant_max=127, + eps=2**-12, +) +""" +Symmetric weight observer with the 8-bit values restricted to [-127, +127], excluding -128. +""" + +default_histogram_observer = HistogramObserver.with_args(quant_min=0, quant_max=127) +""" +Default histogram observer, usually used for PTQ. +""" + +default_per_channel_weight_observer = PerChannelMinMaxObserver.with_args( + dtype=torch.qint8, qscheme=torch.per_channel_symmetric +) +""" +Default per-channel weight observer, usually used on backends where per-channel +weight quantization is supported, such as `fbgemm`. +""" + +per_channel_weight_observer_range_neg_127_to_127 = PerChannelMinMaxObserver.with_args( + dtype=torch.qint8, + qscheme=torch.per_channel_symmetric, + quant_min=-127, + quant_max=127, + eps=2**-12, +) +""" +Per-channel, symmetric weight observer with the 8-bit values restricted to [-127, +127], excluding -128. +""" + +default_dynamic_quant_observer = PlaceholderObserver.with_args( + dtype=torch.quint8, + quant_min=0, + quant_max=255, + is_dynamic=True, +) +""" +Default observer for dynamic quantization. +""" + +default_float_qparams_observer = PerChannelMinMaxObserver.with_args( + dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0 +) +""" +Default observer for a floating point zero-point. +""" + +default_float_qparams_observer_4bit = PerChannelMinMaxObserver.with_args( + dtype=torch.quint4x2, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0 +) +""" +Default observer for a floating point zero-point and 4 bit activations. +""" + +# TODO(future PR): remove these defaults and enforce activation functions +# to explicitly specify their output range +default_fixed_qparams_range_neg1to1_observer = FixedQParamsObserver.with_args( + scale=2.0 / 256.0, zero_point=128, dtype=torch.quint8, quant_min=0, quant_max=255 +) +default_fixed_qparams_range_0to1_observer = FixedQParamsObserver.with_args( + scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8, quant_min=0, quant_max=255 +) +# TODO: the following 2 variables are kept for backwards compatibility; remove after a few releases +default_symmetric_fixed_qparams_observer = default_fixed_qparams_range_neg1to1_observer +default_affine_fixed_qparams_observer = default_fixed_qparams_range_0to1_observer + +""" +Default observers for fixed qparams operations. +""" + +default_reuse_input_observer = ReuseInputObserver +""" +Default observer for operators like reshape that reuses the observer of input to +the operator +""" diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/__init__.py b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12c997846360fc7df3872ed8a6d46c65f44373e2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/_affine_quantization.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/_affine_quantization.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4178a24c67726775f1ff9930a2c24c3e74f0bb8d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/_affine_quantization.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/_numeric_debugger.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/_numeric_debugger.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97af183f7d24a17ac581bd2b9cb725e1f72a6333 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/_numeric_debugger.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/duplicate_dq_pass.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/duplicate_dq_pass.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff1fe1cee633f1f420a66d37d1ad6a3c10d80351 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/duplicate_dq_pass.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/export_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/export_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b850408bf6c38c5b1694bb13372e345890bfe61d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/export_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/graph_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/graph_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da66b3ed53fe2e02d539306a64c44c5b858cad66 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/graph_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/lowering.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/lowering.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b7141d4c32432eb227ab4ef3cd16dc2db9dccae Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/lowering.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/port_metadata_pass.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/port_metadata_pass.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8880b25e14babef51fc38db5a644f2354957198 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/port_metadata_pass.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/prepare.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/prepare.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c083b5963fb172715540f49390a53cbb8268d7be Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/prepare.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/qat_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/qat_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a6c43c0b154c4fbb1c436efcfc5857f35ba1ecf Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/qat_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aec2442eaaff1be22b9926582aaf600f86fe82ac Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/_affine_quantization.py b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/_affine_quantization.py new file mode 100644 index 0000000000000000000000000000000000000000..d1145cd35deded37388a5a7dc1c95c482d445e43 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/_affine_quantization.py @@ -0,0 +1,866 @@ +# copied from https://github.com/pytorch/ao/blob/main/torchao/quantization/observer.py +# and https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py +# PLESE DON'T MODIFY THIS FILE SO THAT WE DON'T GET OUT OF SYNC +import logging +from abc import ABCMeta +from typing import Any, Optional, Union + +import torch +from torch.ao.quantization.observer import ( + AffineQuantizedObserverBase, + get_block_size, + Granularity, + MappingType, + TorchAODType, + ZeroPointDomain, +) + + +ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3: + +logger = logging.getLogger(__name__) + +FP8_TYPES = { + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float8_e4m3fnuz, + torch.float8_e5m2fnuz, +} +_SUB_BYTE_UINT_BOUNDS = { + torch.uint1: (0, 2**1 - 1), + torch.uint2: (0, 2**2 - 1), + torch.uint3: (0, 2**3 - 1), + torch.uint4: (0, 2**4 - 1), + torch.uint5: (0, 2**5 - 1), + torch.uint6: (0, 2**6 - 1), + torch.uint7: (0, 2**7 - 1), +} + +""" +Map from dtype to the bound value of integers +TODO: maybe can replace this with call to torch.iinfo +""" +_DTYPE_TO_QVALUE_BOUNDS: dict[Union[torch.dtype, TorchAODType], tuple[int, int]] = { + torch.uint8: (0, 255), + torch.int8: (-128, 127), + torch.int16: (-(2**15), 2**15 - 1), + torch.int32: (-(2**31), 2**31 - 1), +} +_DTYPE_TO_QVALUE_BOUNDS.update(_SUB_BYTE_UINT_BOUNDS) + + +def _is_float8_type(dtype: torch.dtype) -> bool: + fp8_types = { + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + torch.float8_e5m2, + torch.float8_e5m2fnuz, + } + return dtype in fp8_types + + +# TODO: decide on if we want to allow custom quant_min/quant_max here +def _get_and_check_qmin_qmax(dtype, quant_min, quant_max): + """Get quant_min and quant_max args based on dtype and also + verify that they are within the range of possible quant_min/quant_max + for dtype + """ + if dtype in FP8_TYPES: + quant_min_lower_bound, quant_max_upper_bound = ( + torch.finfo(dtype).min, + torch.finfo(dtype).max, + ) + elif dtype not in _DTYPE_TO_QVALUE_BOUNDS: + raise ValueError(f"Unsupported dtype: {dtype}") + else: + quant_min_lower_bound, quant_max_upper_bound = _DTYPE_TO_QVALUE_BOUNDS[dtype] + if quant_min is None: + quant_min = quant_min_lower_bound + if quant_max is None: + quant_max = quant_max_upper_bound + + assert quant_min >= quant_min_lower_bound, ( + "quant_min out of bound for dtype, " + f"quant_min_lower_bound: {quant_min_lower_bound} quant_min: {quant_min}" + ) + + assert quant_max <= quant_max_upper_bound, ( + "quant_max out of bound for dtype, " + f"quant_max_upper_bound: {quant_max_upper_bound} quant_max: {quant_max}" + ) + return quant_min, quant_max + + +def _get_reduction_params(block_size, input_size): + """Given block_size and input size find the parameters for reduction: + + Output: + shape_for_reduction: the shape we use to `view` input to prepare it for reduction + reduction_dims: the dims we'll do reduction over + + Example:: + Input: + block_size: (3, 3, 2, 10) + input_size: (3, 3, 10, 10) + + Output: + shape_for_reduction: (3, 3, 5, 2, 10) + reduction_dim: [0, 1, 3, 4] + """ + assert len(block_size) == len(input_size) + shape_for_reduction = [] + reduction_dims = [] + cur_dim = 0 + for i in range(len(block_size)): + if block_size[i] != input_size[i] and block_size[i] > 1: + assert input_size[i] % block_size[i] == 0, ( + f"Expecting input size at {i} dimension: " + f"{input_size[i]} to be divisible by block_size at {i} dimension: {block_size[i]}" + ) + shape_for_reduction.append(input_size[i] // block_size[i]) + shape_for_reduction.append(block_size[i]) + # reduce over the block_size[i] dim + reduction_dims.append(cur_dim + 1) + cur_dim += 2 + else: + # block_size[i] == input_size[i] or block_size[i] == 1 + shape_for_reduction.append(input_size[i]) + # we only need to reduce over the dimension if block_size is greater than 1 + # otherwise it's already the same as reduced dimension + if block_size[i] != 1: + reduction_dims.append(cur_dim) + cur_dim += 1 + return shape_for_reduction, reduction_dims + + +def _register_custom_op(lib): + """This decorator is used to preserve some high level operators for torch.export.export + while still allow them to be decomposed for inductor path + + requirement: make sure `fn.__name__[1:]` is the operator name you want to register + + NOTE: This should be applied at the top, after all other decorators have been applied + NOTE: We haven't tested the case when `fn` accepts tensor subclass instance as input, + e.g. uint4 tensor subclass instance, and we'll probably need to figure out what would make + sense for downstream system (like executorch) to accept as well + + Example: + lib = torch.library.Library("my_namespace', "FRAGMENT") + + register_custom_op = _register_custom_op(lib) + + @register_custom_op + def _the_op_that_needs_to_be_preserved(...) + ... + + # after this, `_the_op_that_needs_to_be_preserved` will be preserved as + # torch.ops.my_namespace.the_op_that_needs_to_be_preserved operator after + # torch.export.export / torch._export.export_for_training + + """ + from torch._inductor.decomposition import register_decomposition + + def decorator(fn): + from torch._library.infer_schema import infer_schema + + # expecting fn.__name__ starts with `_` and we want to take the rest + # to be the name of the custom op + assert fn.__name__[0] == "_", ( + f"Expecting function name starts with `_`, got {fn.__name__}" + ) + assert not any(c in fn.__name__ for c in ".<>"), ( + f"Expecting op to be defined in normal functions, not lambda or local: {fn.__name__}" + ) + op_name = fn.__name__[1:] + schema = op_name + infer_schema(fn, mutates_args={}) + lib.define(schema) + lib.impl(op_name, fn, "CompositeImplicitAutograd") + + lib_namespace = lib.ns + op = getattr(getattr(torch.ops, lib_namespace), op_name) + register_decomposition([op])(fn) + return op + + return decorator + + +quant_lib = torch.library.Library("pt2e_quant", "FRAGMENT") # noqa: TOR901 + +register_custom_op = _register_custom_op(quant_lib) + + +def choose_qparams_affine_with_min_max( + min_val: torch.Tensor, + max_val: torch.Tensor, + mapping_type: MappingType, + block_size: tuple[int, ...], + target_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero: bool = True, + zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, +) -> tuple[torch.Tensor, torch.Tensor]: + """A variant of :func:`~torchao.quantization.quant_primitives.choose_qparams_affine` + operator that pass in min_val and max_val directly instead of deriving these from a single input. + This is used for observers in static quantization where min_val and max_val may be obtained through + tracking all the data in calibration data set. + + Args: + Mostly same as :func:`~torchao.quantization.quant_primitives.choose_qparams_affine`. with one + difference: instead of passing in `input` Tensor and use that to calculate min_val/max_val + and then scale/zero_point, we pass in min_val/max_val directly + """ + return _choose_qparams_affine( + None, + mapping_type.name, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + scale_dtype, + zero_point_dtype, + preserve_zero, + zero_point_domain.name if zero_point_domain is not None else None, + min_val, + max_val, + ) + + +@register_custom_op +def _choose_qparams_affine( + input: Optional[torch.Tensor], + mapping_type: str, + block_size: list[int], + target_dtype: torch.dtype, + quant_min: Optional[Union[int, float, bool]] = None, + quant_max: Optional[Union[int, float, bool]] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero: bool = True, + zero_point_domain: Optional[str] = "INT", + min_val: Optional[torch.Tensor] = None, + max_val: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """op definition that has compatible signatures with custom op library + + The op does the following: + 1. figure out the dimension for reduction based on block_size + 2. find min_val/max_val based on the dimension for reduction + 3. calculate quantization parameters based on min_val/max_val based on args like `preserve_zero` + and `zero_point_domain` + """ + quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max) + assert mapping_type in [ + MappingType.SYMMETRIC.name, + MappingType.SYMMETRIC_NO_CLIPPING_ERR.name, + MappingType.ASYMMETRIC.name, + ], f"Unsupported mapping type: {mapping_type}" + if target_dtype in FP8_TYPES: + assert mapping_type == MappingType.SYMMETRIC.name, ( + f"Only symmetric quantization is supported for FP8 types, got {mapping_type}" + ) + + if input is not None: + if scale_dtype is None: + scale_dtype = input.dtype + if zero_point_dtype is None: + zero_point_dtype = input.dtype + if eps is None: + eps = torch.finfo(input.dtype).eps + + assert len(block_size) == input.dim(), ( + f"Got input dim:{input.dim()}, block_size: {block_size}" + ) + shape_for_reduction, reduction_dims = _get_reduction_params( + block_size, input.size() + ) + input = input.view(shape_for_reduction) + + min_val = torch.amin(input, dim=reduction_dims, keepdim=False) + max_val = torch.amax(input, dim=reduction_dims, keepdim=False) + else: + assert min_val is not None and max_val is not None, ( + "Need to provide `min_val` and `max_val` when `input` is None, got: {min_val, max_val}" + ) + assert min_val.dtype == max_val.dtype, ( + "Expecting `min_val` and `max_val` to have the same dtype, got: {min_val.dtype, max_val.dtype}" + ) + + if scale_dtype is None: + scale_dtype = min_val.dtype + if zero_point_dtype is None: + zero_point_dtype = min_val.dtype + if eps is None: + eps = torch.finfo(min_val.dtype).eps + + if preserve_zero: + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + else: + min_val_neg = min_val + max_val_pos = max_val + + if ( + mapping_type == MappingType.SYMMETRIC.name + or mapping_type == MappingType.SYMMETRIC_NO_CLIPPING_ERR.name + ): + # scales + if mapping_type == MappingType.SYMMETRIC.name: + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scale = max_val_pos / (float(quant_max - quant_min) / 2) + else: + assert mapping_type == MappingType.SYMMETRIC_NO_CLIPPING_ERR.name + # calculate smin and smax individually and choose the larger one. For example, if quant_min = -8 and + # quant_max = 7. + # - If smin is bigger: There would be coverage on negative values down to -8, and less rounding + # error than the existing SYMMETRIC case. + # - If smax is bigger: it covers the positive values up to 7. The round + # error may be bigger than the existing SYMMETRIC case. Either way, there's no out-of-range fp values after + # quantization. + smin = min_val_neg / float(quant_min) + smax = max_val_pos / float(quant_max) + mask = smin > smax + scale = torch.where(mask, smin, smax) + # zeros + if not preserve_zero: + raise ValueError( + "preserve_zero == False is not supported for symmetric quantization" + ) + if ( + zero_point_domain is not None + and zero_point_domain != ZeroPointDomain.INT.name + ): + raise ValueError( + "zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization" + ) + scale = torch.clamp(scale, min=eps) + zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2)) + else: + assert mapping_type == MappingType.ASYMMETRIC.name + scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) + scale = torch.clamp(scale, min=eps) + if zero_point_domain == ZeroPointDomain.NONE.name: + zero_point = None + else: + if preserve_zero: + zero_point = quant_min - torch.round(min_val_neg / scale) + zero_point = torch.clamp(zero_point, quant_min, quant_max) + else: + assert zero_point_domain == ZeroPointDomain.FLOAT.name, ( + "if not preserve_zero, zero_point must be in FLOAT domain" + ) + mid_point = (quant_max + quant_min + 1) / 2 + zero_point = min_val_neg + scale * mid_point + + if zero_point is not None: + zero_point = zero_point.to(dtype=zero_point_dtype) + return scale.to(dtype=scale_dtype), zero_point + + +@torch.no_grad() +def quantize_affine( + input: torch.Tensor, + block_size: tuple[int, ...], + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + output_dtype: torch.dtype, + quant_min: Optional[Union[int, float]] = None, + quant_max: Optional[Union[int, float]] = None, + zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, +) -> torch.Tensor: + """ + Args: + input (torch.Tensor): original float32, float16 or bfloat16 Tensor + block_size: (Tuple[int, ...]): granularity of quantization, + this means the size of the tensor elements that's sharing the same qparam + e.g. when size is the same as the input tensor dimension, we are using per tensor quantization + scale (float): quantization parameter for affine quantization + zero_point (int): quantization parameter for affine quantization + output_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor + quant_min (Optional[int]): minimum quantized value for output Tensor, if not specified, it will be derived from dtype + quant_max (Optional[int]): maximum quantized value for output Tensor, if not specified, it will be derived from dtype + zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float + if zero_point is in integer domain, zero point is added to the quantized integer value during + quantization + if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) + value during quantization + default is ZeroPointDomain.INT + + Note: + How can block_size represent different granularities? + let's say we have a Tensor of size: (3, 3, 10, 10), here is the table showing how block_size represents different + granularities: + + granularity type | block_size + per_tensor | (3, 3, 10, 10) + per_axis (axis=0) | (1, 3, 10, 10) + per_axis (axis=1) | (3, 1, 10, 10) + per_group (groupsize=2) | (3, 3, 10, 2) + per_group (groupsize=2) for axis = 3 | (3, 3, 2, 10) + + + Output: + quantized tensor with requested dtype + """ + return _quantize_affine( + input, + block_size, + scale, + zero_point, + output_dtype, + quant_min, + quant_max, + zero_point_domain.name if zero_point_domain is not None else None, + ) + + +@register_custom_op +def _quantize_affine( + input: torch.Tensor, + block_size: list[int], + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + output_dtype: torch.dtype, + quant_min: Optional[Union[int, float, bool]] = None, + quant_max: Optional[Union[int, float, bool]] = None, + zero_point_domain: Optional[str] = ZeroPointDomain.INT.name, +) -> torch.Tensor: + """op definition that has compatible signatures with custom op library + + Note: + zero_point_domain is optional specifies how we quantize the floating point to quantized data: + INT: quantized_val = (float_val / scale) (integer) + zero_point (integer) + FLOAT: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale + None: quantized_val = (float_val / scale) | this is primarily used for floatx quantization + Where we do not want to round values to nearest integer and instead scale and cast. + """ + quant_min, quant_max = _get_and_check_qmin_qmax(output_dtype, quant_min, quant_max) + # workaround for uintx dtypes, since we don't have native Uintx dtype connected with + # torch.uintx dtypes yet + if output_dtype in _SUB_BYTE_UINT_BOUNDS: + output_dtype = torch.uint8 + return _quantize_affine_no_dtype_cast( + input, + block_size, + scale, + zero_point, + quant_min, + quant_max, + zero_point_domain, + ).to(output_dtype) + + +def _quantize_affine_no_dtype_cast( + input: torch.Tensor, + block_size: list[int], + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + quant_min: Union[int, float], + quant_max: Union[int, float], + zero_point_domain: Optional[str] = ZeroPointDomain.INT.name, +) -> torch.Tensor: + """ + The op does the following: + 1. figure out the dimension for reduction based on block_size, also reshape the input to align with + the shape after reduction + 2. quantize the input based on the quantization parameters scale and zero_point and args like zero_point_domain + 3. reshape the quantized result to origianl shape + """ + # TODO: validations + # TODO: validate scale/zero_point dimensions are compatible with block_size + assert input.dtype in [ + torch.float32, + torch.float16, + torch.bfloat16, + ], f"Unsupported input dtype: {input.dtype}" + assert len(block_size) == input.dim(), ( + f"Got input dim:{input.dim()}, block_size: {block_size}" + ) + shape_for_reduction, reduction_dims = _get_reduction_params( + block_size, input.size() + ) + original_shape = input.shape + input = input.view(shape_for_reduction) + shape_after_reduction = shape_for_reduction + for i in reduction_dims: + shape_after_reduction[i] = 1 + scale = scale.view(shape_after_reduction) + if zero_point is not None: + zero_point = zero_point.view(shape_after_reduction) + + if zero_point_domain == ZeroPointDomain.INT.name: + quant = torch.clamp( + torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max + ) + elif zero_point_domain == ZeroPointDomain.NONE.name: + assert zero_point is None, ( + "zero_point should be None when zero_point_domain is NONE" + ) + quant = torch.clamp(torch.round(input * (1.0 / scale)), quant_min, quant_max) + elif zero_point_domain is None: + # This case handles quantization for float8 we expect no zero point and no zero point domain + assert zero_point is None, ( + "zero_point should be None when zero_point_domain is None" + ) + quant = torch.clamp(input * scale.reciprocal(), quant_min, quant_max) + else: + assert zero_point_domain == ZeroPointDomain.FLOAT.name + mid_point = (quant_max + quant_min + 1) / 2 + min_val = zero_point - scale * mid_point + quant = torch.clamp( + torch.round((input - min_val) / scale), quant_min, quant_max + ) + quant = quant.view(original_shape) + + return quant + + +def dequantize_affine( + input: torch.Tensor, + block_size: tuple[int, ...], + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + input_dtype: torch.dtype, + quant_min: Optional[Union[int, float]] = None, + quant_max: Optional[Union[int, float]] = None, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + *, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + Args: + input (torch.Tensor): quantized tensor, should match the dtype `dtype` argument + block_size: (List[int]): granularity of quantization, + this means the size of the tensor elements that's sharing the same qparam + e.g. when size is the same as the input tensor dimension, we are using per tensor quantization + scale (Tensor): quantization parameter for affine quantization + zero_point (Tensor): quantization parameter for affine quantization + input_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor + quant_min (Optional[int]): minimum quantized value for input Tensor + quant_max (Optional[int]): maximum quantized value for input Tensor + output_dtype (torch.dtype): dtype for output Tensor, default is fp32 + zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float + if zero_point is in integer domain, zero point is added to the quantized integer value during + quantization + if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) + value during quantization + default is ZeroPointDomain.INT + + Output: + dequantized Tensor, with requested dtype or fp32 + """ + return _dequantize_affine( + input, + block_size, + scale, + zero_point, + input_dtype, + quant_min, + quant_max, + zero_point_domain.name if zero_point_domain is not None else None, + output_dtype=output_dtype, + ) + + +@register_custom_op +def _dequantize_affine( + input: torch.Tensor, + block_size: list[int], + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + input_dtype: torch.dtype, + quant_min: Optional[Union[int, float, bool]] = None, + quant_max: Optional[Union[int, float, bool]] = None, + zero_point_domain: Optional[str] = ZeroPointDomain.INT.name, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """op definition that has compatible signatures with custom op library""" + # TODO: validate scale/zero_point dimensions are compatible with block_size + if input_dtype not in _SUB_BYTE_UINT_BOUNDS: + assert input.dtype == input_dtype, ( + f"Expected: {input_dtype}, got: {input.dtype}" + ) + assert output_dtype in [ + torch.float32, + torch.float16, + torch.bfloat16, + ], f"Unsupported output dtype: {output_dtype}" + quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max) + return _dequantize_affine_no_dtype_check( + input, + block_size, + scale, + zero_point, + quant_min, + quant_max, + zero_point_domain, + output_dtype, + ) + + +def _dequantize_affine_no_dtype_check( + input: torch.Tensor, + block_size: list[int], + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + quant_min: Union[int, float], + quant_max: Union[int, float], + zero_point_domain: Optional[str] = ZeroPointDomain.INT.name, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """This function converts AQT tensors to their high precision floating point representation + + The op does the following: + 1. figure out the dimension for reduction based on block_size, also reshape the input to align with + the shape after reduction + 2. dequantize the input based on the quantization parameters scale and zero_point and args like zero_point_domain + 3. reshape the quantized result to origianl shape and change dtype to the output_dtype + """ + assert len(block_size) == input.dim(), ( + f"Got input dim:{input.dim()}, block_size: {block_size}" + ) + shape_for_reduction, reduction_dims = _get_reduction_params( + block_size, input.size() + ) + original_shape = input.shape + input = input.view(shape_for_reduction) + shape_after_reduction = shape_for_reduction + for i in reduction_dims: + shape_after_reduction[i] = 1 + scale = scale.view(shape_after_reduction) + + if zero_point is not None: + zero_point = zero_point.view(shape_after_reduction) + + if zero_point_domain == ZeroPointDomain.INT.name: + # Force a copy to avoid input modification due + # to upcoming in-place operations. + dequant = input.to(torch.int32, copy=True) + if zero_point is not None: + dequant = dequant - zero_point.to(torch.int32) + dequant = dequant.to(output_dtype) + dequant = dequant * scale + elif zero_point_domain == ZeroPointDomain.NONE.name: + assert zero_point is None, ( + "zero_point should be None when zero_point_domain is NONE" + ) + dequant = input.to(output_dtype) + dequant = dequant * scale + elif zero_point_domain is None: + # This case handles dequantization for float8 we expect no zero point and no zero point domain + assert zero_point is None, ( + "zero_point should be None when zero_point_domain is None" + ) + assert _is_float8_type(input.dtype), ( + f"dequantiztion with no zero point domain is only supported with FP8 types, got {input.dtype}" + ) + dequant = input.to(output_dtype) + dequant = dequant * scale + else: + assert zero_point_domain == ZeroPointDomain.FLOAT.name, ( + f"Unexpected zero point domain: {zero_point_domain}" + ) + # TODO: this seems to be a detail for tinygemm (converting from uint to int, probably need to refactor this) + mid_point = (quant_max + quant_min + 1) / 2 + # This should allocate new memory and avoid input modification + dequant = input - mid_point + dequant = dequant.to(output_dtype) + dequant *= scale + if zero_point is not None: + dequant += zero_point + + return dequant.view(original_shape).to(output_dtype) + + +class AffineQuantizedMinMaxObserver(AffineQuantizedObserverBase): + def forward(self, input: torch.Tensor): + if input.numel() == 0: + return input + + input_detached = input.detach() + self.original_dtype = input_detached.dtype + assert self.granularity is not None, "granularity is None" + self.block_size = get_block_size(input_detached.shape, self.granularity) + + shape_for_reduction, reduction_dims = _get_reduction_params( + self.block_size, input_detached.size() + ) + input_detached = input_detached.view(shape_for_reduction) + min_val = torch.amin(input_detached, dim=reduction_dims, keepdim=False) + max_val = torch.amax(input_detached, dim=reduction_dims, keepdim=False) + if not hasattr(self, "min_val") or not hasattr(self, "max_val"): + self.min_val = min_val + self.max_val = max_val + else: + assert self.min_val.shape == min_val.shape, ( + f"Can't update existing min_val - shape mismatch, self.min_val:{self.min_val.shape} != min_val:{min_val.shape}" + ) + assert self.max_val.shape == max_val.shape, ( + f"Can't update existing max_val - shape mismatch, self.max_val {self.max_val.shape} != max_val:{max_val.shape}" + ) + min_val = torch.min(self.min_val, min_val) + max_val = torch.max(self.max_val, max_val) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + # returning original input + return input + + def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]: + assert hasattr(self, "min_val") and hasattr(self, "max_val"), ( + "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams" + ) + return choose_qparams_affine_with_min_max( + self.min_val, + self.max_val, + self.mapping_type, + [], # BlockSize is not needed because the min/max are already reduced + self.target_dtype, + self.quant_min, + self.quant_max, + self.eps, + self.scale_dtype, + self.zero_point_dtype, + self.preserve_zero, + self.zero_point_domain, + ) + + +class AffineQuantizedMovingAverageMinMaxObserver(AffineQuantizedObserverBase): + def __init__( + self, + mapping_type: MappingType, + target_dtype: torch.dtype, + granularity: Granularity, + averaging_constant=0.01, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + is_dynamic=False, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero: bool = True, + zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + # there could be some extra args that's ignored + **kwargs, + ): + self.is_dynamic = is_dynamic + self.averaging_constant = averaging_constant + if is_dynamic and self.averaging_constant != 1: + raise NotImplementedError( + "MovingAverageMinMaxObserver doesn't support dynamic quantization for " + f"averaging constant of {self.averaging_constant}" + ) + + super().__init__( + mapping_type=mapping_type, + target_dtype=target_dtype, + granularity=granularity, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + preserve_zero=preserve_zero, + zero_point_domain=zero_point_domain, + ) + + def forward(self, input: torch.Tensor): + if input.numel() == 0: + return input + + input_detached = input.detach() + self.original_dtype = input_detached.dtype + assert self.granularity is not None, "granularity is None" + self.block_size = get_block_size(input_detached.shape, self.granularity) + + shape_for_reduction, reduction_dims = _get_reduction_params( + self.block_size, input_detached.size() + ) + input_detached = input_detached.view(shape_for_reduction) + min_val = torch.amin(input_detached, dim=reduction_dims, keepdim=False) + max_val = torch.amax(input_detached, dim=reduction_dims, keepdim=False) + if not hasattr(self, "min_val") or not hasattr(self, "max_val"): + self.min_val = min_val + self.max_val = max_val + else: + assert self.min_val.shape == min_val.shape, ( + f"Can't update existing min_val - shape mismatch, self.min_val:{self.min_val.shape} != min_val:{min_val.shape}" + ) + assert self.max_val.shape == max_val.shape, ( + f"Can't update existing max_val - shape mismatch, self.max_val {self.max_val.shape} != max_val:{max_val.shape}" + ) + min_val = self.min_val + self.averaging_constant * (min_val - self.min_val) + max_val = self.max_val + self.averaging_constant * (max_val - self.max_val) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + + # returning original input + return input + + def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]: + assert hasattr(self, "min_val") and hasattr(self, "max_val"), ( + "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams" + ) + + return choose_qparams_affine_with_min_max( + self.min_val, + self.max_val, + self.mapping_type, + [], # BlockSize is not needed because the min/max are already reduced + self.target_dtype, + self.quant_min, + self.quant_max, + self.eps, + self.scale_dtype, + self.zero_point_dtype, + self.preserve_zero, + self.zero_point_domain, + ) + + +class AffineQuantizedPlaceholderObserver(AffineQuantizedObserverBase): + def __init__( + self, + mapping_type: MappingType, + target_dtype: torch.dtype, + granularity: Granularity, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + is_dynamic=False, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero: bool = True, + zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + # there could be some extra args that's ignored + **kwargs, + ): + self.is_dynamic = is_dynamic + + super().__init__( + mapping_type=mapping_type, + target_dtype=target_dtype, + granularity=granularity, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + preserve_zero=preserve_zero, + zero_point_domain=zero_point_domain, + ) + + def forward(self, input): + self.block_size = get_block_size(input.shape, self.granularity) + self.original_dtype = input.dtype + return input + + def calculate_qparams(self): + raise Exception( # noqa: TRY002 + "calculate_qparams should not be called for PlaceholderObserver" + ) diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/_numeric_debugger.py b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/_numeric_debugger.py new file mode 100644 index 0000000000000000000000000000000000000000..a479f0cf7849199768f555f8fa19d5baa7c387cd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/_numeric_debugger.py @@ -0,0 +1,342 @@ +import copy +import logging +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Callable, Optional + +import torch +from torch.ao.ns.fx.utils import compute_sqnr +from torch.ao.quantization.pt2e.graph_utils import bfs_trace_with_node_process +from torch.export import ExportedProgram +from torch.fx import GraphModule, Node +from torch.nn import functional as F + + +NUMERIC_DEBUG_HANDLE_KEY = "numeric_debug_handle" +CUSTOM_KEY = "custom" + +log = logging.getLogger(__name__) + + +def generate_numeric_debug_handle(ep: ExportedProgram) -> None: + """ + Attach numeric_debug_handle_id for all nodes in the graph module of the given + ExportedProgram, like conv2d, squeeze, conv1d, etc, except for placeholder. + Notice that nodes like getattr are out of scope since they are not in the graph. + + The graph nodes of input exported program are modified inplace. + + Here's an example of using debug handle quantize flow:: + + ep = export_for_training(eager_model, example_inputs) + generate_numeric_debug_handle(ep) + + m = ep.module() + quantizer = XNNPACKQuantizer() + m = prepare_pt2e(m, quantizer) + m = convert_pt2e(m) + """ + + # Sanity check the input data type + if not isinstance(ep, ExportedProgram): + raise ValueError( + f"Expected ep to be ExportedProgram, got {type(ExportedProgram)}" + ) + + unique_id = 0 + + def _find_max_id(node: torch.fx.Node) -> None: + nonlocal unique_id + unique_id = max( + unique_id, node.meta.get(CUSTOM_KEY, {}).get(NUMERIC_DEBUG_HANDLE_KEY, 0) + ) + + def _assign_debug_handle(node: torch.fx.Node) -> None: + nonlocal unique_id + if CUSTOM_KEY not in node.meta: + node.meta[CUSTOM_KEY] = {} + + if NUMERIC_DEBUG_HANDLE_KEY not in node.meta[CUSTOM_KEY]: + node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = unique_id + unique_id += 1 + + # Find the max ID that exists in the graph first, in case part of the graph + # has already been annotated. This way we guarantee there are no duplicate + # handle IDs. + bfs_trace_with_node_process(ep, _find_max_id) + + unique_id += 1 + + # Assign debug handles to all nodes in the graph that don't have one based on the + # max ID found in the previous step. + bfs_trace_with_node_process(ep, _assign_debug_handle) + + +def _detach(x: object) -> object: + detached: object = None + if isinstance(x, torch.Tensor): + detached = x.detach() + elif isinstance(x, (list, tuple)): + detached = type(x)([_detach(e) for e in x]) + elif isinstance(x, dict): + detached = {k: _detach(e) for k, e in x.items()} + else: + detached = x + return detached + + +def _tensor_shape_equals(x: object, y: object) -> bool: + if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor): + return x.shape == y.shape + elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)): + return all(_tensor_shape_equals(e1, e2) for e1, e2 in zip(x, y)) + elif isinstance(x, dict) and isinstance(y, dict): + all_equal = True + for k in x: + all_equal = all_equal and k in y and (_tensor_shape_equals(x[k], y[k])) + return all_equal + else: + log.debug("Comparing non Tensors: %s and %s, they must be equal", x, y) + return type(x) == type(y) and x == y + + +def _loss_fn( + loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], x: object, y: object +) -> object: + """The returned loss will have the same structure as `x` and `y`, e.g. + if both are Tensor, we'll return a Tensor + if both are list, we'll return a list of Tensors + if both are dict, we'll return a dict with the same key, and value being the loss between the + two Tensors + """ + if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor): + return loss(x.to(torch.float32), y.to(torch.float32)) + elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)): + return type(x)([_loss_fn(loss, e1, e2) for e1, e2 in zip(x, y)]) + elif isinstance(x, dict) and isinstance(y, dict): + return {k: _loss_fn(loss, e, y[k]) for k, e in x.items()} + else: + return None + + +class OutputLogger(torch.nn.Module): + """ + Base class for capturing output values for nodes in a GraphModule, it only captures + Tensor output currently, but we can extend it to work for other types of inputs later if needed + """ + + # Mark as impure so that calls to it will not be removed during DCE. + _is_impure = True + + def __init__( + self, + debug_handle: int, + node_name: Optional[str] = None, + nn_module_stack: Optional[object] = None, + ) -> None: + super().__init__() + self.node_name = node_name + self.nn_module_stack = nn_module_stack + self.debug_handle = debug_handle + self.stats: list[object] = [] + + def forward(self, x: object) -> object: + self.stats.append(_detach(x)) + return x + + def __extra_repr__(self) -> str: + return ( + f"debug_handle={self.debug_handle}, node_name={self.node_name}, " + "nn_module_stack={self.nn_module_stack}, num_stats={len(self.stats)})" + ) + + +def _insert_logger(model: GraphModule, node: Node, debug_handle: int) -> Node: + """For a given node, adds an OutputLogger that observes the output of that node, + and all its users use the OutputLogger output instead. + The OutputLogger will contain the debug_handle which can be used to compare + graphs after transforms""" + + # to avoid circular dep + from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix + + # add a logger after the node + with model.graph.inserting_after(node): + get_new_attr_name = get_new_attr_name_with_prefix(f"{node.name}_logger") + logger_name = get_new_attr_name(model) + setattr( + model, + logger_name, + OutputLogger(debug_handle, node.name, node.meta.get("nn_module_stack")), + ) + logger_node = model.graph.call_module(logger_name, (node,), {}) + + orig_users = list(node.users.keys()) + for user_node in orig_users: + if user_node is logger_node: + continue + user_node.replace_input_with(node, logger_node) + + return logger_node + + +def prepare_for_propagation_comparison(model: GraphModule) -> GraphModule: + """Add output loggers to node that has numeric_debug_handle + + Args: + model (GraphModule): original model + Returns: + a model with output loggers for all nodes that has numeric_debug_handle_id + """ + # don't change the original model + model = copy.deepcopy(model) + for n in model.graph.nodes: + if ( + CUSTOM_KEY not in n.meta + or NUMERIC_DEBUG_HANDLE_KEY not in n.meta[CUSTOM_KEY] + ): + continue + numeric_debug_handle = n.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] + _insert_logger(model, n, numeric_debug_handle) + + model.recompile() + return model + + +@dataclass(frozen=True) +class QuantizationComparisonResult: + actual: torch.Tensor + ref: torch.Tensor + + @property + def mse_loss(self) -> object: + return self.loss(F.mse_loss) + + @property + def sqnr(self) -> object: + return self.loss(compute_sqnr) + + def loss( + self, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] + ) -> object: + return _loss_fn(loss_function, self.actual, self.ref) + + def __repr__(self) -> str: + # Don't include the tensors themselves as they are quite large to print + # out. + return ( + f"QuantizationComparisonResult(mse_loss={self.mse_loss}, sqnr={self.sqnr})" + ) + + def __post_init__(self) -> None: + if not isinstance(self.actual, (torch.Tensor, list, tuple, dict)): + raise ValueError( + f"`self.actual` value must be a Tensor, list, tuple or dict, got: {self.actual}" + ) + + if not isinstance(self.ref, (torch.Tensor, list, tuple, dict)): + raise ValueError( + f"`self.ref` value must be a Tensor, list, tuple or dict, got: {self.ref}" + ) + + if not _tensor_shape_equals(self.ref, self.actual): + raise ValueError( + f"Cannot compare tensors with different shapes: ref={self.ref} vs actual={self.actual}" + ) + + +@dataclass(frozen=True) +class NodeAccuracySummary: + handle: int + actual_node_name: str + actual_module_stack: str + ref_node_name: str + ref_module_stack: str + results: Sequence[QuantizationComparisonResult] + + +def _module_stack_to_str(module_stack: object) -> str: + """Simplifies the stack from ("mod", "mod.foo", "mod.foo.0", "mod.foo.0.linear") + to "mod.foo.0.linear" + """ + if not isinstance(module_stack, dict): + return str(module_stack) + module_values_list = list(module_stack.values()) + if len(module_values_list) > 0: + owning_module = module_values_list[-1][0] + return str(owning_module) + else: + return str(module_stack) + + +def extract_results_from_loggers( + model: GraphModule, +) -> dict[int, tuple[Optional[str], object, list[object]]]: + """For a given model, extract the tensors stats and related information for each debug handle. + The reason we have a list of object, instead of Tensor is because the output of node may not be + a Tensor, it could be (nested) list, tuple or dict as well. + + Returns: + A dict is keyed by the debug_handle id and the values are a list of object recorded + in loggers + + """ + # Results maps debug handle to a tensor list for each model being compared. + handles: dict[int, tuple[Optional[str], object, list[object]]] = {} + for _name, module in model.named_children(): + if isinstance(module, OutputLogger) and len(module.stats) > 0: + handles[module.debug_handle] = ( + module.node_name, + module.nn_module_stack, + module.stats, + ) + + return handles + + +def compare_results( + ref_results: dict[int, tuple[Optional[str], object, list[torch.Tensor]]], + actual_results: dict[int, tuple[Optional[str], object, list[torch.Tensor]]], +) -> dict[int, NodeAccuracySummary]: + """Given two dict mapping from `debug_handle_id` (int) to list of tensors + return a map from `debug_handle_id` to `NodeAccuracySummary` that contains + comparison information like SQNR, MSE etc. + + Args: + ref_results (Dict[int, Tuple[str, object, List[torch.Tensor]]]): reference results for each debug_handle_id + actual_results (Dict[int, Tuple[str, object, List[torch.Tensor]]]): actual results for each debug_handle_id + + Returns: + Dict[int, NodeAccuracySummary] + """ + comparisons = {} + for debug_handle, (ref_name, ref_stack, ref_stats) in ref_results.items(): + if debug_handle not in actual_results: + log.debug( + "Cannot compare for handle %s because it wasn't found in the transformed model", + debug_handle, + ) + continue + actual_name, actual_stack, actual_stats = actual_results[debug_handle] + try: + results = [ + QuantizationComparisonResult(actual=a, ref=b) + for a, b in zip(actual_stats, ref_stats) + ] + except Exception as e: + # Add extra information for an exception from QuantizationComparisonResult + # if the shapes didn't match, to include the handle and the node names. + raise ValueError( + f"For numeric_debug_handle={debug_handle} from ref node {ref_name} and actual node {actual_name}" + ) from e + + comparisons[debug_handle] = NodeAccuracySummary( + handle=debug_handle, + actual_node_name=actual_name or "", + actual_module_stack=_module_stack_to_str(actual_stack), + ref_node_name=ref_name or "", + ref_module_stack=_module_stack_to_str(ref_stack), + results=results, + ) + + return comparisons diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/duplicate_dq_pass.py b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/duplicate_dq_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..9f77378248659bc2822409ac2cb8bb3918596281 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/duplicate_dq_pass.py @@ -0,0 +1,82 @@ +# mypy: allow-untyped-defs +import logging +import operator + +import torch +from torch.ao.quantization.pt2e.utils import ( + _filter_sym_size_users, + _is_valid_annotation, +) +from torch.fx.node import map_arg +from torch.fx.passes.infra.pass_base import PassBase, PassResult + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + +__all__ = ["DuplicateDQPass"] + +_QUANTIZE_OPS = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.quantize_per_channel.default, +] + +_DEQUANTIZE_OPS = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_channel.default, +] + + +def _maybe_duplicate_dq( + gm: torch.fx.GraphModule, dq_node: torch.fx.Node, user: torch.fx.Node +): + annotation = user.meta.get("quantization_annotation", None) + if not _is_valid_annotation(annotation): # type: ignore[arg-type] + return + with gm.graph.inserting_after(dq_node): + new_node = gm.graph.node_copy(dq_node) + + def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node: + if n == dq_node: + return new_node + else: + return n + + new_args = map_arg(user.args, maybe_replace_node) + new_kwargs = map_arg(user.kwargs, maybe_replace_node) + user.args = new_args # type: ignore[assignment] + user.kwargs = new_kwargs # type: ignore[assignment] + + +class DuplicateDQPass(PassBase): + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + for node in graph_module.graph.nodes: + if node.op == "call_function" and node.target in _DEQUANTIZE_OPS: + dq_users = _filter_sym_size_users(node) + if len(dq_users) <= 1: + continue + # Do not duplicate dq for dynamic quantization + # Pattern: choose_qparam - getitem - q - dq + q_node = node.args[0] + if q_node.op == "call_function" and q_node.target in _QUANTIZE_OPS: + getitem_node = q_node.args[1] + if ( + isinstance(getitem_node, torch.fx.node.Node) + and getitem_node.op == "call_function" + and getitem_node.target == operator.getitem + ): + choose_qparam_node = getitem_node.args[0] + if ( + isinstance(choose_qparam_node, torch.fx.node.Node) + and choose_qparam_node.op == "call_function" + and choose_qparam_node.target + == torch.ops.quantized_decomposed.choose_qparams.tensor + ): + continue + for user in dq_users: + _maybe_duplicate_dq(graph_module, node, user) + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/export_utils.py b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/export_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7206b9e8792efbbb8fa5324c0785cf98db415e09 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/export_utils.py @@ -0,0 +1,240 @@ +# mypy: allow-untyped-defs +import types + +import torch +import torch.nn.functional as F +from torch.ao.quantization.utils import _assert_and_get_unique_device + + +__all__ = [ + "model_is_exported", +] + +_EXPORTED_TRAINING_ATTR = "_exported_training" + + +class _WrapperModule(torch.nn.Module): + """Class to wrap a callable in an :class:`torch.nn.Module`. Use this if you + are trying to export a callable. + """ + + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, *args, **kwargs): + """Simple forward that just calls the ``fn`` provided to :meth:`WrapperModule.__init__`.""" + return self.fn(*args, **kwargs) + + +def model_is_exported(m: torch.nn.Module) -> bool: + """ + Return True if the `torch.nn.Module` was exported, False otherwise + (e.g. if the model was FX symbolically traced or not traced at all). + """ + return isinstance(m, torch.fx.GraphModule) and any( + "val" in n.meta for n in m.graph.nodes + ) + + +def _replace_dropout(m: torch.fx.GraphModule, train_to_eval: bool): + """ + Switch dropout patterns in the model between train and eval modes. + + Dropout has different behavior in train vs eval mode. For exported models, + however, calling `model.train()` or `model.eval()` does not automatically switch + the dropout behavior between the two modes, so here we need to rewrite the aten + dropout patterns manually to achieve the same effect. + + See https://github.com/pytorch/pytorch/issues/103681. + """ + # Avoid circular dependencies + from .utils import _get_aten_graph_module_for_pattern + + # Needed to ensure subgraph matches are self-contained + m.graph.eliminate_dead_code() + m.recompile() + + for inplace in [False, True]: + + def dropout_train(x): + return F.dropout(x, p=0.5, training=True, inplace=inplace) + + def dropout_eval(x): + return F.dropout(x, p=0.5, training=False, inplace=inplace) + + example_inputs = (torch.randn(1),) + if train_to_eval: + match_pattern = _get_aten_graph_module_for_pattern( + _WrapperModule(dropout_train), + example_inputs, + ) + replacement_pattern = _get_aten_graph_module_for_pattern( + _WrapperModule(dropout_eval), + example_inputs, + ) + else: + match_pattern = _get_aten_graph_module_for_pattern( + _WrapperModule(dropout_eval), + example_inputs, + ) + replacement_pattern = _get_aten_graph_module_for_pattern( + _WrapperModule(dropout_train), + example_inputs, + ) + + from torch.fx.subgraph_rewriter import replace_pattern_with_filters + + replace_pattern_with_filters( + m, + match_pattern, + replacement_pattern, + match_filters=[], + ignore_literals=True, + ) + m.recompile() + + +def _replace_batchnorm(m: torch.fx.GraphModule, train_to_eval: bool): + """ + Switch batchnorm patterns in the model between train and eval modes. + + Batchnorm has different behavior in train vs eval mode. For exported models, + however, calling `model.train()` or `model.eval()` does not automatically switch + the batchnorm behavior between the two modes, so here we need to rewrite the aten + batchnorm patterns manually to achieve the same effect. + """ + # TODO(Leslie): This function still fails to support custom momentum and eps value. + # Enable this support in future updates. + + # Avoid circular dependencies + from .utils import _get_aten_graph_module_for_pattern + + # Needed to ensure subgraph matches are self-contained + m.graph.eliminate_dead_code() + m.recompile() + + def bn_train( + x: torch.Tensor, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + bn_running_mean: torch.Tensor, + bn_running_var: torch.Tensor, + ): + return F.batch_norm( + x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True + ) + + def bn_eval( + x: torch.Tensor, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + bn_running_mean: torch.Tensor, + bn_running_var: torch.Tensor, + ): + return F.batch_norm( + x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=False + ) + + example_inputs = ( + torch.randn(1, 1, 3, 3), # x + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var + ) + + device = _assert_and_get_unique_device(m) + is_cuda = device is not None and device.type == "cuda" + bn_train_aten = _get_aten_graph_module_for_pattern( + _WrapperModule(bn_train), + example_inputs, + is_cuda, + ) + bn_eval_aten = _get_aten_graph_module_for_pattern( + _WrapperModule(bn_eval), + example_inputs, + is_cuda, + ) + + if train_to_eval: + match_pattern = bn_train_aten + replacement_pattern = bn_eval_aten + else: + match_pattern = bn_eval_aten + replacement_pattern = bn_train_aten + + from torch.fx.subgraph_rewriter import replace_pattern_with_filters + + replace_pattern_with_filters( + m, + match_pattern, + replacement_pattern, + match_filters=[], + ignore_literals=True, + ) + m.recompile() + + +# TODO: expose these under this namespace? +def _move_exported_model_to_eval(model: torch.fx.GraphModule): + """ + Move an exported GraphModule to eval mode. + + This is equivalent to model.eval() but only for certain special ops like dropout, batchnorm. + QAT users should call this before performing inference on the model. + + This call is idempotent; if the model is already in eval mode, nothing will happen. + """ + is_training = getattr(model, _EXPORTED_TRAINING_ATTR, True) + if not is_training: + return model + setattr(model, _EXPORTED_TRAINING_ATTR, False) + _replace_dropout(model, train_to_eval=True) + _replace_batchnorm(model, train_to_eval=True) + return model + + +def _move_exported_model_to_train(model: torch.fx.GraphModule): + """ + Move an exported GraphModule to train mode. + + This is equivalent to model.train() but only for certain special ops like dropout, batchnorm. + QAT users should call this before performing training on the model. + + This call is idempotent; if the model is already in train mode, nothing will happen. + """ + is_training = getattr(model, _EXPORTED_TRAINING_ATTR, False) + if is_training: + return model + setattr(model, _EXPORTED_TRAINING_ATTR, True) + _replace_dropout(model, train_to_eval=False) + _replace_batchnorm(model, train_to_eval=False) + return model + + +def _allow_exported_model_train_eval(model: torch.fx.GraphModule): + """ + Allow users to call `model.train()` and `model.eval()` on an exported model, + but with the effect of changing behavior between the two modes limited to special + ops only, which are currently dropout and batchnorm. + + Note: This does not achieve the same effect as what `model.train()` and `model.eval()` + does in eager models, but only provides an approximation. In particular, user code + branching on `training` flag will not function correctly in general because the branch + is already specialized at export time. Additionally, other ops beyond dropout and batchnorm + that have different train/eval behavior will also not be converted properly. + """ + + def _train(self, mode: bool = True): + if mode: + _move_exported_model_to_train(self) + else: + _move_exported_model_to_eval(self) + + def _eval(self): + _move_exported_model_to_eval(self) + + model.train = types.MethodType(_train, model) # type: ignore[method-assign] + model.eval = types.MethodType(_eval, model) # type: ignore[method-assign] + return model diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/graph_utils.py b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/graph_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b158fb758e585fff3a6bf59fdac97da982ab0f21 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/graph_utils.py @@ -0,0 +1,181 @@ +# mypy: allow-untyped-defs +import itertools +import operator +from collections import OrderedDict +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union + +import torch +from torch.export import ExportedProgram +from torch.fx import Node +from torch.fx.passes.utils.source_matcher_utils import ( + check_subgraphs_connected, + get_source_partitions, + SourcePartition, +) + + +__all__ = [ + "find_sequential_partitions", + "get_equivalent_types", + "update_equivalent_types_dict", + "bfs_trace_with_node_process", +] + +_EQUIVALENT_TYPES: list[set] = [ + {torch.nn.Conv1d, torch.nn.functional.conv1d}, + {torch.nn.Conv2d, torch.nn.functional.conv2d}, + {torch.nn.AdaptiveAvgPool2d, torch.nn.functional.adaptive_avg_pool2d}, + {torch.nn.ReLU, torch.nn.functional.relu, torch.nn.functional.relu_}, + {torch.nn.BatchNorm2d, torch.nn.functional.batch_norm}, + {torch.nn.Hardtanh, torch.nn.functional.hardtanh, torch.nn.functional.hardtanh_}, + {torch.add, operator.add, operator.iadd, "add", "add_"}, + {torch.mul, operator.mul, operator.imul, "mul", "mul_"}, +] + + +def _create_equivalent_types_dict(): + _DICT = {} + for values in _EQUIVALENT_TYPES: + for v in values: + _DICT[v] = list(values) + return _DICT + + +_EQUIVALENT_TYPES_DICT = _create_equivalent_types_dict() + + +def get_equivalent_types() -> list[set]: + return _EQUIVALENT_TYPES + + +def update_equivalent_types_dict(customized_equivalent_types=None): + """Help function for user who wants to customize the _EQUIVALENT_TYPES and _EQUIVALENT_TYPES_DICT. + When customized_equivalent_types passes in, + re-generate _EQUIVALENT_TYPES and _EQUIVALENT_TYPES_DICT. + """ + if customized_equivalent_types is None: + raise ValueError("customized_equivalent_types should not be None") + global _EQUIVALENT_TYPES + global _EQUIVALENT_TYPES_DICT + _EQUIVALENT_TYPES = customized_equivalent_types + _EQUIVALENT_TYPES_DICT = _create_equivalent_types_dict() + + +def _partitions_sequential(partitions: Sequence[SourcePartition]): + prev_partition = None + for partition in partitions: + if prev_partition is not None and not check_subgraphs_connected( + prev_partition, partition + ): + return False + prev_partition = partition + return True + + +def _get_matching_types(partition_type): + matching_types = [partition_type] + if partition_type in _EQUIVALENT_TYPES_DICT: + matching_types.extend(_EQUIVALENT_TYPES_DICT[partition_type]) + return matching_types + + +def _valid_type_sequence(partition_types: list[Any]): + partition_types_set = set() # type: ignore[var-annotated] + for partition_type in partition_types: + matching_types = _get_matching_types(partition_type) + matching_types_set = set(matching_types) + if len(partition_types_set & matching_types_set) > 0: + return False + partition_types_set |= matching_types_set + return True + + +def find_sequential_partitions( + gm: torch.fx.GraphModule, + partition_types: list[Any], + include_functional_equivalent=True, + filter_fn: Optional[Callable[[Node], bool]] = None, +): + if not _valid_type_sequence(partition_types): + raise ValueError( + f"Invalid partition types: {partition_types}. Each type in the sequence must be unique" + ) + + typed_partitions: OrderedDict[Any, list[SourcePartition]] = OrderedDict() + for partition_type in partition_types: + types_to_match = _get_matching_types(partition_type) + partitions = get_source_partitions(gm.graph, types_to_match, filter_fn) + typed_partitions[partition_type] = list( + itertools.chain.from_iterable(partitions.values()) + ) + + typed_partitions_list = list(typed_partitions.values()) + fusion_candidates = itertools.product(*typed_partitions_list) + fused_partitions = [ + candidate + for candidate in fusion_candidates + if _partitions_sequential(candidate) + ] + return fused_partitions + + +def _get_submodule( + graph_module: torch.fx.GraphModule, node: torch.fx.Node, arg_index: int +) -> tuple[str, torch.nn.Module, torch.fx.Node]: + submod_node = node.args[arg_index] + assert isinstance(submod_node, torch.fx.Node) + assert submod_node.op == "get_attr" + assert isinstance(submod_node.target, str) + submodule = graph_module.get_submodule(submod_node.target) + # pyre-ignore + return submod_node.target, submodule, node + + +def _get_control_flow_submodules( + graph_module: torch.fx.GraphModule, +) -> list[tuple[str, torch.nn.Module, torch.fx.Node]]: + """ + Returns a list of submodules used for control flow operations + (torch.ops.higher_order.cond/map) that are in the given toplevel graph (does not look + into submodules). Specifically, the returned value is a list containing a + tuple of (name of the submodule that's stored in the graph module, the + submodule itself, and the fx node that uses this submodule). + """ + control_flow_submodules = [] + for node in graph_module.graph.nodes: + if node.op != "call_function": + continue + + if node.target is torch.ops.higher_order.cond: + control_flow_submodules.append(_get_submodule(graph_module, node, 1)) + control_flow_submodules.append(_get_submodule(graph_module, node, 2)) + if node.target is torch.ops.higher_order.map_impl: + control_flow_submodules.append(_get_submodule(graph_module, node, 0)) + + return control_flow_submodules + + +def bfs_trace_with_node_process( + model: Union[ExportedProgram, torch.fx.GraphModule], node_op: Callable +) -> None: + """Traverse the graph module and apply node_op to each node.""" + + assert isinstance(model, (ExportedProgram, torch.fx.GraphModule)), ( + f"Expected GraphModule or ExportedProgram, got {type(model)}" + ) + gm = model.graph_module if isinstance(model, ExportedProgram) else model + queue = [gm] + while queue: + current_graph_module = queue.pop(0) + for node in current_graph_module.graph.nodes: + if node.op in ["output", "placeholder"]: + continue + + node_op(node) + + control_flow_submodules = [ + submodule + for _, submodule, _ in _get_control_flow_submodules(current_graph_module) + ] + queue.extend(control_flow_submodules) diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/lowering.py b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/lowering.py new file mode 100644 index 0000000000000000000000000000000000000000..073ee82f4fd2ad26f973f47f3801e0c98788d123 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/lowering.py @@ -0,0 +1,60 @@ +import torch +from torch._inductor.constant_folding import constant_fold +from torch._inductor.fx_passes.freezing_patterns import freezing_passes + + +__all__ = [ + "lower_pt2e_quantized_to_x86", +] + + +def lower_pt2e_quantized_to_x86( + model: torch.fx.GraphModule, + example_inputs: tuple[torch.Tensor, ...], +) -> torch.fx.GraphModule: + """Lower a PT2E-qantized model to x86 backend. + + Args: + * `model` (torch.fx.GraphModule): a model quantized by PT2E quantization flow. + * `example_inputs` (tuple[torch.Tensor, ...]): example inputs for the model. + + Return: + A GraphModule lowered to x86 backend. + """ + + def _post_autograd_decomp_table(): # type: ignore[no-untyped-def] + decomp_table = torch.export.default_decompositions() + + # if we are post-autograd, we shouldn't + # decomp prim ops. + for k in list(decomp_table.keys()): + if not torch._export.utils._is_cia_op(k): + del decomp_table[k] + + return decomp_table + + def _node_replace(m): # type: ignore[no-untyped-def] + # Replace aten.t(x) with aten.permute(x, [1, 0]) + aten = torch.ops.aten + g = m.graph + for node in g.nodes: + if node.target == aten.t.default: + with g.inserting_before(node): + x = node.args[0] + dims = [1, 0] + perm_node = g.call_function(aten.permute.default, args=(x, dims)) + node.replace_all_uses_with(perm_node) + g.erase_node(node) + + g.lint() + m.recompile() + + lowered_model = ( + torch.export.export_for_training(model, example_inputs, strict=True) + .run_decompositions(_post_autograd_decomp_table()) + .module() + ) + _node_replace(lowered_model) + freezing_passes(lowered_model, example_inputs) + constant_fold(lowered_model) + return lowered_model diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/port_metadata_pass.py b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/port_metadata_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..f26fbd8bb9aaa8799d6e3a05f96f426113189b8a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/port_metadata_pass.py @@ -0,0 +1,218 @@ +# mypy: allow-untyped-defs +import logging +from typing import Optional + +import torch +from torch._export.error import InternalError +from torch.ao.quantization.pt2e.utils import ( + _filter_sym_size_users, + _find_q_dq_node_for_user, + _is_valid_annotation, +) +from torch.ao.quantization.quantizer import QuantizationSpecBase +from torch.fx.passes.infra.pass_base import PassBase, PassResult + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.ERROR) + +__all__ = ["PortNodeMetaForQDQ"] + +_METADATA_TO_PORT = [ + "stack_trace", + "quantization_tag", +] + +_QUANTIZE_OPS = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.quantize_per_channel.default, + torch.ops.pt2e_quant.quantize_affine, +] + +_DEQUANTIZE_OPS = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_channel.default, + torch.ops.pt2e_quant.dequantize_affine, +] + +_CHOOSE_QPARAMS_OPS = [ + torch.ops.quantized_decomposed.choose_qparams.tensor, + torch.ops.quantized_decomposed.choose_qparams_symmetric.tensor, + torch.ops.pt2e_quant.choose_qparams_affine, +] + + +def _add_metadata(to_node: torch.fx.Node, from_node: torch.fx.Node) -> None: + from_meta = from_node.meta + for meta_name in _METADATA_TO_PORT: + if meta_name in from_meta: + to_node.meta[meta_name] = from_meta[meta_name] + + +def _has_quant_annotation(node: torch.fx.Node) -> bool: + return "quantization_annotation" in node.meta + + +def _find_choose_qparams_node(node: torch.fx.Node) -> Optional[torch.fx.Node]: + # BFS to look for choose qparams + from collections import deque + + queue = deque(list(node.users.keys())) + while len(queue): + n = queue.popleft() + if n.op == "output": + continue + if n.op == "call_function" and n.target in _CHOOSE_QPARAMS_OPS: + return n + for k in n.users.keys(): + queue.append(k) + return None + + +def _port_metadata_for_input_quant_nodes( + input_node: torch.fx.Node, + node: torch.fx.Node, + qspec: Optional[QuantizationSpecBase], +): + if qspec is None: + return + + is_dynamic_quant = getattr(qspec, "is_dynamic", None) + if is_dynamic_quant is not None and is_dynamic_quant is True: + choose_qparams_node = _find_choose_qparams_node(input_node) + if choose_qparams_node is None: + raise ValueError(f"No chose qparams node found for {node}") + choose_qparam_users = _filter_sym_size_users(choose_qparams_node) + if len(choose_qparam_users) != 2: + raise InternalError(f"Expecting exactly two user for {choose_qparams_node}") + scale_node = choose_qparam_users.pop() + dynamic_q_node = next(iter(scale_node.users.keys())) + dynamic_q_node_users = _filter_sym_size_users(dynamic_q_node) + if len(dynamic_q_node_users) > 1: + raise InternalError(f"Expecting single user for {dynamic_q_node}") + dynamic_dq_node = dynamic_q_node_users.pop() + _add_metadata(choose_qparams_node, node) + _add_metadata(dynamic_q_node, node) + _add_metadata(dynamic_dq_node, node) + else: + q_node, dq_node = _find_q_dq_node_for_user(input_node, node) + if q_node is None or dq_node is None: + return + # add metadata for all the node between q_node and get_attr node + # if the q_node can be traced back to get_attr node + q_to_get_attr_nodes = [q_node] + q_node_input = q_node.args[0] + while ( + isinstance(q_node_input, torch.fx.Node) + and q_node_input.op == "call_function" + and q_node_input.target + in [ + torch.ops.aten.flatten.using_ints, + torch.ops.aten.permute.default, + torch.ops.aten.permute_copy.default, + torch.ops.aten.slice_copy.Tensor, + torch.ops.aten.squeeze.dim, + torch.ops.aten.squeeze_copy.dim, + torch.ops.aten.transpose.Dimname, + torch.ops.aten.transpose.int, + torch.ops.aten.transpose_, + torch.ops.aten.view_copy.default, + torch.ops.aten.view.default, + torch.ops.aten._mkldnn_transpose, + ] + ): + q_to_get_attr_nodes.append(q_node_input) + q_node_input = q_node_input.args[0] + if isinstance(q_node_input, torch.fx.Node) and q_node_input.op == "get_attr": + for n in q_to_get_attr_nodes: + _add_metadata(n, q_node_input) + _add_metadata(dq_node, node) + + +def _port_metadata_for_output_quant_nodes( + node: torch.fx.Node, qspec: Optional[QuantizationSpecBase] +): + if qspec is None: + return + + node_users = _filter_sym_size_users(node) + if len(node.users) == 0: + return + if len(node_users) != 1: + logger.warning(f"Expecting {node} to have single user") # noqa: G004 + q_node = node_users.pop() + if q_node.op != "call_function" or q_node.target not in _QUANTIZE_OPS: + logger.warning( + f"Expecting {node} user to be a quantized op but got {q_node}" # noqa: G004 + ) # noqa: G004 + return + + _add_metadata(q_node, node) + + +class PortNodeMetaForQDQ(PassBase): + """ + Port metadata for nodes added by quantization flow. + For static quant these are: + - quantizer_per_tensor.default, dequantize_per_tensor.default + - quantizer_per_channel.default, dequantize_per_channel.default + For dynamic quant these are: + - choose_qparams.tensor + - quantizer_per_tensor.tensor, dequantize_per_tensor.tensor + - quantizer_per_channel.default, dequantize_per_channel.default + + Rules of porting metadata: + - Metadata to be ported: + - nn_module_stack + - stack_trace + - quantization_tag + - Metadata to NOT be ported: + - Everything else + - Rules: + - Statically quantized patterns: + - Dequantize nodes on the inputs to be quantized inherit metadata of the consumer node. + - Quantize nodes on the outputs inherit metadata of the producer node. + - Example 1: + - Original: [Conv -> AvgPool -> Linear] + - Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> Linear -> Q -> DQ] + - Inner brackets specify which nodes Q/DQ inherit metdata from + - [Q-> [DQ -> Conv -> Q] -> [DQ -> AvgPool -> Q] -> [DQ -> Linear -> Q] -> DQ] + - Note first Q and last DQ do not inherit metadata from any nodes + - Example 2: + - Original: [Conv -> AvgPool -> Linear] + - AvgPool is not quantized + - Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> Linear -> Q -> DQ] + - Inner brackets specify which nodes Q/DQ inherit metdata from + - [Q-> [DQ -> Conv -> Q] -> DQ -> [AvgPool] -> Q -> [DQ -> Linear -> Q] -> DQ] + - Note DQ and Q nodes around AvgPool do not inherit metadata from AvgPool because + AvgPool was not supposed to be quantized. Metadata porting relies on quantization_annotation + on the nodes (in this case AvgPool node) to conclude if the node or patter was + supposed to be quantized. And subsequntly decide if the preceding Q, if any, should + inherit metadata from AvgPool. + - Dynamically quantized patterns: + - Input that are dynamically quantized have choose_qparams, quantize and dequantize nodes + - For example, below linear is dynamically quantized while rest statically: + - Original: [Conv -> AvgPool -> Linear] + - Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> choose_params -> Q -> DQ -> Linear] + - Quantized [Q-> [DQ -> Conv -> Q] -> [DQ -> AvgPool -> Q] -> DQ -> [choose_params -> Q -> DQ -> Linear]] + - Note first Q does not inherit metadata from any nodes + NB: + - The best place for porting metadata is during observer conversion to q/dq. This is because it precisely + knows which quantization spec is converted to q/dq and thus from where the metadata should be ported. + However, since FX and PT2E quant workflow are on a common code-base, this hurts readability quite a bit. + Doing it via a separate pass, helps readability of the code. Once we are able to refactor PT2E quant + code, this pass should like to be integrated in the refactored variant of "convert" step. + """ + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + for node in graph_module.graph.nodes: + annotation = node.meta.get("quantization_annotation", None) + if _is_valid_annotation(annotation): + input_qspec_map = node.meta["quantization_annotation"].input_qspec_map + output_qspec = node.meta["quantization_annotation"].output_qspec + for input_node, qspec in input_qspec_map.items(): + _port_metadata_for_input_quant_nodes(input_node, node, qspec) + _port_metadata_for_output_quant_nodes(node, output_qspec) + return PassResult(graph_module, True) diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/prepare.py b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/prepare.py new file mode 100644 index 0000000000000000000000000000000000000000..7a1bc425363a8f53c9aba719b5e10bfe79251bbb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/prepare.py @@ -0,0 +1,574 @@ +# mypy: allow-untyped-defs +from typing import Any, Optional, Union + +import torch +from torch._subclasses import FakeTensor +from torch.ao.quantization import ( + CUSTOM_KEY, + NUMERIC_DEBUG_HANDLE_KEY, + ObserverOrFakeQuantize, + QConfigMapping, +) +from torch.ao.quantization.fx.custom_config import PrepareCustomConfig +from torch.ao.quantization.fx.prepare import ( + _create_obs_or_fq_from_qspec, + _insert_obs_or_fq, + _is_activation_post_process_node, + _save_state, +) +from torch.ao.quantization.qconfig import QConfigAny +from torch.ao.quantization.quantizer import ( + EdgeOrNode, + QuantizationSpecBase, + SharedQuantizationSpec, +) +from torch.fx import Graph, GraphModule, Node +from torch.fx.node import Argument + + +# TODO: make pt2e folder private? +__all__ = [ + "prepare", +] + + +def _find_root_edge_or_node( + edge_or_node: EdgeOrNode, shared_with_map: dict[EdgeOrNode, EdgeOrNode] +) -> EdgeOrNode: + """Find the root node for the sharing tree + Args: + edge_or_node: edge/node that we want to find the root + shared_with_map: each edge/node points to the parent, the root node will points to itself + + Returns: + root edge/node + """ + parent = shared_with_map[edge_or_node] + if parent == edge_or_node: + return edge_or_node + root = _find_root_edge_or_node(parent, shared_with_map) + # path compression + shared_with_map[edge_or_node] = root + return root + + +def _union( + parent: EdgeOrNode, + child: EdgeOrNode, + shared_with_map: dict[EdgeOrNode, EdgeOrNode], +) -> None: + """Merge the subtree for `child` with `parent`, the order is important here""" + root_parent = _find_root_edge_or_node(parent, shared_with_map) + root_child = _find_root_edge_or_node(child, shared_with_map) + # union the two trees by pointing the root of child to root of parent + shared_with_map[root_child] = root_parent + + +def _update_shared_with( + child: EdgeOrNode, + qspec: QuantizationSpecBase, + shared_with_map: dict[EdgeOrNode, EdgeOrNode], +): + """Update the `shared_with_map` based on the qspec, this applies the `SharedQuantizationSpec` + configuration and established the relationship between `edge_or_node` with the edge/node that it + is pointing to, we'll use this information in the end to get the group id + """ + if isinstance(qspec, SharedQuantizationSpec): + parent = qspec.edge_or_node + # we point from edge_or_node to the node that it is sharing_with, e.g. + # qspec for a = SharedQuantizationSpec(b) means `a` points to `b` + _union(parent, child, shared_with_map) + + +def _unwrap_shared_qspec( + qspec: QuantizationSpecBase, + edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase], + shared_with_map: dict[EdgeOrNode, EdgeOrNode], +) -> QuantizationSpecBase: + """Unwraps qspec to get the final root qspec (non SharedQuantizationSpec) + if qspec is SharedQuantizationSpec + (1). tries to find the root edge or node for the node that the qspec points to + (2). recursively find the root qspec based on the qspec for the root node + """ + if isinstance(qspec, SharedQuantizationSpec): + sharing_with = qspec.edge_or_node + root = _find_root_edge_or_node(sharing_with, shared_with_map) + qspec = edge_or_node_to_qspec[root] + return _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map) + return qspec + + +def _has_same_attr( + qspec_a: QuantizationSpecBase, qspec_b: QuantizationSpecBase, attr_name: str +): + return ( + hasattr(qspec_a, attr_name) + and hasattr(qspec_b, attr_name) + and getattr(qspec_a, attr_name) == getattr(qspec_b, attr_name) + ) or (not hasattr(qspec_a, attr_name) and not hasattr(qspec_b, attr_name)) + + +def _get_edge_or_node_to_qspec( + model: torch.fx.GraphModule, +) -> dict[EdgeOrNode, QuantizationSpecBase]: + """Get a map from EdgeOrNode to quantization spec based on annotations on the nodes""" + edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase] = {} + for n in model.graph.nodes: + if hasattr(n, "meta") and "quantization_annotation" in n.meta: + qa = n.meta["quantization_annotation"] + for input_to_n, qspec in qa.input_qspec_map.items(): + input_edge = (input_to_n, n) + edge_or_node_to_qspec[input_edge] = qspec + if qa.output_qspec is not None: + output_node = n + qspec = qa.output_qspec + edge_or_node_to_qspec[output_node] = qspec + return edge_or_node_to_qspec + + +def _union_input_edge_with( + input_edge, + input_edge_root_qspec, + edge_or_node, + edge_or_node_to_qspec, + shared_with_map, +): + """Union input edge with another edge or node, used in implicit sharing to point the current input + edge to other user edges of the producer node, or the output of producer node since these are + referring to the same Tensor + """ + root_qspec = None + if edge_or_node in edge_or_node_to_qspec: + qspec = edge_or_node_to_qspec[edge_or_node] + root_qspec = _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map) + # TODO: add assertions for types of root qspecs + if root_qspec is not None and all( + _has_same_attr(root_qspec, input_edge_root_qspec, attr) + for attr in [ + "dtype", + "is_dynamic", + "quant_min", + "quant_max", + "qscheme", + "ch_axis", + "scale", + "zero_point", + ] + ): + # the input arg to the node should reuse the existing output observer for arg + # since dtype is the same (we may want to extend this to be a more strict check + # in the future) + # so we point from `input_edge` to `arg` (output of the argument) + _union(edge_or_node, input_edge, shared_with_map) + + +def _get_edge_or_node_to_group_id( + edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase], +) -> dict[EdgeOrNode, int]: + """Map from edge/node to the group ID, generated from quantization annotations, + edge/node with the same group ID should use the same observer/fake_quant instance + + This is applying SharedQuantizationSpec configuration and map each edge/node to a group + There is another implicit sharing that's built in the quantization, when we have the following: + * op1 -> op2 + * output of op1: int8_qspec + * (op1 -> op2) input edge: int8_qspec + we'll assume sharing between the output of op1 and input of (op1 -> op2) since these are the same Tensor. + + Figuring out the correct group ID for all edge/node is a standard union find problem: + https://www.geeksforgeeks.org/introduction-to-disjoint-set-data-structure-or-union-find-algorithm/ + + Args: + edge_or_node_to_qspec: Dictionary from edge_or_node to the qspec, derived from annotations + Returns: + edge_or_node_to_group_id: Dictionary from edge_or_node to group_id (int), all edge or node that + belongs to the same group should have the same id + + Example: + op2 -> cat1 -> cat2 + op1 / / + op3 + edge_or_node_to_qspec: { + op1: int8_qspec, + op2: int8_qspec, + (op1, cat1): int8_qspc, + (op2, cat1): SharedQuantizationSpec((op1, cat1)), + cat1: SharedQuantizationSpec((op1, cat1)), + (op3, cat2): int8_qspec, + (cat1, cat2): SharedQuantizationSpec((op3, cat2)), + cat2: SharedQuantizationSpec((op3, cat2)), + } + + edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec) + edge_or_node_to_group_id: { + op1: 1, + op2: 1, + (op1, cat1): 1, + (op2, cat1): 1, + cat1: 1, + (op3, cat2): 1, + (cat1, cat2): 1, + cat2: 1, + } + # everything are in the same group because (cat1) and (cat1, cat2) are implicitly shared, which + # connects the two sharing group around cat1 and cat2 op due to transitive sharing + """ + # means the observer of key should be shared with observer with value, by default it will + # be shared with itself + shared_with_map: dict[EdgeOrNode, EdgeOrNode] = { + k: k for k in edge_or_node_to_qspec.keys() + } + for edge_or_node, qspec in edge_or_node_to_qspec.items(): + if isinstance(edge_or_node, torch.fx.Node): + output_node = edge_or_node + _update_shared_with(output_node, qspec, shared_with_map) + else: + input_edge = edge_or_node + input_edge_root_qspec = _unwrap_shared_qspec( + qspec, edge_or_node_to_qspec, shared_with_map + ) + + assert isinstance(input_edge, tuple) + arg, n = input_edge + if n.meta["quantization_annotation"].allow_implicit_sharing: + # NOTE: the order is important here, we first share with other users and then share with previous + # output because the reverse order could cause circular dependency + # e.g node1 -> node2 + # \ -> node3 + # when processing (node1, node2), if we first point (node1, node2) to node1 + # Step 1. shared_map = {(node1, node2): node1} + # Step 2. after that, we point the (node1, node2) to its other user (node1, node3) , + # which means shared_map = {(node1, node2): node1, node1: (node1, node3)} + # because we will point the root of (node1, node2) (in this case node1) to the root of (node1, node3) + # Step 3. and when we process (node1, node3), it can try to point to node1 as well, then we'll + # have a circular dependency + # the following order works around this issue, but this does not allow arbitrary configuration + # of sharing so it might break in a different case in the future, when it breaks + # quantizer writer can check the notes here to debug the issue + + # sharing with other users of the producer node + # (arg, user) + if not isinstance(arg, Node) or not isinstance(n, Node): + raise Exception( # noqa: TRY002 + f"Expected input_edge to have type Tuple[Node, Node], but got: {arg, n}" + ) + for user in arg.users: + if user is n: + continue + arg_to_user_edge = (arg, user) + _union_input_edge_with( + input_edge, + input_edge_root_qspec, + arg_to_user_edge, + edge_or_node_to_qspec, + shared_with_map, + ) + + # sharing with output of producer node + _union_input_edge_with( + input_edge, + input_edge_root_qspec, + arg, + edge_or_node_to_qspec, + shared_with_map, + ) + + _update_shared_with(input_edge, qspec, shared_with_map) + + # now that we get the sharing relations between all edges and nodes, we can assingn group ids + cur_group_id = 0 + edge_or_node_to_group_id: dict[EdgeOrNode, int] = {} + for edge_or_node in shared_with_map.keys(): + root = _find_root_edge_or_node(edge_or_node, shared_with_map) + if root not in edge_or_node_to_group_id: + edge_or_node_to_group_id[root] = cur_group_id + cur_group_id += 1 + edge_or_node_to_group_id[edge_or_node] = edge_or_node_to_group_id[root] + + return edge_or_node_to_group_id + + +def _get_obs_or_fq_map( + edge_or_node_to_group_id: dict[EdgeOrNode, int], + edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase], + is_qat: bool, +) -> dict[EdgeOrNode, ObserverOrFakeQuantize]: + """Generates the EdgeOrNode to observer/fake_quant instances + Makes sure that for EdgeOrNode that has the same group_id should have the same observer or fake quant + instances + """ + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize] = {} + group_id_to_obs_or_fq: dict[int, ObserverOrFakeQuantize] = {} + for edge_or_node, qspec in edge_or_node_to_qspec.items(): + group_id = edge_or_node_to_group_id[edge_or_node] + if group_id not in group_id_to_obs_or_fq: + # TODO: maybe edge_or_node_to_qspec should be edge_or_node_to_root_qspec, this will simplify + # the implementation for _create_obs_or_fq_from_qspec + group_id_to_obs_or_fq[group_id] = _create_obs_or_fq_from_qspec( + qspec, obs_or_fq_map, is_qat + ) + obs_or_fq_map[edge_or_node] = group_id_to_obs_or_fq[group_id] + return obs_or_fq_map + + +def _maybe_insert_input_observer_for_arg_or_kwarg( + node: Union[Node, Any], + arg: Argument, + qconfig: QConfigAny, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, +) -> Argument: + """ + Given a `node` and an `arg`, inserts an input observer between + `node` and `arg` if necessary. + """ + # for ops such as torch.cat([x0, x1]), + # traverse through the list + if isinstance(arg, (list, tuple)): + new_arg_to_return = [] + for inner_arg in arg: + new_inner_arg = _maybe_insert_input_observer_for_arg_or_kwarg( + node, + inner_arg, + qconfig, + model, + named_modules, + obs_or_fq_map, + is_qat, + ) + new_arg_to_return.append(new_inner_arg) + return type(arg)(new_arg_to_return) + + if not isinstance(arg, Node): + return arg + assert isinstance(arg, Node) + # default (no observer) + new_arg = arg + + # find the original `arg` node to the current node, skipping inserted observer/fake_quant nodes + original_arg = arg + while _is_activation_post_process_node(original_arg, named_modules): + original_arg = original_arg.args[0] # type: ignore[assignment] + assert isinstance(original_arg, Node), ( + f"expect original argument to be a Node, but got: {type(original_arg)}" + ) + + input_edge = (original_arg, node) + if input_edge not in obs_or_fq_map: + return new_arg + # input_edge needs to be observed + input_edge_obs_or_fq = obs_or_fq_map[input_edge] + if input_edge_obs_or_fq is None: + return new_arg + + arg_as_output_obs_or_fq = obs_or_fq_map.get(original_arg, None) + # the arg is observed as the output and is using the same instance as the input_edge + # we'll reuse the inserted observer/fake_quant + if arg_as_output_obs_or_fq is not None and id(arg_as_output_obs_or_fq) == id( + input_edge_obs_or_fq + ): + return new_arg + + # otherwise, we'll insert a new observer/fake_quant node + + # skip inserting new observers if the same observer instance is inserted before for another user + # Example: + # conv1 -> obs1 -> existing_obs -> conv2 + # \ -> conv3 + # + # instead of inserting new observers we will have: + # conv1 -> obs1 -> existing_obs -> conv2 + # \ -> conv3 + for maybe_obs_node in arg.users.keys(): + if not _is_activation_post_process_node(maybe_obs_node, named_modules): + continue + maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index] + if id(maybe_obs_mod) == id(input_edge_obs_or_fq): + return maybe_obs_node + + assert isinstance(model.graph, Graph) + new_arg = _insert_obs_or_fq( + arg, input_edge_obs_or_fq, model, named_modules, model.graph + ) + return new_arg + + +def _maybe_insert_input_observers_for_node( + node: Node, + qconfig: QConfigAny, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, +) -> None: + """ + If needed, inserts observers to the input args and kwargs of `node`. + Note: modifies `node` inplace. + + For example, if cur_node needs an observer after prev_node, we change from + + prev_node -> cur_node + + To + + prev_node -> obs -> cur_node + + """ + # Look through every input arg. If that arg's target dtype does not + # match the current node's target dtype, insert an observer. + new_args = [] + for arg in node.args: + new_arg = _maybe_insert_input_observer_for_arg_or_kwarg( + node, + arg, + qconfig, + model, + named_modules, + obs_or_fq_map, + is_qat, + ) + new_args.append(new_arg) + + # Clone has a memory_format kwarg, zeros_like has a pin_memory kwarg, and + # gelu has a has an approximate kwarg that persist in exported graph. + # This is just a work around for these. + assert ( + node.target == torch.ops.aten.clone.default + or node.target == torch.ops.aten.zeros_like.default + or node.target == torch.ops.aten.gelu.default + or len(node.kwargs) == 0 + ), " expecting kwargs for aten op IR to be empty" + + # assign the new args to the node, inplace + node.args = tuple(new_args) + + +def _maybe_insert_output_observer_for_node( + node: Node, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + graph: Graph, + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, +) -> Optional[Node]: + if node in obs_or_fq_map: + output_act_obs_or_fq = obs_or_fq_map[node] + new_output = _insert_obs_or_fq( + node, output_act_obs_or_fq, model, named_modules, graph + ) + # propagate numeric debug handle from original node to observer/fake_quant node + if ( + isinstance(node, Node) + and isinstance(new_output, Node) + and CUSTOM_KEY in node.meta + and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY] + ): + if CUSTOM_KEY not in new_output.meta: + new_output.meta[CUSTOM_KEY] = {} + new_output.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = node.meta[ + CUSTOM_KEY + ][NUMERIC_DEBUG_HANDLE_KEY] + return new_output + return None + + +def _maybe_insert_input_and_output_observers_for_node( + node: Node, + model: torch.fx.GraphModule, + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, +): + this_node_quantization_annotation = ( + node.meta["quantization_annotation"] + if "quantization_annotation" in node.meta + else None + ) + if this_node_quantization_annotation is None: + return + + named_modules = dict(model.named_modules(remove_duplicate=False)) + _maybe_insert_input_observers_for_node( + node, + None, # qconfig + model, + named_modules, + obs_or_fq_map, + is_qat, + ) + + output_is_a_tensor = "val" in node.meta and isinstance(node.meta["val"], FakeTensor) + if not output_is_a_tensor: + return + + # this returns the new observer node if it was needed + maybe_output_obs_node = _maybe_insert_output_observer_for_node( + node, model, named_modules, model.graph, obs_or_fq_map, is_qat + ) + + if maybe_output_obs_node is None: + return + # Update users of original node to use the output observer + # instead. For example, change + # + # next_node + # / + # cur_node -> obs + # + # to + # + # next_node + # / + # cur_node -> obs + # + # We need to save orig users before updating uses because + # the list of users will change as we update uses + orig_users = list(node.users.keys()) + for user_node in orig_users: + if user_node is maybe_output_obs_node: + continue + user_node.replace_input_with(node, maybe_output_obs_node) + + +def prepare( + model: GraphModule, + node_name_to_scope: dict[str, tuple[str, type]], + is_qat: bool, + obs_or_fq_callback=None, +) -> GraphModule: + # Since we are mutating the graph as we go, we iterate over the original + # nodes before observer insertion, instead of model.graph.nodes. + nodes_before_observation = list(model.graph.nodes) + + # At the high level we construct a map from EdgeOrNode to a observer_or_fake_quant instance + # all edge/nodes that belongs to the same group will use the same instance + # and when we insert observers we'll just query this map to get the correct observer_or_fake_quant + # instance + edge_or_node_to_qspec = _get_edge_or_node_to_qspec(model) + edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec) + obs_or_fq_map = _get_obs_or_fq_map( + edge_or_node_to_group_id, edge_or_node_to_qspec, is_qat + ) + if obs_or_fq_callback: + obs_or_fq_callback(model, obs_or_fq_map) + + for node in nodes_before_observation: + # TODO: simplify logic for inserting observers + _maybe_insert_input_and_output_observers_for_node( + node, model, obs_or_fq_map, is_qat + ) + + model = GraphModule(model, model.graph) + + _save_state( + model, + {}, # node_name_to_qconfig + node_name_to_scope, + PrepareCustomConfig(), + {}, # equalization_node_name_to_qconfig + QConfigMapping(), + is_qat, + set(), # observed_node_names + ) + return model diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/qat_utils.py b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/qat_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0285a1d20b0b596d7bf3c9468c0b34f2fbfa1f9d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/qat_utils.py @@ -0,0 +1,995 @@ +# mypy: allow-untyped-defs +import copy +import dataclasses +import itertools +import operator +from typing import Any, Callable, Optional, TYPE_CHECKING + +import torch +import torch.nn.functional as F +from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 +from torch.ao.quantization.pt2e.export_utils import _WrapperModule +from torch.ao.quantization.quantizer import ( + DerivedQuantizationSpec, + EdgeOrNode, + QuantizationSpecBase, + SharedQuantizationSpec, +) +from torch.fx import Graph, GraphModule, Node +from torch.fx.subgraph_rewriter import replace_pattern_with_filters, ReplacedPatterns + +from .utils import ( + _get_aten_graph_module_for_pattern, + _is_bn_node, + _is_conv_or_conv_transpose_node, + _is_conv_transpose_fn, + fold_bn_weights_into_conv_node, +) + + +if TYPE_CHECKING: + from torch.fx.passes.utils.matcher_with_name_node_map_utils import InternalMatch + +__all__ = [] # type: ignore[var-annotated] + + +def _get_quantized_conv_bn_example_inputs_kwargs( + is_per_channel: bool, + has_bias: bool, + bias_is_quantized: bool, + is_cuda: bool, +) -> dict[str, Any]: + """ + Optional example inputs for quantized and folded conv-bn patterns + used in convert, expressed as kwargs. + """ + kwargs = {} + # Per tensor quantization uses literals to represent scale and zero + # point, so there is no need to include them here as kwargs + if is_per_channel: + kwargs["weight_scale"] = torch.tensor([1], dtype=torch.float) + kwargs["weight_zero_point"] = torch.tensor([0], dtype=torch.int) + if has_bias and bias_is_quantized: + kwargs["bias_scale"] = torch.tensor([1], dtype=torch.float) + kwargs["bias_zero_point"] = torch.tensor([0], dtype=torch.int) + if has_bias: + kwargs["conv_bias"] = torch.randn(1) + if is_cuda: + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + kwargs[k] = v.cuda() + return kwargs + + +def _get_conv_bn_pattern(conv_fn: Callable) -> Callable: + def _conv_bn_pattern( + x: torch.Tensor, + conv_weight: torch.Tensor, + conv_bias: torch.Tensor, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + bn_running_mean: torch.Tensor, + bn_running_var: torch.Tensor, + ) -> torch.Tensor: + x = conv_fn(x, conv_weight, conv_bias) + x = F.batch_norm( + x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True + ) + return x + + return _WrapperModule(_conv_bn_pattern) + + +# TODO: merge this with the `no_conv_bias` case +def _get_qat_conv_bn_pattern(conv_fn: Callable) -> Callable: + def _qat_conv_bn_pattern( + x: torch.Tensor, + conv_weight: torch.Tensor, + conv_bias: torch.Tensor, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + bn_running_mean: torch.Tensor, + bn_running_var: torch.Tensor, + ) -> torch.Tensor: + """ + Approximated method to fuse conv and bn. It requires only one forward pass. + conv_orig = conv / scale_factor where scale_factor = bn.weight / running_std. + This is based on `nniqat.ConvBn2d._forward_approximate`. + """ + # TODO: allow setting eps + bn_eps = 1e-5 + running_std = torch.sqrt(bn_running_var + bn_eps) + scale_factor = bn_weight / running_std + weight_shape = [1] * len(conv_weight.shape) + weight_in_channel_axis = 1 if _is_conv_transpose_fn(conv_fn) else 0 + weight_shape[weight_in_channel_axis] = -1 + bias_shape = [1] * len(conv_weight.shape) + bias_shape[1] = -1 + scaled_weight = conv_weight * scale_factor.reshape(weight_shape) + zero_bias = torch.zeros_like(conv_bias, dtype=x.dtype) + x = conv_fn(x, scaled_weight, zero_bias) + x = x / scale_factor.reshape(bias_shape) + x = x + conv_bias.reshape(bias_shape) + x = F.batch_norm( + x, + bn_running_mean, + bn_running_var, + bn_weight, + bn_bias, + training=True, + eps=bn_eps, + ) + return x + + return _WrapperModule(_qat_conv_bn_pattern) + + +def _get_qat_conv_bn_pattern_no_conv_bias(conv_fn: Callable) -> Callable: + def _qat_conv_bn_pattern_no_conv_bias( + x: torch.Tensor, + conv_weight: torch.Tensor, + # Not used, only for matching convenience + conv_bias: torch.Tensor, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + bn_running_mean: torch.Tensor, + bn_running_var: torch.Tensor, + ) -> torch.Tensor: + """ + Same as `_get_qat_conv_bn_pattern`, but handles the case with no conv bias. + """ + # TODO: allow setting eps + bn_eps = 1e-5 + running_std = torch.sqrt(bn_running_var + bn_eps) + scale_factor = bn_weight / running_std + weight_shape = [1] * len(conv_weight.shape) + weight_in_channel_axis = 1 if _is_conv_transpose_fn(conv_fn) else 0 + weight_shape[weight_in_channel_axis] = -1 + bias_shape = [1] * len(conv_weight.shape) + bias_shape[1] = -1 + scaled_weight = conv_weight * scale_factor.reshape(weight_shape) + x = conv_fn(x, scaled_weight, None) + x = x / scale_factor.reshape(bias_shape) + x = F.batch_norm( + x, + bn_running_mean, + bn_running_var, + bn_weight, + bn_bias, + training=True, + eps=bn_eps, + ) + return x + + return _WrapperModule(_qat_conv_bn_pattern_no_conv_bias) + + +def _append_qdq(x, is_per_channel, is_bias, kwargs): + """ + Helper function to append q-dq ops after `x`, using dummy values for the qparams + and qmin/qmax. We use dummy values here because we match with `ignore_literals=True` + and will manually replace these values after subgraph rewriting. + + Return the dq node. + """ + # Dummy args to be passed into q-dq ops + per_channel_axis = 0 + scale_key = "bias_scale" if is_bias else "weight_scale" + zp_key = "bias_zero_point" if is_bias else "weight_zero_point" + scale = kwargs[scale_key] if is_per_channel else 1.0 + zp = kwargs[zp_key] if is_per_channel else 0 + qmin = -127 + qmax = 127 + dtype = torch.int8 + + qd = torch.ops.quantized_decomposed + if is_per_channel: + x = qd.quantize_per_channel(x, scale, zp, per_channel_axis, qmin, qmax, dtype) + x = qd.dequantize_per_channel(x, scale, zp, per_channel_axis, qmin, qmax, dtype) + else: + x = qd.quantize_per_tensor(x, scale, zp, qmin, qmax, dtype) + x = qd.dequantize_per_tensor(x, scale, zp, qmin, qmax, dtype) + return x + + +def _get_quantized_qat_conv_bn_pattern( + is_per_channel: bool, + has_bias: bool, + bias_is_quantized: bool, + conv_fn: Callable, + bn_is_training: bool, +) -> Callable: + """ + Return the quantized version of QAT conv + BN pattern. + This is based on `nniqat.ConvBn2d._forward_approximate`, + used in QAT convert. We first match this pattern and replace + it with the normal [conv - bn] pattern, then fold the BN + weights into conv. + """ + # TODO: allow setting eps + bn_eps = 1e-5 + + def _quantized_qat_conv_bn_pattern( + x: torch.Tensor, + conv_weight: torch.Tensor, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + bn_running_mean: torch.Tensor, + bn_running_var: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + running_std = torch.sqrt(bn_running_var + bn_eps) + scale_factor = bn_weight / running_std + weight_shape = [1] * len(conv_weight.shape) + weight_shape[0] = -1 + bias_shape = [1] * len(conv_weight.shape) + bias_shape[1] = -1 + scaled_weight = conv_weight * scale_factor.reshape(weight_shape) + scaled_weight = _append_qdq( + scaled_weight, + is_per_channel, + is_bias=False, + kwargs=kwargs, + ) + if has_bias: + zero_bias = torch.zeros_like(kwargs["conv_bias"], dtype=x.dtype) + if bias_is_quantized: + zero_bias = _append_qdq( + zero_bias, + is_per_channel, + is_bias=True, + kwargs=kwargs, + ) + x = conv_fn(x, scaled_weight, zero_bias) + else: + x = conv_fn(x, scaled_weight, None) + x = x / scale_factor.reshape(bias_shape) + if has_bias: + x = x + kwargs["conv_bias"].reshape(bias_shape) + x = F.batch_norm( + x, + bn_running_mean, + bn_running_var, + bn_weight, + bn_bias, + training=bn_is_training, + eps=bn_eps, + ) + return x + + return _WrapperModule(_quantized_qat_conv_bn_pattern) + + +def _get_folded_quantized_qat_conv_bn_pattern( + is_per_channel: bool, + has_bias: bool, + bias_is_quantized: bool, + conv_fn: Callable, + bn_is_training: bool, +) -> Callable: + """ + Quantized QAT conv - bn pattern with bn weights being folded into conv. + """ + # TODO: allow setting eps + bn_eps = 1e-5 + + def _folded_quantized_qat_conv_bn_pattern( + x: torch.Tensor, + conv_weight: torch.Tensor, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + bn_running_mean: torch.Tensor, + bn_running_var: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + conv_weight = _append_qdq( + conv_weight, + is_per_channel, + is_bias=False, + kwargs=kwargs, + ) + if has_bias: + bias = kwargs["conv_bias"] + if bias_is_quantized: + bias = _append_qdq( + bias, + is_per_channel, + is_bias=True, + kwargs=kwargs, + ) + else: + bias = None + x = conv_fn(x, conv_weight, bias) + x = F.batch_norm( + x, + bn_running_mean, + bn_running_var, + bn_weight, + bn_bias, + training=bn_is_training, + eps=bn_eps, + ) + return x + + return _WrapperModule(_folded_quantized_qat_conv_bn_pattern) + + +def _has_conv_bias_filter( + match: "InternalMatch", + original_graph: Graph, + pattern_graph: Graph, +) -> bool: + """ + Match filter for the subgraph rewriter that returns True if the conv node in + the original graph has bias. + """ + for n in match.nodes_map.values(): + if _is_conv_or_conv_transpose_node(n): + return len(n.args) > 2 and n.args[2] is not None + raise ValueError("Could not find conv node in matched conv + bn pattern") + + +def _no_conv_bias_filter( + match: "InternalMatch", + original_graph: Graph, + pattern_graph: Graph, +) -> bool: + """ + Match filter for the subgraph rewriter that returns True if the conv node in + the original graph does NOT have bias. + """ + return not _has_conv_bias_filter(match, original_graph, pattern_graph) + + +def _is_quantize(n: Node) -> bool: + return n.target in [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.quantize_per_channel.default, + ] + + +def _is_dequantize(n: Node) -> bool: + return n.target in [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_channel.default, + ] + + +def _get_conv_bn_pattern_nodes(r: ReplacedPatterns) -> dict[str, tuple[Node, Node]]: + """ + Helper function to extract the nodes in the conv-bn fusion pattern after + subgraph rewriting, in the form of a map: + + {name: (original_node, replacement_node)} + + The following names must exist in the map: + + "conv", "conv_weight", "conv_input", "bn", "getitem" + + The following names may exist in the map: + + "conv_weight_q", "conv_weight_dq", "conv_bias", + "conv_bias_q", "conv_bias_dq" + """ + + def _get_nodes(nodes: list[Node]) -> tuple[Node, Node, Optional[Node]]: + """ + Return a 3-tuple of (conv_node, bn_node, getitem_node). + This asserts that the match contains exactly one of each node. + """ + conv_node, bn_node, getitem_node = None, None, None + for n in nodes: + if n.op != "call_function": + continue + if _is_conv_or_conv_transpose_node(n): + assert conv_node is None + conv_node = n + if _is_bn_node(n): + assert bn_node is None + bn_node = n + if n.target == operator.getitem: + assert getitem_node is None + getitem_node = n + assert conv_node is not None + assert bn_node is not None + return (conv_node, bn_node, getitem_node) + + def _get_q_dq_nodes(n: Node) -> tuple[Node, Node, Node]: + """ + Return a 3-tuple of (orig_node, q_node, dq_node). + """ + assert _is_dequantize(n) + q_node = n.args[0] + assert isinstance(q_node, Node) + assert _is_quantize(q_node) + orig_node = q_node.args[0] + assert isinstance(orig_node, Node) + return (orig_node, q_node, n) + + original_nodes = list(_filter_nodes_map(r.nodes_map).values()) + o_conv, o_bn, o_getitem = _get_nodes(original_nodes) + r_conv, r_bn, r_getitem = _get_nodes(r.replacements) + + # Create the mapping from original node to replacement node + assert o_getitem is None + assert r_getitem is None + mapping = { + "conv": (o_conv, r_conv), + "bn": (o_bn, r_bn), + } + + # Extract conv input and weight + # Note: here we extract the original nodes indirectly through the pattern nodes + # because the args of the original nodes are no longer available after replacement + (p_conv, _, _) = _get_nodes(list(r.nodes_map.keys())) + (p_conv_input, p_conv_weight, *_) = p_conv.args + (r_conv_input, r_conv_weight, *_) = r_conv.args + assert isinstance(p_conv_input, Node) + assert isinstance(p_conv_weight, Node) + assert isinstance(r_conv_input, Node) + assert isinstance(r_conv_weight, Node) + o_conv_input = r.nodes_map[p_conv_input] + o_conv_weight = r.nodes_map[p_conv_weight] + + # If conv weight is quantized, extract the q - dq nodes + if _is_dequantize(p_conv_weight): + p_conv_weight, p_conv_weight_q, p_conv_weight_dq = _get_q_dq_nodes( + p_conv_weight + ) + r_conv_weight, r_conv_weight_q, r_conv_weight_dq = _get_q_dq_nodes( + r_conv_weight + ) + o_conv_weight = r.nodes_map[p_conv_weight] + o_conv_weight_q = r.nodes_map[p_conv_weight_q] + o_conv_weight_dq = r.nodes_map[p_conv_weight_dq] + mapping["conv_weight_q"] = (o_conv_weight_q, r_conv_weight_q) + mapping["conv_weight_dq"] = (o_conv_weight_dq, r_conv_weight_dq) + mapping["conv_input"] = (o_conv_input, r_conv_input) + mapping["conv_weight"] = (o_conv_weight, r_conv_weight) + + # Extract conv bias + if len(p_conv.args) > 2 and len(r_conv.args) > 2: + p_conv_bias = p_conv.args[2] + r_conv_bias = r_conv.args[2] + assert isinstance(p_conv_bias, Node) + assert isinstance(r_conv_bias, Node) + o_conv_bias = r.nodes_map[p_conv_bias] + + # If conv bias is quantized, extract the q - dq nodes + if _is_dequantize(p_conv_bias): + p_conv_bias, p_conv_bias_q, p_conv_bias_dq = _get_q_dq_nodes(p_conv_bias) + r_conv_bias, r_conv_bias_q, r_conv_bias_dq = _get_q_dq_nodes(r_conv_bias) + o_conv_bias = r.nodes_map[p_conv_bias] + o_conv_bias_q = r.nodes_map[p_conv_bias_q] + o_conv_bias_dq = r.nodes_map[p_conv_bias_dq] + mapping["conv_bias_q"] = (o_conv_bias_q, r_conv_bias_q) + mapping["conv_bias_dq"] = (o_conv_bias_dq, r_conv_bias_dq) + mapping["conv_bias"] = (o_conv_bias, r_conv_bias) + return mapping + + +def _filter_nodes_map(nodes_map: dict[Node, Node]) -> dict[Node, Node]: + """ + Return a filtered `nodes_map` returned from the subgraph rewriter. + The filtered `nodes_map` will contain only nodes that are actually + matched in the pattern, excluding None or placeholder nodes. + """ + new_nodes_map: dict[Node, Node] = {} + for pattern_node, graph_node in nodes_map.items(): + # bias can be None + if graph_node is None: + continue + # skip pattern placeholder nodes + if pattern_node.op == "placeholder": + continue + new_nodes_map[pattern_node] = graph_node + return new_nodes_map + + +# TODO: this is error prone, use the replace_literals_with_placeholders hack instead +def _copy_over_literal_conv_args(original_node: Node, new_node: Node): + """ + Copy over literal args in conv, such as stride and padding, from the matched node + in the original graph to its replacement in the new graph. + + This is needed due to the following limitation in the subgraph rewriter when used + with dynamo export: literal (non-tensor) args are not supported in the match and + replacement patterns. This is because dynamo export automatically inlines these + literal args, making them dead placeholder nodes. In the future, we should check + if dynamo export can optionally disable this inlining, or if subgraph rewriter + can do the copying for us. See https://github.com/pytorch/pytorch/issues/100419. + + Note: Unlike other tensor args like conv weights and biases, literal args are + preserved in the original nodes after replacement, so we can access them here. + """ + assert _is_conv_or_conv_transpose_node(original_node) + assert _is_conv_or_conv_transpose_node(new_node) + # x, weight, bias, [stride, padding, dilation, transposed, output_padding, groups] + new_args = list(new_node.args) + if len(new_args) < 3: + # bias is optional, when it is not present, it means it is None + new_args.append(None) + new_node.args = tuple(new_args[:3]) + original_node.args[3:] + + +def _update_conv_input_qspec_map_after_replacement( + original_node: Node, replacement_node: Node +): + """ + Update the `input_qspec_map` in the annotation after subgraph rewriting. + + The original annotation referred to the nodes in the original graph, + so the keys in the `input_qspec_map` will need to be updated to reflect + the corresponding nodes in the replacement graph. + """ + assert _is_conv_or_conv_transpose_node(original_node) + assert _is_conv_or_conv_transpose_node(replacement_node) + if "quantization_annotation" not in original_node.meta: + return + original_input_qspec_map = original_node.meta[ + "quantization_annotation" + ].input_qspec_map + input_qspec_map = {} + # get the list of configs, it should be ordered as input, weight, bias + # note: this is really hacky, we need a better solution, hopefully + # in subgraph_rewriter, issue tracking the problem: https://github.com/pytorch/pytorch/issues/101820 + all_configs = list(original_input_qspec_map.items()) + # input activation + input_qspec_map[replacement_node.args[0]] = all_configs[0][1] + # weight + input_qspec_map[replacement_node.args[1]] = all_configs[1][1] + # bias + if len(replacement_node.args) > 2 and len(all_configs) > 2: + input_qspec_map[replacement_node.args[2]] = all_configs[2][1] + replacement_node.meta["quantization_annotation"].input_qspec_map = input_qspec_map + + +def _update_special_qspecs_after_replacement( + node: Node, + original_to_replacement_node: dict[Node, Node], +): + """ + Update the `SharedQuantizationSpec`s and `DerivedQuantizationSpec`s + used in `node`'s quantization annotation after subgraph rewriting. + + The original annotation referred to the nodes in the original graph, + so the nodes used in these special quantization specs will need to + be updated to the corresponding nodes in the replacement graph. + """ + + def _get_new_edge_or_node(edge_or_node: EdgeOrNode): + if isinstance(edge_or_node, Node): + _node = edge_or_node + return original_to_replacement_node.get(_node, _node) + elif ( + isinstance(edge_or_node, tuple) + and len(edge_or_node) == 2 + and all(isinstance(x, Node) for x in edge_or_node) + ): + src, dest = edge_or_node + return ( + original_to_replacement_node.get(src, src), + original_to_replacement_node.get(dest, dest), + ) + else: + raise ValueError("unexpected type for edge_or_node: ", type(edge_or_node)) + + def _get_new_qspec(qspec: QuantizationSpecBase): + if isinstance(qspec, SharedQuantizationSpec): + new_edge_or_node = _get_new_edge_or_node(qspec.edge_or_node) + return SharedQuantizationSpec(new_edge_or_node) + elif isinstance(qspec, DerivedQuantizationSpec): + new_derived_from = [_get_new_edge_or_node(x) for x in qspec.derived_from] + return dataclasses.replace(qspec, derived_from=new_derived_from) + else: + return qspec + + if "quantization_annotation" not in node.meta: + return + annotation = node.meta["quantization_annotation"] + for input_node, qspec in annotation.input_qspec_map.items(): + annotation.input_qspec_map[input_node] = _get_new_qspec(qspec) + annotation.output_qspec = _get_new_qspec(annotation.output_qspec) + + +def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule: + # Example inputs for conv-bn1d patterns + _conv1d_bn_example_inputs = ( + torch.randn(1, 1, 3), # x + torch.randn(1, 1, 1), # conv_weight + torch.randn(1), # conv_bias + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var + ) + + # Example inputs for conv-bn2d patterns + _conv2d_bn_example_inputs = ( + torch.randn(1, 1, 3, 3), # x + torch.randn(1, 1, 1, 1), # conv_weight + torch.randn(1), # conv_bias + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var + ) + + has_bn = any(_is_bn_node(n) for n in m.graph.nodes) + if not has_bn: + return m + is_cuda_options = [True, False] if torch.cuda.is_available() else [False] + for is_cuda in is_cuda_options: + m = _fuse_conv_bn_qat_helper( + m, F.conv1d, _conv1d_bn_example_inputs, is_cuda=is_cuda + ) + m = _fuse_conv_bn_qat_helper( + m, F.conv2d, _conv2d_bn_example_inputs, is_cuda=is_cuda + ) + m = _fuse_conv_bn_qat_helper( + m, F.conv_transpose1d, _conv1d_bn_example_inputs, is_cuda=is_cuda + ) + m = _fuse_conv_bn_qat_helper( + m, F.conv_transpose2d, _conv2d_bn_example_inputs, is_cuda=is_cuda + ) + return m + + +def _fuse_conv_bn_qat_helper( + m: GraphModule, + conv_fn: Callable, + example_inputs: tuple[Any, ...], + is_cuda: bool, +) -> GraphModule: + """ + Given a graph of decomposed aten ops, replace the (conv + bn) pattern with + the fused QAT subgraph equivalent. The input graph should already be annotated. + The annotations in the original nodes will be preserved in the corresponding + nodes in the new subgraph. + + Note: This also handles the (conv + bn + relu) pattern. + """ + m.graph.eliminate_dead_code() + m.recompile() + + conv_bn_pattern = _get_conv_bn_pattern(conv_fn) + match_pattern = _get_aten_graph_module_for_pattern( + conv_bn_pattern, + example_inputs, + is_cuda, + ) + + # Step (1): Replace patterns with conv bias + # + # Here we do replacement separately for cases with and without conv bias, since + # the replacement patterns for these two cases are substantially different. + # TODO: use the public replace_pattern API once it also returns replacement nodes + + qat_conv_bn_pattern = _get_qat_conv_bn_pattern(conv_fn) + replacement_pattern_with_conv_bias = _get_aten_graph_module_for_pattern( + qat_conv_bn_pattern, + example_inputs, + is_cuda, + ) + replacements_with_conv_bias = replace_pattern_with_filters( + m, + match_pattern, + replacement_pattern_with_conv_bias, + match_filters=[_has_conv_bias_filter], + ignore_literals=True, + ) + m.recompile() + + # Step (2): Replace patterns without conv bias + + qat_conv_bn_pattern_no_conv_bias = _get_qat_conv_bn_pattern_no_conv_bias(conv_fn) + replacement_pattern_no_conv_bias = _get_aten_graph_module_for_pattern( + qat_conv_bn_pattern_no_conv_bias, + example_inputs, + is_cuda, + ) + replacements_no_conv_bias = replace_pattern_with_filters( + m, + match_pattern, + replacement_pattern_no_conv_bias, + match_filters=[_no_conv_bias_filter], + ignore_literals=True, + ) + m.recompile() + + # Step (3): Post processing + # + # Due to limited functionality in the subgraph rewriter, here we manually + # update the replacement graph as follows: + # + # (a) Copy over metadata from original subgraph. This ensures the stack traces + # and annotations are preserved in the new subgraph + # + # (b) Copy over literal args for conv from the original subgraph + # TODO: do this for literal args for batchnorm as well + # + # (c) Update all references of the old nodes in the original subgraph to refer + # to the corresponding nodes in the new subgraph in the annotations + # + # In the future, we should try to push as much of this functionality into the + # subgraph rewriter as possible, so we don't have to manually copy anything over. + # For more detail, see https://github.com/pytorch/pytorch/issues/100419. + + all_original_to_replacement_nodes = {} + for r in replacements_with_conv_bias + replacements_no_conv_bias: + replacement_dict = _get_conv_bn_pattern_nodes(r) + # The original conv node's "nn_module_stack" + conv_nn_module = replacement_dict["conv"][0].meta.get("nn_module_stack", None) + for k, node_tuple in replacement_dict.items(): + original_node, replacement_node = node_tuple + # Step (3a): Copy over metadata for all nodes in [conv - bn - getitem] + replacement_node.meta = original_node.meta + # If original_node is a get_attr node, it doesn't have nn_module_stack. + # In this case, we copy nn_module_stack from the original conv node. + if ( + k in ["conv_input", "conv_weight"] + and conv_nn_module + and "nn_module_stack" not in replacement_node.meta + ): + replacement_node.meta["nn_module_stack"] = copy.deepcopy(conv_nn_module) + if _is_conv_or_conv_transpose_node(original_node): + # Step (3b): Copy over conv literal args + _copy_over_literal_conv_args(original_node, replacement_node) + # Step (3c): Update old references in the conv node's input_qspec_map + _update_conv_input_qspec_map_after_replacement( + original_node, replacement_node + ) + all_original_to_replacement_nodes[original_node] = replacement_node + + # Step (3c): Update old references in the special qspecs for all nodes in the graph + for n in m.graph.nodes: + _update_special_qspecs_after_replacement(n, all_original_to_replacement_nodes) + + return m + + +def _duplicate_dequantize_node(m: GraphModule): + """ + Helper function to duplicate all dequantize nodes in the graph if the + node has more than one user. For example: + + Before: + quantize -> dequantize -> a + \\--> b + \\--> c + + After: + quantize -> dequantize_1 -> a + \\--> dequantize_2 -> b + \\--> dequantize_3 -> c + + This is useful for subgraph rewriting. E.g. if we wish to match the + pattern [dequantize - a] above, subgraph matching would fail because + the dequantize node has users outside the matched portion of the graph. + Instead, we match [dequantize_1 - a], which is safe. + """ + dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor + for n in m.graph.nodes: + if n.op != "call_function" or n.target != dq_op or len(n.users) == 1: + continue + for user in list(n.users): + with m.graph.inserting_before(n): + new_node = m.graph.create_node("call_function", dq_op, n.args, n.kwargs) + user.replace_input_with(n, new_node) + m.graph.erase_node(n) + m.recompile() + + +def _remove_extra_dequantize(m: GraphModule): + """ + Removes duplicate dequant nodes in the graph, for an operator that has + multiple dequant nodes as a user, replace them with a single dequant node + that can be shared across all the uses. This should be seen as the "reverse" + of `_duplicate_dequantize_node`. + """ + dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor + for n in m.graph.nodes: + dq_users = [ + user + for user in n.users + if user.op == "call_function" and user.target == dq_op + ] + if len(dq_users) > 1: + with m.graph.inserting_after(dq_users[0]): + new_node = m.graph.create_node( + "call_function", dq_op, dq_users[0].args, {} + ) + for dq_user in dq_users: + dq_user.replace_all_uses_with(new_node) + m.graph.erase_node(dq_user) + m.recompile() + + +def _copy_over_q_dq_args(original_node: Node, replacement_node: Node): + """ + Given a pair of quantize or dequantize nodes, copy over all literal args + from the original node to the replacement node. + """ + # For quantize_per_tensor, scale and zp are literals and need to be copied + # For quantize_per_channel, scale and zp are get_attr nodes and should be skipped + assert original_node.target == replacement_node.target + if original_node.target in ( + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ): + # Args: input, [scale, zp, qmin, qmax, dtype] + start_copy_arg_index = 1 + elif original_node.target in ( + torch.ops.quantized_decomposed.quantize_per_channel.default, + torch.ops.quantized_decomposed.dequantize_per_channel.default, + ): + # Args: input, scale, zp, [axis, qmin, qmax, dtype] + start_copy_arg_index = 3 + else: + raise ValueError( + f"Expected quantize/dequantize nodes, got '{original_node.target}'" + ) + replacement_node.args = ( + replacement_node.args[:start_copy_arg_index] + + original_node.args[start_copy_arg_index:] + ) + + +def _fold_conv_bn_qat(m: GraphModule) -> GraphModule: + # Example inputs for quantized and folded conv-bn1d patterns used in convert + _quantized_conv1d_bn_example_inputs = ( + torch.randn(1, 1, 3), # x + torch.randn(1, 1, 1), # conv_weight + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var + ) + + # Example inputs for quantized and folded conv-bn2d patterns used in convert + _quantized_conv2d_bn_example_inputs = ( + torch.randn(1, 1, 3, 3), # x + torch.randn(1, 1, 1, 1), # conv_weight + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var + ) + + has_bn = any(_is_bn_node(n) for n in m.graph.nodes) + if not has_bn: + return m + is_cuda_options = [True, False] if torch.cuda.is_available() else [False] + for is_cuda in is_cuda_options: + m = _fold_conv_bn_qat_helper( + m, F.conv1d, _quantized_conv1d_bn_example_inputs, is_cuda=is_cuda + ) + m = _fold_conv_bn_qat_helper( + m, F.conv2d, _quantized_conv2d_bn_example_inputs, is_cuda=is_cuda + ) + m = _fold_conv_bn_qat_helper( + m, F.conv_transpose1d, _quantized_conv1d_bn_example_inputs, is_cuda=is_cuda + ) + m = _fold_conv_bn_qat_helper( + m, F.conv_transpose2d, _quantized_conv2d_bn_example_inputs, is_cuda=is_cuda + ) + + # remove in place add from batchnorm tracking traning stats + for node in m.graph.nodes: + if ( + node.target == torch.ops.aten.add_.Tensor + and node.args[0].op == "get_attr" + and node.args[1] == 1 + and ( + torch.nn.modules.batchnorm.BatchNorm2d + in [val[1] for val in node.meta["source_fn_stack"]] + or torch.nn.modules.batchnorm.BatchNorm1d + in [val[1] for val in node.meta["source_fn_stack"]] + ) + ): + m.graph.erase_node(node) + + m.graph.eliminate_dead_code() + m.recompile() + + return m + + +def _fold_conv_bn_qat_helper( + m: GraphModule, + conv_fn: Callable, + example_inputs: tuple[Any, ...], + is_cuda: bool, +) -> GraphModule: + """ + Replace the quantized (conv + bn) pattern with conv with bn weights folded into the weights of conv. + """ + + m.graph.eliminate_dead_code() + m.recompile() + _duplicate_dequantize_node(m) + + # Step (1): Replace QAT pattern with simple [conv - bn] pattern + replacements = [] + replacement_options = itertools.product( + [True, False], # is_per_channel + [True, False], # has_bias + [True, False], # bias_is_quantized + [True, False], # bn_is_training + ) + for ( + is_per_channel, + has_bias, + bias_is_quantized, + bn_is_training, + ) in replacement_options: + # For the cases without bias, `bias_is_quantized` is irrelevant, so here we arbitrarily + # filter out one of the values for this flag to avoid having duplicate patterns + if not has_bias and bias_is_quantized: + continue + kwargs = _get_quantized_conv_bn_example_inputs_kwargs( + is_per_channel, has_bias, bias_is_quantized, is_cuda + ) + match_pattern = _get_quantized_qat_conv_bn_pattern( + is_per_channel, has_bias, bias_is_quantized, conv_fn, bn_is_training + ) + match_pattern = _get_aten_graph_module_for_pattern( + match_pattern, + example_inputs, + is_cuda, + **kwargs, + ) + replacement_pattern = _get_folded_quantized_qat_conv_bn_pattern( + is_per_channel, has_bias, bias_is_quantized, conv_fn, bn_is_training + ) + replacement_pattern = _get_aten_graph_module_for_pattern( + replacement_pattern, + example_inputs, + is_cuda, + **kwargs, + ) + replacements.extend( + replace_pattern_with_filters( + m, + match_pattern, + replacement_pattern, + ignore_literals=True, + ) + ) + m.recompile() + _remove_extra_dequantize(m) + + for r in replacements: + node_map = _get_conv_bn_pattern_nodes(r) + + # Step (2): Copy over metadata from original subgraph + for original_node, replacement_node in node_map.values(): + replacement_node.meta = original_node.meta + + # Step (3): Copy over args for weight (and optionally bias) q - dq nodes + _copy_over_q_dq_args(*node_map["conv_weight_q"]) + _copy_over_q_dq_args(*node_map["conv_weight_dq"]) + if "conv_bias_q" in node_map: + assert "conv_bias_dq" in node_map + _copy_over_q_dq_args(*node_map["conv_bias_q"]) + _copy_over_q_dq_args(*node_map["conv_bias_dq"]) + + # Step (4): Fold BN weights into conv + conv_bias = None + (_, conv_node) = node_map["conv"] + (_, bn_node) = node_map["bn"] + (_, conv_weight) = node_map["conv_weight"] + if "conv_bias" in node_map: + (_, conv_bias) = node_map["conv_bias"] + fold_bn_weights_into_conv_node(conv_node, conv_weight, conv_bias, bn_node, m) + + # Copy over literal args for conv + for original_node in _filter_nodes_map(r.nodes_map).values(): + if _is_conv_or_conv_transpose_node(original_node): + _copy_over_literal_conv_args(original_node, conv_node) + + m.graph.eliminate_dead_code() + m.recompile() + return m diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/representation/__init__.py b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/representation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7b56f80fe94193da8f7c9f714db7c2aebb33c45c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/representation/__init__.py @@ -0,0 +1,6 @@ +from .rewrite import reference_representation_rewrite + + +__all__ = [ + "reference_representation_rewrite", +] diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/representation/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/representation/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..259cf313c84eba47ebbb398f32c5e127039075b9 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/representation/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/representation/__pycache__/rewrite.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/representation/__pycache__/rewrite.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6cdd96ffbf44b712afbad00f2500e6a83e8d810 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/representation/__pycache__/rewrite.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/representation/rewrite.py b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/representation/rewrite.py new file mode 100644 index 0000000000000000000000000000000000000000..a7bf80e2439aab9a53836d1acc72a7f2883151c6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/representation/rewrite.py @@ -0,0 +1,824 @@ +# mypy: allow-untyped-defs +from dataclasses import dataclass +from functools import partial +from typing import Any, Callable, Optional + +import torch +from torch._export.utils import _disable_aten_to_metadata_assertions +from torch._higher_order_ops.out_dtype import out_dtype +from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 +from torch.ao.quantization.pt2e.export_utils import _WrapperModule +from torch.ao.quantization.pt2e.utils import ( + _get_aten_graph_module_for_pattern, + _replace_literals_with_existing_placeholders, + _replace_literals_with_new_placeholders, + remove_tensor_overload_for_qdq_ops, +) +from torch.fx import GraphModule +from torch.fx.subgraph_rewriter import replace_pattern + + +__all__ = [ + "reference_representation_rewrite", +] + + +def _qdq_quantized_linear( + x_i8, + x_scale, + x_zero_point, + x_quant_min, + x_quant_max, + weight_i8, + weight_scale, + weight_zero_point, + weight_quant_min, + weight_quant_max, + bias_fp32, + out_scale, + out_zero_point, + out_quant_min, + out_quant_max, +): + x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8 + ) + weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + weight_i8, + weight_scale, + weight_zero_point, + weight_quant_min, + weight_quant_max, + torch.int8, + ) + out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32) + out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor( + out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8 + ) + return out_i8 + + +def _reference_quantized_linear( + x_i8, + x_scale, + x_zero_point, + x_quant_min, + x_quant_max, + weight_i8, + weight_scale, + weight_zero_point, + weight_quant_min, + weight_quant_max, + bias_fp32, + out_scale, + out_zero_point, + out_quant_min, + out_quant_max, +): + # without using quant_min/max in clamp, the traced graph will not have quant_mi/max args. + # This results in failure to match the pattern. + # Therefore, we call a torch.ops.aten.clamp here + x_i8 = torch.ops.aten.clamp(x_i8, x_quant_min, x_quant_max) + weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max) + + x_i16 = x_i8.to(torch.int16) + weight_i16 = weight_i8.to(torch.int16) + # always set bias to None so that the same representation can work for the case + # no matter if bias_scale == x_scale * weight_scale or not + acc_i32 = out_dtype( + torch.ops.aten.linear.default, + torch.int32, + x_i16 - x_zero_point, + weight_i16 - weight_zero_point, + None, + ) + # TODO: change to mul.Scalar + # Note: we are quantizing bias with these scales without signal from user, but it might be OK + bias_scale = x_scale * weight_scale + bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale) + acc_i32 = acc_i32 + bias_i32 + # TODO: change to mul.Scalar when we make x_scale/weight_scale etc. Scalar values + acc_i32 = ( + out_dtype( + torch.ops.aten.mul.Tensor, + torch.int32, + acc_i32, + x_scale * weight_scale / out_scale, + ) + + out_zero_point + ) + out_i8 = torch.ops.aten.clamp(acc_i32, out_quant_min, out_quant_max).to(torch.int8) + return out_i8 + + +def _qdq_dynamic_quantized_linear( + x_fp32, + x_quant_min, + x_quant_max, + x_eps, + weight_i8, + weight_scale, + weight_zero_point, + weight_quant_min, + weight_quant_max, + bias_fp32, +): + x_scale, x_zero_point = torch.ops.quantized_decomposed.choose_qparams( + x_fp32, x_quant_min, x_quant_max, x_eps, torch.int8 + ) + x_i8 = torch.ops.quantized_decomposed.quantize_per_tensor( + x_fp32, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8 + ) + x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8 + ) + weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + weight_i8, + weight_scale, + weight_zero_point, + weight_quant_min, + weight_quant_max, + torch.int8, + ) + out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32) + return out_fp32 + + +def _reference_dynamic_quantized_linear( + x_fp32, + x_quant_min, + x_quant_max, + x_eps, + weight_i8, + weight_scale, + weight_zero_point, + weight_quant_min, + weight_quant_max, + bias_fp32, +): + x_scale, x_zero_point = torch.ops.quantized_decomposed.choose_qparams( + x_fp32, x_quant_min, x_quant_max, x_eps, torch.int8 + ) + # decomposed representation for quantize_per_tensor + # TODO: use out_dtype(mul, ...) here when the op is ready + x_fp32 = x_fp32 / x_scale # fp32 + # round modes might be different here + # pytorch is rounding to even, which is also common for most of the backends + x_fp32 = torch.round(x_fp32) # fp32 + x_i32 = x_fp32.to(dtype=torch.int32) # int32 + x_i32 = x_i32 + x_zero_point # int32 + # clamp works for fp32, int32 and int8 dtypes + x_i32 = torch.clamp(x_i32, x_quant_min, x_quant_max) # int32 + x_i8 = x_i32.to(dtype=torch.int8) + + weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max) + + x_i16 = x_i8.to(torch.int16) + weight_i16 = weight_i8.to(torch.int16) + # always set bias to None so that the same representation can work for the case + # no matter if bias_scale == x_scale * weight_scale or not + acc_i32 = out_dtype( + torch.ops.aten.linear.default, + torch.int32, + x_i16 - x_zero_point, + weight_i16 - weight_zero_point, + None, + ) + bias_scale = x_scale * weight_scale + bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale) + acc_i32 = acc_i32 + bias_i32 + out_fp32 = acc_i32 * (x_scale * weight_scale) + return out_fp32 + + +def _qdq_quantized_conv2d( + x_i8, + x_scale, + x_zero_point, + x_quant_min, + x_quant_max, + weight_i8, + weight_scale, + weight_zero_point, + weight_quant_min, + weight_quant_max, + bias_fp32, + out_scale, + out_zero_point, + out_quant_min, + out_quant_max, +): + stride = [1, 1] + padding = [0, 0] + dilation = [1, 1] + transposed = False + output_padding = [0, 0] + groups = 1 + x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8 + ) + weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + weight_i8, + weight_scale, + weight_zero_point, + weight_quant_min, + weight_quant_max, + torch.int8, + ) + out_fp32 = torch.ops.aten.convolution.default( + x_fp32, + weight_fp32, + bias_fp32, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ) + out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor( + out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8 + ) + return out_i8 + + +def _reference_quantized_conv2d( + x_i8, + x_scale, + x_zero_point, + x_quant_min, + x_quant_max, + weight_i8, + weight_scale, + weight_zero_point, + weight_quant_min, + weight_quant_max, + bias_fp32, + out_scale, + out_zero_point, + out_quant_min, + out_quant_max, +): + stride = [1, 1] + padding = [0, 0] + dilation = [1, 1] + transposed = False + output_padding = [0, 0] + groups = 1 + # without using quant_min/max in clamp, the traced graph will not have quant_mi/max args. + # This results in failure to match the pattern. + # Therefore, we call a torch.ops.aten.clamp here + x_i8 = torch.ops.aten.clamp(x_i8, x_quant_min, x_quant_max) + weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max) + + x_i16 = x_i8.to(torch.int16) + weight_i16 = weight_i8.to(torch.int16) + # always set bias to None so that the same representation can work for the case + # no matter if bias_scale == x_scale * weight_scale or not + acc_i32 = out_dtype( + torch.ops.aten.convolution.default, + torch.int32, + x_i16 - x_zero_point, + weight_i16 - weight_zero_point, + None, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ) + # Note: we are quantizing bias with these scales without signal from user, but it might be OK + bias_scale = x_scale * weight_scale + # bias quantization to int32 uses bias_scale = x_scale * weight_scale due to: + # Take linear calculation for example + # Out_(i, j)_fp32 = Sum_(over k)[X_(i, k)_fp32 * W_(i, k)_fp32] + bias_(i)_fp32 + # Represent X, W fp32 as their dequant transforms + # A_fp32 = (A_q - A_zero_point)/A_scale + # Out_(i, j)_fp32 = Sum_(over k)[(X_(i, k)_fp32 - X_zp) * X_scale * (W_(i, k)_fp32 - W_zp) * W_scale] + bias_(i)_fp32 + # Factor out X_scale and W_scale + # Out_(i, j)_fp32 = ((X_scale * W_scale) * Sum_(over k)[(X_(i, k)_fp32 - X_zp) * (W_(i, k)_fp32 - W_zp)]) + bias_(i)_fp32 + # In order to addition of bias_(i)_fp32 inside, we must do + # Out_(i, j)_fp32 = (X_scale * W_scale) * (Sum_(over k)[(X_(i, k)_fp32 - X_zp) * (W_(i, k)_fp32 - W_zp)] + (1 / (X_scale * W_scale)) * bias_(i)_fp32)W_scale # noqa: B950 + # Note we had to multiply bias_fp32 qith X_scale * W_scale = bias_scale + # Thus bias quantization to int32 must be with X_scale * W_scale + + bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale) + # Unsqueeze to match broadcast dims + # Unfortnuately I cannot do bias_i32.unsqueeze(0) due to literal matching nightmare + # in graph pattern replacement + bias_i32 = bias_i32.unsqueeze(-1) + bias_i32 = bias_i32.unsqueeze(-1) + acc_i32 = acc_i32 + bias_i32 + # TODO: change to mul.Scalar when we make x_scale/weight_scale etc. Scalar values + acc_i32 = ( + out_dtype( + torch.ops.aten.mul.Tensor, + torch.int32, + acc_i32, + x_scale * weight_scale / out_scale, + ) + + out_zero_point + ) + out_i8 = torch.ops.aten.clamp(acc_i32, out_quant_min, out_quant_max).to(torch.int8) + return out_i8 + + +def _qdq_quantized_add_relu( + x_i8, + x_scale, + x_zero_point, + y_i8, + y_scale, + y_zero_point, + out_scale, + out_zero_point, + quant_min, + quant_max, +): + x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + x_i8, x_scale, x_zero_point, quant_min, quant_max, torch.int8 + ) + y_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + y_i8, y_scale, y_zero_point, quant_min, quant_max, torch.int8 + ) + out_fp32 = x_fp32 + y_fp32 + out_fp32 = torch.ops.aten.relu(out_fp32) + out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor( + out_fp32, out_scale, out_zero_point, quant_min, quant_max, torch.int8 + ) + return out_i8 + + +def _reference_quantized_add_relu( + x_i8, + x_scale, + x_zero_point, + y_i8, + y_scale, + y_zero_point, + out_scale, + out_zero_point, + quant_min, + quant_max, +): + """ + See comments for `_reference_quantized_add` for more information on + how to derive the formula for out_i8 based on x_i8 and y_i8 + """ + x_i32 = x_i8.to(torch.int32) + y_i32 = y_i8.to(torch.int32) + # TODO: change this to mul.Scalar? + x_i32 = out_dtype( + torch.ops.aten.mul.Tensor, + torch.int32, + (x_i32 - x_zero_point), + (x_scale / out_scale), + ) + y_i32 = out_dtype( + torch.ops.aten.mul.Tensor, + torch.int32, + (y_i32 - y_zero_point), + (y_scale / out_scale), + ) + out_i32 = x_i32 + y_i32 + out_zero_point + # out_i32 = torch.ops.aten.clamp(out_i32, out_zero_point) + out_i8 = torch.ops.aten.clamp(out_i32, out_zero_point, quant_max).to(torch.int8) + return out_i8 + + +def _qdq_quantized_add( + x_i8, + x_scale, + x_zero_point, + y_i8, + y_scale, + y_zero_point, + out_scale, + out_zero_point, + quant_min, + quant_max, +): + x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + x_i8, x_scale, x_zero_point, quant_min, quant_max, torch.int8 + ) + y_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + y_i8, y_scale, y_zero_point, quant_min, quant_max, torch.int8 + ) + out_fp32 = x_fp32 + y_fp32 + out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor( + out_fp32, out_scale, out_zero_point, quant_min, quant_max, torch.int8 + ) + return out_i8 + + +def _reference_quantized_add( + x_i8, + x_scale, + x_zero_point, + y_i8, + y_scale, + y_zero_point, + out_scale, + out_zero_point, + quant_min, + quant_max, +): + """ + # How to Derive the formula for out_i8 based on x_i8 and y_i8 + # (since quantized add takes x_i8, y_i8 and their quantization parameters, and produce an out_i8) + + # out_i8 is quantized output, we can write down the formula for it first: + out_i8 = out_f32 / out_scale + out_zero_point (1) + + # then out_fp32 is computed from x_f32 + y_f32, and the x_fp32 and y_fp32 are the dequantized x_i8 and y_i8 + out_f32 = x_f32 + y_f32 (2) + x_fp32 = (x_i8 - x_zero_point) * x_scale (3) + y_fp32 = (y_i8 - y_zero_point) * y_scale (4) + + # applying the above fomula to the out_i8 equation we can get the following: + out_i8 = out_fp32 / out_scale + out_zero_point # (1) + = (x_f32 + y_f32) / out_scale + out_zero_point # applying (2) to substitute out_fp32 with x_fp32 + y_fp32 + = ((x_i8 - x_zero_point) * x_scale + (y_i8 - y_zero_point) * y_scale) / out_scale + out_zero_point # apply (3) and (4) + """ + x_i32 = x_i8.to(torch.int32) + y_i32 = y_i8.to(torch.int32) + # TODO: use out_dtype op + x_i32 = torch.round((x_scale / out_scale) * (x_i32 - x_zero_point)).to(torch.int32) + y_i32 = torch.round((y_scale / out_scale) * (y_i32 - y_zero_point)).to(torch.int32) + out_i32 = x_i32 + y_i32 + out_zero_point + quant_min = -128 + quant_max = 127 + out_i8 = torch.ops.aten.clamp(out_i32, quant_min, quant_max).to(torch.int8) + return out_i8 + + +def _qdq_quantized_max_pool2d( + x_i8, + x_scale, + x_zero_point, + x_quant_min, + x_quant_max, + out_scale, + out_zero_point, + out_quant_min, + out_quant_max, +): + kernel_size = 1 + stride = 1 + padding = 0 + dilation = 1 + ceil_mode = False + x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8 + ) + out_fp32, _ = torch.ops.aten.max_pool2d_with_indices.default( + x_fp32, kernel_size, stride, padding, dilation, ceil_mode + ) + out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor( + out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8 + ) + return out_i8 + + +def _reference_quantized_max_pool2d( + x_i8, + x_scale, + x_zero_point, + x_quant_min, + x_quant_max, + out_scale, + out_zero_point, + out_quant_min, + out_quant_max, +): + kernel_size = 1 + stride = 1 + padding = 0 + dilation = 1 + ceil_mode = False + # to preserve x_quant_min, x_quant_max in the graph for pattern matching + x_i8 = torch.clamp(x_i8, x_quant_min, x_quant_max) + x_i32 = x_i8.to(torch.int32) + out_i32, _ = torch.ops.aten.max_pool2d_with_indices.default( + x_i32 - x_zero_point, kernel_size, stride, padding, dilation, ceil_mode + ) + out_fp32 = out_i32 * (x_scale / out_scale) + out_zero_point + out_fp32 = torch.clamp(out_fp32, out_quant_min, out_quant_max) + out_i8 = out_fp32.to(torch.int8) + return out_i8 + + +def _quantize_per_tensor_int8(x_fp32, scale, zero_point, quant_min, quant_max): + x = torch.ops.quantized_decomposed.quantize_per_tensor( + x_fp32, scale, zero_point, quant_min, quant_max, torch.int8 + ) + return x + + +def _reference_quantize_per_tensor_int8( + x_fp32, scale, zero_point, quant_min, quant_max +): + # TODO: use out_dtype(mul, ...) here when the op is ready + x = x_fp32 / scale # fp32 + # round modes might be different here + # pytorch is rounding to even, which is also common for most of the backends + x = torch.round(x) # fp32 + x = x.to(dtype=torch.int32) # int32 + x = x + zero_point # int32 + # clamp works for fp32, int32 and int8 dtypes + x = torch.clamp(x, quant_min, quant_max) # int32 + x = x.to(dtype=torch.int8) + return x + + +def _dequantize_per_tensor_int8(x_i8, scale, zero_point, quant_min, quant_max): + x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + x_i8, scale, zero_point, quant_min, quant_max, torch.int8 + ) + return x_fp32 + + +def _reference_dequantize_per_tensor_int8( + x_i8, scale, zero_point, quant_min, quant_max +): + # without using quant_min/max in clamp, the traced graph will not have quant_mi/max args. + # This results in failure to match the pattern. + # Therefore, we call a torch.ops.aten.clamp here + x_i8 = torch.ops.aten.clamp(x_i8, quant_min, quant_max) + # TODO: use out_dtype op + # note: x_i8.to(torch.int32) does not work here + # TODO: debug the implementation later when torchdynamo time out issue is resolved + return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32) + + +def _quantize_per_channel_int8( + x_fp32, scales, zero_points, ch_axis, quant_min, quant_max +): + out_i8 = torch.ops.quantized_decomposed.quantize_per_channel( + x_fp32, scales, zero_points, ch_axis, quant_min, quant_max, torch.int8 + ) + return out_i8 + + +def _reference_quantize_per_channel_int8( + x_fp32, scales, zero_points, ch_axis, quant_min, quant_max +): + x_fp32 = torch.transpose(x_fp32, ch_axis, -1) + out_i32 = torch.ops.aten.clamp( + torch.round(x_fp32 / scales).to(torch.int32) + zero_points, quant_min, quant_max + ) + out_i32 = torch.transpose(out_i32, ch_axis, -1) + return out_i32.to(torch.int8) + + +def _dequantize_per_channel_int8( + x_i8, scales, zero_points, ch_axis, quant_min, quant_max +): + # the following will be replaced as placeholders + out_fp32 = torch.ops.quantized_decomposed.dequantize_per_channel( + x_i8, scales, zero_points, ch_axis, quant_min, quant_max, torch.int8 + ) + return out_fp32 + + +def _reference_dequantize_per_channel_int8( + x_i8, scales, zero_points, ch_axis, quant_min, quant_max +): + # the following will be replaced as placeholders + # in order to preserve the quant_min/quant_max args for pattern matching (e.g. matching for int4 quantized ops) + # we call a torch.ops.aten.clamp here + x_i8 = torch.ops.aten.clamp(x_i8, quant_min, quant_max) + x_i8 = torch.transpose(x_i8, ch_axis, -1) + x_i32 = x_i8.to(torch.int32) + out_fp32 = (x_i32 - zero_points).to(torch.float) * scales + out_fp32 = torch.transpose(out_fp32, ch_axis, -1) + return out_fp32 + + +def _replace_ph_qdq_per_channel_replacement(gm: torch.fx.GraphModule): + return _replace_literals_with_existing_placeholders( + gm, exclude_literals=[-1], literal_to_ph_idx={1: 3, -128: 4, 127: 5} + ) + + +@dataclass +class _RewriteInfo: + """Data needed for rewrite, this includes example inputs, pattern and replacement functions + and post transformation functions for the exported pattern and replacement GraphModule + """ + + # example inputs used for exporting the pattern into GraphModule + example_inputs: tuple[Any, ...] + pattern: Callable + replacement: Callable + # post transformation on the exported pattern and replacement GraphModule + pattern_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None + replacement_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None + + +def reference_representation_rewrite(model: GraphModule) -> GraphModule: + _QUANTIZED_LINEAR_EXAMPLE_INPUTS = ( + torch.randint(-128, 127, (2, 5), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + torch.randint(-128, 127, (5, 5), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-127], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + torch.randn(1, dtype=torch.float), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + ) + + _DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS = ( + torch.randn((2, 5), dtype=torch.float), + -128, + 127, + torch.finfo(torch.float32).eps, + torch.randint(-128, 127, (5, 5), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-127], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + torch.randn(1, dtype=torch.float), + ) + + _QUANTIZED_CONV2d_EXAMPLE_INPUTS = ( + torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-127], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + torch.randn(1, dtype=torch.float), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + ) + + _QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS = ( + torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + ) + + _QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS = ( + torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + ) + + _QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = ( + torch.randn(1, 3, 3, 3, dtype=torch.float), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + ) + + _DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = ( + torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + ) + + _QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = ( + torch.randn(1, 3, 3, 3, dtype=torch.float), + torch.randn(3, dtype=torch.float), + torch.zeros(3, dtype=torch.int), + 1, + -128, + 127, + ) + + _DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = ( + torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), + torch.randn(3, dtype=torch.float), + torch.zeros(3, dtype=torch.int), + 1, + -128, + 127, + ) + + _REWRITE_INFO_LIST = [ + _RewriteInfo( + _DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS, + _WrapperModule(_qdq_dynamic_quantized_linear), + _WrapperModule(_reference_dynamic_quantized_linear), + partial( + _replace_literals_with_existing_placeholders, + literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3}, + ), + partial( + _replace_literals_with_existing_placeholders, + literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3}, + ), + ), + _RewriteInfo( + _QUANTIZED_LINEAR_EXAMPLE_INPUTS, + _WrapperModule(_qdq_quantized_linear), + _WrapperModule(_reference_quantized_linear), + _replace_literals_with_new_placeholders, + _replace_literals_with_new_placeholders, + ), + _RewriteInfo( + _QUANTIZED_CONV2d_EXAMPLE_INPUTS, + _WrapperModule(_qdq_quantized_conv2d), + _WrapperModule(_reference_quantized_conv2d), + partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]), + partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]), + ), + _RewriteInfo( + _QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS, + _WrapperModule(_qdq_quantized_add_relu), + _WrapperModule(_reference_quantized_add_relu), + ), + _RewriteInfo( + _QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS, + _WrapperModule(_qdq_quantized_add), + _WrapperModule(_reference_quantized_add), + ), + _RewriteInfo( + _QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS, + _WrapperModule(_qdq_quantized_max_pool2d), + _WrapperModule(_reference_quantized_max_pool2d), + _replace_literals_with_new_placeholders, + _replace_literals_with_new_placeholders, + ), + _RewriteInfo( + _QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS, + _WrapperModule(_quantize_per_tensor_int8), + _WrapperModule(_reference_quantize_per_tensor_int8), + ), + _RewriteInfo( + _DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS, + _WrapperModule(_dequantize_per_tensor_int8), + _WrapperModule(_reference_dequantize_per_tensor_int8), + ), + _RewriteInfo( + _QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS, + _WrapperModule(_quantize_per_channel_int8), + _WrapperModule(_reference_quantize_per_channel_int8), + _replace_ph_qdq_per_channel_replacement, + _replace_ph_qdq_per_channel_replacement, + ), + _RewriteInfo( + _DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS, + _WrapperModule(_dequantize_per_channel_int8), + _WrapperModule(_reference_dequantize_per_channel_int8), + _replace_ph_qdq_per_channel_replacement, + _replace_ph_qdq_per_channel_replacement, + ), + ] + + remove_tensor_overload_for_qdq_ops(model) + + with _disable_aten_to_metadata_assertions(): + for rewrite_info in _REWRITE_INFO_LIST: + example_inputs = rewrite_info.example_inputs + pattern = rewrite_info.pattern + replacement = rewrite_info.replacement + pattern_post_trans = rewrite_info.pattern_post_trans + replacement_post_trans = rewrite_info.replacement_post_trans + pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs) # type: ignore[arg-type, assignment] + remove_tensor_overload_for_qdq_ops(pattern) # type: ignore[arg-type] + replacement = _get_aten_graph_module_for_pattern( # type: ignore[assignment] + replacement, + example_inputs, # type: ignore[arg-type] + ) + remove_tensor_overload_for_qdq_ops(replacement) # type: ignore[arg-type] + if pattern_post_trans: + pattern = pattern_post_trans(pattern) + if replacement_post_trans: + replacement = replacement_post_trans(replacement) + pattern.recompile() # type: ignore[attr-defined] + replacement.recompile() # type: ignore[attr-defined] + replace_pattern(model, pattern, replacement) + + return model diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/utils.py b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..00bc183ed52b18f0464487cc106bec9b17ad05a0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/pt2e/utils.py @@ -0,0 +1,616 @@ +# mypy: allow-untyped-defs +import operator +import types +from typing import Any, Callable, Optional, Union + +import torch +import torch.ao.quantization.pt2e._affine_quantization # noqa: F401 +import torch.nn.functional as F + +# Makes sure that quantized_decomposed ops are registered +from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 +from torch.ao.quantization.quantizer import QuantizationAnnotation +from torch.export.unflatten import _assign_attr, _AttrKind +from torch.fx import GraphModule, Node +from torch.nn.utils.fusion import fuse_conv_bn_weights +from torch.utils._pytree import LeafSpec + + +__all__ = [ + "fold_bn_weights_into_conv_node", + "remove_tensor_overload_for_qdq_ops", +] + +_QUANTIZE_OPS = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.quantize_per_channel.default, +] + + +_DEQUANTIZE_OPS = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_channel.default, +] + + +def _is_connected(source: torch.fx.Node, dest: torch.fx.Node) -> bool: + """ + Assuming dest is one of the ops inserted by quant workflow, this function + finds if source and dest are connected. Assumption is that only quant workflow + inserted ops exist between source and dest + """ + quant_workflow_ops = _QUANTIZE_OPS + _DEQUANTIZE_OPS + quant_workflow_ops.append(torch.ops.quantized_decomposed.choose_qparams.tensor) + while dest.target in quant_workflow_ops: + if not isinstance(dest.args[0], torch.fx.Node): + raise ValueError( + f"expected arg[0] of quant workflow ops to be a node but found {dest.args[0]}" + ) + dest = dest.args[0] + return dest == source + + +def _find_q_dq_node_for_user( + produer: torch.fx.Node, user: torch.fx.Node +) -> tuple[Any, Any]: + """ + Find q, dq pair corresponding to [producer -> q -> dq -> user] + Utils works by finding dq arg of user and ensuring it is connected to + producer + """ + dq_node = None + for n in user.args: + if ( + isinstance(n, torch.fx.Node) + and n.op == "call_function" + and n.target in _DEQUANTIZE_OPS + ): + if _is_connected(produer, n): + dq_node = n + break + if dq_node is None: + for n in user.kwargs: + if ( + isinstance(n, torch.fx.Node) + and n.op == "call_function" + and n.target in _DEQUANTIZE_OPS + ): + if _is_connected(produer, n): + dq_node = n + break + if dq_node is None: + return (None, None) + + q_node = None + if ( + isinstance(arg := dq_node.args[0], torch.fx.Node) + and arg.op == "call_function" + and arg.target in _QUANTIZE_OPS + ): + q_node = arg + return (q_node, dq_node) + + +def _is_sym_size_node(node: Node): + return ( + node.op == "call_function" + and node.target == torch.ops.aten.sym_size.default + or node.target == torch.ops.aten.sym_numel.default + or node.target == torch.ops.aten.sym_numel + or node.target == torch.ops.aten.sym_size + ) + + +def _filter_sym_size_users(node: torch.fx.Node) -> list[torch.fx.Node]: + node_users = list(filter((lambda x: (_is_sym_size_node(x) is False)), node.users)) + return node_users + + +def _is_valid_annotation(annotation: QuantizationAnnotation) -> bool: + if annotation is None: + return False + input_qspec_map = annotation.input_qspec_map + output_qspec = annotation.output_qspec + if len(input_qspec_map) == 0 and output_qspec is None: + return False + return True + + +def _get_tensor_constant_from_node(node, m): + if node is None: + return None + assert node.op == "get_attr" + target_atoms = node.target.split(".") + attr_itr = m + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError( + f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}" + ) + attr_itr = getattr(attr_itr, atom) + return attr_itr + + +def _get_all_arguments(orig_args, orig_kwargs, args_schema): + all_args = [] + for i, schema in enumerate(args_schema): + if schema.name in orig_kwargs: + all_args.append(orig_kwargs[schema.name]) + elif not schema.kwarg_only and i < len(orig_args): + all_args.append(orig_args[i]) + else: + all_args.append(schema.default_value) + return all_args + + +def _is_supported_batch_norm_for_training(node: Node): + """ + Return True if the given node refers to an aten batch norm op QAT supports. + """ + supported_ops = [ + torch.ops.aten.batch_norm.default, + torch.ops.aten._native_batch_norm_legit.default, + # Note: we won't need this op anymore after batch norm consolidation + # For now, we need to continue to support it because it gives better + # training numerics than `_native_batch_norm_legit` + torch.ops.aten.cudnn_batch_norm.default, + torch.ops.aten.miopen_batch_norm.default, + ] + return node.target in supported_ops + + +# TODO: move this to torch/ao/quantization/utils.py +def _is_conv_node(n: Node): + """ + Return whether the node refers to an aten conv op. + """ + return n.op == "call_function" and n.target in [ + torch.ops.aten.conv1d.default, + torch.ops.aten.conv1d.padding, + torch.ops.aten.conv2d.default, + torch.ops.aten.conv2d.padding, + torch.ops.aten.conv3d.default, + torch.ops.aten.conv3d.padding, + ] + + +def _is_conv_transpose_node(n: Node): + """ + Return whether the node refers to an aten conv_transpose op. + """ + return n.op == "call_function" and n.target in [ + torch.ops.aten.conv_transpose1d, + torch.ops.aten.conv_transpose1d.default, + torch.ops.aten.conv_transpose2d, + torch.ops.aten.conv_transpose2d.input, + ] + + +def _is_conv_or_conv_transpose_node(n: Node): + """ + Return whether the node refers to an aten conv or conv transpose op. + """ + return _is_conv_node(n) or _is_conv_transpose_node(n) + + +def _is_conv_transpose_fn(conv_fn: Callable): + return conv_fn in [F.conv_transpose1d, F.conv_transpose2d] + + +def _is_bn_node(n: Node): + return ( + _is_supported_batch_norm_for_training(n) + or n.target == torch.ops.aten._native_batch_norm_legit_no_training.default + ) + + +def fold_bn_weights_into_conv_node( + conv_node: Node, + conv_weight_node: Node, + conv_bias_node: Optional[Node], + bn_node: Node, + m: GraphModule, +) -> None: + # conv args: input, weight, bias, stride, padding, dilation, ... + conv_w = _get_tensor_constant_from_node(conv_weight_node, m) + conv_b = _get_tensor_constant_from_node(conv_bias_node, m) + transpose = _is_conv_transpose_node(conv_node) + + # eval bn args: input, weight, bias, running mean, running var, momentum, eps + # train bn args: input, weight, bias, running mean, running var, training, momentum, eps + bn_args_schema = bn_node.target._schema.arguments # type: ignore[union-attr] + bn_args = _get_all_arguments(bn_node.args, bn_node.kwargs, bn_args_schema) + bn_w = _get_tensor_constant_from_node(bn_args[1], m) + bn_b = _get_tensor_constant_from_node(bn_args[2], m) + bn_rm = _get_tensor_constant_from_node(bn_args[3], m) + bn_rv = _get_tensor_constant_from_node(bn_args[4], m) + if bn_node.target == torch.ops.aten._native_batch_norm_legit_no_training.default: + eps_arg_index = 6 + elif _is_supported_batch_norm_for_training(bn_node): + eps_arg_index = 7 + else: + raise ValueError("BN node target is unexpected ", bn_node.target) + bn_eps = bn_args[eps_arg_index] + + fused_weight, fused_bias = fuse_conv_bn_weights( + conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=transpose + ) + + # update the weight and bias for conv + conv_args = list(conv_node.args) + # filling in the default bias argument + if len(conv_args) == 2: + conv_args.append(None) + + # calling data since the fused_weight and fused_bias are nn.Parameter + weight_attr_name = conv_weight_node.target + assert isinstance(weight_attr_name, str) + _assign_attr(fused_weight, m, weight_attr_name, _AttrKind.PARAMETER) + if conv_bias_node is not None: + bias_attr_name = conv_bias_node.target + _assign_attr(fused_bias, m, str(bias_attr_name), _AttrKind.PARAMETER) + else: + bias_attr_name = weight_attr_name + "_bias" + _assign_attr(fused_bias, m, bias_attr_name, _AttrKind.PARAMETER) + with m.graph.inserting_before(conv_node): + get_bias_node = m.graph.get_attr(bias_attr_name) + # NOTE: here we assume the bias of conv is not quantized! + conv_args[2] = get_bias_node + conv_node.args = tuple(conv_args) + + # native_batch_norm has 3 outputs, we expect getitem calls on the output + # and we want to replace the uses of getitem 0 with the output of conv + # + if bn_node.target == torch.ops.aten.batch_norm.default: + # With the new training ir, instead of batch_norm + getitem, + # we only have the batch_norm node. + # + # Before: + # conv -> bn -> users + # After: + # conv -> users + # bn has no users now + bn_node.replace_all_uses_with(conv_node) + else: + # Before: + # conv -> bn - (first output) -> users1 + # \ - (second output) -> users2 + # \ - (third output) -> users3 + # After: + # conv -> (first output) -> users1 + # bn - + # \ - (second output) -> users2 + # \ - (third output) -> users3 + # if users2 and users3 are empty then bn will be removed through dead code elimination + for user in bn_node.users: + if ( + user.op != "call_function" + or user.target != operator.getitem + or user.args[1] != 0 + ): + continue + user.replace_all_uses_with(conv_node) + + # If the BN node does not have users, erase it from the graph + # Note: we need to do this manually because the model can still be in train + # mode at this point, in which case DCE won't erase the BN node automatically + # since the node refers to a mutating op. Here we still need to call DCE first + # to get rid of the unused getitem nodes that consume the BN node. + m.graph.eliminate_dead_code() + if len(bn_node.users) == 0: + m.graph.erase_node(bn_node) + + +# fuse conv bn weights, inplace modification of the graph_module and graph +def _fuse_conv_bn_(m: GraphModule) -> None: + has_bn = any(_is_bn_node(n) for n in m.graph.nodes) + if not has_bn: + return + for n in m.graph.nodes: + if n.op != "call_function" or n.target not in ( + torch.ops.aten._native_batch_norm_legit_no_training.default, + torch.ops.aten.batch_norm.default, + ): + continue + bn_node = n + n = bn_node.args[0] + if not _is_conv_or_conv_transpose_node(n): + continue + conv_node = n + conv_weight_node = conv_node.args[1] + conv_bias_node = conv_node.args[2] if len(conv_node.args) > 2 else None + fold_bn_weights_into_conv_node( + conv_node, conv_weight_node, conv_bias_node, bn_node, m + ) + + m.graph.eliminate_dead_code() + m.recompile() + + +def _get_node_name_to_scope(model: GraphModule) -> dict[str, tuple[str, type]]: + # TODO: move this information to fx node itself + node_name_to_scope: dict[str, tuple[str, type]] = {} + for n in model.graph.nodes: + nn_module_stack = n.meta.get("nn_module_stack", None) + current_scope = ("", type(None)) + if nn_module_stack: + bt = list(nn_module_stack.values())[-1] + current_scope = (bt[0].split(".")[-1], bt[1]) + node_name_to_scope[n.name] = current_scope + return node_name_to_scope + + +def _get_aten_graph_module_for_pattern( + pattern: Callable, + example_inputs: tuple[Any, ...], + is_cuda: bool = False, + **kwargs, +) -> GraphModule: + """ + Convert the pattern to an FX graph with decomposed aten ops. + """ + if is_cuda: + example_inputs = tuple( + [x.cuda() if isinstance(x, torch.Tensor) else x for x in example_inputs] + ) + + aten_pattern = torch.export.export_for_training( + pattern, # type: ignore[arg-type] + example_inputs, + kwargs, + strict=True, + ).module() + + aten_pattern.graph.eliminate_dead_code() # type: ignore[operator, union-attr] + aten_pattern.recompile() # type: ignore[operator] + + # ep.module() adds copy_ nodes for the mutated inputs. + # For patterns, it doesn't matter + for node in aten_pattern.graph.nodes: # type: ignore[union-attr] + if ( + node.op == "call_function" + and node.target == torch.ops.aten.copy_.default + and len(node.users) == 0 + ): + aten_pattern.graph.erase_node(node) # type: ignore[operator, union-attr] + + aten_pattern.graph.eliminate_dead_code() # type: ignore[operator, union-attr] + aten_pattern.recompile() # type: ignore[operator] + + return aten_pattern # type: ignore[return-value] + + +def remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None: + """Remove .tensor overload for quantize/dequantize ops so that we can + use the match_pattern that we get from torchdynamo export to match the output of convert_pt2e + """ + _MAP = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: torch.ops.quantized_decomposed.quantize_per_tensor, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: torch.ops.quantized_decomposed.dequantize_per_tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: torch.ops.quantized_decomposed.quantize_per_tensor, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: torch.ops.quantized_decomposed.dequantize_per_tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor2: torch.ops.quantized_decomposed.quantize_per_tensor, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor2: torch.ops.quantized_decomposed.dequantize_per_tensor, + torch.ops.quantized_decomposed.quantize_per_channel.default: torch.ops.quantized_decomposed.quantize_per_channel, + torch.ops.quantized_decomposed.dequantize_per_channel.default: torch.ops.quantized_decomposed.dequantize_per_channel, + torch.ops.aten.clamp.Tensor: torch.ops.aten.clamp, + } + for n in match_pattern.graph.nodes: + if n.op != "call_function": + continue + if n.target in _MAP: + n.target = _MAP[n.target] + + +def _is_literal(arg): + if isinstance(arg, (int, float)): + return True + if isinstance(arg, (tuple, list)): + return all(map(_is_literal, arg)) + return False + + +def _replace_literals_with_new_placeholders( + gm: torch.fx.GraphModule, + merge_dup: bool = False, + exclude_literals: Optional[list[Any]] = None, +): + """Replace the literals in the graph with placeholder nodes that's created on the fly while we + traverse the graph, so that the literal arguments in the graph can be matched and replaced + + To use this, the pattern and replacement graph should have the exact same number of literal args + and they should be used in the exact same order in the pattern and replacement graph. + + If the literal arguments are not used in the same order in pattern and replacement graph, please + use `_replace_literals_with_existing_placeholders` instead + + Args: + `gm`: input GraphModule that we'll transform + `merge_dup`: boolean flag to indicate that if the same literal appears multiple times in + the graph, whether they should correspond to the same placeholder or not + `exclude_literals`: a list of literals that will not be replaced with placeholders + + Example: + + # 1. Original Graph + def pattern(self, x): + return x + 3 + + def replacement(self, x): + return x - 3 + + example_inputs = (torch.randn(1, 3, 3, 3),) + pattern_gm = _get_aten_graph_module_for_pattern(pattern, example_inputs) + replacement_gm = _get_aten_graph_module_for_pattern(pattern, example_inptus) + + # 2. Before calling replace literals we'll see the following graph: + def pattern(self, x): + return x + 3 + + def replacement(self, x): + return x - 3 + + pattern_gm = _replace_literals_with_new_placeholders(pattern_gm) + replacement_gm = _replace_literals_with_new_placeholders(replacement_gm) + + # 3. After replacing literals with new placeholder nodes + + def pattern(self, x, new_ph): + return x + new_ph + + def pattern(self, x, new_ph): + return x - new_ph + + """ + last_ph = None + cnt = 0 + literal_to_ph: dict[Union[float, bool, int, torch.dtype], Node] = {} + if exclude_literals is None: + exclude_literals = [] + + in_spec = gm._in_spec + args_spec = in_spec.children_specs[0] + for node in gm.graph.nodes: + if node.op == "placeholder": + last_ph = node + cnt += 1 + continue + with gm.graph.inserting_after(last_ph): + new_args = [] + for arg in node.args: + if _is_literal(arg) and arg not in exclude_literals: + if merge_dup and arg in literal_to_ph: + new_args.append(literal_to_ph[arg]) + else: + ph_node = gm.graph.placeholder("arg" + str(cnt)) + new_args.append(ph_node) + args_spec.children_specs.append(LeafSpec()) + cnt += 1 + if merge_dup: + literal_to_ph[arg] = ph_node + else: + new_args.append(arg) + new_args = tuple(new_args) + + node.args = new_args + + # Update `num_nodes`, `num_leaves`, `num_children`. + args_spec.__post_init__() + in_spec.__post_init__() + return gm + + +def _replace_literals_with_existing_placeholders( + gm: torch.fx.GraphModule, + exclude_literals: Optional[list[Any]] = None, + literal_to_ph_idx: Optional[dict[Union[float, int, bool, torch.dtype], int]] = None, +): + """Replace the literals in the graph with **existing** placeholder nodes, so that the literal arguments + in the graph can be matched and replaced + + To use this, all literal args in the graph should be unique and each of them should correspond + to exactly one placeholder node + + # 1. Original Graph + def pattern(self, x_i8, scale, zero_point, quant_min, quant_max): + return torch.dequantize_per_tensor(x_i8, scale, zero_point, quant_min, quant_max) + + def replacement(x_i8, scale, zero_point, quant_min, quant_max): + x_i8 = torch.clamp(x_i8, quant_min, quant_max) + return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32) + + example_inputs = ( + torch.randn(1, 3, 3, 3), + 1.0, + 0, + -128, + 127, + ) + pattern_gm = _get_aten_graph_module_for_pattern(pattern, example_inputs) + replacement_gm = _get_aten_graph_module_for_pattern(pattern, example_inptus) + + # 2. Before calling replace literals we'll see the following graph: + def pattern(self, x_i8, scale, zero_point, quant_min, quant_max): + # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values + return torch.dequantize_per_tensor(x_i8, 1.0, 0, -128, 127) + + def replacement(x_i8, scale, zero_point, quant_min, quant_max): + # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values + x_i8 = torch.clamp(x_i8, -128, 127) + return ((x_i8.to(torch.float32) - 0) * 1.0).to(dtype=torch.float32) + + # Note that literal args appear in different order in pattern and replacement graph, so + # we can't use _replace_literals_with_new_placeholders + + literal_to_ph_idx = {1.0: 1, 0: 2, -128: 3, 127: 4} + pattern_gm = _replace_literals_with_existing_placeholders(pattern_gm, literal_to_ph_idx) + replacement_gm = _replace_literals_with_existing_placeholders(replacement_gm, literal_to_ph_idx) + + # 3. After replacing literals with existing placeholder nodes + + def pattern(self, x_i8, scale, zero_point, quant_min, quant_max): + # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values + return torch.dequantize_per_tensor(x_i8, scale, zero_point, quant_min, quant_max) + + def replacement(x_i8, scale, zero_point, quant_min, quant_max): + # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values + x_i8 = torch.clamp(x_i8, quant_min, quant_max) + return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32) + """ + if exclude_literals is None: + exclude_literals = [] + + if literal_to_ph_idx is None: + literal_to_ph_idx = {} + + phs = [node for node in gm.graph.nodes if node.op == "placeholder"] + + for node in gm.graph.nodes: + if node.op != "call_function": + continue + new_args = [] + for arg in node.args: + if ( + _is_literal(arg) + and arg not in exclude_literals + and arg in literal_to_ph_idx + ): + ph_idx = literal_to_ph_idx[arg] + ph_node = phs[ph_idx] + new_args.append(ph_node) + else: + new_args.append(arg) + new_args = tuple(new_args) + node.args = new_args + return gm + + +# TODO: Handle this in export itself and don't wrap the model in another GraphModule +# in prepare and convert +def _disallow_eval_train(model: GraphModule): + """ + Disallow calling `model.train()` or `model.eval()` on the given GraphModule. + This is useful for exported models, where these methods don't actually behave as expected. + """ + error_message = """ + Calling train() or eval() is not supported for exported models. + Please call `torch.ao.quantization.move_exported_model_to_train(model)` (or eval) instead. + + If you cannot replace the calls to `model.train()` and `model.eval()`, you may override + the behavior for these methods by calling `torch.ao.quantization.allow_exported_model_train_eval(model)`, + which does the above automatically for you. Note that this has limited effect on switching + behavior between train and eval modes, and should be used only for special ops such as dropout + and batchnorm. + """ + + def _train(self, mode: bool = True): + raise NotImplementedError(error_message) + + def _eval(self, mode: bool = True): + raise NotImplementedError(error_message) + + model.train = types.MethodType(_train, model) # type: ignore[method-assign] + model.eval = types.MethodType(_eval, model) # type: ignore[method-assign] + return model diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/qconfig.py b/phivenv/Lib/site-packages/torch/ao/quantization/qconfig.py new file mode 100644 index 0000000000000000000000000000000000000000..9c5b123dc84b67beef517175c675ddc9f2a10c4e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/qconfig.py @@ -0,0 +1,701 @@ +# mypy: allow-untyped-defs +import copy +import warnings +from collections import namedtuple +from typing import Any, Optional, Union +from typing_extensions import deprecated + +import torch +import torch.nn as nn +from torch.ao.quantization.fake_quantize import ( + default_dynamic_fake_quant, + default_embedding_fake_quant, + default_embedding_fake_quant_4bit, + default_fake_quant, + default_fused_act_fake_quant, + default_fused_per_channel_wt_fake_quant, + default_fused_wt_fake_quant, + default_per_channel_weight_fake_quant, + default_weight_fake_quant, + FakeQuantize, + FakeQuantizeBase, + fused_per_channel_wt_fake_quant_range_neg_127_to_127, + fused_wt_fake_quant_range_neg_127_to_127, + FusedMovingAvgObsFakeQuantize, +) + +from .observer import ( + _PartialWrapper, + default_debug_observer, + default_dynamic_quant_observer, + default_float_qparams_observer, + default_float_qparams_observer_4bit, + default_observer, + default_per_channel_weight_observer, + default_placeholder_observer, + default_reuse_input_observer, + default_weight_observer, + HistogramObserver, + MinMaxObserver, + MovingAverageMinMaxObserver, + NoopObserver, + ObserverBase, + per_channel_weight_observer_range_neg_127_to_127, + PlaceholderObserver, + ReuseInputObserver, + weight_observer_range_neg_127_to_127, +) + + +__all__ = [ + "QConfig", + # TODO: deprecated, remove + "QConfigDynamic", + "default_qconfig", + "default_debug_qconfig", + "default_per_channel_qconfig", + "default_dynamic_qconfig", + "float16_dynamic_qconfig", + "float16_static_qconfig", + "per_channel_dynamic_qconfig", + "float_qparams_weight_only_qconfig", + "float_qparams_weight_only_qconfig_4bit", + "default_quint8_weight_qconfig", + "default_qat_qconfig", + "default_dynamic_qat_qconfig", + "default_weight_only_qconfig", + "default_activation_only_qconfig", + "default_qat_qconfig_v2", + "default_reuse_input_qconfig", + "default_symmetric_qnnpack_qconfig", + "default_per_channel_symmetric_qnnpack_qconfig", + "default_symmetric_qnnpack_qat_qconfig", + "default_per_channel_symmetric_qnnpack_qat_qconfig", + "default_embedding_qat_qconfig", + "default_embedding_qat_qconfig_4bit", + "get_default_qconfig", + "get_default_qat_qconfig", + "get_default_qconfig_dict", + "get_default_qat_qconfig_dict", + "QConfigAny", + "qconfig_equals", +] + + +class QConfig(namedtuple("QConfig", ["activation", "weight"])): + """ + Describes how to quantize a layer or a part of the network by providing + settings (observer classes) for activations and weights respectively. + + + Note that QConfig needs to contain observer **classes** (like MinMaxObserver) or a callable that returns + instances on invocation, not the concrete observer instances themselves. + Quantization preparation function will instantiate observers multiple times for each of the layers. + + + Observer classes have usually reasonable default arguments, but they can be overwritten with `with_args` + method (that behaves like functools.partial):: + + my_qconfig = QConfig( + activation=MinMaxObserver.with_args(dtype=torch.qint8), + weight=default_observer.with_args(dtype=torch.qint8), + ) + + """ + + __slots__ = () + + def __new__(cls, activation, weight): + # catch common mistakes + if isinstance(activation, nn.Module) or isinstance(weight, nn.Module): + raise ValueError( + "QConfig received observer instance, please pass observer class instead. " + + "Use MyObserver.with_args(x=1) to override arguments to constructor if needed" + ) + return super().__new__(cls, activation, weight) + + +@deprecated( + "`QConfigDynamic` is going to be deprecated in PyTorch 1.12, please use `QConfig` instead", + category=FutureWarning, +) +class QConfigDynamic(namedtuple("QConfigDynamic", ["activation", "weight"])): + """ + Describes how to dynamically quantize a layer or a part of the network by providing + settings (observer classes) for weights. + + It's like QConfig, but for dynamic quantization. + + Note that QConfigDynamic needs to contain observer **classes** (like MinMaxObserver) or a callable that returns + instances on invocation, not the concrete observer instances themselves. + Quantization function will instantiate observers multiple times for each of the layers. + + Observer classes have usually reasonable default arguments, but they can be overwritten with `with_args` + method (that behaves like functools.partial):: + + my_qconfig = QConfigDynamic(weight=default_observer.with_args(dtype=torch.qint8)) + """ + + __slots__ = () + + def __new__(cls, activation=torch.nn.Identity, weight=torch.nn.Identity): + # catch common mistakes + if isinstance(weight, nn.Module): + raise ValueError( + "QConfigDynamic received observer instance, please pass observer class instead. " + + "Use MyObserver.with_args(x=1) to override arguments to constructor if needed" + ) + return super().__new__(cls, activation, weight) + + +default_qconfig = QConfig(activation=default_observer, weight=default_weight_observer) +""" +Default qconfig configuration. +""" + +default_debug_qconfig = QConfig( + weight=default_weight_observer, activation=default_debug_observer +) +""" +Default qconfig configuration for debugging. +""" + +default_per_channel_qconfig = QConfig( + activation=default_observer, weight=default_per_channel_weight_observer +) +""" +Default qconfig configuration for per channel weight quantization. +""" + +default_dynamic_qconfig = QConfig( + activation=default_dynamic_quant_observer, weight=default_weight_observer +) +""" +Default dynamic qconfig. +""" + +float16_dynamic_qconfig = QConfig( + activation=PlaceholderObserver.with_args(dtype=torch.float16, is_dynamic=True), + weight=PlaceholderObserver.with_args(dtype=torch.float16), +) +""" +Dynamic qconfig with weights quantized to `torch.float16`. +""" + +float16_static_qconfig = QConfig( + activation=PlaceholderObserver.with_args(dtype=torch.float16), + weight=PlaceholderObserver.with_args(dtype=torch.float16), +) +""" +Dynamic qconfig with both activations and weights quantized to `torch.float16`. +""" + +per_channel_dynamic_qconfig = QConfig( + activation=default_dynamic_quant_observer, + weight=default_per_channel_weight_observer, +) +""" +Dynamic qconfig with weights quantized per channel. +""" + +float_qparams_weight_only_qconfig = QConfig( + activation=default_placeholder_observer, weight=default_float_qparams_observer +) +""" +Dynamic qconfig with weights quantized with a floating point zero_point. +""" + +float_qparams_weight_only_qconfig_4bit = QConfig( + activation=default_placeholder_observer, weight=default_float_qparams_observer_4bit +) + +default_qat_qconfig = QConfig( + activation=default_fake_quant, weight=default_weight_fake_quant +) +""" +Default qconfig for QAT. +""" + +default_dynamic_qat_qconfig = QConfig( + activation=default_dynamic_fake_quant, weight=default_weight_fake_quant +) +""" +Default qconfig for dynamic QAT. +""" + +default_weight_only_qconfig = QConfig( + activation=torch.nn.Identity, weight=default_weight_fake_quant +) +""" +Default qconfig for quantizing weights only. +""" + +default_activation_only_qconfig = QConfig( + activation=default_fake_quant, weight=torch.nn.Identity +) +""" +Default qconfig for quantizing activations only. +""" + +# QAT config that uses a fused observer + fake quant modules for optimized training performance. +# to modify the activation/weight observers, the default entries in fake_quantize.py can be modified. +default_qat_qconfig_v2 = QConfig( + activation=default_fused_act_fake_quant, weight=default_fused_wt_fake_quant +) +""" +Fused version of `default_qat_config`, has performance benefits. +""" + +default_reuse_input_qconfig = QConfig( + activation=default_reuse_input_observer, weight=NoopObserver +) +""" +Default qconfig for operators that reuse the observers from input Tensor, e.g. reshape +""" + + +def get_default_qconfig(backend="x86", version=0): + """ + Returns the default PTQ qconfig for the specified backend. + + Args: + * `backend` (str): a string representing the target backend. Currently supports + `x86` (default), `fbgemm`, `qnnpack` and `onednn`. + + Return: + qconfig + """ + supported_backends = ["fbgemm", "x86", "qnnpack", "onednn"] + if backend not in supported_backends: + raise AssertionError( + "backend: " + + str(backend) + + f" not supported. backend must be one of {supported_backends}" + ) + + if version == 0: + if backend == "fbgemm": + qconfig = QConfig( + activation=HistogramObserver.with_args(reduce_range=True), + weight=default_per_channel_weight_observer, + ) + elif backend == "qnnpack": + # TODO: make this compatible with xnnpack constraints + qconfig = QConfig( + activation=HistogramObserver.with_args(reduce_range=False), + weight=default_weight_observer, + ) + elif backend == "onednn": + if not torch.cpu._is_vnni_supported(): + warnings.warn( + "Default qconfig of oneDNN backend with reduce_range of false may have accuracy issues " + "on CPU without Vector Neural Network Instruction support." + ) + qconfig = QConfig( + activation=HistogramObserver.with_args(reduce_range=False), + weight=default_per_channel_weight_observer, + ) + elif backend == "x86": + qconfig = QConfig( + activation=HistogramObserver.with_args(reduce_range=True), + weight=default_per_channel_weight_observer, + ) + else: + # won't reach + qconfig = default_qconfig + else: + raise AssertionError( + "Version number: " + + str(version) + + " in get_default_qconfig is not supported. Version number must be 0" + ) + + return qconfig + + +""" +Default, symmetric PTQ qconfig for the specified backend. And a per_channel +variant of the same. + +Symmetric here applies to signed weights with zero point = 0, and additional +value restrictions. The activations are also signed 8-bit integers with this +qconfig. + + * Once this change is merged [as of 3/17/22], with backend or qengine = + 'qnnpack', some quantized operators with this symmetric qconfig may use + operators from xnnpack library. + + ** Support to use xnnpack ops with `qnnpack` backed for asymmetric + qconfig (returned by get_default_qconfig()) is not available yet. + + * This qconfig uses signed activations and weights. Weights have added + restrictions such as zero point is forced to be 0, making the weights + symmetric, hence the name. And the 8-bit quantized values are + restricting to to [-127, +127], excluding -128. + + * xnnpack has a requantization scale value restriction, 0x1p-32 <= + requantization_scale < 256.0 where, `requantization_scale = (input_scale + * kernel_scale) / (output_scale)`. Using this eps (w/ assumed max value + of 256) is to prevent requantization_scale to go below xnnpack lower + threshold. +""" +default_symmetric_qnnpack_qconfig = QConfig( + activation=HistogramObserver.with_args( + dtype=torch.qint8, reduce_range=False, eps=2**-12 + ), + weight=weight_observer_range_neg_127_to_127, +) + +default_per_channel_symmetric_qnnpack_qconfig = QConfig( + activation=HistogramObserver.with_args( + dtype=torch.qint8, reduce_range=False, eps=2**-12 + ), + weight=per_channel_weight_observer_range_neg_127_to_127, +) + +default_embedding_qat_qconfig = QConfig( + activation=NoopObserver.with_args(dtype=torch.float32), + weight=default_embedding_fake_quant, +) + +default_embedding_qat_qconfig_4bit = QConfig( + activation=NoopObserver.with_args(dtype=torch.float32), + weight=default_embedding_fake_quant_4bit, +) + +default_quint8_weight_qconfig = QConfig( + activation=HistogramObserver, weight=MinMaxObserver +) + + +def get_default_qat_qconfig(backend="x86", version=1): + """ + Returns the default QAT qconfig for the specified backend. + + Args: + * `backend` (str): a string representing the target backend. Currently supports + `x86` (default), `fbgemm`, `qnnpack` and `onednn`. + * `version`: version, for backwards compatibility. Can be `None` or `1`. + + Return: + qconfig + """ + supported_backends = ["fbgemm", "x86", "qnnpack", "onednn"] + if backend not in supported_backends: + raise AssertionError( + "backend: " + + str(backend) + + f" not supported. backend must be one of {supported_backends}" + ) + + # Histogram observer is too slow for quantization aware training + if version == 0: + if backend == "fbgemm": + qconfig = QConfig( + activation=FakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + reduce_range=True, + ), + weight=default_per_channel_weight_fake_quant, + ) + elif backend == "qnnpack": + qconfig = QConfig( + activation=FakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + reduce_range=False, + ), + weight=default_weight_fake_quant, + ) + elif backend == "onednn": + qconfig = QConfig( + activation=FakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255 + ), + weight=default_per_channel_weight_fake_quant, + ) + elif backend == "x86": + qconfig = QConfig( + activation=FakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + reduce_range=True, + ), + weight=default_per_channel_weight_fake_quant, + ) + else: + qconfig = default_qat_qconfig + # Use the fused observe + fake_quant modules for doing QAT. + elif version == 1: + if backend == "fbgemm": + qconfig = QConfig( + activation=FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + reduce_range=True, + ), + weight=default_fused_per_channel_wt_fake_quant, + ) + elif backend == "qnnpack": + # TODO: make this compatible with xnnpack constraints + qconfig = QConfig( + activation=FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + reduce_range=False, + ), + weight=default_fused_wt_fake_quant, + ) + elif backend == "onednn": + qconfig = QConfig( + activation=FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255 + ), + weight=default_fused_per_channel_wt_fake_quant, + ) + elif backend == "x86": + qconfig = QConfig( + activation=FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + reduce_range=True, + ), + weight=default_fused_per_channel_wt_fake_quant, + ) + else: + qconfig = default_qat_qconfig_v2 + else: + raise AssertionError( + "Version number: " + + str(version) + + "in get_default_qat_qconfig is not supported. Version number must be 0 or 1" + ) + + return qconfig + + +""" +Default symmetric QAT qconfig for qnnpack. And its per channel weight variant. +""" +default_symmetric_qnnpack_qat_qconfig = QConfig( + activation=FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + reduce_range=False, + eps=2**-12, + ), + weight=fused_wt_fake_quant_range_neg_127_to_127, +) + +default_per_channel_symmetric_qnnpack_qat_qconfig = QConfig( + activation=FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + reduce_range=False, + eps=2**-12, + ), + weight=fused_per_channel_wt_fake_quant_range_neg_127_to_127, +) + +_default_fp32_placeholder_qconfig = QConfig( + activation=PlaceholderObserver.with_args(dtype=torch.float32), + weight=PlaceholderObserver.with_args(dtype=torch.float32), +) + +_default_quint8_placeholder_qconfig = QConfig( + activation=PlaceholderObserver.with_args(dtype=torch.quint8), + # operators using this qconfig doesn't have weights + weight=None, +) + + +@deprecated( + "`torch.ao.quantization.get_default_qconfig_dict` is deprecated and will be removed in " + "a future version. Please use `torch.ao.quantization.get_default_qconfig_mapping` instead.", + category=FutureWarning, +) +def get_default_qconfig_dict(backend="x86", version=0): + return torch.ao.quantization.get_default_qconfig_mapping(backend, version).to_dict() + + +@deprecated( + "`torch.ao.quantization.get_default_qat_qconfig_dict` is deprecated and will be removed in " + "a future version. Please use `torch.ao.quantization.get_default_qat_qconfig_mapping` instead.", + category=FutureWarning, +) +def get_default_qat_qconfig_dict(backend="x86", version=1): + return torch.ao.quantization.get_default_qat_qconfig_mapping( + backend, version + ).to_dict() + + +def _assert_valid_qconfig(qconfig: Optional[QConfig], mod: torch.nn.Module) -> None: + """ + Verifies that this `qconfig` is valid. + """ + if qconfig is None: + return + is_conv_transpose_mod = isinstance( + mod, + (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d), + ) + if is_conv_transpose_mod: + if qconfig.weight is None: + # for now, we assume that any qconfig for ConvTranspose without a weight is valid + return + example_observer = qconfig.weight() + is_per_channel = isinstance( + example_observer, + ( + torch.ao.quantization.PerChannelMinMaxObserver, + torch.ao.quantization.MovingAveragePerChannelMinMaxObserver, + ), + ) + assert not is_per_channel, ( + "Per channel weight observer is not supported yet for ConvTranspose{n}d." + ) + + +QConfigAny = Optional[QConfig] +QConfigAny.__module__ = "torch.ao.quantization.qconfig" + + +def _add_module_to_qconfig_obs_ctr( + qconfig: QConfigAny, module: Optional[nn.Module] +) -> Any: + r"""This is a helper function for use in quantization prepare that updates a qconfig so that + the constructors stored in the qconfig will create observers on the same device that + 'module' is on. This is intended to be used when the qconfigs are propagated to each + module in order to avoid potential device alignment issues. + + Args: + qconfig: QConfig with obs constructors stored in activation and weight + module: module which the qconfig is related to + + Return: + qconfig: configured so that obs constructors set to construct on the same device as module + """ + + if module is None or qconfig is None or qconfig._fields != ("activation", "weight"): + return qconfig + + def get_factory_kwargs_based_on_module_device(): + assert isinstance(module, torch.nn.Module) + devices = {p.device for p in module.parameters()} | { + p.device for p in module.buffers() + } + device = next(iter(devices)) if len(devices) > 0 else None + return None if device is None else {"device": device} + + def configure_constructor_to_put_obs_on_module_device(original_constructor): + try: + # check if constructor can accept factory_kwargs + check = original_constructor.with_args(factory_kwargs=None) + check() + return original_constructor.with_callable_args( + factory_kwargs=get_factory_kwargs_based_on_module_device + ) + except AttributeError: # qconfig doesn't have activation or weight + return original_constructor + except TypeError: # the class doesn't accept factory_kwargs argument + return original_constructor + + activation = configure_constructor_to_put_obs_on_module_device(qconfig.activation) + weight = configure_constructor_to_put_obs_on_module_device(qconfig.weight) + + return QConfig(activation, weight) + + +_ObserverOrFakeQuantizeConstructor = Union[ + _PartialWrapper, type[ObserverBase], type[FakeQuantizeBase] +] + + +def _obs_or_fq_ctr_equals( + obs_or_fq1: _ObserverOrFakeQuantizeConstructor, + obs_or_fq2: _ObserverOrFakeQuantizeConstructor, +): + if isinstance(obs_or_fq1, _PartialWrapper) and isinstance( + obs_or_fq2, _PartialWrapper + ): + return _partial_wrapper_equals(obs_or_fq1, obs_or_fq2) + return obs_or_fq1 == obs_or_fq2 + + +def _partial_wrapper_equals(obs_or_fq1: _PartialWrapper, obs_or_fq2: _PartialWrapper): + """ + Return whether the two partial wrappers are equal, + """ + # functools.partial has no __eq__ operator defined so '==' defaults to 'is' + obs_or_fq1_keywords = copy.copy(obs_or_fq1.p.keywords) + obs_or_fq2_keywords = copy.copy(obs_or_fq2.p.keywords) + keywords_equal = True + # compare observer constructor with _obs_or_fq_ctr_equals since direct compare would fail + if "observer" in obs_or_fq1_keywords and "observer" in obs_or_fq2_keywords: + keywords_equal = keywords_equal and _obs_or_fq_ctr_equals( + obs_or_fq1_keywords["observer"], obs_or_fq2_keywords["observer"] + ) + obs_or_fq1_keywords.pop("observer") + obs_or_fq2_keywords.pop("observer") + keywords_equal = keywords_equal and obs_or_fq1_keywords == obs_or_fq2_keywords + return ( + obs_or_fq1.p.func == obs_or_fq2.p.func + and obs_or_fq1.p.args == obs_or_fq2.p.args + and keywords_equal + ) + + +def qconfig_equals(q1: QConfigAny, q2: QConfigAny): + """ + Returns `True` if `q1` equals `q2`, and `False` otherwise. + """ + if q1 is None or q2 is None: + return q1 == q2 + else: + assert q1 is not None and q2 is not None + try: + # Qconfig weight and activation can be either a partial wrapper, + # or an observer class. Special handling is required (above) for + # comparing partial wrappers. + activation_same = _obs_or_fq_ctr_equals(q1.activation, q2.activation) + weight_same = _obs_or_fq_ctr_equals(q1.weight, q2.weight) + return activation_same and weight_same + except AttributeError: + return q1 == q2 + + +def _activation_is_memoryless(qconfig: QConfig): + """ + Return whether the observer for activations defined in the given QConfig is memoryless. + This means a MovingAverage observer with averaging constant equal to 1. + """ + + def _is_memoryless(observer): + return ( + hasattr(observer, "averaging_constant") and observer.averaging_constant == 1 + ) + + act = qconfig.activation() + if isinstance(act, FakeQuantizeBase) and hasattr(act, "activation_post_process"): + return _is_memoryless(act.activation_post_process) + else: + return _is_memoryless(act) + + +def _is_reuse_input_qconfig(qconfig: Optional[QConfig]): + return ( + qconfig is not None + and isinstance(qconfig.activation(), ReuseInputObserver) + and isinstance(qconfig.weight(), NoopObserver) + ) diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/qconfig_mapping.py b/phivenv/Lib/site-packages/torch/ao/quantization/qconfig_mapping.py new file mode 100644 index 0000000000000000000000000000000000000000..b8151c878539cfb3f3d2ff6cb3684cd50b4306f8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/qconfig_mapping.py @@ -0,0 +1,381 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +from collections import OrderedDict +from typing import Any, Callable, Union + +import torch + +from .fake_quantize import default_weight_fake_quant, FixedQParamsFakeQuantize +from .observer import ( + _PartialWrapper, + default_fixed_qparams_range_0to1_observer, + default_fixed_qparams_range_neg1to1_observer, + default_placeholder_observer, + default_weight_observer, +) +from .qconfig import ( + default_quint8_weight_qconfig, + default_reuse_input_qconfig, + default_symmetric_qnnpack_qat_qconfig, + default_symmetric_qnnpack_qconfig, + get_default_qat_qconfig, + get_default_qconfig, + QConfig, + QConfigAny, +) + + +__all__ = [ + "get_default_qconfig_mapping", + "get_default_qat_qconfig_mapping", + "QConfigMapping", +] + + +# TODO: replace all usages with these constants +_GLOBAL_DICT_KEY = "" +_OBJECT_TYPE_DICT_KEY = "object_type" +_MODULE_NAME_REGEX_DICT_KEY = "module_name_regex" +_MODULE_NAME_DICT_KEY = "module_name" +_MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY = "module_name_object_type_order" + +# TODO: derive this map from the BackendConfig +_FIXED_QPARAMS_OP_TO_OBSERVER: dict[Union[Callable, str], _PartialWrapper] = { + torch.nn.Hardsigmoid: default_fixed_qparams_range_0to1_observer, + torch.nn.functional.hardsigmoid: default_fixed_qparams_range_0to1_observer, + "hardsigmoid": default_fixed_qparams_range_0to1_observer, + "hardsigmoid_": default_fixed_qparams_range_0to1_observer, + torch.nn.Sigmoid: default_fixed_qparams_range_0to1_observer, + torch.sigmoid: default_fixed_qparams_range_0to1_observer, + "sigmoid": default_fixed_qparams_range_0to1_observer, + "sigmoid_": default_fixed_qparams_range_0to1_observer, + torch.nn.Softmax: default_fixed_qparams_range_0to1_observer, + torch.nn.Tanh: default_fixed_qparams_range_neg1to1_observer, + torch.tanh: default_fixed_qparams_range_neg1to1_observer, + "tanh": default_fixed_qparams_range_neg1to1_observer, + "tanh_": default_fixed_qparams_range_neg1to1_observer, +} + + +def _get_default_qconfig_mapping( + is_qat: bool, backend: str, version: int +) -> QConfigMapping: + """ + Return the default QConfigMapping for the given quantization type and backend. + """ + if is_qat: + qconfig = get_default_qat_qconfig(backend, version) + else: + qconfig = get_default_qconfig(backend, version) + default_weight = default_weight_fake_quant if is_qat else default_weight_observer + + # default_per_channel_weight_observer is not currently compatible with fbgemm backend + # so we have to modify the weight observer to default_weight_observer or another + # per tensor supported observer. + # see https://github.com/pytorch/pytorch/issues/47535 + if backend in ("fbgemm", "x86"): + qconfig_transpose = QConfig( + activation=qconfig.activation, weight=default_weight + ) + else: + qconfig_transpose = qconfig + + # currently layernorm only supports float weights + # we have to add this because otherwise there will be a extra quantize-dequantize pair + qconfig_layernorm = QConfig( + activation=qconfig.activation, weight=default_placeholder_observer + ) + + qconfig_mapping = ( + QConfigMapping() + .set_global(qconfig) + .set_object_type("reshape", default_reuse_input_qconfig) + .set_object_type(torch.nn.ConvTranspose1d, qconfig_transpose) + .set_object_type(torch.nn.ConvTranspose2d, qconfig_transpose) + .set_object_type(torch.nn.ConvTranspose3d, qconfig_transpose) + .set_object_type(torch.nn.functional.conv_transpose1d, qconfig_transpose) + .set_object_type(torch.nn.functional.conv_transpose2d, qconfig_transpose) + .set_object_type(torch.nn.functional.conv_transpose3d, qconfig_transpose) + .set_object_type(torch.nn.functional.layer_norm, qconfig_layernorm) + .set_object_type(torch.nn.LayerNorm, qconfig_layernorm) + .set_object_type(torch.nn.PReLU, default_quint8_weight_qconfig) + ) + # Use special observers for ops with fixed qparams + fixed_qparams_observer_to_qconfig: dict[Any, QConfigAny] = {} + for fixed_qparams_op, observer in _FIXED_QPARAMS_OP_TO_OBSERVER.items(): + if observer in fixed_qparams_observer_to_qconfig: + fixed_qparams_qconfig = fixed_qparams_observer_to_qconfig[observer] + else: + if is_qat: + activation = FixedQParamsFakeQuantize.with_args(observer=observer) + else: + activation = observer + fixed_qparams_qconfig = QConfig( + activation=activation, weight=default_weight + ) + fixed_qparams_observer_to_qconfig[observer] = fixed_qparams_qconfig + qconfig_mapping.set_object_type(fixed_qparams_op, fixed_qparams_qconfig) + + # TODO Currently it's required that separate ops in a fused op/module have the same qconfig. + # Need to be able to support fusion of ops with different qconfigs + + return qconfig_mapping + + +def get_default_qconfig_mapping(backend="x86", version=0) -> QConfigMapping: + """ + Return the default QConfigMapping for post training quantization. + + Args: + * ``backend`` (str) : the quantization backend for the default qconfig mapping, should be + one of ["x86" (default), "fbgemm", "qnnpack", "onednn"] + * ``version`` (int) : the version for the default qconfig mapping + """ + # TODO: add assert for backend choices + return _get_default_qconfig_mapping(False, backend, version) + + +def get_default_qat_qconfig_mapping(backend="x86", version=1) -> QConfigMapping: + """ + Return the default QConfigMapping for quantization aware training. + + Args: + * ``backend`` (str) : the quantization backend for the default qconfig mapping, should be + one of ["x86" (default), "fbgemm", "qnnpack", "onednn"] + * ``version`` (int) : the version for the default qconfig mapping + """ + return _get_default_qconfig_mapping(True, backend, version) + + +def _get_symmetric_qnnpack_qconfig_mapping() -> QConfigMapping: + """ + Return a QConfigMapping that uses `torch.ao.quantization.default_symmetric_qnnpack_qconfig` + as the default QConfig. + """ + default_qconfig = default_symmetric_qnnpack_qconfig + return _get_default_qconfig_mapping_with_default_qconfig( + False, "qnnpack", default_qconfig + ) + + +def _get_symmetric_qnnpack_qat_qconfig_mapping() -> QConfigMapping: + """ + Return a QConfigMapping that uses `torch.ao.quantization.default_symmetric_qnnpack_qat_qconfig` + as the default QConfig. + """ + default_qconfig = default_symmetric_qnnpack_qat_qconfig + return _get_default_qconfig_mapping_with_default_qconfig( + True, "qnnpack", default_qconfig + ) + + +def _get_default_qconfig_mapping_with_default_qconfig( + is_qat: bool, + backend: str, + default_qconfig: QConfig, +) -> QConfigMapping: + """ + Return a QConfigMapping that uses the provided qconfig as the default QConfig. + """ + if is_qat: + qconfig_mapping = get_default_qat_qconfig_mapping(backend) + else: + qconfig_mapping = get_default_qconfig_mapping(backend) + qconfig_mapping.set_global(default_qconfig) + for pattern in qconfig_mapping.object_type_qconfigs.keys(): + if pattern not in _FIXED_QPARAMS_OP_TO_OBSERVER: + qconfig_mapping.set_object_type(pattern, default_qconfig) + return qconfig_mapping + + +_QCONFIG_STYLE_ORDER: list[str] = [ + "global_qconfig", + "object_type_qconfigs", + "module_name_regex_qconfigs", + "module_name_qconfigs", + "module_name_object_type_order_qconfigs", +] + + +class QConfigMapping: + """ + Mapping from model ops to :class:`torch.ao.quantization.QConfig` s. + + The user can specify QConfigs using the following methods (in increasing match priority): + + ``set_global`` : sets the global (default) QConfig + + ``set_object_type`` : sets the QConfig for a given module type, function, or method name + + ``set_module_name_regex`` : sets the QConfig for modules matching the given regex string + + ``set_module_name`` : sets the QConfig for modules matching the given module name + + ``set_module_name_object_type_order`` : sets the QConfig for modules matching a combination + of the given module name, object type, and the index at which the module appears + + Example usage:: + + qconfig_mapping = QConfigMapping() + .set_global(global_qconfig) + .set_object_type(torch.nn.Linear, qconfig1) + .set_object_type(torch.nn.ReLU, qconfig1) + .set_module_name_regex("foo.*bar.*conv[0-9]+", qconfig1) + .set_module_name_regex("foo.*", qconfig2) + .set_module_name("module1", qconfig1) + .set_module_name("module2", qconfig2) + .set_module_name_object_type_order("foo.bar", torch.nn.functional.linear, 0, qconfig3) + + """ + + def __init__(self) -> None: + # In increasing match priority: + self.global_qconfig: QConfigAny = None + self.object_type_qconfigs: OrderedDict[Union[Callable, str], QConfigAny] = ( + OrderedDict() + ) + self.module_name_regex_qconfigs: OrderedDict[str, QConfigAny] = OrderedDict() + self.module_name_qconfigs: OrderedDict[str, QConfigAny] = OrderedDict() + self.module_name_object_type_order_qconfigs: OrderedDict[ + tuple[str, Callable, int], QConfigAny + ] = OrderedDict() + + def set_global(self, global_qconfig: QConfigAny) -> QConfigMapping: + """ + Set the global (default) QConfig. + """ + self.global_qconfig = global_qconfig + return self + + def set_object_type( + self, object_type: Union[Callable, str], qconfig: QConfigAny + ) -> QConfigMapping: + """ + Set the QConfig for a given module type, function, or method name. + If the QConfig for an existing object type was already set, the new QConfig will override the old one. + """ + self.object_type_qconfigs[object_type] = qconfig + return self + + def set_module_name_regex( + self, module_name_regex: str, qconfig: QConfigAny + ) -> QConfigMapping: + """ + Set the QConfig for modules matching the given regex string. + + Regexes will be matched in the order in which they are registered through this method. + Thus, the caller should register more specific patterns first, e.g.:: + + qconfig_mapping = QConfigMapping() + .set_module_name_regex("foo.*bar.*conv[0-9]+", qconfig1) + .set_module_name_regex("foo.*bar.*", qconfig2) + .set_module_name_regex("foo.*", qconfig3) + + In this example, "foo.bar.conv0" would match qconfig1, "foo.bar.linear" would match qconfig2, + and "foo.baz.relu" would match qconfig3. + + If the QConfig for an existing module name regex was already set, the new QConfig will override the + old one while preserving the order in which the regexes were originally registered. + """ + self.module_name_regex_qconfigs[module_name_regex] = qconfig + return self + + def set_module_name(self, module_name: str, qconfig: QConfigAny) -> QConfigMapping: + """ + Set the QConfig for modules matching the given module name. + If the QConfig for an existing module name was already set, the new QConfig will override the old one. + """ + self.module_name_qconfigs[module_name] = qconfig + return self + + def set_module_name_object_type_order( + self, module_name: str, object_type: Callable, index: int, qconfig: QConfigAny + ) -> QConfigMapping: + """ + Set the QConfig for modules matching a combination of the given module name, object type, + and the index at which the module appears. + + If the QConfig for an existing (module name, object type, index) was already set, the new QConfig + will override the old one. + """ + self.module_name_object_type_order_qconfigs[ + (module_name, object_type, index) + ] = qconfig + return self + + def __repr__(self) -> str: + output = self.__class__.__name__ + " (" + for style_name in _QCONFIG_STYLE_ORDER: + output += f"\n {style_name}" + qconfigs = getattr(self, style_name) + if isinstance(qconfigs, OrderedDict) and len(qconfigs) > 0: + for key, qconfig in qconfigs.items(): + output += f"\n {key}: {qconfig}" + else: + output += f"\n {qconfigs}" + return output + "\n)" + + # TODO: remove this + def to_dict(self) -> dict[str, Any]: + """ + Convert this ``QConfigMapping`` to a dictionary with the following keys: + + "" (for global QConfig) + + "object_type" + + "module_name_regex" + + "module_name" + + "module_name_object_type_order" + + The values of this dictionary are lists of tuples. + """ + return { + _GLOBAL_DICT_KEY: self.global_qconfig, + _OBJECT_TYPE_DICT_KEY: list(self.object_type_qconfigs.items()), + _MODULE_NAME_REGEX_DICT_KEY: list(self.module_name_regex_qconfigs.items()), + _MODULE_NAME_DICT_KEY: list(self.module_name_qconfigs.items()), + _MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY: [ + (*k, v) for k, v in self.module_name_object_type_order_qconfigs.items() + ], + } + + # TODO: remove this + @classmethod + def from_dict(cls, qconfig_dict: dict[str, Any]) -> QConfigMapping: + """ + Create a ``QConfigMapping`` from a dictionary with the following keys (all optional): + + "" (for global QConfig) + + "object_type" + + "module_name_regex" + + "module_name" + + "module_name_object_type_order" + + The values of this dictionary are expected to be lists of tuples. + """ + conf = cls() + if _GLOBAL_DICT_KEY in qconfig_dict: + conf.set_global(qconfig_dict[_GLOBAL_DICT_KEY]) + for object_type, qconfig in qconfig_dict.get(_OBJECT_TYPE_DICT_KEY, []): + conf.set_object_type(object_type, qconfig) + for module_name_regex, qconfig in qconfig_dict.get( + _MODULE_NAME_REGEX_DICT_KEY, [] + ): + conf.set_module_name_regex(module_name_regex, qconfig) + for module_name, qconfig in qconfig_dict.get(_MODULE_NAME_DICT_KEY, []): + conf.set_module_name(module_name, qconfig) + for module_name, object_type, index, qconfig in qconfig_dict.get( + _MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY, [] + ): + conf.set_module_name_object_type_order( + module_name, object_type, index, qconfig + ) + return conf diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/quant_type.py b/phivenv/Lib/site-packages/torch/ao/quantization/quant_type.py new file mode 100644 index 0000000000000000000000000000000000000000..0a7298bf03562f3f186f7d3c27e29cb2a509497a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/quant_type.py @@ -0,0 +1,35 @@ +import enum + + +__all__ = [ + "QuantType", +] + + +# Quantization type (dynamic quantization, static quantization). +# Should match the c++ enum in quantization_type.h +class QuantType(enum.IntEnum): + DYNAMIC = 0 + STATIC = 1 + QAT = 2 + WEIGHT_ONLY = 3 + + +_quant_type_to_str = { + QuantType.STATIC: "static", + QuantType.DYNAMIC: "dynamic", + QuantType.QAT: "qat", + QuantType.WEIGHT_ONLY: "weight_only", +} + + +# TODO: make this private +def _get_quant_type_to_str(quant_type: QuantType) -> str: + return _quant_type_to_str[quant_type] + + +def _quant_type_from_str(name: str) -> QuantType: + for quant_type, s in _quant_type_to_str.items(): + if name == s: + return quant_type + raise ValueError(f"Unknown QuantType name '{name}'") diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/quantization_mappings.py b/phivenv/Lib/site-packages/torch/ao/quantization/quantization_mappings.py new file mode 100644 index 0000000000000000000000000000000000000000..99e59ea06d022c1f6963fdcad811bfce83d76447 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/quantization_mappings.py @@ -0,0 +1,365 @@ +import copy +from typing import Any, Callable, Optional, Union + +import torch +import torch.ao.nn as ao_nn +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.intrinsic.qat as nniqat +import torch.ao.nn.intrinsic.quantized as nniq +import torch.ao.nn.intrinsic.quantized.dynamic as nniqd +import torch.ao.nn.qat as nnqat +import torch.ao.nn.qat.dynamic as nnqatd +import torch.ao.nn.quantized as nnq +import torch.ao.nn.quantized.dynamic as nnqd +import torch.ao.nn.quantized.reference as nnqr + +# Because `torch.ao.nn` uses lazy imports, we need to make +# sure we import the contents explicitly here. +import torch.ao.nn.sparse +import torch.nn.functional as F +from torch import nn +from torch.ao.quantization.fake_quantize import ( + default_fixed_qparams_range_0to1_fake_quant, + default_fixed_qparams_range_neg1to1_fake_quant, +) +from torch.ao.quantization.stubs import DeQuantStub, QuantStub +from torch.ao.quantization.utils import get_combined_dict +from torch.nn.utils.parametrize import type_before_parametrizations + + +__all__ = [ + "DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS", + "DEFAULT_STATIC_QUANT_MODULE_MAPPINGS", + "DEFAULT_QAT_MODULE_MAPPINGS", + "DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS", + "DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS", + "DEFAULT_MODULE_TO_ACT_POST_PROCESS", + "DEFAULT_STATIC_SPARSE_QUANT_MODULE_MAPPINGS", + "DEFAULT_DYNAMIC_SPARSE_QUANT_MODULE_MAPPINGS", + "no_observer_set", + "get_default_static_quant_module_mappings", + "get_default_static_quant_reference_module_mappings", + "get_embedding_static_quant_module_mappings", + "get_default_static_sparse_quant_module_mappings", + "get_static_quant_module_class", + "get_dynamic_quant_module_class", + "get_default_qat_module_mappings", + "get_embedding_qat_module_mappings", + "get_default_dynamic_quant_module_mappings", + "get_default_dynamic_sparse_quant_module_mappings", + "get_default_qconfig_propagation_list", + "get_default_compare_output_module_list", + "get_default_float_to_quantized_operator_mappings", + "get_quantized_operator", +] + +# Default map for swapping float module to reference quantized modules +DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS: dict[Callable, Any] = { + QuantStub: nnq.Quantize, + DeQuantStub: nnq.DeQuantize, + nn.Linear: nnqr.Linear, + nn.Conv1d: nnqr.Conv1d, + nn.Conv2d: nnqr.Conv2d, + nn.Conv3d: nnqr.Conv3d, + nn.ConvTranspose1d: nnqr.ConvTranspose1d, + nn.ConvTranspose2d: nnqr.ConvTranspose2d, + nn.ConvTranspose3d: nnqr.ConvTranspose3d, + nn.Embedding: nnqr.Embedding, + nn.EmbeddingBag: nnqr.EmbeddingBag, + nn.GRUCell: nnqr.GRUCell, + nn.LSTMCell: nnqr.LSTMCell, + nn.RNNCell: nnqr.RNNCell, + nn.LSTM: nnqr.LSTM, +} + +# Default map for swapping float module to quantized ones +DEFAULT_STATIC_QUANT_MODULE_MAPPINGS: dict[Callable, Any] = { + QuantStub: nnq.Quantize, + DeQuantStub: nnq.DeQuantize, + nn.BatchNorm2d: nnq.BatchNorm2d, + nn.BatchNorm3d: nnq.BatchNorm3d, + nn.Dropout: nnq.Dropout, + nn.Conv1d: nnq.Conv1d, + nn.Conv2d: nnq.Conv2d, + nn.Conv3d: nnq.Conv3d, + nn.ConvTranspose1d: nnq.ConvTranspose1d, + nn.ConvTranspose2d: nnq.ConvTranspose2d, + nn.ConvTranspose3d: nnq.ConvTranspose3d, + nn.ELU: nnq.ELU, + nn.Embedding: nnq.Embedding, + nn.EmbeddingBag: nnq.EmbeddingBag, + nn.GroupNorm: nnq.GroupNorm, + nn.Hardswish: nnq.Hardswish, + nn.InstanceNorm1d: nnq.InstanceNorm1d, + nn.InstanceNorm2d: nnq.InstanceNorm2d, + nn.InstanceNorm3d: nnq.InstanceNorm3d, + nn.LayerNorm: nnq.LayerNorm, + nn.LeakyReLU: nnq.LeakyReLU, + nn.modules.linear.NonDynamicallyQuantizableLinear: nnq.Linear, + nn.Linear: nnq.Linear, + nn.ReLU6: nnq.ReLU6, + nn.PReLU: nnq.PReLU, + # Wrapper Modules: + nnq.FloatFunctional: nnq.QFunctional, + # Intrinsic modules: + nni.BNReLU2d: nniq.BNReLU2d, + nni.BNReLU3d: nniq.BNReLU3d, + nni.ConvReLU1d: nniq.ConvReLU1d, + nni.ConvReLU2d: nniq.ConvReLU2d, + nni.ConvReLU3d: nniq.ConvReLU3d, + nni.ConvAdd2d: nniq.ConvAdd2d, + nni.ConvAddReLU2d: nniq.ConvAddReLU2d, + nni.LinearReLU: nniq.LinearReLU, + nni.LinearLeakyReLU: nniq.LinearLeakyReLU, + nni.LinearTanh: nniq.LinearTanh, + nniqat.ConvBn1d: nnq.Conv1d, + nniqat.ConvBn2d: nnq.Conv2d, + nniqat.ConvBn3d: nnq.Conv3d, + nniqat.ConvBnReLU1d: nniq.ConvReLU1d, + nniqat.ConvBnReLU2d: nniq.ConvReLU2d, + nniqat.ConvBnReLU3d: nniq.ConvReLU3d, + nniqat.ConvReLU2d: nniq.ConvReLU2d, + nniqat.ConvReLU3d: nniq.ConvReLU3d, + nniqat.LinearReLU: nniq.LinearReLU, + nniqat.LinearBn1d: nnq.Linear, + # QAT modules: + nnqat.Linear: nnq.Linear, + nnqat.Conv2d: nnq.Conv2d, + nnqat.Conv3d: nnq.Conv3d, +} + +# Default map for swapping float module to qat modules +DEFAULT_QAT_MODULE_MAPPINGS: dict[Callable, Any] = { + nn.Conv2d: nnqat.Conv2d, + nn.Conv3d: nnqat.Conv3d, + nn.Linear: nnqat.Linear, + nn.modules.linear.NonDynamicallyQuantizableLinear: nnqat.Linear, + # Intrinsic modules: + nni.ConvBn1d: nniqat.ConvBn1d, + nni.ConvBn2d: nniqat.ConvBn2d, + nni.ConvBn3d: nniqat.ConvBn3d, + nni.ConvBnReLU1d: nniqat.ConvBnReLU1d, + nni.ConvBnReLU2d: nniqat.ConvBnReLU2d, + nni.ConvBnReLU3d: nniqat.ConvBnReLU3d, + nni.ConvReLU2d: nniqat.ConvReLU2d, + nni.ConvReLU3d: nniqat.ConvReLU3d, + nni.LinearReLU: nniqat.LinearReLU, + nni.LinearBn1d: nniqat.LinearBn1d, +} + +# Default map for swapping dynamic modules +DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS: dict[Callable, Any] = { + nn.GRUCell: nnqd.GRUCell, + nn.Linear: nnqd.Linear, + nnqatd.Linear: nnqd.Linear, + nn.modules.linear.NonDynamicallyQuantizableLinear: nnqd.Linear, + nn.LSTM: nnqd.LSTM, + nn.GRU: nnqd.GRU, + nn.LSTMCell: nnqd.LSTMCell, + nn.RNNCell: nnqd.RNNCell, + nni.LinearReLU: nniqd.LinearReLU, + nn.EmbeddingBag: nnq.EmbeddingBag, + nn.Embedding: nnq.Embedding, + # Don't want to enable these by default because the numerical + # accuracy is poor compared to other dynamic ops + # nn.Conv1d: nnqd.Conv1d, + # nn.Conv2d: nnqd.Conv2d, + # nn.Conv3d: nnqd.Conv3d, + # nn.ConvTranspose1d: nnqd.ConvTranspose1d, + # nn.ConvTranspose2d: nnqd.ConvTranspose2d, + # nn.ConvTranspose3d: nnqd.ConvTranspose3d, +} + +# Allowlist for propagating the qconfig +_INCLUDE_QCONFIG_PROPAGATE_LIST: set[Callable] = { + nn.Sequential, +} + +# Default mapping from floating point function or torch ops to quantized ops +# TODO: merge with default static mapping +DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS: dict[Union[Callable, str], Callable] = { + F.elu: torch.ops.quantized.elu, + F.hardswish: torch.ops.quantized.hardswish, + F.instance_norm: torch.ops.quantized.instance_norm, + F.layer_norm: torch.ops.quantized.layer_norm, + F.leaky_relu: torch.ops.quantized.leaky_relu, + F.dropout: torch.ops.quantized.dropout, +} + +# mapping from module to output activation post process class +DEFAULT_MODULE_TO_ACT_POST_PROCESS: dict[Callable, Callable] = { + nn.Hardsigmoid: default_fixed_qparams_range_0to1_fake_quant, + nn.Sigmoid: default_fixed_qparams_range_0to1_fake_quant, + nn.Softmax: default_fixed_qparams_range_0to1_fake_quant, + nn.Tanh: default_fixed_qparams_range_neg1to1_fake_quant, +} + +# Default map for swapping float module to static sparse quantized ones +DEFAULT_STATIC_SPARSE_QUANT_MODULE_MAPPINGS: dict[Callable, Any] = { + nn.Linear: ao_nn.sparse.quantized.Linear +} + +# Default map for swapping float module to dynamic sparse quantized ones +DEFAULT_DYNAMIC_SPARSE_QUANT_MODULE_MAPPINGS: dict[Callable, Any] = { + nn.Linear: ao_nn.sparse.quantized.dynamic.Linear +} + + +def no_observer_set() -> set[Any]: + r"""These modules cannot have observers inserted by default.""" + no_observers = {nn.quantizable.LSTM, nn.quantizable.MultiheadAttention} + return no_observers + + +def get_default_static_quant_module_mappings() -> dict[Callable, Any]: + """Get module mapping for post training static quantization""" + return copy.deepcopy(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS) + + +def get_default_static_quant_reference_module_mappings() -> dict[Callable, Any]: + """Get reference module mapping for post training static quantization""" + return copy.deepcopy(DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS) + + +def get_embedding_static_quant_module_mappings() -> dict[Callable, Any]: + """Get module mapping, including mapping for embedding QAT""" + mapping = copy.deepcopy(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS) + mapping[nnqat.EmbeddingBag] = nnq.EmbeddingBag + mapping[nnqat.Embedding] = nnq.Embedding + return mapping + + +def get_default_static_sparse_quant_module_mappings() -> dict[Callable, Any]: + """Get module mapping for post training static sparse quantization""" + return copy.deepcopy(DEFAULT_STATIC_SPARSE_QUANT_MODULE_MAPPINGS) + + +def get_static_quant_module_class( + float_module_class: Callable, + additional_static_quant_mapping: Optional[dict[Callable, Any]] = None, + is_reference: bool = False, +) -> Any: + r"""n Get the statically quantized module class corresponding to + the floating point module class + """ + if additional_static_quant_mapping is None: + additional_static_quant_mapping = {} + all_mappings = get_combined_dict( + DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS + if is_reference + else DEFAULT_STATIC_QUANT_MODULE_MAPPINGS, + additional_static_quant_mapping, + ) + static_quant_module_class = all_mappings.get(float_module_class, None) + assert static_quant_module_class is not None, ( + f"Floating point module class {str(float_module_class)}" + + " does not have a corresponding quantized module class" + ) + return copy.deepcopy(static_quant_module_class) + + +def get_dynamic_quant_module_class( + float_module_class: Callable, + additional_dynamic_quant_mapping: Optional[dict[Callable, Any]] = None, +) -> Any: + r"""n Get the dynamically quantized module class corresponding to + the floating point module class + """ + if additional_dynamic_quant_mapping is None: + additional_dynamic_quant_mapping = {} + all_mappings = get_combined_dict( + DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS, additional_dynamic_quant_mapping + ) + dynamic_quant_module_class = all_mappings.get(float_module_class, None) + assert dynamic_quant_module_class is not None, ( + f"Floating point module class {str(float_module_class)}" + + " does not have a corresponding quantized module class" + ) + return copy.deepcopy(dynamic_quant_module_class) + + +def get_default_qat_module_mappings() -> dict[Callable, Any]: + """Get default module mapping for quantization aware training""" + return copy.deepcopy(DEFAULT_QAT_MODULE_MAPPINGS) + + +def get_embedding_qat_module_mappings() -> dict[Callable, Any]: + """Get module mapping for quantization aware training + This is includes default values in addition to + enabling qat for embeddings. + """ + mapping = copy.deepcopy(DEFAULT_QAT_MODULE_MAPPINGS) + mapping[nn.EmbeddingBag] = nnqat.EmbeddingBag + mapping[nn.Embedding] = nnqat.Embedding + return mapping + + +def get_default_dynamic_quant_module_mappings() -> dict[Callable, Any]: + """Get module mapping for post training dynamic quantization""" + return DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS + + +def get_default_dynamic_sparse_quant_module_mappings() -> dict[Callable, Any]: + """Get module mapping for post training dynamic sparse quantization""" + return DEFAULT_DYNAMIC_SPARSE_QUANT_MODULE_MAPPINGS + + +def get_default_qconfig_propagation_list() -> set[Callable]: + """Get the default list of module types that we'll attach qconfig + attribute to in prepare + """ + QCONFIG_PROPAGATE_MODULE_CLASS_LIST = ( + set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.keys()) + | set(DEFAULT_QAT_MODULE_MAPPINGS.keys()) + | set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.keys()) + | _INCLUDE_QCONFIG_PROPAGATE_LIST + ) + return copy.deepcopy(QCONFIG_PROPAGATE_MODULE_CLASS_LIST) + + +def get_default_compare_output_module_list() -> set[Callable]: + """Get list of module class types that we will record output + in numeric suite + """ + NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST = ( + set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.values()) + | set(DEFAULT_QAT_MODULE_MAPPINGS.values()) + | set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.values()) + | set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.keys()) + | set(DEFAULT_QAT_MODULE_MAPPINGS.keys()) + | set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.keys()) + | _INCLUDE_QCONFIG_PROPAGATE_LIST + ) + return copy.deepcopy(NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST) + + +def get_default_float_to_quantized_operator_mappings() -> dict[ + Union[Callable, str], Callable +]: + return copy.deepcopy(DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS) + + +# TODO: merge with get_static_quant_module_class +def get_quantized_operator(float_op: Union[Callable, str]) -> Callable: + """Get the quantized operator corresponding to the float operator""" + quantized_op = DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op, None) + assert quantized_op is not None, ( + f"Operator {str(float_op)} does not have corresponding quantized op" + ) + return quantized_op + + +def _get_special_act_post_process(module: torch.nn.Module) -> Optional[Callable]: + r"""Get the special activation post process for `module`, this has + higher priority than the activation post process in `qconfig` + e.g. + input: torch.nn.Sigmoid + output: default_affine_fixed_qparam_fake_quant + """ + return DEFAULT_MODULE_TO_ACT_POST_PROCESS.get( + type_before_parametrizations(module), None + ) + + +def _has_special_act_post_process(module: torch.nn.Module) -> bool: + return module.training and type(module) in DEFAULT_MODULE_TO_ACT_POST_PROCESS diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/quantize.py b/phivenv/Lib/site-packages/torch/ao/quantization/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..a69e8f364a8abe8f681e650e05b4f41498a55fa3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/quantize.py @@ -0,0 +1,819 @@ +# mypy: allow-untyped-defs +import copy +import inspect +import itertools +import typing_extensions +import warnings + +import torch +import torch.ao.nn.quantized as nnq +import torch.nn as nn +from torch.ao.nn.intrinsic import _FusedModule +from torch.ao.quantization.observer import _is_activation_post_process +from torch.ao.quantization.qconfig import ( + _activation_is_memoryless, + _add_module_to_qconfig_obs_ctr, + default_dynamic_qconfig, + float16_dynamic_qconfig, + float_qparams_weight_only_qconfig, + float_qparams_weight_only_qconfig_4bit, +) +from torch.ao.quantization.quantization_mappings import ( + _get_special_act_post_process, + _has_special_act_post_process, + get_default_dynamic_quant_module_mappings, + get_default_qat_module_mappings, + get_default_qconfig_propagation_list, + get_default_static_quant_module_mappings, + get_default_static_quant_reference_module_mappings, + no_observer_set, +) +from torch.ao.quantization.stubs import DeQuantStub, QuantWrapper +from torch.nn.utils.parametrize import type_before_parametrizations + +from .utils import ( + DEPRECATION_WARNING, + get_qparam_dict, + has_no_children_ignoring_parametrizations, +) + + +__all__ = [ + "get_default_custom_config_dict", + "propagate_qconfig_", + "add_quant_dequant", + "prepare", + "quantize", + "quantize_dynamic", + "prepare_qat", + "quantize_qat", + "convert", + "swap_module", +] + + +# TODO remove this once BC is no longer required to avoid a SEV +is_activation_post_process = _is_activation_post_process + + +_DEFAULT_CUSTOM_CONFIG_DICT = { + "float_to_observed_custom_module_class": { + nn.LSTM: nn.quantizable.LSTM, + nn.MultiheadAttention: nn.quantizable.MultiheadAttention, + }, + "observed_to_quantized_custom_module_class": { + nn.quantizable.LSTM: nn.quantized.LSTM, + nn.quantizable.MultiheadAttention: nn.quantized.MultiheadAttention, + }, +} + + +def get_default_custom_config_dict(): + r"""Defines the default custom config dict.""" + return _DEFAULT_CUSTOM_CONFIG_DICT + + +def _propagate_qconfig_helper( + module, + qconfig_dict, + qconfig_parent=None, + prefix="", + prepare_custom_config_dict=None, +): + r"""This is a helper function for `propagate_qconfig_` + + Args: + module: input module + qconfig_dict: dictionary that maps from name of submodule to quantization + configuration + qconfig_parent: quantization config of parent module, we will fallback to + this config when there is no specified config for current + module + prefix: corresponding prefix of the current module, used as key in + qconfig_dict + prepare_custom_config_dict: dictionary for custom handling of modules + see docs for :func:`~torch.ao.quantization.prepare_fx` + + Return: + None, module is modified inplace with qconfig attached + """ + + module_qconfig = qconfig_dict.get( + type_before_parametrizations(module), qconfig_parent + ) + module_qconfig = qconfig_dict.get(prefix, module_qconfig) + module_qconfig = getattr(module, "qconfig", module_qconfig) + + torch.ao.quantization.qconfig._assert_valid_qconfig(module_qconfig, module) + + qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(module_qconfig, module) + module.qconfig = qconfig_with_device_check + + for name, child in module.named_children(): + module_prefix = prefix + "." + name if prefix else name + # do no not propagate qconfig to child if child is non traceable + if prepare_custom_config_dict is None or not ( + name in prepare_custom_config_dict.get("non_traceable_module_name", []) + or type(child) + in prepare_custom_config_dict.get("non_traceable_module_class", []) + ): + _propagate_qconfig_helper( + child, qconfig_dict, qconfig_with_device_check, module_prefix + ) + + +def propagate_qconfig_(module, qconfig_dict=None, prepare_custom_config_dict=None): + r"""Propagate qconfig through the module hierarchy and assign `qconfig` + attribute on each leaf module + + Args: + module: input module + qconfig_dict: dictionary that maps from name or type of submodule to + quantization configuration, qconfig applies to all submodules of a + given module unless qconfig for the submodules are specified (when + the submodule already has qconfig attribute) + prepare_custom_config_dict: dictionary for custom handling of modules + see docs for :func:`~torch.ao.quantization.prepare_fx` + + Return: + None, module is modified inplace with qconfig attached + """ + if qconfig_dict is None: + qconfig_dict = {} + if prepare_custom_config_dict is None: + prepare_custom_config_dict = {} + _propagate_qconfig_helper( + module, qconfig_dict, prepare_custom_config_dict=prepare_custom_config_dict + ) + + +def _observer_forward_hook(self, input, output): + r"""Forward hook that calls observer on the output""" + return self.activation_post_process(output) + + +def _observer_forward_pre_hook(self, input): + r"""Forward pre hook that calls observer on the output""" + return self.activation_post_process(input[0]) + + +def _register_activation_post_process_hook(module, pre_hook=False): + assert hasattr(module, "activation_post_process"), ( + "Expect activation_post_process attribute already attached to the module" + ) + if pre_hook: + module.register_forward_pre_hook(_observer_forward_pre_hook, prepend=True) + else: + module.register_forward_hook(_observer_forward_hook, prepend=True) + + +def _add_observer_( + module, + qconfig_propagation_list=None, + non_leaf_module_list=None, + device=None, + custom_module_class_mapping=None, +): + r"""Add observer for the leaf child of the module. + + This function insert observer module to all leaf child module that + has a valid qconfig attribute. + + Args: + module: input module with qconfig attributes for all the leaf modules that we want to quantize + qconfig_propagation_list: a list of quantizable modules that will have observers added to them + if they are leaf nodes + device: parent device, if any + non_leaf_module_list: list of non-leaf modules we want to add observer + + Return: + None, module is modified inplace with added observer modules and forward_hooks + """ + if qconfig_propagation_list is None: + qconfig_propagation_list = get_default_qconfig_propagation_list() + + if custom_module_class_mapping is None: + custom_module_class_mapping = {} + + # respect device affinity when adding observers + if device is None: + devices = _get_unique_devices_(module) + assert len(devices) <= 1, ( + f"_add_observer_ only works with cpu or single-device CUDA modules, but got devices {devices}" + ) + device = next(iter(devices)) if len(devices) > 0 else None + + def get_activation_post_process(qconfig, device, special_act_post_process=None): + activation = ( + qconfig.activation() + if special_act_post_process is None + else special_act_post_process() + ) + if device is not None: + activation.to(device) + return activation + + def needs_observation(m): + return hasattr(m, "qconfig") and m.qconfig is not None + + def insert_activation_post_process(m, special_act_post_process=None): + """Adds an activation post process module and register + a pre or post hook that calls the module + """ + # We don't insert observer/fake_quantize for DeQuantStub + if needs_observation(m) and not isinstance(m, DeQuantStub): + # observer and hook will be gone after we swap the module + m.add_module( + "activation_post_process", + get_activation_post_process( + m.qconfig, device, special_act_post_process + ), + ) + # Register observer as the first entry in the hook list + # All post forward hooks are preserved and will be executed after the observer before convert + _register_activation_post_process_hook( + m, pre_hook=_activation_is_memoryless(m.qconfig) + ) + + for name, child in module.named_children(): + # TODO remove Dropout special after codebase stable + if type_before_parametrizations(child) in [nn.Dropout]: + continue + elif issubclass( + type_before_parametrizations(child), (nnq.FloatFunctional, nnq.QFunctional) + ): + if needs_observation(child): + assert hasattr(child, "activation_post_process"), ( + f"functional class {type_before_parametrizations(child)} has no pre-defined `activation_post_process`" + ) + child.activation_post_process = get_activation_post_process( + child.qconfig, device + ) + elif isinstance(child, _FusedModule): + # activation_post_process are now added directly to nn.Sequential/_FusedModule + if needs_observation(child): + insert_activation_post_process(child) + elif ( + non_leaf_module_list is not None + and type_before_parametrizations(child) in non_leaf_module_list + ): + if needs_observation(child): + insert_activation_post_process(child) + elif _has_special_act_post_process(child): + special_act_post_process = _get_special_act_post_process(child) + insert_activation_post_process(child, special_act_post_process) + elif ( + needs_observation(child) + and type_before_parametrizations(child) in custom_module_class_mapping + ): + observed_class = custom_module_class_mapping[ + type_before_parametrizations(child) + ] + observed_child = observed_class.from_float(child) + setattr(module, name, observed_child) + # TODO: These are the modules that cannot be observed + # Once there are more, we should move them to a separate list + if not issubclass(observed_class, tuple(no_observer_set())): + insert_activation_post_process(observed_child) + else: + _add_observer_( + child, + qconfig_propagation_list, + non_leaf_module_list, + device, + custom_module_class_mapping, + ) + + # Insert observers only for leaf nodes, note that this observer is for + # the output of the module, for input QuantStub will observe them + if ( + has_no_children_ignoring_parametrizations(module) + and not isinstance(module, torch.nn.Sequential) + and type_before_parametrizations(module) in qconfig_propagation_list + ): + insert_activation_post_process(module) + # This is a special case for AdaRound eager mode + # AdaRound contains weight_fake_quant to be propagated from API to convert + # leaf node check with a number of children looks naive assumption that blocks + # Adding an exception case for AdaRound + if ( + hasattr(module, "weight_fake_quant") + and not isinstance(module, torch.nn.Sequential) + and type_before_parametrizations(module) in qconfig_propagation_list + ): + insert_activation_post_process(module) + + +def _get_unique_devices_(module): + return {p.device for p in module.parameters() if p.device.type != "meta"} | { + p.device for p in module.buffers() if p.device.type != "meta" + } + + +def add_quant_dequant(module): + r"""Wrap the leaf child module in QuantWrapper if it has a valid qconfig + Note that this function will modify the children of module inplace and it + can return a new module which wraps the input module as well. + + Args: + module: input module with qconfig attributes for all the leaf modules + that we want to quantize + + Return: + Either the inplace modified module with submodules wrapped in + `QuantWrapper` based on qconfig or a new `QuantWrapper` module which + wraps the input module, the latter case only happens when the input + module is a leaf module and we want to quantize it. + """ + if ( + has_no_children_ignoring_parametrizations(module) + and hasattr(module, "qconfig") + and module.qconfig + ): + return QuantWrapper(module) + + for name, child in module.named_children(): + module._modules[name] = add_quant_dequant(child) + return module + + +@typing_extensions.deprecated(DEPRECATION_WARNING) +def prepare( + model, + inplace=False, + allow_list=None, + observer_non_leaf_module_list=None, + prepare_custom_config_dict=None, +): + r"""Prepares a copy of the model for quantization calibration or quantization-aware training. + + Quantization configuration should be assigned preemptively + to individual submodules in `.qconfig` attribute. + + The model will be attached with observer or fake quant modules, and qconfig + will be propagated. + + Args: + `model`: input model to be modified in-place + `inplace`: carry out model transformations in-place, the original module is mutated + `allow_list`: list of quantizable modules + `observer_non_leaf_module_list`: list of non-leaf modules we want to add observer + `prepare_custom_config_dict`: customization configuration dictionary for prepare function + + .. code-block:: python + + # Example of prepare_custom_config_dict: + prepare_custom_config_dict = { + # user will manually define the corresponding observed + # module class which has a from_float class method that converts + # float custom module to observed custom module + "float_to_observed_custom_module_class": {CustomModule: ObservedCustomModule} + } + + """ + torch._C._log_api_usage_once("quantization_api.quantize.prepare") + if prepare_custom_config_dict is None: + prepare_custom_config_dict = get_default_custom_config_dict() + custom_module_class_mapping = prepare_custom_config_dict.get( + "float_to_observed_custom_module_class", {} + ) + + if not inplace: + model = copy.deepcopy(model) + + # TODO: remove allow_list + qconfig_propagation_list = allow_list + if allow_list is None: + qconfig_propagation_list = get_default_qconfig_propagation_list() + propagate_qconfig_(model, qconfig_dict=None) + + # sanity check common API misusage + if not any(hasattr(m, "qconfig") and m.qconfig for m in model.modules()): + warnings.warn( + "None of the submodule got qconfig applied. Make sure you " + "passed correct configuration through `qconfig_dict` or " + "by assigning the `.qconfig` attribute directly on submodules" + ) + + _add_observer_( + model, + qconfig_propagation_list, + observer_non_leaf_module_list, + custom_module_class_mapping=custom_module_class_mapping, + ) + return model + + +def _remove_activation_post_process(module): + # TODO: maybe we should change activation_post_process to _activation_post_process + # to prevent it from being used by user + if hasattr(module, "activation_post_process") and _is_activation_post_process( + module.activation_post_process + ): + delattr(module, "activation_post_process") + + # remove activation_post_process pre and post hooks + def remove_hooks(pre_hook=False): + hook_map = module._forward_pre_hooks if pre_hook else module._forward_hooks + observer_hook = ( + _observer_forward_pre_hook if pre_hook else _observer_forward_hook + ) + handle_ids_to_remove = set() + for handle_id, hook_fn in hook_map.items(): + if hook_fn is observer_hook: + handle_ids_to_remove.add(handle_id) + for handle_id in handle_ids_to_remove: + hook_map.pop(handle_id) + + remove_hooks(pre_hook=True) + remove_hooks(pre_hook=False) + + +# TODO: rename to something more general +def _remove_qconfig(module): + r"""Clean up the qconfig left in the module so that new qconfig can be + propagated. + + Args: + module: module to be cleaned up + """ + for child in module.children(): + _remove_qconfig(child) + + if hasattr(module, "qconfig"): + del module.qconfig + + _remove_activation_post_process(module) + + +@typing_extensions.deprecated(DEPRECATION_WARNING) +def quantize(model, run_fn, run_args, mapping=None, inplace=False): + r"""Quantize the input float model with post training static quantization. + + First it will prepare the model for calibration, then it calls + `run_fn` which will run the calibration step, after that we will + convert the model to a quantized model. + + Args: + model: input float model + run_fn: a calibration function for calibrating the prepared model + run_args: positional arguments for `run_fn` + inplace: carry out model transformations in-place, the original module is mutated + mapping: correspondence between original module types and quantized counterparts + + Return: + Quantized model. + """ + torch._C._log_api_usage_once("quantization_api.quantize.quantize") + if mapping is None: + mapping = get_default_static_quant_module_mappings() + if not inplace: + model = copy.deepcopy(model) + model.eval() + prepare(model, inplace=True) + run_fn(model, *run_args) + convert(model, mapping, inplace=True) + return model + + +@typing_extensions.deprecated(DEPRECATION_WARNING) +def quantize_dynamic( + model, qconfig_spec=None, dtype=torch.qint8, mapping=None, inplace=False +): + r"""Converts a float model to dynamic (i.e. weights-only) quantized model. + + Replaces specified modules with dynamic weight-only quantized versions and output the quantized model. + + For simplest usage provide `dtype` argument that can be float16 or qint8. Weight-only quantization + by default is performed for layers with large weights size - i.e. Linear and RNN variants. + + Fine grained control is possible with `qconfig` and `mapping` that act similarly to `quantize()`. + If `qconfig` is provided, the `dtype` argument is ignored. + + Args: + model: input model + qconfig_spec: Either: + + - A dictionary that maps from name or type of submodule to quantization + configuration, qconfig applies to all submodules of a given + module unless qconfig for the submodules are specified (when the + submodule already has qconfig attribute). Entries in the dictionary + need to be QConfig instances. + + - A set of types and/or submodule names to apply dynamic quantization to, + in which case the `dtype` argument is used to specify the bit-width + + inplace: carry out model transformations in-place, the original module is mutated + mapping: maps type of a submodule to a type of corresponding dynamically quantized version + with which the submodule needs to be replaced + + """ + torch._C._log_api_usage_once("quantization_api.quantize.quantize_dynamic") + if qconfig_spec is None: + if dtype == torch.qint8: + qconfig_spec = { + nn.Linear: default_dynamic_qconfig, + nn.LSTM: default_dynamic_qconfig, + nn.GRU: default_dynamic_qconfig, + nn.LSTMCell: default_dynamic_qconfig, + nn.RNNCell: default_dynamic_qconfig, + nn.GRUCell: default_dynamic_qconfig, + } + elif dtype == torch.float16: + qconfig_spec = { + nn.Linear: float16_dynamic_qconfig, + nn.LSTM: float16_dynamic_qconfig, + nn.GRU: float16_dynamic_qconfig, + nn.LSTMCell: float16_dynamic_qconfig, + nn.RNNCell: float16_dynamic_qconfig, + nn.GRUCell: float16_dynamic_qconfig, + } + elif dtype == torch.quint8: + qconfig_spec = { + nn.EmbeddingBag: float_qparams_weight_only_qconfig, + nn.Embedding: float_qparams_weight_only_qconfig, + } + elif dtype == torch.quint4x2: + qconfig_spec = { + nn.EmbeddingBag: float_qparams_weight_only_qconfig_4bit, + } + else: + raise ValueError( + f"Don't know how to quantize with default settings for {dtype}. Provide full qconfig please" + ) + elif isinstance(qconfig_spec, set): + if dtype is torch.qint8: + default_qconfig = default_dynamic_qconfig + elif dtype is torch.float16: + default_qconfig = float16_dynamic_qconfig + elif dtype is torch.quint8: + default_qconfig = float_qparams_weight_only_qconfig + elif dtype is torch.quint4x2: + default_qconfig = float_qparams_weight_only_qconfig_4bit + else: + raise RuntimeError( + "Unknown dtype specified for quantize_dynamic: ", str(dtype) + ) + qconfig_spec = dict(zip(qconfig_spec, itertools.repeat(default_qconfig))) + + if mapping is None: + mapping = get_default_dynamic_quant_module_mappings() + + if not inplace: + model = copy.deepcopy(model) + model.eval() + propagate_qconfig_(model, qconfig_spec) + convert(model, mapping, inplace=True) + return model + + +@typing_extensions.deprecated(DEPRECATION_WARNING) +def prepare_qat(model, mapping=None, inplace=False): + r""" + Prepares a copy of the model for quantization calibration or + quantization-aware training and converts it to quantized version. + + Quantization configuration should be assigned preemptively + to individual submodules in `.qconfig` attribute. + + Args: + model: input model to be modified in-place + mapping: dictionary that maps float modules to quantized modules to be + replaced. + inplace: carry out model transformations in-place, the original module + is mutated + """ + torch._C._log_api_usage_once("quantization_api.quantize.prepare_qat") + assert model.training, "prepare_qat only works on models in training mode" + if mapping is None: + mapping = get_default_qat_module_mappings() + + if not inplace: + model = copy.deepcopy(model) + + propagate_qconfig_(model, qconfig_dict=None) + convert(model, mapping=mapping, inplace=True, remove_qconfig=False) + prepare(model, observer_non_leaf_module_list=set(mapping.values()), inplace=True) + return model + + +@typing_extensions.deprecated(DEPRECATION_WARNING) +def quantize_qat(model, run_fn, run_args, inplace=False): + r"""Do quantization aware training and output a quantized model + + Args: + model: input model + run_fn: a function for evaluating the prepared model, can be a + function that simply runs the prepared model or a training + loop + run_args: positional arguments for `run_fn` + + Return: + Quantized model. + """ + torch._C._log_api_usage_once("quantization_api.quantize.quantize_qat") + if not inplace: + model = copy.deepcopy(model) + model.train() + prepare_qat(model, inplace=True) + run_fn(model, *run_args) + convert(model, inplace=True) + return model + + +@typing_extensions.deprecated(DEPRECATION_WARNING) +def convert( + module, + mapping=None, + inplace=False, + remove_qconfig=True, + is_reference=False, + convert_custom_config_dict=None, + use_precomputed_fake_quant=False, +): + r"""Converts submodules in input module to a different module according to `mapping` + by calling `from_float` method on the target module class. And remove qconfig at the + end if remove_qconfig is set to True. + + Args: + `module`: prepared and calibrated module + `mapping`: a dictionary that maps from source module type to target + module type, can be overwritten to allow swapping user defined + Modules + `inplace`: carry out model transformations in-place, the original module + is mutated + `convert_custom_config_dict`: custom configuration dictionary for convert function + `use_precomputed_fake_quant`: a flag to enable use of precomputed fake quant + + .. code-block:: python + + # Example of convert_custom_config_dict: + convert_custom_config_dict = { + # user will manually define the corresponding quantized + # module class which has a from_observed class method that converts + # observed custom module to quantized custom module + "observed_to_quantized_custom_module_class": { + ObservedCustomModule: QuantizedCustomModule + } + } + + """ + torch._C._log_api_usage_once("quantization_api.quantize.convert") + if not inplace: + module = copy.deepcopy(module) + _convert( + module, + mapping, + inplace=True, + is_reference=is_reference, + convert_custom_config_dict=convert_custom_config_dict, + use_precomputed_fake_quant=use_precomputed_fake_quant, + ) + if remove_qconfig: + _remove_qconfig(module) + return module + + +def _convert( + module, + mapping=None, + inplace=False, + is_reference=False, + convert_custom_config_dict=None, + use_precomputed_fake_quant=False, +): + r"""Converts submodules in input module to a different module according to `mapping` + by calling `from_float` method on the target module class + + Args: + module: input module + mapping: a dictionary that maps from source module type to target + module type, can be overwritten to allow swapping user defined + Modules + inplace: carry out model transformations in-place, the original module + is mutated + is_reference: a flag to enable quantized reference module + use_precomputed_fake_quant: a flag to enable use of precomputed fake quant + + """ + if mapping is None: + mapping = ( + get_default_static_quant_reference_module_mappings() + if is_reference + else get_default_static_quant_module_mappings() + ) + if convert_custom_config_dict is None: + convert_custom_config_dict = get_default_custom_config_dict() + custom_module_class_mapping = convert_custom_config_dict.get( + "observed_to_quantized_custom_module_class", {} + ) + + if not inplace: + module = copy.deepcopy(module) + reassign = {} + for name, mod in module.named_children(): + # both fused modules and observed custom modules are + # swapped as one unit + if ( + not isinstance(mod, _FusedModule) + and type_before_parametrizations(mod) not in custom_module_class_mapping + ): + _convert( + mod, + mapping, + True, # inplace + is_reference, + convert_custom_config_dict, + use_precomputed_fake_quant=use_precomputed_fake_quant, + ) + reassign[name] = swap_module( + mod, mapping, custom_module_class_mapping, use_precomputed_fake_quant + ) + + for key, value in reassign.items(): + module._modules[key] = value + + return module + + +def swap_module( + mod, mapping, custom_module_class_mapping, use_precomputed_fake_quant=False +): + r"""Swaps the module if it has a quantized counterpart and it has an + `observer` attached. + + Args: + mod: input module + mapping: a dictionary that maps from nn module to nnq module + + Return: + The corresponding quantized module of `mod` + """ + new_mod = mod + if hasattr(mod, "qconfig") and mod.qconfig is not None: + swapped = False + if type_before_parametrizations(mod) in custom_module_class_mapping: + new_mod = custom_module_class_mapping[ + type_before_parametrizations(mod) + ].from_observed(mod) + swapped = True + elif type_before_parametrizations(mod) in mapping: + qmod = mapping[type_before_parametrizations(mod)] + if hasattr(qmod, "_IS_REFERENCE") and qmod._IS_REFERENCE: + assert mod.qconfig is not None + weight_post_process = mod.qconfig.weight() + weight_post_process(mod.weight) + weight_qparams = get_qparam_dict(weight_post_process) + new_mod = qmod.from_float(mod, weight_qparams) + else: + sig = inspect.signature(qmod.from_float) + if "use_precomputed_fake_quant" in sig.parameters: + new_mod = qmod.from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + else: + new_mod = qmod.from_float(mod) + swapped = True + + if swapped: + # Preserve module's pre forward hooks. They'll be called on quantized input + for pre_hook_fn in mod._forward_pre_hooks.values(): + new_mod.register_forward_pre_hook(pre_hook_fn) + # Preserve module's post forward hooks except _observer_forward_hook + # After convert they'll work with quantized output + for hook_fn in mod._forward_hooks.values(): + if hook_fn is not _observer_forward_hook: + new_mod.register_forward_hook(hook_fn) + + # respect device affinity when swapping modules + devices = _get_unique_devices_(mod) + assert len(devices) <= 1 or ( + len(devices) == 2 and torch.device("meta") in devices + ), ( + f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}" + ) + device = next(iter(devices)) if len(devices) > 0 else None + if device: + new_mod.to(device) + return new_mod + + +def _get_observer_dict(mod, target_dict, prefix=""): + r"""Traverse the modules and save all observers into dict. + This is mainly used for quantization accuracy debug + Args: + mod: the top module we want to save all observers + prefix: the prefix for the current module + target_dict: the dictionary used to save all the observers + """ + + def get_prefix(prefix): + return prefix if prefix == "" else prefix + "." + + if hasattr(mod, "activation_post_process"): + target_dict[get_prefix(prefix) + "activation_post_process"] = ( + mod.activation_post_process + ) + for name, child in mod.named_children(): + module_prefix = get_prefix(prefix) + name if prefix else name + _get_observer_dict(child, target_dict, module_prefix) diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/quantize_fx.py b/phivenv/Lib/site-packages/torch/ao/quantization/quantize_fx.py new file mode 100644 index 0000000000000000000000000000000000000000..a0f8530b4cfa5d73eea98f1160e8a6814e97ba10 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/quantize_fx.py @@ -0,0 +1,759 @@ +import copy +import typing_extensions +import warnings +from typing import Any, Optional, Union + +import torch +from torch.fx import GraphModule +from torch.fx.graph_module import _USER_PRESERVED_ATTRIBUTES_KEY + +from .backend_config import BackendConfig, get_tensorrt_backend_config # noqa: F401 +from .fx.convert import convert +from .fx.custom_config import ConvertCustomConfig, FuseCustomConfig, PrepareCustomConfig +from .fx.fuse import fuse # noqa: F401 +from .fx.graph_module import ObservedGraphModule # noqa: F401 +from .fx.prepare import prepare # noqa: F401 +from .fx.tracer import QuantizationTracer, Scope, ScopeContextManager # noqa: F401 +from .fx.utils import ( # noqa: F401 + get_custom_module_class_keys, + get_skipped_module_name_and_classes, +) +from .qconfig_mapping import QConfigMapping +from .utils import DEPRECATION_WARNING + + +def attach_preserved_attrs_to_model( + model: Union[GraphModule, torch.nn.Module], + preserved_attrs: dict[str, Any], +) -> None: + """Store preserved attributes to the model.meta so that it can be preserved during deepcopy""" + model.meta[_USER_PRESERVED_ATTRIBUTES_KEY] = copy.copy(preserved_attrs) # type: ignore[operator, index, assignment] + # set the preserved attributes in the model so that user can call + # model.attr as they do before calling fx graph mode quantization + for attr_name, attr in model.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items(): # type: ignore[index, union-attr] + setattr(model, attr_name, attr) + + +def _check_is_graph_module(model: torch.nn.Module) -> None: + if not isinstance(model, GraphModule): + raise ValueError( + "input model must be a GraphModule, " + + "Got type:" + + str(type(model)) + + " Please make " + + "sure to follow the tutorials." + ) + + +def _attach_meta_to_node_if_not_exist(model: GraphModule) -> None: + """Attach meta field to all nodes of the graph if it does not exist, + meta field is a field stores some meta information about the node, such + as dtype and shape information for output of the node, this only exists + if the program is captured by make_fx (used in quantize_pt2e flow), if + the program is captured by torch.fx symbolic tracing, this field may not exist, + so we add it here to avoid checking this all over the places + """ + for node in model.graph.nodes: + if not hasattr(node, "meta"): + node.meta = {} + + +def _swap_ff_with_fxff(model: torch.nn.Module) -> None: + r"""Swap FloatFunctional with FXFloatFunctional""" + modules_to_swap = [] + for name, module in model.named_children(): + if isinstance(module, torch.ao.nn.quantized.FloatFunctional): + modules_to_swap.append(name) + else: + _swap_ff_with_fxff(module) + + for name in modules_to_swap: + del model._modules[name] + model._modules[name] = torch.ao.nn.quantized.FXFloatFunctional() + + +def _fuse_fx( + model: GraphModule, + is_qat: bool, + fuse_custom_config: Union[FuseCustomConfig, dict[str, Any], None] = None, + backend_config: Union[BackendConfig, dict[str, Any], None] = None, +) -> GraphModule: + r"""Internal helper function to fuse modules in preparation for quantization + + Args: + model: GraphModule object from symbolic tracing (torch.fx.symbolic_trace) + """ + _check_is_graph_module(model) + return fuse(model, is_qat, fuse_custom_config, backend_config) # type: ignore[operator] + + +def _prepare_fx( + model: torch.nn.Module, + qconfig_mapping: Union[QConfigMapping, dict[str, Any]], + is_qat: bool, + example_inputs: tuple[Any, ...], + prepare_custom_config: Union[PrepareCustomConfig, dict[str, Any], None] = None, + _equalization_config: Optional[Union[QConfigMapping, dict[str, Any]]] = None, + backend_config: Union[BackendConfig, dict[str, Any], None] = None, + is_standalone_module: bool = False, +) -> GraphModule: + r"""Internal helper function for prepare_fx + Args: + `model`, `qconfig_mapping`, `prepare_custom_config`, `_equalization_config`: + see docs for :func:`~torch.ao.quantization.prepare_fx` + `is_standalone_module`: a boolean flag indicates whether we are + quantizing a standalone module or not, a standalone module + is a submodule of the parent module that is not inlined in the + forward graph of the parent module, + the way we quantize standalone module is described in: + :func:`~torch.ao.quantization._prepare_standalone_module_fx` + """ + if prepare_custom_config is None: + prepare_custom_config = PrepareCustomConfig() + if _equalization_config is None: + _equalization_config = QConfigMapping() + + if isinstance(prepare_custom_config, dict): + warnings.warn( + "Passing a prepare_custom_config_dict to prepare is deprecated and will not be supported " + "in a future version. Please pass in a PrepareCustomConfig instead.", + FutureWarning, + stacklevel=3, + ) + prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config) + + # swap FloatFunctional with FXFloatFunctional + _swap_ff_with_fxff(model) + + skipped_module_names, skipped_module_classes = get_skipped_module_name_and_classes( + prepare_custom_config, is_standalone_module + ) + preserved_attr_names = prepare_custom_config.preserved_attributes + preserved_attrs = { + attr: getattr(model, attr) + for attr in preserved_attr_names + if hasattr(model, attr) + } + # symbolically trace the model + tracer = QuantizationTracer(skipped_module_names, skipped_module_classes) # type: ignore[arg-type] + graph_module = GraphModule(model, tracer.trace(model)) + _attach_meta_to_node_if_not_exist(graph_module) + + fuse_custom_config = FuseCustomConfig().set_preserved_attributes( + prepare_custom_config.preserved_attributes + ) + graph_module = _fuse_fx(graph_module, is_qat, fuse_custom_config, backend_config) + prepared = prepare( + graph_module, + qconfig_mapping, + is_qat, + tracer.node_name_to_scope, + example_inputs=example_inputs, + prepare_custom_config=prepare_custom_config, + _equalization_config=_equalization_config, + backend_config=backend_config, + is_standalone_module=is_standalone_module, + ) # type: ignore[operator] + + attach_preserved_attrs_to_model(prepared, preserved_attrs) + return prepared + + +def _prepare_standalone_module_fx( + model: torch.nn.Module, + qconfig_mapping: Union[QConfigMapping, dict[str, Any]], + is_qat: bool, + example_inputs: tuple[Any, ...], + prepare_custom_config: Union[PrepareCustomConfig, dict[str, Any], None] = None, + backend_config: Union[BackendConfig, dict[str, Any], None] = None, +) -> GraphModule: + r"""[Internal use only] Prepare a standalone module, so that it can be used when quantizing the + parent module. + standalone_module means it a submodule that is not inlined in parent module, + and will be quantized separately as one unit. + + How the standalone module is observed is specified by `input_quantized_idxs` and + `output_quantized_idxs` in the prepare_custom_config for the standalone module + + Returns: + + * model(GraphModule): prepared standalone module. It has these attributes in + model.meta: + + * `standalone_module_input_quantized_idxs(List[Int])`: a list of + indexes for the graph input that is expected to be quantized, + same as input_quantized_idxs configuration provided + for the standalone module + * `standalone_module_output_quantized_idxs(List[Int])`: a list of + indexs for the graph output that is quantized + same as input_quantized_idxs configuration provided + for the standalone module + + """ + return _prepare_fx( + model, + qconfig_mapping, + is_qat, + example_inputs, + prepare_custom_config, + backend_config=backend_config, + is_standalone_module=True, + ) + + +def fuse_fx( + model: torch.nn.Module, + fuse_custom_config: Union[FuseCustomConfig, dict[str, Any], None] = None, + backend_config: Union[BackendConfig, dict[str, Any], None] = None, +) -> GraphModule: + r"""Fuse modules like conv+bn, conv+bn+relu etc, model must be in eval mode. + Fusion rules are defined in torch.ao.quantization.fx.fusion_pattern.py + + Args: + + * `model` (torch.nn.Module): a torch.nn.Module model + * `fuse_custom_config` (FuseCustomConfig): custom configurations for fuse_fx. + See :class:`~torch.ao.quantization.fx.custom_config.FuseCustomConfig` for more details + Example:: + + from torch.ao.quantization import fuse_fx + + m = Model().eval() + m = fuse_fx(m) + + """ + if fuse_custom_config is None: + fuse_custom_config = FuseCustomConfig() + + if isinstance(fuse_custom_config, dict): + warnings.warn( + "Passing a fuse_custom_config_dict to fuse is deprecated and will not be supported " + "in a future version. Please pass in a FuseCustomConfig instead.", + FutureWarning, + stacklevel=2, + ) + fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config) + + torch._C._log_api_usage_once("quantization_api.quantize_fx.fuse_fx") + preserved_attr_names = fuse_custom_config.preserved_attributes + preserved_attrs = { + attr: getattr(model, attr) + for attr in preserved_attr_names + if hasattr(model, attr) + } + + graph_module = torch.fx.symbolic_trace(model) + _attach_meta_to_node_if_not_exist(graph_module) + graph_module = _fuse_fx(graph_module, False, fuse_custom_config, backend_config) + + attach_preserved_attrs_to_model(graph_module, preserved_attrs) + return graph_module + + +@typing_extensions.deprecated(DEPRECATION_WARNING) +def prepare_fx( + model: torch.nn.Module, + qconfig_mapping: Union[QConfigMapping, dict[str, Any]], + example_inputs: tuple[Any, ...], + prepare_custom_config: Union[PrepareCustomConfig, dict[str, Any], None] = None, + _equalization_config: Optional[Union[QConfigMapping, dict[str, Any]]] = None, + backend_config: Union[BackendConfig, dict[str, Any], None] = None, +) -> GraphModule: + r""" Prepare a model for post training quantization + + Args: + * `model` (torch.nn.Module): torch.nn.Module model + + * `qconfig_mapping` (QConfigMapping): QConfigMapping object to configure how a model is + quantized, see :class:`~torch.ao.quantization.qconfig_mapping.QConfigMapping` + for more details + + * `example_inputs` (Tuple[Any, ...]): Example inputs for forward function of the model, + Tuple of positional args (keyword args can be passed as positional args as well) + + * `prepare_custom_config` (PrepareCustomConfig): customization configuration for quantization tool. + See :class:`~torch.ao.quantization.fx.custom_config.PrepareCustomConfig` for more details + + * `_equalization_config`: config for specifying how to perform equalization on the model + + * `backend_config` (BackendConfig): config that specifies how operators are quantized + in a backend, this includes how the operators are observed, + supported fusion patterns, how quantize/dequantize ops are + inserted, supported dtypes etc. See :class:`~torch.ao.quantization.backend_config.BackendConfig` for more details + + Return: + A GraphModule with observer (configured by qconfig_mapping), ready for calibration + + Example:: + + import torch + from torch.ao.quantization import get_default_qconfig_mapping + from torch.ao.quantization.quantize_fx import prepare_fx + + class Submodule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + def forward(self, x): + x = self.linear(x) + return x + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + self.sub = Submodule() + + def forward(self, x): + x = self.linear(x) + x = self.sub(x) + x + return x + + # initialize a floating point model + float_model = M().eval() + + # define calibration function + def calibrate(model, data_loader): + model.eval() + with torch.no_grad(): + for image, target in data_loader: + model(image) + + # qconfig is the configuration for how we insert observers for a particular + # operator + # qconfig = get_default_qconfig("fbgemm") + # Example of customizing qconfig: + # qconfig = torch.ao.quantization.QConfig( + # activation=MinMaxObserver.with_args(dtype=torch.qint8), + # weight=MinMaxObserver.with_args(dtype=torch.qint8)) + # `activation` and `weight` are constructors of observer module + + # qconfig_mapping is a collection of quantization configurations, user can + # set the qconfig for each operator (torch op calls, functional calls, module calls) + # in the model through qconfig_mapping + # the following call will get the qconfig_mapping that works best for models + # that target "fbgemm" backend + qconfig_mapping = get_default_qconfig_mapping("fbgemm") + + # We can customize qconfig_mapping in different ways. + # e.g. set the global qconfig, which means we will use the same qconfig for + # all operators in the model, this can be overwritten by other settings + # qconfig_mapping = QConfigMapping().set_global(qconfig) + # e.g. quantize the linear submodule with a specific qconfig + # qconfig_mapping = QConfigMapping().set_module_name("linear", qconfig) + # e.g. quantize all nn.Linear modules with a specific qconfig + # qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Linear, qconfig) + # for a more complete list, please see the docstring for :class:`torch.ao.quantization.QConfigMapping` + # argument + + # example_inputs is a tuple of inputs, that is used to infer the type of the + # outputs in the model + # currently it's not used, but please make sure model(*example_inputs) runs + example_inputs = (torch.randn(1, 3, 224, 224),) + + # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack + # e.g. backend_config = get_default_backend_config("fbgemm") + # `prepare_fx` inserts observers in the model based on qconfig_mapping and + # backend_config. If the configuration for an operator in qconfig_mapping + # is supported in the backend_config (meaning it's supported by the target + # hardware), we'll insert observer modules according to the qconfig_mapping + # otherwise the configuration in qconfig_mapping will be ignored + # + # Example: + # in qconfig_mapping, user sets linear module to be quantized with quint8 for + # activation and qint8 for weight: + # qconfig = torch.ao.quantization.QConfig( + # observer=MinMaxObserver.with_args(dtype=torch.quint8), + # weight=MinMaxObserver.with-args(dtype=torch.qint8)) + # Note: current qconfig api does not support setting output observer, but + # we may extend this to support these more fine grained control in the + # future + # + # qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Linear, qconfig) + # in backend config, linear module also supports in this configuration: + # weighted_int8_dtype_config = DTypeConfig( + # input_dtype=torch.quint8, + # output_dtype=torch.quint8, + # weight_dtype=torch.qint8, + # bias_type=torch.float) + + # linear_pattern_config = BackendPatternConfig(torch.nn.Linear) \ + # .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ + # .add_dtype_config(weighted_int8_dtype_config) \ + # ... + + # backend_config = BackendConfig().set_backend_pattern_config(linear_pattern_config) + # `prepare_fx` will check that the setting requested by suer in qconfig_mapping + # is supported by the backend_config and insert observers and fake quant modules + # in the model + prepared_model = prepare_fx(float_model, qconfig_mapping, example_inputs) + # Run calibration + calibrate(prepared_model, sample_inference_data) + """ + torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_fx") + return _prepare_fx( + model, + qconfig_mapping, + False, # is_qat + example_inputs, + prepare_custom_config, + _equalization_config, + backend_config, + ) + + +@typing_extensions.deprecated(DEPRECATION_WARNING) +def prepare_qat_fx( + model: torch.nn.Module, + qconfig_mapping: Union[QConfigMapping, dict[str, Any]], + example_inputs: tuple[Any, ...], + prepare_custom_config: Union[PrepareCustomConfig, dict[str, Any], None] = None, + backend_config: Union[BackendConfig, dict[str, Any], None] = None, +) -> GraphModule: + r"""Prepare a model for quantization aware training + + Args: + * `model` (torch.nn.Module): torch.nn.Module model + * `qconfig_mapping` (QConfigMapping): see :func:`~torch.ao.quantization.prepare_fx` + * `example_inputs` (Tuple[Any, ...]): see :func:`~torch.ao.quantization.prepare_fx` + * `prepare_custom_config` (PrepareCustomConfig): see :func:`~torch.ao.quantization.prepare_fx` + * `backend_config` (BackendConfig): see :func:`~torch.ao.quantization.prepare_fx` + + Return: + A GraphModule with fake quant modules (configured by qconfig_mapping and backend_config), ready for + quantization aware training + + Example:: + + import torch + from torch.ao.quantization import get_default_qat_qconfig_mapping + from torch.ao.quantization.quantize_fx import prepare_qat_fx + + + class Submodule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + + def forward(self, x): + x = self.linear(x) + return x + + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + self.sub = Submodule() + + def forward(self, x): + x = self.linear(x) + x = self.sub(x) + x + return x + + + # initialize a floating point model + float_model = M().train() + # (optional, but preferred) load the weights from pretrained model + # float_model.load_weights(...) + + + # define the training loop for quantization aware training + def train_loop(model, train_data): + model.train() + for image, target in data_loader: + ... + + + # qconfig is the configuration for how we insert observers for a particular + # operator + # qconfig = get_default_qconfig("fbgemm") + # Example of customizing qconfig: + # qconfig = torch.ao.quantization.QConfig( + # activation=FakeQuantize.with_args(observer=MinMaxObserver.with_args(dtype=torch.qint8)), + # weight=FakeQuantize.with_args(observer=MinMaxObserver.with_args(dtype=torch.qint8))) + # `activation` and `weight` are constructors of observer module + + # qconfig_mapping is a collection of quantization configurations, user can + # set the qconfig for each operator (torch op calls, functional calls, module calls) + # in the model through qconfig_mapping + # the following call will get the qconfig_mapping that works best for models + # that target "fbgemm" backend + qconfig_mapping = get_default_qat_qconfig_mapping("fbgemm") + + # We can customize qconfig_mapping in different ways, please take a look at + # the docstring for :func:`~torch.ao.quantization.prepare_fx` for different ways + # to configure this + + # example_inputs is a tuple of inputs, that is used to infer the type of the + # outputs in the model + # currently it's not used, but please make sure model(*example_inputs) runs + example_inputs = (torch.randn(1, 3, 224, 224),) + + # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack + # e.g. backend_config = get_default_backend_config("fbgemm") + # `prepare_qat_fx` inserts observers in the model based on qconfig_mapping and + # backend_config, if the configuration for an operator in qconfig_mapping + # is supported in the backend_config (meaning it's supported by the target + # hardware), we'll insert fake_quantize modules according to the qconfig_mapping + # otherwise the configuration in qconfig_mapping will be ignored + # see :func:`~torch.ao.quantization.prepare_fx` for a detailed explanation of + # how qconfig_mapping interacts with backend_config + prepared_model = prepare_qat_fx(float_model, qconfig_mapping, example_inputs) + # Run training + train_loop(prepared_model, train_loop) + + """ + torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_qat_fx") + return _prepare_fx( + model, + qconfig_mapping, + True, # is_qat + example_inputs, + prepare_custom_config, + backend_config=backend_config, + ) + + +def _convert_fx( + graph_module: GraphModule, + is_reference: bool, + convert_custom_config: Union[ConvertCustomConfig, dict[str, Any], None] = None, + is_standalone_module: bool = False, + _remove_qconfig: bool = True, + qconfig_mapping: Union[QConfigMapping, dict[str, Any], None] = None, + backend_config: Union[BackendConfig, dict[str, Any], None] = None, + is_decomposed: bool = False, + keep_original_weights: bool = False, +) -> GraphModule: + """`is_standalone_module`: see docs in :func:`~torch.ao.quantization.prepare_standalone_module_fx`""" + if convert_custom_config is None: + convert_custom_config = ConvertCustomConfig() + + if isinstance(convert_custom_config, dict): + warnings.warn( + "Passing a convert_custom_config_dict to convert is deprecated and will not be supported " + "in a future version. Please pass in a ConvertCustomConfig instead.", + FutureWarning, + stacklevel=3, + ) + convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config) + + _check_is_graph_module(graph_module) + preserved_attr_names = convert_custom_config.preserved_attributes + preserved_attrs = { + attr: getattr(graph_module, attr) + for attr in preserved_attr_names + if hasattr(graph_module, attr) + } + + quantized = convert( + graph_module, + is_reference, + convert_custom_config, + is_standalone_module, + _remove_qconfig_flag=_remove_qconfig, + qconfig_mapping=qconfig_mapping, + backend_config=backend_config, + is_decomposed=is_decomposed, + keep_original_weights=keep_original_weights, + ) + + attach_preserved_attrs_to_model(quantized, preserved_attrs) + return quantized + + +@typing_extensions.deprecated(DEPRECATION_WARNING) +def convert_fx( + graph_module: GraphModule, + convert_custom_config: Union[ConvertCustomConfig, dict[str, Any], None] = None, + _remove_qconfig: bool = True, + qconfig_mapping: Union[QConfigMapping, dict[str, Any], None] = None, + backend_config: Union[BackendConfig, dict[str, Any], None] = None, + keep_original_weights: bool = False, +) -> GraphModule: + r"""Convert a calibrated or trained model to a quantized model + + Args: + * `graph_module` (torch.fx.GraphModule): A prepared and calibrated/trained model (GraphModule) + + * `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function. + See :class:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig` for more details + + * `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert. + + * `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization. + + The keys must include the ones in the qconfig_mapping passed to `prepare_fx` or `prepare_qat_fx`, + with the same values or `None`. Additional keys can be specified with values set to `None`. + + For each entry whose value is set to None, we skip quantizing that entry in the model:: + + qconfig_mapping = QConfigMapping + .set_global(qconfig_from_prepare) + .set_object_type(torch.nn.functional.add, None) # skip quantizing torch.nn.functional.add + .set_object_type(torch.nn.functional.linear, qconfig_from_prepare) + .set_module_name("foo.bar", None) # skip quantizing module "foo.bar" + + * `backend_config` (BackendConfig): A configuration for the backend which describes how + operators should be quantized in the backend, this includes quantization + mode support (static/dynamic/weight_only), dtype support (quint8/qint8 etc.), + observer placement for each operators and fused operators. + See :class:`~torch.ao.quantization.backend_config.BackendConfig` for more details + + Return: + A quantized model (torch.nn.Module) + + Example:: + + # prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training + # convert_fx converts a calibrated/trained model to a quantized model for the + # target hardware, this includes converting the model first to a reference + # quantized model, and then lower the reference quantized model to a backend + # Currently, the supported backends are fbgemm (onednn), qnnpack (xnnpack) and + # they share the same set of quantized operators, so we are using the same + # lowering procedure + # + # backend_config defines the corresponding reference quantized module for + # the weighted modules in the model, e.g. nn.Linear + # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack + # e.g. backend_config = get_default_backend_config("fbgemm") + quantized_model = convert_fx(prepared_model) + + """ + torch._C._log_api_usage_once("quantization_api.quantize_fx.convert_fx") + return _convert_fx( + graph_module, + is_reference=False, + convert_custom_config=convert_custom_config, + _remove_qconfig=_remove_qconfig, + qconfig_mapping=qconfig_mapping, + backend_config=backend_config, + keep_original_weights=keep_original_weights, + ) + + +def convert_to_reference_fx( + graph_module: GraphModule, + convert_custom_config: Union[ConvertCustomConfig, dict[str, Any], None] = None, + _remove_qconfig: bool = True, + qconfig_mapping: Union[QConfigMapping, dict[str, Any], None] = None, + backend_config: Union[BackendConfig, dict[str, Any], None] = None, +) -> GraphModule: + r"""Convert a calibrated or trained model to a reference quantized model, + see https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md for more details, + reference quantized model is a standard representation of a quantized model provided + by FX Graph Mode Quantization, it can be further lowered to run on the target + hardware, like accelerators + + Args: + * `graph_module` (GraphModule): A prepared and calibrated/trained model (GraphModule) + + * `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function. + See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. + + * `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert. + + * `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization. + See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. + + * `backend_config` (BackendConfig): A configuration for the backend which describes how + operators should be quantized in the backend. See + :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. + + Return: + A reference quantized model (GraphModule) + + Example:: + + # prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training + # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack + # e.g. backend_config = get_default_backend_config("fbgemm") + reference_quantized_model = convert_to_reference_fx(prepared_model) + + """ + torch._C._log_api_usage_once("quantization_api.quantize_fx.convert_to_reference_fx") + return _convert_fx( + graph_module, + is_reference=True, + convert_custom_config=convert_custom_config, + _remove_qconfig=_remove_qconfig, + qconfig_mapping=qconfig_mapping, + backend_config=backend_config, + ) + + +def _convert_to_reference_decomposed_fx( + graph_module: GraphModule, + convert_custom_config: Union[ConvertCustomConfig, dict[str, Any], None] = None, + qconfig_mapping: Union[QConfigMapping, dict[str, Any], None] = None, + backend_config: Union[BackendConfig, dict[str, Any], None] = None, +) -> GraphModule: + r"""Convert a calibrated or trained model to a reference quantized model, with + decomposed representation for quantized Tensor + see https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md for more details, + reference quantized model is a standard representation of a quantized model provided + by FX Graph Mode Quantization, it can be further lowered to run on the target + hardware, like accelerators + + Note: this is not public API + + Args: + * `graph_module` (GraphModule): A prepared and calibrated/trained model (GraphModule) + + * `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function. + See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. + + * `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert. + + * `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization. + See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. + + * `backend_config` (BackendConfig): A configuration for the backend which describes how + operators should be quantized in the backend. See + :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. + + Return: + A reference quantized model (GraphModule) with operators working with decomposed quantized Tensor + + Example:: + + # prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training + # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack + # e.g. backend_config = get_default_backend_config("fbgemm") + reference_quantized_model = _convert_to_reference_decomposed_fx(prepared_model) + + """ + torch._C._log_api_usage_once( + "quantization_api.quantize_fx._convert_to_reference_decomposed_fx" + ) + return _convert_fx( + graph_module, + is_reference=True, + convert_custom_config=convert_custom_config, + _remove_qconfig=False, + qconfig_mapping=qconfig_mapping, + backend_config=backend_config, + is_decomposed=True, + ) + + +def _convert_standalone_module_fx( + graph_module: GraphModule, + is_reference: bool = False, + convert_custom_config: Union[ConvertCustomConfig, dict[str, Any], None] = None, +) -> GraphModule: + r"""[Internal use only] Convert a model produced by :func:`~torch.ao.quantization.prepare_standalone_module_fx` + and convert it to a quantized model + + Returns a quantized standalone module, whether input/output is quantized is + specified by prepare_custom_config, with + input_quantized_idxs, output_quantized_idxs, please + see docs for prepare_fx for details + """ + return _convert_fx( + graph_module, + is_reference, + convert_custom_config, + is_standalone_module=True, + ) diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/quantize_jit.py b/phivenv/Lib/site-packages/torch/ao/quantization/quantize_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..50ab7a9e1ea55d4ea6fd855f8fee670d80fbee18 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/quantize_jit.py @@ -0,0 +1,421 @@ +# mypy: allow-untyped-defs + +import torch +from torch.ao.quantization.qconfig import QConfig +from torch.ao.quantization.quant_type import QuantType +from torch.jit._recursive import wrap_cpp_module + + +__all__ = [ + "script_qconfig", + "script_qconfig_dict", + "fuse_conv_bn_jit", + "prepare_jit", + "prepare_dynamic_jit", + "convert_jit", + "convert_dynamic_jit", + "quantize_jit", + "quantize_dynamic_jit", +] + + +def _check_is_script_module(model): + if not isinstance(model, torch.jit.ScriptModule): + raise ValueError("input must be a script module, got: " + str(type(model))) + + +def _check_forward_method(model): + if not model._c._has_method("forward"): + raise ValueError("input script module does not have forward method") + + +def script_qconfig(qconfig): + r"""Instantiate the activation and weight observer modules and script + them, these observer module instances will be deepcopied during + prepare_jit step. + """ + return QConfig( + activation=torch.jit.script(qconfig.activation())._c, + weight=torch.jit.script(qconfig.weight())._c, + ) + + +def script_qconfig_dict(qconfig_dict): + r"""Helper function used by `prepare_jit`. + Apply `script_qconfig` for all entries in `qconfig_dict` that is + not None. + """ + return {k: script_qconfig(v) if v else None for k, v in qconfig_dict.items()} + + +def fuse_conv_bn_jit(model, inplace=False): + r"""Fuse conv - bn module + Works for eval model only. + + Args: + model: TorchScript model from scripting or tracing + """ + torch._C._log_api_usage_once("quantization_api.quantize_jit.fuse_conv_bn_jit") + model_c = model._c + model_c = torch._C._jit_pass_fold_convbn(model_c) + if inplace: + model._reconstruct(model_c) + else: + model = wrap_cpp_module(model_c) + return model + + +def _prepare_jit(model, qconfig_dict, inplace=False, quant_type=QuantType.STATIC): + _check_is_script_module(model) + _check_forward_method(model) + if not all(isinstance(x, str) for x in qconfig_dict.keys()): + raise ValueError("qconfig_dict should only contain names(str) as keys.") + scripted_qconfig_dict = script_qconfig_dict(qconfig_dict) + model = fuse_conv_bn_jit(model, inplace) + model_c = torch._C._jit_pass_insert_observers( + model._c, "forward", scripted_qconfig_dict, inplace, quant_type + ) + if inplace: + model._reconstruct(model_c) + else: + model = wrap_cpp_module(model_c) + return model + + +def _prepare_ondevice_jit( + model, + qconfig_dict, + method_name="forward", + inplace=False, + quant_type=QuantType.STATIC, +): + _check_is_script_module(model) + if not all(isinstance(x, str) for x in qconfig_dict.keys()): + raise ValueError("qconfig_dict should only contain names(str) as keys.") + scripted_qconfig_dict = script_qconfig_dict(qconfig_dict) + method_graph = model._c._get_method(method_name).graph + torch._C._jit_pass_inline(method_graph) + model = fuse_conv_bn_jit(model, inplace) + model_c = torch._C._jit_pass_insert_observer_method_for_ondevice_ptq( + model._c, method_name, scripted_qconfig_dict, inplace, quant_type + ) + if inplace: + model._reconstruct(model_c) + else: + model = wrap_cpp_module(model_c) + return model + + +def prepare_jit(model, qconfig_dict, inplace=False): + torch._C._log_api_usage_once("quantization_api.quantize_jit.prepare_jit") + return _prepare_jit(model, qconfig_dict, inplace, quant_type=QuantType.STATIC) + + +def prepare_dynamic_jit(model, qconfig_dict, inplace=False): + torch._C._log_api_usage_once("quantization_api.quantize_jit.prepare_dynamic_jit") + return _prepare_jit(model, qconfig_dict, inplace, quant_type=QuantType.DYNAMIC) + + +def _prepare_ondevice_dynamic_jit( + model, qconfig_dict, method_name="forward", inplace=False +): + return _prepare_ondevice_jit( + model, qconfig_dict, method_name, inplace, quant_type=QuantType.DYNAMIC + ) + + +def _convert_jit( + model, inplace=False, debug=False, quant_type=QuantType.STATIC, preserved_attrs=None +): + _check_is_script_module(model) + model.eval() + model_c = model._c + model_c = torch._C._jit_pass_insert_quant_dequant( + model_c, "forward", inplace, debug, quant_type + ) + if not debug: + is_xpu = all(p.device.type == "xpu" for p in model.parameters()) + if not is_xpu: + # Moving model parameters to CPU since quantized operators + # are only supported on CPU and XPU right now + model.cpu() + if preserved_attrs is None: + preserved_attrs = [] + model_c = torch._C._jit_pass_quant_finalize( + model_c, quant_type, preserved_attrs + ) + if inplace: + model._reconstruct(model_c) + else: + model = wrap_cpp_module(model_c) + torch._C._jit_pass_constant_propagation(model.graph) + torch._C._jit_pass_dce(model.graph) + return model + + +def _convert_ondevice_jit( + model, method_name, inplace=False, debug=False, quant_type=QuantType.STATIC +): + _check_is_script_module(model) + assert quant_type == QuantType.DYNAMIC, ( + "This API, while should work for static quant, is only tested for dynamic quant." + ) + assert not method_name.startswith("observe_"), ( + "Pass in valid method to be quantized, e.g. forward" + ) + observe_method_name = "observe_" + method_name + quantize_method_name = "quantize_" + method_name + model_c = model._c + model_c = torch._C._jit_pass_insert_quant_dequant_for_ondevice_ptq( + model._c, observe_method_name, inplace, debug, QuantType.DYNAMIC + ) + model_c = torch._C._jit_pass_quant_finalize_for_ondevice_ptq( + model_c, QuantType.DYNAMIC, quantize_method_name + ) + if inplace: + model._reconstruct(model_c) + else: + model = wrap_cpp_module(model_c) + return model + + +def convert_jit(model, inplace=False, debug=False, preserved_attrs=None): + torch._C._log_api_usage_once("quantization_api.quantize_jit.convert_jit") + return _convert_jit( + model, + inplace, + debug, + quant_type=QuantType.STATIC, + preserved_attrs=preserved_attrs, + ) + + +def convert_dynamic_jit(model, inplace=False, debug=False, preserved_attrs=None): + torch._C._log_api_usage_once("quantization_api.quantize_jit.convert_dynamic_jit") + return _convert_jit( + model, + inplace, + debug, + quant_type=QuantType.DYNAMIC, + preserved_attrs=preserved_attrs, + ) + + +def _convert_ondevice_dynamic_jit(model, method_name, inplace=False, debug=False): + return _convert_ondevice_jit( + model, method_name, inplace, debug, quant_type=QuantType.DYNAMIC + ) + + +def _quantize_ondevice_dynamic_jit_impl( + model, qconfig_dict, method_name, inplace=False +): + model = _prepare_ondevice_dynamic_jit(model, qconfig_dict, method_name, inplace) + model = _convert_ondevice_dynamic_jit(model, method_name, inplace) + return model + + +def _quantize_jit( + model, + qconfig_dict, + run_fn=None, + run_args=None, + inplace=False, + debug=False, + quant_type=QuantType.STATIC, +): + # Always do inplace convert because the Tensor is already + # copied in prepare_jit when inplace is False + if quant_type == QuantType.DYNAMIC: + model = prepare_dynamic_jit(model, qconfig_dict, inplace) + model = convert_dynamic_jit(model, True, debug) + else: + assert run_fn, ( + "Must provide calibration function for post training static quantization" + ) + assert run_args, ( + "Must provide calibration dataset for post training static quantization" + ) + model = prepare_jit(model, qconfig_dict, inplace) + run_fn(model, *run_args) + model = convert_jit(model, True, debug) + + torch._C._jit_pass_constant_propagation(model.graph) + torch._C._jit_pass_dce(model.graph) + return model + + +def quantize_jit(model, qconfig_dict, run_fn, run_args, inplace=False, debug=False): + r"""Quantize the input float TorchScript model with + post training static quantization. + + First it will prepare the model for calibration, then it calls + `run_fn` which will run the calibration step, after that we will + convert the model to a quantized model. + + Args: + `model`: input float TorchScript model + `qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and + qconfig for that module as value, empty key means the qconfig will be applied + to whole model unless it's overwritten by more specific configurations, the + qconfig for each module is either found in the dictionary or fallback to + the qconfig of parent module. + + Right now qconfig_dict is the only way to configure how the model is quantized, + and it is done in the granularity of module, that is, we only support one type + of qconfig for each torch.nn.Module, and the qconfig for sub module will + override the qconfig for parent module, empty string means global configuration. + `run_fn`: a calibration function for calibrating the prepared model + `run_args`: positional arguments for `run_fn` + `inplace`: carry out model transformations in-place, the original module is + mutated + `debug`: flag for producing a debug friendly model (preserve weight attribute) + + Return: + Quantized TorchSciprt model. + + Example: + ```python + import torch + from torch.ao.quantization import get_default_qconfig + from torch.ao.quantization import quantize_jit + + ts_model = torch.jit.script( + float_model.eval() + ) # or torch.jit.trace(float_model, input) + qconfig = get_default_qconfig("fbgemm") + + + def calibrate(model, data_loader): + model.eval() + with torch.no_grad(): + for image, target in data_loader: + model(image) + + + quantized_model = quantize_jit( + ts_model, {"": qconfig}, calibrate, [data_loader_test] + ) + ``` + """ + torch._C._log_api_usage_once("quantization_api.quantize_jit.quantize_jit") + return _quantize_jit( + model, + qconfig_dict, + run_fn, + run_args, + inplace, + debug, + quant_type=QuantType.STATIC, + ) + + +def quantize_dynamic_jit(model, qconfig_dict, inplace=False, debug=False): + r"""Quantize the input float TorchScript model with + post training dynamic quantization. + Currently only qint8 quantization of torch.nn.Linear is supported. + + Args: + `model`: input float TorchScript model + `qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and + qconfig for that module as value, please see detailed + descriptions in :func:`~torch.ao.quantization.quantize_jit` + `inplace`: carry out model transformations in-place, the original module is + mutated + `debug`: flag for producing a debug friendly model (preserve weight attribute) + + Return: + Quantized TorchSciprt model. + + Example: + ```python + import torch + from torch.ao.quantization import per_channel_dynamic_qconfig + from torch.ao.quantization import quantize_dynamic_jit + + ts_model = torch.jit.script( + float_model.eval() + ) # or torch.jit.trace(float_model, input) + qconfig = get_default_qconfig("fbgemm") + + + def calibrate(model, data_loader): + model.eval() + with torch.no_grad(): + for image, target in data_loader: + model(image) + + + quantized_model = quantize_dynamic_jit( + ts_model, {"": qconfig}, calibrate, [data_loader_test] + ) + ``` + """ + torch._C._log_api_usage_once("quantization_api.quantize_jit.quantize_dynamic_jit") + return _quantize_jit( + model, qconfig_dict, inplace=inplace, debug=debug, quant_type=QuantType.DYNAMIC + ) + + +def _quantize_ondevice_dynamic_jit( + model, qconfig_dict, method_name="forward", inplace=False +): + r"""Prepares the input float TorchScript model with + *on-device* post training dynamic quantization. + Currently only qint8 quantization of torch.nn.Linear is supported. + + Args: + `model`: input float TorchScript model + `qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and + qconfig for that module as value, please see detailed + `method_name`: Name of the method within the model, to be prepared for quantization + descriptions in :func:`~torch.ao.quantization.quantize_jit` + `inplace`: carry out model transformations in-place, the original module is + mutated + + Return: + TorchScript model that is ready for on device quantization. + This means that the returned + model has: + - Method is inlined. + - Model has observer modules inserted in the model. + - Model has packed params inserted in the model. However they are empty as in they dont + contain valid quantized weights. + - observe_ is added that observe the values to be quantized. + - reset_observers_ to reset observers. + - quantize_ is added to the model. + - This method extract scale, zero points. + - Quantizes observed weights. + - Creates packed params from it and update the attribute of the model with the new values + for the packed params. + - Reset the original fp32 weights with empty tensor using SetAttr. + - quantized_ is added to the model. + - This method uses quantized weights and quantized linear ops instead of fp32 op. + - This method should be used for inference post PTQ. + - Note that all method's signatures should be the same as method_name. + + Later on device: + - Run reset_observers_ + - Run observe_ + - Run quantize_ + - Now model can be saved and loaded later. + - Run model with quantized_ + + Example: + ```python + import torch + from torch.ao.quantization import per_channel_dynamic_qconfig + from torch.ao.quantization.quantize_jit import _quantize_ondevice_dynamic_jit + + ts_model = torch.jit.script( + float_model.eval() + ) # or torch.jit.trace(float_model, input) + qconfig = get_default_qconfig("fbgemm") + quant_ready_model = _quantize_ondevice_dynamic_jit( + ts_model, {"": qconfig}, "forward", True + ) + ``` + """ + return _quantize_ondevice_dynamic_jit_impl( + model, qconfig_dict, method_name, inplace=inplace + ) diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/quantize_pt2e.py b/phivenv/Lib/site-packages/torch/ao/quantization/quantize_pt2e.py new file mode 100644 index 0000000000000000000000000000000000000000..e5cf36e8d6c986b20d5288427a1dfc42b2ad17e3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/quantize_pt2e.py @@ -0,0 +1,262 @@ +import typing_extensions + +import torch +from torch._export.passes.constant_folding import constant_fold +from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass +from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ +from torch.ao.quantization.quantizer import ( # noqa: F401 + DerivedQuantizationSpec, + FixedQParamsQuantizationSpec, + QuantizationAnnotation, + QuantizationSpec, + QuantizationSpecBase, + Quantizer, + SharedQuantizationSpec, +) +from torch.fx import GraphModule, Node +from torch.fx.passes.infra.pass_manager import PassManager + +from .pt2e.prepare import prepare +from .pt2e.qat_utils import _fold_conv_bn_qat, _fuse_conv_bn_qat +from .pt2e.representation import reference_representation_rewrite +from .pt2e.utils import _disallow_eval_train, _fuse_conv_bn_, _get_node_name_to_scope +from .quantize_fx import _convert_to_reference_decomposed_fx +from .utils import DEPRECATION_WARNING + + +__all__ = [ + "prepare_pt2e", + "prepare_qat_pt2e", + "convert_pt2e", +] + + +@typing_extensions.deprecated(DEPRECATION_WARNING) +def prepare_pt2e( + model: GraphModule, + quantizer: Quantizer, +) -> GraphModule: + """Prepare a model for post training quantization + + Args: + * `model` (torch.fx.GraphModule): a model captured by `torch.export.export_for_training` API. + * `quantizer`: A backend specific quantizer that conveys how user want the + model to be quantized. Tutorial for how to write a quantizer can be found here: + https://pytorch.org/tutorials/prototype/pt2e_quantizer.html + + Return: + A GraphModule with observer (based on quantizer annotation), ready for calibration + + Example:: + + import torch + from torch.ao.quantization.quantize_pt2e import prepare_pt2e + from torch.ao.quantization.quantizer import ( + XNNPACKQuantizer, + get_symmetric_quantization_config, + ) + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 10) + + def forward(self, x): + return self.linear(x) + + # initialize a floating point model + float_model = M().eval() + + # define calibration function + def calibrate(model, data_loader): + model.eval() + with torch.no_grad(): + for image, target in data_loader: + model(image) + + # Step 1. program capture + # NOTE: this API will be updated to torch.export API in the future, but the captured + # result shoud mostly stay the same + m = torch.export.export_for_training(m, *example_inputs).module() + # we get a model with aten ops + + # Step 2. quantization + # backend developer will write their own Quantizer and expose methods to allow + # users to express how they + # want the model to be quantized + quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config()) + m = prepare_pt2e(m, quantizer) + + # run calibration + # calibrate(m, sample_inference_data) + """ + torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_pt2e") + original_graph_meta = model.meta + node_name_to_scope = _get_node_name_to_scope(model) + # TODO: check qconfig_mapping to make sure conv and bn are both configured + # to be quantized before fusion + # TODO: (maybe) rewrite this with subgraph_rewriter + _fuse_conv_bn_(model) + model = quantizer.transform_for_annotation(model) + quantizer.annotate(model) + quantizer.validate(model) + model = prepare( + model, + node_name_to_scope, + is_qat=False, + obs_or_fq_callback=quantizer.prepare_obs_or_fq_callback, + ) + model.meta.update(original_graph_meta) + model = _disallow_eval_train(model) + return model + + +@typing_extensions.deprecated(DEPRECATION_WARNING) +def prepare_qat_pt2e( + model: GraphModule, + quantizer: Quantizer, +) -> GraphModule: + """Prepare a model for quantization aware training + + Args: + * `model` (torch.fx.GraphModule): see :func:`~torch.ao.quantization.quantize_pt2e.prepare_pt2e` + * `quantizer`: see :func:`~torch.ao.quantization.quantize_pt2e.prepare_pt2e` + + Return: + A GraphModule with fake quant modules (based on quantizer annotation), ready for + quantization aware training + + Example:: + import torch + from torch.ao.quantization.quantize_pt2e import prepare_qat_pt2e + from torch.ao.quantization.quantizer import ( + XNNPACKQuantizer, + get_symmetric_quantization_config, + ) + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 10) + + def forward(self, x): + return self.linear(x) + + # initialize a floating point model + float_model = M().eval() + + # define the training loop for quantization aware training + def train_loop(model, train_data): + model.train() + for image, target in data_loader: + ... + + # Step 1. program capture + # NOTE: this API will be updated to torch.export API in the future, but the captured + # result shoud mostly stay the same + m = torch.export.export_for_training(m, *example_inputs).module() + # we get a model with aten ops + + # Step 2. quantization + # backend developer will write their own Quantizer and expose methods to allow + # users to express how they + # want the model to be quantized + quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config()) + m = prepare_qat_pt2e(m, quantizer) + + # run quantization aware training + train_loop(prepared_model, train_loop) + + """ + torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_qat_pt2e") + original_graph_meta = model.meta + node_name_to_scope = _get_node_name_to_scope(model) + model = quantizer.transform_for_annotation(model) + quantizer.annotate(model) + quantizer.validate(model) + # Perform fusion after annotate to avoid quantizing ops in the new + # subgraph that don't need to be quantized + # TODO: only fuse if conv and bn are both configured to be quantized + _fuse_conv_bn_qat(model) + model = prepare( + model, + node_name_to_scope, + is_qat=True, + obs_or_fq_callback=quantizer.prepare_obs_or_fq_callback, + ) + model.meta.update(original_graph_meta) + model = _disallow_eval_train(model) + return model + + +_QUANT_OPS = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.quantize_per_channel.default, + torch.ops.pt2e_quant.quantize_affine, +] + + +def _quant_node_constraint(n: Node) -> bool: + """If there is any pure ops between get_attr and quantize op they will be const propagated + e.g. get_attr(weight) -> transpose -> quantize -> dequantize* + (Note: dequantize op is not going to be constant propagated) + + This filter is added because we don't want to constant fold the things that are not + related to quantization + """ + return n.op == "call_function" and n.target in _QUANT_OPS + + +@typing_extensions.deprecated(DEPRECATION_WARNING) +def convert_pt2e( + model: GraphModule, + use_reference_representation: bool = False, + fold_quantize: bool = True, +) -> GraphModule: + """Convert a calibrated/trained model to a quantized model + + Args: + * `model` (torch.fx.GraphModule): calibrated/trained model + * `use_reference_representation` (bool): boolean flag to indicate whether to produce referece representation or not + * `fold_quantize` (bool): boolean flag for whether fold the quantize op or not + + Returns: + quantized model, either in q/dq representation or reference representation + + Example:: + + # prepared_model: the model produced by `prepare_pt2e`/`prepare_qat_pt2e` and calibration/training + # `convert_pt2e` produces a quantized model that represents quantized computation with + # quantize dequantize ops and fp32 ops by default. + # Please refer to + # https://pytorch.org/tutorials/prototype/pt2e_quant_ptq_static.html#convert-the-calibrated-model-to-a-quantized-model + # for detailed explanation of output quantized model + quantized_model = convert_pt2e(prepared_model) + + """ + torch._C._log_api_usage_once("quantization_api.quantize_pt2e.convert_pt2e") + if not isinstance(use_reference_representation, bool): + raise ValueError( + "Unexpected argument type for `use_reference_representation`, " + f"please make sure you intend to pass argument {use_reference_representation} to convert_pt2e" + ) + original_graph_meta = model.meta + model = _convert_to_reference_decomposed_fx(model) + model = _fold_conv_bn_qat(model) + + pm = PassManager([DuplicateDQPass()]) + model = pm(model).graph_module + + pm = PassManager([PortNodeMetaForQDQ()]) + model = pm(model).graph_module + + if fold_quantize: + constant_fold(model, _quant_node_constraint) + + if use_reference_representation: + model = reference_representation_rewrite(model) + + model.meta.update(original_graph_meta) + model = _disallow_eval_train(model) + return model diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/__init__.py b/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e1a69cbda4845d845b7d3a4fe459f9e60f4b907b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/__init__.py @@ -0,0 +1,22 @@ +from .quantizer import ( + DerivedQuantizationSpec, + EdgeOrNode, + FixedQParamsQuantizationSpec, + QuantizationAnnotation, + QuantizationSpec, + QuantizationSpecBase, + Quantizer, + SharedQuantizationSpec, +) + + +__all__ = [ + "EdgeOrNode", + "Quantizer", + "QuantizationSpecBase", + "QuantizationSpec", + "FixedQParamsQuantizationSpec", + "SharedQuantizationSpec", + "DerivedQuantizationSpec", + "QuantizationAnnotation", +] diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b23fbd20b8426bfe2a2025a7586246ab9925a65 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/composable_quantizer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/composable_quantizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfcf3f8c5c502a224e1e9dffe44fb218fa39dccf Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/composable_quantizer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/embedding_quantizer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/embedding_quantizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..480b0013e1cd29c67b426de71af25be0b04bd6a0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/embedding_quantizer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/quantizer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/quantizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9253e9997455e469fc8fc0c0f4e6df62710abd0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/quantizer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f91c47d9b3984274dc53f0a65f1455a38094a3fe Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/x86_inductor_quantizer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/x86_inductor_quantizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95526d7ef163fd9c13ad931c953c055b32ade125 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/x86_inductor_quantizer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/xnnpack_quantizer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/xnnpack_quantizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3fe2a7d3ab808d806555df2db79d619d47e7cd80 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/xnnpack_quantizer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/xnnpack_quantizer_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/xnnpack_quantizer_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..924b6095a2dc1a86067ed03500f124962c6999a3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/xnnpack_quantizer_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/xpu_inductor_quantizer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/xpu_inductor_quantizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42f10bb1bfbf1b6d214a5526dbd949692276bdb6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/xpu_inductor_quantizer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/composable_quantizer.py b/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/composable_quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..fa7f1b16e3cbb97a3a213168a37e344b3b24c9a8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/composable_quantizer.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from .quantizer import QuantizationAnnotation, Quantizer + + +if TYPE_CHECKING: + import torch + from torch.fx import Node + +__all__ = [ + "ComposableQuantizer", +] + + +class ComposableQuantizer(Quantizer): + """ + ComposableQuantizer allows users to combine more than one quantizer into a single quantizer. + This allows users to quantize a model with multiple quantizers. E.g., embedding quantization + maybe supported by one quantizer while linear layers and other ops might be supported by another + quantizer. + + ComposableQuantizer is initialized with a list of `Quantizer` instances. + The order of the composition matters since that is the order in which the quantizers will be + applies. + Example: + ``` + embedding_quantizer = EmbeddingQuantizer() + linear_quantizer = MyLinearQuantizer() + xnnpack_quantizer = ( + XNNPackQuantizer() + ) # to handle ops not quantized by previous two quantizers + composed_quantizer = ComposableQuantizer( + [embedding_quantizer, linear_quantizer, xnnpack_quantizer] + ) + prepared_m = prepare_pt2e(model, composed_quantizer) + ``` + """ + + def __init__(self, quantizers: list[Quantizer]): + super().__init__() + self.quantizers = quantizers + self._graph_annotations: dict[Node, QuantizationAnnotation] = {} + + def _record_and_validate_annotations( + self, gm: torch.fx.GraphModule, quantizer: Quantizer + ) -> None: + for n in gm.graph.nodes: + if "quantization_annotation" in n.meta: + # check if the annotation has been changed by + # comparing QuantizationAnnotation object id + if n in self._graph_annotations and ( + id(self._graph_annotations[n]) + != id(n.meta["quantization_annotation"]) + ): + raise RuntimeError( + f"Quantizer {quantizer.__class__.__name__} has changed annotations on node {n}" + ) + else: + self._graph_annotations[n] = n.meta["quantization_annotation"] + else: + if n in self._graph_annotations: + raise RuntimeError( + f"Quantizer {quantizer.__class__.__name__} has removed annotations on node {n}" + ) + + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + """just handling global spec for now""" + for quantizer in self.quantizers: + quantizer.annotate(model) + self._record_and_validate_annotations(model, quantizer) + return model + + def transform_for_annotation( + self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + for quantizer in self.quantizers: + model = quantizer.transform_for_annotation(model) + return model + + def validate(self, model: torch.fx.GraphModule) -> None: + pass diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/embedding_quantizer.py b/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/embedding_quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..fcaf02a743fa2686dd826af58a2da5d51066a00d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/embedding_quantizer.py @@ -0,0 +1,97 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import copy + +import torch +import torch.nn.functional as F +from torch.ao.quantization.observer import PerChannelMinMaxObserver +from torch.ao.quantization.quantizer.quantizer import ( + QuantizationAnnotation, + QuantizationSpec, + Quantizer, +) +from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( + OperatorConfig, + OperatorPatternType, + QuantizationConfig, +) + + +__all__ = [ + "get_embedding_operators_config", + "EmbeddingQuantizer", +] + + +def get_embedding_operators_config() -> OperatorConfig: + weight_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + qscheme=torch.per_channel_affine_float_qparams, + ch_axis=0, + observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(eps=2**-12), + ) + quantization_config = QuantizationConfig(None, None, weight_quantization_spec, None) + ops: list[OperatorPatternType] = [[torch.nn.Embedding]] + ops.append([F.embedding]) + supported_config_and_operators = OperatorConfig( + config=quantization_config, operators=ops + ) + return copy.deepcopy(supported_config_and_operators) + + +class EmbeddingQuantizer(Quantizer): + def __init__(self) -> None: + super().__init__() + + @classmethod + def get_supported_quantization_configs(cls) -> list[QuantizationConfig]: + op_configs: set[QuantizationConfig] = { + spec for spec, _ in cls.get_supported_operators() + } + return list(op_configs) + + @classmethod + def get_supported_operator_for_quantization_config( + cls, quantization_config: QuantizationConfig + ) -> list[OperatorPatternType]: + for config, ops in cls.get_supported_operators(): + # note: this assumes each entry in cls.supported_spec_and_operators + # corresponds to one spec, e.g. we don't have + # [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)] + # where the first and second entry have the same spec but did not + # merge the op list + if config == quantization_config: + return ops + return [] + + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + """just handling global spec for now""" + self._annotate_embedding_ops(model.graph) + return model + + def _annotate_embedding_ops(self, graph: torch.fx.Graph) -> None: + embedding_config: OperatorConfig = get_embedding_operators_config() + for node in graph.nodes: + # Keep node parsing based annotations instead of module partitioners + # just as an example of alternate ways of annotating + if ( + node.op == "call_function" + and node.target == torch.ops.aten.embedding.default + ): + if embedding_config.config.weight is None: + raise ValueError( + "Embedding config must have a valid weight quantization spec." + ) + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + node.args[0]: embedding_config.config.weight, + } + ) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + @classmethod + def get_supported_operators(cls) -> list[OperatorConfig]: + return [get_embedding_operators_config()] diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/quantizer.py b/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..93aeb77b50b14ce2d3fe59ae52e0f8addb503921 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/quantizer.py @@ -0,0 +1,181 @@ +# mypy: allow-untyped-defs +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Callable, Optional, Union + +import torch +from torch import Tensor +from torch.ao.quantization import ObserverOrFakeQuantize +from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor +from torch.fx import Node + + +__all__ = [ + "Quantizer", + "QuantizationSpecBase", + "QuantizationSpec", + "FixedQParamsQuantizationSpec", + "EdgeOrNode", + "SharedQuantizationSpec", + "DerivedQuantizationSpec", + "QuantizationAnnotation", +] + + +class QuantizationSpecBase(ABC): # noqa: B024 + """Base class for different types of quantization specs that allows users to + specify how to quantize a Tensor (input/output of a Node) in the model + """ + + +@dataclass(eq=True, frozen=True) +class QuantizationSpec(QuantizationSpecBase): + """Quantization spec for common operators that allows user to specify how to + quantize a Tensor, this includes dtype, quant_min, quant_max etc. + """ + + dtype: torch.dtype + # observer or fake_quantize constructor such as + # MinMaxObserver, PerChannelHistogramObserver etc. + # or we can attach some custom args to them + # e.g. MinMaxObserver.with_args(eps=eps) + observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor + quant_min: Optional[int] = None + quant_max: Optional[int] = None + qscheme: Optional[torch.qscheme] = None + ch_axis: Optional[int] = None + is_dynamic: bool = False + + def __post_init__(self): + # TODO: add init for quant_min/quant_max + # quant_min must be less than quant_max + if ( + self.quant_min is not None + and self.quant_max is not None + and self.quant_min > self.quant_max + ): + raise ValueError( + f"quant_min {self.quant_min} must be <= quant_max {self.quant_max}." + ) + + # ch_axis must be less than the number of channels + # but no way to check here. Just check that it is not < 0. + if self.ch_axis is not None and self.ch_axis < 0: + raise ValueError("Ch_axis is < 0.") + + +@dataclass(eq=True, frozen=True) +class FixedQParamsQuantizationSpec(QuantizationSpecBase): + dtype: torch.dtype + scale: float + zero_point: int + quant_min: Optional[int] = None + quant_max: Optional[int] = None + qscheme: Optional[torch.qscheme] = None + is_dynamic: bool = False + + +""" +The way we refer to other points of quantization in the graph will be either +an input edge or an output value +input edge is the connection between input node and the node consuming the input, so it's a Tuple[Node, Node] +output value is an fx Node +""" +EdgeOrNode = Union[tuple[Node, Node], Node] +EdgeOrNode.__module__ = "torch.ao.quantization.quantizer.quantizer" + + +@dataclass(eq=True, frozen=True) +class SharedQuantizationSpec(QuantizationSpecBase): + """ + Quantization spec for the Tensors whose quantization parameters are shared with other Tensors + """ + + # the edge or node to share observer or fake quant instances with + edge_or_node: EdgeOrNode + + +@dataclass(eq=True, frozen=True) +class DerivedQuantizationSpec(QuantizationSpecBase): + """Quantization spec for the Tensors whose quantization parameters are derived from other Tensors""" + + derived_from: list[EdgeOrNode] + derive_qparams_fn: Callable[[list[ObserverOrFakeQuantize]], tuple[Tensor, Tensor]] + dtype: torch.dtype + quant_min: Optional[int] = None + quant_max: Optional[int] = None + qscheme: Optional[torch.qscheme] = None + ch_axis: Optional[int] = None + is_dynamic: bool = False + + +@dataclass +class QuantizationAnnotation: + """How are input arguemnt or output should be quantized, + expressed as QuantizationSpec, this corresponds to how a Tensor in the + operator Graph is observed (PTQ) or fake quantized (QAT) + """ + + # a map from torch.fx.Node to a type of QuantizationSpecBase + input_qspec_map: dict[Node, Optional[QuantizationSpecBase]] = field( + default_factory=dict + ) + + # How the output of this node is quantized, expressed as QuantizationSpec + # TODO: change the value to QuantizationSpec in a separate PR + output_qspec: Optional[QuantizationSpecBase] = None + + # For a Node: node1 and edge: (node1, node2), since they are observing the same + # Tensor, we may want to implicitly share observers, this flag allows people to + # turn off this behavior for the output of the node + allow_implicit_sharing: bool = True + + # whether the node is annotated or not + _annotated: bool = False + + +class Quantizer(ABC): + def transform_for_annotation( + self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + """Allows for user defined transforms to run before annotating the graph. + This allows quantizer to allow quantizing part of the model that are otherwise not quantizable. + For example quantizer can + a) decompose a compound operator like scaled dot product attention, + into bmm and softmax if quantizer knows how to quantize bmm/softmax but not sdpa + or b) transform scalars to tensor to allow quantizing scalares. + + Note: this is an optional method + """ + return model + + # annotate nodes in the graph with observer or fake quant constructors + # to convey the desired way of quantization + @abstractmethod + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + pass + + # validate the annotated graph is supported by the backend + @abstractmethod + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + def prepare_obs_or_fq_callback( + self, + model: torch.fx.GraphModule, + edge_or_node_to_obs_or_fq: dict[EdgeOrNode, ObserverOrFakeQuantize], + ) -> None: + """A callback that will be called after the observers or fake quants are created + for each sharing group, but before they are inserted into the graph. The + callback can be used to make final quantization adjustments, such as enforcing + specific scale and zero point on model input or output. + + Args: + * `model`: the graph module being prepared. + * `edge_or_node_to_obs_or_fq`: a dictionary mapping each annotated edge and + node to the corresponding observer or fake quant object. Note that multiple + edges and/or nodes can map to the same observer / fake quant instance if + they were annotated with SharedQuantizationSpec. This dictionary can be + modified by the callback. + """ + return diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/utils.py b/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1ebc4c765905ee55e9fdb7310578872f98402251 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/utils.py @@ -0,0 +1,82 @@ +# mypy: allow-untyped-defs + +from torch.ao.quantization.pt2e.utils import _is_sym_size_node +from torch.ao.quantization.quantizer.quantizer import QuantizationAnnotation +from torch.fx import Node + + +def _annotate_input_qspec_map(node: Node, input_node: Node, qspec): + quantization_annotation = node.meta.get( + "quantization_annotation", QuantizationAnnotation() + ) + if quantization_annotation.input_qspec_map is None: + quantization_annotation.input_qspec_map = {} + quantization_annotation.input_qspec_map[input_node] = qspec + node.meta["quantization_annotation"] = quantization_annotation + + +def _annotate_output_qspec(node: Node, qspec): + quantization_annotation = node.meta.get( + "quantization_annotation", QuantizationAnnotation() + ) + quantization_annotation.output_qspec = qspec + node.meta["quantization_annotation"] = quantization_annotation + + +def _node_only_used_for_sym_size(node: Node, partition_nodes: list[Node]): + """ + This utility is used to handle cases when dynami_shape=True tracing leads + to symint nodes in the pattern of linear module. In those cases, we need to + distinguish between the nodes that are in input for just extracting value of + some dimentions (and symint nodes) vs. the one that is activation. + For example: + graph(x, y, weight): + size_0 = torch.ops.aten.sym_size([x], [0]) + size_1 = torch.ops.aten.sym_size([y], [1]) + view_size = size_0 * size_1 + size_3 = torch.ops.aten.sym_size([x], [2]) + vie_out = torch.ops.aten.view(x, [view_size, size_3]) + return mm(view_out, weight) + In the example above y node is not actual input. It exist only to extract size_1 + """ + if _is_sym_size_node(node): + return True + + return all( + ((user not in partition_nodes) or _is_sym_size_node(user)) + for user in node.users + ) + + +def _get_module_name_filter(module_name: str): + """Get the module_name_filter function for a given module name, the filter accepts + a node and checks if the node comes from a module that has certain module name + + For example: + node: linear_op = call_function[...](...) # comes from a module with name blocks.sub.linear1 + + + >> module_name_filter = _get_module_name_filter("blocks.sub") + >> print(module_name_filter(node)) + True # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1" + """ + + def module_name_filter(n: Node) -> bool: + # example: { + # 'L__self___sub': ("L['self'].sub", ), + # 'L__self___sub_linear': ("L['self'].sub.linear", ) + # } + # get_attr nodes doesn't have nn_module_stack? + nn_module_stack = n.meta.get("nn_module_stack", {}) + + def _normalize_path(n): + prefix = 0 + # TODO This is non standard behavior and should be removed when we migrate off capture_pre_autograd_graph. + if n.startswith("L['self']."): + prefix = len("L['self'].") + return n[prefix:] + + names = [_normalize_path(n) for n, _ in nn_module_stack.values()] + return module_name in names + + return module_name_filter diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/x86_inductor_quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..cafb5d844ebb9ab38c31c354f986eb8683f81302 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -0,0 +1,1575 @@ +# mypy: allow-untyped-defs +import functools +import itertools +import operator +import warnings +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any, Callable, Optional, TYPE_CHECKING, Union +from typing_extensions import TypeAlias + +import torch +import torch.nn.functional as F +from torch.ao.quantization.fake_quantize import ( + FakeQuantize, + FusedMovingAvgObsFakeQuantize, +) +from torch.ao.quantization.observer import ( + HistogramObserver, + MovingAverageMinMaxObserver, + MovingAveragePerChannelMinMaxObserver, + PerChannelMinMaxObserver, + PlaceholderObserver, +) +from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions +from torch.ao.quantization.quantizer.quantizer import ( + QuantizationAnnotation, + QuantizationSpec, + Quantizer, + SharedQuantizationSpec, +) +from torch.ao.quantization.quantizer.utils import _get_module_name_filter +from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( + get_bias_qspec, + get_input_act_qspec, + get_output_act_qspec, + get_weight_qspec, + QuantizationConfig, +) +from torch.fx import Node +from torch.fx.passes.utils.source_matcher_utils import ( + get_source_partitions, + SourcePartition, +) + + +FilterFn: TypeAlias = Callable[[list[Node]], bool] + + +if TYPE_CHECKING: + from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor + +__all__ = [ + "X86InductorQuantizer", + "get_default_x86_inductor_quantization_config", + "get_x86_inductor_linear_dynamic_fp16_config", +] + + +@dataclass +class _X86InductorQuantizationAnnotation(QuantizationAnnotation): + # _is_output_of_quantized_pattern: + # * Node as output node of a fusion pattern. + # * The fusion pattern supports int8 data type. + # * The fusion pattern has inputs annotated to insert observer. + # * The quantization_config is not `None`. + _is_output_of_quantized_pattern: bool = False + + +# Operators that: +# 1. Operators are optimized to run with int8 when int8 input provided. +# 2. Operators do not support int8 input and produce fp32 output. +int8_in_int8_out_ops: set = { + torch.ops.aten.max_pool2d.default, + torch.ops.aten.cat.default, + torch.ops.aten.avg_pool2d.default, + torch.ops.aten.adaptive_avg_pool2d.default, + torch.ops.aten.flatten.using_ints, +} + +# Operators that support the int8 data type for quantization config propagation. +# A superset of int8_in_int8_out_ops incorporating additional operators. +propagation_quantizable_ops = int8_in_int8_out_ops + +# Operators support the int8 data type +# and recipe is configured by default in X86InductorQuantizer. +default_quantizable_ops = propagation_quantizable_ops | { + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.linear.default, +} + +# A superset of default_quantizable_ops includes operators support the int8 data type +# but not enabled by default recipe of X86InductorQuantizer. +quantizable_ops = default_quantizable_ops | { + torch.ops.aten.matmul.default, +} + +QUANT_ANNOTATION_KEY = "quantization_annotation" + + +def _skip_annotate(nodes: list[Node], filter_fn: Optional[FilterFn] = None) -> bool: + """Determine whether to skip annotation for a list of nodes.""" + + # 1) Skip annotate if any node is already annotated + if _is_any_annotated(nodes): + return True + + # 2) Proceed annotate if a) a filter function is provided + # and b) the given nodes list passes the filter function check. + if filter_fn and filter_fn(nodes): + return False + + return True + + +def _create_module_name_filter(module_name: str) -> FilterFn: + """Create a filter function for a given module name. + + The filter function takes a list of nodes (as determined by the annotate function) + and return True if *all* nodes come from the specified module name, False otherwise. + + For example: + linear_1: "f32[3, 10]" = torch.ops.aten.linear.default(...) # comes from a module with name `sub.linear1` + relu: "f32[3, 10]" = torch.ops.aten.relu.default(linear_1); # comes from a module with name `sub.relu1` + + >> module_name_filter = _create_module_name_filter_inner("sub") + >> print(module_name_filter([relu, linear_1])) + # True # These two nodes are determined by `_annotate_linear_unary` function and from "sub". + """ + + filter_fn = _get_module_name_filter(module_name) + + def check_all_nodes_from_module(nodes: list[Node]) -> bool: + all_nodes_from_module_name: bool = all(filter_fn(n) for n in nodes) + return all_nodes_from_module_name + + return check_all_nodes_from_module + + +def _create_operator_type_filter( + operator_type: Callable, +) -> FilterFn: + """Create a filter function for a given operator type. + + The filter function takes a list of nodes and returns True if it contains + exactly one node with the specified operator type, False otherwise. + + For example: + linear_1: "f32[3, 10]" = torch.ops.aten.linear.default(...) # comes from a module with name `sub.linear1` + relu: "f32[3, 10]" = torch.ops.aten.relu.default(linear_1); # comes from a module with name `sub.relu1` + + >> operator_type_filter = _create_operator_type_filter(torch.ops.aten.linear.default) + >> print(operator_type_filter([relu, linear_1])) + # True # These two nodes are determined by `_annotate_linear_unary` function and the second node is `linear`. + """ + + def operator_type_filter(nodes: list[Node]): + num_nodes_with_operator_type = sum( + node.target == operator_type for node in nodes + ) + if num_nodes_with_operator_type > 1: + raise NotImplementedError( + f"Several nodes within a single pattern are {operator_type}." + ) + return num_nodes_with_operator_type == 1 + + return operator_type_filter + + +def _global_config_filter(nodes: list[Node]) -> bool: + """Filter function for global configuration. + + This filter function takes a list of nodes and returns True if there is exactly one node + in the list that is a default quantizable operation, False otherwise. + """ + num_nodes_in_default_quantizable_ops = sum( + node.target in default_quantizable_ops for node in nodes + ) + if num_nodes_in_default_quantizable_ops > 1: + raise NotImplementedError( + "Several nodes within a single pattern are default quantizable operations." + ) + return num_nodes_in_default_quantizable_ops == 1 + + +def _map_module_function_to_aten_operator_type(): + module_function_to_aten_operator: dict[Callable, torch._ops.OpOverloadPacket] = {} + map_list = ( + ([torch.nn.Conv2d, F.conv1d], torch.ops.aten.conv1d.default), + ([torch.nn.Conv2d, F.conv2d], torch.ops.aten.conv2d.default), + ([torch.nn.Linear, F.linear], torch.ops.aten.linear.default), + ([torch.nn.MaxPool2d, F.max_pool2d], torch.ops.aten.max_pool2d.default), + ( + [ + torch.cat, + ], + torch.ops.aten.cat.default, + ), + ([torch.nn.AvgPool2d, F.avg_pool2d], torch.ops.aten.avg_pool2d.default), + ( + [torch.nn.AdaptiveAvgPool2d, F.adaptive_avg_pool2d], + torch.ops.aten.adaptive_avg_pool2d.default, + ), + ( + [ + torch.flatten, + ], + torch.ops.aten.flatten.using_ints, + ), + ( + [ + torch.matmul, + ], + torch.ops.aten.matmul.default, + ), + ) + for map_item in map_list: + module_function_to_aten_operator.update(dict.fromkeys(map_item[0], map_item[1])) # type: ignore[arg-type, call-overload] + return module_function_to_aten_operator + + +def _mark_nodes_as_annotated(nodes: list[Node]): + for node in nodes: + if node is not None: + if QUANT_ANNOTATION_KEY not in node.meta: + node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation() + node.meta[QUANT_ANNOTATION_KEY]._annotated = True + + +def _is_node_annotated(_node): + """ + return True if the node is annotated, otherwise return False + """ + return ( + QUANT_ANNOTATION_KEY in _node.meta + and _node.meta[QUANT_ANNOTATION_KEY]._annotated + ) + + +def _is_any_annotated(nodes: list[Node]): + """ + Given a list of nodes (that represents an operator pattern), + check if any of the node is annotated, return True if any of the node + is annotated, otherwise return False. + """ + return any(_is_node_annotated(node) for node in nodes) + + +def _is_all_annotated(nodes: list[Node]): + """ + Given a list of nodes (that represents an operator pattern), + return True if all of the node is annotated, otherwise return False. + """ + return all(_is_node_annotated(node) for node in nodes) + + +def _is_quantized_op_pt2e(node: torch.fx.Node): + """ + Used for pt2e flow to check if the node is a quantized node: + Case1: the node has been annotated as output node of a fusion pattern. + Case2: the node has been annotated as single quantized node. + """ + if not _is_any_annotated([node]): + # The node has not been annotated, directly return False + return False + quantization_annotation = node.meta.get(QUANT_ANNOTATION_KEY, None) + assert isinstance(quantization_annotation, _X86InductorQuantizationAnnotation) + return quantization_annotation._is_output_of_quantized_pattern + + +@functools.lru_cache +def get_default_x86_inductor_quantization_config( + is_qat: bool = False, + is_dynamic: bool = False, + reduce_range: bool = False, +): + """ + reduce_range is False by default. Set it to True on earlier CPUs without VNNI to avoid accuracy issue. + """ + extra_args: dict[str, Any] = {"eps": 2**-12} + if is_qat: + if is_dynamic: + act_observer_or_fake_quant_ctr = FakeQuantize + dynamic_quant_observer = MovingAverageMinMaxObserver.with_args( + averaging_constant=1 + ) + extra_args["observer"] = dynamic_quant_observer + else: + act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment] + else: + if is_dynamic: + act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment] + else: + act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment] + + # Copy from x86 default qconfig from torch/ao/quantization/qconfig.py + act_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=127 if reduce_range else 255, + qscheme=torch.per_tensor_affine, + is_dynamic=is_dynamic, + observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args( + **extra_args + ), + ) + + weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( + FusedMovingAvgObsFakeQuantize if is_qat else PerChannelMinMaxObserver + ) + + if is_qat: + # Only support per channel quant for now + extra_args["observer"] = MovingAveragePerChannelMinMaxObserver # type: ignore[dict-item] + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_channel_symmetric, + ch_axis=0, # 0 corresponding to weight shape = (oc, ic, kh, kw) of conv + is_dynamic=False, + observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args( + **extra_args + ), + ) + bias_quantization_spec = None # will use placeholder observer by default + quantization_config = QuantizationConfig( + act_quantization_spec, + act_quantization_spec, + weight_quantization_spec, + bias_quantization_spec, + is_qat, + ) + return quantization_config + + +@functools.lru_cache +def get_x86_inductor_linear_dynamic_fp16_config(): + """ + For linear_dynamic_fp16. The name may be confusing. + The op's behavior is fp32_input * (fp16_weight -> to_fp32) -> fp32_output. + """ + weight_quantization_spec = QuantizationSpec( + dtype=torch.float16, + observer_or_fake_quant_ctr=PlaceholderObserver, + ) + quantization_config = QuantizationConfig( + None, # input_quantization_spec + None, # output_quantization_spec + weight_quantization_spec, + None, # bias_quantization_spec + ) + return quantization_config + + +def _annotate_nodes_not_quantize(nodes: Union[Node, list[Node]]) -> None: + """Annotate nodes to exclude them from quantization (their `quantization_config` is `None`).""" + if not isinstance(nodes, list): + nodes = [nodes] + for node in nodes: + node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + _annotated=True + ) + + +def _config_checker(method: Callable) -> Callable: + @functools.wraps(method) + def wrapper( + quantizer: "X86InductorQuantizer", + name: Any, + quantization_config: Optional["QuantizationConfig"], + ) -> "X86InductorQuantizer": + if quantizer._need_skip_config(quantization_config): + warnings.warn( + f"Skip the quantization config for {name}.", + ) + return quantizer + return method(quantizer, name, quantization_config) + + return wrapper + + +@dataclass +class _CurrentQuantizationMode: + r"""Configuration defining the current quantization mode for the quantizer. + + All possible current quantization modes are listed below: + ---------------------------------------------------------------------------------------------------------- + | dynamic_state + qat_state |--------------------------------------------------------------------------------------------- + | None | True | False + ---------------------------------------------------------------------------------------------------------- + None | quantizer does not receive a non-None `quantization_config` | \ | \ + False | quantizer will not do QAT | dynamic | static + True | quantizer will do QAT | QAT + dynamic | QAT + static + """ + + qat_state: Optional[bool] + dynamic_state: Optional[bool] + + +class X86InductorQuantizer(Quantizer): + module_function_to_aten_operator_type = _map_module_function_to_aten_operator_type() + + def __init__(self) -> None: + super().__init__() + self.global_config: Optional[QuantizationConfig] = None + self.operator_type_qconfig: dict[ + torch._ops.OpOverloadPacket, Optional[QuantizationConfig] + ] = {} + self.module_name_qconfig: dict[str, Optional[QuantizationConfig]] = {} + + def _get_current_quantization_mode(self) -> _CurrentQuantizationMode: + """Retrieves the current quantization mode based on all configurations.""" + qat_state = None + dynamic_state = None + + # As we use `_need_skip_config` to skip all invalid configurations, + # we can safely assume that the all existing non-None configurations + # have the same quantization mode. + for qconfig in ( + list(self.module_name_qconfig.values()) + + list(self.operator_type_qconfig.values()) + + [self.global_config] + ): + if qconfig is not None: + # Query the `is_qat` state + if qat_state is None: + qat_state = qconfig.is_qat + else: + assert qat_state == qconfig.is_qat, ( + f"All non-None quantization configs should have the same `is_qat`," + f"but got {qat_state} and {qconfig.is_qat}." + ) + # Query the `is_dynamic` state + input_activation_spec = qconfig.input_activation + if input_activation_spec is not None: + if dynamic_state is None: + dynamic_state = input_activation_spec.is_dynamic + else: + assert dynamic_state == input_activation_spec.is_dynamic, ( + f"All non-None `input_activation_spec` should have the same `is_dynamic`," + f"but got {dynamic_state} and {input_activation_spec.is_dynamic}." + ) + return _CurrentQuantizationMode( + qat_state=qat_state, dynamic_state=dynamic_state + ) + + def _need_skip_config( + self, quantization_config: Optional[QuantizationConfig] + ) -> bool: + """Check if the provided quantization config is valid for X86InductorQuantizer. + + Mixed static/dynamic configurations or mixed QAT/non-QAT configurations are not supported. + To avoid such a mix, we compare the incoming configuration with current configuration status. + Refer the `_CurrentQuantizationMode` definition for all possible modes. + """ + if quantization_config is None: + return False + + need_skip = False + current_mode = self._get_current_quantization_mode() + if ( + current_mode.qat_state is not None + and current_mode.qat_state != quantization_config.is_qat + ): + warnings.warn("Mixed QAT and Non-QAT quantization config is not supported.") + need_skip = True + if current_mode.dynamic_state is not None: + input_activation_spec = quantization_config.input_activation + if ( + input_activation_spec is not None + and current_mode.dynamic_state != input_activation_spec.is_dynamic + ): + warnings.warn( + "Mixed dynamic and static quantization config is not supported." + ) + need_skip = True + return need_skip + + def set_global(self, quantization_config: QuantizationConfig): + if self._need_skip_config(quantization_config): + warnings.warn("Skip the global quantization config.") + return self + self.global_config = quantization_config + return self + + def get_global_quantization_config(self): + if not isinstance(self.global_config, QuantizationConfig): + warnings.warn( + "The global_config for X86InductorQuantizer is currently invalid. \ + Please ensure that you use set_global to establish the global quantization configuration." + ) + return self.global_config + + @_config_checker + def set_function_type_qconfig( + self, + function_type: Callable, + quantization_config: Optional[QuantizationConfig], + ) -> "X86InductorQuantizer": + if function_type in X86InductorQuantizer.module_function_to_aten_operator_type: + self._set_aten_operator_qconfig( + X86InductorQuantizer.module_function_to_aten_operator_type[ + function_type + ], + quantization_config, + ) + else: + warnings.warn( + f"function: Unable to customize quantization config for {function_type} by X86InductorQuantizer." + ) + return self + + @_config_checker + def set_module_type_qconfig( + self, + module_type: torch.nn.Module, + quantization_config: Optional[QuantizationConfig], + ) -> "X86InductorQuantizer": + if module_type in X86InductorQuantizer.module_function_to_aten_operator_type: + self._set_aten_operator_qconfig( + X86InductorQuantizer.module_function_to_aten_operator_type[module_type], + quantization_config, + ) + else: + warnings.warn( + f"Module: Unable to customize quantization config for {module_type} by X86InductorQuantizer." + ) + return self + + @_config_checker + def set_module_name_qconfig( + self, module_name: str, quantization_config: Optional[QuantizationConfig] + ): + """Set quantization_config for a submodule with name: `module_name`, for example: + quantizer.set_module_name_qconfig("blocks.sub"), it will quantize all supported operator/operator + patterns in the submodule with this module name with the given `quantization_config` + + The supported operators include `quantizable_ops` and `propagation_quantizable_ops`. + """ + self.module_name_qconfig[module_name] = quantization_config + return self + + def _set_aten_operator_qconfig( + self, + operator_type: torch._ops.OpOverloadPacket, + quantization_config: Optional[QuantizationConfig], + ) -> "X86InductorQuantizer": + if operator_type in quantizable_ops: + self.operator_type_qconfig[operator_type] = quantization_config + else: + warnings.warn( + f"operator: Unable to quantize {operator} by X86InductorQuantizer." + ) + return self + + def _annotate_conv_node_helper( + self, + conv_node: torch.fx.Node, + annotate_output: bool, + quantization_config: Optional[QuantizationConfig], + ) -> None: + """Helper function to annotate the conv node""" + if quantization_config is None: + _annotate_nodes_not_quantize(conv_node) + return + input_qspec_map = {} + input_node = conv_node.args[0] + assert isinstance(input_node, Node) + input_qspec_map[input_node] = get_input_act_qspec(quantization_config) + weight_node = conv_node.args[1] + assert isinstance(weight_node, Node) + input_qspec_map[weight_node] = get_weight_qspec(quantization_config) + bias_node = None if len(conv_node.args) == 2 else conv_node.args[2] + if isinstance(bias_node, Node): + input_qspec_map[bias_node] = get_bias_qspec(quantization_config) + if annotate_output: + conv_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + else: + conv_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + ) + + def _annotate_linear_node_helper( + self, + linear_node: torch.fx.Node, + annotate_output: bool, + quantization_config: Optional[QuantizationConfig], + ) -> None: + """Helper function to annotate the linear node""" + if quantization_config is None: + _annotate_nodes_not_quantize(linear_node) + return + input_qspec_map = {} + assert linear_node.target in (torch.ops.aten.linear.default,) + has_bias = len(linear_node.args) == 3 + input_index = 0 + weight_index = 1 + bias_index = 2 + + input_node = linear_node.args[input_index] + assert isinstance(input_node, Node) + input_qspec_map[input_node] = get_input_act_qspec(quantization_config) + + weight_node = linear_node.args[weight_index] + assert isinstance(weight_node, Node) + input_qspec_map[weight_node] = get_weight_qspec(quantization_config) + + bias_node = linear_node.args[bias_index] if has_bias else None + if isinstance(bias_node, Node): + input_qspec_map[bias_node] = get_bias_qspec(quantization_config) + + if annotate_output: + linear_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + else: + linear_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=input_qspec_map, _annotated=True + ) + + def _get_output_nodes_of_partitions( + self, + partition_list: list[SourcePartition], + ) -> list[torch.fx.Node]: + """Helper function to get the output node list from partition list""" + output_node_list = [] + for partition in partition_list: + if len(partition.output_nodes) > 1: + raise ValueError("Input partition has more than one output node") + output_node = partition.output_nodes[0] + assert isinstance(output_node, Node) + output_node_list.append(output_node) + if len(output_node_list) != len(partition_list): + raise ValueError( + "length of output_node_list should equal to length of partition_list" + ) + return output_node_list + + def _get_input_idx_for_binary_node( + self, + conv_gemm_node: torch.fx.Node, + binary_node: torch.fx.Node, + ): + """Helper function to check conv_gemm and extra input node index + for binary node fused with conv_gemm. + """ + conv_gemm_node_idx = None + extra_input_node_idx = None + if (binary_node.args[0].op == "call_function") and ( # type: ignore[union-attr] + binary_node.args[0] == conv_gemm_node + ): + conv_gemm_node_idx = 0 + extra_input_node_idx = 1 + elif (binary_node.args[1].op == "call_function") and ( # type: ignore[union-attr] + binary_node.args[1] == conv_gemm_node + ): + conv_gemm_node_idx = 1 + extra_input_node_idx = 0 + extra_input_node = binary_node.args[extra_input_node_idx] # type: ignore[index] + assert isinstance(extra_input_node, Node) + return conv_gemm_node_idx, extra_input_node_idx + + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + """Annotate the given model with quantization configurations. + + Annotation contracts: + 1. Annotate each node according to the user's qconfig in the following order: + `module_name_qconfig`, `operator_type_qconfig`, and `global_config`. + 2. Avoid re-annotating nodes already annotated in prior stages. For example, + if `linear1` has been annotated by `module_name_qconfig`, it won't be annotated again + during the processing of the 'operator_type_qconfig' or 'global_config'. + 3. For config is `None`, the node will be annotated with `_X86InductorQuantizationAnnotation(_annotated=True)`. + + For each pair of (module_name_or_operator_type_or_global, qconfig), a filter function is created. + This filter function checks if the node is marked by current stage and not annotated by the previous stage. + """ + for module_name, quantization_config in self.module_name_qconfig.items(): + self._annotate_with_config( + model, quantization_config, _create_module_name_filter(module_name) + ) + + for operator_type, quantization_config in self.operator_type_qconfig.items(): + self._annotate_with_config( + model, quantization_config, _create_operator_type_filter(operator_type) + ) + + if self.global_config: + self._annotate_with_config( + model, + self.global_config, + _global_config_filter, + ) + + # Once we've annotated the model with quantization configurations, we also need to annotate + # the output of quantizable operations. For example, if we annotated `maxpool2d` to quantize its inputs, + # we will quantize its output accordingly. This enables us to fuse the dq-operator-q into a quantized op. + # Refer to + # https://github.com/intel/intel-extension-for-pytorch/blob/90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L487 # noqa: B950 + + self._annotate_output_for_int8_in_int8_out_pattern_entry(model) + + return model + + def _annotate_with_config( + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: FilterFn, + ) -> None: + """Annotate the model with the given quantization configuration. + + High-level description of quantization recipe for X86 Inductor Backend: + Step 1: Apply quantization recipe for fusion patterns of conv/linear to enable int8 data type actively. + Step 2: Propagate quantization annotation for patterns besides conv/linear. Go through the pattern in model + from start to the end. If a pattern supports computation with int8 data type and inputs connected to + quantized patterns, annotate its inputs as quantized pattern. + """ + + # Step1: Recipe of fusion patterns like conv/linear. + self._annotate_conv2d_fusion_pattern(model, quantization_config, filter_fn) + self._annotate_linear_fusion_pattern(model, quantization_config, filter_fn) + self._annotate_matmul(model, quantization_config, filter_fn) + + # Step2: Recipe to propagate annotation for patterns beside conv/linear. + # Go through all the nodes from start to end. + # Recipe refer to + # https://github.com/intel/intel-extension-for-pytorch/blob/90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L538 # noqa: B950 + + self._annotate_propagation_quantizable_pattern_entry( + model, quantization_config, filter_fn + ) + + def _annotate_qat_conv2d_fusion_pattern( + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ): + # Annotate QAT Specific patterns + self._annotate_qat_conv2d_bn_binary_unary(model, quantization_config, filter_fn) + self._annotate_qat_conv2d_bn_binary(model, quantization_config, filter_fn) + self._annotate_qat_conv2d_bn_unary(model, quantization_config, filter_fn) + self._annotate_qat_conv2d_bn(model, quantization_config, filter_fn) + + def _annotate_qat_conv2d_bn_binary_unary( + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ) -> None: + fused_partitions = find_sequential_partitions( + gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, operator.add, torch.nn.ReLU] + ) + for fused_partition in fused_partitions: + ( + conv_partition, + bn_partition, + binary_partition, + unary_partition, + ) = fused_partition + + ( + conv_node, + bn_output_node, + binary_node, + unary_node, + ) = self._get_output_nodes_of_partitions( + [conv_partition, bn_partition, binary_partition, unary_partition] + ) + if len(bn_output_node.users) != 1: + # Conv BN pattern should only has 1 user. + continue + ( + bn_output_node_idx, + extra_input_node_idx, + ) = self._get_input_idx_for_binary_node(bn_output_node, binary_node) + if (bn_output_node_idx is None) or (extra_input_node_idx is None): + continue + if bn_output_node != binary_node.args[bn_output_node_idx]: + raise ValueError(f"{bn_output_node} doesn't match input of binary node") + extra_input_node = binary_node.args[extra_input_node_idx] + + if ( + conv_node.op != "call_function" + or conv_node.target != torch.ops.aten.conv2d.default + ): + continue + + if _skip_annotate( + [unary_node, binary_node, bn_output_node, conv_node], filter_fn + ): + continue + + self._annotate_conv_node_helper(conv_node, False, quantization_config) + + if quantization_config is not None: + binary_node_input_qspec_map = {} + binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( + quantization_config + ) + binary_node.meta[QUANT_ANNOTATION_KEY] = ( + _X86InductorQuantizationAnnotation( + input_qspec_map=binary_node_input_qspec_map, + _annotated=True, + ) + ) + unary_node.meta[QUANT_ANNOTATION_KEY] = ( + _X86InductorQuantizationAnnotation( + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + ) + else: + _annotate_nodes_not_quantize([binary_node, unary_node]) + nodes_to_mark_annotated = list(conv_partition.nodes) + nodes_to_mark_annotated.extend(list(bn_partition.nodes)) + nodes_to_mark_annotated.extend(list(binary_partition.nodes)) + nodes_to_mark_annotated.extend(list(unary_partition.nodes)) + _mark_nodes_as_annotated(nodes_to_mark_annotated) + + def _annotate_qat_conv2d_bn_binary( + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ) -> None: + fused_partitions = find_sequential_partitions( + gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, operator.add] + ) + for fused_partition in fused_partitions: + conv_partition, bn_partition, binary_partition = fused_partition + ( + conv_node, + bn_output_node, + binary_node, + ) = self._get_output_nodes_of_partitions( + [conv_partition, bn_partition, binary_partition] + ) + if len(bn_output_node.users) != 1: + # Conv BN pattern should only has 1 user. + continue + ( + bn_output_node_idx, + extra_input_node_idx, + ) = self._get_input_idx_for_binary_node(bn_output_node, binary_node) + if (bn_output_node_idx is None) or (extra_input_node_idx is None): + continue + if bn_output_node != binary_node.args[bn_output_node_idx]: + raise ValueError(f"{bn_output_node} doesn't match input of binary node") + + extra_input_node = binary_node.args[extra_input_node_idx] + + if ( + conv_node.op != "call_function" + or conv_node.target != torch.ops.aten.conv2d.default + ): + continue + + if _skip_annotate([binary_node, bn_output_node, conv_node], filter_fn): + continue + + self._annotate_conv_node_helper(conv_node, False, quantization_config) + + if quantization_config is not None: + binary_node_input_qspec_map = {} + binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( + quantization_config + ) + binary_node.meta[QUANT_ANNOTATION_KEY] = ( + _X86InductorQuantizationAnnotation( + input_qspec_map=binary_node_input_qspec_map, + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + ) + else: + _annotate_nodes_not_quantize(binary_node) + nodes_to_mark_annotated = list(conv_partition.nodes) + nodes_to_mark_annotated.extend(list(bn_partition.nodes)) + nodes_to_mark_annotated.extend(list(binary_partition.nodes)) + _mark_nodes_as_annotated(nodes_to_mark_annotated) + + def _annotate_qat_conv2d_bn_unary( + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ) -> None: + fused_partitions = [] + unary_patterns = [ + [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU], + [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.Hardtanh], + [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.Hardswish], + [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU6], + [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.SiLU], + ] + for unary_pattern in unary_patterns: + partitions = find_sequential_partitions(gm, unary_pattern) + if partitions: + # Extend the fused_partitions if partitions is not empty + fused_partitions.extend(partitions) + + for fused_partition in fused_partitions: + conv_partition, bn_partition, unary_partition = fused_partition + ( + conv_node, + bn_output_node, + unary_node, + ) = self._get_output_nodes_of_partitions( + [conv_partition, bn_partition, unary_partition] + ) + + if ( + conv_node.op != "call_function" + or conv_node.target != torch.ops.aten.conv2d.default + ): + continue + + if _skip_annotate([unary_node, bn_output_node, conv_node], filter_fn): + continue + + self._annotate_conv_node_helper(conv_node, False, quantization_config) + if quantization_config is not None: + unary_node.meta[QUANT_ANNOTATION_KEY] = ( + _X86InductorQuantizationAnnotation( + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + ) + else: + _annotate_nodes_not_quantize(unary_node) + nodes_to_mark_annotated = list(conv_partition.nodes) + nodes_to_mark_annotated.extend(list(bn_partition.nodes)) + nodes_to_mark_annotated.extend(list(unary_partition.nodes)) + _mark_nodes_as_annotated(nodes_to_mark_annotated) + + def _annotate_qat_conv2d_bn( + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ) -> None: + fused_partitions = find_sequential_partitions( + gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d] + ) + for fused_partition in fused_partitions: + conv_partition, bn_partition = fused_partition + conv_node, bn_output_node = self._get_output_nodes_of_partitions( + [conv_partition, bn_partition] + ) + + if ( + conv_node.op != "call_function" + or conv_node.target != torch.ops.aten.conv2d.default + ): + continue + + if _skip_annotate([bn_output_node, conv_node], filter_fn): + continue + + self._annotate_conv_node_helper(conv_node, False, quantization_config) + if quantization_config is not None: + bn_output_node.meta[QUANT_ANNOTATION_KEY] = ( + _X86InductorQuantizationAnnotation( + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + ) + else: + _annotate_nodes_not_quantize(bn_output_node) + nodes_to_mark_annotated = list(conv_partition.nodes) + nodes_to_mark_annotated.extend(list(bn_partition.nodes)) + _mark_nodes_as_annotated(nodes_to_mark_annotated) + + def _annotate_conv2d_fusion_pattern( + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ): + if (quantization_config is None) or (quantization_config.is_qat): + # Annotate QAT specific pattern: mainly due to BN not folded in prepare_qat + self._annotate_qat_conv2d_fusion_pattern( + model, quantization_config, filter_fn + ) + self._annotate_conv2d_binary_unary(model, quantization_config, filter_fn) + self._annotate_conv2d_binary(model, quantization_config, filter_fn) + self._annotate_conv2d_unary(model, quantization_config, filter_fn) + self._annotate_conv2d(model, quantization_config, filter_fn) + + def _annotate_linear_fusion_pattern( + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ): + self._annotate_linear_binary_unary(model, quantization_config, filter_fn) + self._annotate_linear_unary(model, quantization_config, filter_fn) + self._annotate_linear(model, quantization_config, filter_fn) + + def _annotate_matmul( + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ): + for node in model.graph.nodes: + if node.target != torch.ops.aten.matmul.default: + continue + if _skip_annotate([node], filter_fn): + continue + + if quantization_config is None: + _annotate_nodes_not_quantize(node) + continue + + input_qspec_map = {} + matmul_node = node + for input_node in matmul_node.args: + input_qspec_map[input_node] = get_input_act_qspec(quantization_config) + matmul_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + + def _annotate_conv2d_binary_unary( + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ) -> None: + # Conv2d + add + unary op + fused_partitions = find_sequential_partitions( + gm, [torch.nn.Conv2d, operator.add, torch.nn.ReLU] + ) + for fused_partition in fused_partitions: + conv_partition, binary_partition, unary_partition = fused_partition + conv_node, binary_node, unary_node = self._get_output_nodes_of_partitions( + [conv_partition, binary_partition, unary_partition] + ) + if len(conv_node.users) != 1: + # Conv Node should only has 1 user node + continue + conv_node_idx, extra_input_node_idx = self._get_input_idx_for_binary_node( + conv_node, binary_node + ) + if (conv_node_idx is None) or (extra_input_node_idx is None): + continue + if conv_node != binary_node.args[conv_node_idx]: + raise ValueError(f"{conv_node} doesn't match input of binary node") + extra_input_node = binary_node.args[extra_input_node_idx] + if ( + conv_node.op != "call_function" + or conv_node.target != torch.ops.aten.conv2d.default + ): + # No conv node found to be fused with add + continue + if _skip_annotate([unary_node, binary_node, conv_node], filter_fn): + continue + + if quantization_config is None: + _annotate_nodes_not_quantize([conv_node, binary_node, unary_node]) + continue + + self._annotate_conv_node_helper(conv_node, False, quantization_config) + binary_node_input_qspec_map = {} + binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( + quantization_config + ) + binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=binary_node_input_qspec_map, + _annotated=True, + ) + unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + + def _annotate_conv2d_binary( + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ) -> None: + # Conv2d + add + fused_partitions = find_sequential_partitions( + gm, [torch.nn.Conv2d, operator.add] + ) + for fused_partition in fused_partitions: + conv_partition, binary_partition = fused_partition + conv_node, binary_node = self._get_output_nodes_of_partitions( + [conv_partition, binary_partition] + ) + if len(conv_node.users) != 1: + # Conv Node should only has 1 user node + continue + conv_node_idx, extra_input_node_idx = self._get_input_idx_for_binary_node( + conv_node, binary_node + ) + if (conv_node_idx is None) or (extra_input_node_idx is None): + continue + if conv_node != binary_node.args[conv_node_idx]: + raise ValueError(f"{conv_node} doesn't match input of binary node") + extra_input_node = binary_node.args[extra_input_node_idx] + assert isinstance(conv_node, Node) + if ( + conv_node.op != "call_function" + or conv_node.target != torch.ops.aten.conv2d.default + ): + # No conv node found to be fused with add + continue + if _skip_annotate([binary_node, conv_node], filter_fn): + continue + + if quantization_config is None: + _annotate_nodes_not_quantize([conv_node, binary_node]) + continue + + self._annotate_conv_node_helper(conv_node, False, quantization_config) + binary_node_input_qspec_map = {} + binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( + quantization_config + ) + binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=binary_node_input_qspec_map, + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + + def _annotate_conv2d_unary( + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ) -> None: + fused_partitions = [] + unary_patterns = [ + [torch.nn.Conv2d, torch.nn.ReLU], + [torch.nn.Conv2d, torch.nn.Hardtanh], + [torch.nn.Conv2d, torch.nn.Hardswish], + [torch.nn.Conv2d, torch.nn.ReLU6], + [torch.nn.Conv2d, torch.nn.SiLU], + [torch.nn.Conv1d, torch.nn.ReLU], + ] + for unary_pattern in unary_patterns: + partitions = find_sequential_partitions(gm, unary_pattern) + if partitions: + # Extend the fused_partitions if partitions is not empty + fused_partitions.extend(partitions) + + for fused_partition in fused_partitions: + conv_partition, unary_partition = fused_partition + conv_node, unary_node = self._get_output_nodes_of_partitions( + [conv_partition, unary_partition] + ) + if conv_node.op != "call_function" or conv_node.target not in ( + torch.ops.aten.conv2d.default, + torch.ops.aten.conv1d.default, + ): + continue + if _skip_annotate([unary_node, conv_node], filter_fn): + continue + + if quantization_config is None: + _annotate_nodes_not_quantize([conv_node, unary_node]) + continue + + self._annotate_conv_node_helper(conv_node, False, quantization_config) + unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + + def _annotate_conv2d( + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ) -> None: + conv_partitions = get_source_partitions( + gm.graph, [torch.nn.Conv2d, torch.nn.functional.conv2d] + ) + conv_partitions = list(itertools.chain.from_iterable(conv_partitions.values())) + for conv_partition in conv_partitions: + if len(conv_partition.output_nodes) > 1: + raise ValueError("conv partition has more than one output node") + conv_node = conv_partition.output_nodes[0] + if ( + conv_node.op != "call_function" + or conv_node.target != torch.ops.aten.conv2d.default + ): + raise ValueError(f"{conv_node} is not an aten conv2d operator") + # skip annotation if it is already annotated + if _skip_annotate([conv_node], filter_fn): + continue + self._annotate_conv_node_helper(conv_node, True, quantization_config) + + def _annotate_maxpool2d( + self, + node: Node, + quantization_config: Optional[QuantizationConfig], + ) -> None: + if node.target is not torch.ops.aten.max_pool2d.default: + return + if quantization_config is None: + _annotate_nodes_not_quantize(node) + return + + maxpool_node = node + if _is_any_annotated( + [ + maxpool_node, + ] + ): + return + + input_node = maxpool_node.args[0] + assert isinstance(input_node, Node) + input_qspec_map = {} + input_qspec_map[input_node] = get_input_act_qspec(quantization_config) + maxpool_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + + def _annotate_cat( + self, node: Node, quantization_config: QuantizationConfig + ) -> None: + if quantization_config is None: + _annotate_nodes_not_quantize(node) + return + cat_node = node + input_nodes = cat_node.args[0] + assert isinstance(input_nodes, Sequence) + first_input_node = input_nodes[0] + input_qspec_map = {} + assert isinstance(first_input_node, Node) + assert isinstance(cat_node, Node) + input_qspec_map[first_input_node] = get_input_act_qspec(quantization_config) + share_qparams_with_input_act0_qspec = SharedQuantizationSpec( + (first_input_node, cat_node) + ) + + for input_node in input_nodes[1:]: + if input_node not in input_qspec_map: + # There has the case of cat same nodes: torch.cat([input0, input0], 1) + assert isinstance(input_node, Node) + input_qspec_map[input_node] = share_qparams_with_input_act0_qspec + + cat_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + + def _annotate_propagation_quantizable_pattern_entry( + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ): + for node in gm.graph.nodes: + self._annotate_propagation_quantizable_pattern( + node, quantization_config, filter_fn + ) + + def _annotate_propagation_quantizable_pattern( + self, node: Node, quantization_config, filter_fn + ) -> None: + # Propagate annotation to quantizable patterns. + if ( + (node.target in propagation_quantizable_ops) + and (not _is_any_annotated([node])) + and (node.op == "call_function") + ): + + def is_all_inputs_connected_to_quantized_op(input_nodes): + # Ensure all the inputs connect to fusion pattern or quantized node + for input_node in input_nodes: + if not _is_quantized_op_pt2e(input_node): + return False + return True + + if _skip_annotate([node], filter_fn): + return + + if quantization_config is None: + _annotate_nodes_not_quantize(node) + return + + if node.target is torch.ops.aten.max_pool2d.default: + # Recipe of maxpool2d: check input arg[0] of maxpool2d is quantized or not + input_nodes_to_check = [node.all_input_nodes[0]] + if not is_all_inputs_connected_to_quantized_op(input_nodes_to_check): + if quantization_config is not None: + warnings.warn( + f"The input of maxpool2d is not quantized, skip annotate maxpool2d with config {quantization_config}." + ) + return + + self._annotate_maxpool2d(node, quantization_config) + return + elif node.target is torch.ops.aten.cat.default: + input_nodes_to_check = node.all_input_nodes + if not is_all_inputs_connected_to_quantized_op(input_nodes_to_check): + return + self._annotate_cat(node, quantization_config) + elif ( + node.target is torch.ops.aten.flatten.using_ints + and len(node.users) > 0 + and not any( + user.target in quantizable_ops for user in node.users.keys() + ) + ): + # Recipe of flatten: check if any users of flatten node are quantizable ops or not + return + else: + input_node = node.all_input_nodes[0] + if not is_all_inputs_connected_to_quantized_op( + [ + input_node, + ] + ): + return + input_qspec_map = {} + input_qspec_map[input_node] = get_input_act_qspec(quantization_config) + node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + return + + def _annotate_output_share_observer_as_input( + self, input_node: Node, source_node: Node + ): + source_node_quantization_annotation = ( + source_node.meta[QUANT_ANNOTATION_KEY] + if QUANT_ANNOTATION_KEY in source_node.meta + else None + ) + if ( + source_node_quantization_annotation + and source_node_quantization_annotation._is_output_of_quantized_pattern + ): + edge_or_node = (input_node, source_node) + source_node_quantization_annotation.output_qspec = SharedQuantizationSpec( + edge_or_node + ) + return + + def _annotate_output_for_int8_in_int8_out_pattern_entry( + self, + model: torch.fx.GraphModule, + ): + for node in model.graph.nodes: + self._annotate_output_for_int8_in_int8_out_pattern(node) + + def _annotate_output_for_int8_in_int8_out_pattern( + self, + node: Node, + ) -> None: + r""" + Check and insert observer at output of node in int8_in_int8_out_ops if needed. + Recipe refers to + https://github.com/intel/intel-extension-for-pytorch/blob/90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_utils.py#L495 + """ # noqa: B950 + edge_or_node: tuple[Node, Node] + if (node.target in int8_in_int8_out_ops) and (_is_any_annotated([node])): + if node.target == torch.ops.aten.max_pool2d.default: + maxpool_node = node + if not _is_all_annotated( + [ + maxpool_node, + ] + ): + return + + # Get the quantization_annotation from getitem_node + maxpool_node_quantization_annotation = ( + maxpool_node.meta[QUANT_ANNOTATION_KEY] + if QUANT_ANNOTATION_KEY in maxpool_node.meta + else None + ) + if ( + maxpool_node_quantization_annotation + and maxpool_node_quantization_annotation._is_output_of_quantized_pattern + ): + # Annotate the output_qspec of getitem_node + input_act = maxpool_node.args[0] + assert isinstance(input_act, Node) + assert isinstance(maxpool_node, Node) + edge_or_node = (input_act, maxpool_node) + maxpool_node_quantization_annotation.output_qspec = ( + SharedQuantizationSpec(edge_or_node) + ) + else: + input_node = node.all_input_nodes[0] + self._annotate_output_share_observer_as_input(input_node, node) + return + + def _annotate_linear( + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ) -> None: + linear_partitions = get_source_partitions( + gm.graph, [torch.nn.Linear, torch.nn.functional.linear] + ) + linear_partitions = list( + itertools.chain.from_iterable(linear_partitions.values()) + ) + for partition in linear_partitions: + if len(partition.output_nodes) > 1: + raise ValueError( + "Linear partition cannot have more than one output node" + ) + linear_node = partition.output_nodes[0] + if linear_node.op != "call_function" or linear_node.target not in ( + torch.ops.aten.linear.default, + ): + raise ValueError(f"{linear_node} is not an aten linear operator") + # skip annotation if it is already annotated + if _skip_annotate([linear_node], filter_fn): + continue + self._annotate_linear_node_helper(linear_node, True, quantization_config) + + def _annotate_linear_unary( + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ) -> None: + postop_list = [ + torch.nn.ReLU, + torch.nn.LeakyReLU, + torch.nn.Tanh, + torch.nn.GELU, + ] + fused_partitions: list[tuple] = [] + for postop in postop_list: + fused_partitions = fused_partitions + find_sequential_partitions( + gm, [torch.nn.Linear, postop] + ) + for fused_partition in fused_partitions: + linear_partition, unary_partition = fused_partition + linear_node, unary_node = self._get_output_nodes_of_partitions( + [linear_partition, unary_partition] + ) + if linear_node.op != "call_function" or linear_node.target not in ( + torch.ops.aten.linear.default, + ): + continue + if _skip_annotate([unary_node, linear_node], filter_fn): + continue + + if quantization_config is None: + _annotate_nodes_not_quantize([linear_node, unary_node]) + continue + + self._annotate_linear_node_helper(linear_node, False, quantization_config) + unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + + def _annotate_linear_binary_unary( + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ) -> None: + # linear + binary_op + (optional) unary op + binary_op_list = [operator.add] + unary_op_list = [torch.nn.ReLU, None] + combinations = itertools.product(binary_op_list, unary_op_list) + for binary_op, unary_op in combinations: + has_unary = unary_op is not None + seq_partition = [torch.nn.Linear, binary_op] + if has_unary: + seq_partition.append(unary_op) + fused_partitions = find_sequential_partitions(gm, seq_partition) + for fused_partition in fused_partitions: + unary_partition, unary_node = None, None + if has_unary: + ( + linear_partition, + binary_partition, + unary_partition, + ) = fused_partition + ( + linear_node, + binary_node, + unary_node, + ) = self._get_output_nodes_of_partitions( + [linear_partition, binary_partition, unary_partition] + ) + else: + linear_partition, binary_partition = fused_partition + linear_node, binary_node = self._get_output_nodes_of_partitions( + [linear_partition, binary_partition] + ) + if len(linear_node.users) != 1: + # Linear Node should only has 1 user node + continue + ( + linear_node_idx, + extra_input_node_idx, + ) = self._get_input_idx_for_binary_node(linear_node, binary_node) + if (linear_node_idx is None) or (extra_input_node_idx is None): + continue + if linear_node != binary_node.args[linear_node_idx]: + raise ValueError( + f"{linear_node} doesn't match input of binary node" + ) + assert isinstance(linear_node, Node) + if ( + linear_node.op != "call_function" + or linear_node.target != torch.ops.aten.linear.default + ): + # No linear node found to be fused with add + continue + node_list = ( + [binary_node, linear_node] + if unary_node is None + else [unary_node, binary_node, linear_node] + ) + if _skip_annotate(node_list, filter_fn): + continue + + if quantization_config is None: + _annotate_nodes_not_quantize(node_list) + continue + + self._annotate_linear_node_helper( + linear_node, False, quantization_config + ) + # We don't insert q-dq before the binary input node due to accuracy issues + binary_node.meta[QUANT_ANNOTATION_KEY] = ( + _X86InductorQuantizationAnnotation( + input_qspec_map={}, + _annotated=True, + _is_output_of_quantized_pattern=(not has_unary), + ) + ) + if unary_node is not None: + unary_node.meta[QUANT_ANNOTATION_KEY] = ( + _X86InductorQuantizationAnnotation( + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + ) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/xnnpack_quantizer.py b/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/xnnpack_quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..2d0b6af22e8779f5e57360659298bd331c8594e8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/xnnpack_quantizer.py @@ -0,0 +1,450 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import copy +import functools +import typing_extensions +from typing import Any, Callable, Optional, TYPE_CHECKING + +import torch +import torch._dynamo as torchdynamo +import torch.nn.functional as F +from torch.ao.quantization.fake_quantize import ( + FakeQuantize, + FusedMovingAvgObsFakeQuantize, +) +from torch.ao.quantization.observer import ( + HistogramObserver, + MinMaxObserver, + MovingAverageMinMaxObserver, + MovingAveragePerChannelMinMaxObserver, + PerChannelMinMaxObserver, + PlaceholderObserver, +) +from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer +from torch.ao.quantization.quantizer.utils import _get_module_name_filter +from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( + _convert_scalars_to_attrs, + OP_TO_ANNOTATOR, + OperatorConfig, + OperatorPatternType, + propagate_annotation, + QuantizationConfig, +) +from torch.fx._compatibility import compatibility + + +if TYPE_CHECKING: + from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor + from torch.fx import Node + + +__all__ = [ + "XNNPACKQuantizer", + "get_symmetric_quantization_config", +] + + +def _get_dynamo_graph(function: Callable, inputs) -> torch.fx.Graph: + gm, _ = torchdynamo.export(function, aten_graph=True)(*inputs) + gm.graph.eliminate_dead_code() + return gm.graph + + +def _get_linear_patterns(input_size: list[int]): + in_channels = input_size[-1] + out_channels = 8 # hard coding but this should not matter + weight = torch.ones((out_channels, in_channels)) + bias = torch.ones((out_channels,)) + act = torch.ones(input_size) + + def linear_op(act, weight, bias=None): + return F.linear(act, weight, bias) + + pattern_w_bias = _get_dynamo_graph(linear_op, (act, weight, bias)) + pattern_wo_bias = _get_dynamo_graph(linear_op, (act, weight)) + return [pattern_w_bias, pattern_wo_bias] + + +def _supported_symmetric_quantized_operators() -> dict[str, list[OperatorPatternType]]: + supported_operators: dict[str, list[OperatorPatternType]] = { + # Both conv and linear should be able to handle relu + hardtanh fusion since + # those are clamp ops + "conv2d": [ + [torch.nn.Conv2d, torch.nn.ReLU], + [torch.nn.Conv2d, F.relu], + [F.conv2d, torch.nn.ReLU], + [F.conv2d, F.relu], + ], + "linear": [[torch.nn.Linear], [F.linear]], + "add": [[torch.add]], + "adaptive_avg_pool2d": [ + [torch.nn.AdaptiveAvgPool2d], + [F.adaptive_avg_pool2d], + ], + } + return copy.deepcopy(supported_operators) + + +def _get_supported_symmetric_config_and_operators() -> list[OperatorConfig]: + supported_config_and_operators: list[OperatorConfig] = [] + for quantization_config in [ + get_symmetric_quantization_config(), + get_symmetric_quantization_config(is_qat=True), + get_symmetric_quantization_config(is_per_channel=True), + get_symmetric_quantization_config(is_per_channel=True, is_qat=True), + ]: + ops = _supported_symmetric_quantized_operators() + supported_config_and_operators.extend( + OperatorConfig(quantization_config, pattern_list) + for pattern_list in ops.values() + ) + return copy.deepcopy(supported_config_and_operators) + + +@functools.lru_cache +def get_symmetric_quantization_config( + is_per_channel: bool = False, + is_qat: bool = False, + is_dynamic: bool = False, + act_qmin: int = -128, + act_qmax: int = 127, + weight_qmin: int = -127, + weight_qmax: int = 127, +): + extra_args: dict[str, Any] = {"eps": 2**-12} + if is_qat: + if is_dynamic: + act_observer_or_fake_quant_ctr = FakeQuantize + dynamic_quant_observer = MovingAverageMinMaxObserver.with_args( + averaging_constant=1 + ) + extra_args["observer"] = dynamic_quant_observer + else: + act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment] + else: + if is_dynamic: + act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment] + else: + act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment] + + act_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=act_qmin, + quant_max=act_qmax, + qscheme=torch.per_tensor_affine, + is_dynamic=is_dynamic, + observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args( + **extra_args, + ), + ) + weight_qscheme = ( + torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric + ) + weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( + MinMaxObserver + ) + if is_qat: + # TODO: qat + per channel? + weight_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize + elif is_per_channel: + weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver + + extra_args: dict[str, Any] = {"eps": 2**-12} + if is_qat: + if weight_qscheme == torch.per_tensor_symmetric: + extra_args["observer"] = MovingAverageMinMaxObserver + else: + extra_args["observer"] = MovingAveragePerChannelMinMaxObserver # type: ignore[dict-item] + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=weight_qmin, + quant_max=weight_qmax, + qscheme=weight_qscheme, + ch_axis=0, + is_dynamic=False, + observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args( + **extra_args + ), + ) + + bias_quantization_spec = None + if is_dynamic: + quantization_config = QuantizationConfig( + act_quantization_spec, + None, + weight_quantization_spec, + bias_quantization_spec, + is_qat, + ) + else: + quantization_config = QuantizationConfig( + act_quantization_spec, + act_quantization_spec, + weight_quantization_spec, + bias_quantization_spec, + is_qat, + ) + return quantization_config + + +def _get_supported_config_and_operators() -> list[OperatorConfig]: + return _get_supported_symmetric_config_and_operators() + + +def _get_module_type_filter(tp: Callable): + """Get the module_type_filter function for a given module type, the filter accepts + a node and checks if the node comes from a module that has certain module type + + For example: + node: linear_op = call_function[...](...) # comes from a module with type Block -> Sub -> Linear + + + >> module_type_filter = _get_module_type_filter(Sub) # submodule with type `Sub`, under the `Block` submodule + >> print(module_type_filter(node)) + True # the node is from the submodule `Sub` (same for `Block` and `Linear` as well) + """ + + tp_str = tp.__module__ + "." + tp.__qualname__ + + def module_type_filter(n: Node) -> bool: + # example: { + # 'L__self___sub': ("L['self'].sub", ), + # 'L__self___sub_linear': ("L['self'].sub.linear", ) + # } + nn_module_stack = n.meta.get("nn_module_stack", {}) + types = [] + for _, t in nn_module_stack.values(): + # export() returns str, but older APIs (e.g. capture_pre_autograd_graph) + # return type. Handle both cases. + if isinstance(t, type): + t = t.__module__ + "." + t.__qualname__ + types.append(t) + return tp_str in types + + return module_type_filter + + +def _get_not_module_type_or_name_filter( + tp_list: list[Callable], module_name_list: list[str] +) -> Callable[[Node], bool]: + module_type_filters = [_get_module_type_filter(tp) for tp in tp_list] + module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list] + + def not_module_type_or_name_filter(n: Node) -> bool: + return not any(f(n) for f in module_type_filters + module_name_list_filters) + + return not_module_type_or_name_filter + + +@compatibility(is_backward_compatible=False) +@typing_extensions.deprecated( + "XNNPACKQuantizer is deprecated! Please use xnnpack quantizer in " + "ExecuTorch (https://github.com/pytorch/executorch/tree/main/backends/xnnpack/quantizer) instead." +) +class XNNPACKQuantizer(Quantizer): + """ + !!! DEPRECATED !!! + XNNPACKQuantizer is a marked as deprected. It will be removed in the future. + It has been moved to executorch.backends.xnnpack.quantizer.xnnpack_quantizer.XNNPACKQuantizer. + Please use the new quantizer instead. + """ + + supported_config_and_operators = _get_supported_config_and_operators() + STATIC_QAT_ONLY_OPS = [ + "conv_bn_relu", + "conv_bn", + "conv_transpose_bn_relu", + "conv_transpose_bn", + ] + + # static quantization ops (both PTQ and QAT) + # Preserve the order that fusions come before singular ops + STATIC_OPS = [ + "linear_relu", + "linear", + "conv_relu", + "conv", + "conv_transpose_relu", + "adaptive_avg_pool2d", + # TODO: move this to BoltNNQuantizer? + "gru_io_only", + "add_relu", + "add", + "mul_relu", + "mul", + "cat", + ] + + DYNAMIC_OPS = [ + "linear", + ] + + def __init__(self) -> None: + super().__init__() + self.global_config: Optional[QuantizationConfig] = None + self.operator_type_config: dict[ + torch._ops.OpOverloadPacket, Optional[QuantizationConfig] + ] = {} + self.module_type_config: dict[Callable, Optional[QuantizationConfig]] = {} + self.module_name_config: dict[str, Optional[QuantizationConfig]] = {} + + @classmethod + def get_supported_quantization_configs(cls) -> list[QuantizationConfig]: + op_configs: set[QuantizationConfig] = { + spec for spec, _ in cls.supported_config_and_operators + } + return list(op_configs) + + @classmethod + def get_supported_operator_for_quantization_config( + cls, quantization_config: Optional[QuantizationConfig] + ) -> list[OperatorPatternType]: + if quantization_config is None: + all_ops = [] + for _, ops in cls.supported_config_and_operators: + all_ops.extend(ops) + return all_ops + + for config, ops in cls.supported_config_and_operators: + # note: this assumes each entry in cls.supported_spec_and_operators + # corresponds to one spec, e.g. we don't have + # [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)] + # where the first and second entry have the same spec but did not + # merge the op list + if config == quantization_config: + return ops + return [] + + def set_global(self, quantization_config: QuantizationConfig) -> XNNPACKQuantizer: + self.global_config = quantization_config + return self + + def set_operator_type( + self, + operator_type: torch._ops.OpOverloadPacket, + quantization_config: QuantizationConfig, + ) -> XNNPACKQuantizer: + self.operator_type_config[operator_type] = quantization_config + return self + + def set_module_type( + self, module_type: Callable, quantization_config: QuantizationConfig + ): + """Set quantization_config for a submodule with type: `module_type`, for example: + quantizer.set_module_name(Sub) or quantizer.set_module_name(nn.Linear), it will quantize all supported operator/operator + patterns in the submodule with this module type with the given `quantization_config` + """ + self.module_type_config[module_type] = quantization_config + return self + + def set_module_name( + self, module_name: str, quantization_config: Optional[QuantizationConfig] + ): + """Set quantization_config for a submodule with name: `module_name`, for example: + quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator + patterns in the submodule with this module name with the given `quantization_config` + """ + assert quantization_config is not None, ( + " quantization_config == None is not supported yet" + ) + self.module_name_config[module_name] = quantization_config + return self + + def transform_for_annotation( + self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + """Transforms scalar values to tensor attributes""" + return _convert_scalars_to_attrs(model) + + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + """just handling global spec for now""" + # hacked for handling dynamic linear quant. will fix later. + if self.global_config and self.global_config.input_activation.is_dynamic: # type: ignore[union-attr] + model = self._annotate_for_dynamic_quantization_config(model) + else: + model = self._annotate_for_static_quantization_config(model) + propagate_annotation(model) + return model + + def _annotate_all_static_patterns( + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, + ) -> torch.fx.GraphModule: + # TODO: implement the support for None to be canceling out previous annotations + if quantization_config is None: + return model + + if quantization_config.is_qat: + for op in self.STATIC_QAT_ONLY_OPS: + OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) + for op in self.STATIC_OPS: + OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) + return model + + def _annotate_all_dynamic_patterns( + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, + ) -> torch.fx.GraphModule: + # TODO: implement the support for None to be canceling out previous annotations + if quantization_config is None: + return model + + for op in self.DYNAMIC_OPS: + OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) + return model + + def _annotate_for_static_quantization_config( + self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + module_name_list = list(self.module_name_config.keys()) + for module_name, config in self.module_name_config.items(): + self._annotate_all_static_patterns( + model, config, _get_module_name_filter(module_name) + ) + + tp_list = list(self.module_type_config.keys()) + for module_type, config in self.module_type_config.items(): + self._annotate_all_static_patterns( + model, config, _get_module_type_filter(module_type) + ) + + self._annotate_all_static_patterns( + model, + self.global_config, + _get_not_module_type_or_name_filter(tp_list, module_name_list), + ) + return model + + def _annotate_for_dynamic_quantization_config( + self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + module_name_list = list(self.module_name_config.keys()) + for module_name, config in self.module_name_config.items(): + self._annotate_all_dynamic_patterns( + model, config, _get_module_name_filter(module_name) + ) + + tp_list = list(self.module_type_config.keys()) + for module_type, config in self.module_type_config.items(): + self._annotate_all_dynamic_patterns( + model, config, _get_module_type_filter(module_type) + ) + + self._annotate_all_dynamic_patterns( + model, + self.global_config, + _get_not_module_type_or_name_filter(tp_list, module_name_list), + ) + return model + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + @classmethod + def get_supported_operators(cls) -> list[OperatorConfig]: + return cls.supported_config_and_operators diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py b/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e3aa16935ca0a2a7ca70bea3d281d5bf28806f6b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py @@ -0,0 +1,1127 @@ +# mypy: allow-untyped-defs +import itertools +import typing +from dataclasses import dataclass +from typing import Callable, NamedTuple, Optional + +import torch +import torch.nn.functional as F +from torch._subclasses import FakeTensor +from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix +from torch.ao.quantization.pt2e.export_utils import _WrapperModule +from torch.ao.quantization.pt2e.utils import ( + _get_aten_graph_module_for_pattern, + _is_conv_node, + _is_conv_transpose_node, +) +from torch.ao.quantization.quantizer import ( + QuantizationAnnotation, + QuantizationSpec, + SharedQuantizationSpec, +) +from torch.ao.quantization.quantizer.utils import ( + _annotate_input_qspec_map, + _annotate_output_qspec, +) +from torch.fx import Node +from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( + SubgraphMatcherWithNameNodeMap, +) +from torch.fx.passes.utils.source_matcher_utils import get_source_partitions + + +__all__ = [ + "OperatorConfig", + "OperatorPatternType", + "QuantizationConfig", + "get_input_act_qspec", + "get_output_act_qspec", + "get_weight_qspec", + "get_bias_qspec", + "OP_TO_ANNOTATOR", + "propagate_annotation", +] + + +# In the absence of better name, just winging it with QuantizationConfig +@dataclass(eq=True, frozen=True) +class QuantizationConfig: + input_activation: Optional[QuantizationSpec] + output_activation: Optional[QuantizationSpec] + weight: Optional[QuantizationSpec] + bias: Optional[QuantizationSpec] + # TODO: remove, since we can use observer_or_fake_quant_ctr to express this + is_qat: bool = False + + +# Use Annotated because list[Callable].__module__ is read-only. +OperatorPatternType = typing.Annotated[list[Callable], None] +OperatorPatternType.__module__ = ( + "torch.ao.quantization.quantizer.xnnpack_quantizer_utils" +) + +AnnotatorType = Callable[ + [ + torch.fx.GraphModule, + Optional[QuantizationConfig], + Optional[Callable[[Node], bool]], + ], + Optional[list[list[Node]]], +] +OP_TO_ANNOTATOR: dict[str, AnnotatorType] = {} + + +def register_annotator(op: str) -> Callable[[AnnotatorType], None]: + def decorator(annotator: AnnotatorType) -> None: + OP_TO_ANNOTATOR[op] = annotator + + return decorator + + +class OperatorConfig(NamedTuple): + # fix List[str] with List[List[Union[nn.Module, FunctionType, BuiltinFunctionType]]] + # Basically we are mapping a quantization config to some list of patterns. + # a pattern is defined as a list of nn module, function or builtin function names + # e.g. [nn.Conv2d, torch.relu, torch.add] + # We have not resolved whether fusion can be considered internal details of the + # quantizer hence it does not need communication to user. + # Note this pattern is not really informative since it does not really + # tell us the graph structure resulting from the list of ops. + config: QuantizationConfig + operators: list[OperatorPatternType] + + +def _is_annotated(nodes: list[Node]): + """ + Given a list of nodes (that represents an operator pattern), + check if any of the node is annotated, return True if any of the node + is annotated, otherwise return False + """ + annotated = False + for node in nodes: + annotated = annotated or ( + "quantization_annotation" in node.meta + and node.meta["quantization_annotation"]._annotated + ) + return annotated + + +def _mark_nodes_as_annotated(nodes: list[Node]): + for node in nodes: + if node is not None: + if "quantization_annotation" not in node.meta: + node.meta["quantization_annotation"] = QuantizationAnnotation() + node.meta["quantization_annotation"]._annotated = True + + +def get_input_act_qspec(quantization_config: Optional[QuantizationConfig]): + if quantization_config is None: + return None + if quantization_config.input_activation is None: + return None + quantization_spec: QuantizationSpec = quantization_config.input_activation + assert quantization_spec.qscheme in [ + torch.per_tensor_affine, + torch.per_tensor_symmetric, + ] + return quantization_spec + + +def get_output_act_qspec(quantization_config: Optional[QuantizationConfig]): + if quantization_config is None: + return None + if quantization_config.output_activation is None: + return None + quantization_spec: QuantizationSpec = quantization_config.output_activation + assert quantization_spec.qscheme in [ + torch.per_tensor_affine, + torch.per_tensor_symmetric, + ] + return quantization_spec + + +def get_weight_qspec(quantization_config: Optional[QuantizationConfig]): + if quantization_config is None: + return None + assert quantization_config is not None + if quantization_config.weight is None: + return None + quantization_spec: QuantizationSpec = quantization_config.weight + if quantization_spec.qscheme not in [ + torch.per_tensor_symmetric, + torch.per_channel_symmetric, + None, + ]: + raise ValueError( + f"Unsupported quantization_spec {quantization_spec} for weight" + ) + return quantization_spec + + +def get_bias_qspec(quantization_config: Optional[QuantizationConfig]): + if quantization_config is None: + return None + assert quantization_config is not None + if quantization_config.bias is None: + return None + quantization_spec: QuantizationSpec = quantization_config.bias + assert quantization_spec.dtype == torch.float, ( + "Only float dtype for bias is supported for bias right now" + ) + return quantization_spec + + +@register_annotator("linear") +def _annotate_linear( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + annotated_partitions = [] + input_act_qspec = get_input_act_qspec(quantization_config) + output_act_qspec = get_output_act_qspec(quantization_config) + weight_qspec = get_weight_qspec(quantization_config) + bias_qspec = get_bias_qspec(quantization_config) + for node in gm.graph.nodes: + if node.op != "call_function" or node.target != torch.ops.aten.linear.default: + continue + if filter_fn and not filter_fn(node): + continue + act_node = node.args[0] + weight_node = node.args[1] + bias_node = None + if len(node.args) > 2: + bias_node = node.args[2] + + if _is_annotated([node]) is False: # type: ignore[list-item] + _annotate_input_qspec_map( + node, + act_node, + input_act_qspec, + ) + _annotate_input_qspec_map( + node, + weight_node, + weight_qspec, + ) + nodes_to_mark_annotated = [node, weight_node] + if bias_node: + _annotate_input_qspec_map( + node, + bias_node, + bias_qspec, + ) + nodes_to_mark_annotated.append(bias_node) + _annotate_output_qspec(node, output_act_qspec) + _mark_nodes_as_annotated(nodes_to_mark_annotated) + annotated_partitions.append(nodes_to_mark_annotated) + + return annotated_partitions + + +@register_annotator("linear_relu") +def _annotate_linear_relu( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + annotated_partitions = [] + input_act_qspec = get_input_act_qspec(quantization_config) + output_act_qspec = get_output_act_qspec(quantization_config) + weight_qspec = get_weight_qspec(quantization_config) + bias_qspec = get_bias_qspec(quantization_config) + for node in gm.graph.nodes: + if node.op != "call_function" or node.target not in [ + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + ]: + continue + relu_node = node + maybe_linear_node = node.args[0] + if ( + not isinstance(maybe_linear_node, Node) + or maybe_linear_node.op != "call_function" + or maybe_linear_node.target != torch.ops.aten.linear.default + ): + continue + + linear_node = maybe_linear_node + if len(linear_node.users) > 1: + # if linear node has multiple users, then it can't be fused with relu + continue + + input_qspec_map = {} + input_act = linear_node.args[0] + assert isinstance(input_act, Node) + input_qspec_map[input_act] = input_act_qspec + + weight = linear_node.args[1] + assert isinstance(weight, Node) + input_qspec_map[weight] = weight_qspec + + # adding weight node to the partition as well + partition = [relu_node, linear_node, weight] + bias = linear_node.args[2] if len(linear_node.args) > 2 else None + if isinstance(bias, Node): + input_qspec_map[bias] = bias_qspec + partition.append(bias) + + if _is_annotated(partition): + continue + + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + linear_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + ) + relu_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=output_act_qspec, + _annotated=True, + ) + _mark_nodes_as_annotated(partition) + annotated_partitions.append(partition) + return annotated_partitions + + +@register_annotator("conv") +def _annotate_conv( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + annotated_partitions = [] + for n in gm.graph.nodes: + if n.op != "call_function" or n.target not in [ + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + ]: + continue + conv_node = n + + input_qspec_map = {} + input_act = conv_node.args[0] + assert isinstance(input_act, Node) + input_qspec_map[input_act] = get_input_act_qspec(quantization_config) + + weight = conv_node.args[1] + assert isinstance(weight, Node) + input_qspec_map[weight] = get_weight_qspec(quantization_config) + + # adding weight node to the partition as well + partition = [conv_node, conv_node.args[1]] + + bias = conv_node.args[2] if len(conv_node.args) > 2 else None + if isinstance(bias, Node): + input_qspec_map[bias] = get_bias_qspec(quantization_config) + partition.append(bias) + + if _is_annotated(partition): + continue + + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + conv_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=get_output_act_qspec(quantization_config), + _annotated=True, + ) + _mark_nodes_as_annotated(partition) + annotated_partitions.append(partition) + return annotated_partitions + + +def _do_annotate_conv_relu( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, + is_conv_transpose: bool = False, +): + annotated_partitions = [] + for n in gm.graph.nodes: + if n.op != "call_function" or n.target not in [ + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + ]: + continue + relu_node = n + maybe_conv_node = n.args[0] + + is_conv_node = _is_conv_transpose_node if is_conv_transpose else _is_conv_node + if not isinstance(maybe_conv_node, Node) or not is_conv_node(maybe_conv_node): + continue + conv_node = maybe_conv_node + + if len(conv_node.users) > 1: + # relu shouldn't be fuseable to conv if there are other users + # of convolution + continue + + input_qspec_map = {} + input_act = conv_node.args[0] + assert isinstance(input_act, Node) + input_qspec_map[input_act] = get_input_act_qspec(quantization_config) + + weight = conv_node.args[1] + assert isinstance(weight, Node) + input_qspec_map[weight] = get_weight_qspec(quantization_config) + + # adding weight node to the partition as well + partition = [relu_node, conv_node, conv_node.args[1]] + bias = conv_node.args[2] if len(conv_node.args) > 2 else None + if isinstance(bias, Node): + input_qspec_map[bias] = get_bias_qspec(quantization_config) + partition.append(bias) + + if _is_annotated(partition): + continue + + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + conv_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, _annotated=True + ) + relu_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + ) + _mark_nodes_as_annotated(partition) + annotated_partitions.append(partition) + return annotated_partitions + + +@register_annotator("conv_relu") +def _annotate_conv_relu( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + return _do_annotate_conv_relu( + gm, quantization_config, filter_fn, is_conv_transpose=False + ) + + +@register_annotator("conv_transpose_relu") +def _annotate_conv_transpose_relu( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + return _do_annotate_conv_relu( + gm, quantization_config, filter_fn, is_conv_transpose=True + ) + + +@register_annotator("conv_bn") +def _annotate_conv_bn( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + """ + Find conv + batchnorm parititions + Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv. + """ + return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=False) + + +@register_annotator("conv_bn_relu") +def _annotate_conv_bn_relu( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + """ + Find conv + batchnorm + relu parititions + Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv. + """ + return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=True) + + +@register_annotator("conv_transpose_bn") +def _annotate_conv_transpose_bn( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + """ + Find conv_transpose + batchnorm parititions + Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv. + """ + return _do_annotate_conv_bn( + gm, quantization_config, filter_fn, has_relu=False, is_conv_transpose=True + ) + + +@register_annotator("conv_transpose_bn_relu") +def _annotate_conv_transpose_bn_relu( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + """ + Find conv_transpose + batchnorm + relu parititions + Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv. + """ + return _do_annotate_conv_bn( + gm, quantization_config, filter_fn, has_relu=True, is_conv_transpose=True + ) + + +def _do_annotate_conv_bn( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]], + has_relu: bool, + is_conv_transpose: bool = False, +) -> list[list[Node]]: + """ + Given a function that takes in a `conv_fn` and returns a conv-bn[-relu] pattern, + return a list of annotated partitions. + + The output of the pattern must include a dictionary from string name to node + for the following names: "input", "conv", "weight", "bias", and "output". + """ + + # Example inputs for conv-bn1d patterns + _conv1d_bn_example_inputs = ( + torch.randn(1, 1, 3), # x + torch.randn(1, 1, 1), # conv_weight + torch.randn(1), # conv_bias + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var + ) + + # Example inputs for conv-bn2d patterns + _conv2d_bn_example_inputs = ( + torch.randn(1, 1, 3, 3), # x + torch.randn(1, 1, 1, 1), # conv_weight + torch.randn(1), # conv_bias + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var + ) + + def get_pattern(conv_fn: Callable, relu_is_inplace: bool): + def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv): + conv = conv_fn(x, conv_weight, conv_bias) + bn = F.batch_norm(conv, bn_rm, bn_rv, bn_weight, bn_bias, training=True) + if has_relu: + output = F.relu_(bn) if relu_is_inplace else F.relu(bn) + else: + output = bn + return output, { + "input": x, + "conv": conv, + "weight": conv_weight, + "bias": conv_bias, + "output": output, + } + + return _WrapperModule(_conv_bn) + + # Needed for matching, otherwise the matches gets filtered out due to unused + # nodes returned by batch norm + gm.graph.eliminate_dead_code() + gm.recompile() + + matches = [] + if is_conv_transpose: + combinations = [ + (F.conv_transpose1d, _conv1d_bn_example_inputs), + (F.conv_transpose2d, _conv2d_bn_example_inputs), + ] + else: + combinations = [ + (F.conv1d, _conv1d_bn_example_inputs), # type: ignore[list-item] + (F.conv2d, _conv2d_bn_example_inputs), # type: ignore[list-item] + ] + + # Add `is_cuda` and `relu_is_inplace` dimensions + combinations = itertools.product( # type: ignore[assignment] + combinations, + [True, False] if torch.cuda.is_available() else [False], # is_cuda + [True, False] if has_relu else [False], # relu_is_inplace + ) + + # Match against all conv dimensions and cuda variants + for (conv_fn, example_inputs), is_cuda, relu_is_inplace in combinations: # type: ignore[misc] + pattern = get_pattern(conv_fn, relu_is_inplace) # type: ignore[has-type] + pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs, is_cuda) # type: ignore[has-type] + pattern.graph.eliminate_dead_code() + pattern.recompile() + matcher = SubgraphMatcherWithNameNodeMap(pattern, ignore_literals=True) + matches.extend(matcher.match(gm.graph)) + + # Annotate nodes returned in the matches + annotated_partitions = [] + for match in matches: + name_node_map = match.name_node_map + input_node = name_node_map["input"] + conv_node = name_node_map["conv"] + weight_node = name_node_map["weight"] + bias_node = name_node_map["bias"] + output_node = name_node_map["output"] + + # TODO: annotate the uses of input, weight, and bias separately instead + # of assuming they come from a single conv node. This is not possible today + # because input may have multiple users, and we can't rely on the conv node + # always being the first user. This was the case in models with skip + # connections like resnet18 + + # Validate conv args + if conv_node.args[0] is not input_node: + raise ValueError("Conv arg did not contain input node ", input_node) + if conv_node.args[1] is not weight_node: + raise ValueError("Conv arg did not contain weight node ", weight_node) + if len(conv_node.args) > 2 and conv_node.args[2] is not bias_node: + raise ValueError("Conv arg did not contain bias node ", bias_node) + + # Skip if the partition is already annotated or is filtered out by the user + partition = [conv_node, weight_node] + if bias_node is not None: + partition.append(bias_node) + if _is_annotated(partition): + continue + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + # Annotate conv inputs and pattern output + input_qspec_map = {} + input_qspec_map[input_node] = get_input_act_qspec(quantization_config) + input_qspec_map[weight_node] = get_weight_qspec(quantization_config) + if bias_node is not None: + input_qspec_map[bias_node] = get_bias_qspec(quantization_config) + conv_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + ) + output_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + ) + _mark_nodes_as_annotated(partition) + annotated_partitions.append(partition) + return annotated_partitions + + +@register_annotator("gru_io_only") +def _annotate_gru_io_only( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + gru_partitions = get_source_partitions(gm.graph, [torch.nn.GRU], filter_fn) + gru_partitions = list(itertools.chain.from_iterable(gru_partitions.values())) + annotated_partitions = [] + for gru_partition in gru_partitions: + annotated_partitions.append(gru_partition.nodes) + output_nodes = gru_partition.output_nodes + input_nodes = gru_partition.input_nodes + # skip annotation if it is already annotated + if _is_annotated(input_nodes + output_nodes): + continue + # inside each GRU partition, we should be able to annotate each linear + # subgraph + input_act = input_nodes[0] + input_act_user = next(iter(input_act.users.keys())) + assert isinstance(input_act, Node) + assert isinstance(input_act_user, Node) + input_act_user.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act: get_input_act_qspec(quantization_config), + }, + _annotated=True, + ) + + hidden_state = input_nodes[1] + hidden_state_user = next(iter(hidden_state.users.keys())) + assert isinstance(hidden_state, Node) + assert isinstance(hidden_state_user, Node) + hidden_state_user.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + hidden_state: get_input_act_qspec(quantization_config), + }, + _annotated=True, + ) + + assert len(output_nodes) == 2, "expecting GRU to have two outputs" + for output in output_nodes: + output.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=get_output_act_qspec(quantization_config), + _annotated=True, + ) + nodes_to_mark_annotated = list(gru_partition.nodes) + _mark_nodes_as_annotated(nodes_to_mark_annotated) + return annotated_partitions + + +@register_annotator("adaptive_avg_pool2d") +def _annotate_adaptive_avg_pool2d( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + """Always annotate adaptive_avg_pool2d op""" + module_partitions = get_source_partitions( + gm.graph, [torch.nn.AdaptiveAvgPool2d, F.adaptive_avg_pool2d], filter_fn + ) + partitions = list(itertools.chain.from_iterable(module_partitions.values())) + annotated_partitions = [] + for partition in partitions: + pool_node = partition.output_nodes[0] + if ( + pool_node.op != "call_function" + or pool_node.target != torch.ops.aten.adaptive_avg_pool2d.default + ): + raise ValueError(f"{pool_node} is not an aten adaptive_avg_pool2d operator") + + if _is_annotated([pool_node]): + continue + + annotated_partitions.append(partition.nodes) + input_act = pool_node.args[0] + assert isinstance(input_act, Node) + + # only annotate input output sharing operator + # when the output of the input node is annotated + if ( + "quantization_annotation" not in input_act.meta + or not input_act.meta["quantization_annotation"]._annotated + or input_act.meta["quantization_annotation"].output_qspec is None + ): + input_act_qspec = get_input_act_qspec(quantization_config) + else: + input_act_qspec = SharedQuantizationSpec(input_act) + + # output sharing with input + output_act_qspec = SharedQuantizationSpec((input_act, pool_node)) + pool_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act: input_act_qspec, + }, + output_qspec=output_act_qspec, + _annotated=True, + ) + return annotated_partitions + + +def _is_input_large_scalar(node: Node, gm: torch.fx.GraphModule): + """Check if input is a large scalar value. So that we can skip quantization for the node + since histc op (in HistogramObserver) only works for values up to certain upper bound + """ + if node.op == "get_attr": + qualified_name = str(node.target) + module_path, _, name = qualified_name.rpartition(".") + submod = gm.get_submodule(module_path) + tensor = getattr(submod, name) + # torch.histc works until this upper bound + HISTC_UPPER_BOUND = 3.4028235e15 + return tensor.numel() == 1 and abs(tensor.item()) > HISTC_UPPER_BOUND + return False + + +def _is_input_non_float_tensor(node: Node): + """Check if the input is not a float tensor, so that we can skip quantization for the node + since observers only works with float Tensors + """ + if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor): + return True + return node.meta["val"].dtype != torch.float32 + + +@register_annotator("add_relu") +def _annotate_add_relu( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + annotated_partitions = [] + for node in gm.graph.nodes: + if node.op != "call_function" or node.target not in [ + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + ]: + continue + relu_node = node + maybe_add = node.args[0] + if ( + not isinstance(maybe_add, Node) + or maybe_add.op != "call_function" + or maybe_add.target + not in [ + torch.ops.aten.add.Tensor, + torch.ops.aten.add_.Tensor, + ] + ): + continue + + add_node = maybe_add + + if len(add_node.users) > 1: + # add can't be fused with ReLU if the result of add is being used + # else where in the graph + continue + + partition = [relu_node, add_node] + + if _is_annotated(partition): + continue + + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + input_act_qspec = get_input_act_qspec(quantization_config) + output_act_qspec = get_output_act_qspec(quantization_config) + + input_qspec_map = {} + input_act0 = add_node.args[0] + if isinstance(input_act0, Node): + if _is_input_large_scalar(input_act0, gm): + continue + if _is_input_non_float_tensor(input_act0): + continue + partition.append(input_act0) + input_qspec_map[input_act0] = input_act_qspec + + input_act1 = add_node.args[1] + if isinstance(input_act1, Node): + if _is_input_large_scalar(input_act1, gm): + continue + if _is_input_non_float_tensor(input_act1): + continue + partition.append(input_act1) + input_qspec_map[input_act1] = input_act_qspec + + add_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + ) + relu_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=output_act_qspec, + _annotated=True, + ) + annotated_partitions.append(partition) + return annotated_partitions + + +@register_annotator("add") +def _annotate_add( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + annotated_partitions = [] + for node in gm.graph.nodes: + if node.op != "call_function" or node.target not in [ + torch.ops.aten.add.Tensor, + torch.ops.aten.add_.Tensor, + ]: + continue + add_node = node + partition = [add_node] + + if _is_annotated(partition): + continue + + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + input_act_qspec = get_input_act_qspec(quantization_config) + output_act_qspec = get_output_act_qspec(quantization_config) + + input_qspec_map = {} + input_act0 = add_node.args[0] + if isinstance(input_act0, Node): + if _is_input_large_scalar(input_act0, gm): + continue + if _is_input_non_float_tensor(input_act0): + continue + input_qspec_map[input_act0] = input_act_qspec + partition.append(input_act0) + + input_act1 = add_node.args[1] + if isinstance(input_act1, Node): + if _is_input_large_scalar(input_act1, gm): + continue + if _is_input_non_float_tensor(input_act1): + continue + input_qspec_map[input_act1] = input_act_qspec + partition.append(input_act1) + + add_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qspec, + _annotated=True, + ) + annotated_partitions.append(partition) + return annotated_partitions + + +@register_annotator("mul_relu") +def _annotate_mul_relu( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + annotated_partitions = [] + for node in gm.graph.nodes: + if node.op != "call_function" or node.target not in [ + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + ]: + continue + relu_node = node + maybe_mul = node.args[0] + if ( + not isinstance(maybe_mul, Node) + or maybe_mul.op != "call_function" + or maybe_mul.target + not in [ + torch.ops.aten.mul.Tensor, + torch.ops.aten.mul_.Tensor, + ] + ): + continue + + mul_node = maybe_mul + if len(mul_node.users) > 1: + # mul can't be fused with ReLU if the result of mul is being used + # else where in the graph + continue + + partition = [relu_node, mul_node] + + if _is_annotated(partition): + continue + + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + input_act_qspec = get_input_act_qspec(quantization_config) + output_act_qspec = get_output_act_qspec(quantization_config) + + input_qspec_map = {} + input_act0 = mul_node.args[0] + if isinstance(input_act0, Node): + if _is_input_large_scalar(input_act0, gm): + continue + if _is_input_non_float_tensor(input_act0): + continue + partition.append(input_act0) + input_qspec_map[input_act0] = input_act_qspec + + input_act1 = mul_node.args[1] + if isinstance(input_act1, Node): + if _is_input_large_scalar(input_act1, gm): + continue + if _is_input_non_float_tensor(input_act1): + continue + partition.append(input_act1) + input_qspec_map[input_act1] = input_act_qspec + + mul_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + ) + relu_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=output_act_qspec, + _annotated=True, + ) + annotated_partitions.append(partition) + return annotated_partitions + + +@register_annotator("mul") +def _annotate_mul( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + annotated_partitions = [] + for node in gm.graph.nodes: + if node.op != "call_function" or node.target not in [ + torch.ops.aten.mul.Tensor, + torch.ops.aten.mul_.Tensor, + ]: + continue + + mul_node = node + partition = [mul_node] + if _is_annotated(partition): + continue + + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + input_act_qspec = get_input_act_qspec(quantization_config) + output_act_qspec = get_output_act_qspec(quantization_config) + + input_qspec_map = {} + input_act0 = mul_node.args[0] + if isinstance(input_act0, Node): + if _is_input_large_scalar(input_act0, gm): + continue + if _is_input_non_float_tensor(input_act0): + continue + input_qspec_map[input_act0] = input_act_qspec + partition.append(input_act0) + + input_act1 = mul_node.args[1] + if isinstance(input_act1, Node): + if _is_input_large_scalar(input_act1, gm): + continue + if _is_input_non_float_tensor(input_act1): + continue + input_qspec_map[input_act1] = input_act_qspec + partition.append(input_act0) + + mul_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qspec, + _annotated=True, + ) + annotated_partitions.append(partition) + return annotated_partitions + + +# TODO: remove Optional in return type, fix annotated_partitions logic +@register_annotator("cat") +def _annotate_cat( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + cat_partitions = get_source_partitions(gm.graph, [torch.cat], filter_fn) + cat_partitions = list(itertools.chain.from_iterable(cat_partitions.values())) + annotated_partitions = [] + for cat_partition in cat_partitions: + cat_node = cat_partition.output_nodes[0] + if _is_annotated([cat_node]): + continue + + if cat_node.target != torch.ops.aten.cat.default: + # TODO: change this to AnnotationException + raise Exception( # noqa: TRY002 + f"Expected cat node: torch.ops.aten.cat.default, but found {cat_node.target}" + " please check if you are calling the correct capture API" + ) + + annotated_partitions.append(cat_partition.nodes) + + input_act_qspec = get_input_act_qspec(quantization_config) + inputs = cat_node.args[0] + + input_qspec_map = {} + input_act0 = inputs[0] # type: ignore[index] + if isinstance(input_act0, Node): + input_qspec_map[input_act0] = input_act_qspec + + shared_with_input0_qspec = SharedQuantizationSpec((input_act0, cat_node)) # type: ignore[arg-type] + for input_act in inputs[1:]: # type: ignore[index, union-attr] + if input_act not in input_qspec_map: + input_qspec_map[input_act] = shared_with_input0_qspec # type: ignore[index] + + output_act_qspec = shared_with_input0_qspec + + cat_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qspec, + _annotated=True, + ) + return annotated_partitions + + +def _is_share_obs_or_fq_op(op: Callable) -> bool: + return op in [ + torch.ops.aten.relu.default, + torch.ops.aten.hardtanh.default, + torch.ops.aten.hardtanh_.default, + torch.ops.aten.max_pool2d.default, + torch.ops.aten.mean.default, + torch.ops.aten.mean.dim, + torch.ops.aten.permute.default, + torch.ops.aten.permute_copy.default, + torch.ops.aten.squeeze.dim, + torch.ops.aten.squeeze_copy.dim, + # TODO: remove? + torch.ops.aten.adaptive_avg_pool2d.default, + torch.ops.aten.view_copy.default, + torch.ops.aten.view.default, + torch.ops.aten.slice_copy.Tensor, + torch.ops.aten.flatten.using_ints, + ] + + +def propagate_annotation(model: torch.fx.GraphModule) -> None: + for n in model.graph.nodes: + if n.op != "call_function" or not _is_share_obs_or_fq_op(n.target): + continue + + prev_node = n.args[0] + if not isinstance(prev_node, Node): + continue + + quantization_annotation = prev_node.meta.get("quantization_annotation", None) + if not quantization_annotation: + continue + + output_qspec = quantization_annotation.output_qspec + if not output_qspec: + continue + + # make sure current node is not annotated + if ( + "quantization_annotation" in n.meta + and n.meta["quantization_annotation"]._annotated + ): + continue + + shared_qspec = SharedQuantizationSpec(prev_node) + # propagate the previous output_qspec to the current node + n.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + prev_node: shared_qspec, + }, + output_qspec=shared_qspec, + _annotated=True, + ) + + +# TODO: make the list of ops customizable +def _convert_scalars_to_attrs(model: torch.fx.GraphModule) -> torch.fx.GraphModule: + for n in model.graph.nodes: + if n.op != "call_function" or n.target not in [ + torch.ops.aten.add.Tensor, + torch.ops.aten.mul.Tensor, + ]: + continue + args = list(n.args) + new_args = [] + for i in range(len(args)): + if isinstance(args[i], torch.fx.Node): + new_args.append(args[i]) + continue + prefix = "_tensor_constant_" + get_new_attr_name = get_new_attr_name_with_prefix(prefix) + tensor_constant_name = get_new_attr_name(model) + float_tensor = torch.tensor(float(args[i])) + model.register_buffer(tensor_constant_name, float_tensor) + fake_mode = n.meta["val"].fake_mode + with model.graph.inserting_before(n): + get_attr_node = model.graph.create_node( + "get_attr", tensor_constant_name, (), {} + ) + get_attr_node.meta["val"] = fake_mode.from_tensor( + float_tensor, static_shapes=True + ) + new_args.append(get_attr_node) + n.args = tuple(new_args) + model.recompile() + return model diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py b/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..f79bef0f21eb1fb17803a319142753fe5b98959e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py @@ -0,0 +1,120 @@ +# mypy: allow-untyped-defs +import functools +from typing import Any, Optional, TYPE_CHECKING + +import torch +from torch.ao.quantization.observer import HistogramObserver, PerChannelMinMaxObserver +from torch.ao.quantization.quantizer.quantizer import QuantizationSpec +from torch.ao.quantization.quantizer.x86_inductor_quantizer import ( + _is_any_annotated, + FilterFn, + int8_in_int8_out_ops, + X86InductorQuantizer, +) +from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import QuantizationConfig +from torch.fx import Node + + +if TYPE_CHECKING: + from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor + +__all__ = [ + "XPUInductorQuantizer", + "get_default_xpu_inductor_quantization_config", +] + + +@functools.lru_cache +def get_default_xpu_inductor_quantization_config(): + extra_args: dict[str, Any] = {"eps": 2**-12} + act_observer_or_fake_quant_ctr = HistogramObserver + act_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args( + **extra_args + ), + ) + + weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( + PerChannelMinMaxObserver + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_channel_symmetric, + ch_axis=0, # 0 corresponding to weight shape = (oc, ic, kh, kw) of conv + is_dynamic=False, + observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args( + **extra_args + ), + ) + + bias_quantization_spec = None # will use placeholder observer by default + quantization_config = QuantizationConfig( + act_quantization_spec, + act_quantization_spec, + weight_quantization_spec, + bias_quantization_spec, + False, + ) + return quantization_config + + +class XPUInductorQuantizer(X86InductorQuantizer): + """ + XPUInductorQuantizer is a class designed to facilitate + quantization capability at Intel GPU backend. The class + highly reuses the existing implementation of + X86InductorQuantizer as both are intended to take advantage + of the optimized kernels in oneDNN library. + """ + + def __init__(self) -> None: + super().__init__() + + """ + Following annotate_xx overrides the impls in base class, as + no XPU implementation for these operators currently. We would + gradually enable the XPU implementation and remove following + overrides. We keep the annotate methods but make the function + body empty, aiming to let `_generate_qdq_quantized_model` + generate qdq around op and graph execute on fp32 dtype for + unspported operators. + """ + + def _annotate_qat_conv2d_fusion_pattern( + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ): + pass + + def _annotate_maxpool2d( + self, + node: Node, + quantization_config: Optional[QuantizationConfig], + ) -> None: + """ + Here we skip the annotate logic for maxpool at XPU backend + as the quantized::max_pool2d is only implemented for CPU. + """ + return + + def _annotate_output_for_int8_in_int8_out_pattern( + self, + node: Node, + ) -> None: + if (node.target in int8_in_int8_out_ops) and (_is_any_annotated([node])): + if node.target == torch.ops.aten.max_pool2d.default: + return + else: + input_node = node.all_input_nodes[0] + self._annotate_output_share_observer_as_input(input_node, node) + return diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/stubs.py b/phivenv/Lib/site-packages/torch/ao/quantization/stubs.py new file mode 100644 index 0000000000000000000000000000000000000000..8f50a98ac9c1356fa00e6ab0941417e79a3f2b31 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/stubs.py @@ -0,0 +1,74 @@ +from typing import Any, Optional + +import torch +from torch import nn +from torch.ao.quantization import QConfig + + +__all__ = ["QuantStub", "DeQuantStub", "QuantWrapper"] + + +class QuantStub(nn.Module): + r"""Quantize stub module, before calibration, this is same as an observer, + it will be swapped as `nnq.Quantize` in `convert`. + + Args: + qconfig: quantization configuration for the tensor, + if qconfig is not provided, we will get qconfig from parent modules + """ + + def __init__(self, qconfig: Optional[QConfig] = None): + super().__init__() + if qconfig: + self.qconfig = qconfig + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + + +class DeQuantStub(nn.Module): + r"""Dequantize stub module, before calibration, this is same as identity, + this will be swapped as `nnq.DeQuantize` in `convert`. + + Args: + qconfig: quantization configuration for the tensor, + if qconfig is not provided, we will get qconfig from parent modules + """ + + def __init__(self, qconfig: Optional[Any] = None): + super().__init__() + if qconfig: + self.qconfig = qconfig + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + + +class QuantWrapper(nn.Module): + r"""A wrapper class that wraps the input module, adds QuantStub and + DeQuantStub and surround the call to module with call to quant and dequant + modules. + + This is used by the `quantization` utility functions to add the quant and + dequant modules, before `convert` function `QuantStub` will just be observer, + it observes the input tensor, after `convert`, `QuantStub` + will be swapped to `nnq.Quantize` which does actual quantization. Similarly + for `DeQuantStub`. + """ + + quant: QuantStub + dequant: DeQuantStub + module: nn.Module + + def __init__(self, module: nn.Module): + super().__init__() + qconfig = getattr(module, "qconfig", None) + self.add_module("quant", QuantStub(qconfig)) + self.add_module("dequant", DeQuantStub(qconfig)) + self.add_module("module", module) + self.train(module.training) + + def forward(self, X: torch.Tensor) -> torch.Tensor: + X = self.quant(X) + X = self.module(X) + return self.dequant(X) diff --git a/phivenv/Lib/site-packages/torch/ao/quantization/utils.py b/phivenv/Lib/site-packages/torch/ao/quantization/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..02df826c4be6bdf3554a8dee7c0f96581075f85e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/ao/quantization/utils.py @@ -0,0 +1,838 @@ +# mypy: allow-untyped-defs +""" +Utils shared by different modes of quantization (eager/graph) +""" + +import functools +import warnings +from collections import OrderedDict +from inspect import getfullargspec, signature +from typing import Any, Callable, Optional, Union + +import torch +from torch.ao.quantization.quant_type import QuantType +from torch.fx import Node +from torch.nn.utils.parametrize import is_parametrized + + +NodePattern = Union[tuple[Node, Node], tuple[Node, tuple[Node, Node]], Any] +NodePattern.__module__ = "torch.ao.quantization.utils" + +# This is the Quantizer class instance from torch/quantization/fx/quantize.py. +# Define separately to prevent circular imports. +# TODO(future PR): improve this. +# make this public once fixed (can't be public as is because setting the module directly +# doesn't work) +QuantizerCls = Any + +# Type for fusion patterns, it can be more complicated than the following actually, +# see pattern.md for docs +# TODO: not sure if typing supports recursive data types +Pattern = Union[ + Callable, tuple[Callable, Callable], tuple[Callable, tuple[Callable, Callable]], Any +] +Pattern.__module__ = "torch.ao.quantization.utils" + + +# TODO: maybe rename this to MatchInputNode +class MatchAllNode: + """A node pattern that matches all nodes, used in defining + fusion patterns in FX Graph Mode Quantization + """ + + +module_type_list = { + torch.nn.ReLU, + torch.nn.ReLU6, + torch.nn.AdaptiveAvgPool1d, + torch.nn.AdaptiveAvgPool2d, + torch.nn.AdaptiveAvgPool3d, + torch.nn.AvgPool1d, + torch.nn.AvgPool2d, + torch.nn.AvgPool3d, + torch.nn.MaxPool1d, + torch.nn.MaxPool2d, + torch.nn.MaxPool3d, + torch.nn.Identity, + torch.nn.Hardsigmoid, + torch.nn.Sigmoid, + torch.nn.Tanh, +} +func_list = { + torch.nn.functional.adaptive_avg_pool1d, + torch.nn.functional.adaptive_avg_pool2d, + torch.nn.functional.adaptive_avg_pool3d, + torch.nn.functional.elu, + torch.nn.functional.hardswish, + torch.nn.functional.instance_norm, + torch.nn.functional.layer_norm, + torch.nn.functional.leaky_relu, + torch.nn.functional.silu, + torch.nn.functional.mish, + torch.nn.functional.dropout, + torch.nn.functional.max_pool1d, + torch.nn.functional.max_pool2d, + torch.nn.functional.max_pool3d, + torch.nn.functional.relu, + torch.nn.functional.hardtanh, + torch.nn.functional.hardtanh_, + torch.nn.functional.hardsigmoid, + torch.nn.functional.sigmoid, + torch.transpose, + torch.repeat_interleave, + torch.sigmoid, + torch.squeeze, + torch.stack, + torch.sum, + torch.tanh, + torch.unsqueeze, + torch.cat, +} +method_list = { + torch.mean, + "relu", + "relu_", + "contiguous", + "detach", + "detach_", + "hardsigmoid", + "hardsigmoid_", + "permute", + "repeat", + "repeat_interleave", + "reshape", + "resize_", + "shape", + "sigmoid", + "sigmoid_", + "size", + "squeeze", + "squeeze_", + "tanh", + "tanh_", + "transpose", + "unsqueeze", + "unsqueeze_", + "view", +} + + +# TODO: not used now, remove +def check_node(node, modules): + # TODO: reuse is_fixed_qparam_node after we move this function to _lower_to_native_backend.py + is_call_function = node.op == "call_function" and node.target in func_list + is_call_method = node.op == "call_method" and node.target in method_list + is_call_module = ( + node.op == "call_module" and type(modules[str(node.target)]) in module_type_list + ) + return is_call_function, is_call_method, is_call_module + + +def get_combined_dict(default_dict, additional_dict): + """ + Combines two dictionaries. + + This function takes two dictionaries as input and returns a new dictionary + that contains all the key-value pairs from both input dictionaries. + If there are any duplicate keys in the `additional_dict`, the values + from the `additional_dict` will overwrite those in the `default_dict`. + Args: + default_dict (dict): The main dictionary that will be used as the base + additional_dict (dict): The dictionary used to update `default_dict` + + Returns: + dict: The resulting dictionary + Example: + >>> x = dict(a=1, b=1) + >>> y = dict(b=2, c=3) + >>> get_combined_dict(x, y) + {'a': 1, 'b': 2, 'c': 3} + """ + d = default_dict.copy() + d.update(additional_dict) + return d + + +def is_per_tensor(qscheme): + return qscheme == torch.per_tensor_affine or qscheme == torch.per_tensor_symmetric + + +def is_per_channel(qscheme): + return qscheme in [ + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + torch.per_channel_symmetric, + ] + + +def getattr_from_fqn(obj: Any, fqn: str) -> Any: + """ + Given an obj and a fqn such as "foo.bar.baz", returns gm.foo.bar.baz. + """ + return functools.reduce(getattr, fqn.split("."), obj) + + +def to_underlying_dtype(qdtype): + DTYPE_MAPPING = { + torch.quint8: torch.uint8, + torch.qint8: torch.int8, + torch.qint32: torch.int32, + torch.quint4x2: torch.uint8, + torch.quint2x4: torch.uint8, + torch.uint8: torch.uint8, + torch.int8: torch.int8, + torch.uint16: torch.uint16, + torch.int16: torch.int16, + torch.int32: torch.int32, + torch.float8_e5m2: torch.float8_e5m2, + torch.float8_e4m3fn: torch.float8_e4m3fn, + } + assert qdtype in DTYPE_MAPPING, "Unsupported dtype: " + str(qdtype) + return DTYPE_MAPPING[qdtype] + + +def get_qparam_dict(observer_or_fake_quant): + from torch.ao.quantization.observer import PlaceholderObserver + + qscheme = getattr(observer_or_fake_quant, "qscheme", None) + dtype = observer_or_fake_quant.dtype + qparams = {"qscheme": qscheme, "dtype": dtype} + + if not qscheme or isinstance(observer_or_fake_quant, PlaceholderObserver): + return {"qscheme": None, "dtype": dtype} + + if is_per_tensor(qscheme): + qscheme = torch.per_tensor_affine + elif is_per_channel(qscheme): + # change symmetric to affine since we do not have symmetric + # quantized Tensor + if qscheme == torch.per_channel_symmetric: + qscheme = torch.per_channel_affine + qparams["axis"] = observer_or_fake_quant.ch_axis + else: + raise RuntimeError(f"Unrecognized qscheme: {qscheme}") + # update qscheme, since we don't have symmetric quant qscheme + # in quantized Tensor + qparams["qscheme"] = qscheme + + scale, zero_point = observer_or_fake_quant.calculate_qparams() + qparams["scale"] = scale + qparams["zero_point"] = zero_point + + if hasattr(observer_or_fake_quant, "quant_min"): + qparams["quant_min"] = observer_or_fake_quant.quant_min + if hasattr(observer_or_fake_quant, "quant_max"): + qparams["quant_max"] = observer_or_fake_quant.quant_max + + return qparams + + +def get_swapped_custom_module_class( + custom_module, custom_module_class_mapping, qconfig +): + """Get the observed/quantized custom module class that we need + to swap `custom_module` to + Input: + custom_module: input, can be an instance of either a float or observed custom module + custom_module_class_mapping: the float to observed or observed to quantized custom module class mapping + qconfig: qconfig configured for the custom module + + Output: + corresponding observed/quantized custom module class for input custom module instance + """ + quant_type = get_quant_type(qconfig) + class_mapping = custom_module_class_mapping.get(quant_type, {}) + assert type(custom_module) in class_mapping, ( + "did not find corresponding observed " + f"module class for {type(custom_module)} in mapping: {class_mapping}" + ) + return class_mapping[type(custom_module)] + + +def activation_dtype(qconfig): + assert qconfig is not None + activation = qconfig.activation() + return activation.dtype + + +def weight_dtype(qconfig): + assert qconfig is not None + weight = qconfig.weight() + return weight.dtype + + +def activation_is_statically_quantized(qconfig): + """Given a qconfig, decide if the activation needs to be + quantized or not, this includes quantizing to quint8, qint8 and qint32 and float16 + """ + return activation_dtype(qconfig) in [ + torch.quint8, + torch.qint8, + torch.qint32, + torch.float16, + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.float8_e5m2, + torch.float8_e4m3fn, + ] and (not activation_is_dynamically_quantized(qconfig)) + + +def activation_is_dynamically_quantized(qconfig): + """Given a qconfig, decide if the activation needs to be + dynamically quantized or not, this includes dynamically quantizing to + quint8, qint8 and float16 + """ + _activation_dtype, _, activation_is_dynamic = get_qconfig_dtypes(qconfig) + return activation_is_dynamic + + +def activation_is_int8_quantized(qconfig): + """Given a qconfig, decide if the activation needs to be + quantized to int8 or not, this includes quantizing to quint8, qint8 + """ + return activation_dtype(qconfig) in [ + torch.quint8, + torch.qint8, + torch.uint8, + torch.int8, + ] + + +def activation_is_int32_quantized(qconfig): + """Given a qconfig, decide if the activation needs to be + quantized to int32 or not + """ + return activation_dtype(qconfig) in [torch.qint32, torch.int32] + + +def weight_is_quantized(qconfig): + """Given a qconfig, decide if the weight needs to be + quantized or not + """ + return weight_dtype(qconfig) in [ + torch.quint8, + torch.qint8, + torch.float16, + torch.quint4x2, + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.float8_e5m2, + torch.float8_e4m3fn, + ] + + +def weight_is_statically_quantized(qconfig): + """Given a qconfig, decide if the weight needs to be statically + quantized or not + """ + return weight_dtype(qconfig) in [torch.quint8, torch.qint8, torch.uint8, torch.int8] + + +def op_is_int8_dynamically_quantized(qconfig) -> bool: + """Given a qconfig, returns True if this op is using int8 dynamic + quantization + """ + activation_dtype, weight_dtype, activation_is_dynamic = get_qconfig_dtypes(qconfig) + return ( + activation_dtype in [torch.quint8, torch.uint8] + and + # for now, the lines below assume fbgemm or qnnpack + weight_dtype in [torch.qint8, torch.int8] + and activation_is_dynamic + ) + + +def get_qconfig_dtypes(qconfig): + r"""returns the qconfig tuple for qconfig: + (activation_dtype, weight_dtype, activation_is_dynamic) + """ + assert qconfig is not None + activation = qconfig.activation() + weight = qconfig.weight() + act_is_dynamic = getattr(activation, "is_dynamic", False) + return (activation.dtype, weight.dtype, act_is_dynamic) + + +def get_quant_type(qconfig): + assert qconfig is not None + activation = qconfig.activation() + weight = qconfig.weight() + static_dtypes = [ + torch.quint8, + torch.qint8, + torch.quint4x2, + torch.qint32, + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.float8_e5m2, + torch.float8_e4m3fn, + ] + if weight.dtype in static_dtypes: + if hasattr(activation, "is_dynamic") and activation.is_dynamic: + return QuantType.DYNAMIC + elif activation.dtype in static_dtypes: + return QuantType.STATIC + else: + return QuantType.WEIGHT_ONLY + + if weight.dtype == torch.float16: + if hasattr(activation, "is_dynamic") and activation.is_dynamic: + return QuantType.DYNAMIC + elif activation.dtype == torch.float16: + return QuantType.STATIC + + raise Exception( # noqa: TRY002 + f"Unrecognized dtype combination in get_quant_type: activation({activation.dtype})," + f"weight({weight.dtype})" + ) + + +def check_min_max_valid(min_val: torch.Tensor, max_val: torch.Tensor) -> bool: + """Checks if the given minimum and maximum values are valid, meaning that + they exist and the min value is less than the max value. + """ + if min_val.numel() == 0 or max_val.numel() == 0: + warnings.warn( + "must run observer before calling calculate_qparams. " + + "Returning default values." + ) + return False + + if min_val.dim() == 0 or max_val.dim() == 0: + if min_val == float("inf") and max_val == float("-inf"): + warnings.warn( + "must run observer before calling calculate_qparams. " + + "Returning default values." + ) + + return False + + assert min_val <= max_val, f"min {min_val} should be less than max {max_val}" + else: + assert torch.all(min_val <= max_val), ( + f"min {min_val} should be less than max {max_val}" + ) + + return True + + +def calculate_qmin_qmax( + quant_min: int, + quant_max: int, + has_customized_qrange: bool, + dtype: torch.dtype, + reduce_range: bool, +) -> tuple[int, int]: + r"""Calculates actual qmin and qmax based on the quantization range, + observer datatype and if range is reduced. + """ + # TODO(jerryzh): Figure out why custom quant_min/quant_max are still adjusted. + if has_customized_qrange: + # This initialization here is to be resolve TorchScript compilation issues and allow + # using of refinement to decouple initial_qmin and initial_qmax from quantization range. + # The actual values of initial_qmin and initial_qmax will be reset below. + if dtype in [torch.qint32, torch.int32]: + initial_quant_min, initial_quant_max = 0, 2**32 - 1 + else: + initial_quant_min, initial_quant_max = 0, 255 + # The following assignment of self.qmin and self.qmax to the local variables and the if check refine the + # attribute from Optional valid integers for use, based on TorchScript's requirements. + custom_quant_min, custom_quant_max = quant_min, quant_max + if custom_quant_min is not None and custom_quant_max is not None: + initial_quant_min, initial_quant_max = ( + custom_quant_min, + custom_quant_max, + ) + + qrange_len = initial_quant_max - initial_quant_min + 1 + if dtype in [torch.qint8, torch.int8]: + assert 0 < qrange_len <= 256, ( + "quantization range should be positive and not exceed the maximum bit range (=256)." + ) + elif dtype in [torch.qint32, torch.int32]: + assert 0 < qrange_len <= 2**32, ( + "quantization range should be positive and not exceed the maximum bit range (=4294967296)." + ) + if reduce_range: + quant_min, quant_max = quant_min // 2, quant_max // 2 + else: + # Fallback onto default 8-bit qmin and qmax calculation if dynamic range is not used. + if dtype in [torch.qint8, torch.int8]: + if reduce_range: + quant_min, quant_max = -64, 63 + else: + quant_min, quant_max = -128, 127 + elif dtype in [torch.quint8, torch.uint8]: + if reduce_range: + quant_min, quant_max = 0, 127 + else: + quant_min, quant_max = 0, 255 + elif dtype in [torch.qint32, torch.int32]: + quant_min, quant_max = -1 * (2**31), (2**31) - 1 + elif dtype in [torch.uint16]: + quant_min, quant_max = 0, 2**16 - 1 + elif dtype in [torch.int16]: + quant_min, quant_max = -(2**15), 2**15 - 1 + else: + quant_min, quant_max = 0, 15 + return quant_min, quant_max + + +def _parent_name(target): + """ + Turn 'foo.bar' into ['foo', 'bar'] + """ + r = target.rsplit(".", 1) + if len(r) == 1: + return "", r[0] + else: + return r[0], r[1] + + +def has_no_children_ignoring_parametrizations(module): + """ + Checks if module._modules is empty or + if module is a parametrization, checks that module._modules only has + the 'parametrizations' module + """ + if len(module._modules) == 0: + return True + elif is_parametrized(module): + return len(module._modules) == 1 and "parametrizations" in module._modules + else: + return False + + +def _get_path_of_module( + root: torch.nn.Module, submodule: torch.nn.Module +) -> Optional[str]: + """Get the path (fully qualified name) of a submodule + + Example:: + + >> class M(torch.nn.Module): + def __init__(self) -> None: + self.linear = torch.nn.Linear(5, 5) + def forward(self, x): + return self.linear(x) + + >> m = M() + >> l = m.linear + >> _get_path_of_module(m, l) + "linear" + """ + for n, p in root.named_modules(): + if submodule is p: + return n + return None + + +def _get_signature_locals(f: Callable, loc: dict[str, Any]) -> dict[str, Any]: + """Get local keyword arguments + + Example:: + + >> def f(self, a, b=9): + pass + >> loc = {"a": 6, "c": 7} + >> _get_signature_locals(f, loc) + {"a": 6} + """ + return {k: v for k, v in loc.items() if k in signature(f).parameters} + + +def _get_default_kwargs(f: Callable) -> "OrderedDict[str, Any]": + """Get all default keyword arguments from function signature + + Example:: + + >> def f(self, a, b=9): + pass + >> _get_default_kwargs(f) + {"b": 9} + """ + kwargs = {} + for name, param in signature(f).parameters.items(): + if param.default is not param.empty: + kwargs[name] = param.default + elif param.kind is param.VAR_POSITIONAL: + kwargs[name] = () + elif param.kind is param.VAR_KEYWORD: + kwargs[name] = {} + return OrderedDict(kwargs) + + +def _normalize_kwargs(func: Callable, loc: dict[str, Any]) -> "OrderedDict[str, Any]": + """Given a function and local function arguments, normalize the keyword + arguments by filling in default arguments from function signature + + Example:: + + >> def f(self, key1=3, key2=3): + pass + >> loc = {"key2": 6} + >> _normalize_kwargs(f, loc) + {"key1": 3, "key2": 6} + """ + default_kwargs = _get_default_kwargs(func) + local_kwargs = _get_signature_locals(func, loc) + normalized_kwargs = default_kwargs.copy() + for attr, val in local_kwargs.items(): + if attr in normalized_kwargs: + # override the default keyword arguments + normalized_kwargs[attr] = val + return normalized_kwargs + + +def validate_qmin_qmax(quant_min: int, quant_max: int) -> None: + r"""Validates that the user-specified quantization range is properly initialized + and within the given bound supported by the observer dtype. + + To accommodate lower-bit quantization with respect to the existing torch.qint8 and + torch.quint8 datatypes, the user can choose to use dynamic quantization range by passing + in a tuple of initial qmin and qmax values. One use case is these customized qmin and qmax + values are used to calculate static estimates of the scale and zero point for aggressive lower-bit + fake quantization. These estimates are compared against parameters learned through backpropagation. + The related literatures for scale and zero point via backpropagation are as follows: + + Learned Step Size Quantization: https://openreview.net/pdf?id=rkgO66VKDS + Trained Quantization Thresholds: https://arxiv.org/pdf/1903.08066.pdf + """ + # The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted + # based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer. + assert quant_min <= 0 <= quant_max, ( + "Used-specified quantization range must include 0." + ) + assert quant_min < quant_max, ( + "qmin must be strictly less than qmax for user-specified quantization range." + ) + + +# Functionally equivalent to '_calculate_qparams' in observer.py. Observers must be torchscriptable however and qscheme +# as far as I can tell is not allowed to passed as a parameter in torchscript functions. This makes refactoring observer +# to use this utility a massive pain and very gross. For now Im opting just to duplicate as this code seems unlikey to change +# (last update over 1 year ago) and when torchscript is fully deprecated we can refactor. TODO(jakeszwe, jerryzh168) +def determine_qparams( + min_val: torch.Tensor, + max_val: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + eps: torch.Tensor, + has_customized_qrange: bool, + qscheme: torch.qscheme = torch.per_tensor_affine, +) -> tuple[torch.Tensor, torch.Tensor]: + r"""Calculates the quantization parameters, given min and max + value tensors. Works for both per tensor and per channel cases + + Args: + min_val: Minimum values per channel + max_val: Maximum values per channel + + Returns: + scales: Scales tensor of shape (#channels,) + zero_points: Zero points tensor of shape (#channels,) + """ + if not check_min_max_valid(min_val, max_val): + return torch.tensor([1.0], device=min_val.device.type), torch.tensor( + [0], device=min_val.device.type + ) + + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + + device = min_val_neg.device + scale = torch.ones(min_val_neg.size(), dtype=torch.double, device=device) + zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) + eps = eps.to(device) + + if qscheme == torch.per_tensor_symmetric or qscheme == torch.per_channel_symmetric: + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scale = max_val_pos / (float(quant_max - quant_min) / 2) + scale = torch.max(scale, eps) + if dtype in [torch.uint8, torch.quint8]: + if has_customized_qrange: + # When customized quantization range is used, down-rounded midpoint of the range is chosen. + zero_point = zero_point.new_full( + zero_point.size(), (quant_min + quant_max) // 2 + ) + else: + zero_point = zero_point.new_full(zero_point.size(), 128) + elif qscheme == torch.per_channel_affine_float_qparams: + scale = (max_val - min_val) / float(quant_max - quant_min) + scale = torch.where(scale > eps, scale, torch.ones_like(scale)) + # We use the quantize function + # xq = Round(Xf * inv_scale + zero_point), + # setting zero_point to (-1 * min *inv_scale) we get + # Xq = Round((Xf - min) * inv_scale) + zero_point = -1 * min_val / scale + else: + scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) + scale = torch.max(scale, eps) + zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int) + zero_point = torch.clamp(zero_point, quant_min, quant_max) + + # For scalar values, cast them to Tensors of size 1 to keep the shape + # consistent with default values in FakeQuantize. + if len(scale.shape) == 0: + # TODO: switch to scale.item() after adding JIT support + scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device) + if len(zero_point.shape) == 0: + # TODO: switch to zero_point.item() after adding JIT support + zero_point = torch.tensor( + [int(zero_point)], dtype=zero_point.dtype, device=device + ) + if qscheme == torch.per_channel_affine_float_qparams: + zero_point = torch.tensor( + [float(zero_point)], dtype=zero_point.dtype, device=device + ) + + return scale.to(torch.double), zero_point.to(torch.int64) + + +def _get_num_pos_args(f: Callable) -> int: + """Get number of positional args for a function + + Example:: + + >> def f(self, key1=3, key2=3): + pass + >> _get_num_pos_args(f) + 3 + """ + return len(getfullargspec(f).args) + + +def get_fqn_to_example_inputs( + model: torch.nn.Module, example_inputs: tuple[Any, ...] +) -> dict[str, tuple[Any, ...]]: + """Given a model and its example inputs, return a dictionary from + fully qualified name of submodules to example_inputs for that submodule, + e.g. {"linear1": (tensor1,), "linear2": (tensor2,), "sub": (tensor3,), + "sub.linear1": (tensor4,), ...} + + Used to make quantizing submodules easier now that FX Graph Mode Quantization requires + example inputs. + + Also works for keyword arguments with default values, we would flatten keyword + arguments as positional arguments and fill in the missing keyword args with default + values, e.g. if we have a forward function: + def forward(self, x, key1=3, key2=3): + ... + + and we call it with self.submodule(x, key2=6) + we'll get example_inputs: (x, 3, 6) + + user can also override `key1` with positional arguments as well: + for self.submodule(x, 5, key2=6) + we'll get: (x, 5, 6) + + variable positional arguments and variable positional keyword arguments in forward + function are not supported currently, so please make sure no submodules is using + them. + """ + root = model + fqn_to_example_inputs = {} + + def _patched_module_call(self, *args, **kwargs): + submodule_example_inputs = list(args).copy() + normalized_kwargs = _normalize_kwargs(self.forward, kwargs) + # minus 1 to skipping counting `self` + num_args = _get_num_pos_args(self.forward) - 1 + num_to_pop = num_args - len(submodule_example_inputs) + while num_to_pop and normalized_kwargs: + normalized_kwargs.popitem(last=False) + num_to_pop -= 1 + submodule_example_inputs.extend(normalized_kwargs.values()) + submodule_example_inputs_tuple = tuple(submodule_example_inputs) + fqn = _get_path_of_module(root, self) + if fqn is not None: + fqn_to_example_inputs[fqn] = submodule_example_inputs_tuple + return orig_module_call(self, *args, **kwargs) + + orig_module_call = torch.nn.Module.__call__ + torch.nn.Module.__call__ = _patched_module_call # type: ignore[method-assign] + try: + model(*example_inputs) + finally: + # restore the module call even if there is an exception + torch.nn.Module.__call__ = orig_module_call # type: ignore[method-assign] + return fqn_to_example_inputs + + +def _assert_and_get_unique_device(module: torch.nn.Module) -> Any: + """ + Returns the unique device for a module, or None if no device is found. + Throws an error if multiple devices are detected. + """ + devices = {p.device for p in module.parameters()} | { + p.device for p in module.buffers() + } + """ + As a temp workaround for AIMP HHC publish we added CPU check.remove it later. T163614564 + """ + if {torch.device("cpu"), torch.device("meta")} == devices: + warnings.warn( + "Both 'meta' and 'cpu' are present in the list of devices. Module can have one device. We Select 'cpu'." + ) + devices = {torch.device("cpu")} + "" + assert len(devices) <= 1, ( + "prepare only works with cpu or single-device CUDA modules, " + f"but got devices {devices}" + ) + device = next(iter(devices)) if len(devices) > 0 else None + return device + + +DEPRECATION_WARNING = ( + "torch.ao.quantization is deprecated and will be removed in 2.10. \n" + "For migrations of users: \n" + "1. Eager mode quantization (torch.ao.quantization.quantize, " + "torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode " + "quantize_ API instead \n" + "2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx," + "torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization " + "API instead (prepare_pt2e, convert_pt2e) \n" + "3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) \n" + "see https://github.com/pytorch/ao/issues/2259 for more details" +) + + +__all__ = [ + "NodePattern", + "Pattern", + "MatchAllNode", + "check_node", + "get_combined_dict", + "is_per_tensor", + "is_per_channel", + "getattr_from_fqn", + "get_qparam_dict", + "get_swapped_custom_module_class", + "activation_dtype", + "weight_dtype", + "activation_is_statically_quantized", + "activation_is_dynamically_quantized", + "activation_is_int8_quantized", + "activation_is_int32_quantized", + "weight_is_quantized", + "weight_is_statically_quantized", + "op_is_int8_dynamically_quantized", + "get_qconfig_dtypes", + "get_quant_type", + "check_min_max_valid", + "calculate_qmin_qmax", + "has_no_children_ignoring_parametrizations", + "get_fqn_to_example_inputs", + "to_underlying_dtype", + "determine_qparams", + "validate_qmin_qmax", + "DEPRECATION_WARNING", +] diff --git a/phivenv/Lib/site-packages/torch/autograd/__init__.py b/phivenv/Lib/site-packages/torch/autograd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e0feb3d54fc488e6faca30bf48df8f71a8278e7f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/autograd/__init__.py @@ -0,0 +1,606 @@ +# mypy: allow-untyped-defs +""" +``torch.autograd`` provides classes and functions implementing automatic differentiation of arbitrary scalar valued functions. + +It requires minimal changes to the existing code - you only need to declare :class:`Tensor` s +for which gradients should be computed with the ``requires_grad=True`` keyword. +As of now, we only support autograd for floating point :class:`Tensor` types ( +half, float, double and bfloat16) and complex :class:`Tensor` types (cfloat, cdouble). +""" + +import warnings +from collections.abc import Sequence +from typing import cast, Optional, Union + +import torch +from torch import _vmap_internals +from torch.overrides import handle_torch_function, has_torch_function, is_tensor_like +from torch.types import _size, _TensorOrTensors, _TensorOrTensorsOrGradEdge + +from . import forward_ad, functional, graph +from .anomaly_mode import detect_anomaly, set_detect_anomaly +from .function import Function, NestedIOFunction +from .grad_mode import ( + _force_original_view_tracking, + _unsafe_preserve_version_counter, + enable_grad, + inference_mode, + no_grad, + set_grad_enabled, + set_multithreading_enabled, +) +from .gradcheck import gradcheck, gradgradcheck +from .graph import _engine_run_backward +from .variable import Variable + + +__all__ = [ + "Variable", + "Function", + "backward", + "grad_mode", + "NestedIOFunction", + "detect_anomaly", + "enable_grad", + "grad", + "gradcheck", + "gradgradcheck", + "inference_mode", + "no_grad", + "set_detect_anomaly", + "set_grad_enabled", + "set_multithreading_enabled", + "variable", +] + +_OptionalTensor = Optional[torch.Tensor] +_ShapeorNestedShape = Union[_size, Sequence[_size], torch.Tensor] + + +def _calculate_shape( + output: Union[torch.Tensor, graph.GradientEdge], + grad: torch.Tensor, + is_grads_batched: bool, +) -> tuple[_ShapeorNestedShape, _ShapeorNestedShape]: + # is_same_size ensures that both tensors are either nested or non nested + # circular import + from torch.nested._internal.nested_tensor import NestedTensor + + if isinstance(output, graph.GradientEdge): + # We have already checked that we are not a C++ NestedTensor + if is_grads_batched: + raise RuntimeError("Batched grads are not supported with GradientEdge") + out_metadata = output.node._input_metadata[output.output_nr] + return torch.Size(out_metadata.shape), grad.shape + + if output.is_nested and not isinstance(output, NestedTensor): + if is_grads_batched: + raise RuntimeError("Batched grads are not supported with Nested Tensor.") + out_shape = output._nested_tensor_size() + grad_shape = grad._nested_tensor_size() + + return out_shape, grad_shape + + reg_out_shape = output.shape + reg_grad_shape = grad.shape if not is_grads_batched else grad.shape[1:] + return reg_out_shape, reg_grad_shape + + +def _make_grads( + outputs: Union[Sequence[torch.Tensor], Sequence[graph.GradientEdge]], + grads: Sequence[_OptionalTensor], + is_grads_batched: bool, +) -> tuple[_OptionalTensor, ...]: + new_grads: list[_OptionalTensor] = [] + for out, grad in zip(outputs, grads): + out = cast(Union[torch.Tensor, graph.GradientEdge], out) + out_size = None + out_device = None + + if isinstance(out, graph.GradientEdge): + out_metadata = out.node._input_metadata[out.output_nr] + out_size = torch.Size(out_metadata.shape) + out_dtype = out_metadata.dtype + out_device = out_metadata.device + out_is_nested = out_metadata.is_nested_tensor + if out_metadata.is_cpp_nested_tensor: + raise RuntimeError( + "C++ NestedTensor are not supported with GradientEdge" + ) + out_is_cpp_nested = False + else: + # circular import + from torch.nested._internal.nested_tensor import NestedTensor + + assert isinstance(out, torch.Tensor) + out_dtype = out.dtype + out_is_nested = out.is_nested + out_is_cpp_nested = out_is_nested and not isinstance(out, NestedTensor) + if not out_is_cpp_nested: + out_size = out.shape + + if isinstance(grad, torch.Tensor): + from torch.fx.experimental.symbolic_shapes import expect_true, sym_eq + + first_grad = grad if not is_grads_batched else grad[0] + + # TODO: We can remove this conditional once we uniformly use + # singleton int to represent jagged dimension, so that size() call + # on nested tensor works. + if out_is_cpp_nested: + assert isinstance(out, torch.Tensor) + shape_matches = torch.is_same_size(out, first_grad) + else: + # We need to do a regular size check, without going through + # the operator, to be able to handle unbacked symints + # (expect_true ensures we can deal with unbacked) + assert out_size is not None + shape_matches = expect_true(sym_eq(out_size, first_grad.size())) + + if not shape_matches: + out = cast(Union[torch.Tensor, graph.GradientEdge], out) # type: ignore[redundant-cast] + out_shape, grad_shape = _calculate_shape( + out, first_grad, is_grads_batched + ) + if is_grads_batched: + raise RuntimeError( + "If `is_grads_batched=True`, we interpret the first " + "dimension of each grad_output as the batch dimension. " + "The sizes of the remaining dimensions are expected to match " + "the shape of corresponding output, but a mismatch " + "was detected: grad_output[" + + str(grads.index(grad)) + + "] has a shape of " + + str(grad_shape) + + " and output[" + + str(outputs.index(out)) + + "] has a shape of " + + str(out_shape) + + ". " + "If you only want some tensors in `grad_output` to be considered " + "batched, consider using vmap." + ) + else: + raise RuntimeError( + "Mismatch in shape: grad_output[" + + str(grads.index(grad)) + + "] has a shape of " + + str(grad_shape) + + " and output[" + + str(outputs.index(out)) + + "] has a shape of " + + str(out_shape) + + "." + ) + if out_dtype.is_complex != grad.dtype.is_complex: + raise RuntimeError( + "For complex Tensors, both grad_output and output" + " are required to have the same dtype." + " Mismatch in dtype: grad_output[" + + str(grads.index(grad)) + + "] has a dtype of " + + str(grad.dtype) + + " and output[" + + str(outputs.index(out)) + + "] has a dtype of " + + str(out_dtype) + + "." + ) + new_grads.append(grad) + elif grad is None: + if isinstance(out, graph.GradientEdge) or out.requires_grad: # type: ignore[attr-defined] + if isinstance(out, graph.GradientEdge): + assert out_size is not None + out_numel_is_1 = all(o == 1 for o in out_size) + else: + assert isinstance(out, torch.Tensor) + out_numel_is_1 = out.numel() == 1 + if not out_numel_is_1: + raise RuntimeError( + "grad can be implicitly created only for scalar outputs" + ) + if not out_dtype.is_floating_point: + msg = ( + "grad can be implicitly created only for real scalar outputs" + f" but got {out_dtype}" + ) + raise RuntimeError(msg) + if isinstance(out, graph.GradientEdge): + assert out_size is not None + assert out_device is not None + new_grads.append( + torch.ones( + out_size, + dtype=out_dtype, + device=out_device, + ) + ) + else: + assert isinstance(out, torch.Tensor) + new_grads.append( + torch.ones_like(out, memory_format=torch.preserve_format) + ) + else: + new_grads.append(None) + else: + raise TypeError( + "gradients can be either Tensors or None, but got " + + type(grad).__name__ + ) + return tuple(new_grads) + + +def _tensor_or_tensors_to_tuple( + tensors: Optional[_TensorOrTensors], length: int +) -> tuple[_OptionalTensor, ...]: + if tensors is None: + return (None,) * length + if isinstance(tensors, torch.Tensor): + return (tensors,) + return tuple(tensors) + + +def backward( + tensors: _TensorOrTensorsOrGradEdge, + grad_tensors: Optional[_TensorOrTensors] = None, + retain_graph: Optional[bool] = None, + create_graph: bool = False, + grad_variables: Optional[_TensorOrTensors] = None, + inputs: Optional[_TensorOrTensorsOrGradEdge] = None, +) -> None: + r"""Compute the sum of gradients of given tensors with respect to graph leaves. + + The graph is differentiated using the chain rule. If any of ``tensors`` + are non-scalar (i.e. their data has more than one element) and require + gradient, then the Jacobian-vector product would be computed, in this + case the function additionally requires specifying ``grad_tensors``. + It should be a sequence of matching length, that contains the "vector" + in the Jacobian-vector product, usually the gradient of the differentiated + function w.r.t. corresponding tensors (``None`` is an acceptable value for + all tensors that don't need gradient tensors). + + This function accumulates gradients in the leaves - you might need to zero + ``.grad`` attributes or set them to ``None`` before calling it. + See :ref:`Default gradient layouts` + for details on the memory layout of accumulated gradients. + + .. note:: + Using this method with ``create_graph=True`` will create a reference cycle + between the parameter and its gradient which can cause a memory leak. + We recommend using ``autograd.grad`` when creating the graph to avoid this. + If you have to use this function, make sure to reset the ``.grad`` fields of your + parameters to ``None`` after use to break the cycle and avoid the leak. + + .. note:: + + If you run any forward ops, create ``grad_tensors``, and/or call ``backward`` + in a user-specified CUDA stream context, see + :ref:`Stream semantics of backward passes`. + + .. note:: + + When ``inputs`` are provided and a given input is not a leaf, + the current implementation will call its grad_fn (even though it is not strictly needed to get this gradients). + It is an implementation detail on which the user should not rely. + See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details. + + Args: + tensors (Sequence[Tensor] or Tensor or Sequence[GradientEdge] or GradientEdge): Tensors of which + the derivative will be computed. + grad_tensors (Sequence[Tensor or None] or Tensor, optional): The "vector" in + the Jacobian-vector product, usually gradients w.r.t. each element of + corresponding tensors. None values can be specified for scalar Tensors or + ones that don't require grad. If a None value would be acceptable for all + grad_tensors, then this argument is optional. + retain_graph (bool, optional): If ``False``, the graph used to compute the grad + will be freed. Note that in nearly all cases setting this option to ``True`` + is not needed and often can be worked around in a much more efficient + way. Defaults to the value of ``create_graph``. + create_graph (bool, optional): If ``True``, graph of the derivative will + be constructed, allowing to compute higher order derivative products. + Defaults to ``False``. + inputs (Sequence[Tensor] or Tensor or Sequence[GradientEdge], optional): Inputs w.r.t. which the gradient + be will accumulated into ``.grad``. All other Tensors will be ignored. If + not provided, the gradient is accumulated into all the leaf Tensors that + were used to compute the :attr:`tensors`. + """ + if torch._C._are_functorch_transforms_active(): + raise RuntimeError( + "backward() called inside a functorch transform. This is not " + "supported, please use functorch.grad or functorch.vjp instead " + "or call backward() outside of functorch transforms." + ) + + if grad_variables is not None: + warnings.warn( + "`grad_variables` is deprecated. Use `grad_tensors` instead.", + FutureWarning, + stacklevel=2, + ) + if grad_tensors is None: + grad_tensors = grad_variables + else: + raise RuntimeError( + "`grad_tensors` and `grad_variables` (deprecated) " + "arguments both passed to `backward()`. Please only " + "use `grad_tensors`." + ) + + inputs_tuple: tuple[Union[torch.Tensor, graph.GradientEdge], ...] + if inputs is None: + inputs_tuple = () + elif isinstance(inputs, (torch.Tensor, graph.GradientEdge)): + inputs_tuple = (inputs,) + else: + inputs_tuple = tuple(inputs) + if len(inputs_tuple) == 0: + raise RuntimeError("`inputs` argument to `backward()` cannot be empty.") + + if is_tensor_like(tensors) or isinstance(tensors, graph.GradientEdge): + tensors = cast( + Union[tuple[torch.Tensor], tuple[graph.GradientEdge]], (tensors,) + ) + else: + tensors = tuple(tensors) + + grad_tensors_ = _tensor_or_tensors_to_tuple(grad_tensors, len(tensors)) + grad_tensors_ = _make_grads(tensors, grad_tensors_, is_grads_batched=False) + if retain_graph is None: + retain_graph = create_graph + + # The reason we repeat the same comment below is that + # some Python versions print out the first line of a multi-line function + # calls in the traceback and some print out the last line + _engine_run_backward( + tensors, + grad_tensors_, + retain_graph, + create_graph, + inputs_tuple, + allow_unreachable=True, + accumulate_grad=True, + ) + + +def grad( + outputs: _TensorOrTensorsOrGradEdge, + inputs: _TensorOrTensorsOrGradEdge, + grad_outputs: Optional[_TensorOrTensors] = None, + retain_graph: Optional[bool] = None, + create_graph: bool = False, + only_inputs: bool = True, + allow_unused: Optional[bool] = None, + is_grads_batched: bool = False, + materialize_grads: bool = False, +) -> tuple[torch.Tensor, ...]: + r"""Compute and return the sum of gradients of outputs with respect to the inputs. + + ``grad_outputs`` should be a sequence of length matching ``output`` + containing the "vector" in vector-Jacobian product, usually the pre-computed + gradients w.r.t. each of the outputs. If an output doesn't require_grad, + then the gradient can be ``None``). + + .. note:: + + If you run any forward ops, create ``grad_outputs``, and/or call ``grad`` + in a user-specified CUDA stream context, see + :ref:`Stream semantics of backward passes`. + + .. note:: + + ``only_inputs`` argument is deprecated and is ignored now (defaults to ``True``). + To accumulate gradient for other parts of the graph, please use + ``torch.autograd.backward``. + + Args: + outputs (sequence of Tensor or GradientEdge): outputs of the differentiated function. + inputs (sequence of Tensor or GradientEdge): Inputs w.r.t. which the gradient will be + returned (and not accumulated into ``.grad``). + grad_outputs (sequence of Tensor): The "vector" in the vector-Jacobian product. + Usually gradients w.r.t. each output. None values can be specified for scalar + Tensors or ones that don't require grad. If a None value would be acceptable + for all grad_tensors, then this argument is optional. Default: None. + retain_graph (bool, optional): If ``False``, the graph used to compute the grad + will be freed. Note that in nearly all cases setting this option to ``True`` + is not needed and often can be worked around in a much more efficient + way. Defaults to the value of ``create_graph``. + create_graph (bool, optional): If ``True``, graph of the derivative will + be constructed, allowing to compute higher order derivative products. + Default: ``False``. + allow_unused (Optional[bool], optional): If ``False``, specifying inputs + that were not used when computing outputs (and therefore their grad is + always zero) is an error. Defaults to the value of ``materialize_grads``. + is_grads_batched (bool, optional): If ``True``, the first dimension of each + tensor in ``grad_outputs`` will be interpreted as the batch dimension. + Instead of computing a single vector-Jacobian product, we compute a + batch of vector-Jacobian products for each "vector" in the batch. + We use the vmap prototype feature as the backend to vectorize calls + to the autograd engine so that this computation can be performed in a + single call. This should lead to performance improvements when compared + to manually looping and performing backward multiple times. Note that + due to this feature being experimental, there may be performance + cliffs. Please use ``torch._C._debug_only_display_vmap_fallback_warnings(True)`` + to show any performance warnings and file an issue on github if warnings exist + for your use case. Defaults to ``False``. + materialize_grads (bool, optional): If ``True``, set the gradient for unused inputs + to zero instead of None. This is useful when computing higher-order derivatives. + If ``materialize_grads`` is ``True`` and ``allow_unused`` is ``False``, an error + will be raised. Defaults to ``False``. + + """ + if materialize_grads and allow_unused is False: + raise ValueError( + "Expected allow_unused to be True or not passed when materialize_grads=True, " + "but got: allow_unused=False." + ) + if allow_unused is None: + allow_unused = materialize_grads + if is_tensor_like(outputs) or isinstance(outputs, graph.GradientEdge): + outputs = cast( + Union[Sequence[torch.Tensor], Sequence[graph.GradientEdge]], (outputs,) + ) + else: + outputs = tuple(outputs) + if is_tensor_like(inputs) or isinstance(inputs, graph.GradientEdge): + inputs = cast(_TensorOrTensorsOrGradEdge, (inputs,)) + else: + inputs = tuple(inputs) + t_outputs = tuple(i for i in outputs if is_tensor_like(i)) + t_inputs = tuple(i for i in inputs if is_tensor_like(i)) + overridable_args = t_outputs + t_inputs + if has_torch_function(overridable_args): + return handle_torch_function( + grad, + overridable_args, + outputs, + inputs, + grad_outputs=grad_outputs, + retain_graph=retain_graph, + create_graph=create_graph, + only_inputs=only_inputs, + allow_unused=allow_unused, + is_grads_batched=is_grads_batched, + materialize_grads=materialize_grads, + ) + + if not only_inputs: + warnings.warn( + "only_inputs argument is deprecated and is ignored now " + "(defaults to True). To accumulate gradient for other " + "parts of the graph, please use torch.autograd.backward.", + FutureWarning, + stacklevel=2, + ) + + grad_outputs_ = _tensor_or_tensors_to_tuple(grad_outputs, len(outputs)) + grad_outputs_ = _make_grads( + outputs, grad_outputs_, is_grads_batched=is_grads_batched + ) + + if retain_graph is None: + retain_graph = create_graph + + # The reason we repeat the same comment several times below is because + # some Python versions print out the first line of multi-line function + # calls in the traceback and some print out the last line + if is_grads_batched: + + def vjp(gO): + return _engine_run_backward( + outputs, + gO, + retain_graph, + create_graph, + inputs, + allow_unused, + accumulate_grad=False, + ) + + result = _vmap_internals._vmap(vjp, 0, 0, allow_none_pass_through=True)( + grad_outputs_ + ) + else: + result = _engine_run_backward( + outputs, + grad_outputs_, + retain_graph, + create_graph, + inputs, + allow_unused, + accumulate_grad=False, + ) + if materialize_grads: + if any( + result[i] is None and not is_tensor_like(inputs[i]) + for i in range(len(inputs)) + ): + raise RuntimeError( + "materialize_grads cannot be used when the given input is a GradientEdge" + ) + result = tuple( + output + if output is not None + else torch.zeros_like(input, requires_grad=True) + for (output, input) in zip(result, inputs) + ) + return result + + +# This function applies in case of gradient checkpointing for memory +# optimization. Currently, gradient checkpointing is supported only if the +# execution engine is invoked through torch.autograd.backward() and its +# inputs argument is not passed. It is not supported for torch.autograd.grad(). +# This is because if inputs are specified, the gradient won't be calculated for +# anything else e.g. model parameters like weights, bias etc. +# +# This function returns whether the checkpointing is valid i.e. torch.autograd.backward +# or not i.e. torch.autograd.grad. The implementation works by maintaining a thread +# local variable in torch/csrc/autograd/engine.cpp which looks at the NodeTask +# in the stack and before a NodeTask is executed in evaluate_function, it +# checks for whether reentrant backwards is imperative or not. +# See https://github.com/pytorch/pytorch/pull/4594 for more discussion/context +def _is_checkpoint_valid(): + return Variable._execution_engine.is_checkpoint_valid() + + +def variable(*args, **kwargs): # noqa: D103 + raise RuntimeError( + "torch.autograd.variable(...) is deprecated, use torch.tensor(...) instead" + ) + + +# Monkey patching variable.Variable to fix FX codegen. FX generates a call by roughly doing +# f"{fn.__module__}.{fn.__name__}(...). This yields torch.autograd.variable.Variable(...) in the +# output of an FX graph. Unfortunately the module name torch.autograd.variable is shadowed by the +# deprecated function - variable(...). +variable.Variable = Variable # type: ignore[attr-defined] + +if not torch._C._autograd_init(): + raise RuntimeError("autograd initialization failed") + +# Import all native method/classes +from torch._C._autograd import ( + _add_metadata_json, + _disable_profiler, + _disable_profiler_legacy, + _enable_profiler, + _enable_profiler_legacy, + _enable_record_function, + _get_sequence_nr, + _kineto_step, + _KinetoEvent, + _pop_saved_tensors_default_hooks, + _prepare_profiler, + _profiler_enabled, + _ProfilerResult, + _push_saved_tensors_default_hooks, + _record_function_with_args_enter, + _record_function_with_args_exit, + _set_empty_test_observer, + _supported_activities, + _toggle_collection_dynamic, + DeviceType, + kineto_available, + ProfilerEvent, + SavedTensor, +) +from torch._C._profiler import ProfilerActivity, ProfilerConfig, ProfilerState + +from . import profiler + + +def _register_py_tensor_class_for_device(device, cls): + if not isinstance(cls, type): + raise RuntimeError("cls isn't a typeinfo object") + torch._C._register_py_class_for_device(device, cls) + + +is_multithreading_enabled = torch._C._is_multithreading_enabled +torch._C._add_docstr( + is_multithreading_enabled, "Returns True if multithreading is currently enabled." +) + +is_view_replay_enabled = torch._C._is_view_replay_enabled +torch._C._add_docstr( + is_view_replay_enabled, "Returns True if view-replay is currently enabled." +) diff --git a/phivenv/Lib/site-packages/torch/autograd/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/autograd/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5de670d43ed59a7c433a13b6e8e82747c4718dba Binary files /dev/null and b/phivenv/Lib/site-packages/torch/autograd/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/autograd/__pycache__/anomaly_mode.cpython-39.pyc b/phivenv/Lib/site-packages/torch/autograd/__pycache__/anomaly_mode.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c468ba49eac6e375de9bd7a4c394f22f36bc7d82 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/autograd/__pycache__/anomaly_mode.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/autograd/__pycache__/forward_ad.cpython-39.pyc b/phivenv/Lib/site-packages/torch/autograd/__pycache__/forward_ad.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..495c93e6fc50bfd71a41129f9dcba8cc1766764a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/autograd/__pycache__/forward_ad.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/autograd/__pycache__/function.cpython-39.pyc b/phivenv/Lib/site-packages/torch/autograd/__pycache__/function.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d46b743a66fff4233bb830310ecb785da1e1a1b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/autograd/__pycache__/function.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/autograd/__pycache__/functional.cpython-39.pyc b/phivenv/Lib/site-packages/torch/autograd/__pycache__/functional.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd86047f0180816f75c46d019f563daa02ffd761 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/autograd/__pycache__/functional.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/autograd/__pycache__/grad_mode.cpython-39.pyc b/phivenv/Lib/site-packages/torch/autograd/__pycache__/grad_mode.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5eb8f352e280df6712645fdae96059ce1d40208b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/autograd/__pycache__/grad_mode.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/autograd/__pycache__/gradcheck.cpython-39.pyc b/phivenv/Lib/site-packages/torch/autograd/__pycache__/gradcheck.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd4cc2679ca4cbcb694c57bd6491236d9a4e2930 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/autograd/__pycache__/gradcheck.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/autograd/__pycache__/graph.cpython-39.pyc b/phivenv/Lib/site-packages/torch/autograd/__pycache__/graph.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2874ec0efac6b8fb3ffec0e82b7b00bab2bd2823 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/autograd/__pycache__/graph.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/autograd/__pycache__/profiler.cpython-39.pyc b/phivenv/Lib/site-packages/torch/autograd/__pycache__/profiler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f76029b3c157419e64fc7e72eb178828c8751c34 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/autograd/__pycache__/profiler.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/autograd/__pycache__/profiler_legacy.cpython-39.pyc b/phivenv/Lib/site-packages/torch/autograd/__pycache__/profiler_legacy.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ecb2e2ddfcd17390851d3c7c2d5a206fb8d2093 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/autograd/__pycache__/profiler_legacy.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/autograd/__pycache__/profiler_util.cpython-39.pyc b/phivenv/Lib/site-packages/torch/autograd/__pycache__/profiler_util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..702abda6be9bb96dbd0d31448f1dbeef6229f694 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/autograd/__pycache__/profiler_util.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/autograd/__pycache__/variable.cpython-39.pyc b/phivenv/Lib/site-packages/torch/autograd/__pycache__/variable.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..760ab05a210974cc198ff1c81fd1a6d15d7e3e1e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/autograd/__pycache__/variable.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/autograd/_functions/__init__.py b/phivenv/Lib/site-packages/torch/autograd/_functions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fc92d7c7fe74ad79a100e0233150f90becde55be --- /dev/null +++ b/phivenv/Lib/site-packages/torch/autograd/_functions/__init__.py @@ -0,0 +1 @@ +from .tensor import * # noqa: F403 diff --git a/phivenv/Lib/site-packages/torch/autograd/_functions/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/autograd/_functions/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..292e9dff1b91d0846363b9d5e425981ad6b47637 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/autograd/_functions/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/autograd/_functions/__pycache__/tensor.cpython-39.pyc b/phivenv/Lib/site-packages/torch/autograd/_functions/__pycache__/tensor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb656dc521b9a59a52638727561cd16199b56892 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/autograd/_functions/__pycache__/tensor.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/autograd/_functions/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/autograd/_functions/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2564b777af2ae024d68be7281136b7a1689647fd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/autograd/_functions/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/autograd/_functions/tensor.py b/phivenv/Lib/site-packages/torch/autograd/_functions/tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..f973096758fbf89a3192e7221c35ad925f4b4513 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/autograd/_functions/tensor.py @@ -0,0 +1,65 @@ +# mypy: allow-untyped-defs +import operator +from functools import reduce +from typing_extensions import deprecated + +import torch +import torch._utils +from torch.autograd.function import Function + + +class Type(Function): + @staticmethod + @deprecated( + "`torch.autograd._functions.Type` is deprecated as of PyTorch 2.1, " + "please use `torch.tensor.to(dtype=dtype)` instead.", + category=FutureWarning, + ) + def forward(ctx, i, dest_type): + ctx.input_type = type(i) + ctx.input_device = -1 if not i.is_cuda else i.get_device() + return i.type(dest_type) + + @staticmethod + def backward(ctx, grad_output): + if ctx.input_device == -1: + return grad_output.type(ctx.input_type), None + else: + with torch.accelerator.device_index(ctx.input_device): + return grad_output.type(ctx.input_type), None + + +# TODO: deprecate this +class Resize(Function): + @staticmethod + def forward(ctx, tensor, sizes): + ctx.sizes = sizes + ctx.numel = reduce(operator.mul, sizes, 1) + if tensor.numel() != ctx.numel: + raise RuntimeError( + ( + "requested resize to {} ({} elements in total), " + "but the given tensor has a size of {} ({} elements). " + "autograd's resize can only change the shape of a given " + "tensor, while preserving the number of elements. " + ).format( + "x".join(map(str, sizes)), + ctx.numel, + "x".join(map(str, tensor.size())), + tensor.numel(), + ) + ) + ctx.input_sizes = tensor.size() + if tensor.is_quantized: + tensor.copy_(tensor) + return tensor.contiguous().view(*sizes) + if tensor.is_contiguous(): + result = tensor.new(tensor).contiguous().view(*sizes) + return result + else: + return tensor.contiguous().view(*sizes) + + @staticmethod + def backward(ctx, grad_output): + assert grad_output.numel() == ctx.numel + return grad_output.contiguous().view(ctx.input_sizes), None diff --git a/phivenv/Lib/site-packages/torch/autograd/_functions/utils.py b/phivenv/Lib/site-packages/torch/autograd/_functions/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a47107fcb364ba4afe10d1912f478bfd28a4c702 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/autograd/_functions/utils.py @@ -0,0 +1,63 @@ +# mypy: allow-untyped-defs +import operator +from functools import reduce + + +def maybe_view(tensor, size, check_same_size=True): + if check_same_size and tensor.size() == size: + return tensor + return tensor.contiguous().view(size) + + +def maybe_unexpand(tensor, old_size, check_same_size=True): + if check_same_size and tensor.size() == old_size: + return tensor + num_unsqueezed = tensor.dim() - len(old_size) + expanded_dims = [ + dim + for dim, (expanded, original) in enumerate( + zip(tensor.size()[num_unsqueezed:], old_size) + ) + if expanded != original + ] + + for _ in range(num_unsqueezed): + tensor = tensor.sum(0, keepdim=False) + for dim in expanded_dims: + tensor = tensor.sum(dim, keepdim=True) + return tensor + + +# Check whether the op enable broadcasting, and whether it is supported by ONNX. +# If dims1 and dims2 are different, then broadcast is True. +# We always assume the combination of dims1 and dims2 is broadcastable. +# The following types of broadcasting are supported in ONNX: +# 1) Only one element in dims2, such as dims2 = [1, 1] +# 2) dims2 is suffix of dims1, such as dims1 = [2, 3, 4], and dims2 = [3, 4] +# Details can be found here: https://github.com/onnx/onnx/blob/master/docs/Operators.md#Gemm +def check_onnx_broadcast(dims1, dims2): + broadcast = False + supported = True + len1 = len(dims1) + len2 = len(dims2) + + numel2 = reduce(operator.mul, dims2) + if len1 < len2: + broadcast = True + if numel2 != 1: + supported = False + elif len1 > len2: + broadcast = True + if numel2 != 1 and dims1[len1 - len2 :] != dims2: + supported = False + else: + if dims1 != dims2: + broadcast = True + if numel2 != 1: + supported = False + + if not supported: + raise ValueError( + f"Numpy style broadcasting is not supported in ONNX. Input dims are: {dims1}, {dims2}" + ) + return broadcast diff --git a/phivenv/Lib/site-packages/torch/autograd/anomaly_mode.py b/phivenv/Lib/site-packages/torch/autograd/anomaly_mode.py new file mode 100644 index 0000000000000000000000000000000000000000..0e2ef192266d0bc910d9758c8d00b61a568af0b2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/autograd/anomaly_mode.py @@ -0,0 +1,121 @@ +# mypy: allow-untyped-defs +r"""Autograd anomaly mode.""" +import warnings + +import torch + + +__all__ = ["detect_anomaly", "set_detect_anomaly"] + + +class detect_anomaly: + r"""Context-manager that enable anomaly detection for the autograd engine. + + This does two things: + + - Running the forward pass with detection enabled will allow the backward + pass to print the traceback of the forward operation that created the failing + backward function. + - If ``check_nan`` is ``True``, any backward computation that generate "nan" + value will raise an error. Default ``True``. + + .. warning:: + This mode should be enabled only for debugging as the different tests + will slow down your program execution. + + Example: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ANOMALY) + >>> import torch + >>> from torch import autograd + >>> class MyFunc(autograd.Function): + ... @staticmethod + ... def forward(ctx, inp): + ... return inp.clone() + ... @staticmethod + ... def backward(ctx, gO): + ... # Error during the backward pass + ... raise RuntimeError("Some error in backward") + ... return gO.clone() + >>> def run_fn(a): + ... out = MyFunc.apply(a) + ... return out.sum() + >>> inp = torch.rand(10, 10, requires_grad=True) + >>> out = run_fn(inp) + >>> out.backward() + Traceback (most recent call last): + File "", line 1, in + File "/your/pytorch/install/torch/_tensor.py", line 93, in backward + torch.autograd.backward(self, gradient, retain_graph, create_graph) + File "/your/pytorch/install/torch/autograd/__init__.py", line 90, in backward + allow_unreachable=True) # allow_unreachable flag + File "/your/pytorch/install/torch/autograd/function.py", line 76, in apply + return self._forward_cls.backward(self, *args) + File "", line 8, in backward + RuntimeError: Some error in backward + >>> with autograd.detect_anomaly(): + ... inp = torch.rand(10, 10, requires_grad=True) + ... out = run_fn(inp) + ... out.backward() + Traceback of forward call that caused the error: + File "tmp.py", line 53, in + out = run_fn(inp) + File "tmp.py", line 44, in run_fn + out = MyFunc.apply(a) + Traceback (most recent call last): + File "", line 4, in + File "/your/pytorch/install/torch/_tensor.py", line 93, in backward + torch.autograd.backward(self, gradient, retain_graph, create_graph) + File "/your/pytorch/install/torch/autograd/__init__.py", line 90, in backward + allow_unreachable=True) # allow_unreachable flag + File "/your/pytorch/install/torch/autograd/function.py", line 76, in apply + return self._forward_cls.backward(self, *args) + File "", line 8, in backward + RuntimeError: Some error in backward + + """ + + def __init__(self, check_nan=True) -> None: # noqa: D107 + self.prev = torch.is_anomaly_enabled() + self.check_nan = check_nan + self.prev_check_nan = torch.is_anomaly_check_nan_enabled() + warnings.warn( + "Anomaly Detection has been enabled. " + "This mode will increase the runtime " + "and should only be enabled for debugging.", + stacklevel=2, + ) + + def __enter__(self) -> None: # noqa: D105 + torch.set_anomaly_enabled(True, self.check_nan) + + def __exit__(self, *args: object) -> None: # noqa: D105 + torch.set_anomaly_enabled(self.prev, self.prev_check_nan) + + +class set_detect_anomaly: + r"""Context-manager that sets the anomaly detection for the autograd engine on or off. + + ``set_detect_anomaly`` will enable or disable the autograd anomaly detection + based on its argument :attr:`mode`. + It can be used as a context-manager or as a function. + + See ``detect_anomaly`` above for details of the anomaly detection behaviour. + + Args: + mode (bool): Flag whether to enable anomaly detection (``True``), + or disable (``False``). + check_nan (bool): Flag whether to raise an error when the backward + generate "nan" + + """ + + def __init__(self, mode: bool, check_nan: bool = True) -> None: # noqa: D107 + self.prev = torch.is_anomaly_enabled() + self.prev_check_nan = torch.is_anomaly_check_nan_enabled() + torch.set_anomaly_enabled(mode, check_nan) + + def __enter__(self) -> None: # noqa: D105 + pass + + def __exit__(self, *args: object) -> None: # noqa: D105 + torch.set_anomaly_enabled(self.prev, self.prev_check_nan) diff --git a/phivenv/Lib/site-packages/torch/autograd/forward_ad.py b/phivenv/Lib/site-packages/torch/autograd/forward_ad.py new file mode 100644 index 0000000000000000000000000000000000000000..36ef0da38864bb15c561cc67aebd77018515a53a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/autograd/forward_ad.py @@ -0,0 +1,229 @@ +# mypy: allow-untyped-defs +import os +from typing import Any, NamedTuple, Optional + +import torch + +from .grad_mode import _DecoratorContextManager + + +__all__ = [ + "UnpackedDualTensor", + "enter_dual_level", + "exit_dual_level", + "make_dual", + "unpack_dual", + "dual_level", +] + +# Global variable used to make the python API simpler to use +_current_level = -1 + + +def enter_dual_level(): + r"""Enter a new forward grad level. + + This level can be used to make and unpack dual Tensors to compute + forward gradients. + + This function also updates the current level that is used by default + by the other functions in this API. + """ + global _current_level + new_level = torch._C._enter_dual_level() + if new_level != _current_level + 1: + raise RuntimeError( + "Entering a new forward AD level but the current level " + "is not valid. Make sure you did not modified it directly." + ) + _current_level = new_level + return new_level + + +def exit_dual_level(*, level=None): + r"""Exit a forward grad level. + + This function deletes all the gradients associated with this + level. Only deleting the latest entered level is allowed. + + This function also updates the current level that is used by default + by the other functions in this API. + """ + global _current_level + if level is None: + level = _current_level + if level != _current_level: + raise RuntimeError( + "Trying to exit a forward AD level that was not the last one " + "that was created. This is not supported." + ) + torch._C._exit_dual_level(level=level) + _current_level = level - 1 + + +def _maybe_load_decompositions(): + if os.environ.get("PYTORCH_JIT", "1") == "1" and __debug__: + from torch._decomp import decompositions_for_jvp # noqa: F401 + + +def make_dual(tensor, tangent, *, level=None): + r"""Associate a tensor value with its tangent to create a "dual tensor" for forward AD gradient computation. + + The result is a new tensor aliased to :attr:`tensor` with :attr:`tangent` embedded + as an attribute as-is if it has the same storage layout or copied otherwise. + The tangent attribute can be recovered with :func:`unpack_dual`. + + This function is backward differentiable. + + Given a function `f` whose jacobian is `J`, it allows one to compute the Jacobian-vector product (`jvp`) + between `J` and a given vector `v` as follows. + + Example:: + + >>> # xdoctest: +SKIP("Undefined variables") + >>> with dual_level(): + ... inp = make_dual(x, v) + ... out = f(inp) + ... y, jvp = unpack_dual(out) + + Please see the `forward-mode AD tutorial `__ + for detailed steps on how to use this API. + + """ + # See NOTE: [forward-mode AD decompositions mechanism] + # + # Import from torch._decomp import decompositions_for_jvp to register + # decompositions for jvp to the jit registry + # + # FIXME: We specify that __debug__ must be True because + # if python is run with -OO or -O flags (i.e., __debug__ is False), we encounter the + # following error: + # + # Return value was annotated as having type Tuple[NoneType, NoneType] but is actually of + # type Tuple[Tensor, Tensor]: + # File ".../torch/_decomp/__init__.py", line 1585 + # else: + # buffer = z + # return min - torch.log1p(z), buffer + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE + _maybe_load_decompositions() + + if level is None: + level = _current_level + + if level < 0: + raise RuntimeError( + "Trying to create a dual Tensor for forward AD but no level " + "exists, make sure to enter_dual_level() first." + ) + if not (tensor.is_floating_point() or tensor.is_complex()): + raise ValueError( + f"Expected primal to be floating point or complex, but got: {tensor.dtype}" + ) + if not (tangent.is_floating_point() or tangent.is_complex()): + raise ValueError( + f"Expected tangent to be floating point or complex, but got: {tangent.dtype}" + ) + + return torch._VF._make_dual(tensor, tangent, level=level) + + +class UnpackedDualTensor(NamedTuple): + r"""Namedtuple returned by :func:`unpack_dual` containing the primal and tangent components of the dual tensor. + + See :func:`unpack_dual` for more details. + """ + + primal: torch.Tensor + tangent: Optional[torch.Tensor] + + +def unpack_dual(tensor, *, level=None): + r"""Unpack a "dual tensor" to get both its Tensor value and its forward AD gradient. + + The result is a namedtuple ``(primal, tangent)`` where ``primal`` is a view of + :attr:`tensor`'s primal and ``tangent`` is :attr:`tensor`'s tangent as-is. + Neither of these tensors can be dual tensor of level :attr:`level`. + + This function is backward differentiable. + + Example:: + + >>> # xdoctest: +SKIP("Undefined variables") + >>> with dual_level(): + ... inp = make_dual(x, x_t) + ... out = f(inp) + ... y, jvp = unpack_dual(out) + ... jvp = unpack_dual(out).tangent + + Please see the `forward-mode AD tutorial `__ + for detailed steps on how to use this API. + """ + if level is None: + level = _current_level + + if level < 0: + return UnpackedDualTensor(tensor, None) + + primal, dual = torch._VF._unpack_dual(tensor, level=level) + + return UnpackedDualTensor(primal, dual) + + +class dual_level(_DecoratorContextManager): + r"""Context-manager for forward AD, where all forward AD computation must occur within the ``dual_level`` context. + + .. Note:: + + The ``dual_level`` context appropriately enters and exit the dual level to + controls the current forward AD level, which is used by default by the other + functions in this API. + + We currently don't plan to support nested ``dual_level`` contexts, however, so + only a single forward AD level is supported. To compute higher-order + forward grads, one can use :func:`torch.func.jvp`. + + Example:: + + >>> # xdoctest: +SKIP("Undefined variables") + >>> x = torch.tensor([1]) + >>> x_t = torch.tensor([1]) + >>> with dual_level(): + ... inp = make_dual(x, x_t) + ... # Do computations with inp + ... out = your_fn(inp) + ... _, grad = unpack_dual(out) + >>> grad is None + False + >>> # After exiting the level, the grad is deleted + >>> _, grad_after = unpack_dual(out) + >>> grad is None + True + + Please see the `forward-mode AD tutorial `__ + for detailed steps on how to use this API. + """ + + def __enter__(self): + return enter_dual_level() + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + exit_dual_level() + + +# Private helper functions +_is_fwd_grad_enabled = torch._C._is_fwd_grad_enabled + + +# Private helper function to enable or disable fwd grad. +# If you're a user and want to use this, please file an issue to discuss the use case. +class _set_fwd_grad_enabled(_DecoratorContextManager): + def __init__(self, mode: bool) -> None: + self.prev = _is_fwd_grad_enabled() + torch._C._set_fwd_grad_enabled(mode) + + def __enter__(self) -> None: + pass + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + torch._C._set_fwd_grad_enabled(self.prev) diff --git a/phivenv/Lib/site-packages/torch/autograd/function.py b/phivenv/Lib/site-packages/torch/autograd/function.py new file mode 100644 index 0000000000000000000000000000000000000000..7b0d0df79377194f5a83e436cb80b61e1f7fc3cb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/autograd/function.py @@ -0,0 +1,844 @@ +# mypy: allow-untyped-defs +import functools +import inspect +import itertools +import warnings +from collections import OrderedDict +from typing import Any, Optional +from typing_extensions import deprecated + +import torch +import torch._C as _C +import torch._functorch as _functorch +import torch.utils.hooks as hooks +from torch._C import _functions +from torch._functorch.autograd_function import custom_function_call + + +__all__ = [ + "FunctionCtx", + "BackwardCFunction", + "FunctionMeta", + "Function", + "once_differentiable", + "InplaceFunction", + "NestedIOFunction", +] + +# Unique id provider for each class inheriting from Function +# This is incremented in FunctionMeta during class definition +AUTOGRAD_FUNCTION_COUNTER = itertools.count() + + +# Formerly known as: _ContextMethodMixin +class FunctionCtx: + def save_for_backward(self, *tensors: torch.Tensor): + r"""Save given tensors for a future call to :func:`~Function.backward`. + + ``save_for_backward`` should be called at most once, in either the + :func:`setup_context` or :func:`forward` methods, and only with tensors. + + All tensors intended to be used in the backward pass should be saved + with ``save_for_backward`` (as opposed to directly on ``ctx``) to prevent + incorrect gradients and memory leaks, and enable the application of saved + tensor hooks. See :class:`torch.autograd.graph.saved_tensors_hooks`. + See :ref:`extending-autograd` for more details. + + Note that if intermediary tensors, tensors that are neither inputs + nor outputs of :func:`forward`, are saved for backward, your custom Function + may not support double backward. + Custom Functions that do not support double backward should decorate their + :func:`backward` method with ``@once_differentiable`` so that performing + double backward raises an error. If you'd like to support double backward, + you can either recompute intermediaries based on the inputs during backward + or return the intermediaries as the outputs of the custom Function. See the + `double backward tutorial `_ + for more details. + + In :func:`backward`, saved tensors can be accessed through the :attr:`saved_tensors` + attribute. Before returning them to the user, a check is made to ensure + they weren't used in any in-place operation that modified their content. + + Arguments can also be ``None``. This is a no-op. + + See :ref:`extending-autograd` for more details on how to use this method. + + Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) + >>> class Func(Function): + >>> @staticmethod + >>> def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int): + >>> w = x * z + >>> out = x * y + y * z + w * y + >>> ctx.save_for_backward(x, y, w, out) + >>> ctx.z = z # z is not a tensor + >>> return out + >>> + >>> @staticmethod + >>> @once_differentiable + >>> def backward(ctx, grad_out): + >>> x, y, w, out = ctx.saved_tensors + >>> z = ctx.z + >>> gx = grad_out * (y + y * z) + >>> gy = grad_out * (x + z + w) + >>> gz = None + >>> return gx, gy, gz + >>> + >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double) + >>> b = torch.tensor(2., requires_grad=True, dtype=torch.double) + >>> c = 4 + >>> d = Func.apply(a, b, c) + + """ + self.to_save = tensors + + def save_for_forward(self, *tensors: torch.Tensor): + r"""Save given tensors for a future call to :func:`~Function.jvp`. + + ``save_for_forward`` should be called at most once, in either the + :func:`setup_context` or :func:`forward` methods, and all arguments + should be tensors. + + In :func:`jvp`, saved objects can be accessed through the :attr:`saved_tensors` + attribute. + + Arguments can also be ``None``. This is a no-op. + + See :ref:`extending-autograd` for more details on how to use this method. + + Example:: + + >>> # xdoctest: +SKIP + >>> class Func(torch.autograd.Function): + >>> @staticmethod + >>> def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int): + >>> ctx.save_for_backward(x, y) + >>> ctx.save_for_forward(x, y) + >>> ctx.z = z + >>> return x * y * z + >>> + >>> @staticmethod + >>> def jvp(ctx, x_t, y_t, _): + >>> x, y = ctx.saved_tensors + >>> z = ctx.z + >>> return z * (y * x_t + x * y_t) + >>> + >>> @staticmethod + >>> def vjp(ctx, grad_out): + >>> x, y = ctx.saved_tensors + >>> z = ctx.z + >>> return z * grad_out * y, z * grad_out * x, None + >>> + >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double) + >>> t = torch.tensor(1., dtype=torch.double) + >>> b = torch.tensor(2., requires_grad=True, dtype=torch.double) + >>> c = 4 + >>> + >>> with fwAD.dual_level(): + >>> a_dual = fwAD.make_dual(a, t) + >>> d = Func.apply(a_dual, b, c) + + """ + for tensor in tensors: + assert isinstance(tensor, torch.Tensor) or tensor is None, ( + "save_for_forward expects all arguments to be tensors; you should " + "save non-tensors as attributes on ctx." + ) + + self.saved_for_forward = tensors + + def mark_dirty(self, *args: torch.Tensor): + r"""Mark given tensors as modified in an in-place operation. + + This should be called at most once, in either the :func:`setup_context` + or :func:`forward` methods, and all arguments should be inputs. + + Every tensor that's been modified in-place in a call to :func:`forward` + should be given to this function, to ensure correctness of our checks. + It doesn't matter whether the function is called before or after + modification. + + Examples:: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) + >>> class Inplace(Function): + >>> @staticmethod + >>> def forward(ctx, x): + >>> x_npy = x.numpy() # x_npy shares storage with x + >>> x_npy += 1 + >>> ctx.mark_dirty(x) + >>> return x + >>> + >>> @staticmethod + >>> @once_differentiable + >>> def backward(ctx, grad_output): + >>> return grad_output + >>> + >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double).clone() + >>> b = a * a + >>> Inplace.apply(a) # This would lead to wrong gradients! + >>> # but the engine would not know unless we mark_dirty + >>> # xdoctest: +SKIP + >>> b.backward() # RuntimeError: one of the variables needed for gradient + >>> # computation has been modified by an inplace operation + + """ + self.dirty_tensors = args + + @deprecated( + "`mark_shared_storage` is deprecated. " + "Tensors with shared storages are automatically tracked. " + "Note that calls to `set_()` are not tracked", + category=FutureWarning, + ) + def mark_shared_storage(self, *pairs): + pass + + def mark_non_differentiable(self, *args: torch.Tensor): + r"""Mark outputs as non-differentiable. + + This should be called at most once, in either the :func:`setup_context` + or :func:`forward` methods, and all arguments should be tensor outputs. + + This will mark outputs as not requiring gradients, increasing the + efficiency of backward computation. You still need to accept a gradient + for each output in :meth:`~Function.backward`, but it's always going to + be a zero tensor with the same shape as the shape of a corresponding + output. + + This is used e.g. for indices returned from a sort. See example:: + >>> class Func(Function): + >>> @staticmethod + >>> def forward(ctx, x): + >>> sorted, idx = x.sort() + >>> ctx.mark_non_differentiable(idx) + >>> ctx.save_for_backward(x, idx) + >>> return sorted, idx + >>> + >>> @staticmethod + >>> @once_differentiable + >>> def backward(ctx, g1, g2): # still need to accept g2 + >>> x, idx = ctx.saved_tensors + >>> grad_input = torch.zeros_like(x) + >>> grad_input.index_add_(0, idx, g1) + >>> return grad_input + + """ + self.non_differentiable = args + + def set_materialize_grads(self, value: bool): + r"""Set whether to materialize grad tensors. Default is ``True``. + + This should be called only from either the :func:`setup_context` or + :func:`forward` methods. + + If ``True``, undefined grad tensors will be expanded to tensors full of zeros + prior to calling the :func:`backward` and :func:`jvp` methods. + + Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) + >>> class SimpleFunc(Function): + >>> @staticmethod + >>> def forward(ctx, x): + >>> return x.clone(), x.clone() + >>> + >>> @staticmethod + >>> @once_differentiable + >>> def backward(ctx, g1, g2): + >>> return g1 + g2 # No check for None necessary + >>> + >>> # We modify SimpleFunc to handle non-materialized grad outputs + >>> class Func(Function): + >>> @staticmethod + >>> def forward(ctx, x): + >>> ctx.set_materialize_grads(False) + >>> ctx.save_for_backward(x) + >>> return x.clone(), x.clone() + >>> + >>> @staticmethod + >>> @once_differentiable + >>> def backward(ctx, g1, g2): + >>> x, = ctx.saved_tensors + >>> grad_input = torch.zeros_like(x) + >>> if g1 is not None: # We must check for None now + >>> grad_input += g1 + >>> if g2 is not None: + >>> grad_input += g2 + >>> return grad_input + >>> + >>> a = torch.tensor(1., requires_grad=True) + >>> b, _ = Func.apply(a) # induces g2 to be undefined + + """ + self.materialize_grads = value + + +# DO NOT USE: This is only defined to be able to load old serialized models +_ContextMethodMixin = FunctionCtx + + +class _HookMixin: + @staticmethod + def _register_hook(backward_hooks, hook): + if backward_hooks is None: + backward_hooks = OrderedDict() + handle = hooks.RemovableHandle(backward_hooks) + backward_hooks[handle.id] = hook + return backward_hooks, handle + + +class BackwardCFunction(_C._FunctionBase, FunctionCtx, _HookMixin): + r""" + This class is used for internal autograd work. Do not use. + """ + + def apply(self, *args): + r""" + Apply method used when executing this Node during the backward + """ + # _forward_cls is defined by derived class + # The user should define either backward or vjp but never both. + backward_fn = self._forward_cls.backward # type: ignore[attr-defined] + vjp_fn = self._forward_cls.vjp # type: ignore[attr-defined] + if backward_fn is not Function.backward and vjp_fn is not Function.vjp: + raise RuntimeError( + "Implementing both 'backward' and 'vjp' for a custom " + "Function is not allowed. You should only implement one " + "of them." + ) + user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn + return user_fn(self, *args) + + def apply_jvp(self, *args): + r""" + Apply method used when executing forward mode AD during the forward + """ + # _forward_cls is defined by derived class + return self._forward_cls.jvp(self, *args) # type: ignore[attr-defined] + + def _compiled_autograd_key(self): + return self._forward_cls._compiled_autograd_key(self) # type: ignore[attr-defined] + + +class FunctionMeta(type): + """Function metaclass. + + This metaclass sets up the following properties: + _backward_cls: The Function class corresponding to the differentiated + version of this function (which is generated on the fly by this + metaclass). + """ + + def __init__(cls, name, bases, attrs): + backward_fn = type( + name + "Backward", (BackwardCFunction,), {"_forward_cls": cls} + ) + backward_fn._autograd_function_id = next(AUTOGRAD_FUNCTION_COUNTER) # type: ignore[attr-defined] + cls._backward_cls = backward_fn + + super().__init__(name, bases, attrs) + + +class _SingleLevelFunction( + _C._FunctionBase, FunctionCtx, _HookMixin, metaclass=FunctionMeta +): + @staticmethod + def forward(*args: Any, **kwargs: Any) -> Any: + r"""Define the forward of the custom autograd Function. + + This function is to be overridden by all subclasses. + There are two ways to define forward: + + Usage 1 (Combined forward and ctx):: + + @staticmethod + def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: + pass + + - It must accept a context ctx as the first argument, followed by any + number of arguments (tensors or other types). + - See :ref:`combining-forward-context` for more details + + Usage 2 (Separate forward and ctx):: + + @staticmethod + def forward(*args: Any, **kwargs: Any) -> Any: + pass + + @staticmethod + def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: + pass + + - The forward no longer accepts a ctx argument. + - Instead, you must also override the :meth:`torch.autograd.Function.setup_context` + staticmethod to handle setting up the ``ctx`` object. + ``output`` is the output of the forward, ``inputs`` are a Tuple of inputs + to the forward. + - See :ref:`extending-autograd` for more details + + The context can be used to store arbitrary data that can be then + retrieved during the backward pass. Tensors should not be stored + directly on `ctx` (though this is not currently enforced for + backward compatibility). Instead, tensors should be saved either with + :func:`ctx.save_for_backward` if they are intended to be used in + ``backward`` (equivalently, ``vjp``) or :func:`ctx.save_for_forward` + if they are intended to be used for in ``jvp``. + """ + raise NotImplementedError( + "You must implement the forward function for custom autograd.Function." + ) + + @staticmethod + def setup_context(ctx: Any, inputs: tuple[Any, ...], output: Any) -> Any: + r"""There are two ways to define the forward pass of an autograd.Function. + + Either: + + 1. Override forward with the signature ``forward(ctx, *args, **kwargs)``. + ``setup_context`` is not overridden. Setting up the ctx for backward + happens inside the ``forward``. + 2. Override forward with the signature ``forward(*args, **kwargs)`` and + override ``setup_context``. Setting up the ctx for backward happens + inside ``setup_context`` (as opposed to inside the ``forward``) + + See :meth:`torch.autograd.Function.forward` and :ref:`extending-autograd` for more details. + """ + raise NotImplementedError("setup_context is not implemented.") + + @staticmethod + def backward(ctx: Any, *grad_outputs: Any) -> Any: + r"""Define a formula for differentiating the operation with backward mode automatic differentiation. + + This function is to be overridden by all subclasses. + (Defining this function is equivalent to defining the ``vjp`` function.) + + It must accept a context :attr:`ctx` as the first argument, followed by + as many outputs as the :func:`forward` returned (None will be passed in + for non tensor outputs of the forward function), + and it should return as many tensors, as there were inputs to + :func:`forward`. Each argument is the gradient w.r.t the given output, + and each returned value should be the gradient w.r.t. the + corresponding input. If an input is not a Tensor or is a Tensor not + requiring grads, you can just pass None as a gradient for that input. + + The context can be used to retrieve tensors saved during the forward + pass. It also has an attribute :attr:`ctx.needs_input_grad` as a tuple + of booleans representing whether each input needs gradient. E.g., + :func:`backward` will have ``ctx.needs_input_grad[0] = True`` if the + first input to :func:`forward` needs gradient computed w.r.t. the + output. + """ + raise NotImplementedError( + "You must implement either the backward or vjp method for " + "your custom autograd.Function to use it with backward " + "mode AD." + ) + + # vjp and backward are alias of each other + vjp = backward + + @staticmethod + def jvp(ctx: Any, *grad_inputs: Any) -> Any: + r"""Define a formula for differentiating the operation with forward mode automatic differentiation. + + This function is to be overridden by all subclasses. + It must accept a context :attr:`ctx` as the first argument, followed by + as many inputs as the :func:`forward` got (None will be passed in + for non tensor inputs of the forward function), + and it should return as many tensors as there were outputs to + :func:`forward`. Each argument is the gradient w.r.t the given input, + and each returned value should be the gradient w.r.t. the + corresponding output. If an output is not a Tensor or the function is not + differentiable with respect to that output, you can just pass None as a + gradient for that input. + + You can use the :attr:`ctx` object to pass any value from the forward to this + functions. + """ + raise NotImplementedError( + "You must implement the jvp function for custom " + "autograd.Function to use it with forward mode AD." + ) + + +class Function(_SingleLevelFunction): + r"""Base class to create custom `autograd.Function`. + + To create a custom `autograd.Function`, subclass this class and implement + the :meth:`forward` and :meth:`backward` static methods. Then, to use your custom + op in the forward pass, call the class method ``apply``. Do not call + :meth:`forward` directly. + + To ensure correctness and best performance, make sure you are calling the + correct methods on ``ctx`` and validating your backward function using + :func:`torch.autograd.gradcheck`. + + See :ref:`extending-autograd` for more details on how to use this class. + + Examples:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) + >>> class Exp(Function): + >>> @staticmethod + >>> def forward(ctx, i): + >>> result = i.exp() + >>> ctx.save_for_backward(result) + >>> return result + >>> + >>> @staticmethod + >>> def backward(ctx, grad_output): + >>> result, = ctx.saved_tensors + >>> return grad_output * result + >>> + >>> # Use it by calling the apply method: + >>> # xdoctest: +SKIP + >>> output = Exp.apply(input) + """ + + def __init__(self, *args, **kwargs): + warnings.warn( + f"{self.__class__} should not be instantiated. Methods on autograd functions " + "are all static, so you should invoke them on the class itself. " + "Instantiating an autograd function will raise an " + "error in a future version of PyTorch.", + DeprecationWarning, + stacklevel=2, + ) + + def __call__(self, *args, **kwargs): + raise RuntimeError( + "Legacy autograd function with non-static forward method is deprecated. " + "Please use new-style autograd function with static forward method. " + "(Example: https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function)" + ) + + """ + Bool that specifies if PyTorch should attempt to autogenerate + :func:`torch.vmap` support for this autograd.Function. You may set this to + True only if this autograd.Function's forward, backward, and jvp (if they + exist) are written using PyTorch operations; otherwise, please override + :meth:`torch.autograd.Function.vmap` to add support for :func:`torch.vmap`. + + Please see :ref:`func-autograd-function` for more details. + """ + generate_vmap_rule = False + + @staticmethod + def vmap(info, in_dims, *args): + r"""Define the behavior for this autograd.Function underneath :func:`torch.vmap`. + + For a :func:`torch.autograd.Function` to support + :func:`torch.vmap`, you must either override this static method, or set + ``generate_vmap_rule`` to ``True`` (you may not do both). + + If you choose to override this staticmethod: it must accept + + - an ``info`` object as the first argument. ``info.batch_size`` + specifies the size of the dimension being vmapped over, + while ``info.randomness`` is the randomness option passed to + :func:`torch.vmap`. + - an ``in_dims`` tuple as the second argument. + For each arg in ``args``, ``in_dims`` has a corresponding + ``Optional[int]``. It is ``None`` if the arg is not a Tensor or if + the arg is not being vmapped over, otherwise, it is an integer + specifying what dimension of the Tensor is being vmapped over. + - ``*args``, which is the same as the args to :meth:`~Function.forward`. + + The return of the vmap staticmethod is a tuple of ``(output, out_dims)``. + Similar to ``in_dims``, ``out_dims`` should be of the same structure as + ``output`` and contain one ``out_dim`` per output that specifies if the + output has the vmapped dimension and what index it is in. + + Please see :ref:`func-autograd-function` for more details. + """ + raise NotImplementedError( + "To use autograd.Function with vmap, you must either override the " + "vmap staticmethod or set generate_vmap_rule=True." + ) + + @classmethod + def apply(cls, *args, **kwargs): + def bind_default_args(func, *args, **kwargs): + signature = inspect.signature(func) + bound_args = signature.bind(*args, **kwargs) + bound_args.apply_defaults() + + return bound_args.args + + is_setup_ctx_defined = _is_setup_context_defined(cls.setup_context) + if is_setup_ctx_defined: + args = bind_default_args(cls.forward, *args, **kwargs) + + if not torch._C._are_functorch_transforms_active(): + # See NOTE: [functorch vjp and autograd interaction] + args = _functorch.utils.unwrap_dead_wrappers(args) + return super().apply(*args, **kwargs) # type: ignore[misc] + + if not is_setup_ctx_defined: + raise RuntimeError( + "In order to use an autograd.Function with functorch transforms " + "(vmap, grad, jvp, jacrev, ...), it must override the setup_context " + "staticmethod. For more details, please see " + "https://pytorch.org/docs/main/notes/extending.func.html" + ) + + return custom_function_call(cls, *args, **kwargs) + + @staticmethod + def _compiled_autograd_key(ctx): + return (ctx._autograd_function_id,) + + +def _is_setup_context_defined(fn): + return fn != _SingleLevelFunction.setup_context + + +def once_differentiable(fn): + @functools.wraps(fn) + def wrapper(ctx, *args): + with torch.no_grad(): + outputs = fn(ctx, *args) + + if not torch.is_grad_enabled(): + return outputs + + # If any of the inputs have requires_grad=True, we force the outputs + # to have requires_grad=True but point to a grad_fn which throws an + # error message during (double) back-propagation. + # XXX: this is only an approximation of requires_grad - there's no way + # to figure out if fn didn't use ctx.saved_tensors and as a result + # some Tensors might require grad, even if no args do. + # Unfortunately, this leads to unexpected error messages ("no nodes + # require computing gradients"), but I don't have a better idea. + # These functions would raise an error in backward anyway. + requires_grad = any( + isinstance(arg, torch.Tensor) and arg.requires_grad for arg in args + ) + if not requires_grad: + return outputs + + if not isinstance(outputs, tuple): + outputs = (outputs,) + + err_fn = _functions.DelayedError( + b"trying to differentiate twice a function that was marked " + b"with @once_differentiable", + len(outputs), + ) + + # Create aliases of each output that has requires_grad=True. We need + # at least one of the inputs to err_fn to require grad so that the + # output will have a grad_fn. + def fake_requires_grad(var): + if var is not None: + var = var.detach() + var.requires_grad = True + return var + + return err_fn(*[fake_requires_grad(v) for v in outputs]) + + return wrapper + + +class InplaceFunction(Function): + r""" + This class is here only for backward compatibility reasons. + Use :class:`Function` instead of this for any new use case. + """ + + def __init__(self, inplace=False): + super().__init__() + self.inplace = inplace + + +def _nested_map(condition, fn, condition_msg=None): + def _map(obj): + if condition(obj): + return fn(obj) + elif obj is None: + return None + elif isinstance(obj, (list, tuple)): + mapped = (_map(x) for x in obj) + if hasattr(obj, "_fields"): + # obj is namedtuple + return type(obj)(*mapped) + return type(obj)(mapped) + elif isinstance(obj, dict): + return {x: _map(obj[x]) for x in obj} + else: + raise ValueError( + "Auto nesting doesn't know how to process " + "an input object of type " + + torch.typename(obj) + + ( + ". Accepted types: " + condition_msg + ", or lists/tuples of them" + if condition_msg + else "" + ) + ) + + return _map + + +def _jit_unwrap_structured(obj): + if hasattr(obj, "_jit_unwrap"): + return obj._jit_unwrap() + return obj + + +def _iter_filter(condition, allow_unknown=False, condition_msg=None, conversion=None): + def _iter(obj): + if conversion is not None: + obj = conversion(obj) + if condition(obj): + yield obj + elif obj is None: + return + elif isinstance(obj, (list, tuple)): + for o in obj: + yield from _iter(o) + elif isinstance(obj, dict): + # We only accept primitive key types, so we needn't inspect them + for o in obj.values(): + yield from _iter(o) + elif allow_unknown: + yield obj + else: + raise ValueError( + "Auto nesting doesn't know how to process " + "an input object of type " + + torch.typename(obj) + + ( + ". Accepted types: " + condition_msg + ", or lists/tuples of them" + if condition_msg + else "" + ) + ) + + return _iter + + +def _unflatten(input, proto): + # unflatten a list or tuple input into a nested list/tuple structure + # specified by proto + def unflatten_helper(input, proto): + res: list[Optional[torch.Tensor]] = [] + if hasattr(proto, "_jit_wrap"): + return proto._jit_wrap(input) + if not isinstance(proto, (list, tuple)): + return input[0], input[1:] + for e in proto: + if e is None: + res.append(e) + else: + res_e, input = unflatten_helper(input, e) + res.append(res_e) + return type(proto)(res), input + + return unflatten_helper(input, proto)[0] + + +_iter_jit_values = _iter_filter( + lambda o: o is None or isinstance(o, torch._C.Value), + condition_msg="jit's Values or None", +) +_iter_tensors = _iter_filter( + lambda x: isinstance(x, torch.Tensor), + condition_msg="Tensors", + conversion=_jit_unwrap_structured, +) +_iter_tensors_permissive = _iter_filter( + lambda x: isinstance(x, torch.Tensor), + allow_unknown=True, + condition_msg="Tensors (permissive)", +) +_iter_None_tensors = _iter_filter( + lambda o: o is None or isinstance(o, torch.Tensor), condition_msg="Tensors or None" +) +_map_tensor_data = _nested_map( + lambda x: isinstance(x, torch.Tensor), lambda o: o.data, condition_msg="Tensors" +) + + +class NestedIOFunction(Function): + r""" + This class is here only for backward compatibility reasons. + Use :class:`Function` instead of this for any new use case. + """ + # The 'type: ignore' statements are needed here because these functions are declared as '@staticmethod' in the + # superclass (Function) but are instance methods here, which mypy reports as incompatible. + + def _do_forward(self, *input): + self._nested_input = input + flat_input = tuple(_iter_tensors(input)) + flat_output = super()._do_forward(*flat_input) # type: ignore[misc] + nested_tensors = _unflatten(flat_output, self._nested_output) + return nested_tensors + + def _do_backward(self, gradients, retain_variables): + self.retain_variables = retain_variables + result = super()._do_backward(gradients, retain_variables) # type: ignore[misc] + if not retain_variables: + del self._nested_output + del self._to_save_nested + return result + + def backward(self, *gradients: Any) -> Any: # type: ignore[override] + r""" + Shared backward utility. + """ + nested_gradients = _unflatten(gradients, self._nested_output) + result = self.backward_extended(*nested_gradients) # type: ignore[func-returns-value] + return tuple(_iter_None_tensors(result)) + + __call__ = _do_forward + + def forward(self, *args: Any) -> Any: # type: ignore[override] + r""" + Shared forward utility. + """ + nested_tensors = _map_tensor_data(self._nested_input) + result = self.forward_extended(*nested_tensors) # type: ignore[func-returns-value] + del self._nested_input + self._nested_output = result + return tuple(_iter_tensors(result)) + + def save_for_backward(self, *args: Any) -> None: + r""" + See :meth:`Function.save_for_backward`. + """ + self.to_save = tuple(_iter_tensors(args)) + self._to_save_nested = args + + @property + def saved_tensors(self): # type: ignore[override] + r""" + See :meth:`Function.saved_tensors`. + """ + flat_tensors = super().saved_tensors # type: ignore[misc] + return _unflatten(flat_tensors, self._to_save_nested) + + def mark_dirty(self, *args: Any, **kwargs: Any) -> None: + r""" + See :meth:`Function.mark_dirty`. + """ + self.dirty_tensors = tuple(_iter_tensors((args, kwargs))) + + def mark_non_differentiable(self, *args: Any, **kwargs: Any) -> None: + r""" + See :meth:`Function.mark_non_differentiable`. + """ + self.non_differentiable = tuple(_iter_tensors((args, kwargs))) + + def forward_extended(self, *input: Any) -> None: + r""" + User defined forward. + """ + raise NotImplementedError + + def backward_extended(self, *grad_output: Any) -> None: + r""" + User defined backward. + """ + raise NotImplementedError diff --git a/phivenv/Lib/site-packages/torch/autograd/functional.py b/phivenv/Lib/site-packages/torch/autograd/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..490ddf96b31815cc10ef8c8b9dbd4823e54298bc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/autograd/functional.py @@ -0,0 +1,1194 @@ +# mypy: allow-untyped-defs + +import torch +from torch._vmap_internals import _vmap + +from . import forward_ad as fwAD + + +__all__ = ["vjp", "jvp", "jacobian", "hessian", "hvp", "vhp"] + +# Utility functions + + +def _as_tuple_nocheck(x): + if isinstance(x, tuple): + return x + elif isinstance(x, list): + return tuple(x) + else: + return (x,) + + +def _as_tuple(inp, arg_name=None, fn_name=None): + # Ensures that inp is a tuple of Tensors + # Returns whether or not the original inp was a tuple and the tupled version of the input + if arg_name is None and fn_name is None: + return _as_tuple_nocheck(inp) + + is_inp_tuple = True + if not isinstance(inp, tuple): + inp = (inp,) + is_inp_tuple = False + + for i, el in enumerate(inp): + if not isinstance(el, torch.Tensor): + if is_inp_tuple: + raise TypeError( + f"The {arg_name} given to {fn_name} must be either a Tensor or a tuple of Tensors but the" + f" value at index {i} has type {type(el)}." + ) + else: + raise TypeError( + f"The {arg_name} given to {fn_name} must be either a Tensor or a tuple of Tensors but the" + f" given {arg_name} has type {type(el)}." + ) + + return is_inp_tuple, inp + + +def _tuple_postprocess(res, to_unpack): + # Unpacks a potentially nested tuple of Tensors + # to_unpack should be a single boolean or a tuple of two booleans. + # It is used to: + # - invert _as_tuple when res should match the inp given to _as_tuple + # - optionally remove nesting of two tuples created by multiple calls to _as_tuple + if isinstance(to_unpack, tuple): + assert len(to_unpack) == 2 + if not to_unpack[1]: + res = tuple(el[0] for el in res) + if not to_unpack[0]: + res = res[0] + else: + if not to_unpack: + res = res[0] + return res + + +def _grad_preprocess(inputs, create_graph, need_graph): + # Preprocess the inputs to make sure they require gradient + # inputs is a tuple of Tensors to preprocess + # create_graph specifies if the user wants gradients to flow back to the Tensors in inputs + # need_graph specifies if we internally want gradients to flow back to the Tensors in res + # Note that we *always* create a new Tensor object to be able to see the difference between + # inputs given as arguments and the same Tensors automatically captured by the user function. + # Check this issue for more details on how that can happen: https://github.com/pytorch/pytorch/issues/32576 + res = [] + for inp in inputs: + if create_graph and inp.requires_grad: + # Create at least a new Tensor object in a differentiable way + if not inp.is_sparse: + # Use .view_as() to get a shallow copy + res.append(inp.view_as(inp)) + else: + # We cannot use view for sparse Tensors so we clone + res.append(inp.clone()) + else: + res.append(inp.detach().requires_grad_(need_graph)) + return tuple(res) + + +def _grad_postprocess(inputs, create_graph): + # Postprocess the generated Tensors to avoid returning Tensors with history when the user did not + # request it. + if isinstance(inputs[0], torch.Tensor): + if not create_graph: + return tuple(inp.detach() for inp in inputs) + else: + return inputs + else: + return tuple(_grad_postprocess(inp, create_graph) for inp in inputs) + + +def _validate_v(v, other, is_other_tuple): + # This assumes that other is the correct shape, and v should match + # Both are assumed to be tuples of Tensors + if len(other) != len(v): + if is_other_tuple: + raise RuntimeError( + f"v is a tuple of invalid length: should be {len(other)} but got {len(v)}." + ) + else: + raise RuntimeError("The given v should contain a single Tensor.") + + for idx, (el_v, el_other) in enumerate(zip(v, other)): + if el_v.size() != el_other.size(): + prepend = "" + if is_other_tuple: + prepend = f"Entry {idx} in " + raise RuntimeError( + f"{prepend}v has invalid size: should be {el_other.size()} but got {el_v.size()}." + ) + + +def _check_requires_grad(inputs, input_type, strict): + # Used to make all the necessary checks to raise nice errors in strict mode. + if not strict: + return + + if input_type not in ["outputs", "grad_inputs", "jacobian", "hessian"]: + raise RuntimeError("Invalid input_type to _check_requires_grad") + for i, inp in enumerate(inputs): + if inp is None: + # This can only be reached for grad_inputs. + raise RuntimeError( + f"The output of the user-provided function is independent of input {i}." + " This is not allowed in strict mode." + ) + if not inp.requires_grad: + if input_type == "hessian": + raise RuntimeError( + f"The hessian of the user-provided function with respect to input {i}" + " is independent of the input. This is not allowed in strict mode." + " You should ensure that your function is thrice differentiable and that" + " the hessian depends on the inputs." + ) + elif input_type == "jacobian": + raise RuntimeError( + "While computing the hessian, found that the jacobian of the user-provided" + f" function with respect to input {i} is independent of the input. This is not" + " allowed in strict mode. You should ensure that your function is twice" + " differentiable and that the jacobian depends on the inputs (this would be" + " violated by a linear function for example)." + ) + elif input_type == "grad_inputs": + raise RuntimeError( + f"The gradient with respect to input {i} is independent of the inputs of the" + " user-provided function. This is not allowed in strict mode." + ) + else: + raise RuntimeError( + f"Output {i} of the user-provided function does not require gradients." + " The outputs must be computed in a differentiable manner from the input" + " when running in strict mode." + ) + + +def _autograd_grad( + outputs, + inputs, + grad_outputs=None, + create_graph=False, + retain_graph=None, + is_grads_batched=False, +): + # Version of autograd.grad that accepts `None` in outputs and do not compute gradients for them. + # This has the extra constraint that inputs has to be a tuple + assert isinstance(outputs, tuple) + if grad_outputs is None: + grad_outputs = (None,) * len(outputs) + assert isinstance(grad_outputs, tuple) + assert len(outputs) == len(grad_outputs) + + new_outputs: tuple[torch.Tensor, ...] = () + new_grad_outputs: tuple[torch.Tensor, ...] = () + for out, grad_out in zip(outputs, grad_outputs): + if out is not None and out.requires_grad: + new_outputs += (out,) + new_grad_outputs += (grad_out,) + + if len(new_outputs) == 0: + # No differentiable output, we don't need to call the autograd engine + return (None,) * len(inputs) + else: + return torch.autograd.grad( + new_outputs, + inputs, + new_grad_outputs, + allow_unused=True, + create_graph=create_graph, + retain_graph=retain_graph, + is_grads_batched=is_grads_batched, + ) + + +def _fill_in_zeros(grads, refs, strict, create_graph, stage): + # Used to detect None in the grads and depending on the flags, either replace them + # with Tensors full of 0s of the appropriate size based on the refs or raise an error. + # strict and create graph allow us to detect when it is appropriate to raise an error + # stage gives us information of which backward call we consider to give good error message + if stage not in ["back", "back_trick", "double_back", "double_back_trick"]: + raise RuntimeError(f"Invalid stage argument '{stage}' to _fill_in_zeros") + + res: tuple[torch.Tensor, ...] = () + for i, grads_i in enumerate(grads): + if grads_i is None: + if strict: + if stage == "back": + raise RuntimeError( + "The output of the user-provided function is independent of " + f"input {i}. This is not allowed in strict mode." + ) + elif stage == "back_trick": + raise RuntimeError( + f"The gradient with respect to the input is independent of entry {i}" + " in the grad_outputs when using the double backward trick to compute" + " forward mode gradients. This is not allowed in strict mode." + ) + elif stage == "double_back": + raise RuntimeError( + "The jacobian of the user-provided function is independent of " + f"input {i}. This is not allowed in strict mode." + ) + else: + raise RuntimeError( + "The hessian of the user-provided function is independent of " + f"entry {i} in the grad_jacobian. This is not allowed in strict " + "mode as it prevents from using the double backward trick to " + "replace forward mode AD." + ) + + grads_i = torch.zeros_like(refs[i]) + else: + if strict and create_graph and not grads_i.requires_grad: + if "double" not in stage: + raise RuntimeError( + "The jacobian of the user-provided function is independent of " + f"input {i}. This is not allowed in strict mode when create_graph=True." + ) + else: + raise RuntimeError( + "The hessian of the user-provided function is independent of " + f"input {i}. This is not allowed in strict mode when create_graph=True." + ) + + res += (grads_i,) + + return res + + +# Public API + + +def vjp(func, inputs, v=None, create_graph=False, strict=False): + r"""Compute the dot product between a vector ``v`` and the Jacobian of the given function at the point given by the inputs. + + Args: + func (function): a Python function that takes Tensor inputs and returns + a tuple of Tensors or a Tensor. + inputs (tuple of Tensors or Tensor): inputs to the function ``func``. + v (tuple of Tensors or Tensor): The vector for which the vector + Jacobian product is computed. Must be the same size as the output + of ``func``. This argument is optional when the output of ``func`` + contains a single element and (if it is not provided) will be set + as a Tensor containing a single ``1``. + create_graph (bool, optional): If ``True``, both the output and result + will be computed in a differentiable way. Note that when ``strict`` + is ``False``, the result can not require gradients or be + disconnected from the inputs. Defaults to ``False``. + strict (bool, optional): If ``True``, an error will be raised when we + detect that there exists an input such that all the outputs are + independent of it. If ``False``, we return a Tensor of zeros as the + vjp for said inputs, which is the expected mathematical value. + Defaults to ``False``. + + Returns: + output (tuple): tuple with: + func_output (tuple of Tensors or Tensor): output of ``func(inputs)`` + + vjp (tuple of Tensors or Tensor): result of the dot product with + the same shape as the inputs. + + Example: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) + >>> def exp_reducer(x): + ... return x.exp().sum(dim=1) + >>> inputs = torch.rand(4, 4) + >>> v = torch.ones(4) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> vjp(exp_reducer, inputs, v) + (tensor([5.7817, 7.2458, 5.7830, 6.7782]), + tensor([[1.4458, 1.3962, 1.3042, 1.6354], + [2.1288, 1.0652, 1.5483, 2.5035], + [2.2046, 1.1292, 1.1432, 1.3059], + [1.3225, 1.6652, 1.7753, 2.0152]])) + + >>> vjp(exp_reducer, inputs, v, create_graph=True) + (tensor([5.7817, 7.2458, 5.7830, 6.7782], grad_fn=), + tensor([[1.4458, 1.3962, 1.3042, 1.6354], + [2.1288, 1.0652, 1.5483, 2.5035], + [2.2046, 1.1292, 1.1432, 1.3059], + [1.3225, 1.6652, 1.7753, 2.0152]], grad_fn=)) + + >>> def adder(x, y): + ... return 2 * x + 3 * y + >>> inputs = (torch.rand(2), torch.rand(2)) + >>> v = torch.ones(2) + >>> vjp(adder, inputs, v) + (tensor([2.4225, 2.3340]), + (tensor([2., 2.]), tensor([3., 3.]))) + """ + with torch.enable_grad(): + is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "vjp") + inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True) + + outputs = func(*inputs) + is_outputs_tuple, outputs = _as_tuple( + outputs, "outputs of the user-provided function", "vjp" + ) + _check_requires_grad(outputs, "outputs", strict=strict) + + if v is not None: + _, v = _as_tuple(v, "v", "vjp") + v = _grad_preprocess(v, create_graph=create_graph, need_graph=False) + _validate_v(v, outputs, is_outputs_tuple) + else: + if len(outputs) != 1 or outputs[0].nelement() != 1: + raise RuntimeError( + "The vector v can only be None if the " + "user-provided function returns " + "a single Tensor with a single element." + ) + + enable_grad = True if create_graph else torch.is_grad_enabled() + with torch.set_grad_enabled(enable_grad): + grad_res = _autograd_grad(outputs, inputs, v, create_graph=create_graph) + vjp = _fill_in_zeros(grad_res, inputs, strict, create_graph, "back") + + # Cleanup objects and return them to the user + outputs = _grad_postprocess(outputs, create_graph) + vjp = _grad_postprocess(vjp, create_graph) + + return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess( + vjp, is_inputs_tuple + ) + + +def jvp(func, inputs, v=None, create_graph=False, strict=False): + r"""Compute the dot product between the Jacobian of the given function at the point given by the inputs and a vector ``v``. + + Args: + func (function): a Python function that takes Tensor inputs and returns + a tuple of Tensors or a Tensor. + inputs (tuple of Tensors or Tensor): inputs to the function ``func``. + v (tuple of Tensors or Tensor): The vector for which the Jacobian + vector product is computed. Must be the same size as the input of + ``func``. This argument is optional when the input to ``func`` + contains a single element and (if it is not provided) will be set + as a Tensor containing a single ``1``. + create_graph (bool, optional): If ``True``, both the output and result + will be computed in a differentiable way. Note that when ``strict`` + is ``False``, the result can not require gradients or be + disconnected from the inputs. Defaults to ``False``. + strict (bool, optional): If ``True``, an error will be raised when we + detect that there exists an input such that all the outputs are + independent of it. If ``False``, we return a Tensor of zeros as the + jvp for said inputs, which is the expected mathematical value. + Defaults to ``False``. + + Returns: + output (tuple): tuple with: + func_output (tuple of Tensors or Tensor): output of ``func(inputs)`` + + jvp (tuple of Tensors or Tensor): result of the dot product with + the same shape as the output. + + Note: + ``autograd.functional.jvp`` computes the jvp by using the backward of + the backward (sometimes called the double backwards trick). This is not + the most performant way of computing the jvp. Please consider using + :func:`torch.func.jvp` or the + :ref:`low-level forward-mode AD API ` instead. + + Example: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) + >>> def exp_reducer(x): + ... return x.exp().sum(dim=1) + >>> inputs = torch.rand(4, 4) + >>> v = torch.ones(4, 4) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> jvp(exp_reducer, inputs, v) + (tensor([6.3090, 4.6742, 7.9114, 8.2106]), + tensor([6.3090, 4.6742, 7.9114, 8.2106])) + + >>> jvp(exp_reducer, inputs, v, create_graph=True) + (tensor([6.3090, 4.6742, 7.9114, 8.2106], grad_fn=), + tensor([6.3090, 4.6742, 7.9114, 8.2106], grad_fn=)) + + >>> def adder(x, y): + ... return 2 * x + 3 * y + >>> inputs = (torch.rand(2), torch.rand(2)) + >>> v = (torch.ones(2), torch.ones(2)) + >>> jvp(adder, inputs, v) + (tensor([2.2399, 2.5005]), + tensor([5., 5.])) + + """ + with torch.enable_grad(): + is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jvp") + inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True) + + if v is not None: + _, v = _as_tuple(v, "v", "jvp") + v = _grad_preprocess(v, create_graph=create_graph, need_graph=False) + _validate_v(v, inputs, is_inputs_tuple) + else: + if len(inputs) != 1 or inputs[0].nelement() != 1: + raise RuntimeError( + "The vector v can only be None if the input to " + "the user-provided function is a single Tensor " + "with a single element." + ) + + outputs = func(*inputs) + is_outputs_tuple, outputs = _as_tuple( + outputs, "outputs of the user-provided function", "jvp" + ) + _check_requires_grad(outputs, "outputs", strict=strict) + # The backward is linear so the value of grad_outputs is not important as + # it won't appear in the double backward graph. We only need to ensure that + # it does not contain inf or nan. + grad_outputs = tuple( + torch.zeros_like(out, requires_grad=True) for out in outputs + ) + + grad_inputs = _autograd_grad(outputs, inputs, grad_outputs, create_graph=True) + _check_requires_grad(grad_inputs, "grad_inputs", strict=strict) + + if create_graph: + with torch.enable_grad(): + grad_res = _autograd_grad( + grad_inputs, grad_outputs, v, create_graph=create_graph + ) + jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick") + else: + grad_res = _autograd_grad( + grad_inputs, grad_outputs, v, create_graph=create_graph + ) + jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick") + + # Cleanup objects and return them to the user + outputs = _grad_postprocess(outputs, create_graph) + jvp = _grad_postprocess(jvp, create_graph) + + return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess( + jvp, is_outputs_tuple + ) + + +def _construct_standard_basis_for( + tensors: tuple[torch.Tensor, ...], tensor_numels: tuple[int, ...] +) -> tuple[torch.Tensor, ...]: + # This function: + # - constructs a N=sum(tensor_numels) standard basis. i.e. an NxN identity matrix. + # - Splits the identity matrix into chunks with each chunk size determined by `tensor_numels`. + # - Each chunk corresponds to one tensor. The chunk has the same dtype and + # device as the tensor + # + # For example, with tensor_numels = [1, 2, 1], this function returns: + # ( tensor([[1], tensor([[0, 0], tensor([[0], + # [0], [1, 0], [0], + # [0], [0, 1], [0], + # [0]]) , [0, 0]]) , [1]]) ) + # + # Precondition: tensor_numels == tuple(tensor.numel() for tensor in tensors) + # Precondition: tensors always has at least one element. + # + # See NOTE: [Computing jacobian with vmap and grad for multiple tensors] + # for context behind this function. All the pre-conditions are guarded for + # in torch.autograd.functional.jacobian. + assert len(tensors) == len(tensor_numels) + assert len(tensors) > 0 + total_numel = sum(tensor_numels) + chunks = tuple( + tensor.new_zeros(total_numel, tensor_numel) + for tensor, tensor_numel in zip(tensors, tensor_numels) + ) + diag_start_idx = 0 + for chunk, numel in zip(chunks, tensor_numels): + chunk.diagonal(diag_start_idx).fill_(1) + diag_start_idx -= numel + return chunks + + +def _jacfwd(func, inputs, strict=False, vectorize=False): + if strict: + raise RuntimeError( + "torch.autograd.functional.jacobian: `strict=True` " + 'and `strategy="forward-mode"` are not supported together (yet). ' + "Please either set `strict=False` or " + '`strategy="reverse-mode"`.' + ) + is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jacobian") + output_info = [] + + if vectorize: + # See NOTE: [Computing jacobian with vmap and grad for multiple outputs] + input_numels = tuple(input.numel() for input in inputs) + + # Step 1: Prepare tangents + tangents = _construct_standard_basis_for(inputs, input_numels) + + # Step 2: Compute vmap over computation with dual tensors + def jvp(tangents): + with fwAD.dual_level(): + dual_inputs = tuple( + fwAD.make_dual(input, tangent.view_as(input)) + for input, tangent in zip(inputs, tangents) + ) + _is_outputs_tuple, dual_outputs = _as_tuple( + func(*dual_inputs), "outputs" + ) + output_info.append(_is_outputs_tuple) + jv = [] + primal_outs = [] + for dual_out in dual_outputs: + primal, tangent = fwAD.unpack_dual(dual_out) + primal_outs.append(primal) + if tangent is not None: + jv.append(tangent) + else: + jv.append(torch.zeros_like(primal)) + output_info.append(primal_outs) + return tuple(jv) + + outputs_before_split = _vmap(jvp)(tangents) + is_outputs_tuple, outputs = output_info + # Step 3: for each of the output tangents, split along dim 0 + jacobian_input_output = [] + for jac_output_i, output_i in zip(outputs_before_split, outputs): + jacobian_output_i_output = [] + for jac, input_j in zip(jac_output_i.split(input_numels, dim=0), inputs): + # We need to transpose the Jacobian because in forward AD, the + # batch dimension represents that of the inputs + jacobian_input_i_output_j = jac.permute(*range(1, jac.ndim), 0).reshape( + (*output_i.shape, *input_j.shape) + ) # noqa: C409 + + jacobian_output_i_output.append(jacobian_input_i_output_j) + jacobian_input_output.append(jacobian_output_i_output) + + # Omit [Step 4] because everything is already transposed w/ forward AD + return _tuple_postprocess( + jacobian_input_output, (is_outputs_tuple, is_inputs_tuple) + ) + else: + raise NotImplementedError( + "Computing Jacobian using forward-AD or forward-over-reverse Hessian is" + "only implemented for `vectorize=True`." + ) + + +def jacobian( + func, + inputs, + create_graph=False, + strict=False, + vectorize=False, + strategy="reverse-mode", +): + r"""Compute the Jacobian of a given function. + + Args: + func (function): a Python function that takes Tensor inputs and returns + a tuple of Tensors or a Tensor. + inputs (tuple of Tensors or Tensor): inputs to the function ``func``. + create_graph (bool, optional): If ``True``, the Jacobian will be + computed in a differentiable manner. Note that when ``strict`` is + ``False``, the result can not require gradients or be disconnected + from the inputs. Defaults to ``False``. + strict (bool, optional): If ``True``, an error will be raised when we + detect that there exists an input such that all the outputs are + independent of it. If ``False``, we return a Tensor of zeros as the + jacobian for said inputs, which is the expected mathematical value. + Defaults to ``False``. + vectorize (bool, optional): This feature is experimental. + Please consider using :func:`torch.func.jacrev` or + :func:`torch.func.jacfwd` instead if you are looking for something + less experimental and more performant. + When computing the jacobian, usually we invoke + ``autograd.grad`` once per row of the jacobian. If this flag is + ``True``, we perform only a single ``autograd.grad`` call with + ``batched_grad=True`` which uses the vmap prototype feature. + Though this should lead to performance improvements in many cases, + because this feature is still experimental, there may be performance + cliffs. See :func:`torch.autograd.grad`'s ``batched_grad`` parameter for + more information. + strategy (str, optional): Set to ``"forward-mode"`` or ``"reverse-mode"`` to + determine whether the Jacobian will be computed with forward or reverse + mode AD. Currently, ``"forward-mode"`` requires ``vectorized=True``. + Defaults to ``"reverse-mode"``. If ``func`` has more outputs than + inputs, ``"forward-mode"`` tends to be more performant. Otherwise, + prefer to use ``"reverse-mode"``. + + Returns: + Jacobian (Tensor or nested tuple of Tensors): if there is a single + input and output, this will be a single Tensor containing the + Jacobian for the linearized inputs and output. If one of the two is + a tuple, then the Jacobian will be a tuple of Tensors. If both of + them are tuples, then the Jacobian will be a tuple of tuple of + Tensors where ``Jacobian[i][j]`` will contain the Jacobian of the + ``i``\th output and ``j``\th input and will have as size the + concatenation of the sizes of the corresponding output and the + corresponding input and will have same dtype and device as the + corresponding input. If strategy is ``forward-mode``, the dtype will be + that of the output; otherwise, the input. + + Example: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) + >>> def exp_reducer(x): + ... return x.exp().sum(dim=1) + >>> inputs = torch.rand(2, 2) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> jacobian(exp_reducer, inputs) + tensor([[[1.4917, 2.4352], + [0.0000, 0.0000]], + [[0.0000, 0.0000], + [2.4369, 2.3799]]]) + + >>> jacobian(exp_reducer, inputs, create_graph=True) + tensor([[[1.4917, 2.4352], + [0.0000, 0.0000]], + [[0.0000, 0.0000], + [2.4369, 2.3799]]], grad_fn=) + + >>> def exp_adder(x, y): + ... return 2 * x.exp() + 3 * y + >>> inputs = (torch.rand(2), torch.rand(2)) + >>> jacobian(exp_adder, inputs) + (tensor([[2.8052, 0.0000], + [0.0000, 3.3963]]), + tensor([[3., 0.], + [0., 3.]])) + + >>> def linear_model(x): + ... W = torch.tensor([[2.0, -1.0], [0.0, 1.0]]) + ... b = torch.tensor([1.0, 0.5]) + ... return x @ W.T + b + + >>> x = torch.randn(4, 2, requires_grad=True) + >>> jac = jacobian(linear_model, x, vectorize=True) + >>> jac.shape + torch.Size([4, 2, 4, 2]) + """ + assert strategy in ("forward-mode", "reverse-mode"), ( + 'Expected strategy to be either "forward-mode" or "reverse-mode". Hint: If your ' + 'function has more outputs than inputs, "forward-mode" tends to be more performant. ' + 'Otherwise, prefer to use "reverse-mode".' + ) + if strategy == "forward-mode": + if create_graph: + raise NotImplementedError( + "torch.autograd.functional.jacobian: `create_graph=True` " + 'and `strategy="forward-mode"` are not supported together (yet). ' + "Please either set `create_graph=False` or " + '`strategy="reverse-mode"`.' + ) + return _jacfwd(func, inputs, strict, vectorize) + + with torch.enable_grad(): + is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jacobian") + inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True) + + outputs = func(*inputs) + is_outputs_tuple, outputs = _as_tuple( + outputs, "outputs of the user-provided function", "jacobian" + ) + _check_requires_grad(outputs, "outputs", strict=strict) + + if vectorize: + if strict: + raise RuntimeError( + "torch.autograd.functional.jacobian: `strict=True` " + "and `vectorized=True` are not supported together. " + "Please either set `strict=False` or " + "`vectorize=False`." + ) + # NOTE: [Computing jacobian with vmap and grad for multiple outputs] + # + # Let's consider f(x) = (x**2, x.sum()) and let x = torch.randn(3). + # It turns out we can compute the jacobian of this function with a single + # call to autograd.grad by using vmap over the correct grad_outputs. + # + # Firstly, one way to compute the jacobian is to stack x**2 and x.sum() + # into a 4D vector. E.g., use g(x) = torch.stack([x**2, x.sum()]) + # + # To get the first row of the jacobian, we call + # >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([1, 0, 0, 0])) + # To get the 2nd row of the jacobian, we call + # >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([0, 1, 0, 0])) + # and so on. + # + # Using vmap, we can vectorize all 4 of these computations into one by + # passing the standard basis for R^4 as the grad_output. + # vmap(partial(autograd.grad, g(x), x))(torch.eye(4)). + # + # Now, how do we compute the jacobian *without stacking the output*? + # We can just split the standard basis across the outputs. So to + # compute the jacobian of f(x), we'd use + # >>> autograd.grad(f(x), x, grad_outputs=_construct_standard_basis_for(...)) + # The grad_outputs looks like the following: + # ( torch.tensor([[1, 0, 0], + # [0, 1, 0], + # [0, 0, 1], + # [0, 0, 0]]), + # torch.tensor([[0], + # [0], + # [0], + # [1]]) ) + # + # But we're not done yet! + # >>> vmap(partial(autograd.grad(f(x), x, grad_outputs=...))) + # returns a Tensor of shape [4, 3]. We have to remember to split the + # jacobian of shape [4, 3] into two: + # - one of shape [3, 3] for the first output + # - one of shape [ 3] for the second output + + # Step 1: Construct grad_outputs by splitting the standard basis + output_numels = tuple(output.numel() for output in outputs) + grad_outputs = _construct_standard_basis_for(outputs, output_numels) + flat_outputs = tuple(output.reshape(-1) for output in outputs) + + # Step 2: Call vmap + autograd.grad + def vjp(grad_output): + vj = list( + _autograd_grad( + flat_outputs, + inputs, + grad_output, + create_graph=create_graph, + is_grads_batched=True, + ) + ) + for el_idx, vj_el in enumerate(vj): + if vj_el is not None: + continue + vj[el_idx] = torch.zeros_like(inputs[el_idx]).expand( + (sum(output_numels),) + inputs[el_idx].shape + ) + return tuple(vj) + + jacobians_of_flat_output = vjp(grad_outputs) + + # Step 3: The returned jacobian is one big tensor per input. In this step, + # we split each Tensor by output. + jacobian_input_output = [] + for jac_input_i, input_i in zip(jacobians_of_flat_output, inputs): + jacobian_input_i_output = [] + for jac, output_j in zip( + jac_input_i.split(output_numels, dim=0), outputs + ): + jacobian_input_i_output_j = jac.view(output_j.shape + input_i.shape) + jacobian_input_i_output.append(jacobian_input_i_output_j) + jacobian_input_output.append(jacobian_input_i_output) + + # Step 4: Right now, `jacobian` is a List[List[Tensor]]. + # The outer List corresponds to the number of inputs, + # the inner List corresponds to the number of outputs. + # We need to exchange the order of these and convert to tuples + # before returning. + jacobian_output_input = tuple(zip(*jacobian_input_output)) + + jacobian_output_input = _grad_postprocess( + jacobian_output_input, create_graph + ) + return _tuple_postprocess( + jacobian_output_input, (is_outputs_tuple, is_inputs_tuple) + ) + + jacobian: tuple[torch.Tensor, ...] = () + + for i, out in enumerate(outputs): + # mypy complains that expression and variable have different types due to the empty list + jac_i: tuple[list[torch.Tensor]] = tuple([] for _ in range(len(inputs))) # type: ignore[assignment] + for j in range(out.nelement()): + vj = _autograd_grad( + (out.reshape(-1)[j],), + inputs, + retain_graph=True, + create_graph=create_graph, + ) + + for el_idx, (jac_i_el, vj_el, inp_el) in enumerate( + zip(jac_i, vj, inputs) + ): + if vj_el is not None: + if strict and create_graph and not vj_el.requires_grad: + msg = ( + "The jacobian of the user-provided function is " + f"independent of input {i}. This is not allowed in " + "strict mode when create_graph=True." + ) + raise RuntimeError(msg) + jac_i_el.append(vj_el) + else: + if strict: + msg = ( + f"Output {i} of the user-provided function is " + f"independent of input {el_idx}. This is not allowed in " + "strict mode." + ) + raise RuntimeError(msg) + jac_i_el.append(torch.zeros_like(inp_el)) + + jacobian += ( + tuple( + torch.stack(jac_i_el, dim=0).view( + out.size() + inputs[el_idx].size() # type: ignore[operator] + ) + for (el_idx, jac_i_el) in enumerate(jac_i) + ), + ) + + jacobian = _grad_postprocess(jacobian, create_graph) + + return _tuple_postprocess(jacobian, (is_outputs_tuple, is_inputs_tuple)) + + +def hessian( + func, + inputs, + create_graph=False, + strict=False, + vectorize=False, + outer_jacobian_strategy="reverse-mode", +): + r"""Compute the Hessian of a given scalar function. + + Args: + func (function): a Python function that takes Tensor inputs and returns + a Tensor with a single element. + inputs (tuple of Tensors or Tensor): inputs to the function ``func``. + create_graph (bool, optional): If ``True``, the Hessian will be computed in + a differentiable manner. Note that when ``strict`` is ``False``, the result can not + require gradients or be disconnected from the inputs. + Defaults to ``False``. + strict (bool, optional): If ``True``, an error will be raised when we detect that there exists an input + such that all the outputs are independent of it. If ``False``, we return a Tensor of zeros as the + hessian for said inputs, which is the expected mathematical value. + Defaults to ``False``. + vectorize (bool, optional): This feature is experimental. + Please consider using :func:`torch.func.hessian` + instead if you are looking for something less experimental and more performant. + When computing the hessian, usually we invoke + ``autograd.grad`` once per row of the hessian. If this flag is + ``True``, we use the vmap prototype feature as the backend to + vectorize calls to ``autograd.grad`` so we only invoke it once + instead of once per row. This should lead to performance + improvements in many use cases, however, due to this feature + being incomplete, there may be performance cliffs. Please + use `torch._C._debug_only_display_vmap_fallback_warnings(True)` + to show any performance warnings and file us issues if + warnings exist for your use case. Defaults to ``False``. + outer_jacobian_strategy (str, optional): The Hessian is computed by + computing the Jacobian of a Jacobian. The inner Jacobian is always + computed in reverse-mode AD. Setting strategy to ``"forward-mode"`` + or ``"reverse-mode"`` determines whether the outer Jacobian will be + computed with forward or reverse mode AD. Currently, computing the outer + Jacobian in ``"forward-mode"`` requires ``vectorized=True``. Defaults + to ``"reverse-mode"``. + + Returns: + Hessian (Tensor or a tuple of tuple of Tensors): if there is a single input, + this will be a single Tensor containing the Hessian for the input. + If it is a tuple, then the Hessian will be a tuple of tuples where + ``Hessian[i][j]`` will contain the Hessian of the ``i``\th input + and ``j``\th input with size the sum of the size of the ``i``\th input plus + the size of the ``j``\th input. ``Hessian[i][j]`` will have the same + dtype and device as the corresponding ``i``\th input. + + Example: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) + >>> def pow_reducer(x): + ... return x.pow(3).sum() + >>> inputs = torch.rand(2, 2) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> hessian(pow_reducer, inputs) + tensor([[[[5.2265, 0.0000], + [0.0000, 0.0000]], + [[0.0000, 4.8221], + [0.0000, 0.0000]]], + [[[0.0000, 0.0000], + [1.9456, 0.0000]], + [[0.0000, 0.0000], + [0.0000, 3.2550]]]]) + + >>> hessian(pow_reducer, inputs, create_graph=True) + tensor([[[[5.2265, 0.0000], + [0.0000, 0.0000]], + [[0.0000, 4.8221], + [0.0000, 0.0000]]], + [[[0.0000, 0.0000], + [1.9456, 0.0000]], + [[0.0000, 0.0000], + [0.0000, 3.2550]]]], grad_fn=) + + + >>> def pow_adder_reducer(x, y): + ... return (2 * x.pow(2) + 3 * y.pow(2)).sum() + >>> inputs = (torch.rand(2), torch.rand(2)) + >>> hessian(pow_adder_reducer, inputs) + ((tensor([[4., 0.], + [0., 4.]]), + tensor([[0., 0.], + [0., 0.]])), + (tensor([[0., 0.], + [0., 0.]]), + tensor([[6., 0.], + [0., 6.]]))) + """ + is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "hessian") + assert outer_jacobian_strategy in ( + "forward-mode", + "reverse-mode", + ), 'Expected strategy to be either "forward-mode" or "reverse-mode".' + + def ensure_single_output_function(*inp): + out = func(*inp) + is_out_tuple, t_out = _as_tuple( + out, "outputs of the user-provided function", "hessian" + ) + _check_requires_grad(t_out, "outputs", strict=strict) + + if is_out_tuple or not isinstance(out, torch.Tensor): + raise RuntimeError( + "The function given to hessian should return a single Tensor" + ) + + if out.nelement() != 1: + raise RuntimeError( + "The Tensor returned by the function given to hessian should contain a single element" + ) + + return out.squeeze() + + def jac_func(*inp): + if outer_jacobian_strategy == "forward-mode": + # _grad_preprocess requires create_graph=True and input to require_grad + # or else the input will be detached + inp = tuple(t.requires_grad_(True) for t in inp) + jac = jacobian(ensure_single_output_function, inp, create_graph=True) + _check_requires_grad(jac, "jacobian", strict=strict) + return jac + + res = jacobian( + jac_func, + inputs, + create_graph=create_graph, + strict=strict, + vectorize=vectorize, + strategy=outer_jacobian_strategy, + ) + return _tuple_postprocess(res, (is_inputs_tuple, is_inputs_tuple)) + + +def vhp(func, inputs, v=None, create_graph=False, strict=False): + r"""Compute the dot product between vector ``v`` and Hessian of a given scalar function at a specified point. + + Args: + func (function): a Python function that takes Tensor inputs and returns + a Tensor with a single element. + inputs (tuple of Tensors or Tensor): inputs to the function ``func``. + v (tuple of Tensors or Tensor): The vector for which the vector Hessian + product is computed. Must be the same size as the input of + ``func``. This argument is optional when ``func``'s input contains + a single element and (if it is not provided) will be set as a + Tensor containing a single ``1``. + create_graph (bool, optional): If ``True``, both the output and result + will be computed in a differentiable way. Note that when ``strict`` + is ``False``, the result can not require gradients or be + disconnected from the inputs. + Defaults to ``False``. + strict (bool, optional): If ``True``, an error will be raised when we + detect that there exists an input such that all the outputs are + independent of it. If ``False``, we return a Tensor of zeros as the + vhp for said inputs, which is the expected mathematical value. + Defaults to ``False``. + + Returns: + output (tuple): tuple with: + func_output (tuple of Tensors or Tensor): output of ``func(inputs)`` + + vhp (tuple of Tensors or Tensor): result of the dot product with the + same shape as the inputs. + + Example: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) + >>> def pow_reducer(x): + ... return x.pow(3).sum() + >>> inputs = torch.rand(2, 2) + >>> v = torch.ones(2, 2) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> vhp(pow_reducer, inputs, v) + (tensor(0.5591), + tensor([[1.0689, 1.2431], + [3.0989, 4.4456]])) + >>> vhp(pow_reducer, inputs, v, create_graph=True) + (tensor(0.5591, grad_fn=), + tensor([[1.0689, 1.2431], + [3.0989, 4.4456]], grad_fn=)) + >>> def pow_adder_reducer(x, y): + ... return (2 * x.pow(2) + 3 * y.pow(2)).sum() + >>> inputs = (torch.rand(2), torch.rand(2)) + >>> v = (torch.zeros(2), torch.ones(2)) + >>> vhp(pow_adder_reducer, inputs, v) + (tensor(4.8053), + (tensor([0., 0.]), + tensor([6., 6.]))) + """ + with torch.enable_grad(): + is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "vhp") + inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True) + + if v is not None: + _, v = _as_tuple(v, "v", "vhp") + v = _grad_preprocess(v, create_graph=create_graph, need_graph=False) + _validate_v(v, inputs, is_inputs_tuple) + else: + if len(inputs) != 1 or inputs[0].nelement() != 1: + raise RuntimeError( + "The vector v can only be None if the input to the user-provided function " + "is a single Tensor with a single element." + ) + outputs = func(*inputs) + is_outputs_tuple, outputs = _as_tuple( + outputs, "outputs of the user-provided function", "vhp" + ) + _check_requires_grad(outputs, "outputs", strict=strict) + + if is_outputs_tuple or not isinstance(outputs[0], torch.Tensor): + raise RuntimeError( + "The function given to vhp should return a single Tensor" + ) + + if outputs[0].nelement() != 1: + raise RuntimeError( + "The Tensor returned by the function given to vhp should contain a single element" + ) + + jac = _autograd_grad(outputs, inputs, create_graph=True) + _check_requires_grad(jac, "jacobian", strict=strict) + + enable_grad = True if create_graph else torch.is_grad_enabled() + with torch.set_grad_enabled(enable_grad): + grad_res = _autograd_grad(jac, inputs, v, create_graph=create_graph) + vhp = _fill_in_zeros(grad_res, inputs, strict, create_graph, "double_back") + + outputs = _grad_postprocess(outputs, create_graph) + vhp = _grad_postprocess(vhp, create_graph) + + return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess( + vhp, is_inputs_tuple + ) + + +def hvp(func, inputs, v=None, create_graph=False, strict=False): + r"""Compute the dot product between the scalar function's Hessian and a vector ``v`` at a specified point. + + Args: + func (function): a Python function that takes Tensor inputs and returns + a Tensor with a single element. + inputs (tuple of Tensors or Tensor): inputs to the function ``func``. + v (tuple of Tensors or Tensor): The vector for which the Hessian vector + product is computed. Must be the same size as the input of + ``func``. This argument is optional when ``func``'s input contains + a single element and (if it is not provided) will be set as a + Tensor containing a single ``1``. + create_graph (bool, optional): If ``True``, both the output and result will be + computed in a differentiable way. Note that when ``strict`` is + ``False``, the result can not require gradients or be disconnected + from the inputs. Defaults to ``False``. + strict (bool, optional): If ``True``, an error will be raised when we + detect that there exists an input such that all the outputs are + independent of it. If ``False``, we return a Tensor of zeros as the + hvp for said inputs, which is the expected mathematical value. + Defaults to ``False``. + Returns: + output (tuple): tuple with: + func_output (tuple of Tensors or Tensor): output of ``func(inputs)`` + + hvp (tuple of Tensors or Tensor): result of the dot product with + the same shape as the inputs. + + Example: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) + >>> def pow_reducer(x): + ... return x.pow(3).sum() + >>> inputs = torch.rand(2, 2) + >>> v = torch.ones(2, 2) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> hvp(pow_reducer, inputs, v) + (tensor(0.1448), + tensor([[2.0239, 1.6456], + [2.4988, 1.4310]])) + + >>> hvp(pow_reducer, inputs, v, create_graph=True) + (tensor(0.1448, grad_fn=), + tensor([[2.0239, 1.6456], + [2.4988, 1.4310]], grad_fn=)) + + + >>> def pow_adder_reducer(x, y): + ... return (2 * x.pow(2) + 3 * y.pow(2)).sum() + >>> inputs = (torch.rand(2), torch.rand(2)) + >>> v = (torch.zeros(2), torch.ones(2)) + >>> hvp(pow_adder_reducer, inputs, v) + (tensor(2.3030), + (tensor([0., 0.]), + tensor([6., 6.]))) + + Note: + + This function is significantly slower than `vhp` due to backward mode AD constraints. + If your functions is twice continuously differentiable, then hvp = vhp.t(). So if you + know that your function satisfies this condition, you should use vhp instead that is + much faster with the current implementation. + + """ + with torch.enable_grad(): + is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "hvp") + inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True) + + if v is not None: + _, v = _as_tuple(v, "v", "hvp") + v = _grad_preprocess(v, create_graph=create_graph, need_graph=False) + _validate_v(v, inputs, is_inputs_tuple) + else: + if len(inputs) != 1 or inputs[0].nelement() != 1: + raise RuntimeError( + "The vector v can only be None if the input to the user-provided function " + "is a single Tensor with a single element." + ) + outputs = func(*inputs) + is_outputs_tuple, outputs = _as_tuple( + outputs, "outputs of the user-provided function", "hvp" + ) + _check_requires_grad(outputs, "outputs", strict=strict) + + if is_outputs_tuple or not isinstance(outputs[0], torch.Tensor): + raise RuntimeError( + "The function given to hvp should return a single Tensor" + ) + + if outputs[0].nelement() != 1: + raise RuntimeError( + "The Tensor returned by the function given to hvp should contain a single element" + ) + + jac = _autograd_grad(outputs, inputs, create_graph=True) + _check_requires_grad(jac, "jacobian", strict=strict) + + grad_jac = tuple(torch.zeros_like(inp, requires_grad=True) for inp in inputs) + + double_back = _autograd_grad(jac, inputs, grad_jac, create_graph=True) + _check_requires_grad(jac, "hessian", strict=strict) + + enable_grad = True if create_graph else torch.is_grad_enabled() + with torch.set_grad_enabled(enable_grad): + grad_res = _autograd_grad(double_back, grad_jac, v, create_graph=create_graph) + hvp = _fill_in_zeros( + grad_res, inputs, strict, create_graph, "double_back_trick" + ) + + outputs = _grad_postprocess(outputs, create_graph) + hvp = _grad_postprocess(hvp, create_graph) + + return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess( + hvp, is_inputs_tuple + ) diff --git a/phivenv/Lib/site-packages/torch/autograd/grad_mode.py b/phivenv/Lib/site-packages/torch/autograd/grad_mode.py new file mode 100644 index 0000000000000000000000000000000000000000..ebba640d69eb1b71ffe00850e25afc7416c6e3c5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/autograd/grad_mode.py @@ -0,0 +1,404 @@ +# mypy: allow-untyped-defs +from typing import Any, Union + +import torch +from torch.utils._contextlib import ( + _DecoratorContextManager, + _NoParamDecoratorContextManager, + F, +) + + +__all__ = [ + "no_grad", + "enable_grad", + "set_grad_enabled", + "inference_mode", + "set_multithreading_enabled", +] + + +class no_grad(_NoParamDecoratorContextManager): + r"""Context-manager that disables gradient calculation. + + Disabling gradient calculation is useful for inference, when you are sure + that you will not call :meth:`Tensor.backward()`. It will reduce memory + consumption for computations that would otherwise have `requires_grad=True`. + + In this mode, the result of every computation will have + `requires_grad=False`, even when the inputs have `requires_grad=True`. + There is an exception! All factory functions, or functions that create + a new Tensor and take a requires_grad kwarg, will NOT be affected by + this mode. + + This context manager is thread local; it will not affect computation + in other threads. + + Also functions as a decorator. + + .. note:: + No-grad is one of several mechanisms that can enable or + disable gradients locally see :ref:`locally-disable-grad-doc` for + more information on how they compare. + + .. note:: + This API does not apply to :ref:`forward-mode AD `. + If you want to disable forward AD for a computation, you can unpack + your dual tensors. + + Example:: + >>> # xdoctest: +SKIP + >>> x = torch.tensor([1.], requires_grad=True) + >>> with torch.no_grad(): + ... y = x * 2 + >>> y.requires_grad + False + >>> @torch.no_grad() + ... def doubler(x): + ... return x * 2 + >>> z = doubler(x) + >>> z.requires_grad + False + >>> @torch.no_grad() + ... def tripler(x): + ... return x * 3 + >>> z = tripler(x) + >>> z.requires_grad + False + >>> # factory function exception + >>> with torch.no_grad(): + ... a = torch.nn.Parameter(torch.rand(10)) + >>> a.requires_grad + True + """ + + def __init__(self) -> None: + if not torch._jit_internal.is_scripting(): + super().__init__() + self.prev = False + + def __enter__(self) -> None: + self.prev = torch.is_grad_enabled() + torch.set_grad_enabled(False) + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + torch.set_grad_enabled(self.prev) + + +class enable_grad(_NoParamDecoratorContextManager): + r"""Context-manager that enables gradient calculation. + + Enables gradient calculation, if it has been disabled via :class:`~no_grad` + or :class:`~set_grad_enabled`. + + This context manager is thread local; it will not affect computation + in other threads. + + Also functions as a decorator. + + .. note:: + enable_grad is one of several mechanisms that can enable or + disable gradients locally see :ref:`locally-disable-grad-doc` for + more information on how they compare. + + .. note:: + This API does not apply to :ref:`forward-mode AD `. + + Example:: + >>> # xdoctest: +SKIP + >>> x = torch.tensor([1.], requires_grad=True) + >>> with torch.no_grad(): + ... with torch.enable_grad(): + ... y = x * 2 + >>> y.requires_grad + True + >>> y.backward() + >>> x.grad + tensor([2.]) + >>> @torch.enable_grad() + ... def doubler(x): + ... return x * 2 + >>> with torch.no_grad(): + ... z = doubler(x) + >>> z.requires_grad + True + >>> @torch.enable_grad() + ... def tripler(x): + ... return x * 3 + >>> with torch.no_grad(): + ... z = tripler(x) + >>> z.requires_grad + True + + """ + + def __enter__(self) -> None: + self.prev = torch.is_grad_enabled() + torch._C._set_grad_enabled(True) + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + torch._C._set_grad_enabled(self.prev) + + +class set_grad_enabled(_DecoratorContextManager): + r"""Context-manager that sets gradient calculation on or off. + + ``set_grad_enabled`` will enable or disable grads based on its argument :attr:`mode`. + It can be used as a context-manager or as a function. + + This context manager is thread local; it will not affect computation + in other threads. + + Args: + mode (bool): Flag whether to enable grad (``True``), or disable + (``False``). This can be used to conditionally enable + gradients. + + .. note:: + set_grad_enabled is one of several mechanisms that can enable or + disable gradients locally see :ref:`locally-disable-grad-doc` for + more information on how they compare. + + .. note:: + This API does not apply to :ref:`forward-mode AD `. + + Example:: + >>> # xdoctest: +SKIP + >>> x = torch.tensor([1.], requires_grad=True) + >>> is_train = False + >>> with torch.set_grad_enabled(is_train): + ... y = x * 2 + >>> y.requires_grad + False + >>> _ = torch.set_grad_enabled(True) + >>> y = x * 2 + >>> y.requires_grad + True + >>> _ = torch.set_grad_enabled(False) + >>> y = x * 2 + >>> y.requires_grad + False + + """ + + def __init__(self, mode: bool) -> None: + self.prev = torch.is_grad_enabled() + self.mode = mode + torch._C._set_grad_enabled(mode) + + def __call__(self, orig_func: F) -> F: + torch._C._set_grad_enabled(self.prev) + return super().__call__(orig_func) + + def __enter__(self) -> None: + torch._C._set_grad_enabled(self.mode) + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + torch._C._set_grad_enabled(self.prev) + + def __str__(self) -> str: + return f"{torch.typename(self)}(mode={self.mode})" + + def __repr__(self) -> str: + return str(self) + + def clone(self) -> "set_grad_enabled": + r""" + Create a copy of this class + """ + return self.__class__(self.mode) + + +class inference_mode(_DecoratorContextManager): + r"""Context-manager that enables or disables inference mode. + + InferenceMode is a context manager analogous to :class:`~no_grad` + to be used when you are certain your operations will have no interactions + with autograd (e.g., model training). Code run under this mode gets better + performance by disabling view tracking and version counter bumps. Note that + unlike some other mechanisms that locally enable or disable grad, + entering inference_mode also disables to :ref:`forward-mode AD `. + + This context manager is thread local; it will not affect computation + in other threads. + + Also functions as a decorator. + + .. note:: + Inference mode is one of several mechanisms that can enable or + disable gradients locally see :ref:`locally-disable-grad-doc` for + more information on how they compare. + + Args: + mode (bool or function): Either a boolean flag whether to enable or + disable inference mode or a Python function to decorate with + inference mode enabled + + Example:: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) + >>> import torch + >>> x = torch.ones(1, 2, 3, requires_grad=True) + >>> with torch.inference_mode(): + ... y = x * x + >>> y.requires_grad + False + >>> # xdoctest: +SKIP("want string isnt quite right") + >>> y._version + Traceback (most recent call last): + File "", line 1, in + RuntimeError: Inference tensors do not track version counter. + >>> @torch.inference_mode() + ... def func(x): + ... return x * x + >>> out = func(x) + >>> out.requires_grad + False + >>> @torch.inference_mode() + ... def doubler(x): + ... return x * 2 + >>> out = doubler(x) + >>> out.requires_grad + False + + """ + + def __init__(self, mode: bool = True) -> None: + if not torch._jit_internal.is_scripting(): + super().__init__() + self.mode = mode + + def __new__(cls, mode=True): + if isinstance(mode, bool): + return super().__new__(cls) + return cls()(mode) + + def __enter__(self) -> None: + self._inference_mode_context = torch._C._InferenceMode(self.mode) + self._inference_mode_context.__enter__() + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + self._inference_mode_context.__exit__(exc_type, exc_value, traceback) + + def clone(self) -> "inference_mode": + r""" + Create a copy of this class + """ + return self.__class__(self.mode) + + +def _enter_inference_mode(mode): + mode_context = torch._C._InferenceMode(mode) + mode_context.__enter__() + return mode_context + + +def _exit_inference_mode(mode): + mode.__exit__(None, None, None) + + +class set_multithreading_enabled(_DecoratorContextManager): + r"""Context-manager that sets multithreaded backwards on or off. + + ``set_multithreading_enabled`` will enable or disable multithreaded backwards based on its argument :attr:`mode`. + It can be used as a context-manager or as a function. + + This context manager is thread local; it will not affect computation + in other threads. + + Args: + mode (bool): Flag whether to enable multithreaded backwards (``True``), or disable + (``False``). + + .. note:: + This API does not apply to :ref:`forward-mode AD `. + + """ + + def __init__(self, mode: bool) -> None: + self.prev = torch._C._is_multithreading_enabled() + torch._C._set_multithreading_enabled(mode) + self.mode = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + torch._C._set_multithreading_enabled(self.prev) + + def clone(self) -> "set_multithreading_enabled": + r""" + Create a copy of this class + """ + return self.__class__(self.mode) + + +class _force_original_view_tracking(_DecoratorContextManager): + r"""Context-manager that sets whether or not to always enable view-replay in autograd. + + ``set_view_replay_enabled`` will enable or disable view-replay based on its argument :attr:`mode`. + It can be used as a context-manager or as a function. + + This context manager is thread local; it will not affect computation + in other threads. + + When a tensor view is mutated, the autograd engine needs to decide whether or not + to regenerate the "updated view" by either replaying the chain of views from the updated base, + or with a single call to as_strided. + + If set_view_replay_enabled is set to True, then autograd will always use view replay. + Otherwise, it will fall back to its existing logic. + + Args: + mode (bool): Flag whether to enable view-replay (``True``), or disable + (``False``). + + """ + + def __init__(self, mode: bool) -> None: + self.prev = torch._C._is_view_replay_enabled() + torch._C._set_view_replay_enabled(mode) + self.mode = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + torch._C._set_view_replay_enabled(self.prev) + + def clone(self): + return self.__class__(self.mode) + + +class _unsafe_preserve_version_counter(_DecoratorContextManager): + r"""DO NOT USE THIS UNLESS YOU KNOW EXACTLY WHAT YOU'RE DOING. + + This context manager can lead to arbitrary silent-correctness issues in any other part of your code + (even the ones not touched directly by the context manager)! + + Ordinarily, autograd will track mutations to tensors by incrementing it's `._version` attribute. + This is generally important for correctness, as for example, mutating a tensor that autograd has saved + for the backwards pass can result in incorrect gradients, and autograd uses the version counter to detect + and error out in this situation. + + However, there are rare instances where it might be useful to hide mutations from autograd. For example: + if a tensor is very large, and you'd like to free its memory by storing it elsewhere, and re-populate + the tensor right before it is needed by autograd. + + Args: + tensor (torch.Tensor): the tensor in question, that you would like to preserve the version counter of. + + .. note:: + This API does not apply to :ref:`forward-mode AD `. + + """ + + def __init__(self, tensors: Union[torch.Tensor, tuple[torch.Tensor, ...]]) -> None: + self.tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tensors + assert isinstance(self.tensors, tuple) + self.prev_versions = tuple(t._version for t in self.tensors) + + def __enter__(self) -> None: + pass + + def __exit__(self, *args) -> None: + torch._C._autograd._unsafe_set_version_counter(self.tensors, self.prev_versions) diff --git a/phivenv/Lib/site-packages/torch/autograd/gradcheck.py b/phivenv/Lib/site-packages/torch/autograd/gradcheck.py new file mode 100644 index 0000000000000000000000000000000000000000..01b9a9d78b7ae820d68c38f750ef8eb4eef9c45c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/autograd/gradcheck.py @@ -0,0 +1,2273 @@ +# mypy: allow-untyped-defs +import collections +import functools +import warnings +from collections.abc import Iterable +from itertools import product +from typing import Callable, Optional, Union +from typing_extensions import deprecated + +import torch +import torch.testing +from torch._vmap_internals import _vmap, vmap +from torch.overrides import is_tensor_like +from torch.types import _TensorOrTensors + + +# Note: `get_*_jacobian` functions are added here even though we didn't intend to make them public +# since they have been exposed from before we added `__all__` and we already maintain BC for them +# We should eventually deprecate them and remove them from `__all__` +__all__ = [ + "gradcheck", + "gradgradcheck", + "GradcheckError", + "get_numerical_jacobian", + "get_analytical_jacobian", + "get_numerical_jacobian_wrt_specific_input", +] + + +class GradcheckError(RuntimeError): + r"""Error raised by :func:`gradcheck` and :func:`gradgradcheck`.""" + + +def _is_sparse_compressed_tensor(obj: torch.Tensor): + return obj.layout in { + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + } + + +def _is_sparse_any_tensor(obj: torch.Tensor): + return _is_sparse_compressed_tensor(obj) or obj.layout is torch.sparse_coo + + +def _is_float_or_complex_tensor(obj): + return is_tensor_like(obj) and (obj.is_floating_point() or obj.is_complex()) + + +def _allocate_jacobians_with_inputs( + input_tensors: tuple, numel_output +) -> tuple[torch.Tensor, ...]: + # Makes zero-filled tensors from inputs. If `numel_output` is not None, for + # each tensor in `input_tensors`, returns a new zero-filled tensor with height + # of `t.numel` and width of `numel_output`. Otherwise, for each tensor, returns + # a 1-d tensor with size `(t.numel,)`. Each new tensor will be strided and have + # the same dtype and device as those of the corresponding input. + out: list[torch.Tensor] = [ + t.new_zeros((t.numel(), numel_output), layout=torch.strided) + for t in input_tensors + if _is_float_or_complex_tensor(t) and t.requires_grad + ] + return tuple(out) + + +def _allocate_jacobians_with_outputs( + output_tensors: tuple, numel_input, dtype=None, device=None +) -> tuple[torch.Tensor, ...]: + # Makes zero-filled tensors from outputs. If `dim` is not None, for each tensor + # in `output_tensors`, returns a new zero-filled tensor with height of `dim` and + # width of `t.numel`. Otherwise, for each tensor, returns a 1-d tensor with size + # (t.numel,). + options = {"dtype": dtype, "device": device, "layout": torch.strided} + out: list[torch.Tensor] = [ + t.new_zeros((numel_input, t.numel()), **options) + for t in output_tensors + if _is_float_or_complex_tensor(t) + ] + return tuple(out) + + +def _iter_tensors( + x: Union[torch.Tensor, Iterable[torch.Tensor]], only_requiring_grad: bool = False +) -> Iterable[torch.Tensor]: + if is_tensor_like(x): + # mypy doesn't narrow type of `x` to torch.Tensor + if x.requires_grad or not only_requiring_grad: # type: ignore[union-attr] + yield x # type: ignore[misc] + elif isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + for elem in x: + yield from _iter_tensors(elem, only_requiring_grad) + + +def _densify(x): + # return a copy of sparse x with all unspecified elements + # "replaced" with zero-valued elements + if isinstance(x, (list, tuple)): + return type(x)(map(_densify, x)) + elif not is_tensor_like(x) or x.layout in {torch.strided, torch._mkldnn}: # type: ignore[attr-defined] # no attr _mkldnn + return x + elif x.layout is torch.sparse_coo: + device = x.device + indices_dtype = x._indices().dtype + tmp = torch.ones(x.shape[: x.sparse_dim()], dtype=torch.int8, device=device) + indices = tmp.nonzero().t().to(dtype=indices_dtype) + values = torch.zeros( + (tmp.numel(), *x.shape[x.sparse_dim() :]), dtype=x.dtype, device=device + ) + x_coalesced = x.detach().coalesce() + if x_coalesced.numel() > 0: + stride = tmp.stride() + flat_indices = ( + x_coalesced.indices() + .mul( + torch.tensor(stride, dtype=indices_dtype, device=device).unsqueeze( + 1 + ) + ) + .sum(0) + ) + values[flat_indices] = x_coalesced.values() + return ( + torch.sparse_coo_tensor(indices, values, x.shape) + ._coalesced_(True) + .requires_grad_(x.requires_grad) + ) + elif _is_sparse_compressed_tensor(x): + blocksize = ( + x.values().shape[1:3] + if x.layout in {torch.sparse_bsr, torch.sparse_bsc} + else None + ) + compressed_indices = ( + x.crow_indices() + if x.layout in {torch.sparse_csr, torch.sparse_bsr} + else x.ccol_indices() + ) + # We'll use intermediate sparse COO for simplicity + r = _densify(x.detach().to_sparse(layout=torch.sparse_coo)).to_sparse( + layout=x.layout, blocksize=blocksize + ) + # Check that all elements are specified also after `to_sparse` op: + dense_numel = r.values().numel() // max(1, r.values().shape[0]) + batch_numel = compressed_indices.numel() // compressed_indices.shape[-1] + sparse_numel = r.numel() // max(1, dense_numel * batch_numel) + if sparse_numel != r._nnz(): + raise AssertionError( + f"{x.layout} densify failed: expected nnz={sparse_numel} but got {r._nnz()}" + ) + return r.requires_grad_(x.requires_grad) + elif _is_sparse_any_tensor(x): + raise NotImplementedError(x.layout) + return x + + +def _iter_tensor(x_tensor): + # (Only used for slow gradcheck) Returns a generator that yields the following + # elements at each iteration: + # 1) a tensor: the same tensor is returned across all iterations. The tensor + # is not the same as the original x_tensor as given as input - it is + # prepared so that it can be modified in-place. Depending on whether the + # input tensor is strided, sparse, or dense, the returned tensor may or may + # not share storage with x_tensor. + # 2) a tuple of indices that can be used with advanced indexing (yielded in + # dictionary order) + # 3) flattened index that will be used to index into the Jacobian tensor + # + # For a tensor t with size (2, 2), _iter_tensor yields: + # `x, (0, 0), 0`, `x, (0, 1), 1`, `x, (1, 0), 2`, `x, (1, 1), 3` + # + # where x is the t.data of the original tensor. Perturbing the entry of x + # at index (1, 1) yields the 3rd column of the overall Jacobian matrix. + if _is_sparse_any_tensor(x_tensor): + + def get_stride(size): + dim = len(size) + tmp = 1 + stride = [0] * dim + for i in reversed(range(dim)): + stride[i] = tmp + tmp *= size[i] + return stride + + x_nnz = x_tensor._nnz() + x_size = list(x_tensor.size()) + if x_tensor.layout is torch.sparse_coo: + x_indices = x_tensor._indices().t() + x_values = x_tensor._values() + elif x_tensor.layout is torch.sparse_csr: + x_indices = torch._convert_indices_from_csr_to_coo( + x_tensor.crow_indices(), x_tensor.col_indices() + ).t() + x_values = x_tensor.values() + elif x_tensor.layout is torch.sparse_csc: + x_indices = torch._convert_indices_from_csr_to_coo( + x_tensor.ccol_indices(), x_tensor.row_indices(), transpose=True + ).t() + x_values = x_tensor.values() + elif x_tensor.layout is torch.sparse_bsr: + x_block_values = x_tensor.values() + x_blocksize = x_block_values.size()[1:3] + x_indices = ( + torch._convert_indices_from_csr_to_coo( + x_tensor.crow_indices(), x_tensor.col_indices() + ) + .repeat_interleave(x_blocksize[0] * x_blocksize[1], 1) + .mul_(torch.tensor(x_blocksize, device=x_tensor.device).reshape(2, 1)) + .add_( + torch.stack( + torch.where(torch.ones(x_blocksize, device=x_tensor.device)) + ).repeat(1, x_nnz) + ) + .t() + ) + x_values = x_block_values.flatten(0, 2) + x_nnz = x_values.size(0) + elif x_tensor.layout is torch.sparse_bsc: + x_block_values = x_tensor.values() + x_blocksize = x_block_values.size()[1:3] + x_indices = ( + torch._convert_indices_from_csr_to_coo( + x_tensor.ccol_indices(), x_tensor.row_indices(), transpose=True + ) + .repeat_interleave(x_blocksize[0] * x_blocksize[1], 1) + .mul_(torch.tensor(x_blocksize, device=x_tensor.device).reshape(2, 1)) + .add_( + torch.stack( + torch.where(torch.ones(x_blocksize, device=x_tensor.device)) + ).repeat(1, x_nnz) + ) + .t() + ) + x_values = x_block_values.flatten(0, 2) + x_nnz = x_values.size(0) + else: + raise NotImplementedError(f"_iter_tensor for {x_tensor.layout} input") + x_stride = get_stride(x_size) + # Use .data here to get around the version check + x_values = x_values.data + for i in range(x_nnz): + x_value = x_values[i] + for x_idx in product(*[range(m) for m in x_values.size()[1:]]): + indices = x_indices[i].tolist() + list(x_idx) + d_idx = sum(indices[k] * x_stride[k] for k in range(len(x_size))) + yield x_value, x_idx, d_idx + elif x_tensor.layout == torch._mkldnn: # type: ignore[attr-defined] + for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])): + # this is really inefficient, but without indexing implemented, there's + # not really a better way than converting back and forth + x_tensor_dense = x_tensor.to_dense() + yield x_tensor_dense, x_idx, d_idx + else: + # Use .data here to get around the version check + x_tensor = x_tensor.data + for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])): + yield x_tensor, x_idx, d_idx + + +def _get_numerical_jacobian( + fn, inputs, outputs=None, target=None, eps=1e-3, is_forward_ad=False +) -> list[tuple[torch.Tensor, ...]]: + """Compute the numerical Jacobian of `fn(inputs)` with respect to `target`. + + If not specified, targets are the input. Returns M * N Jacobians where N is the + number of tensors in target that require grad and M is the number of non-integral + outputs. + + Args: + fn: the function to compute the jacobian for + inputs: inputs to `fn` + outputs: provide precomputed outputs to avoid one extra invocation of fn + target: the Tensors wrt whom Jacobians are calculated (default=`inputs`) + eps: the magnitude of the perturbation during finite differencing + (default=`1e-3`) + is_forward_ad: if this numerical jacobian is computed to be checked wrt + forward AD gradients (this is used for error checking only) + + Returns: + A list of M N-tuples of tensors + + Note that `target` may not even be part of `input` to `fn`, so please be + **very careful** in this to not clone `target`. + """ + jacobians: list[tuple[torch.Tensor, ...]] = [] + if outputs is None: + outputs = _as_tuple(fn(*_as_tuple(inputs))) + if not is_forward_ad and any(o.is_complex() for o in outputs): + raise ValueError( + "Expected output to be non-complex. get_numerical_jacobian no " + "longer supports functions that return complex outputs." + ) + if target is None: + target = inputs + inp_indices = [ + i for i, a in enumerate(target) if is_tensor_like(a) and a.requires_grad + ] + for i, (inp, inp_idx) in enumerate(zip(_iter_tensors(target, True), inp_indices)): + jacobians += [ + get_numerical_jacobian_wrt_specific_input( + fn, + inp_idx, + inputs, + outputs, + eps, + input=inp, + is_forward_ad=is_forward_ad, + ) + ] + return jacobians + + +@deprecated( + "`get_numerical_jacobian` was part of PyTorch's private API and not " + "meant to be exposed. We are deprecating it and it will be removed " + "in a future version of PyTorch. If you have a specific use for " + "this or feature request for this to be a stable API, please file " + "us an issue at https://github.com/pytorch/pytorch/issues/new", + category=FutureWarning, +) +def get_numerical_jacobian(fn, inputs, target=None, eps=1e-3, grad_out=1.0): + """Compute the numerical Jacobian for a given fn and its inputs. + + This is a Deprecated API. + + Args: + fn: the function to compute the Jacobian for (must take inputs as a tuple) + inputs: input to `fn` + target: the Tensors wrt whom Jacobians are calculated (default=`input`) + eps: the magnitude of the perturbation during finite differencing + (default=`1e-3`) + grad_out: defaults to 1.0. + + Returns: + A list of Jacobians of `fn` (restricted to its first output) with respect to + each input or target, if provided. + + Note that `target` may not even be part of `input` to `fn`, so please be + **very careful** in this to not clone `target`. + """ + if ( + grad_out != 1.0 + ): # grad_out param is only kept for backward compatibility reasons + raise ValueError( + "Expected grad_out to be 1.0. get_numerical_jacobian no longer " + "supports values of grad_out != 1.0." + ) + + def fn_pack_inps(*inps): + return fn(inps) + + jacobians = _get_numerical_jacobian(fn_pack_inps, inputs, None, target, eps) + + return tuple(jacobian_for_each_output[0] for jacobian_for_each_output in jacobians) + + +def _compute_numerical_gradient(fn, entry, v, norm_v, nbhd_checks_fn): + # Computes numerical directional derivative as finite difference + # of function `fn` at input `entry`, perturbed by vector `v`. + if _is_sparse_compressed_tensor(entry): + # sparse compressed tensors don't implement sub/add/copy_ + # yet. However, in non-masked semantics context entry and v + # have the same sparse indices ... + assert entry.layout == v.layout, (entry.layout, v.layout) + assert entry._nnz() == v._nnz(), (entry._nnz(), v._nnz(), entry.shape) + # ... the finite differencing can be performed on values only: + entry = entry.values() + v = v.values() + # we'll detach to avoid backward computations that sparse + # tensors have limited support for. + entry = entry.detach() + + orig = entry.clone() + entry.copy_(orig - v) + outa = fn() + entry.copy_(orig + v) + outb = fn() + entry.copy_(orig) + + def compute(a, b): + nbhd_checks_fn(a, b) + ret = (b - a) / (2 * norm_v) # use central difference approx + return ret.detach().reshape(-1) + + return tuple(compute(a, b) for (a, b) in zip(outa, outb)) + + +def _compute_numerical_jvps_wrt_specific_input( + jvp_fn, delta, input_is_complex, is_forward_ad=False +) -> list[torch.Tensor]: + # Computing the jacobian only works for real delta + # For details on the algorithm used here, refer: + # Section 3.5.3 https://arxiv.org/pdf/1701.00392.pdf + # s = fn(z) where z = x for real valued input + # and z = x + yj for complex valued input + jvps: list[torch.Tensor] = [] + ds_dx_tup = jvp_fn(delta[0] if isinstance(delta, tuple) else delta) + + if input_is_complex: # C -> R + ds_dy_tup = ( + jvp_fn(delta[1] * 1j) if isinstance(delta, tuple) else jvp_fn(delta * 1j) + ) + for ds_dx, ds_dy in zip(ds_dx_tup, ds_dy_tup): + assert not ds_dx.is_complex() + # conjugate wirtinger derivative + conj_w_d = ds_dx + ds_dy * 1j + jvps.append(conj_w_d) + else: + for ds_dx in ds_dx_tup: # R -> R or (R -> C for the forward AD case) + assert is_forward_ad or not ds_dx.is_complex() + jvps.append(ds_dx) + return jvps + + +def _combine_jacobian_cols( + jacobians_cols: dict[int, list[torch.Tensor]], outputs, input, numel +) -> tuple[torch.Tensor, ...]: + # jacobian_cols maps column_idx -> output_idx -> single column of jacobian Tensor + # we return a list that maps output_idx -> full jacobian Tensor + jacobians = _allocate_jacobians_with_outputs( + outputs, numel, dtype=input.dtype if input.dtype.is_complex else None + ) + for i, jacobian in enumerate(jacobians): + for k, v in jacobians_cols.items(): + jacobian[k] = v[i] + return jacobians + + +def _prepare_input( + input: torch.Tensor, maybe_perturbed_input: Optional[torch.Tensor], fast_mode=False +) -> torch.Tensor: + # Prepares the inputs to be passed into the function while including the new + # modified input. + if input.layout == torch._mkldnn: # type: ignore[attr-defined] # no attr _mkldnn + # Convert back to mkldnn + if maybe_perturbed_input is not None: + return maybe_perturbed_input.to_mkldnn() + else: + return input + elif _is_sparse_any_tensor(input): + if fast_mode and maybe_perturbed_input is not None: + # entry is already a "cloned" version of the original tensor + # thus changes to entry are not reflected in the input + return maybe_perturbed_input + else: + return input + else: + # We cannot use entry (input.data) if we want gradgrad to work because + # fn (in the gradgrad case) needs to compute grad wrt input + return input + + +def _check_outputs_same_dtype_and_shape(output1, output2, eps, idx=None) -> None: + # Check that the returned outputs don't have different dtype or shape when you + # perturb the input + on_index = "on index {idx} " if idx is not None else "" + assert output1.shape == output2.shape, ( + f"Expected `func` to return outputs with the same shape" + f" when inputs are perturbed {on_index}by {eps}, but got:" + f" shapes {output1.shape} and {output2.shape}." + ) + assert output1.dtype == output2.dtype, ( + f"Expected `func` to return outputs with the same dtype" + f" when inputs are perturbed {on_index}by {eps}, but got:" + f" dtypes {output1.dtype} and {output2.dtype}." + ) + + +def get_numerical_jacobian_wrt_specific_input( + fn, input_idx, inputs, outputs, eps, input=None, is_forward_ad=False +) -> tuple[torch.Tensor, ...]: + # Computes the numerical jacobians wrt to a single input. Returns N jacobian + # tensors, where N is the number of outputs. We use a dictionary for + # jacobian_cols because indices aren't necessarily consecutive for sparse inputs + # When we perturb only a single element of the input tensor at a time, the jvp + # is equivalent to a single col of the Jacobian matrix of fn. + jacobian_cols: dict[int, list[torch.Tensor]] = {} + input = inputs[input_idx] if input is None else input + assert input.requires_grad + for x, idx, d_idx in _iter_tensor(input): + wrapped_fn = _with_prepare_inputs(fn, inputs, input_idx, x) + input_to_perturb = x[idx] + nbhd_checks_fn = functools.partial( + _check_outputs_same_dtype_and_shape, idx=idx, eps=eps + ) + jvp_fn = _get_numerical_jvp_fn( + wrapped_fn, input_to_perturb, eps, nbhd_checks_fn + ) + jacobian_cols[d_idx] = _compute_numerical_jvps_wrt_specific_input( + jvp_fn, eps, x.is_complex(), is_forward_ad + ) + return _combine_jacobian_cols(jacobian_cols, outputs, input, input.numel()) + + +def _get_analytical_jacobian_forward_ad( + fn, inputs, outputs, *, check_grad_dtypes=False, all_u=None +) -> tuple[tuple[torch.Tensor, ...], ...]: + """Compute the analytical Jacobian using forward mode AD of `fn(inputs)` using forward mode AD with respect to `target`. + + Return N * M Jacobians where N is the number of tensors in target that require grad and + M is the number of non-integral outputs. + Contrary to other functions here, this function requires "inputs" to actually be used by the function. + The computed value is expected to be wrong if the function captures the inputs by side effect instead of + using the passed ones (many torch.nn tests do this). + + Args: + fn: the function to compute the jacobian for + inputs: inputs to `fn` + outputs: provide precomputed outputs to avoid one extra invocation of fn + check_grad_dtypes: if True, will check that the gradient dtype are valid + all_u (optional): if provided, the Jacobian will be right multiplied with this vector + + Returns: + A tuple of M N-tuples of tensors + """ + # To avoid early import issues + fwAD = torch.autograd.forward_ad + + tensor_inputs = tuple(i for i in inputs if is_tensor_like(i) and i.requires_grad) + + if any(i.is_complex() for i in tensor_inputs): + raise ValueError( + "Expected inputs to be non-complex for _get_analytical_jacobian_forward_ad." + ) + + if all_u: + jacobians = tuple( + _allocate_jacobians_with_outputs(outputs, 1) for i in tensor_inputs + ) + else: + jacobians = tuple( + _allocate_jacobians_with_outputs(outputs, i.numel()) for i in tensor_inputs + ) + + with fwAD.dual_level(): + fw_grads = [] + dual_inputs = [] + for i, inp in enumerate(inputs): + if is_tensor_like(inp) and inp.requires_grad: + if inp.layout == torch._mkldnn: # type: ignore[attr-defined] + raise ValueError( + "MKLDNN inputs are not support for forward AD gradcheck." + ) + + inp = fwAD.make_dual(inp.detach(), torch.zeros_like(inp)) + # If inp is a differentiable view, the dual might not be the tangent given to + # make_dual, so read it explicitly from the dual tensor + fw_grads.append(fwAD.unpack_dual(inp)[1]) + dual_inputs.append(inp) + + if all_u: + # Do the full reduction in one pass + # To be consistent with numerical evaluation, we actually compute one reduction per input + for i, (fw_grad, u) in enumerate(zip(fw_grads, all_u)): + fw_grad.copy_(u.view_as(fw_grad)) + raw_outputs = _as_tuple(fn(*dual_inputs)) + dual_outputs = filter(_is_float_or_complex_tensor, raw_outputs) + for index_o, d_o in enumerate(dual_outputs): + val, res = fwAD.unpack_dual(d_o) + if ( + check_grad_dtypes + and res is not None + and val.is_complex() != res.is_complex() + ): + raise GradcheckError("Forward AD gradient has dtype mismatch.") + + # Remove extra dimension of size 1 corresponding to the reduced input + jacobians[i][index_o].squeeze_(0) + if res is None: + jacobians[i][index_o].zero_() + else: + jacobians[i][index_o].copy_(res.reshape(-1)) + fw_grad.zero_() + else: + # Reconstruct the full Jacobian column by column + for i, fw_grad in enumerate(fw_grads): + for lin_idx, grad_idx in enumerate( + product(*[range(m) for m in fw_grad.size()]) + ): + fw_grad[grad_idx] = 1.0 + raw_outputs = _as_tuple(fn(*dual_inputs)) + dual_outputs = filter(_is_float_or_complex_tensor, raw_outputs) + for index_o, d_o in enumerate(dual_outputs): + val, res = fwAD.unpack_dual(d_o) + if ( + check_grad_dtypes + and res is not None + and val.is_complex() != res.is_complex() + ): + raise GradcheckError( + "Forward AD gradient has dtype mismatch." + ) + + if res is None: + jacobians[i][index_o][lin_idx].zero_() + else: + jacobians[i][index_o][lin_idx].copy_(res.reshape(-1)) + fw_grad[grad_idx] = 0.0 + + return jacobians + + +def _get_input_to_perturb(input): + # Prepare the input so that it can be modified in-place and do certain + # operations that require the tensor to have strides. If fast_mode=False, + # _iter_tensor would handle the below cases: + if input.layout == torch._mkldnn: # type: ignore[attr-defined] # no attr _mkldnn + # Convert to dense so we can perform operations that require strided tensors + input_to_perturb = input.to_dense() + elif _is_sparse_any_tensor(input): + # Clone because input may require grad, and copy_ calls resize_, + # which is not allowed for .data + input_to_perturb = input.clone() + else: + input_to_perturb = input.data + return input_to_perturb + + +def _with_prepare_inputs(fn, inputs, input_idx, input_to_perturb, fast_mode=False): + # Wraps `fn` so that its inputs are already supplied + def wrapped_fn(): + inp = tuple( + _prepare_input(a, input_to_perturb if i == input_idx else None, fast_mode) + if is_tensor_like(a) + else a + for i, a in enumerate(_as_tuple(inputs)) + ) + return tuple(a.clone() for a in _as_tuple(fn(*inp))) + + return wrapped_fn + + +def _get_numerical_jvp_fn(wrapped_fn, input_to_perturb, eps, nbhd_checks_fn): + # Wraps jvp_fn so that certain arguments are already supplied + def jvp_fn(delta): + return _compute_numerical_gradient( + wrapped_fn, input_to_perturb, delta, eps, nbhd_checks_fn + ) + + return jvp_fn + + +def _reshape_tensor_or_tuple(u, shape): + # We don't need to reshape when input corresponding to u is sparse + if isinstance(u, tuple): + if not _is_sparse_any_tensor(u[0]): + return (u[0].reshape(shape), u[1].reshape(shape)) + else: + if not _is_sparse_any_tensor(u): + return u.reshape(shape) + return u + + +def _mul_tensor_or_tuple(u, k): + if isinstance(u, tuple): + return (k * u[0], k * u[1]) + else: + return k * u + + +def _get_numerical_jvp_wrt_specific_input( + fn, input_idx, inputs, u, eps, is_forward_ad=False +) -> list[torch.Tensor]: + input = inputs[input_idx] + input_to_perturb = _get_input_to_perturb(input) + wrapped_fn = _with_prepare_inputs(fn, inputs, input_idx, input_to_perturb, True) + nbhd_checks_fn = functools.partial(_check_outputs_same_dtype_and_shape, eps=eps) + jvp_fn = _get_numerical_jvp_fn(wrapped_fn, input_to_perturb, eps, nbhd_checks_fn) + u = _reshape_tensor_or_tuple(u, input_to_perturb.shape) + u = _mul_tensor_or_tuple(u, eps) + return _compute_numerical_jvps_wrt_specific_input( + jvp_fn, u, input.is_complex(), is_forward_ad + ) + + +def _get_numerical_vJu( + fn, inputs, inp_indices, func_out, all_u, all_v, eps, is_forward_ad +): + # Note that all_v can also be None, in that case, this function only computes Ju. + reduced_jacobians: list[list[torch.Tensor]] = [] + for inp_idx, u in zip(inp_indices, all_u): + all_Ju = _get_numerical_jvp_wrt_specific_input( + fn, inp_idx, inputs, u, eps, is_forward_ad + ) + # Filter out the Ju for non floating point outputs + filtered_Ju = [] + func_out = _as_tuple(func_out) + assert len(all_Ju) == len(func_out) + for Ju, output in zip(all_Ju, func_out): + if _is_float_or_complex_tensor(output): + filtered_Ju.append(Ju) + else: + # TODO: handle the other Ju + pass + if all_v is not None: + jacobian_scalars: list[torch.Tensor] = [] + for v, Ju in zip(all_v, filtered_Ju): + jacobian_scalars.append(_dot_with_type_promotion(v, Ju)) + reduced_jacobians.append(jacobian_scalars) + else: + reduced_jacobians.append(filtered_Ju) + return reduced_jacobians + + +def _check_jacobians_equal(j1, j2, atol): + # Check whether the max difference between two Jacobian tensors are within some + # tolerance `atol`. + for j1_x, j2_x in zip(j1, j2): + if j1_x.numel() != 0 and (j1_x - j2_x).abs().max() > atol: + return False + return True + + +def _stack_and_check_tensors( + list_of_list_of_tensors, inputs, numel_outputs +) -> tuple[tuple[torch.Tensor, ...], bool, bool]: + # For the ith tensor in the inner list checks whether it has the same size and + # dtype as the ith differentiable input. + out_jacobians = _allocate_jacobians_with_inputs(inputs, numel_outputs) + diff_input_list = list(_iter_tensors(inputs, True)) + correct_grad_sizes = True + correct_grad_types = True + for i, tensor_list in enumerate(list_of_list_of_tensors): + inp = diff_input_list[i] + out_jacobian = out_jacobians[i] + for j, tensor in enumerate(tensor_list): + if tensor is not None and tensor.size() != inp.size(): + correct_grad_sizes = False + elif tensor is not None and tensor.dtype != inp.dtype: + correct_grad_types = False + if tensor is None: + out_jacobian[:, j].zero_() + else: + dense = ( + tensor.to_dense() if not tensor.layout == torch.strided else tensor + ) + assert out_jacobian[:, j].numel() == dense.numel() + out_jacobian[:, j] = dense.reshape(-1) + return out_jacobians, correct_grad_sizes, correct_grad_types + + +FAILED_NONDET_MSG = """\n +NOTE: If your op relies on non-deterministic operations i.e., it is listed here: +https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html +this failure might be expected. + +If you are adding a new operator, please file an issue and then use one of the +workarounds. The workaround depends on how your test invokes gradcheck/gradgradcheck. +If the test +- manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck + with `nondet_tol=` as a keyword argument. +- is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test + to have `gradcheck_nondet_tol=`. +- is a Module test (e.g., in common_nn.py), then modify the corresponding + module_test entry to have `gradcheck_nondet_tol=` +""" + + +def _check_analytical_jacobian_attributes( + inputs, output, nondet_tol, check_grad_dtypes, fast_mode=False, v=None +) -> tuple[torch.Tensor, ...]: + # This is used by both fast and slow mode: + # - For slow mode, vjps[i][j] is the jth row of the Jacobian wrt the ith + # input. + # - For fast mode, vjps[i][0] is a linear combination of the rows + # of the Jacobian wrt the ith input + diff_input_list = list(_iter_tensors(inputs, True)) + + def vjp_fn(grad_output): + return torch.autograd.grad( + output, diff_input_list, grad_output, retain_graph=True, allow_unused=True + ) + + # Compute everything twice to check for nondeterminism (which we call reentrancy) + if fast_mode: + vjps1 = _get_analytical_vjps_wrt_specific_output(vjp_fn, output.clone(), v) + vjps2 = _get_analytical_vjps_wrt_specific_output(vjp_fn, output.clone(), v) + else: + vjps1 = _compute_analytical_jacobian_rows(vjp_fn, output.clone()) + vjps2 = _compute_analytical_jacobian_rows(vjp_fn, output.clone()) + + output_numel = output.numel() if not fast_mode else 1 + jacobians1, types_ok, sizes_ok = _stack_and_check_tensors( + vjps1, inputs, output_numel + ) + jacobians2, _, _ = _stack_and_check_tensors(vjps2, inputs, output_numel) + reentrant = _check_jacobians_equal(jacobians1, jacobians2, nondet_tol) + + if not types_ok and check_grad_dtypes: + raise GradcheckError("Gradient has dtype mismatch") + if not sizes_ok: + raise GradcheckError("Analytical gradient has incorrect size") + if not reentrant: + raise GradcheckError( + "Backward is not reentrant, i.e., running backward with " + "same input and grad_output multiple times gives different values, " + "although analytical gradient matches numerical gradient." + f"The tolerance for nondeterminism was {nondet_tol}." + FAILED_NONDET_MSG + ) + return jacobians1 + + +def _get_analytical_vJu_backward_mode( + inputs, outputs, nondet_tol, check_grad_dtypes, all_v, all_u +): + reduced_jacobians: list[list[torch.Tensor]] = [] + for output, v in zip(outputs, all_v): + all_vJ = _check_analytical_jacobian_attributes( + inputs, output, nondet_tol, check_grad_dtypes, fast_mode=True, v=v + ) + jacobian_scalars: list[torch.Tensor] = [] + for vJ, u in zip(all_vJ, all_u): + # Why do we need squeeze here? vJ is a 2-d tensor so that we can reuse + # the error checking logic from slow mode + vJ = vJ.T.squeeze(0) + if vJ.is_complex(): # C -> R + tv = torch.view_as_real(vJ.resolve_conj()) + tr = tv.select(-1, 0) + ti = tv.select(-1, 1) + jacobian_scalars.append(tr.dot(u[0]) + 1j * ti.dot(u[1])) + else: # R -> R + jacobian_scalars.append(vJ.dot(u)) + reduced_jacobians.append(jacobian_scalars) + return reduced_jacobians + + +@deprecated( + "`get_analytical_jacobian` was part of PyTorch's private API and not " + "meant to be exposed. We are deprecating it and it will be removed " + "in a future version of PyTorch. If you have a specific use for " + "this or feature request for this to be a stable API, please file " + "us an issue at https://github.com/pytorch/pytorch/issues/new", + category=FutureWarning, +) +def get_analytical_jacobian(inputs, output, nondet_tol=0.0, grad_out=1.0): + # Replicates the behavior of the old get_analytical_jacobian before the refactor + # This shares much of its code with _check_analytical_jacobian_attributes + if ( + grad_out != 1.0 + ): # grad_out param is only kept for backward compatibility reasons + raise ValueError( + "Expected grad_out to be 1.0. get_analytical_jacobian no longer " + "supports values of grad_out != 1.0." + ) + if output.is_complex(): + raise ValueError( + "Expected output to be non-complex. get_analytical_jacobian no " + "longer supports functions that return complex outputs." + ) + diff_input_list = list(_iter_tensors(inputs, True)) + + def vjp_fn(grad_output): + return torch.autograd.grad( + output, diff_input_list, grad_output, retain_graph=True, allow_unused=True + ) + + # Compute everything twice to check for nondeterminism (which we call reentrancy) + vjps1 = _compute_analytical_jacobian_rows(vjp_fn, output.clone()) + vjps2 = _compute_analytical_jacobian_rows(vjp_fn, output.clone()) + + output_numel = output.numel() + jacobians1, types_ok, sizes_ok = _stack_and_check_tensors( + vjps1, inputs, output_numel + ) + jacobians2, _, _ = _stack_and_check_tensors(vjps2, inputs, output_numel) + reentrant = _check_jacobians_equal(jacobians1, jacobians2, nondet_tol) + + return jacobians1, reentrant, sizes_ok, types_ok + + +def _get_analytical_jacobian(inputs, outputs, input_idx, output_idx): + # Computes the analytical Jacobian in slow mode for a single input-output pair. + # Forgoes performing checks on dtype, shape, and reentrancy. + jacobians = _check_analytical_jacobian_attributes( + inputs, outputs[output_idx], nondet_tol=float("inf"), check_grad_dtypes=False + ) + return jacobians[input_idx] + + +def _compute_analytical_jacobian_rows( + vjp_fn, sample_output +) -> list[list[Optional[torch.Tensor]]]: + # Computes Jacobian row-by-row by projecting `vjp_fn` = v^T J on standard basis + # vectors: vjp_fn(e) = e^T J is a corresponding row of the Jacobian. + # NB: this function does not assume vjp_fn(v) to return tensors with the same + # number of elements for different v. This is checked when we later combine the + # rows into a single tensor. + grad_out_base = torch.zeros_like( + sample_output, memory_format=torch.legacy_contiguous_format + ) + flat_grad_out = grad_out_base.view(-1) + # jacobians_rows[i][j] is the Jacobian jth row for the ith input + jacobians_rows: list[list[Optional[torch.Tensor]]] = [] + for j in range(flat_grad_out.numel()): + flat_grad_out.zero_() + flat_grad_out[j] = 1.0 # projection for jth row of Jacobian + grad_inputs = vjp_fn(grad_out_base) + for i, d_x in enumerate(grad_inputs): + if j == 0: + jacobians_rows.append([]) + jacobians_rows[i] += [ + d_x.clone() if isinstance(d_x, torch.Tensor) else None + ] + return jacobians_rows + + +def _get_analytical_vjps_wrt_specific_output( + vjp_fn, sample_output, v +) -> list[list[Optional[torch.Tensor]]]: + grad_inputs = vjp_fn(v.reshape(sample_output.shape)) + vjps: list[list[Optional[torch.Tensor]]] = [ + [vjp.clone() if isinstance(vjp, torch.Tensor) else None] for vjp in grad_inputs + ] + return vjps + + +def _check_inputs(tupled_inputs) -> bool: + # Make sure that gradients are saved for at least one input + any_input_requiring_grad = False + for idx, inp in enumerate(tupled_inputs): + if is_tensor_like(inp) and inp.requires_grad: + if not (inp.dtype == torch.float64 or inp.dtype == torch.complex128): + warnings.warn( + f"Input #{idx} requires gradient and " + "is not a double precision floating point or complex. " + "This check will likely fail if all the inputs are " + "not of double precision floating point or complex. " + ) + if inp.is_sparse: + content = inp._values() + elif _is_sparse_compressed_tensor(inp): + content = inp.values() + else: + content = inp + # TODO: To cover more problematic cases, replace stride = 0 check with + # "any overlap in memory" once we have a proper function to check it. + if content.layout is not torch._mkldnn: # type: ignore[attr-defined] + if not all( + st > 0 or sz <= 1 + for st, sz in zip(content.stride(), content.size()) + ): + raise RuntimeError( + f"The {idx}th input has a dimension with stride 0. gradcheck only " + "supports inputs that are non-overlapping to be able to " + "compute the numerical gradients correctly. You should call " + ".contiguous on the input before passing it to gradcheck." + ) + any_input_requiring_grad = True + + if not any_input_requiring_grad: + raise ValueError( + "gradcheck expects at least one input tensor to require gradient, " + "but none of the them have requires_grad=True." + ) + return True + + +def _check_outputs(outputs) -> None: + if any(_is_sparse_any_tensor(t) for t in outputs if isinstance(t, torch.Tensor)): + # it is easier to call to_dense() on the sparse output than + # to modify analytical jacobian + raise ValueError( + "Sparse output is not supported at gradcheck yet. " + "Please call to_dense(masked_grad=...) on the output of fn for gradcheck." + ) + if any(t.layout == torch._mkldnn for t in outputs if isinstance(t, torch.Tensor)): # type: ignore[attr-defined] + raise ValueError( + "MKLDNN output is not supported at gradcheck yet. " + "Please call to_dense(masked_grad=...) on the output of fn for gradcheck." + ) + + +def _check_no_differentiable_outputs( + func, inputs, func_out, eps, *, is_forward_ad +) -> bool: + # When there are no differentiable outputs, numerical gradient for a function is + # expected to be zero. + jacobians_all_inputs_outputs = _get_numerical_jacobian( + func, inputs, func_out, eps=eps, is_forward_ad=is_forward_ad + ) + for jacobians_all_outputs_and_fixed_input in jacobians_all_inputs_outputs: + for jacobian in jacobians_all_outputs_and_fixed_input: + if torch.ne(jacobian, 0).sum() > 0: + raise GradcheckError( + "Numerical gradient for function expected to be zero" + ) + return True + + +def _check_no_differentiable_outputs_fast( + func, func_out, all_inputs, inputs_indices, all_u, eps, nondet_tol +): + for inp_idx, u in zip(inputs_indices, all_u): + jvps = _get_numerical_jvp_wrt_specific_input(func, inp_idx, all_inputs, u, eps) + for jvp in jvps: + if jvp.numel() == 0: + continue + if (jvp - torch.zeros_like(jvp)).abs().max() > nondet_tol: + raise GradcheckError( + "Numerical gradient for function expected to be zero" + ) + return True + + +FAILED_BATCHED_GRAD_MSG = """ +gradcheck or gradgradcheck failed while testing batched gradient computation. +This could have been invoked in a number of ways (via a test that calls +gradcheck/gradgradcheck directly or via an autogenerated test). + +If you are adding a new operator, please file an issue and then use one of the +workarounds. The workaround depends on how your test invokes gradcheck/gradgradcheck. +If the test +- manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck + with `check_batched_grad=False` as a keyword argument. +- is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test + to have `check_batched_grad=False` and/or `check_batched_gradgrad=False`. + +If you're modifying an existing operator that supports batched grad computation, +or wish to make a new operator work with batched grad computation, please read +the following. + +To compute batched grads (e.g., jacobians, hessians), we vmap over the backward +computation. The most common failure case is if there is a 'vmap-incompatible +operation' in the backward pass. Please see +NOTE: [How to write vmap-compatible backward formulas] +in the codebase for an explanation of how to fix this. +""".strip() + +FAILED_BATCHED_GRAD_MSG_FWD_AD = """ +gradcheck failed while testing batched gradient computation with forward-mode AD. +This test is enabled automatically when both `check_batched_grad=True` +and `check_forward_ad=True`, but can be disabled in the following ways +dependong on how the test was invoked (via a test that calls gradcheck +directly or via an autogenerated test). + +If you are adding a new operator, please file an issue and then use one of the +workarounds. The workaround depends on how your test invokes gradcheck/gradgradcheck. +If the test +- manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck + with `check_batched_forward_grad=False` as a keyword argument. +- is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test + to have `check_batched_forward_grad=False` +""" + + +def _get_failed_batched_grad_test_msg( + output_idx, input_idx, res, exp, is_forward_ad=False +): + return f""" +For output {output_idx} and input {input_idx}: + +{FAILED_BATCHED_GRAD_MSG_FWD_AD if is_forward_ad else FAILED_BATCHED_GRAD_MSG} + +Got: +{res} + +Expected: +{exp} +""".strip() + + +def _test_batched_grad_forward_ad(func, inputs) -> bool: + fwAD = torch.autograd.forward_ad # To avoid early import issues (do we need this?) + assert isinstance(inputs, tuple) + + for input_idx, current_input in enumerate(inputs): + if not (is_tensor_like(current_input) and current_input.requires_grad): + continue + + def jvp(tangent: torch.Tensor): + with fwAD.dual_level(): + dual = fwAD.make_dual(current_input.detach(), tangent) + inputs_with_dual = tuple( + dual + if idx == input_idx + else (inp.detach() if is_tensor_like(inp) else inp) + for idx, inp in enumerate(inputs) + ) + dual_outputs = _as_tuple(func(*inputs_with_dual)) + ret = [] + for dual_output in dual_outputs: + if dual_output is None: + continue + primal_out, tangent_out = fwAD.unpack_dual(dual_output) + if tangent_out is not None: + ret.append(tangent_out) + else: + ret.append( + torch.zeros( + [], dtype=primal_out.dtype, device=primal_out.device + ).expand(primal_out.shape) + ) + return tuple(ret) + + if not _is_float_or_complex_tensor(current_input): + continue + + tangents = [torch.randn_like(current_input) for _ in range(2)] + expected = [jvp(t) for t in tangents] + expected = [torch.stack(shards) for shards in zip(*expected)] + + try: + result = _vmap(jvp)(torch.stack(tangents)) + except RuntimeError as ex: + # Rethrow to provide a better error message + raise GradcheckError( + f"While computing batched gradients, got: {ex}\n\n{FAILED_BATCHED_GRAD_MSG_FWD_AD}" + ) from ex + + for input_idx, (res, exp) in enumerate(zip(result, expected)): + if torch.allclose(res, exp): + continue + raise GradcheckError( + _get_failed_batched_grad_test_msg( + input_idx, input_idx, res, exp, is_forward_ad=True + ) + ) + return True + + +def _test_batched_grad(input, output, output_idx) -> bool: + # NB: _test_batched_grad compares two autograd.grad invocations with a single + # vmap(autograd.grad) invocation. It's not exactly a "gradcheck" in the + # sense that we're not comparing an analytical jacobian with a numeric one, + # but it is morally similar (we could have computed a full analytic jac + # via vmap, but that is potentially slow) + diff_input_list = list(_iter_tensors(input, True)) + grad = functools.partial( + torch.autograd.grad, + output, + diff_input_list, + retain_graph=True, + allow_unused=True, + ) + + def vjp(v): + results = grad(v) + results = tuple( + grad + if grad is not None + else torch.zeros([], dtype=inp.dtype, device=inp.device).expand(inp.shape) + for grad, inp in zip(results, diff_input_list) + ) + return results + + grad_outputs = [torch.randn_like(output) for _ in range(2)] + + expected = [vjp(gO) for gO in grad_outputs] + expected = [torch.stack(shards) for shards in zip(*expected)] + + # Squash warnings since these are expected to happen in most cases + # NB: this doesn't work for CUDA tests: https://github.com/pytorch/pytorch/issues/50209 + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="There is a performance drop") + warnings.filterwarnings("ignore", message="Please use `torch.vmap`") + try: + result = vmap(vjp)(torch.stack(grad_outputs)) + except RuntimeError as ex: + # It's OK that we're not raising the error at the correct callsite. + # That's because the callsite is always going to inside the Python + # autograd.grad instead of the C++ traceback of what line in the + # backward formula + raise GradcheckError( + f"While computing batched gradients, got: {ex}\n\n{FAILED_BATCHED_GRAD_MSG}" + ) from ex + + for input_idx, (res, exp) in enumerate(zip(result, expected)): + if torch.allclose(res, exp): + continue + raise GradcheckError( + _get_failed_batched_grad_test_msg(output_idx, input_idx, res, exp) + ) + return True + + +def _test_backward_mul_by_grad_output(outputs, inputs, masked) -> bool: + # Tests that backward is multiplied by grad_output + diff_input_list: list[torch.Tensor] = list(_iter_tensors(inputs, True)) + if not diff_input_list: + raise GradcheckError("no Tensors requiring grad found in input") + grads_input = torch.autograd.grad( + outputs, + diff_input_list, + [ + torch.zeros_like(o, memory_format=torch.legacy_contiguous_format) + for o in outputs + ], + allow_unused=True, + ) + for gi, di in zip(grads_input, diff_input_list): + if gi is None: + continue + if isinstance(gi, torch.Tensor) and gi.layout != torch.strided: + if gi.layout != di.layout: + raise GradcheckError( + "grad is incorrect layout (" + + str(gi.layout) + + " is not " + + str(di.layout) + + ")" + ) + if _is_sparse_any_tensor(gi): + sparse_kind = str(gi.layout).replace("torch.", "").replace("_coo", "") + if gi.sparse_dim() != di.sparse_dim(): + raise GradcheckError( + f"grad is {sparse_kind} tensor, but has incorrect sparse_dim" + f" {gi.sparse_dim()}, expected {di.sparse_dim()}" + ) + if gi.dense_dim() != di.dense_dim(): + raise GradcheckError( + f"grad is {sparse_kind} tensor, but has incorrect dense_dim" + f" {gi.dense_dim()}, expected {di.dense_dim()}" + ) + gi = gi.to_dense() + di = di.to_dense() + if masked: + if not torch.allclose(gi, torch.zeros_like(gi)): + raise GradcheckError("backward not multiplied by grad_output") + elif not gi.eq(0).all(): + raise GradcheckError("backward not multiplied by grad_output") + if gi.dtype != di.dtype: + raise GradcheckError("grad is incorrect type") + if gi.device != di.device: + raise GradcheckError("grad is incorrect device") + if gi.size() != di.size(): + raise GradcheckError("grad is incorrect size") + return True + + +def _test_undefined_forward_mode(func, outputs, inputs): + fwAD = torch.autograd.forward_ad + + _inp_tensors_idx, inp_tensors = _get_inp_tensors(inputs) + _all_v, all_u, _all_u_dense = _make_vectors( + inp_tensors, outputs, use_forward_ad=True + ) + + with fwAD.dual_level(): + fw_grads = [] + dual_inputs = [] + tensor_indices = set() + for i, inp in enumerate(inputs): + if is_tensor_like(inp) and inp.requires_grad: + if inp.layout == torch._mkldnn: # type: ignore[attr-defined] + raise ValueError( + "MKLDNN inputs are not support for forward AD gradcheck." + ) + + inp = fwAD.make_dual(inp.detach(), torch.zeros_like(inp)) + # If inp is a differentiable view, the dual might not be the tangent given to + # make_dual, so read it explicitly from the dual tensor + fw_grads.append(fwAD.unpack_dual(inp)[1]) + tensor_indices.add(i) + dual_inputs.append(inp) + + for i, (fw_grad, u) in enumerate(zip(fw_grads, all_u)): + fw_grad.copy_(u.view_as(fw_grad)) + + for idx, inp in enumerate(inputs): + if idx not in tensor_indices: + continue + dual_inp_obj = dual_inputs[idx] + + # case 1 (Materialized Zero Tensor Tangent) + dual_inputs[idx] = fwAD.make_dual(inp.detach(), torch.zeros_like(inp)) + raw_outputs = _as_tuple(func(*dual_inputs)) + dual_outputs1 = filter(_is_float_or_complex_tensor, raw_outputs) + + # case 2 (Efficient Zero Tensor Tangent since we don't make a dual object and pass a regular tensor) + dual_inputs[idx] = inp.detach() + raw_outputs = _as_tuple(func(*dual_inputs)) + dual_outputs2 = filter(_is_float_or_complex_tensor, raw_outputs) + + # reset + dual_inputs[idx] = dual_inp_obj + + for index_o, (d_o1, d_o2) in enumerate(zip(dual_outputs1, dual_outputs2)): + _val1, res1 = fwAD.unpack_dual(d_o1) + _val2, res2 = fwAD.unpack_dual(d_o2) + + if not (res1 is None or res2 is None): + if not torch.allclose(res1, res2): + raise GradcheckError( + "Mismatch in tangent values for output with index: ", + index_o, + " when input: ", + inp, + " has an undefined tangent value. ", + " Got: ", + res1, + " but expected: ", + res2, + ) + return True + + +def _test_undefined_backward_mode(func, outputs, inputs) -> bool: + diff_input_list: list[torch.Tensor] = list(_iter_tensors(inputs, True)) + if not diff_input_list: + raise GradcheckError("no Tensors requiring grad found in input") + + def warn_bc_breaking(): + warnings.warn( + "Backwards compatibility: New undefined gradient support checking " + "feature is enabled by default, but it may break existing callers " + "of this function. If this is true for you, you can call this " + 'function with "check_undefined_grad=False" to disable the feature' + ) + + def check_undefined_grad_support(output_to_check): + grads_output = [ + torch.zeros_like(o, memory_format=torch.legacy_contiguous_format) + for o in output_to_check + ] + try: + grads_input = torch.autograd.grad( + output_to_check, diff_input_list, grads_output, allow_unused=True + ) + except RuntimeError as e: + warn_bc_breaking() + raise GradcheckError( + "Expected backward function to handle undefined output grads. " + 'Please look at "Notes about undefined output gradients" in ' + '"tools/autograd/derivatives.yaml"' + ) from e + + for gi in grads_input: + if (gi is not None) and (not gi.eq(0).all()): + warn_bc_breaking() + raise GradcheckError( + "Expected all input grads to be undefined or zero when all output grads are undefined " + 'or zero. Please look at "Notes about undefined output gradients" in ' + '"tools/autograd/derivatives.yaml"' + ) + return True + + # All backward functions must work properly if all output grads are undefined + outputs_to_check = [ + [ + torch._C._functions.UndefinedGrad()(o) + for o in _differentiable_outputs(func(*inputs)) + # This check filters out Tensor-likes that aren't instances of Tensor. + if isinstance(o, torch.Tensor) + ] + ] + + # If there are multiple output grads, we should be able to undef one at a time without error + if len(outputs_to_check[0]) > 1: + for undef_grad_idx in range(len(outputs)): + output_to_check = _differentiable_outputs(func(*inputs)) + outputs_to_check.append( + [ + torch._C._functions.UndefinedGrad()(o) + if idx == undef_grad_idx + else o + for idx, o in enumerate(output_to_check) + ] + ) + + return all(check_undefined_grad_support(output) for output in outputs_to_check) + + +def _as_tuple(x): + if isinstance(x, tuple): + return x + elif isinstance(x, list): + return tuple(x) + else: + return (x,) + + +def _differentiable_outputs(x): + return tuple(o for o in _as_tuple(x) if o.requires_grad) + + +def _get_notallclose_msg( + analytical, + numerical, + output_idx, + input_idx, + complex_indices, + test_imag=False, + is_forward_ad=False, +) -> str: + out_is_complex = ( + (not is_forward_ad) and complex_indices and output_idx in complex_indices + ) + inp_is_complex = is_forward_ad and complex_indices and input_idx in complex_indices + part = "imaginary" if test_imag else "real" + element = "inputs" if is_forward_ad else "outputs" + prefix = ( + "" + if not (out_is_complex or inp_is_complex) + else f"While considering the {part} part of complex {element} only, " + ) + mode = "computed with forward mode " if is_forward_ad else "" + return ( + prefix + + f"Jacobian {mode}mismatch for output {output_idx:d} with respect to input {input_idx:d},\n" + f"numerical:{numerical}\nanalytical:{analytical}\n" + ) + + +def _transpose(matrix_of_tensors): + # returns list of tuples + return list(zip(*matrix_of_tensors)) + + +def _real_and_imag_output(fn): + # returns new functions real(fn), and imag(fn) where real(fn) and imag(fn) behave the same as + # the original fn, except torch.real or torch.imag are applied to the complex outputs + def apply_to_c_outs(fn, fn_to_apply): + def wrapped_fn(*inputs): + outs = _as_tuple(fn(*inputs)) + return tuple(fn_to_apply(o) if o.is_complex() else o for o in outs) + + return wrapped_fn + + return apply_to_c_outs(fn, torch.real), apply_to_c_outs(fn, torch.imag) + + +def _real_and_imag_input(fn, complex_inp_indices, tupled_inputs): + # returns new functions that take real inputs instead of complex inputs as + # (x, y) -> fn(x + y * 1j). And it computes: inp -> fn(inp + y * 1j) and inp -> fn(x + inp * 1j). + # In each case, the other part is considered constant. + # We do not use 0 for the constant here to make sure we always call the user function with a valid input. + def apply_to_c_inps(fn, fn_to_apply): + def wrapped_fn(*inputs): + new_inputs = list(inputs) + for should_be_complex in complex_inp_indices: + new_inputs[should_be_complex] = fn_to_apply( + new_inputs[should_be_complex], tupled_inputs[should_be_complex] + ) + return _as_tuple(fn(*new_inputs)) + + return wrapped_fn + + real_fn = apply_to_c_inps(fn, lambda inp, orig: inp + orig.imag * 1j) + imag_fn = apply_to_c_inps(fn, lambda inp, orig: orig.real + inp * 1j) + return real_fn, imag_fn + + +def _gradcheck_real_imag( + gradcheck_fn, + func, + func_out, + tupled_inputs, + outputs, + eps, + rtol, + atol, + check_grad_dtypes, + check_forward_ad, + check_backward_ad, + nondet_tol, + check_undefined_grad, +): + complex_out_indices = [i for i, o in enumerate(outputs) if o.is_complex()] + has_any_complex_output = any(o.is_complex() for o in _as_tuple(func_out)) + if check_backward_ad: + if has_any_complex_output: + real_fn, imag_fn = _real_and_imag_output(func) + + imag_func_out = imag_fn(*tupled_inputs) + imag_outputs = _differentiable_outputs(imag_func_out) + gradcheck_fn( + imag_fn, + imag_func_out, + tupled_inputs, + imag_outputs, + eps, + rtol, + atol, + check_grad_dtypes, + nondet_tol, + complex_indices=complex_out_indices, + test_imag=True, + ) + + real_func_out = real_fn(*tupled_inputs) + real_outputs = _differentiable_outputs(real_func_out) + gradcheck_fn( + real_fn, + real_func_out, + tupled_inputs, + real_outputs, + eps, + rtol, + atol, + check_grad_dtypes, + nondet_tol, + complex_indices=complex_out_indices, + ) + else: + gradcheck_fn( + func, + func_out, + tupled_inputs, + outputs, + eps, + rtol, + atol, + check_grad_dtypes, + nondet_tol, + ) + + if check_forward_ad: + complex_inp_indices = [ + i + for i, inp in enumerate(tupled_inputs) + if is_tensor_like(inp) and inp.is_complex() + ] + if complex_inp_indices: + real_fn, imag_fn = _real_and_imag_input( + func, complex_inp_indices, tupled_inputs + ) + + imag_inputs = [ + inp.imag if is_tensor_like(inp) and inp.is_complex() else inp + for inp in tupled_inputs + ] + imag_func_out = imag_fn(*imag_inputs) + diff_imag_func_out = _differentiable_outputs(imag_func_out) + gradcheck_fn( + imag_fn, + imag_func_out, + imag_inputs, + diff_imag_func_out, + eps, + rtol, + atol, + check_grad_dtypes, + nondet_tol, + complex_indices=complex_inp_indices, + test_imag=True, + use_forward_ad=True, + ) + + real_inputs = [ + inp.real if is_tensor_like(inp) and inp.is_complex() else inp + for inp in tupled_inputs + ] + real_func_out = real_fn(*real_inputs) + diff_real_func_out = _differentiable_outputs(real_func_out) + gradcheck_fn( + real_fn, + real_func_out, + real_inputs, + diff_real_func_out, + eps, + rtol, + atol, + check_grad_dtypes, + nondet_tol, + complex_indices=complex_inp_indices, + use_forward_ad=True, + ) + if check_undefined_grad: + _test_undefined_forward_mode(imag_fn, imag_func_out, imag_inputs) + _test_undefined_forward_mode(real_fn, real_func_out, real_inputs) + else: + gradcheck_fn( + func, + func_out, + tupled_inputs, + outputs, + eps, + rtol, + atol, + check_grad_dtypes, + nondet_tol, + use_forward_ad=True, + ) + if check_undefined_grad: + _test_undefined_forward_mode(func, outputs, tupled_inputs) + + +def _slow_gradcheck( + func, + func_out, + tupled_inputs, + outputs, + eps, + rtol, + atol, + check_grad_dtypes, + nondet_tol, + *, + use_forward_ad=False, + complex_indices=None, + test_imag=False, + masked=False, +): + func_out = _as_tuple(func_out) + if not outputs: + return _check_no_differentiable_outputs( + func, tupled_inputs, func_out, eps=eps, is_forward_ad=use_forward_ad + ) + tupled_inputs_numerical = tupled_inputs if masked else _densify(tupled_inputs) + + numerical = _transpose( + _get_numerical_jacobian( + func, + tupled_inputs_numerical, + func_out, + eps=eps, + is_forward_ad=use_forward_ad, + ) + ) + # Note: [numerical vs analytical output length] + # The numerical path returns jacobian quantity for all outputs, even if requires_grad of that + # output is False. This behavior is necessary for _check_no_differentiable_outputs to work. + numerical = [nj for o, nj in zip(func_out, numerical) if o.requires_grad] + if use_forward_ad: + analytical_forward = _get_analytical_jacobian_forward_ad( + func, tupled_inputs, func_out, check_grad_dtypes=check_grad_dtypes + ) + + for i, n_per_out in enumerate(numerical): + for j, n in enumerate(n_per_out): + a = analytical_forward[j][i] + if not _allclose_with_type_promotion(a, n.to(a.device), rtol, atol): + raise GradcheckError( + _get_notallclose_msg( + a, n, i, j, complex_indices, test_imag, is_forward_ad=True + ) + ) + else: + for i, o in enumerate(outputs): + analytical = _check_analytical_jacobian_attributes( + tupled_inputs, o, nondet_tol, check_grad_dtypes + ) + + for j, (a, n) in enumerate(zip(analytical, numerical[i])): + if not _allclose_with_type_promotion(a, n.to(a.device), rtol, atol): + raise GradcheckError( + _get_notallclose_msg(a, n, i, j, complex_indices, test_imag) + ) + + return True + + +def _dot_with_type_promotion(u, v): + assert u.dim() == 1 and v.dim() == 1 + return (u * v).sum() + + +def _allclose_with_type_promotion(a, b, rtol, atol): + promoted_type = torch.promote_types(a.dtype, b.dtype) + a = a.to(dtype=promoted_type) + b = b.to(dtype=promoted_type) + return torch.allclose(a, b, rtol, atol) + + +def _to_real_dtype(dtype): + if dtype == torch.complex128: + return torch.float64 + elif dtype == torch.complex64: + return torch.float32 + else: + return dtype + + +def _vec_from_tensor(x, generator, downcast_complex=False): + # Create a random vector with the same number of elements as x and the same + # dtype/device. If x is complex and downcast_complex is False, we create a + # complex tensor with only real component. + if x.layout == torch.sparse_coo: + # For sparse, create a random sparse vec with random values in the same + # indices. Make sure size is set so that it isn't inferred to be smaller. + x_values = x._values() + dtype = _to_real_dtype(x.dtype) if downcast_complex else x.dtype + values = ( + torch.rand(x_values.numel(), generator=generator) + .to(dtype=dtype, device=x.device) + .view(x_values.shape) + ) + values /= values.norm() + vec = torch.sparse_coo_tensor(x._indices(), values, x.size(), device=x.device) + elif _is_sparse_compressed_tensor(x): + if x.layout in {torch.sparse_csr, torch.sparse_bsr}: + compressed_indices, plain_indices = x.crow_indices(), x.col_indices() + else: + compressed_indices, plain_indices = x.ccol_indices(), x.row_indices() + x_values = x.values() + dtype = _to_real_dtype(x.dtype) if downcast_complex else x.dtype + values = ( + torch.rand(x_values.numel(), generator=generator) + .to(dtype=dtype, device=x.device) + .view(x_values.shape) + ) + values /= values.norm() + vec = torch.sparse_compressed_tensor( + compressed_indices, + plain_indices, + values, + x.size(), + layout=x.layout, + device=x.device, + ) + else: + dtype = _to_real_dtype(x.dtype) if downcast_complex else x.dtype + vec = torch.rand(x.numel(), generator=generator).to( + dtype=dtype, device=x.device + ) + vec /= vec.norm() + return vec + + +def _get_inp_tensors(tupled_inputs): + inp_idx_tup = [ + (i, t) + for i, t in enumerate(tupled_inputs) + if is_tensor_like(t) and t.requires_grad + ] + return [tup[0] for tup in inp_idx_tup], [tup[1] for tup in inp_idx_tup] + + +def _adjusted_atol(atol, u, v): + # In slow gradcheck, we compare A and B element-wise, i.e., for some a, b we + # allow: |a - b| < atol + rtol * b. But since we now compare q1 = v^T A u and + # q2 = v^T B u, we must allow |q1 - q2| < v^T E u + rtol * v^T B u, where E is + # the correctly sized matrix in which each entry is atol. + # + # We see that atol needs to be scaled by v^T M u (where M is an all-ones M x N + # matrix): v^T M u = \sum_{i} \sum_{j} u_i * v_j = (\sum_{i} u_i)(\sum_{i} v_i) + # TODO: properly handle case when u is tuple instead of only taking first element + u = u[0] if isinstance(u, tuple) else u + sum_u = u.sum() + sum_v = 1.0 if v is None else v.sum() + return atol * float(sum_u) * float(sum_v) + + +FAST_FAIL_SLOW_OK_MSG = """ +Fast gradcheck failed but element-wise differences are small. This means that the +test might've passed in slow_mode! + +If you are adding a new operator, please file an issue and then use one of the +workarounds. The workaround depends on how your test invokes gradcheck/gradgradcheck: + +If the test +- manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck + with `fast_mode=False` as a keyword argument. +- is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test + to have `gradcheck_fast_mode=False` +- is a Module test (e.g., in common_nn.py), then modify the corresponding + module_test entry to have `gradcheck_fast_mode=False` +""".strip() + + +def _run_slow_mode_and_get_error( + func, tupled_inputs, outputs, input_idx, output_idx, rtol, atol, eps, is_forward_ad +): + # Compute jacobians in slow mode for better error message + slow_numerical = _get_numerical_jacobian( + func, tupled_inputs, outputs, eps=eps, is_forward_ad=is_forward_ad + )[input_idx][output_idx] + if is_forward_ad: + + def new_fn(inp): + new_inputs = list(tupled_inputs) + new_inputs[input_idx] = inp + return _as_tuple(func(*new_inputs))[output_idx] + + slow_analytical = _get_analytical_jacobian_forward_ad( + new_fn, (tupled_inputs[input_idx],), (outputs[output_idx],) + )[0][0] + else: + slow_analytical = _get_analytical_jacobian( + tupled_inputs, outputs, input_idx, output_idx + ) + + # Assume jacobians are non-empty and have the same shape + slow_max_diff = (slow_numerical - slow_analytical).abs().max() + + slow_allclose = torch.allclose(slow_analytical, slow_numerical, rtol, atol) + msg = ( + "\nThe above quantities relating the numerical and analytical jacobians are computed \n" + "in fast mode. See: https://github.com/pytorch/pytorch/issues/53876 for more background \n" + "about fast mode. Below, we recompute numerical and analytical jacobians in slow mode:\n\n" + f"Numerical:\n {slow_numerical}\n" + f"Analytical:\n{slow_analytical}\n\n" + f"The max per-element difference (slow mode) is: {slow_max_diff}.\n" + ) + if slow_allclose: + # Slow gradcheck would've passed! + msg += FAST_FAIL_SLOW_OK_MSG + return msg + + +def _to_flat_dense_if_sparse(tensor): + if _is_sparse_any_tensor(tensor): + return tensor.to_dense().reshape(-1) + else: + return tensor + + +def _make_vectors(inp_tensors, outputs, *, use_forward_ad): + # Use our own generator to avoid messing with the user's RNG state + g_cpu = torch.Generator() + + def _vec_from_tensor_cpu(*args): + # Default allocate all tensors on CPU, so they are on the same device as the generator + # even if the user specified a default device + with torch.device("cpu"): + return _vec_from_tensor(*args) + + all_u = [] + all_u_dense = [] + for inp in inp_tensors: + ur = _vec_from_tensor_cpu(inp, g_cpu, True) + ur_dense = _to_flat_dense_if_sparse(ur) + if inp.is_complex(): + ui = _vec_from_tensor_cpu(inp, g_cpu, True) + all_u.append((ur, ui)) + ui_dense = _to_flat_dense_if_sparse(ui) + all_u_dense.append((ur_dense, ui_dense)) + else: + all_u.append(ur) + all_u_dense.append(ur_dense) + all_v = ( + None + if use_forward_ad + else [_vec_from_tensor_cpu(out, g_cpu) for out in outputs] + ) + return all_v, all_u, all_u_dense + + +def _check_analytical_numerical_equal( + all_analytical, + all_numerical, + complex_indices, + tupled_inputs, + outputs, + func, + all_v, + all_u, + rtol, + atol, + eps, + test_imag, + *, + is_forward_ad=False, +): + for i, all_numerical_for_input_i in enumerate(all_numerical): + for j, n in enumerate(all_numerical_for_input_i): + # Forward AD generates the transpose of what this function expects + if is_forward_ad: + a = all_analytical[i][j] + else: + a = all_analytical[j][i] + n = n.to(device=a.device) + updated_atol = _adjusted_atol(atol, all_u[i], all_v[j] if all_v else None) + if not _allclose_with_type_promotion(a, n.to(a.device), rtol, updated_atol): + jacobians_str = _run_slow_mode_and_get_error( + func, tupled_inputs, outputs, i, j, rtol, atol, eps, is_forward_ad + ) + raise GradcheckError( + _get_notallclose_msg( + a, n, j, i, complex_indices, test_imag, is_forward_ad + ) + + jacobians_str + ) + + +def _fast_gradcheck( + func, + func_out, + inputs, + outputs, + eps, + rtol, + atol, + check_grad_dtypes, + nondet_tol, + *, + use_forward_ad=False, + complex_indices=None, + test_imag=False, + masked=False, +): + # See https://github.com/pytorch/pytorch/issues/53876 for details + inp_tensors_idx, inp_tensors = _get_inp_tensors(inputs) + # Backward mode computes v^T * J (VJP) + # Since we computed J * u (JVP) through finite difference method, we perform an equality check + # between VJP * u, v * JVP + # ---- + # Forward mode computes J * u (JVP) + # Since we already compute JVP through finite difference method, + # we don't need v for correctness check here as asserted below + all_v, all_u, all_u_dense = _make_vectors( + inp_tensors, outputs, use_forward_ad=use_forward_ad + ) + + inputs_numerical, all_u_numerical, all_v_numerical = ( + (inputs, all_u, all_v) if masked else _densify((inputs, all_u, all_v)) + ) + + numerical_vJu = _get_numerical_vJu( + func, + inputs_numerical, + inp_tensors_idx, + func_out, + all_u_numerical, + all_v_numerical, + eps, + is_forward_ad=use_forward_ad, + ) + # TODO: replicate https://github.com/pytorch/pytorch/pull/77743 for fast gradcheck as well + if use_forward_ad: + assert all_v is None + analytical_vJu = _get_analytical_jacobian_forward_ad( + func, + inputs, + _as_tuple(func_out), + all_u=all_u, + check_grad_dtypes=check_grad_dtypes, + ) + else: + if not outputs: + _check_no_differentiable_outputs_fast( + func, func_out, inputs, inp_tensors_idx, all_u, eps, nondet_tol + ) + + analytical_vJu = _get_analytical_vJu_backward_mode( + inputs, outputs, nondet_tol, check_grad_dtypes, all_v, all_u_dense + ) + + _check_analytical_numerical_equal( + analytical_vJu, + numerical_vJu, + complex_indices, + inputs, + outputs, + func, + all_v, + all_u, + rtol, + atol, + eps, + test_imag, + is_forward_ad=use_forward_ad, + ) + + return True + + +# Note [VarArg of Tensors] +# ~~~~~~~~~~~~~~~~~~~~~~~~ +# 'func' accepts a vararg of tensors, which isn't expressable in the type system at the moment. +# If https://mypy.readthedocs.io/en/latest/additional_features.html?highlight=callable#extended-callable-types is accepted, +# the '...' first argument of Callable can be replaced with VarArg(Tensor). +# For now, we permit any input. +def gradcheck( + func: Callable[..., Union[_TensorOrTensors]], # See Note [VarArg of Tensors] + inputs: _TensorOrTensors, + *, + eps: float = 1e-6, + atol: float = 1e-5, + rtol: float = 1e-3, + raise_exception: bool = True, + nondet_tol: float = 0.0, + check_undefined_grad: bool = True, + check_grad_dtypes: bool = False, + check_batched_grad: bool = False, + check_batched_forward_grad: bool = False, + check_forward_ad: bool = False, + check_backward_ad: bool = True, + fast_mode: bool = False, + masked: Optional[bool] = None, +) -> bool: # noqa: D400,D205 + r"""Check gradients computed via small finite differences against analytical + gradients wrt tensors in :attr:`inputs` that are of floating point or complex type + and with ``requires_grad=True``. + + The check between numerical and analytical gradients uses :func:`~torch.allclose`. + + For most of the complex functions we consider for optimization purposes, no notion of + Jacobian exists. Instead, gradcheck verifies if the numerical and analytical values of + the Wirtinger and Conjugate Wirtinger derivatives are consistent. Because the gradient + computation is done under the assumption that the overall function has a real-valued + output, we treat functions with complex output in a special way. For these functions, + gradcheck is applied to two real-valued functions corresponding to taking the real + components of the complex outputs for the first, and taking the imaginary components + of the complex outputs for the second. For more details, check out + :ref:`complex_autograd-doc`. + + .. note:: + The default values are designed for :attr:`input` of double precision. + This check will likely fail if :attr:`input` is of less precision, e.g., + ``FloatTensor``. + + .. note:: + Gradcheck may fail when evaluated on non-differentiable points + because the numerically computed gradients via finite differencing may differ + those computed analytically (not necessarily because either is incorrect). + For more context, see :ref:`non-differentiable-func-grad`. + + .. warning:: + If any checked tensor in :attr:`input` has overlapping memory, i.e., + different indices pointing to the same memory address (e.g., from + :func:`torch.Tensor.expand`), this check will likely fail because the numerical + gradients computed by point perturbation at such indices will change + values at all other indices that share the same memory address. + + Args: + func (function): a Python function that takes Tensor inputs and returns + a Tensor or a tuple of Tensors + inputs (tuple of Tensor or Tensor): inputs to the function + eps (float, optional): perturbation for finite differences + atol (float, optional): absolute tolerance + rtol (float, optional): relative tolerance + raise_exception (bool, optional): indicating whether to raise an exception if + the check fails. The exception gives more information about the + exact nature of the failure. This is helpful when debugging gradchecks. + nondet_tol (float, optional): tolerance for non-determinism. When running + identical inputs through the differentiation, the results must either match + exactly (default, 0.0) or be within this tolerance. + check_undefined_grad (bool, optional): if ``True``, check if undefined output grads + are supported and treated as zeros, for ``Tensor`` outputs. + check_batched_grad (bool, optional): if ``True``, check if we can compute + batched gradients using prototype vmap support. Defaults to False. + check_batched_forward_grad (bool, optional): if ``True``, checks if we can compute + batched forward gradients using forward ad and prototype vmap support. Defaults to ``False``. + check_forward_ad (bool, optional): if ``True``, check that the gradients computed with forward + mode AD match the numerical ones. Defaults to ``False``. + check_backward_ad (bool, optional): if ``False``, do not perform any checks that rely on + backward mode AD to be implemented. Defaults to ``True``. + fast_mode (bool, optional): Fast mode for gradcheck and gradgradcheck is currently only + implemented for R to R functions. If none of the inputs and outputs are complex + a faster implementation of gradcheck that no longer computes the entire jacobian + is run; otherwise, we fall back to the slow implementation. + masked (bool, optional): if ``True``, the gradients of unspecified elements of + sparse tensors are ignored. Defaults to ``False``. + Returns: + ``True`` if all differences satisfy allclose condition + + """ + assert ( + check_forward_ad or check_backward_ad + ), "Expected at least one of check_forward_ad or check_backward_ad to be True" + assert not ( + check_batched_grad and not check_backward_ad + ), "Setting check_batched_grad=True requires check_backward_ad to be True" + assert not ( + check_batched_forward_grad and not check_forward_ad + ), "Setting check_batched_forward_grad=True requires check_forward_ad to be True" + args = locals().copy() + args.pop("raise_exception") + if not raise_exception: + try: + return _gradcheck_helper(**args) + except GradcheckError: + return False + else: + return _gradcheck_helper(**args) + + +def _gradcheck_helper( + func, + inputs, + eps, + atol, + rtol, + nondet_tol, + check_undefined_grad, + check_grad_dtypes, + check_batched_grad, + check_batched_forward_grad, + check_forward_ad, + check_backward_ad, + fast_mode, + masked, +): + tupled_inputs = _as_tuple(inputs) + _check_inputs(tupled_inputs) + + func_out = func(*tupled_inputs) + outputs = _differentiable_outputs(func_out) + _check_outputs(outputs) + + gradcheck_fn = functools.partial( + _fast_gradcheck if fast_mode else _slow_gradcheck, masked=masked + ) + _gradcheck_real_imag( + gradcheck_fn, + func, + func_out, + tupled_inputs, + outputs, + eps, + rtol, + atol, + check_grad_dtypes, + check_forward_ad=check_forward_ad, + check_backward_ad=check_backward_ad, + nondet_tol=nondet_tol, + check_undefined_grad=check_undefined_grad, + ) + + if check_batched_forward_grad: + _test_batched_grad_forward_ad(func, tupled_inputs) + + # Short circuit because remaining tests rely on backward AD to be implemented + if not check_backward_ad: + return True + + for i, o in enumerate(outputs): + if check_batched_grad: + _test_batched_grad(tupled_inputs, o, i) + + _test_backward_mul_by_grad_output(outputs, tupled_inputs, masked) + + if check_undefined_grad and check_backward_ad: + _test_undefined_backward_mode(func, outputs, tupled_inputs) + return True + + +def gradgradcheck( + func: Callable[..., _TensorOrTensors], # See Note [VarArg of Tensors] + inputs: _TensorOrTensors, + grad_outputs: Optional[_TensorOrTensors] = None, + *, + eps: float = 1e-6, + atol: float = 1e-5, + rtol: float = 1e-3, + gen_non_contig_grad_outputs: bool = False, + raise_exception: bool = True, + nondet_tol: float = 0.0, + check_undefined_grad: bool = True, + check_grad_dtypes: bool = False, + check_batched_grad: bool = False, + check_fwd_over_rev: bool = False, + check_rev_over_rev: bool = True, + fast_mode: bool = False, + masked: bool = False, +) -> bool: # noqa: D400,D205 + r"""Check gradients of gradients computed via small finite differences + against analytical gradients wrt tensors in :attr:`inputs` and + :attr:`grad_outputs` that are of floating point or complex type and with + ``requires_grad=True``. + + This function checks that backpropagating through the gradients computed + to the given :attr:`grad_outputs` are correct. + + The check between numerical and analytical gradients uses :func:`~torch.allclose`. + + .. note:: + The default values are designed for :attr:`input` and + :attr:`grad_outputs` of double precision. This check will likely fail if + they are of less precision, e.g., ``FloatTensor``. + + .. warning:: + If any checked tensor in :attr:`input` and :attr:`grad_outputs` has + overlapping memory, i.e., different indices pointing to the same memory + address (e.g., from :func:`torch.Tensor.expand`), this check will likely fail + because the numerical gradients computed by point perturbation at such + indices will change values at all other indices that share the same + memory address. + + Args: + func (function): a Python function that takes Tensor inputs and returns + a Tensor or a tuple of Tensors + inputs (tuple of Tensor or Tensor): inputs to the function + grad_outputs (tuple of Tensor or Tensor, optional): The gradients with + respect to the function's outputs. + eps (float, optional): perturbation for finite differences + atol (float, optional): absolute tolerance + rtol (float, optional): relative tolerance + gen_non_contig_grad_outputs (bool, optional): if :attr:`grad_outputs` is + ``None`` and :attr:`gen_non_contig_grad_outputs` is ``True``, the + randomly generated gradient outputs are made to be noncontiguous + raise_exception (bool, optional): indicating whether to raise an exception if + the check fails. The exception gives more information about the + exact nature of the failure. This is helpful when debugging gradchecks. + nondet_tol (float, optional): tolerance for non-determinism. When running + identical inputs through the differentiation, the results must either match + exactly (default, 0.0) or be within this tolerance. Note that a small amount + of nondeterminism in the gradient will lead to larger inaccuracies in + the second derivative. + check_undefined_grad (bool, optional): if True, check if undefined output grads + are supported and treated as zeros + check_batched_grad (bool, optional): if True, check if we can compute + batched gradients using prototype vmap support. Defaults to False. + fast_mode (bool, optional): if True, run a faster implementation of gradgradcheck that + no longer computes the entire jacobian. + masked (bool, optional): if True, the gradients of unspecified elements of + sparse tensors are ignored (default, False). + Returns: + True if all differences satisfy allclose condition + """ + assert ( + check_fwd_over_rev or check_rev_over_rev + ), "Expected at least one of check_fwd_over_rev or check_rev_over_rev to be True" + assert not ( + check_undefined_grad and not check_rev_over_rev + ), "Setting check_undefined_grad=True requires check_rev_over_rev to be True" + assert not ( + check_batched_grad and not check_rev_over_rev + ), "Setting check_batched_grad=True requires check_rev_over_rev to be True" + # TODO: do we want to test this too? + # assert not (check_batched_forward_grad and not check_fwd_over_rev), ( + # "Setting check_batched_forward_grad=True requires check_fwd_over_rev to be True") + tupled_inputs = _as_tuple(inputs) + + if grad_outputs is None: + # If grad_outputs is not specified, create random Tensors of the same shape, type, and device as the outputs + + outputs = _differentiable_outputs(func(*tupled_inputs)) + tupled_grad_outputs = tuple( + torch.testing.make_tensor( + x.shape, + dtype=x.dtype + if x.is_floating_point() or x.is_complex() + else torch.double, + device=x.device, + low=-1, + high=1, + requires_grad=True, + noncontiguous=gen_non_contig_grad_outputs, + ) + for x in outputs + ) + else: + tupled_grad_outputs = _as_tuple(grad_outputs) + + num_outputs = len(tupled_grad_outputs) + + # NB: We need to save the requires_grad information about the inputs here because gradcheck detaches inputs + # before running forward mode AD + diff_input_args_indices = { + i for i, x in enumerate(tupled_inputs) if is_tensor_like(x) and x.requires_grad + } + diff_grad_output_indices = { + i for i, x in enumerate(tupled_grad_outputs) if x.requires_grad + } + + def new_func(*args): + # Restore the requires_grad information + input_args = tuple( + x.requires_grad_() if i in diff_input_args_indices else x + for i, x in enumerate(args[:-num_outputs]) + ) + outputs = _differentiable_outputs(func(*input_args)) + grad_outputs = tuple( + x.requires_grad_() if i in diff_grad_output_indices else x + for i, x in enumerate(args[-num_outputs:]) + ) + diff_input_args = tuple( + x for i, x in enumerate(input_args) if i in diff_input_args_indices + ) + grad_inputs = torch.autograd.grad( + outputs, diff_input_args, grad_outputs, create_graph=True, allow_unused=True + ) + grad_inputs = tuple(g for g in grad_inputs if g is not None) + return grad_inputs + + return gradcheck( + new_func, + tupled_inputs + tupled_grad_outputs, + eps=eps, + atol=atol, + rtol=rtol, + raise_exception=raise_exception, + nondet_tol=nondet_tol, + check_undefined_grad=check_undefined_grad, + check_grad_dtypes=check_grad_dtypes, + check_batched_grad=check_batched_grad, + fast_mode=fast_mode, + check_forward_ad=check_fwd_over_rev, + check_backward_ad=check_rev_over_rev, + masked=masked, + ) diff --git a/phivenv/Lib/site-packages/torch/autograd/graph.py b/phivenv/Lib/site-packages/torch/autograd/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..1c971929bb11754749eea1c3e93d2e83ca432947 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/autograd/graph.py @@ -0,0 +1,834 @@ +import abc +import contextlib +import functools +import logging +import threading +from collections import defaultdict, deque +from collections.abc import Generator, Iterable, Iterator, MutableMapping, Sequence +from typing import ( + Any, + Callable, + cast, + Literal, + NamedTuple, + Optional, + TYPE_CHECKING, + Union, +) +from typing_extensions import TypeAlias +from weakref import WeakKeyDictionary, WeakValueDictionary + +import torch +from torch.autograd.variable import Variable +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils.hooks import RemovableHandle + + +if TYPE_CHECKING: + from torch._ops import OpOverload + + +__all__ = [ + "saved_tensors_hooks", + "save_on_cpu", + "disable_saved_tensors_hooks", + "register_multi_grad_hook", + "allow_mutation_on_saved_tensors", + "Node", + "GradientEdge", + "get_gradient_edge", + "increment_version", +] + + +log = logging.getLogger(__name__) + + +class Node(abc.ABC): + @abc.abstractmethod + def name(self) -> str: + r"""Return the name. + + Example:: + + >>> import torch + >>> a = torch.tensor([0., 0., 0.], requires_grad=True) + >>> b = a.clone() + >>> assert isinstance(b.grad_fn, torch.autograd.graph.Node) + >>> print(b.grad_fn.name()) + CloneBackward0 + """ + raise NotImplementedError + + @property + @abc.abstractmethod + def next_functions(self) -> tuple[tuple[Optional["Node"], int], ...]: + raise NotImplementedError + + @abc.abstractmethod + def metadata(self) -> dict: + r"""Return the metadata.""" + raise NotImplementedError + + @property + @abc.abstractmethod + def _input_metadata(self) -> list[Any]: + raise NotImplementedError + + @abc.abstractmethod + def _register_hook_dict(self, tensor: torch.Tensor) -> None: + raise NotImplementedError + + @abc.abstractmethod + def register_hook(self, fn: Callable[..., Any]) -> RemovableHandle: + r"""Register a backward hook. + + The hook will be called every time a gradient with respect to the + Node is computed. The hook should have the following signature:: + + hook(grad_inputs: Tuple[Tensor], grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None + + + The hook should not modify its argument, but it can optionally return + a new gradient which will be used in place of :attr:`grad_inputs`. + + This function returns a handle with a method ``handle.remove()`` + that removes the hook from the module. + + .. note:: + See :ref:`backward-hooks-execution` for more information on how when this hook + is executed, and how its execution is ordered relative to other hooks. + + .. note:: + In the rare case where the hook is registered while the Node has already + begun execution, there is no longer any guarantee on :attr:`grad_outputs` + content (it might be as usual or empty depending on other factors). The + hook can still optionally return a new gradient to be used in place of + :attr:`grad_inputs` independent of :attr:`grad_outputs`. + + Example:: + + >>> import torch + >>> a = torch.tensor([0., 0., 0.], requires_grad=True) + >>> b = a.clone() + >>> assert isinstance(b.grad_fn, torch.autograd.graph.Node) + >>> handle = b.grad_fn.register_hook(lambda gI, gO: (gO[0] * 2,)) + >>> b.sum().backward(retain_graph=True) + >>> print(a.grad) + tensor([2., 2., 2.]) + >>> handle.remove() # Removes the hook + >>> a.grad = None + >>> b.sum().backward(retain_graph=True) + >>> print(a.grad) + tensor([1., 1., 1.]) + """ + raise NotImplementedError + + @abc.abstractmethod + def register_prehook(self, fn: Callable[..., Any]) -> RemovableHandle: + r"""Register a backward pre-hook. + + The hook will be called every time a gradient with respect to the + Node is computed. The hook should have the following signature:: + + hook(grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None + + The hook should not modify its argument, but it can optionally return + a new gradient which will be used in place of :attr:`grad_outputs`. + + This function returns a handle with a method ``handle.remove()`` + that removes the hook from the module. + + .. note:: + See :ref:`backward-hooks-execution` for more information on how when this hook + is executed, and how its execution is ordered relative to other hooks. + + Example:: + + >>> a = torch.tensor([0., 0., 0.], requires_grad=True) + >>> b = a.clone() + >>> assert isinstance(b.grad_fn, torch.autograd.graph.Node) + >>> handle = b.grad_fn.register_prehook(lambda gI: (gI[0] * 2,)) + >>> b.sum().backward(retain_graph=True) + >>> print(a.grad) + tensor([2., 2., 2.]) + >>> handle.remove() + >>> a.grad = None + >>> b.sum().backward(retain_graph=True) + >>> print(a.grad) + tensor([1., 1., 1.]) + """ + raise NotImplementedError + + @classmethod + def __subclasshook__(cls, subclass: type) -> bool: + if cls is Node and ( + ( + subclass is not None + and subclass is getattr(torch._C._functions, subclass.__name__, None) + ) + or issubclass(subclass, torch.autograd.function.BackwardCFunction) + ): + return True + return NotImplemented + + +def _get_grad_fn_or_grad_acc(t: Union[torch.Tensor, "GradientEdge"]) -> Node: + if isinstance(t, GradientEdge): + return t.node + if t.requires_grad and t.grad_fn is None: + with torch.enable_grad(): + node = t.view_as(t).grad_fn.next_functions[0][0] # type: ignore[union-attr] + else: + node = t.grad_fn + assert node is not None + return node + + +class GradientEdge(NamedTuple): + """Object representing a given gradient edge within the autograd graph. + + To get the gradient edge where a given Tensor gradient will be computed, + you can do ``edge = autograd.graph.get_gradient_edge(tensor)``. + """ + + node: Node + output_nr: int + + +def get_gradient_edge(tensor: torch.Tensor) -> GradientEdge: + """Get the gradient edge for computing the gradient of the given Tensor. + + In particular, it is equivalent to call + ``g = autograd.grad(loss, input)`` and ``g = autograd.grad(loss, get_gradient_edge(input))``. + """ + if not tensor.requires_grad: + raise RuntimeError( + "It is not possible to get the gradient edge for a Tensor " + "that does not require gradients", + ) + grad_fn = _get_grad_fn_or_grad_acc(tensor) + + # Note that output_nr default to 0 which is the right value + # for the AccumulateGrad node. + return GradientEdge(grad_fn, tensor.output_nr) + + +def increment_version(tensor: Union[torch.Tensor, Iterable[torch.Tensor]]) -> None: + """Update autograd metadata tracking whether the given Tensor was modified in place. + + This is to enable more accurate error checking within the autograd engine. + It is already done automatically by PyTorch functions and within custom Function + when mark_dirty() is called appropriately so you only need to call this explicitly + if you are doing inplace operation on the Tensor data in a way that Pytorch doesn't + know about. For example a custom kernel that reads the Tensor data_ptr and modifies + the memory inplace based on this pointer. Can accept either a tensor, or a list of tensors. + + Note that incrementing the version counter multiple times for a single inplace operation + is not problematic. + + Note that if you pass in tensor constructed under torch.inference_mode(), + we will not bump its version counter (because your tensor does not have one). + """ + if isinstance(tensor, torch.Tensor): + tensor = (tensor,) + torch._C._increment_version(tensor) + + +class saved_tensors_hooks: + """Context-manager that sets a pair of pack / unpack hooks for saved tensors. + + Use this context-manager to define how intermediary results of an operation + should be packed before saving, and unpacked on retrieval. + + In that context, the ``pack_hook`` function will be called everytime an + operation saves a tensor for backward (this includes intermediary results + saved using + :func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but + also those recorded by a PyTorch-defined operation). The output of + ``pack_hook`` is then stored in the computation graph instead of the + original tensor. + + The ``unpack_hook`` is called when the saved tensor needs to be accessed, + namely when executing :func:`torch.Tensor.backward()` or + :func:`torch.autograd.grad()`. It takes as argument the *packed* object + returned by ``pack_hook`` and should return a tensor which has the same + content as the original tensor (passed as input to the corresponding + ``pack_hook``). + + The hooks should have the following signatures: + + pack_hook(tensor: Tensor) -> Any + + unpack_hook(Any) -> Tensor + + where the return value of ``pack_hook`` is a valid input to ``unpack_hook``. + + In general, you want ``unpack_hook(pack_hook(t))`` to be equal to ``t`` in terms + of value, size, dtype and device. + + Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) + >>> def pack_hook(x): + ... print("Packing", x) + ... return x.detach() + >>> + >>> def unpack_hook(x): + ... print("Unpacking", x) + ... return x + >>> + >>> a = torch.ones(5, requires_grad=True) + >>> b = torch.ones(5, requires_grad=True) * 2 + >>> with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook): + ... y = a * b + Packing tensor([1., 1., 1., 1., 1.], requires_grad=True) + Packing tensor([2., 2., 2., 2., 2.], grad_fn=) + >>> y.sum().backward() + Unpacking tensor([1., 1., 1., 1., 1.], requires_grad=True) + Unpacking tensor([2., 2., 2., 2., 2.], grad_fn=) + + .. warning :: + Performing an inplace operation on the input to either hooks may lead + to undefined behavior. + + .. warning :: + Only one pair of hooks is allowed at a time. When recursively nesting this + context-manager, only the inner-most pair of hooks will be applied. + + .. warning :: + To avoid reference cycle, the return value of ``pack_hook`` cannot hold a + reference to the input tensor. For example, use `lambda x: x.detach()` + instead of `lambda x: x` as the pack hook. + """ + + def __init__( + self, + pack_hook: Callable[[torch.Tensor], Any], + unpack_hook: Callable[[Any], torch.Tensor], + ) -> None: + self.pack_hook = pack_hook + self.unpack_hook = unpack_hook + + def __enter__(self) -> None: + torch._C._autograd._push_saved_tensors_default_hooks( + self.pack_hook, self.unpack_hook + ) + + def __exit__(self, *args: object) -> None: + torch._C._autograd._pop_saved_tensors_default_hooks() + + +class save_on_cpu(saved_tensors_hooks): + """Context manager under which tensors saved by the forward pass will be stored on cpu, then retrieved for backward. + + When performing operations within this context manager, intermediary + results saved in the graph during the forward pass will be moved to CPU, + then copied back to the original device when needed for the backward pass. + If the graph was already on CPU, no tensor copy is performed. + + Use this context-manager to trade compute for GPU memory usage (e.g. + when your model doesn't fit in GPU memory during training). + + Args: + pin_memory (bool): If ``True`` tensors will be saved to CPU pinned memory + during packing and copied to GPU asynchronously during unpacking. + Defaults to ``False``. + Also see :ref:`cuda-memory-pinning`. + + + Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) + >>> a = torch.randn(5, requires_grad=True, device="cuda") + >>> b = torch.randn(5, requires_grad=True, device="cuda") + >>> c = torch.randn(5, requires_grad=True, device="cuda") + >>> + >>> def f(a, b, c): + ... prod_1 = a * b # a and b are saved on GPU + ... with torch.autograd.graph.save_on_cpu(): + ... prod_2 = prod_1 * c # prod_1 and c are saved on CPU + ... y = prod_2 * a # prod_2 and a are saved on GPU + ... return y + >>> + >>> y = f(a, b, c) + >>> del a, b, c # for illustration only + >>> # the content of a, b, and prod_2 are still alive on GPU + >>> # the content of prod_1 and c only live on CPU + >>> y.sum().backward() # all CPU tensors are moved back to GPU, for backward + >>> # all intermediary tensors are released (deleted) after the call to backward + """ + + def __init__(self, pin_memory: bool = False, device_type: str = "cuda") -> None: + device_module = getattr(torch, device_type, torch.cuda) + + def pack_to_cpu(tensor: torch.Tensor) -> tuple[torch.device, torch.Tensor]: + if not pin_memory: + return (tensor.device, tensor.cpu()) + packed = torch.empty( + tensor.size(), + dtype=tensor.dtype, + layout=tensor.layout, + pin_memory=(device_module.is_available() and not tensor.is_sparse), + ) + packed.copy_(tensor) + return (tensor.device, packed) + + def unpack_from_cpu(packed: tuple[torch.device, torch.Tensor]) -> torch.Tensor: + device, tensor = packed + return tensor.to(device, non_blocking=pin_memory) + + super().__init__(pack_to_cpu, unpack_from_cpu) + + +@contextlib.contextmanager +def disable_saved_tensors_hooks(error_message: str) -> Generator[None, None, None]: + """Context-manager that disables the saved tensors default hooks feature. + + Useful for if you are creating a feature that does not work with saved + tensors default hooks. + + Args: + error_message (str): When saved tensors default hooks are used when they + have been are disabled, a RuntimeError with this + error message gets raised. + + Example:: + + >>> # xdoctest: +SKIP(failing) + >>> message = "saved tensors default hooks are disabled" + >>> with torch.autograd.graph.disable_saved_tensors_hooks(message): + ... # Raises RuntimeError: saved tensors default hooks are disabled + ... with torch.autograd.graph.save_on_cpu(): + ... pass + """ + maybe_prev_message = None + try: + maybe_prev_message = ( + torch._C._autograd._saved_tensors_hooks_get_disabled_error_message() + ) + torch._C._autograd._saved_tensors_hooks_disable(error_message) + yield + finally: + # See NOTE: [disabled_error_message invariant] + if maybe_prev_message is None: + torch._C._autograd._saved_tensors_hooks_enable() + else: + torch._C._autograd._saved_tensors_hooks_disable(maybe_prev_message) + + +class _MultiHandle(RemovableHandle): + handles: tuple[RemovableHandle, ...] + + def __init__(self, handles: tuple[RemovableHandle, ...]) -> None: + self.handles = handles + + def remove(self) -> None: + for handle in self.handles: + handle.remove() + + def __getstate__(self) -> tuple[RemovableHandle, ...]: + return self.handles + + def __setstate__(self, state: tuple[RemovableHandle, ...]) -> None: + self.handles = state + + +def register_multi_grad_hook( + tensors: Sequence[torch.Tensor], + fn: Union[ + Callable[[Sequence[Optional[torch.Tensor]]], None], + Callable[[torch.Tensor], None], + ], + *, + mode: Literal["all", "any"] = "all", +) -> RemovableHandle: + r"""Register a multi-grad backward hook. + + There are two supported modes: ``"all"`` and ``"any"``. + + Under the ``"all"`` mode, the hook will be called after gradients with respect to every tensor in + :attr:`tensors` have been computed. If a tensor is in :attr:`tensors` but + is not part of the graph, or if a tensor is not needed to compute the gradients + for any ``inputs`` specified for the current ``.backward()`` or ``.grad()`` call, + this tensor will be ignored and the hook will not wait for its gradient to be + computed. + + After every non-ignored tensor's gradient has been computed, :attr:`fn` will be + called with those gradients. ``None`` will be passed for tensors that did not + have their gradients computed. + + Under the ``"any"`` mode, the hook will be called after the first gradient + with respect to a tensor in :attr:`tensors` has been computed. The hook + will be called with that gradient as its argument. + + The hook should not modify its arguments. + + This function returns a handle with a method ``handle.remove()`` that removes the hook. + + .. note:: + See :ref:`backward-hooks-execution` for more information on how when this hook + is executed, and how its execution is ordered relative to other hooks. + + Example:: + + >>> import torch + >>> + >>> a = torch.rand(2, 3, requires_grad=True) + >>> b = torch.rand(2, 3, requires_grad=True) + >>> c = a * b + >>> d = a * b + >>> + >>> def fn(grads): + ... print([g is not None for g in grads]) + ... + >>> torch.autograd.graph.register_multi_grad_hook((a, b, c, d), fn) + >>> + >>> c.sum().backward(retain_graph=True) + [True, True, True, False] + >>> c.sum().backward(inputs=(a,), retain_graph=True) + [True, False, True, False] + >>> + """ + supported_modes = ("all", "any") + lock = threading.Lock() + + if mode not in supported_modes: + raise ValueError(f"Expects mode to be one of {supported_modes} but got {mode}") + + if mode == "all": + count: dict[int, int] = {} + nb_calls = None + buffer: dict[int, list[Optional[torch.Tensor]]] = {} + + grad_fns = list(map(_get_grad_fn_or_grad_acc, tensors)) + len_tensors = len(tensors) + + def get_inner_hook(idx: int) -> Callable[[torch.Tensor], None]: + def inner_hook(grad: torch.Tensor) -> None: + nonlocal count, nb_calls, buffer, fn + id = torch._C._current_graph_task_id() + assert ( + id != -1 + ), "expected this hook to be called inside a backward call" + count[id] = count.get(id, 0) + buffer[id] = buffer.get(id, [None] * len_tensors) + + with lock: + curr_count, count[id] = count[id], count[id] + 1 + + if curr_count == 0: + # On the first call, compute the actual nb_calls and buffer + nb_calls = sum( + map(torch._C._will_engine_execute_node, grad_fns) + ) + + buffer[id][idx] = grad + + assert nb_calls is not None + if curr_count == nb_calls - 1: + fn = cast(Callable[[Sequence[Optional[torch.Tensor]]], None], fn) + fn(buffer[id]) + del count[id] + del buffer[id] + + return inner_hook + + handles = tuple( + t.register_hook(get_inner_hook(i)) for i, t in enumerate(tensors) + ) + elif mode == "any": + fn = cast(Callable[[torch.Tensor], None], fn) + ran_hook: dict[int, bool] = defaultdict(bool) + + @functools.wraps(fn) + def wrapped_fn(grad: torch.Tensor) -> None: + nonlocal ran_hook + id = torch._C._current_graph_task_id() + assert id != -1, "expected this hook to be called inside a backward call" + with lock: + prev, ran_hook[id] = ran_hook[id], True + if prev: + return + fn(grad) + + handles = tuple( + tensor.register_hook(wrapped_fn) + for tensor in tensors + if tensor.requires_grad + ) + + return _MultiHandle(handles) # type: ignore[possibly-undefined] + + +# NOTE [Allow mutation on tensors saved for backward] +# +# 1. Tensor gets saved for backward +# - remember the python object id and the version of the tensor +# - remember aliasing information (data_ptr of base + version) +# - save the original so we control its lifetime +# 2. Any time a tensor gets in-placed +# - for each tensor aliased to it: +# - check using its object id and version to see if it has been saved +# - if it has been saved, clone it +# - delete the reference to the original +# 3. during backward +# - if the clone exists, the tensor must've been modified in-place +_allow_mutation_on_saved_tensors_enabled: bool = False + + +_TID: TypeAlias = tuple[int, int, int] +_SID: TypeAlias = tuple[int, int] + + +def _get_tid(tensor: torch.Tensor) -> _TID: + # FIXME: This is almost definitely a bug. + if isinstance( + tensor, + ( + torch._subclasses.fake_tensor.FakeTensor, + torch._subclasses.functional_tensor.FunctionalTensor, + ), + ): + data_ptr = 0 + else: + data_ptr = tensor.data_ptr() + return (id(tensor), data_ptr, tensor._version) + + +def _get_sid(tensor: torch.Tensor) -> _SID: + # FIXME: This is almost definitely a bug. + if isinstance( + tensor, + ( + torch._subclasses.fake_tensor.FakeTensor, + torch._subclasses.functional_tensor.FunctionalTensor, + ), + ): + data_ptr = 0 + else: + data_ptr = tensor.data_ptr() + return (data_ptr, tensor._version) + + +class _Handle: + pass + + +class _swap_with_cloned(saved_tensors_hooks): + def __init__(self, ctx: "_AllowMutationOnSavedContext") -> None: + def pack_hook(tensor: torch.Tensor) -> _Handle: + tid = _get_tid(tensor) + sid = _get_sid(tensor) + # Tensors saved for backward have an entry in _tid_to_weakhandle + handle: Optional[_Handle] = None + + # Save aliasing information + ctx.sid_to_tid[sid].add(tid) + + # NB: The same tensor (of the same version) can be saved multiple times + if tid not in ctx.tid_to_weakhandle: + handle = _Handle() + ctx.tid_to_weakhandle[tid] = handle + ctx.original[handle] = tensor + else: + # Store an additional strong reference to the handle + handle = ctx.tid_to_weakhandle[tid] + return handle + + def unpack_hook(handle: _Handle) -> torch.Tensor: + error_msg = ( + "Trying to backward outside of the 'allow_mutation_on_saved_tensors' context" + "in which the graph was originally recorded." + ) + assert _allow_mutation_on_saved_tensors_enabled, error_msg + if handle in ctx.cloned: + res = ctx.cloned[handle] + else: + assert handle in ctx.original, error_msg + res = ctx.original[handle] + return res + + super().__init__(pack_hook, unpack_hook) + + +class _CloneArgBeforeMutateMode(TorchDispatchMode): + def __init__(self, ctx: "_AllowMutationOnSavedContext") -> None: + self.ctx = ctx + + def __torch_dispatch__( + self, + func: "OpOverload", + types: Iterable[type], + args: tuple[Any, ...] = (), + kwargs: Optional[dict[Any, Any]] = None, + ) -> Any: + kwargs = kwargs or {} + + def maybe_clone(t: torch.Tensor) -> None: + tid = _get_tid(t) + sid = _get_sid(t) + ctx = self.ctx + if sid in ctx.sid_to_tid: + for tid in ctx.sid_to_tid[sid]: + if tid not in ctx.tid_to_weakhandle: + # We know that if tid is in sid_to_tid, then it must also be in + # tid_to_weakhandle. However, it is possible for the tensor to be + # saved at one point, but cleared by backward before it is modified + # in-place. Consider the following example: + # + # >>> a = torch.randn(2, 3, requires_grad=True).clone() + # >>> out = (a**2).sum() + # >>> out.backward() + # >>> a.sin_() + continue + handle = ctx.tid_to_weakhandle[tid] + if handle in ctx.cloned: + # The same exact tensor has been cloned already + continue + ctx.cloned[handle] = ctx.original[handle].clone() + del ctx.original[handle] + + for idx, arg in enumerate(func._schema.arguments): + if arg.alias_info is not None and arg.alias_info.is_write: + if arg.is_out: + maybe_clone(kwargs["out"]) + elif isinstance(args[idx], list): + # Foreach case. (Possible optimization: if most of the + # tensors need to be cloned, use a for each clone?) + for t in args[idx]: + maybe_clone(t) + else: + maybe_clone(args[idx]) + + return func(*args, **kwargs) + + +class _AllowMutationOnSavedContext: + def __init__(self) -> None: + self.cloned: MutableMapping[_Handle, torch.Tensor] = WeakKeyDictionary() + self.original: MutableMapping[_Handle, torch.Tensor] = WeakKeyDictionary() + self.tid_to_weakhandle: MutableMapping[_TID, _Handle] = WeakValueDictionary() + self.sid_to_tid: dict[_SID, set[_TID]] = defaultdict(set) + + def clear(self) -> None: + self.cloned.clear() + self.original.clear() + self.tid_to_weakhandle.clear() + self.sid_to_tid.clear() + + +@contextlib.contextmanager +def allow_mutation_on_saved_tensors() -> ( + Generator[_AllowMutationOnSavedContext, None, None] +): + """Context manager under which mutating tensors saved for backward is allowed. + + Under this context manager, tensors saved for backward are cloned on mutation, + so the original version can still be used during backward. Normally, mutating a tensor + saved for backward will result in an error raised when it's used during backward. + + To ensure the correct behavior, both the forward and backward should be run under + the same context manager. + + Returns: + An _AllowMutationOnSavedContext object storing the state managed by this + context manager. This object can be useful for debugging purposes. The state + managed by the context manager is automatically cleared upon exiting. + + Example:: + + >>> import torch + >>> with torch.autograd.graph.allow_mutation_on_saved_tensors(): + ... # forward + ... a = torch.ones(2, 3, requires_grad=True) + ... b = a.clone() + ... out = (b**2).sum() + ... b.sin_() + ... # backward + ... out.sum().backward() + ... + tensor([[0.8415, 0.8415, 0.8415], + [0.8415, 0.8415, 0.8415]], grad_fn=) + """ + global _allow_mutation_on_saved_tensors_enabled + + ctx = _AllowMutationOnSavedContext() + + with _swap_with_cloned(ctx), _CloneArgBeforeMutateMode(ctx): + try: + if _allow_mutation_on_saved_tensors_enabled: + raise RuntimeError( + "allow_mutation_on_saved_tensors contexts cannot be nested" + ) + _allow_mutation_on_saved_tensors_enabled = True + yield ctx + finally: + ctx.clear() + _allow_mutation_on_saved_tensors_enabled = False + + +def _register_logging_hooks_on_whole_graph( + t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], +) -> Callable[[], None]: + grad_fns = list(map(_get_grad_fn_or_grad_acc, t_outputs)) + + def iter_graph(roots: list[Node]) -> Iterator[Node]: + if not roots: + return + seen: set[Node] = set() + q: deque[Node] = deque() + for node in roots: + if node is not None: + seen.add(node) + q.append(node) + + while q: + node = q.popleft() + for fn, _ in node.next_functions: + if fn in seen or fn is None: + continue + seen.add(fn) + q.append(fn) + + yield node + + def fmt(t: Optional[torch.Tensor]) -> str: + # Avoid circular import + from torch.utils._dtype_abbrs import dtype_abbrs + + if t is None: + return "None" + return f"{dtype_abbrs[t.dtype]}[{', '.join(map(str, t.shape))}]" + + def prehook(grad_outputs: Sequence[Optional[torch.Tensor]]) -> None: + node = torch._C._current_autograd_node() + grad_outputs_str = f"[{','.join(fmt(t) for t in grad_outputs)}]" + log_str = f"Executing: {node} with grad_outputs: {grad_outputs_str}" + log.debug(log_str) + + handles = [node.register_prehook(prehook) for node in iter_graph(grad_fns)] + + def unregister_hooks() -> None: + for handle in handles: + handle.remove() + + return unregister_hooks + + +def _engine_run_backward( + t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], + *args: Any, + **kwargs: Any, +) -> tuple[torch.Tensor, ...]: + attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG + if attach_logging_hooks: + unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) + try: + return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass + t_outputs, *args, **kwargs + ) # Calls into the C++ engine to run the backward pass + finally: + if attach_logging_hooks: + unregister_hooks() # type: ignore[possibly-undefined] diff --git a/phivenv/Lib/site-packages/torch/autograd/profiler.py b/phivenv/Lib/site-packages/torch/autograd/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..a6c68bdf7149512c2ff2bd7f4207c9e755b6c2a2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/autograd/profiler.py @@ -0,0 +1,1213 @@ +# mypy: allow-untyped-defs +import uuid +from collections import defaultdict +from collections.abc import Iterable +from dataclasses import dataclass +from time import perf_counter_ns +from typing import Any, Optional +from warnings import warn + +import torch +import torch.cuda +from torch._C import _get_privateuse1_backend_name +from torch._C._profiler import _ExperimentalConfig +from torch.autograd import ( + _disable_profiler, + _enable_profiler, + _kineto_step, + _prepare_profiler, + _ProfilerResult, + _supported_activities, + _toggle_collection_dynamic, + DeviceType, + kineto_available, + ProfilerActivity, + ProfilerConfig, + ProfilerState, +) +from torch.autograd.profiler_util import ( + _filter_name, + _filter_stack_entry, + _rewrite_name, + EventList, + FunctionEvent, + MEMORY_EVENT_NAME, + MemRecordsAcc, + OUT_OF_MEMORY_EVENT_NAME, +) +from torch.futures import Future + + +__all__ = [ + "profile", + "record_function", + "emit_itt", + "emit_nvtx", + "load_nvprof", + "EnforceUnique", + "parse_nvprof_trace", + "KinetoStepTracker", + "EventList", + "FunctionEvent", + "MemRecordsAcc", +] + +try: + # Available in Python >= 3.2 + from contextlib import ContextDecorator as _ContextDecorator +except ImportError: + import functools + + class _ContextDecorator: # type: ignore[no-redef] + def __enter__(self): + raise NotImplementedError + + def __exit__(self, exc_type, exc_val, exc_tb): + raise NotImplementedError + + def __call__(self, func): + @functools.wraps(func) + def wrapped(*args, **kwargs): + with self: + return func(*args, **kwargs) + + return wrapped + + +# global python state - whether profiler is currently enabled +# useful for fast python checks to reduce latency +_is_profiler_enabled: bool = False + + +def _set_is_profiler_enabled(enable: bool): + global _is_profiler_enabled + _is_profiler_enabled = enable + + +def _run_on_profiler_start(): + _set_is_profiler_enabled(True) + + +def _run_on_profiler_stop(): + _set_is_profiler_enabled(False) + + +@dataclass +class _ProfilerStats: + "Profiler timing and stats used by developers to catch issues/regressions" + profiling_window_duration_sec: float = 0 + number_of_events: int = 0 + profiler_prepare_call_duration_us: int = 0 + profiler_enable_call_duration_us: int = 0 + profiler_disable_call_duration_us: int = 0 + parse_kineto_call_duration_us: int = 0 + function_events_build_tree_call_duration_us: int = 0 + + +class profile: + """Context manager that manages autograd profiler state and holds a summary of results. + + Under the hood it just records events of functions being executed in C++ and + exposes those events to Python. You can wrap any code into it and it will + only report runtime of PyTorch functions. + Note: profiler is thread local and is automatically propagated into the async tasks + + Args: + enabled (bool, optional): Setting this to False makes this context manager a no-op. + + use_cuda (bool, optional): Enables timing of CUDA events as well + using the cudaEvent API. (will be deprecated) + + use_device (str, optional): Enables timing of device events. + Adds approximately 4us of overhead to each tensor operation when use cuda. + The valid devices options are 'cuda', 'xpu', 'mtia' and 'privateuseone'. + + record_shapes (bool, optional): If shapes recording is set, information + about input dimensions will be collected. This allows one to see which + dimensions have been used under the hood and further group by them + using prof.key_averages(group_by_input_shape=True). Please note that + shape recording might skew your profiling data. It is recommended to + use separate runs with and without shape recording to validate the timing. + Most likely the skew will be negligible for bottom most events (in a case + of nested function calls). But for higher level functions the total + self cpu time might be artificially increased because of the shape + collection. + + with_flops (bool, optional): If with_flops is set, the profiler will estimate + the FLOPs (floating point operations) value using the operator's input shape. + This allows one to estimate the hardware performance. Currently, + this option only works for the matrix multiplication and 2D convolution operators. + + profile_memory (bool, optional): track tensor memory allocation/deallocation. + + with_stack (bool, optional): record source information (file and line number) for the ops. + + with_modules (bool): record module hierarchy (including function names) + corresponding to the callstack of the op. e.g. If module A's forward call's + module B's forward which contains an aten::add op, + then aten::add's module hierarchy is A.B + Note that this support exist, at the moment, only for TorchScript models + and not eager mode models. + + use_kineto (bool, optional): experimental, enable profiling with Kineto profiler. + + use_cpu (bool, optional): profile CPU events; setting to ``False`` requires + ``use_kineto=True`` and can be used to lower the overhead for GPU-only profiling. + + experimental_config (_ExperimentalConfig) : A set of experimental options + used by profiler libraries like Kineto. Note, backward compatibility is not guaranteed. + + acc_events (bool): Enable the accumulation of FunctionEvents across multiple profiling cycles + + + .. warning:: + Enabling memory profiling or source attribution incurs additional profiler + overhead + + .. warning:: + This context managers should not be called recursively, i.e. no nested + instances are allowed + + .. warning:: + Due to some CUDA multiprocessing limitations (see :ref:`multiprocessing-cuda-note`), + one cannot use the profiler with ``use_device = 'cuda'`` to benchmark + DataLoaders with ``num_workers > 0``. If you wish to benchmark data loading, + please use ``use_device = None`` or ``num_workers = 0``. + + Example: + >>> # xdoctest: +SKIP + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER) + >>> x = torch.randn((1, 1), requires_grad=True) + >>> with torch.autograd.profiler.profile() as prof: + >>> for _ in range(100): # any normal python code, really! + >>> y = x ** 2 + >>> y.backward() + >>> # NOTE: some columns were removed for brevity + >>> print(prof.key_averages().table(sort_by="self_cpu_time_total")) + ----------------------------------- --------------- --------------- --------------- + Name Self CPU total CPU time avg Number of Calls + ----------------------------------- --------------- --------------- --------------- + mul 32.048ms 32.048ms 200 + pow 27.041ms 27.041ms 200 + PowBackward0 9.727ms 55.483ms 100 + torch::autograd::AccumulateGrad 9.148ms 9.148ms 100 + torch::autograd::GraphRoot 691.816us 691.816us 100 + ----------------------------------- --------------- --------------- --------------- + + """ + + def __init__( + self, + enabled=True, + *, + use_cuda=False, # Deprecated + use_device=None, + record_shapes=False, + with_flops=False, + profile_memory=False, + with_stack=False, + with_modules=False, + use_kineto=False, + use_cpu=True, + experimental_config=None, + acc_events=False, + custom_trace_id_callback=None, + ): + self.enabled: bool = enabled + if not self.enabled: + return + self.use_cuda = use_cuda + if self.use_cuda: + warn( + "The attribute `use_cuda` will be deprecated soon, " + "please use ``use_device = 'cuda'`` instead.", + FutureWarning, + stacklevel=2, + ) + self.use_device: Optional[str] = "cuda" + else: + self.use_device = use_device + # TODO Consider changing _function_events into data structure with size cap + self._function_events: Optional[EventList] = None + self._old_function_events: Optional[EventList] = None + # Function event processing is done lazily + self._needs_processing = False + self.entered = False + self.record_shapes = record_shapes + self.with_flops = with_flops + self.record_shapes |= self.with_flops + self.profile_memory = profile_memory + self.with_stack = with_stack + self.with_modules = with_modules + self.use_cpu = use_cpu + self.acc_events = acc_events + if experimental_config is None: + experimental_config = _ExperimentalConfig() + self.experimental_config = experimental_config + self.kineto_results: Optional[_ProfilerResult] = None + self.profiling_start_time_ns = 0 + self.profiling_end_time_ns = 0 + self._stats = _ProfilerStats() + self.custom_trace_id_callback = custom_trace_id_callback + self.trace_id = "" + if not self.use_cpu: + assert ( + use_kineto + ), "Device-only events supported only with Kineto (use_kineto=True)" + + if self.use_device is not None: + VALID_DEVICE_OPTIONS = ["cuda", "xpu", "mtia", "hpu"] + if _get_privateuse1_backend_name() != "privateuseone": + VALID_DEVICE_OPTIONS.append(_get_privateuse1_backend_name()) + if self.use_device not in VALID_DEVICE_OPTIONS: + warn(f"The {self.use_device} is not a valid device option.") + self.use_device = None + + if self.use_device == "cuda" and not torch.cuda.is_available(): + warn("CUDA is not available, disabling CUDA profiling") + self.use_cuda = False + self.use_device = None + + if self.use_device == "xpu" and not torch.xpu.is_available(): + warn("XPU is not available, disabling XPU profiling") + self.use_device = None + + if self.use_device == "hpu" and not ( + hasattr(torch, "hpu") and torch.hpu.is_available() + ): + warn("HPU is not available, disabling HPU profiling") + self.use_device = None + + self.kineto_activities = set() + if self.use_cpu: + self.kineto_activities.add(ProfilerActivity.CPU) + + self.profiler_kind = ProfilerState.KINETO + if self.use_device == "cuda": + if not use_kineto or ProfilerActivity.CUDA not in _supported_activities(): + assert self.use_cpu, "Legacy CUDA profiling requires use_cpu=True" + self.profiler_kind = ProfilerState.KINETO_GPU_FALLBACK + else: + self.kineto_activities.add(ProfilerActivity.CUDA) + elif self.use_device == "xpu": + assert ( + use_kineto and ProfilerActivity.XPU in _supported_activities() + ), "Legacy XPU profiling is not supported. Requires use_kineto=True on XPU devices." + self.kineto_activities.add(ProfilerActivity.XPU) + elif self.use_device == "mtia": + assert ( + use_kineto and ProfilerActivity.MTIA in _supported_activities() + ), "Legacy MTIA profiling is not supported. Requires use_kineto=True on MTIA devices." + self.kineto_activities.add(ProfilerActivity.MTIA) + elif self.use_device == "hpu": + assert ( + use_kineto and ProfilerActivity.HPU in _supported_activities() + ), "Legacy HPU profiling is not supported. Requires use_kineto=True on HPU devices." + self.kineto_activities.add(ProfilerActivity.HPU) + elif self.use_device is not None and self.use_device != "privateuseone": + if ( + not use_kineto + or ProfilerActivity.PrivateUse1 not in _supported_activities() + ): + assert ( + self.use_cpu + ), "Legacy custombackend profiling requires use_cpu=True" + self.profiler_kind = ProfilerState.KINETO_PRIVATEUSE1_FALLBACK + else: + self.kineto_activities.add(ProfilerActivity.PrivateUse1) + + assert ( + len(self.kineto_activities) > 0 + ), "No activities specified for the profiler" + + def default_trace_id(self): + # Generate a UUID + uuid_raw = uuid.uuid4() + + return f"{uuid_raw.int:032X}" + + def create_trace_id(self): + if self.custom_trace_id_callback: + return self.custom_trace_id_callback() + return self.default_trace_id() + + def config(self, create_trace_id=False): + # only need to generate new trace id upon prepare trace not start trace + if create_trace_id: + trace_id = self.create_trace_id() + self.trace_id = trace_id + return ProfilerConfig( + self.profiler_kind, + self.record_shapes, + self.profile_memory, + self.with_stack, + self.with_flops, + self.with_modules, + self.experimental_config, + self.trace_id, + ) + + def __enter__(self): + if not self.enabled: + return + if self.entered: + raise RuntimeError("Profiler context manager is not reentrant") + self._prepare_trace() + self._start_trace() + return self + + def _prepare_trace(self): + self.entered = True + t0 = perf_counter_ns() + _prepare_profiler(self.config(create_trace_id=True), self.kineto_activities) + t1 = perf_counter_ns() + self._stats.profiler_prepare_call_duration_us = int((t1 - t0) / 1000) + + def _start_trace(self): + self.entered = True + _run_on_profiler_start() + t0 = perf_counter_ns() + _enable_profiler(self.config(create_trace_id=False), self.kineto_activities) + t1 = perf_counter_ns() + self._stats.profiler_enable_call_duration_us = int((t1 - t0) / 1000) + self.profiling_start_time_ns = t1 + + def __exit__(self, exc_type, exc_val, exc_tb): + if not self.enabled: + return + if self.use_device and hasattr(torch, self.use_device): + device_module = getattr(torch, self.use_device) + if hasattr(device_module, "synchronize"): + device_module.synchronize() + + if self._function_events and self.acc_events: + self._old_function_events = self._function_events + self._function_events = None + self._needs_processing = True + + t0 = perf_counter_ns() + + self.kineto_results = _disable_profiler() + t1 = perf_counter_ns() + self._stats.profiler_disable_call_duration_us = int((t1 - t0) / 1000) + self.profiling_end_time_ns = t0 + + _run_on_profiler_stop() + + self._stats.profiling_window_duration_sec = ( + (self.profiling_end_time_ns - self.profiling_start_time_ns) * 1.0 / 1e9 + ) + + # If we plan to accumulate events we should post process the function events + # right away to retain the state across mulitple start/stop calls + if self.acc_events: + self._ensure_function_events() + return False + + def __repr__(self): + if self._needs_processing: + self._ensure_function_events() + if self._function_events is None: + return "" + return repr(self._function_events) + + def __str__(self): + if self._needs_processing: + self._ensure_function_events() + if self._function_events is None: + return "" + return str(self._function_events) + + def _ensure_function_events(self): + """Process function events lazily if required""" + if self._function_events is not None: + return + self._needs_processing = False + + t0 = perf_counter_ns() + parsed_results = [] + if self.kineto_results: + parsed_results = self._parse_kineto_results(self.kineto_results) + t1 = perf_counter_ns() + self._stats.parse_kineto_call_duration_us = int((t1 - t0) / 1000) + + self._function_events = EventList( + parsed_results, + use_device=self.use_device, + profile_memory=self.profile_memory, + with_flops=self.with_flops, + ) + t0 = perf_counter_ns() + self._function_events._build_tree() + t1 = perf_counter_ns() + self._stats.function_events_build_tree_call_duration_us = int((t1 - t0) / 1000) + self._stats.number_of_events = len(self._function_events) + + if self._old_function_events and self.acc_events: + for evt in self._old_function_events: + self._function_events.append(evt) + self._old_function_events = None + + if self._function_events is None: + raise RuntimeError("Profiler didn't finish running") + + @property + def function_events(self): + if self._function_events is None or self._needs_processing: + self._ensure_function_events() + return self._function_events + + def table( + self, + sort_by=None, + row_limit=100, + max_src_column_width=75, + max_name_column_width=55, + max_shapes_column_width=80, + header=None, + top_level_events_only=False, + ): + self._ensure_function_events() + assert self._function_events is not None + return self._function_events.table( + sort_by=sort_by, + row_limit=row_limit, + max_src_column_width=max_src_column_width, + max_name_column_width=max_name_column_width, + max_shapes_column_width=max_shapes_column_width, + header=header, + top_level_events_only=top_level_events_only, + ) + + table.__doc__ = EventList.table.__doc__ + + def export_chrome_trace(self, path): + """ + Exports the collected trace in Chrome JSON format. If kineto is enabled, only + last cycle in schedule is exported. + """ + if kineto_available(): + self.kineto_results.save(path) # type: ignore[union-attr] + else: + self._ensure_function_events() + return self._function_events.export_chrome_trace(path) # type: ignore[union-attr] + + export_chrome_trace.__doc__ = EventList.export_chrome_trace.__doc__ + + def export_stacks(self, path: str, metric: str = "self_cpu_time_total"): + self._ensure_function_events() + assert self._function_events is not None, "Expected profiling results" + assert self.with_stack, "export_stacks() requires with_stack=True" + return self._function_events.export_stacks(path, metric) + + def toggle_collection_dynamic( + self, enabled: bool, activities: Iterable[ProfilerActivity] + ): + """ + Toggles the collection of activities for the current profiler instance. + """ + return _toggle_collection_dynamic(enabled, set(activities)) + + def key_averages( + self, + group_by_input_shape=False, + group_by_stack_n=0, + group_by_overload_name=False, + ): + self._ensure_function_events() + assert self._function_events is not None, "Expected profiling results" + return self._function_events.key_averages( + group_by_input_shape, group_by_stack_n, group_by_overload_name + ) + + key_averages.__doc__ = EventList.key_averages.__doc__ + + def total_average(self): + self._ensure_function_events() + assert self._function_events is not None, "Expected profiling results" + return self._function_events.total_average() + + total_average.__doc__ = EventList.total_average.__doc__ + + @property + def self_cpu_time_total(self): + """Returns total time spent on CPU. + + The total time is a sum of all self times across all the events. + """ + self._ensure_function_events() + assert self._function_events is not None + return self._function_events.self_cpu_time_total + + def _parse_kineto_results(self, result: _ProfilerResult): + # result.events() has most of the events - PyTorch op-level and device-level events + + trace_start_ns = result.trace_start_ns() + mem_records = [ + [evt, False] for evt in result.events() if evt.name() == MEMORY_EVENT_NAME + ] + oom_records = [ + evt for evt in result.events() if evt.name() == OUT_OF_MEMORY_EVENT_NAME + ] + mem_records_acc = MemRecordsAcc(mem_records) + + def _cpu_memory_usage(mem_record): + return ( + mem_record.nbytes() + if mem_record.device_type() + in [DeviceType.CPU, DeviceType.MKLDNN, DeviceType.IDEEP] + else 0 + ) + + def _device_memory_usage(mem_record): + return ( + mem_record.nbytes() + if mem_record.device_type() + in [ + DeviceType.CUDA, + DeviceType.PrivateUse1, + DeviceType.HIP, + DeviceType.XPU, + ] + else 0 + ) + + # Create and return FunctionEvent list, which contains all function events + # Here 2 function events are created: + # all_function_events contains all events associated with each kineto event from result + all_function_events = [] + # frontend_function_events contains the events in aten or torch frontend level, + # whose correlation id is 0 + frontend_function_events = [] + device_corr_map: dict[int, list[FunctionEvent]] = {} + max_evt_id = 0 + for kineto_event in result.events(): + if _filter_name(kineto_event.name()): + continue + rel_start_ns = kineto_event.start_ns() - trace_start_ns + rel_end_ns = kineto_event.end_ns() - trace_start_ns + abs_end_ns = kineto_event.end_ns() + + cpu_memory_usage = 0 + device_memory_usage = 0 + if kineto_event.device_type() == DeviceType.CPU: + # find the corresponding memory allocation events + for mem_record in mem_records_acc.in_interval( + kineto_event.start_ns() / 1000, abs_end_ns / 1000 + ): + cpu_memory_usage += _cpu_memory_usage(mem_record[0]) + device_memory_usage += _device_memory_usage(mem_record[0]) + mem_record[1] = True + + is_async = kineto_event.is_async() or ( + kineto_event.start_thread_id() != kineto_event.end_thread_id() + ) + + fe = FunctionEvent( + id=kineto_event.correlation_id(), + name=_rewrite_name(name=kineto_event.name(), with_wildcard=True), + overload_name=kineto_event.overload_name(), + trace_name=_rewrite_name(name=kineto_event.name(), with_wildcard=False), + thread=kineto_event.start_thread_id(), + start_us=rel_start_ns / 1000, + end_us=rel_end_ns / 1000, + fwd_thread=kineto_event.fwd_thread_id(), + input_shapes=kineto_event.shapes(), + concrete_inputs=kineto_event.concrete_inputs(), + kwinputs=kineto_event.kwinputs(), + stack=[ + entry + for entry in kineto_event.stack() + if _filter_stack_entry(entry) + ], + scope=kineto_event.scope(), + use_device=self.use_device, + cpu_memory_usage=cpu_memory_usage, + device_memory_usage=device_memory_usage, + is_async=is_async, + sequence_nr=kineto_event.sequence_nr(), + device_type=kineto_event.device_type(), + device_index=kineto_event.device_index(), + device_resource_id=kineto_event.device_resource_id(), + flops=kineto_event.flops(), + is_user_annotation=kineto_event.is_user_annotation(), + ) + max_evt_id = max(max_evt_id, fe.id) + if fe.device_type == DeviceType.CPU and not fe.is_async: + if self.use_device == _get_privateuse1_backend_name(): + privateuse1_time = kineto_event.privateuse1_elapsed_us() + if privateuse1_time > 0: + fe.append_kernel(fe.name, fe.device_index, privateuse1_time) + fe.is_legacy = True + elif self.use_device == "cuda": + # Check if we have CUDA time as a fallback + cuda_time = kineto_event.cuda_elapsed_us() + if cuda_time > 0: + fe.append_kernel(fe.name, fe.device_index, cuda_time) + fe.is_legacy = True + all_function_events.append(fe) + corr_id = kineto_event.linked_correlation_id() + if corr_id > 0: + if corr_id not in device_corr_map: + device_corr_map[corr_id] = [] + device_corr_map[corr_id].append(fe) + elif corr_id == 0: + frontend_function_events.append(fe) + else: + raise RuntimeError( + f"Got negative correlation id {corr_id} in profiler post processing" + ) + + # associate device kernels and device runtime (CPU) with CPU events + for fe in frontend_function_events: + if ( + fe.device_type == DeviceType.CPU + and not fe.is_async + and fe.id in device_corr_map + ): + for f_evt in device_corr_map[fe.id]: + if ( + f_evt.device_type == DeviceType.CUDA + or f_evt.device_type == DeviceType.PrivateUse1 + ): + fe.append_kernel( + f_evt.name, + f_evt.device_index, + f_evt.time_range.end - f_evt.time_range.start, + ) + elif f_evt.device_type == DeviceType.CPU: + # make sure that 'thread' of a CPU Kineto (e.g. Device Runtime) event is associated + # with the 'thread' of the corresponding linked PyTorch event to properly track + # parents and children + f_evt.thread = fe.thread + + def createFunctionEventForMemoryEvents(evt): + rel_start_ns = evt.start_ns() - trace_start_ns + fe = FunctionEvent( + id=max_evt_id, + name=evt.name(), + overload_name="", + trace_name=None, # not outputting in the trace + thread=evt.start_thread_id(), + start_us=rel_start_ns / 1000, + end_us=rel_start_ns / 1000, # no duration + fwd_thread=evt.start_thread_id(), + input_shapes=[], + stack=[], + scope=0, # RecordScope::FUNCTION + use_device=self.use_device, + cpu_memory_usage=_cpu_memory_usage(evt), + device_memory_usage=_device_memory_usage(evt), + is_async=False, + sequence_nr=-1, + device_type=DeviceType.CPU, + device_index=0, + ) + return fe + + # output top-level memory events + for mem_record in mem_records: + if not mem_record[1]: + max_evt_id += 1 + fe = createFunctionEventForMemoryEvents(mem_record[0]) + all_function_events.append(fe) + + for oom_record in oom_records: + max_evt_id += 1 + fe = createFunctionEventForMemoryEvents(oom_record) + all_function_events.append(fe) + + all_function_events.sort( + key=lambda evt: [evt.time_range.start, -evt.time_range.end] + ) + return all_function_events + + +class record_function(_ContextDecorator): + """Context manager/function decorator that adds a label to a code block/function when running autograd profiler. + Label will only appear if CPU activity tracing is enabled. + + It is useful when tracing the code profile. + + Args: + name (str): Label assigned to the block of code. + node_id (int): ID of node, for distributed profiling. Unset in + non-distributed cases. + + Example: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER) + >>> x = torch.randn((1, 1), requires_grad=True) + >>> with torch.autograd.profiler.profile() as prof: + ... y = x ** 2 + ... with torch.autograd.profiler.record_function("label-z"): # label the block + ... z = y ** 3 + ... y.backward() + ... + >>> # xdoctest: +IGNORE_WANT + >>> # NOTE: some columns were removed for brevity + >>> print(prof.key_averages().table(sort_by="self_cpu_time_total")) + ----------------------------------- --------------- --------------- --------------- + Name Self CPU total % CPU time avg Number of Calls + ----------------------------------- --------------- --------------- --------------- + pow 60.77% 47.470us 3 + mul 21.73% 25.465us 2 + PowBackward0 12.03% 121.891us 1 + torch::autograd::AccumulateGrad 2.70% 6.324us 1 + label-z 2.13% 12.421us 1 + torch::autograd::GraphRoot 0.64% 1.503us 1 + ----------------------------------- --------------- --------------- --------------- + Self CPU time total: 234.344us + CUDA time total: 0.000us + + """ + + def __init__(self, name: str, args: Optional[str] = None): + self.name: str = name + self.args: Optional[str] = args + # Whether or not we should run record function's end callbacks when exiting. + self.run_callbacks_on_exit: bool = True + # TODO: TorchScript ignores standard type annotation here + # self.record: Optional["torch.classes.profiler._RecordFunction"] = None + self.record = torch.jit.annotate( + Optional["torch.classes.profiler._RecordFunction"], None + ) + + def __enter__(self): + self.record = torch.ops.profiler._record_function_enter_new( + self.name, self.args + ) + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): + if not self.run_callbacks_on_exit: + return + + # Local variable is needed by TorchScript to refine Optional[T] to T + record = self.record + assert record is not None + + # TODO: Too slow with __torch_function__ handling enabled + # See https://github.com/pytorch/pytorch/issues/76410 + if not torch.jit.is_scripting(): + with torch._C.DisableTorchFunctionSubclass(): + torch.ops.profiler._record_function_exit._RecordFunction(record) + else: + torch.ops.profiler._record_function_exit(record) + + def _call_end_callbacks_on_future(self, fut: Future[Any]) -> Future[Any]: + """Use for profiling async calls that return a future. + + Calling this function will extend recording beyond this scope, until the future is + satisfied. It is useful for profiling the end to end time of asynchronous calls. + This function should only be called once to attach the callback onto the future, and + will throw if called multiple times. + + Args: + fut: (torch._C.Future): future for which to schedule + callback for. + + Returns: + A future that completes with the value of the passed in future when + the profiling callbacks have ran. + + """ + # Throw if we have already attached a callback onto the future. + if not self.run_callbacks_on_exit: + raise RuntimeError("_call_end_callbacks_on_future can only be called once.") + + # We are scheduling to run this RecordFunction's end callbacks when the + # passed in future completes, so don't run end callbacks on exit. + self.run_callbacks_on_exit = False + + # Local variable is needed by TorchScript to refine Optional[T] to T + record = self.record + assert record is not None + + # TODO: Too slow with __torch_function__ handling enabled + # See https://github.com/pytorch/pytorch/issues/76410 + if not torch.jit.is_scripting(): + with torch._C.DisableTorchFunctionSubclass(): + profiled_future = ( + torch.ops.profiler._call_end_callbacks_on_jit_fut._RecordFunction( + record, fut + ) + ) + else: + profiled_future = torch.ops.profiler._call_end_callbacks_on_jit_fut( + record, fut + ) + return profiled_future + + +class emit_itt: + """Context manager that makes every autograd operation emit an ITT range. + + It is useful when running the program under Intel(R) VTune Profiler:: + + vtune <--vtune-flags> + + The Instrumentation and Tracing Technology (ITT) API enables your application to generate and + control the collection of trace data during its execution across different Intel tools. + This context manager is to annotate Intel(R) VTune Profiling trace. With help of this context manager, + you will be able to see labled ranges in Intel(R) VTune Profiler GUI. + + .. warning: + This context manager should not be called recursively, i.e. at most one + instance should be enabled at any given time. + + Args: + enabled (bool, optional): Setting ``enabled=False`` makes this context manager a no-op. + Default: ``True``. + record_shapes (bool, optional): If ``record_shapes=True``, the itt range wrapping + each autograd op will append information about the sizes of Tensor arguments received + by that op, in the following format: + ``[[arg0.size(0), arg0.size(1), ...], [arg1.size(0), arg1.size(1), ...], ...]`` + Non-tensor arguments will be represented by ``[]``. + Arguments will be listed in the order they are received by the backend op. + Please note that this order may not match the order in which those arguments were passed + on the Python side. Also note that shape recording may increase the overhead of itt range creation. + Default: ``False`` + + Example: + >>> # xdoctest: +SKIP("Undefined variables") + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER) + >>> with torch.autograd.profiler.emit_itt(): + ... model(x) + + """ + + def __init__(self, enabled=True, record_shapes=False): + self.enabled = enabled + self.entered = False + self.record_shapes = record_shapes + + def __enter__(self): + if not self.enabled: + return + if self.entered: + raise RuntimeError("ITT annotation context manager is not reentrant") + self.entered = True + _run_on_profiler_start() + _enable_profiler( + ProfilerConfig( + ProfilerState.ITT, + self.record_shapes, + False, + False, + False, + False, + _ExperimentalConfig(), + ), + set(), + ) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if not self.enabled: + return + _disable_profiler() + _run_on_profiler_stop() + return False + + +class emit_nvtx: + """Context manager that makes every autograd operation emit an NVTX range. + + It is useful when running the program under nvprof:: + + nvprof --profile-from-start off -o trace_name.prof -- + + Unfortunately, there's no way to force nvprof to flush the data it collected + to disk, so for CUDA profiling one has to use this context manager to annotate + nvprof traces and wait for the process to exit before inspecting them. + Then, either NVIDIA Visual Profiler (nvvp) can be used to visualize the timeline, or + :func:`torch.autograd.profiler.load_nvprof` can load the results for inspection + e.g. in Python REPL. + + .. warning: + This context manager should not be called recursively, i.e. at most one + instance should be enabled at any given time. + + Args: + enabled (bool, optional): Setting ``enabled=False`` makes this context manager a no-op. + Default: ``True``. + record_shapes (bool, optional): If ``record_shapes=True``, the nvtx range wrapping + each autograd op will append information about the sizes of Tensor arguments received + by that op, in the following format: + ``[[arg0.size(0), arg0.size(1), ...], [arg1.size(0), arg1.size(1), ...], ...]`` + Non-tensor arguments will be represented by ``[]``. + Arguments will be listed in the order they are received by the backend op. + Please note that this order may not match the order in which those arguments were passed + on the Python side. Also note that shape recording may increase the overhead of nvtx range creation. + Default: ``False`` + + Example: + >>> # xdoctest: +SKIP("undefined variables") + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER) + >>> with torch.cuda.profiler.profile(): + ... model(x) # Warmup CUDA memory allocator and profiler + ... with torch.autograd.profiler.emit_nvtx(): + ... model(x) + + **Forward-backward correlation** + + When viewing a profile created using :class:`emit_nvtx` in the Nvidia Visual Profiler, + correlating each backward-pass op with the corresponding forward-pass op can be difficult. + To ease this task, :class:`emit_nvtx` appends sequence number information to the ranges it + generates. + + During the forward pass, each function range is decorated with ``seq=``. ``seq`` is a running + counter, incremented each time a new backward Function object is created and stashed for backward. + Thus, the ``seq=`` annotation associated with each forward function range tells you that + if a backward Function object is created by this forward function, + the backward object will receive sequence number N. + During the backward pass, the top-level range wrapping each C++ backward Function's + ``apply()`` call is decorated with ``stashed seq=``. ``M`` is the sequence number that + the backward object was created with. By comparing ``stashed seq`` numbers in backward with ``seq`` + numbers in forward, you can track down which forward op created each backward Function. + + Any functions executed during the backward pass are also decorated with ``seq=``. During + default backward (with ``create_graph=False``) this information is irrelevant, and in fact, + ``N`` may simply be 0 for all such functions. Only the top-level ranges associated with + backward Function objects' ``apply()`` methods are useful, as a way to correlate these Function + objects with the earlier forward pass. + + **Double-backward** + + If, on the other hand, a backward pass with ``create_graph=True`` is underway (in other words, + if you are setting up for a double-backward), each function's execution during backward + is given a nonzero, useful ``seq=``. Those functions may themselves create Function objects + to be executed later during double-backward, just as the original functions in the forward pass did. + The relationship between backward and double-backward is conceptually the same as the relationship + between forward and backward: The functions still emit current-sequence-number-tagged ranges, + the Function objects they create still stash those sequence numbers, and during the eventual + double-backward, the Function objects' ``apply()`` ranges are still tagged with ``stashed seq`` + numbers, which can be compared to `seq` numbers from the backward pass. + + .. warning: + The sequence number is thread-local, and some forward functions don't create an associated + backward Function object (instead delegating that to sub-functions further down the call chain). + For these reasons, the correspondence of stashed sequence numbers in + backward Function ``apply()`` ranges with `seq` numbers in forward-pass ranges is + not guaranteed to be 1 to 1. The sequence numbers alone may not be enough to fully + disambiguate which forward function created which + backward Function object. You may need to make a judgment based on analytic knowledge of what + the expected correspondence should be. + """ + + def __init__(self, enabled=True, record_shapes=False): + self.enabled = enabled + self.entered = False + self.record_shapes = record_shapes + + def __enter__(self): + if not self.enabled: + return + if self.entered: + raise RuntimeError("NVTX annotation context manager is not reentrant") + self.entered = True + torch.cuda.synchronize() + _run_on_profiler_start() + _enable_profiler( + ProfilerConfig( + ProfilerState.NVTX, + self.record_shapes, + False, + False, + False, + False, + _ExperimentalConfig(), + ), + set(), + ) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if not self.enabled: + return + torch.cuda.synchronize() + _disable_profiler() + _run_on_profiler_stop() + return False + + +def load_nvprof(path): + """Open an nvprof trace file and parses autograd annotations. + + Args: + path (str): path to nvprof trace + """ + return EventList(parse_nvprof_trace(path)) + + +class EnforceUnique: + """Raises an error if a key is seen more than once.""" + + def __init__(self): + self.seen = set() + + def see(self, *key): + r""" + Observe a key and raise an error if it is seen multiple times. + """ + if key in self.seen: + raise RuntimeError("duplicate key: " + str(key)) + self.seen.add(key) + + +def parse_nvprof_trace(path): + import sqlite3 + + conn = sqlite3.connect(path) + conn.row_factory = sqlite3.Row + + # Parse strings table + strings = {} + for r in conn.execute("SELECT _id_ as id, value FROM StringTable"): + strings[r["id"]] = torch._C._demangle(r["value"]) + + # First, find all functions and create FunctionEvents for them + marker_query = """ + SELECT + start.id AS marker_id, start.name, start.timestamp AS start_time, end.timestamp AS end_time + FROM + CUPTI_ACTIVITY_KIND_MARKER AS start INNER JOIN CUPTI_ACTIVITY_KIND_MARKER AS end + ON start.id = end.id + WHERE + start.name != 0 AND end.name = 0 + """ + functions = [] + functions_map = {} + unique = EnforceUnique() + for row in conn.execute(marker_query): + unique.see(row["marker_id"]) + evt = FunctionEvent( + id=row["marker_id"], + node_id=0, # missing a node_id when calling FunctionEvent. This is just to ensure + # that pytorch doesn't crash when creating a FunctionEvent() object + name=strings[row["name"]], + start_us=row["start_time"], + end_us=row["end_time"], + thread=0, + ) # TODO: find in sqlite database + functions.append(evt) + functions_map[evt.id] = evt + + # Now, correlate all kernels with FunctionEvents + kernel_query = """ + SELECT + start.id AS marker_id, start.name, start.timestamp, end.timestamp, + runtime._id_ AS runtime_id, runtime.cbid, runtime.start AS runtime_start, runtime.end AS runtime_end, + kernel.start AS kernel_start, kernel.end AS kernel_end, kernel.name AS kernel_name + FROM + CUPTI_ACTIVITY_KIND_MARKER AS start + INNER JOIN CUPTI_ACTIVITY_KIND_MARKER AS end + ON start.id = end.id + INNER JOIN CUPTI_ACTIVITY_KIND_RUNTIME as runtime + ON (start.timestamp < runtime.start AND runtime.end < end.timestamp) + INNER JOIN CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL AS kernel + ON kernel.correlationId = runtime.correlationId + """ + unique = EnforceUnique() + for row in conn.execute(kernel_query): + unique.see(row["marker_id"], row["runtime_id"]) + # 211 is cudaKernelLaunch for cuda >= 9.2 + assert row["cbid"] == 211 + evt = functions_map[row["marker_id"]] + evt.append_kernel( + row["kernel_name"], 0, row["kernel_end"] - row["kernel_start"] + ) + + functions.sort(key=lambda evt: evt.time_range.start) + return functions + + +class KinetoStepTracker: + """Provides an abstraction for incrementing the step count globally. + + Previously, we only had one place to mark that a step() has occurred + in the program via pytorch profiler step(). We will now add step hooks + in the Optimizer class https://github.com/pytorch/pytorch/issues/88446 + + - This could mean programs that already call profiler.step() every + iteration can end up double incrementing step count. + - If a model uses multiple optimizers we can also have double or more + counting of the step. + + We fix this by adding a layer of abstraction before calling step() + to the kineto library. The idea is to maintain steps per requester in a dict: + + .. code-block:: + + { + "ProfilerStep": 100, # triggered by profiler step() call + "Optimizer1Step": 100, # Optimizer 1 or 2 are just examples, could be SGD, Adam etc + "Optimizer2Step": 100, + } + + To figure out the global step count just take the max of dict values (100). + + If one of the count increments the max will go up. + + .. code-block:: + + { + "ProfilerStep": 100, + "Optimizer1Step": 101, # Optimizer1 got incremented first say + "Optimizer2Step": 100, + } + + Then global step count is 101 + We only call the kineto step() function when global count increments. + + NOTE: Please do not use the KinetoStepTracker in modules beside the Optimizer + for now. The result could be incorrect increments of the step count. + """ + + _current_step = 0 + _step_dict: dict[str, int] = defaultdict(int) + + @classmethod + def init_step_count(cls, requester: str): + r""" + Initialize for a given requester. + """ + cls._step_dict[requester] = cls._current_step + + @classmethod + def erase_step_count(cls, requester: str) -> bool: + r""" + Remove a given requester. + """ + return cls._step_dict.pop(requester, None) is not None + + @classmethod + def increment_step(cls, requester: str) -> int: + """Increments the step count for the requester. + + Additionally if the max over all step counts has incremented then + trigger the _kineto_step() returns global step count + """ + if requester not in cls._step_dict: + cls.init_step_count(requester) + cls._step_dict[requester] += 1 + + new_step = max(cls._step_dict.values()) + if new_step > cls._current_step: + delta = new_step - cls._current_step + if delta > 1: + warn( + "Profiler step count has increased more than 1 - " + f"current_step = {cls._current_step} step dict = {cls._step_dict}" + ) + for _ in range(0, delta): + _kineto_step() + cls._current_step = new_step + return cls._current_step + + @classmethod + def current_step(cls) -> int: + r""" + Get the latest step for any requester + """ + return cls._current_step diff --git a/phivenv/Lib/site-packages/torch/autograd/profiler_legacy.py b/phivenv/Lib/site-packages/torch/autograd/profiler_legacy.py new file mode 100644 index 0000000000000000000000000000000000000000..f64df9d0f60775506e0eb14d606c775220611706 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/autograd/profiler_legacy.py @@ -0,0 +1,312 @@ +# mypy: allow-untyped-defs +import itertools +import warnings +from typing_extensions import deprecated + +import torch +import torch.cuda +from torch.autograd import ( + _disable_profiler_legacy, + _enable_profiler_legacy, + DeviceType, + ProfilerConfig, + ProfilerState, +) +from torch.autograd.profiler_util import ( + _filter_name, + _filter_stack_entry, + _rewrite_name, + EventList, + FunctionEvent, + MEMORY_EVENT_NAME, +) + + +__all__ = ["profile"] + + +@deprecated( + "`torch.autograd.profiler_legacy.profile` is deprecated and will be removed in a future release. " + "Please use `torch.profiler` instead.", + category=None, # TODO: change to `FutureWarning` +) +class profile: + """DEPRECATED: use torch.profiler instead.""" + + def __init__( + self, + enabled=True, + *, + use_cuda=False, + record_shapes=False, + with_flops=False, + profile_memory=False, + with_stack=False, + with_modules=False, + ): + self.enabled: bool = enabled + if not self.enabled: + return + self.use_cuda = use_cuda + self.function_events = None + self.entered = False + self.record_shapes = record_shapes + self.with_flops = with_flops + self.record_shapes |= self.with_flops + self.profile_memory = profile_memory + self.with_stack = with_stack + self.with_modules = with_modules + + if self.use_cuda and not torch.cuda.is_available(): + warnings.warn( + "CUDA is not available, disabling CUDA profiling", + stacklevel=2, + ) + self.use_cuda = False + + if self.use_cuda: + self.profiler_kind = ProfilerState.CUDA + else: + self.profiler_kind = ProfilerState.CPU + + def config(self): + return ProfilerConfig( + self.profiler_kind, + self.record_shapes, + self.profile_memory, + self.with_stack, + self.with_flops, + self.with_modules, + # avoid exposing _ExperimentalConfig this in legacy public API + torch._C._profiler._ExperimentalConfig(), + ) + + def __enter__(self): + if not self.enabled: + return + if self.entered: + raise RuntimeError("Profiler context manager is not reentrant") + self.entered = True + self._start_trace() + return self + + def _start_trace(self): + _enable_profiler_legacy(self.config()) + + def __exit__(self, exc_type, exc_val, exc_tb): + if not self.enabled: + return + if self.use_cuda: + torch.cuda.synchronize() + + records = _disable_profiler_legacy() + parsed_results = _parse_legacy_records(records) + self.function_events = EventList( + parsed_results, + use_device="cuda" if self.use_cuda else None, + profile_memory=self.profile_memory, + with_flops=self.with_flops, + ) + self.function_events._build_tree() + return False + + def __repr__(self): + if self.function_events is None: + return "" + return repr(self.function_events) + + def __str__(self): + if self.function_events is None: + return "" + return str(self.function_events) + + def _check_finish(self): + if self.function_events is None: + raise RuntimeError("Profiler didn't finish running") + + def table( + self, + sort_by=None, + row_limit=100, + max_src_column_width=75, + max_name_column_width=55, + max_shapes_column_width=80, + header=None, + top_level_events_only=False, + ): + self._check_finish() + assert self.function_events is not None + return self.function_events.table( + sort_by=sort_by, + row_limit=row_limit, + max_src_column_width=max_src_column_width, + max_name_column_width=max_name_column_width, + max_shapes_column_width=max_shapes_column_width, + header=header, + top_level_events_only=top_level_events_only, + ) + + table.__doc__ = EventList.table.__doc__ + + def export_chrome_trace(self, path): + self._check_finish() + assert self.function_events is not None + return self.function_events.export_chrome_trace(path) + + export_chrome_trace.__doc__ = EventList.export_chrome_trace.__doc__ + + def export_stacks(self, path: str, metric: str = "self_cpu_time_total"): + self._check_finish() + assert self.function_events is not None, "Expected profiling results" + assert self.with_stack, "export_stacks() requires with_stack=True" + return self.function_events.export_stacks(path, metric) + + def key_averages(self, group_by_input_shape=False, group_by_stack_n=0): + self._check_finish() + assert self.function_events is not None, "Expected profiling results" + return self.function_events.key_averages(group_by_input_shape, group_by_stack_n) + + key_averages.__doc__ = EventList.key_averages.__doc__ + + def total_average(self): + self._check_finish() + assert self.function_events is not None, "Expected profiling results" + return self.function_events.total_average() + + total_average.__doc__ = EventList.total_average.__doc__ + + @property + def self_cpu_time_total(self): + """Return CPU time as the sum of self times across all events.""" + self._check_finish() + assert self.function_events is not None + return self.function_events.self_cpu_time_total + + +def _parse_legacy_records(thread_records): + def _get_record_key(record): + """Return a tuple for correlating start and end records in `_parse_legacy_records`.""" + return (record.handle(), record.node_id()) + + start_record = None + functions = [] + + # '__start_profile' is not guaranteed to be first, so we must find it here + for record in itertools.chain.from_iterable(thread_records): + name = record.name() + if start_record is None and name == "__start_profile": + start_record = record + + assert start_record is not None and not start_record.is_remote() + + for thread_record_list in thread_records: + # accumulated memory allocations per handle + cpu_memory_allocs = {} + cuda_memory_allocs = {} + # ranges per handle + range_starts = {} + + filtered_handles = set() + prev_record = None + for record in thread_record_list: + record_key = _get_record_key(record) + if _filter_name(record.name()) or record_key in filtered_handles: + filtered_handles.add(record_key) + continue + + if record.kind() == "push": + # workaround to reduce double logging from operator + # wrappers and redispatch + if prev_record is not None: + duplicate = ( + prev_record.name() == record.name() + and prev_record.kind() == record.kind() + and prev_record.node_id() == record.node_id() + ) + if duplicate: + filtered_handles.add(record_key) + continue + + range_starts[record_key] = record + cpu_memory_allocs[record_key] = 0 + cuda_memory_allocs[record_key] = 0 + elif record.kind() == "pop": + assert ( + record_key in range_starts + ), f"""Expected record with key {record_key} to exist in range_starts. + This means that the pop event did not have a corresponding push.""" + + start = range_starts[record_key] + + cpu_memory_usage = cpu_memory_allocs[record_key] + cuda_memory_usage = cuda_memory_allocs[record_key] + is_async = start.is_async() or (start.thread_id() != record.thread_id()) + is_remote_event = record.is_remote() + start_flops = start.flops() + + fe = FunctionEvent( + id=record.handle(), + node_id=record.node_id(), + name=_rewrite_name(name=start.name(), with_wildcard=True), + trace_name=_rewrite_name(name=start.name(), with_wildcard=False), + thread=start.thread_id(), + start_us=start_record.cpu_elapsed_us(start), + end_us=start_record.cpu_elapsed_us(record), + fwd_thread=start.fwd_thread_id(), + input_shapes=start.shapes(), + stack=[ + entry for entry in start.stack() if _filter_stack_entry(entry) + ], + scope=start.scope(), + use_device="cuda" if start.has_cuda() else None, + cpu_memory_usage=cpu_memory_usage, + device_memory_usage=cuda_memory_usage, + is_async=is_async, + is_remote=is_remote_event, + sequence_nr=start.sequence_nr(), + device_type=DeviceType.CPU, + is_legacy=True, + flops=start_flops, + ) + # note: async events have only cpu total time + if not is_async and start.has_cuda(): + duration = start.cuda_elapsed_us(record) + if duration > 0: + fe.append_kernel(start.name(), start.device(), duration) + functions.append(fe) + del range_starts[record_key] + del cpu_memory_allocs[record_key] + del cuda_memory_allocs[record_key] + elif record.kind() == "memory_alloc": + num_open_handles_cpu = len(cpu_memory_allocs) + num_open_handles_cuda = len(cuda_memory_allocs) + assert num_open_handles_cpu == num_open_handles_cuda + for handle in cpu_memory_allocs.keys(): + cpu_memory_allocs[handle] += record.cpu_memory_usage() + for handle in cuda_memory_allocs.keys(): + cuda_memory_allocs[handle] += record.cuda_memory_usage() + if num_open_handles_cpu == 0: + # output event as a top-level memory event + fe = FunctionEvent( + id=0, + name=MEMORY_EVENT_NAME, + trace_name=None, + thread=0, + start_us=0, + end_us=0, + stack=[], + cpu_memory_usage=record.cpu_memory_usage(), + device_memory_usage=record.cuda_memory_usage(), + is_legacy=True, + ) + functions.append(fe) + prev_record = record + + # Sort functions by start time then by end time ascending. + # This ensures that--in the case of nested events which + # have the same start time (which may happen due to the + # granularity of the given clock tick)--we always show + # the outermost nested call first. This adds stability + # in how FunctionEvents appear + functions.sort(key=lambda evt: [evt.time_range.start, -evt.time_range.end]) + return functions diff --git a/phivenv/Lib/site-packages/torch/autograd/profiler_util.py b/phivenv/Lib/site-packages/torch/autograd/profiler_util.py new file mode 100644 index 0000000000000000000000000000000000000000..76a3a43b1e67661345bcd29e44b502c95d35f6e4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/autograd/profiler_util.py @@ -0,0 +1,1150 @@ +# mypy: allow-untyped-defs +import bisect +import itertools +import math +from collections import defaultdict, namedtuple +from operator import attrgetter +from typing import Any, Optional +from typing_extensions import deprecated + +import torch +from torch.autograd import DeviceType + + +__all__ = [ + "EventList", + "FormattedTimesMixin", + "Interval", + "Kernel", + "FunctionEvent", + "FunctionEventAvg", + "StringTable", + "MemRecordsAcc", +] + + +class EventList(list): + """A list of Events (for pretty printing).""" + + def __init__(self, *args, **kwargs): + use_device = kwargs.pop("use_device", None) + profile_memory = kwargs.pop("profile_memory", False) + with_flops = kwargs.pop("with_flops", False) + super().__init__(*args, **kwargs) + self._use_device = use_device + self._profile_memory = profile_memory + self._tree_built = False + self._with_flops = with_flops + + def _build_tree(self): + self._populate_cpu_children() + self._remove_dup_nodes() + self._set_backward_stacktraces() + self._tree_built = True + + def __str__(self): + return self.table() + + def _remove_dup_nodes(self): + while True: + to_delete = set() + for idx in range(len(self)): + if ( + self[idx].cpu_parent is not None + and self[idx].cpu_parent.name == self[idx].name + and len(self[idx].cpu_parent.cpu_children) == 1 + ): + self[idx].cpu_parent.cpu_children = self[idx].cpu_children + self[idx].cpu_parent.kernels = self[idx].kernels # lift kernels up + for ch in self[idx].cpu_children: + ch.cpu_parent = self[idx].cpu_parent + to_delete.add(idx) + if len(to_delete) == 0: + break + new_evts = [ev for ind, ev in enumerate(self) if ind not in to_delete] + self.clear() + self.extend(new_evts) + + def _populate_cpu_children(self): + """Populate child events into each underlying FunctionEvent object. + + One event is a child of another if [s1, e1) is inside [s2, e2). Where + s1 and e1 would be start and end of the child event's interval. And + s2 and e2 start and end of the parent event's interval + + Example: In event list [[0, 10], [1, 3], [3, 4]] would have make [0, 10] + be a parent of two other intervals. + + If for any reason two intervals intersect only partially, this function + will not record a parent child relationship between then. + """ + # Some events can be async (i.e. start and end on different threads), + # since it's generally undefined how to attribute children ranges to + # async ranges, we do not use them when calculating nested ranges and stats + sync_events = [ + evt + for evt in self + if not evt.is_async and evt.device_type == DeviceType.CPU + ] + events = sorted( + sync_events, + key=attrgetter("thread"), + ) + # Group by both thread and node_id, so that events that happen to have + # the same thread_id but are from different nodes aren't incorrectly + # grouped together. + threads = itertools.groupby( + events, key=lambda event: (event.thread, event.node_id) + ) + + # For each thread we keep a stack of current nested parents. + # We maintain the invariant that each interval is a subset of all other + # intervals lower in the stack. + # + # First we sort the intervals by their start time. Then we iterate over them. + # Every time we see a new interval we remove several parents from + # the top until we restore the invariant. Then parent child relationship + # if recorded if the stack is not empty. + # Finally we add new interval to the list + # + # Algorithm has O(N * log(N)) complexity where N is number of + # intervals + for _thread_id, thread_events in threads: + thread_events_ = sorted( + thread_events, + key=lambda event: [event.time_range.start, -event.time_range.end], + ) + current_events: list[FunctionEvent] = [] + for event in thread_events_: + while len(current_events) > 0: + parent = current_events[-1] + if ( + event.time_range.start >= parent.time_range.end + or event.time_range.end > parent.time_range.end + ): + # this can't be a parent + current_events.pop() + else: + parent.append_cpu_child(event) + assert ( + event.cpu_parent is None + ), f"There is already a CPU parent event for {event.key}" + event.set_cpu_parent(parent) + break + + current_events.append(event) + + def _set_backward_stacktraces(self): + def bw_parent(evt): + if evt is None: + return None + elif evt.scope == 1: # BACKWARD_FUNCTION + return evt + else: + return bw_parent(evt.cpu_parent) + + fwd_stacks = {} + for evt in self: + if bw_parent(evt) is None and evt.stack is not None: + t = (evt.sequence_nr, evt.thread) + if t not in fwd_stacks: + fwd_stacks[t] = evt.stack + + for evt in self: + p = bw_parent(evt) + if p is not None: + assert p.fwd_thread is not None + t = (p.sequence_nr, p.fwd_thread) + if t in fwd_stacks: + evt.stack = fwd_stacks[t] + else: + evt.stack = [] + + @property + def self_cpu_time_total(self): + return sum(event.self_cpu_time_total for event in self) + + def table( + self, + sort_by=None, + row_limit=100, + max_src_column_width=75, + max_name_column_width=55, + max_shapes_column_width=80, + header=None, + top_level_events_only=False, + ): + """Print an EventList as a nicely formatted table. + + Args: + sort_by (str, optional): Attribute used to sort entries. By default + they are printed in the same order as they were registered. + Valid keys include: ``cpu_time``, ``cuda_time``, ``xpu_time``, + ``cpu_time_total``, ``cuda_time_total``, ``xpu_time_total``, + ``cpu_memory_usage``, ``cuda_memory_usage``, ``xpu_memory_usage``, + ``self_cpu_memory_usage``, ``self_cuda_memory_usage``, + ``self_xpu_memory_usage``, ``count``. + top_level_events_only(bool, optional): Boolean flag to determine the + selection of events to display. If true, the profiler will only + display events at top level like top-level invocation of python + `lstm`, python `add` or other functions, nested events like low-level + cpu/cuda/xpu ops events are omitted for profiler result readability. + + Returns: + A string containing the table. + """ + return _build_table( + self, + sort_by=sort_by, + row_limit=row_limit, + max_src_column_width=max_src_column_width, + max_name_column_width=max_name_column_width, + max_shapes_column_width=max_shapes_column_width, + header=header, + profile_memory=self._profile_memory, + with_flops=self._with_flops, + top_level_events_only=top_level_events_only, + ) + + def export_chrome_trace(self, path): + """Export an EventList as a Chrome tracing tools file. + + The checkpoint can be later loaded and inspected under ``chrome://tracing`` URL. + + Args: + path (str): Path where the trace will be written. + """ + import os + + device_name = "cuda" if not self._use_device else self._use_device + with open(path, "w") as f: + next_id = 0 + # Use file IO over using json.dump since JSON dumping is very slow and + # this technique is proven to give a 4x speedup. + f.write("[") + for evt in self: + if evt.trace_name is None: + continue + f.write( + '{{"name": "{}", ' + '"ph": "X", ' + '"ts": {}, ' + '"dur": {}, ' + '"tid": {}, ' + '"pid": "CPU functions", ' + '"args": {{}}}}, '.format( + evt.trace_name, + evt.time_range.start, + evt.time_range.elapsed_us(), + evt.thread + if not evt.is_remote + else f'" node_id:{evt.node_id}, thread_id:{evt.thread} "', + ) + ) + for _ in evt.kernels: + # 's' and 'f' draw Flow arrows from + # the CPU launch to the GPU kernel + f.write( + f'{{"name": "{evt.trace_name}", ' + '"ph": "s", ' + f'"ts": {evt.time_range.start}, ' + f'"tid": {evt.thread}, ' + '"pid": "CPU functions", ' + f'"id": {next_id}, ' + f'"cat": "cpu_to_{device_name}", ' + '"args": {}}, ' + ) + # Note: use torch.profiler to get device kernel trace + next_id += 1 + if len(self) > 0: + # remove trailing whitespace and comma + f.seek(f.tell() - 2, os.SEEK_SET) + f.truncate() + f.write("]") + + def supported_export_stacks_metrics(self): + return [ + "self_cpu_time_total", + "self_cuda_time_total", + "self_xpu_time_total", + "self_privateuse1_time_total", + ] + + def export_stacks(self, path: str, metric: str): + if metric not in self.supported_export_stacks_metrics(): + raise ValueError( + "metric should be one of: " + + str(self.supported_export_stacks_metrics()) + ) + translate_table = str.maketrans(" ;\t\n", "____") + with open(path, "w") as f: + for evt in self: + if evt.stack and len(evt.stack) > 0: + metric_value = getattr( + evt, + metric.replace("cuda", "device") + .replace("xpu", "device") + .replace("privateuse1", "device"), + ) + if int(metric_value) > 0: + stack_str = "" + for entry in reversed(evt.stack): + stack_str += entry.translate(translate_table) + stack_str += ";" + stack_str = stack_str[:-1] + " " + str(int(metric_value)) + f.write(stack_str + "\n") + + def key_averages( + self, + group_by_input_shapes=False, + group_by_stack_n=0, + group_by_overload_name=False, + ): + """Averages all function events over their keys. + + Args: + group_by_input_shapes: group entries by + (event name, input shapes) rather than just event name. + This is useful to see which input shapes contribute to the runtime + the most and may help with size-specific optimizations or + choosing the best candidates for quantization (aka fitting a roof line) + + group_by_stack_n: group by top n stack trace entries + + group_by_overload_name: Differentiate operators by their overload name e.g. aten::add.Tensor + and aten::add.out will be aggregated separately + + Returns: + An EventList containing FunctionEventAvg objects. + """ + assert self._tree_built + stats: dict[tuple[str, ...], FunctionEventAvg] = defaultdict(FunctionEventAvg) + + def get_key( + event, group_by_input_shapes, group_by_stack_n, group_by_overload_name + ) -> tuple[str, ...]: + key = [ + str(event.key), + str(event.node_id), + str(event.device_type), + str(event.is_legacy), + str(event.is_user_annotation), + ] + if group_by_overload_name: + key.append(evt.overload_name) + if group_by_input_shapes: + key.append(str(event.input_shapes)) + if group_by_stack_n > 0: + key += event.stack[:group_by_stack_n] + return tuple(key) + + for evt in self: + stats[ + get_key( + evt, group_by_input_shapes, group_by_stack_n, group_by_overload_name + ) + ].add(evt) + + avg_list = EventList( + stats.values(), + use_device=self._use_device, + profile_memory=self._profile_memory, + with_flops=self._with_flops, + ) + for evt in avg_list: + evt.stack = evt.stack[:group_by_stack_n] + if not group_by_input_shapes: + evt.input_shapes = "" + if not group_by_overload_name: + evt.overload_name = "" + return avg_list + + def total_average(self): + """Averages all events. + + Returns: + A FunctionEventAvg object. + """ + total_stat = FunctionEventAvg() + for evt in self: + total_stat += evt + total_stat.key = None + total_stat.key = "Total" + return total_stat + + +def _format_time(time_us): + """Define how to format time in FunctionEvent.""" + US_IN_SECOND = 1000.0 * 1000.0 + US_IN_MS = 1000.0 + if time_us >= US_IN_SECOND: + return f"{time_us / US_IN_SECOND:.3f}s" + if time_us >= US_IN_MS: + return f"{time_us / US_IN_MS:.3f}ms" + return f"{time_us:.3f}us" + + +def _format_time_share(time_us, total_time_us): + """Define how to format time in FunctionEvent.""" + if total_time_us == 0: + assert time_us == 0, f"Expected time_us == 0 but got {time_us}" + return "NaN" + return f"{time_us * 100.0 / total_time_us:.2f}%" + + +def _format_memory(nbytes): + """Return a formatted memory size string.""" + KB = 1024 + MB = 1024 * KB + GB = 1024 * MB + if abs(nbytes) >= GB: + return f"{nbytes * 1.0 / GB:.2f} GB" + elif abs(nbytes) >= MB: + return f"{nbytes * 1.0 / MB:.2f} MB" + elif abs(nbytes) >= KB: + return f"{nbytes * 1.0 / KB:.2f} KB" + else: + return str(nbytes) + " B" + + +def _attr_formatter(name): + return property(lambda self: _format_time(getattr(self, name))) + + +class FormattedTimesMixin: + """Helpers for FunctionEvent and FunctionEventAvg. + + The subclass should define `*_time_total` and `count` attributes. + """ + + cpu_time_str = _attr_formatter("cpu_time") + device_time_str = _attr_formatter("device_time") + cpu_time_total_str = _attr_formatter("cpu_time_total") + device_time_total_str = _attr_formatter("device_time_total") + self_cpu_time_total_str = _attr_formatter("self_cpu_time_total") + self_device_time_total_str = _attr_formatter("self_device_time_total") + + @property + def cpu_time(self): + return 0.0 if self.count == 0 else 1.0 * self.cpu_time_total / self.count # type: ignore[attr-defined] + + @property + def device_time(self): + return 0.0 if self.count == 0 else 1.0 * self.device_time_total / self.count # type: ignore[attr-defined] + + @property + @deprecated( + "`cuda_time` is deprecated, please use `device_time` instead.", + category=FutureWarning, + ) + def cuda_time(self): # To be deprecated + return self.device_time + + +class Interval: + def __init__(self, start, end): + self.start = start + self.end = end + + def elapsed_us(self): + r""" + Returns the length of the interval + """ + return self.end - self.start + + +Kernel = namedtuple("Kernel", ["name", "device", "duration"]) + + +class FunctionEvent(FormattedTimesMixin): + """Profiling information about a single function.""" + + def __init__( + self, + id, + name, + thread, + start_us, + end_us, + overload_name=None, + fwd_thread=None, + input_shapes=None, + stack=None, + scope=0, + use_device=None, + cpu_memory_usage=0, + device_memory_usage=0, + is_async=False, + is_remote=False, + sequence_nr=-1, + node_id=-1, + device_type=DeviceType.CPU, + device_index=0, + device_resource_id=None, + is_legacy=False, + flops=None, + trace_name=None, + concrete_inputs=None, + kwinputs=None, + is_user_annotation=False, + ): + self.id: int = id + self.node_id: int = node_id + self.name: str = name + self.overload_name: str = overload_name + self.trace_name: str = trace_name + self.time_range: Interval = Interval(start_us, end_us) + self.thread: int = thread + self.fwd_thread: Optional[int] = fwd_thread + self.kernels: list[Kernel] = [] + self.count: int = 1 + self.cpu_children: list[FunctionEvent] = [] + self.cpu_parent: Optional[FunctionEvent] = None + self.input_shapes: tuple[int, ...] = input_shapes + self.concrete_inputs: list[Any] = concrete_inputs + self.kwinputs: dict[str, Any] = kwinputs + self.stack: list = stack + self.scope: int = scope + self.use_device: Optional[str] = use_device + self.cpu_memory_usage: int = cpu_memory_usage + self.device_memory_usage: int = device_memory_usage + self.is_async: bool = is_async + self.is_remote: bool = is_remote + self.sequence_nr: int = sequence_nr + self.device_type: DeviceType = device_type + self.device_index: int = device_index + self.device_resource_id: int = ( + thread if device_resource_id is None else device_resource_id + ) + self.is_legacy: bool = is_legacy + self.flops: Optional[int] = flops + self.is_user_annotation: Optional[bool] = is_user_annotation + self.self_cpu_percent = -1 + self.total_cpu_percent = -1 + self.total_device_percent = -1 + + def append_kernel(self, name, device, duration): + assert self.device_type == DeviceType.CPU + self.kernels.append(Kernel(name, device, duration)) + + def append_cpu_child(self, child): + """Append a CPU child of type FunctionEvent. + + One is supposed to append only direct children to the event to have + correct self cpu time being reported. + """ + assert self.device_type == DeviceType.CPU + assert isinstance(child, FunctionEvent) + assert child.device_type == DeviceType.CPU + self.cpu_children.append(child) + + def set_cpu_parent(self, parent): + """Set the immediate CPU parent of type FunctionEvent. + + One profiling FunctionEvent should have only one CPU parent such that + the child's range interval is completely inside the parent's. We use + this connection to determine the event is from top-level op or not. + """ + assert self.device_type == DeviceType.CPU + assert isinstance(parent, FunctionEvent) + assert parent.device_type == DeviceType.CPU + self.cpu_parent = parent + + # Note: async events don't have children, are not used when computing 'self' + # metrics of other events, have only total cpu time + @property + def self_cpu_memory_usage(self): + if self.is_async or self.device_type != DeviceType.CPU: + return 0 + return self.cpu_memory_usage - sum( + child.cpu_memory_usage for child in self.cpu_children + ) + + @property + def self_device_memory_usage(self): + if self.is_async or self.device_type != DeviceType.CPU: + return 0 + return self.device_memory_usage - sum( + child.device_memory_usage for child in self.cpu_children + ) + + @property + @deprecated( + "`self_cuda_memory_usage` is deprecated. Use `self_device_memory_usage` instead.", + category=FutureWarning, + ) + def self_cuda_memory_usage(self): # To be deprecated + return self.self_device_memory_usage + + @property + def cpu_time_total(self): + if self.device_type == DeviceType.CPU: + return self.time_range.elapsed_us() + else: + return 0 + + @property + def self_cpu_time_total(self): + if self.is_async or self.device_type != DeviceType.CPU: + return 0 + return self.cpu_time_total - sum( + child.cpu_time_total for child in self.cpu_children + ) + + @property + def device_time_total(self): + if self.is_async or not self.use_device: + return 0 + if self.device_type == DeviceType.CPU: + if not self.is_legacy: + # account for the kernels in the children ops + return sum(kinfo.duration for kinfo in self.kernels) + sum( + ch.device_time_total for ch in self.cpu_children + ) + else: + # each legacy cpu events has a single (fake) kernel + return sum(kinfo.duration for kinfo in self.kernels) + else: + assert self.device_type in [ + DeviceType.CUDA, + DeviceType.PrivateUse1, + DeviceType.MTIA, + DeviceType.HPU, + ] + return self.time_range.elapsed_us() + + @property + @deprecated( + "`cuda_time_total` is deprecated. Use `device_time_total` instead.", + category=FutureWarning, + ) + def cuda_time_total(self): # To be deprecated + return self.device_time_total + + @property + def self_device_time_total(self): + if self.is_async or not self.use_device: + return 0 + if self.device_type == DeviceType.CPU: + return self.device_time_total - sum( + child.device_time_total for child in self.cpu_children + ) + else: + assert self.device_type in [ + DeviceType.CUDA, + DeviceType.PrivateUse1, + DeviceType.MTIA, + DeviceType.HPU, + ] + return self.device_time_total + + @property + @deprecated( + "`self_cuda_time_total` is deprecated. Use `self_device_time_total` instead.", + category=FutureWarning, + ) + def self_cuda_time_total(self): # To be deprecated + return self.self_device_time_total + + @property + def key(self): + return self.name + + def __repr__(self): + device_name = self.use_device + device_time = self.device_time_str + device_memory_usage = self.device_memory_usage + return ( + f"" + ) + + +class FunctionEventAvg(FormattedTimesMixin): + """Used to average stats over multiple FunctionEvent objects.""" + + def __init__(self) -> None: + self.key: Optional[str] = None + self.count: int = 0 + self.node_id: int = 0 + self.is_async: bool = False + self.is_remote: bool = False + self.use_device: Optional[str] = None + self.cpu_time_total: int = 0 + self.device_time_total: int = 0 + self.self_cpu_time_total: int = 0 + self.self_device_time_total: int = 0 + self.input_shapes: Optional[list[list[int]]] = None + self.overload_name: Optional[str] = None + self.stack: Optional[list] = None + self.scope: Optional[int] = None + self.cpu_memory_usage: int = 0 + self.device_memory_usage: int = 0 + self.self_cpu_memory_usage: int = 0 + self.self_device_memory_usage: int = 0 + self.cpu_children: Optional[list[FunctionEvent]] = None + self.cpu_parent: Optional[FunctionEvent] = None + self.device_type: DeviceType = DeviceType.CPU + self.is_legacy: bool = False + self.flops: int = 0 + + def add(self, other): + if self.key is None: + # First function being recorded as part of FunctionEventAvg, propagate + # fields. + self.key = other.key + self.node_id = other.node_id + self.is_async = other.is_async + self.is_remote = other.is_remote + self.cpu_parent = other.cpu_parent + self.cpu_children = other.cpu_children + + self.overload_name = other.overload_name + self.input_shapes = other.input_shapes + self.stack = other.stack + self.scope = other.scope + self.device_type = other.device_type + self.is_legacy = other.is_legacy + self.use_device = other.use_device + self.is_user_annotation = other.is_user_annotation + + assert isinstance(other, (FunctionEvent, FunctionEventAvg)) + assert other.key == self.key + + self.cpu_time_total += other.cpu_time_total + self.device_time_total += other.device_time_total + self.self_cpu_time_total += other.self_cpu_time_total + self.self_device_time_total += other.self_device_time_total + self.cpu_memory_usage += other.cpu_memory_usage + self.device_memory_usage += other.device_memory_usage + self.self_cpu_memory_usage += other.self_cpu_memory_usage + self.self_device_memory_usage += other.self_device_memory_usage + self.count += other.count + if self.flops is None: + self.flops = other.flops + elif other.flops is not None: + self.flops += other.flops + return self + + def __iadd__(self, other): + return self.add(other) + + def __repr__(self): + device_name = "cuda" if not self.use_device else self.use_device + self_device_time = self.self_device_time_total_str + device_time = self.device_time_str + device_memory = self.device_memory_usage + return ( + f"" + ) + + +class StringTable(defaultdict): + def __missing__(self, key): + # manage cases like 't' (demangled to 'unsigned short') separately, + # for now simply check the length to avoid unexpected results for + # the short sequences + self[key] = torch._C._demangle(key) if len(key) > 1 else key + return self[key] + + +class MemRecordsAcc: + """Acceleration structure for accessing mem_records in interval.""" + + def __init__(self, mem_records): + self._mem_records = mem_records + self._start_nses: list[int] = [] + self._indices: list[int] = [] + if len(mem_records) > 0: + tmp = sorted([(r[0].start_ns(), i) for i, r in enumerate(mem_records)]) + self._start_nses, self._indices = zip(*tmp) # type: ignore[assignment] + + def in_interval(self, start_us, end_us): + r""" + Return all records in the given interval + To maintain backward compatibility, convert us to ns in function + """ + start_idx = bisect.bisect_left(self._start_nses, start_us * 1000) + end_idx = bisect.bisect_right(self._start_nses, end_us * 1000) + for i in range(start_idx, end_idx): + yield self._mem_records[self._indices[i]] + + +def _filter_stack_entry(entry): + filtered_entries = [ + ("autograd/__init__", "_make_grads"), + ("autograd/__init__", "backward"), + ("torch/tensor", "backward"), + ("_internal/common_utils", "prof_callable"), + ("_internal/common_utils", "prof_func_call"), + ("_internal/common_utils", "prof_meth_call"), + ] + return all(not (f[0] in entry and f[1] in entry) for f in filtered_entries) + + +MEMORY_EVENT_NAME = "[memory]" +OUT_OF_MEMORY_EVENT_NAME = "[OutOfMemory]" + + +def _filter_name(name): + # ignoring the following utility ops + filtered_out_names = [ + MEMORY_EVENT_NAME, # used only for the top-level memory events + OUT_OF_MEMORY_EVENT_NAME, + "profiler::_record_function_enter", + "profiler::_record_function_enter_new", + "profiler::_record_function_exit", + "aten::is_leaf", + "aten::output_nr", + "aten::_version", + ] + return name in filtered_out_names + + +# Demangles and optionally rewrites the provided event name, +# with_wildcard - whether to replace certain numbered event names +# with a wildcard name to aggregate them together in the profiler table +# output +def _rewrite_name(name, with_wildcard=False): + string_table = StringTable() + name = string_table[name] + if with_wildcard: + if name.startswith("ProfilerStep#"): + name = "ProfilerStep*" + return name + + +def _build_table( + events, + sort_by=None, + header=None, + row_limit=100, + max_src_column_width=75, + max_name_column_width=55, + max_shapes_column_width=80, + with_flops=False, + profile_memory=False, + top_level_events_only=False, +): + """Print a summary of events (which can be a list of FunctionEvent or FunctionEventAvg).""" + if len(events) == 0: + return "" + + has_device_time = any(event.self_device_time_total > 0 for event in events) + has_device_mem = any(event.self_device_memory_usage > 0 for event in events) + use_device = events[0].use_device + # Running on PrivateUse1 device with profiler but not enable + # ProfilerActivity.PrivateUse1 can also catch privateuse1 memory usage. + # Here only need to check has_privateuse1_time if not use_device. + if not use_device and has_device_time: + raise RuntimeError("use_device is None, but there is device performance data.") + + has_input_shapes = any( + (event.input_shapes is not None and len(event.input_shapes) > 0) + for event in events + ) + + has_overload_names = any( + (event.overload_name is not None and len(event.overload_name) > 0) + for event in events + ) + + if sort_by is not None: + events = EventList( + sorted( + events, + key=lambda evt: getattr( + evt, + sort_by.replace("cuda", "device") + .replace("xpu", "device") + .replace("privateuse1", "device"), + ), + reverse=True, + ), + use_device=use_device, + profile_memory=profile_memory, + with_flops=with_flops, + ) + + name_column_width = max(len(evt.key) for evt in events) + 4 + if max_name_column_width is not None: + name_column_width = min(name_column_width, max_name_column_width) + + shapes_column_width = max(len(str(evt.input_shapes)) for evt in events) + 4 + if max_shapes_column_width is not None: + shapes_column_width = min(shapes_column_width, max_shapes_column_width) + + DEFAULT_COLUMN_WIDTH = 12 + flops_column_width = DEFAULT_COLUMN_WIDTH + + src_column_width = None + stacks = [ + evt.stack for evt in events if evt.stack is not None and len(evt.stack) > 0 + ] + has_stack = len(stacks) > 0 + if has_stack: + src_column_width = ( + max(max(len(entry) for entry in stack) for stack in stacks) + 4 + ) + if max_src_column_width is not None: + src_column_width = min(src_column_width, max_src_column_width) + + headers = ["Name"] + if has_overload_names: + headers.append("Overload Name") + headers += [ + "Self CPU %", + "Self CPU", + "CPU total %", + "CPU total", + "CPU time avg", + ] + + device_name = use_device.upper() if use_device is not None else "None" + if has_device_time: + headers.extend( + [ + f"Self {device_name}", + f"Self {device_name} %", + f"{device_name} total", + f"{device_name} time avg", + ] + ) + if profile_memory: + headers.extend( + [ + "CPU Mem", + "Self CPU Mem", + ] + ) + if use_device and has_device_mem: + headers.extend( + [ + f"{device_name} Mem", + f"Self {device_name} Mem", + ] + ) + headers.append("# of Calls") + # Only append Node ID if any event has a valid (>= 0) Node ID + append_node_id = any(evt.node_id != -1 for evt in events) + if append_node_id: + headers.append("Node ID") + + # Have to use a list because nonlocal is Py3 only... + SPACING_SIZE = 2 + row_format_lst = [""] + header_sep_lst = [""] + line_length_lst = [-SPACING_SIZE] + + def add_column(padding, text_dir=">"): + row_format_lst[0] += ( + "{: " + text_dir + str(padding) + "}" + (" " * SPACING_SIZE) + ) + header_sep_lst[0] += "-" * padding + (" " * SPACING_SIZE) + line_length_lst[0] += padding + SPACING_SIZE + + def auto_scale_flops(flops): + flop_headers = [ + "FLOPs", + "KFLOPs", + "MFLOPs", + "GFLOPs", + "TFLOPs", + "PFLOPs", + ] + assert flops > 0 + log_flops = max(0, min(math.log10(flops) / 3, float(len(flop_headers) - 1))) + assert log_flops >= 0 and log_flops < len(flop_headers) + return (pow(10, (math.floor(log_flops) * -3.0)), flop_headers[int(log_flops)]) + + add_column(name_column_width) + if has_overload_names: + add_column(name_column_width) + for _ in headers[1 + has_overload_names :]: + add_column(DEFAULT_COLUMN_WIDTH) + + if has_input_shapes: + headers.append("Input Shapes") + add_column(shapes_column_width) + + if has_stack: + headers.append("Source Location") + add_column(src_column_width, text_dir="<") + + if with_flops: + # Auto-scaling of flops header + raw_flops = [evt.flops for evt in events if evt.flops > 0] + if len(raw_flops) != 0: + (flops_scale, flops_header) = auto_scale_flops(min(raw_flops)) + headers.append(f"Total {flops_header}") + add_column(flops_column_width) + else: + with_flops = False # can't find any valid flops + + row_format = row_format_lst[0] + header_sep = header_sep_lst[0] + line_length = line_length_lst[0] + add_column = None # type: ignore[assignment] + + # Have to use a list because nonlocal is Py3 only... + result = [] + + def append(s): + result.append(s) + result.append("\n") # Yes, newline after the end as well + + sum_self_cpu_time_total = 0 + sum_self_device_time_total = 0 + for evt in events: + sum_self_cpu_time_total += evt.self_cpu_time_total + if evt.device_type == DeviceType.CPU and evt.is_legacy: + # in legacy profiler, kernel info is stored in cpu events + sum_self_device_time_total += evt.self_device_time_total + elif ( + evt.device_type + in [ + DeviceType.CUDA, + DeviceType.PrivateUse1, + DeviceType.MTIA, + ] + and not evt.is_user_annotation + ): + # in kineto profiler, there're events with the correct device type (e.g. CUDA) + sum_self_device_time_total += evt.self_device_time_total + + # Actual printing + if header is not None: + append("=" * line_length) + append(header) + if top_level_events_only: + append("=" * line_length) + append("This report only display top-level ops statistics") + append(header_sep) + append(row_format.format(*headers)) + + append(header_sep) + + def trim_path(path, src_column_width): + if len(path) > src_column_width: + offset = len(path) - src_column_width + path = path[offset:] + if len(path) > 3: + path = "..." + path[3:] + return path + + event_limit = 0 + for evt in events: + if event_limit == row_limit: + break + if top_level_events_only and evt.cpu_parent is not None: + continue + else: + event_limit += 1 + name = evt.key + if max_name_column_width is not None and len(name) >= max_name_column_width - 3: + name = name[: (max_name_column_width - 3)] + "..." + + evt.self_cpu_percent = _format_time_share( + evt.self_cpu_time_total, sum_self_cpu_time_total + ) + evt.total_cpu_percent = ( + _format_time_share(evt.cpu_time_total, sum_self_cpu_time_total) + if not evt.is_async + else 0 + ) + + row_values = [name] + if has_overload_names: + overload_name = evt.overload_name + if ( + max_name_column_width is not None + and len(overload_name) >= max_name_column_width - 3 + ): + overload_name = overload_name[: (max_name_column_width - 3)] + "..." + row_values += [overload_name] + row_values += [ + # Self CPU total %, 0 for async events. + evt.self_cpu_percent, + evt.self_cpu_time_total_str, # Self CPU total + # CPU total %, 0 for async events. + evt.total_cpu_percent, + evt.cpu_time_total_str, # CPU total + evt.cpu_time_str, # CPU time avg + ] + if has_device_time: + evt.total_device_percent = _format_time_share( + evt.self_device_time_total, sum_self_device_time_total + ) + row_values.extend( + [ + evt.self_device_time_total_str, + # device time total % + evt.total_device_percent, + evt.device_time_total_str, + evt.device_time_str, # device time avg + ] + ) + if profile_memory: + row_values.extend( + [ + # CPU Mem Total + _format_memory(evt.cpu_memory_usage), + # Self CPU Mem Total + _format_memory(evt.self_cpu_memory_usage), + ] + ) + if use_device and has_device_mem: + row_values.extend( + [ + # Device Mem Total + _format_memory(evt.device_memory_usage), + # Self Device Mem Total + _format_memory(evt.self_device_memory_usage), + ] + ) + row_values.append( + evt.count, # Number of calls + ) + + if append_node_id: + row_values.append(evt.node_id) + if has_input_shapes: + row_values.append(str(evt.input_shapes)[:shapes_column_width]) + if with_flops: + if evt.flops <= 0: + row_values.append("--") + else: + row_values.append(f"{evt.flops * flops_scale:8.3f}") # type: ignore[possibly-undefined] + if has_stack: + src_field = "" + if len(evt.stack) > 0: + src_field = trim_path(evt.stack[0], src_column_width) + row_values.append(src_field) + append(row_format.format(*row_values)) + + if has_stack: + empty_headers = [""] * (len(headers) - 1) + for entry in evt.stack[1:]: + append( + row_format.format( + *(empty_headers + [trim_path(entry, src_column_width)]) + ) + ) + empty_headers.append("") + append(row_format.format(*empty_headers)) + + append(header_sep) + append(f"Self CPU time total: {_format_time(sum_self_cpu_time_total)}") + if has_device_time: + append( + f"Self {use_device.upper() if use_device is not None else 'None'} " + f"time total: {_format_time(sum_self_device_time_total)}" + ) + return "".join(result) diff --git a/phivenv/Lib/site-packages/torch/autograd/variable.py b/phivenv/Lib/site-packages/torch/autograd/variable.py new file mode 100644 index 0000000000000000000000000000000000000000..a81d0af68fd6d692b7f14dd8d0b46e515197131f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/autograd/variable.py @@ -0,0 +1,15 @@ +# mypy: allow-untyped-defs +import torch +from torch._C import _ImperativeEngine as ImperativeEngine + + +__all__ = ["VariableMeta", "Variable"] + + +class VariableMeta(type): + def __instancecheck__(cls, other): + return isinstance(other, torch.Tensor) + + +class Variable(torch._C._LegacyVariableBase, metaclass=VariableMeta): # type: ignore[misc] + _execution_engine = ImperativeEngine() diff --git a/phivenv/Lib/site-packages/torch/backends/__init__.py b/phivenv/Lib/site-packages/torch/backends/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4411a133df702104afdcb0c96dc027795d5a483c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/backends/__init__.py @@ -0,0 +1,73 @@ +# mypy: allow-untyped-defs +import types +from contextlib import contextmanager + + +# The idea for this parameter is that we forbid bare assignment +# to torch.backends..enabled and friends when running our +# test suite, where it's very easy to forget to undo the change +# later. +__allow_nonbracketed_mutation_flag = True + + +def disable_global_flags(): + global __allow_nonbracketed_mutation_flag + __allow_nonbracketed_mutation_flag = False + + +def flags_frozen(): + return not __allow_nonbracketed_mutation_flag + + +@contextmanager +def __allow_nonbracketed_mutation(): + global __allow_nonbracketed_mutation_flag + old = __allow_nonbracketed_mutation_flag + __allow_nonbracketed_mutation_flag = True + try: + yield + finally: + __allow_nonbracketed_mutation_flag = old + + +class ContextProp: + def __init__(self, getter, setter): + self.getter = getter + self.setter = setter + + def __get__(self, obj, objtype): + return self.getter() + + def __set__(self, obj, val): + if not flags_frozen(): + self.setter(val) + else: + raise RuntimeError( + f"not allowed to set {obj.__name__} flags " + "after disable_global_flags; please use flags() context manager instead" + ) + + +class PropModule(types.ModuleType): + def __init__(self, m, name): + super().__init__(name) + self.m = m + + def __getattr__(self, attr): + return self.m.__getattribute__(attr) + + +from torch.backends import ( + cpu as cpu, + cuda as cuda, + cudnn as cudnn, + cusparselt as cusparselt, + kleidiai as kleidiai, + mha as mha, + mkl as mkl, + mkldnn as mkldnn, + mps as mps, + nnpack as nnpack, + openmp as openmp, + quantized as quantized, +) diff --git a/phivenv/Lib/site-packages/torch/backends/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/backends/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6df9c140ad5fae68cf6690b76c88c12238da1c73 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/backends/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/backends/_coreml/__init__.py b/phivenv/Lib/site-packages/torch/backends/_coreml/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/backends/_coreml/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/backends/_coreml/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d616452cf5be43f4d0710909cd5d66cf77a6ca4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/backends/_coreml/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/backends/_coreml/__pycache__/preprocess.cpython-39.pyc b/phivenv/Lib/site-packages/torch/backends/_coreml/__pycache__/preprocess.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae3f9f420e589d14db7ce6ac5ad8b5fadb8eba2a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/backends/_coreml/__pycache__/preprocess.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/backends/_coreml/preprocess.py b/phivenv/Lib/site-packages/torch/backends/_coreml/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..dc3eff10be640388cdfb3a2509226211f0ea7569 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/backends/_coreml/preprocess.py @@ -0,0 +1,150 @@ +# mypy: allow-untyped-defs +import hashlib +import json + +import coremltools as ct # type: ignore[import] +from coremltools.converters.mil.input_types import TensorType # type: ignore[import] +from coremltools.converters.mil.mil import types # type: ignore[import] +from coremltools.models.neural_network import quantization_utils # type: ignore[import] + +import torch + + +CT_METADATA_VERSION = "com.github.apple.coremltools.version" +CT_METADATA_SOURCE = "com.github.apple.coremltools.source" + + +class ScalarType: + Float = 0 + Double = 1 + Int = 2 + Long = 3 + Undefined = 4 + + +# Supported Tensor types in coremltools: +# https://github.com/apple/coremltools/blob/main/coremltools/converters/mil/frontend/torch/converter.py#L28 +torch_to_mil_types = { + ScalarType.Float: types.fp32, + ScalarType.Double: types.fp64, + ScalarType.Int: types.int32, + ScalarType.Long: types.int64, +} + + +class CoreMLComputeUnit: + CPU = "cpuOnly" + CPUAndGPU = "cpuAndGPU" + ALL = "all" + + +class CoreMLQuantizationMode: + LINEAR = "linear" + LINEAR_SYMMETRIC = "linear_symmetric" + NONE = "none" + + +def TensorSpec(shape, dtype=ScalarType.Float): + return (shape, dtype) + + +def CompileSpec( + inputs, + outputs, + backend=CoreMLComputeUnit.CPU, + allow_low_precision=True, + quantization_mode=CoreMLQuantizationMode.NONE, + mlmodel_export_path=None, + convert_to=None, +): + return ( + inputs, + outputs, + backend, + allow_low_precision, + quantization_mode, + mlmodel_export_path, + convert_to, + ) + + +def _check_enumerated_shape(shape): + for s in shape: + if not isinstance(s, (list, tuple)): + return False + return True + + +def _convert_to_mil_type(shape, dtype, name: str): + mil_shape = shape + if _check_enumerated_shape(shape): + mil_shape = ct.EnumeratedShapes(shape) + ml_type = TensorType(shape=mil_shape, dtype=torch_to_mil_types[dtype]) + ml_type.name = name + return ml_type + + +def preprocess(script_module: torch._C.ScriptObject, compile_spec: dict[str, tuple]): + spec = compile_spec["forward"] + ( + input_specs, + output_specs, + backend, + allow_low_precision, + quantization_mode, + mlmodel_export_path, + convert_to, + ) = spec + mil_inputs = [] + inputs = [] + for index, input in enumerate(input_specs): + shape, dtype = input + name = "input_" + str(index) + inputs.append([name, str(dtype), str(shape)]) + ml_type = _convert_to_mil_type(shape, dtype, name) + mil_inputs.append(ml_type) + model = torch.jit.RecursiveScriptModule._construct(script_module, lambda x: None) + mlmodel = ct.convert(model, inputs=mil_inputs, convert_to=convert_to) + + if quantization_mode != CoreMLQuantizationMode.NONE: + quant_model_spec = quantization_utils.quantize_weights( + mlmodel, nbits=8, quantization_mode=quantization_mode + ) + mlmodel = ct.models.MLModel(quant_model_spec) + + spec = mlmodel.get_spec() + assert len(spec.description.output) == len(output_specs) # type: ignore[attr-defined] + outputs = [] + for index, output in enumerate(output_specs): + shape, dtype = output + name = spec.description.output[index].name # type: ignore[attr-defined] + outputs.append([name, str(dtype), str(shape)]) + mlmodel = ct.models.model.MLModel(spec) + print(mlmodel) + + if mlmodel_export_path is not None: + print(f"Saving CoreML .mlmodel file to {mlmodel_export_path}") + mlmodel.save(mlmodel_export_path) + + config = { + "spec_ver": str(spec.specificationVersion), # type: ignore[attr-defined] + "backend": backend, + "allow_low_precision": str(allow_low_precision), + } + metadata = { + "coremltool_ver": mlmodel.user_defined_metadata[CT_METADATA_VERSION], + "torch_ver": mlmodel.user_defined_metadata[CT_METADATA_SOURCE], + } + coreml_compile_spec = { + "inputs": inputs, + "outputs": outputs, + "config": config, + "metadata": metadata, + } + mlmodel = spec.SerializeToString() # type: ignore[attr-defined] + + return { + "model": mlmodel, + "hash": str(hashlib.sha256(mlmodel).hexdigest()), + "extra": json.dumps(coreml_compile_spec), + } diff --git a/phivenv/Lib/site-packages/torch/backends/_nnapi/__init__.py b/phivenv/Lib/site-packages/torch/backends/_nnapi/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/backends/_nnapi/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/backends/_nnapi/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..580cf4fec2a3c727e67ae6afde31aa1bace11c49 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/backends/_nnapi/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/backends/_nnapi/__pycache__/prepare.cpython-39.pyc b/phivenv/Lib/site-packages/torch/backends/_nnapi/__pycache__/prepare.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da988f36c386f26600b58985c01c572bcaa2c82a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/backends/_nnapi/__pycache__/prepare.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/backends/_nnapi/__pycache__/serializer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/backends/_nnapi/__pycache__/serializer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5fbe342afc89f7227b01fe43b7936691ec09a34e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/backends/_nnapi/__pycache__/serializer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/backends/_nnapi/prepare.py b/phivenv/Lib/site-packages/torch/backends/_nnapi/prepare.py new file mode 100644 index 0000000000000000000000000000000000000000..8303895d4eacdd938ec6dc932adbe5c5ba2ed4a7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/backends/_nnapi/prepare.py @@ -0,0 +1,199 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +from typing import Optional + +import torch +from torch.backends._nnapi.serializer import _NnapiSerializer + + +ANEURALNETWORKS_PREFER_LOW_POWER = 0 +ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER = 1 +ANEURALNETWORKS_PREFER_SUSTAINED_SPEED = 2 + + +class NnapiModule(torch.nn.Module): + """Torch Module that wraps an NNAPI Compilation. + + This module handles preparing the weights, initializing the + NNAPI TorchBind object, and adjusting the memory formats + of all inputs and outputs. + """ + + # _nnapi.Compilation is defined + comp: Optional[torch.classes._nnapi.Compilation] # type: ignore[name-defined] + weights: list[torch.Tensor] + out_templates: list[torch.Tensor] + + def __init__( + self, + shape_compute_module: torch.nn.Module, + ser_model: torch.Tensor, + weights: list[torch.Tensor], + inp_mem_fmts: list[int], + out_mem_fmts: list[int], + compilation_preference: int, + relax_f32_to_f16: bool, + ): + super().__init__() + self.shape_compute_module = shape_compute_module + self.ser_model = ser_model + self.weights = weights + self.inp_mem_fmts = inp_mem_fmts + self.out_mem_fmts = out_mem_fmts + self.out_templates = [] + self.comp = None + self.compilation_preference = compilation_preference + self.relax_f32_to_f16 = relax_f32_to_f16 + + @torch.jit.export + def init(self, args: list[torch.Tensor]): + assert self.comp is None + self.out_templates = self.shape_compute_module.prepare(self.ser_model, args) # type: ignore[operator] + self.weights = [w.contiguous() for w in self.weights] + comp = torch.classes._nnapi.Compilation() + comp.init2( + self.ser_model, + self.weights, + self.compilation_preference, + self.relax_f32_to_f16, + ) + + self.comp = comp + + def forward(self, args: list[torch.Tensor]) -> list[torch.Tensor]: + if self.comp is None: + self.init(args) + comp = self.comp + assert comp is not None + outs = [torch.empty_like(out) for out in self.out_templates] + + assert len(args) == len(self.inp_mem_fmts) + fixed_args = [] + for idx in range(len(args)): + fmt = self.inp_mem_fmts[idx] + # These constants match the values in DimOrder in serializer.py + # TODO: See if it's possible to use those directly. + if fmt == 0: + fixed_args.append(args[idx].contiguous()) + elif fmt == 1: + fixed_args.append(args[idx].permute(0, 2, 3, 1).contiguous()) + else: + raise ValueError("Invalid mem_fmt") + comp.run(fixed_args, outs) + assert len(outs) == len(self.out_mem_fmts) + for idx in range(len(self.out_templates)): + fmt = self.out_mem_fmts[idx] + # These constants match the values in DimOrder in serializer.py + # TODO: See if it's possible to use those directly. + if fmt in (0, 2): + pass + elif fmt == 1: + outs[idx] = outs[idx].permute(0, 3, 1, 2) + else: + raise ValueError("Invalid mem_fmt") + return outs + + +def convert_model_to_nnapi( + model, + inputs, + serializer=None, + return_shapes=None, + use_int16_for_qint16=False, + compilation_preference=ANEURALNETWORKS_PREFER_SUSTAINED_SPEED, + relax_f32_to_f16=False, +): + ( + shape_compute_module, + ser_model_tensor, + used_weights, + inp_mem_fmts, + out_mem_fmts, + retval_count, + ) = process_for_nnapi( + model, inputs, serializer, return_shapes, use_int16_for_qint16 + ) + + nnapi_model = NnapiModule( + shape_compute_module, + ser_model_tensor, + used_weights, + inp_mem_fmts, + out_mem_fmts, + compilation_preference, + relax_f32_to_f16, + ) + + class NnapiInterfaceWrapper(torch.nn.Module): + """NNAPI list-ifying and de-list-ifying wrapper. + + NNAPI always expects a list of inputs and provides a list of outputs. + This module allows us to accept inputs as separate arguments. + It returns results as either a single tensor or tuple, + matching the original module. + """ + + def __init__(self, mod): + super().__init__() + self.mod = mod + + wrapper_model_py = NnapiInterfaceWrapper(nnapi_model) + wrapper_model = torch.jit.script(wrapper_model_py) + # TODO: Maybe make these names match the original. + arg_list = ", ".join(f"arg_{idx}" for idx in range(len(inputs))) + if retval_count < 0: + ret_expr = "retvals[0]" + else: + ret_expr = "".join(f"retvals[{idx}], " for idx in range(retval_count)) + wrapper_model.define( + f"def forward(self, {arg_list}):\n" + f" retvals = self.mod([{arg_list}])\n" + f" return {ret_expr}\n" + ) + return wrapper_model + + +def process_for_nnapi( + model, inputs, serializer=None, return_shapes=None, use_int16_for_qint16=False +): + model = torch.jit.freeze(model) + + if isinstance(inputs, torch.Tensor): + inputs = [inputs] + + serializer = serializer or _NnapiSerializer( + config=None, use_int16_for_qint16=use_int16_for_qint16 + ) + ( + ser_model, + used_weights, + inp_mem_fmts, + out_mem_fmts, + shape_compute_lines, + retval_count, + ) = serializer.serialize_model(model, inputs, return_shapes) + ser_model_tensor = torch.tensor(ser_model, dtype=torch.int32) + + # We have to create a new class here every time this function is called + # because module.define adds a method to the *class*, not the instance. + class ShapeComputeModule(torch.nn.Module): + """Code-gen-ed module for tensor shape computation. + + module.prepare will mutate ser_model according to the computed operand + shapes, based on the shapes of args. Returns a list of output templates. + """ + + shape_compute_module = torch.jit.script(ShapeComputeModule()) + real_shape_compute_lines = [ + "def prepare(self, ser_model: torch.Tensor, args: List[torch.Tensor]) -> List[torch.Tensor]:\n", + ] + [f" {line}\n" for line in shape_compute_lines] + shape_compute_module.define("".join(real_shape_compute_lines)) + + return ( + shape_compute_module, + ser_model_tensor, + used_weights, + inp_mem_fmts, + out_mem_fmts, + retval_count, + ) diff --git a/phivenv/Lib/site-packages/torch/backends/_nnapi/serializer.py b/phivenv/Lib/site-packages/torch/backends/_nnapi/serializer.py new file mode 100644 index 0000000000000000000000000000000000000000..7e8f4e820ff7bf308db24791ee85f29e18d98747 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/backends/_nnapi/serializer.py @@ -0,0 +1,2228 @@ +# mypy: allow-untyped-defs +import array +import enum +import functools +import logging +import operator +import struct +import sys +from typing import NamedTuple, Optional + +import torch + + +# TODO: Add type annotations +# TODO: Check tensor types for ops + + +LOG = logging.getLogger("nnapi_serialize") + + +class NNAPI_OperandCode: + FLOAT32 = 0 + INT32 = 1 + UINT32 = 2 + TENSOR_FLOAT32 = 3 + TENSOR_INT32 = 4 + TENSOR_QUANT8_ASYMM = 5 + BOOL = 6 + TENSOR_QUANT16_SYMM = 7 + TENSOR_FLOAT16 = 8 + TENSOR_BOOL8 = 9 + FLOAT16 = 10 + TENSOR_QUANT8_SYMM_PER_CHANNEL = 11 + TENSOR_QUANT16_ASYMM = 12 + + +class NNAPI_OperationCode: + ADD = 0 + AVERAGE_POOL_2D = 1 + CONCATENATION = 2 + CONV_2D = 3 + DEPTHWISE_CONV_2D = 4 + DEPTH_TO_SPACE = 5 + DEQUANTIZE = 6 + EMBEDDING_LOOKUP = 7 + FLOOR = 8 + FULLY_CONNECTED = 9 + HASHTABLE_LOOKUP = 10 + L2_NORMALIZATION = 11 + L2_POOL_2D = 12 + LOCAL_RESPONSE_NORMALIZATION = 13 + LOGISTIC = 14 + LSH_PROJECTION = 15 + LSTM = 16 + MAX_POOL_2D = 17 + MUL = 18 + RELU = 19 + RELU1 = 20 + RELU6 = 21 + RESHAPE = 22 + RESIZE_BILINEAR = 23 + RNN = 24 + SOFTMAX = 25 + SPACE_TO_DEPTH = 26 + SVDF = 27 + TANH = 28 + BATCH_TO_SPACE_ND = 29 + DIV = 30 + MEAN = 31 + PAD = 32 + SPACE_TO_BATCH_ND = 33 + SQUEEZE = 34 + STRIDED_SLICE = 35 + SUB = 36 + TRANSPOSE = 37 + ABS = 38 + ARGMAX = 39 + ARGMIN = 40 + AXIS_ALIGNED_BBOX_TRANSFORM = 41 + BIDIRECTIONAL_SEQUENCE_LSTM = 42 + BIDIRECTIONAL_SEQUENCE_RNN = 43 + BOX_WITH_NMS_LIMIT = 44 + CAST = 45 + CHANNEL_SHUFFLE = 46 + DETECTION_POSTPROCESSING = 47 + EQUAL = 48 + EXP = 49 + EXPAND_DIMS = 50 + GATHER = 51 + GENERATE_PROPOSALS = 52 + GREATER = 53 + GREATER_EQUAL = 54 + GROUPED_CONV_2D = 55 + HEATMAP_MAX_KEYPOINT = 56 + INSTANCE_NORMALIZATION = 57 + LESS = 58 + LESS_EQUAL = 59 + LOG = 60 + LOGICAL_AND = 61 + LOGICAL_NOT = 62 + LOGICAL_OR = 63 + LOG_SOFTMAX = 64 + MAXIMUM = 65 + MINIMUM = 66 + NEG = 67 + NOT_EQUAL = 68 + PAD_V2 = 69 + POW = 70 + PRELU = 71 + QUANTIZE = 72 + QUANTIZED_16BIT_LSTM = 73 + RANDOM_MULTINOMIAL = 74 + REDUCE_ALL = 75 + REDUCE_ANY = 76 + REDUCE_MAX = 77 + REDUCE_MIN = 78 + REDUCE_PROD = 79 + REDUCE_SUM = 80 + ROI_ALIGN = 81 + ROI_POOLING = 82 + RSQRT = 83 + SELECT = 84 + SIN = 85 + SLICE = 86 + SPLIT = 87 + SQRT = 88 + TILE = 89 + TOPK_V2 = 90 + TRANSPOSE_CONV_2D = 91 + UNIDIRECTIONAL_SEQUENCE_LSTM = 92 + UNIDIRECTIONAL_SEQUENCE_RNN = 93 + RESIZE_NEAREST_NEIGHBOR = 94 + + +class NNAPI_FuseCode: + FUSED_NONE = 0 + FUSED_RELU = 1 + FUSED_RELU1 = 2 + FUSED_RELU6 = 3 + + +class OperandValueSourceType: + IMMEDIATE = 0 + NUMBERED_BUFFER = 2 + NUMBERED_MEMORY = 3 + + +# Scalar types that appear explicitly in models. +# These must be kept in sync with +# AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS. +# TODO: Expose these directly to Python to avoid maintaining this list. +class TorchScalarTypes(enum.Enum): + QUINT8 = 13 + + +def approx_equal(lhs, rhs, tolerance=1e-6): + return abs(lhs - rhs) <= tolerance * min(lhs, rhs) + + +def tensor_size(op_type, dims): + ITEM_SIZES = { + NNAPI_OperandCode.TENSOR_FLOAT32: 4, + NNAPI_OperandCode.TENSOR_INT32: 4, + NNAPI_OperandCode.TENSOR_QUANT8_ASYMM: 1, + NNAPI_OperandCode.TENSOR_QUANT16_SYMM: 2, + NNAPI_OperandCode.TENSOR_QUANT16_ASYMM: 2, + } + size = ITEM_SIZES[op_type] + for d in dims: + size *= d + return size + + +def change_element(tup, index, value): + ls = list(tup) + ls[index] = value + return tuple(ls) + + +class ConvPoolArgs2d(NamedTuple): + """Configuration arguments for a convolution.""" + + kernel_h: int + kernel_w: int + stride_h: int + stride_w: int + pad_t: int + pad_b: int + pad_l: int + pad_r: int + dilation_h: int + dilation_w: int + group: int + + +class DimOrder(enum.Enum): + PRESUMED_CONTIGUOUS = 0 + CHANNELS_LAST = 1 + SCALAR_OR_VECTOR = 2 + UNKNOWN_CONSTANT = 999 + + +class Operand(NamedTuple): + """Represenation of an NNAPI operand.""" + + # NNAPI operand type. One of NNAPI_OperandCode. + # TODO: Make this an enum. + op_type: int + + # This is always the PyTorch shape, which is NCHW for feature maps. + # The actual NNAPI operand might have a transposed shape. + # we use 0 for load time dynamic shapes & -1 for runtime dynamic shapes + shape: tuple[int, ...] + + # Specifies how the shape of the operand that we define in NNAPI + # relates to the shape we track above. + # - PRESUMED_CONTIGUOUS: physical NNAPI operand will exactly match + # the shape of the PyTorch tensor. + # - CHANNELS_LAST: The PyTorch tensor is expected to be NCHW, and + # the NNAPI operand will be represented explicitly as NHWC. + dim_order: DimOrder + + # Quantization params + scale: float + zero_point: int + + def use_nchw(self): + if self.dim_order is DimOrder.PRESUMED_CONTIGUOUS: + return True + if self.dim_order is DimOrder.CHANNELS_LAST: + return False + raise Exception("Unknown dim order") # noqa: TRY002 + + +def broadcast_shapes(shape1, shape2): + assert len(shape1) > 0 + assert len(shape2) > 0 + s1 = list(shape1) + s2 = list(shape2) + # TODO: Support non-equal-rank broadcast where semantics match. + # This can be tricky for NHWC tensors because dimension orders + # don't match between PT and NNAPI, even though semantics match. + if len(s1) > len(s2): + # s2 = [1] * (len(s1) - len(s2)) + s2 + raise Exception( # noqa: TRY002 + "Non-equal-rank broadcast is not supported yet." + ) # noqa: TRY002 + if len(s2) > len(s1): + # s3 = [1] * (len(s2) - len(s1)) + s1 + raise Exception( # noqa: TRY002 + "Non-equal-rank broadcast is not supported yet." + ) # noqa: TRY002 + ret = [] + for d1, d2 in zip(s1, s2): + if d1 == 1: + ret.append(d2) + elif d2 == 1: + ret.append(d1) + elif d1 == d2: + ret.append(d1) + else: + raise Exception( # noqa: TRY002 + f"Cannot broadcast shapes: {shape1} and {shape2}" + ) # noqa: TRY002 + return tuple(ret) + + +def get_conv_pool_shape(image_shape, args, out_ch, transpose): + batch, _in_c, in_h, in_w = image_shape + + # TODO: Handle dilation + if args.dilation_h != 1 or args.dilation_w != 1: + raise Exception("Dilation not supported yet.") # noqa: TRY002 + + if transpose: + out_h = (in_h - 1) * args.stride_h + args.kernel_h - args.pad_t - args.pad_b + out_w = (in_w - 1) * args.stride_w + args.kernel_w - args.pad_l - args.pad_l + else: + out_h = (in_h - args.kernel_h + args.pad_t + args.pad_b) // args.stride_h + 1 + out_w = (in_w - args.kernel_w + args.pad_l + args.pad_r) // args.stride_w + 1 + + # Handle variable-sized tensors. + if in_h == 0: + out_h = 0 + if in_w == 0: + out_w = 0 + + out_shape = (batch, out_ch, out_h, out_w) + return out_shape + + +def fix_shape(shape, dim_order): + # Return the actual shape that an operand should have in NNAPI, + # given a PyTorch shape and dimension order. This is where we + # convert from PyTorch's "always NCHW" shape to explicit NHWC. + if dim_order is DimOrder.PRESUMED_CONTIGUOUS: + return shape + if dim_order is DimOrder.CHANNELS_LAST: + return tuple([shape[0]] + list(shape[2:]) + [shape[1]]) + if dim_order is DimOrder.SCALAR_OR_VECTOR: + assert len(shape) == 0 or len(shape) == 1 + return shape + if dim_order is DimOrder.UNKNOWN_CONSTANT: + # XXX think this through + return shape + raise Exception(f"Bad dim_order: {dim_order!r}.") # noqa: TRY002 + + +def reverse_map_dim(dim_order, d): + # Return the original PyTorch dimension position for a given dimension. + # d should be the dimension that NNAPI will see. + # reverse_map_dim(PRESUMED_CONTIGUOUS, x) == x + # reverse_map_dim(CHANNELS_LAST, 3) == 1 + if dim_order in (DimOrder.PRESUMED_CONTIGUOUS, DimOrder.SCALAR_OR_VECTOR): + return d + assert dim_order is DimOrder.CHANNELS_LAST + return [0, 2, 3, 1][d] + + +def flex_name(op_id, dim): + # Return the local variable name for the computed flexible size + # for a given op and dimension. + return f"s_{op_id}_{dim}" + + +class _NnapiSerializer: + def __init__(self, config, use_int16_for_qint16=False): + self.operands = [] + self.values = [] + self.operations = [] + self.value_data = [] + self.operation_args = [] + self.inputs = [] + self.outputs = [] + self.flexible_shape_computation_lines = [] + + self.modules = {} + self.constants = {} + self.tensor_sequences = {} + self.jitval_operand_map = {} + self.cached_immediates = {} + self.used_weights = [] + self.weight_offset = 0 + self.use_int16_for_qint16 = use_int16_for_qint16 + + if config is None: + config = {} + + def get_next_operand_id(self): + return len(self.operands) + + # Add a tensor operand corresponding to a JIT Value. + # Returns the NNAPI operand ID. Can be looked up later with + # get_tensor_operand_by_jitval. + def add_tensor_operand(self, jitval, oper): + assert isinstance(oper, Operand) + if jitval in self.jitval_operand_map: + raise Exception(f"Duplicate tensor: {jitval!r}") # noqa: TRY002 + + operand_id = self.get_next_operand_id() + self.operands.append(oper) + self.jitval_operand_map[jitval] = operand_id + return operand_id + + # Add a tensor operand that does not correspond to a JIT Value. + # Useful for cases where multiple NNAPI operands are required + # to implement one JIT IR node. Returns the NNAPI operand ID. + def add_anonymous_tensor_operand(self, oper): + assert isinstance(oper, Operand) + operand_id = self.get_next_operand_id() + self.operands.append(oper) + return operand_id + + def torch_tensor_to_operand(self, tensor, dim_order): + dtype = str(tensor.dtype).replace("torch.", "") + scale = 0.0 + zero_point = 0 + if dtype == "float32": + op_type = NNAPI_OperandCode.TENSOR_FLOAT32 + elif dtype == "int32": + op_type = NNAPI_OperandCode.TENSOR_INT32 + elif dtype == "quint8": + op_type = NNAPI_OperandCode.TENSOR_QUANT8_ASYMM + scale = tensor.q_scale() + zero_point = tensor.q_zero_point() + elif dtype == "qint32": + op_type = NNAPI_OperandCode.TENSOR_INT32 + scale = tensor.q_scale() + zero_point = tensor.q_zero_point() + assert zero_point == 0 + elif dtype == "int16": + if self.use_int16_for_qint16: + nnapi_dtype = getattr(tensor, "nnapi_dtype", None) + op_codes = ( + NNAPI_OperandCode.TENSOR_QUANT16_SYMM, + NNAPI_OperandCode.TENSOR_QUANT16_ASYMM, + ) + if nnapi_dtype in op_codes: + op_type = nnapi_dtype + scale = tensor.nnapi_scale + zero_point = tensor.nnapi_zero_point + else: + raise Exception( # noqa: TRY002 + f"`nnapi_type` needs to be one of {op_codes} for `int16`" + ) + else: + raise Exception( # noqa: TRY002 + "`int16` isn't supported. If you're trying to represent NNAPI" + " qint16 with Pytorch int16, set `use_int16_for_qint16 = True`" + ) + else: + raise Exception( # noqa: TRY002 + f"Can't handle input with dtype '{tensor.dtype}'" + ) # noqa: TRY002 + return Operand( + shape=tuple(tensor.shape), + op_type=op_type, + dim_order=dim_order, + scale=scale, + zero_point=zero_point, + ) + + def add_tensor_operand_for_input(self, arg_idx, jitval, tensor): + dim_order = ( + DimOrder.CHANNELS_LAST + if getattr(tensor, "nnapi_nhwc", False) + else DimOrder.PRESUMED_CONTIGUOUS + ) + toper = self.torch_tensor_to_operand(tensor, dim_order) + operand_id = self.add_tensor_operand(jitval, toper) + self.inputs.append(operand_id) + for dim, size in enumerate(tensor.shape): + if size == 0: + self.compute_operand_shape( + operand_id, dim, f"args[{arg_idx}].shape[{dim}]" + ) + return operand_id + + def add_tensor_operand_for_weight( + self, tensor, dim_order=DimOrder.UNKNOWN_CONSTANT + ): + toper = self.torch_tensor_to_operand(tensor, dim_order) + operand_id = len(self.operands) + self.operands.append(toper) + tsize = tensor_size(toper.op_type, toper.shape) + self.values.append((operand_id, OperandValueSourceType.NUMBERED_BUFFER)) + buf_num = len(self.used_weights) + offset = 0 + self.value_data.append(struct.pack("iii", buf_num, offset, tsize)) + # For NHWC NNAPI op, lay out data in the same dim order by permuting torch tensor + if dim_order == DimOrder.CHANNELS_LAST: + tensor = tensor.permute(0, 2, 3, 1) + self.used_weights.append(tensor) + return operand_id + + def add_immediate_operand(self, code, value, dims): + assert isinstance(dims, tuple) + cache_key = (code, value) + if cache_key not in self.cached_immediates: + operand_id = len(self.operands) + self.operands.append(Operand(code, dims, DimOrder.SCALAR_OR_VECTOR, 0.0, 0)) + self.values.append((operand_id, OperandValueSourceType.IMMEDIATE)) + self.value_data.append(value) + self.cached_immediates[cache_key] = operand_id + return self.cached_immediates[cache_key] + + def add_immediate_int_scalar(self, value): + return self.add_immediate_operand( + NNAPI_OperandCode.INT32, struct.pack("i", value), () + ) + + def add_immediate_float_scalar(self, value): + return self.add_immediate_operand( + NNAPI_OperandCode.FLOAT32, struct.pack("f", value), () + ) + + def add_immediate_bool_scalar(self, value): + return self.add_immediate_operand( + NNAPI_OperandCode.BOOL, b"\x01" if value else b"\x00", () + ) + + def add_immediate_int_vector(self, value): + return self.add_immediate_operand( + NNAPI_OperandCode.TENSOR_INT32, + array.array("i", value).tobytes(), + (len(value),), + ) + + def has_operand_for_jitval(self, jitval): + return jitval in self.jitval_operand_map + + def get_tensor_operand_by_jitval(self, jitval): + operand_id = self.jitval_operand_map[jitval] + return (operand_id, self.operands[operand_id]) + + def get_tensor_operand_by_jitval_fixed_size(self, jitval): + op_id, oper = self.get_tensor_operand_by_jitval(jitval) + for s in oper.shape: + if s == 0: + # TODO: Improve this error message, possibly after converting + # many callsites to support flexible size. + raise Exception( # noqa: TRY002 + "Flexible size is not supported for this operand." + ) # noqa: TRY002 + if s < 0: + # runtime flex + LOG.warning("Operand %s has runtime flex shape", oper) + return op_id, oper + + def get_tensor_operand_or_constant( + self, jitval, dim_order=DimOrder.PRESUMED_CONTIGUOUS + ): + operand_id = self.jitval_operand_map.get(jitval) + if operand_id is None: + _, value = self.get_constant_value(jitval, "TensorType") + operand_id = self.add_tensor_operand_for_weight(value, dim_order) + return (operand_id, self.operands[operand_id]) + + def get_tensor_operand_for_weight(self, jitval): + _, value = self.get_constant_value(jitval, "TensorType") + operand_id = self.add_tensor_operand_for_weight(value) + return (operand_id, self.operands[operand_id]) + + def add_operation(self, opcode, inputs, outputs): + self.operations.append((opcode, len(inputs), len(outputs))) + self.operation_args.extend(inputs + outputs) + + def add_tensor_sequence(self, jitval, values): + assert jitval not in self.tensor_sequences + self.tensor_sequences[jitval] = values + + def add_constant_value(self, jitval, ctype, value): + assert jitval not in self.constants + self.constants[jitval] = (ctype, value) + + def get_constant_value(self, jitval, typekind=None): + record = self.constants.get(jitval) + if record is None: + raise Exception( # noqa: TRY002 + f"Could not find constant value for '{jitval!r}'." + ) # noqa: TRY002 + ctype, _ = record + if typekind is not None and ctype.kind() != typekind: + raise Exception( # noqa: TRY002 + f"Expected constant value of type {typekind}, but got {ctype.kind()} for value '{jitval!r}'" + ) + return record + + def operand_to_template_torchscript(self, op_id, oper, shape=None): + """Return a TorchScript expression to build a template for a given operand.""" + if shape is None: + shape = oper.shape + else: + assert len(shape) == len(oper.shape) + + shape_parts = ["("] + for d, s in enumerate(shape): + if s > 0: + # Fixed shape dimension: just add the value. + shape_parts.append(str(s)) + elif s == 0: + # Load time flexible shape dimension: it should have been computed in a variable. + shape_parts.append(flex_name(op_id, d)) + elif s == -1: + # Runtime flexible shape + shape_parts.append("0") + else: + raise Exception( # noqa: TRY002 + "Unknown dim value, dimensions should be >= -1" + ) # noqa: TRY002 + shape_parts.append(",") + shape_parts.append(")") + shape_code = "".join(shape_parts) + if oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32: + return f"torch.zeros({shape_code}, dtype=torch.float32)" + elif oper.op_type == NNAPI_OperandCode.TENSOR_INT32: + return f"torch.zeros({shape_code}, dtype=torch.int32)" + elif oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM: + return ( + f"torch.quantize_per_tensor(" + f"torch.zeros(1), scale={oper.scale}, zero_point={oper.zero_point}, dtype=torch.quint8)" + f".expand({shape_code}).contiguous()" + ) + elif oper.op_type in ( + NNAPI_OperandCode.TENSOR_QUANT16_ASYMM, + NNAPI_OperandCode.TENSOR_QUANT16_SYMM, + ): + if self.use_int16_for_qint16: + return f"torch.zeros({shape_code}, dtype=torch.int16)" + else: + raise Exception( # noqa: TRY002 + "`int16` isn't supported. If you're trying to represent NNAPI" + " qint16 with Pytorch int16, set `use_int16_for_qint16 = True`" + ) + + raise Exception( # noqa: TRY002 + f"Unsupported output operand type: {oper.op_type}" + ) # noqa: TRY002 + + def forward_operand_shape(self, out_op_id, out_dim, in_op_id, in_dim): + self.compute_operand_shape(out_op_id, out_dim, flex_name(in_op_id, in_dim)) + + def compute_operand_shape(self, op_id, dim, expr): + self.flexible_shape_computation_lines.append( + f"{flex_name(op_id, dim)} = {expr}" + ) + + def transpose_to_nhwc(self, in_id, oper): + if oper.shape[2:] != (1, 1): + raise Exception( # noqa: TRY002 + "Automatic transpose only supported for H,W == 1,1" + ) # noqa: TRY002 + + out_oper = oper._replace(dim_order=DimOrder.CHANNELS_LAST) + + inputs = [None] * 2 + inputs[0] = in_id + inputs[1] = self.add_immediate_int_vector([0, 2, 3, 1]) + + outputs = [None] * 1 + outputs[0] = self.add_anonymous_tensor_operand(out_oper) + + self.add_operation(NNAPI_OperationCode.TRANSPOSE, inputs, outputs) + + return outputs[0], out_oper + + # Transpose inputs as necessary to allow broadcasting. + def transpose_for_broadcast(self, in0_id, in0_oper, in1_id, in1_oper): + if in0_oper.dim_order == in1_oper.dim_order: + return in0_id, in0_oper, in1_id, in1_oper + + # Assume NHWC is preferred if there is a mismatch. + orders = (in0_oper.dim_order, in1_oper.dim_order) + if orders == (DimOrder.PRESUMED_CONTIGUOUS, DimOrder.CHANNELS_LAST): + return self.transpose_to_nhwc(in0_id, in0_oper) + (in1_id, in1_oper) + if orders == (DimOrder.CHANNELS_LAST, DimOrder.PRESUMED_CONTIGUOUS): + return (in0_id, in0_oper) + self.transpose_to_nhwc(in1_id, in1_oper) + + raise Exception( # noqa: TRY002 + f"Automatic transpose not supported for dim_orders: {in0_oper.dim_order!r}, {in1_oper.dim_order!r}" + ) + + def get_size_arg(self, jitval): + ctype, value = self.get_constant_value(jitval) + if ctype.kind() == "ListType": + assert ctype.getElementType().kind() == "IntType" + return value + raise Exception( # noqa: TRY002 + f"Can't handle size arg of type '{ctype!r}' for '{jitval!r}'" + ) # noqa: TRY002 + + def get_conv_pool_args_2d_from_pack(self, kernel_size, packed_config): + pc = [i.item() for i in packed_config] + assert pc[0] == 2 + strides = [pc[1], pc[2]] + paddings = [pc[3], pc[4]] + dilations = [pc[5], pc[6]] + output_padding = [pc[7], pc[8]] + group_num = pc[9] + + assert len(pc) == 11 + assert output_padding == [0, 0] + + return self.get_conv_pool_args_2d_common( + kernel_size, strides, paddings, dilations, group_num + ) + + def get_conv_pool_args_2d_from_jit( + self, kernel_size, stride, padding, dilation=None, group=None + ): + strides = self.get_size_arg(stride) + paddings = self.get_size_arg(padding) + if dilation is None: + dilations = [1, 1] + else: + dilations = self.get_size_arg(dilation) + if group is not None: + _, group_num = self.get_constant_value(group, "IntType") + else: + group_num = None + return self.get_conv_pool_args_2d_common( + kernel_size, strides, paddings, dilations, group_num + ) + + def get_conv_pool_args_2d_common( + self, kernel_size, strides, paddings, dilations, group_num + ): + kernels = list(kernel_size) + + assert len(kernels) == 2 + assert len(strides) == 2 + assert len(paddings) == 2 + assert len(dilations) == 2 + + # NNAPI uses 4 values for padding. + ph, pw = paddings + real_paddings = [ph, ph, pw, pw] + + return ConvPoolArgs2d( + *(kernels + strides + real_paddings + dilations + [group_num]) + ) + + def serialize_model(self, model, inputs, return_shapes=None): + self.add_immediate_bool_scalar(False) + self.add_immediate_bool_scalar(True) + + inp_dim_orders = [] + out_dim_orders = [] + + self_jitval = next(model.graph.inputs()) + self.add_constant_value(self_jitval, self_jitval.type(), model) + + for arg_idx, (input_value, input_tensor) in enumerate( + zip(list(model.graph.inputs())[1:], inputs) + ): + op_id = self.add_tensor_operand_for_input( + arg_idx, input_value, input_tensor + ) + inp_dim_orders.append(self.operands[op_id].dim_order.value) + + for idx, node in enumerate(model.graph.nodes()): + LOG.debug("Processing node #%d: %r", idx, node) + self.add_node(node) + + retn = model.graph.return_node() + assert retn.inputsSize() == 1 + assert retn.outputsSize() == 0 + retn_input = retn.inputsAt(0) + template_return_lines = ["return ["] + if retn_input.type().kind() == "TensorType": + return_values = [retn_input] + retval_count = -1 + elif retn_input.type().kind() == "TupleType": + return_values = self.tensor_sequences[retn_input] + retval_count = len(return_values) + else: + raise Exception( # noqa: TRY002 + f"Unsupported return type: {retn_input.type()}" + ) # noqa: TRY002 + + if return_shapes is not None: + assert len(return_shapes) == len(return_values) + for i, v in enumerate(return_values): + op_id = self.jitval_operand_map[v] + self.outputs.append(op_id) + out_dim_orders.append(self.operands[op_id].dim_order.value) + shape = return_shapes[i] if return_shapes else None + template_return_lines.append( + self.operand_to_template_torchscript(op_id, self.operands[op_id], shape) + + "," + ) + template_return_lines.append("]") + + model = [] + + version = 1 + header = struct.pack( + "iiiiii", + version, + len(self.operands), + len(self.values), + len(self.operations), + len(self.inputs), + len(self.outputs), + ) + model.append(header) + + serialized_values, serialized_value_data = self.serialize_values() + + model.extend( + struct.pack("iifi", t, len(d), s, z) for (t, d, _m, s, z) in self.operands + ) + model.extend(serialized_values) + model.extend(struct.pack("iii", *x) for x in self.operations) + + # Compact the model so we can get its length so far. + model = [b"".join(model)] + model_offset = len(model[0]) + # Model offset is the index into the model (in 32-bit words, not bytes) + # of the next dimension we're about to serialize. If it's 0, + # generate code to mutate it before passing to NNAPI. + assert model_offset % 4 == 0 + model_offset = int(model_offset / 4) + + for op_id, (_, dims, dim_order, _, _) in enumerate(self.operands): + shape = fix_shape(dims, dim_order) + for d, s in enumerate(shape): + if s == 0: + pt_d = reverse_map_dim(dim_order, d) + self.flexible_shape_computation_lines.append( + f"ser_model[{model_offset}] = {flex_name(op_id, pt_d)}" + ) + model_offset += 1 + + # convert runtime flex shape from -1 to 0 + shape = tuple(d if d != -1 else 0 for d in shape) + model.append(self.serialize_ints(shape)) + + model.extend(serialized_value_data) + model.append(self.serialize_ints(self.operation_args)) + model.append(self.serialize_ints(self.inputs)) + model.append(self.serialize_ints(self.outputs)) + + self.flexible_shape_computation_lines.extend(template_return_lines) + + return ( + array.array("i", b"".join(model)), + self.used_weights, + inp_dim_orders, + out_dim_orders, + self.flexible_shape_computation_lines, + retval_count, + ) + + def serialize_values(self): + serialized_values = [] + serialized_value_data = [] + assert len(self.values) == len(self.value_data) + for (op_index, source_type), data in zip(self.values, self.value_data): + source_length = len(data) + + # Pad with 0 bytes out to a multiple of 4 for alignment. + physical_length = ((source_length - 1) | 0x3) + 1 + padded_data = data + (b"\0" * (physical_length - source_length)) + + serialized_values.append( + struct.pack("iii", op_index, source_type, source_length) + ) + serialized_value_data.append(padded_data) + + return serialized_values, serialized_value_data + + @staticmethod + def serialize_ints(ints): + return array.array("i", ints).tobytes() + + ADDER_MAP = { + "prim::GetAttr": lambda self, node: self.add_getattr(node), + "prim::Constant": lambda self, node: self.add_constant_node(node), + "prim::ListConstruct": lambda self, node: self.add_list_construct(node), + "prim::TupleConstruct": lambda self, node: self.add_tuple_construct(node), + "aten::unsqueeze": lambda self, node: self.add_unsqueeze(node), + "aten::to": lambda self, node: self.add_to(node), + "aten::detach": lambda self, node: self._identity(node), + "aten::reshape": lambda self, node: self.add_reshape(node), + "aten::flatten": lambda self, node: self.add_flatten(node), + "aten::slice": lambda self, node: self.add_slice(node), + "aten::size": lambda self, node: self.add_size(node), + "aten::cat": lambda self, node: self.add_cat(node), + "aten::mean": lambda self, node: self.add_mean(node), + "aten::quantize_per_tensor": lambda self, node: self.add_quantize(node), + "aten::dequantize": lambda self, node: self.add_dequantize(node), + "aten::add": lambda self, node: self.add_add_sub_op( + node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_NONE + ), + "aten::sub": lambda self, node: self.add_add_sub_op( + node, NNAPI_OperationCode.SUB, NNAPI_FuseCode.FUSED_NONE + ), + "aten::mul": lambda self, node: self.add_pointwise_simple_binary_broadcast_op( + node, NNAPI_OperationCode.MUL, NNAPI_FuseCode.FUSED_NONE + ), + "aten::div": lambda self, node: self.add_pointwise_simple_binary_broadcast_op( + node, NNAPI_OperationCode.DIV, NNAPI_FuseCode.FUSED_NONE + ), + "aten::relu": lambda self, node: self.add_pointwise_simple_unary_op( + node, NNAPI_OperationCode.RELU + ), + "aten::sigmoid": lambda self, node: self.add_pointwise_simple_unary_op( + node, NNAPI_OperationCode.LOGISTIC + ), + "aten::softmax": lambda self, node: self.add_softmax(node), + "aten::hardtanh": lambda self, node: self.add_hardtanh(node), + "aten::avg_pool2d": lambda self, node: self.add_avg_pool2d(node), + "aten::max_pool2d": lambda self, node: self.add_pool2d_node( + node, NNAPI_OperationCode.MAX_POOL_2D + ), + "aten::adaptive_avg_pool2d": lambda self, node: self.add_adaptive_avg_pool2d( + node + ), + "aten::upsample_nearest2d": lambda self, node: self.add_upsample_nearest2d( + node + ), + "aten::prelu": lambda self, node: self.add_prelu_op(node), + "aten::addmm": lambda self, node: self.add_addmm(node), + "aten::linear": lambda self, node: self.add_linear(node), + "aten::_convolution": lambda self, node: self.add_conv_underscore(node), + "aten::conv2d": lambda self, node: self.add_conv2d(node), + "aten::log_softmax": lambda self, node: self.add_log_softmax(node), + "quantized::linear": lambda self, node: self.add_qlinear(node), + "quantized::conv2d": lambda self, node: self.add_qconv2d( + node, NNAPI_FuseCode.FUSED_NONE + ), + "quantized::conv2d_relu": lambda self, node: self.add_qconv2d( + node, NNAPI_FuseCode.FUSED_RELU + ), + "quantized::conv_transpose2d": lambda self, node: self.add_qconv2d( + node, NNAPI_FuseCode.FUSED_NONE, transpose=True + ), + "quantized::add": lambda self, node: self.add_qadd( + node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_NONE + ), + "quantized::add_relu": lambda self, node: self.add_qadd( + node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_RELU + ), + "quantized::mul": lambda self, node: self.add_qadd( + node, NNAPI_OperationCode.MUL, NNAPI_FuseCode.FUSED_NONE + ), + } + + def add_node(self, node): + adder = self.ADDER_MAP.get(node.kind()) + if not adder: + raise Exception( # noqa: TRY002 + f"Unsupported node kind ({node.kind()!r}) in node {node!r}" + ) # noqa: TRY002 + adder(self, node) + + def _identity(self, node): + in_id, _in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) + jitval = node.outputsAt(0) + self.jitval_operand_map[jitval] = in_id + + def add_getattr(self, node): + assert node.inputsSize() == 1 + assert node.outputsSize() == 1 + obj_ctype, obj = self.get_constant_value(node.inputsAt(0)) + assert str(obj_ctype).startswith("__torch__.") + name = node.s("name") + value = getattr(obj, name) + output = node.outputsAt(0) + ctype = output.type() + self.add_constant_value(output, ctype, value) + + def add_constant_node(self, node): + assert node.inputsSize() == 0 + assert node.outputsSize() == 1 + output = node.outputsAt(0) + ctype = output.type() + value = output.toIValue() + self.add_constant_value(output, ctype, value) + + def add_list_construct(self, node): + assert node.outputsSize() == 1 + output = node.outputsAt(0) + ctype = output.type() + const_vals: Optional[list] = [] + tensors: Optional[list] = [] + for inp in node.inputs(): + if const_vals is not None and inp in self.constants: + _, val = self.get_constant_value(inp) + const_vals.append(val) + else: + const_vals = None + if tensors is not None and inp.type().kind() == "TensorType": + tensors.append(inp) + else: + tensors = None + + if const_vals is not None: + # NOTE: Now that TorchScript supports list constants, + # this code path might not be used anymore. + self.add_constant_value(output, ctype, const_vals) + if tensors is not None: + self.add_tensor_sequence(output, tensors) + if const_vals is None and tensors is None: + raise Exception( # noqa: TRY002 + f"Unable to handle ListConstruct node. Neither all constants nor all tensors. {node!r}" + ) + + def add_tuple_construct(self, node): + assert node.outputsSize() == 1 + output = node.outputsAt(0) + values = list(node.inputs()) + self.add_tensor_sequence(output, values) + + def add_unsqueeze(self, node): + assert node.inputsSize() == 2 + assert node.outputsSize() == 1 + + in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0)) + + _, dim = self.get_constant_value(node.inputsAt(1), "IntType") + assert in_oper.dim_order == DimOrder.PRESUMED_CONTIGUOUS + + real_dim = dim if dim >= 0 else dim + len(in_oper.shape) + 1 + out_shape_list = list(in_oper.shape) + out_shape_list.insert(real_dim, 1) + out_shape = tuple(out_shape_list) + out_oper = in_oper._replace(shape=out_shape) + + inputs = [None] * 2 + inputs[0] = in_id + inputs[1] = self.add_immediate_int_scalar(dim) + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper) + + self.add_operation(NNAPI_OperationCode.EXPAND_DIMS, inputs, outputs) + + def add_to(self, node): + # Handle to("cpu") / to("gpu") case + self._identity(node) + + def add_reshape(self, node): + assert node.inputsSize() == 2 + assert node.outputsSize() == 1 + + in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0)) + + shape_ctype, shape = self.get_constant_value(node.inputsAt(1)) + assert shape_ctype.kind() == "ListType" + assert shape_ctype.getElementType().kind() == "IntType" + is_trivial_reshape = len(shape) == 2 and shape[1] == -1 + + if in_oper.dim_order != DimOrder.PRESUMED_CONTIGUOUS and not is_trivial_reshape: + raise Exception( # noqa: TRY002 + "Currently, reshape is only supported on NHWC tensors if the target size is [X, -1]." + ) + + # Bit of a hack here. Use a real tensor to infer the output shape. + out_shape = torch.zeros(1).expand(in_oper.shape).reshape(shape).shape + out_oper = in_oper._replace( + shape=out_shape, dim_order=DimOrder.PRESUMED_CONTIGUOUS + ) + + inputs = [None] * 2 + inputs[0] = in_id + inputs[1] = self.add_immediate_int_vector(shape) + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper) + + self.add_operation(NNAPI_OperationCode.RESHAPE, inputs, outputs) + + def add_flatten(self, node): + assert node.inputsSize() == 3 + assert node.outputsSize() == 1 + + in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) + + _start_ctype, start_dim = self.get_constant_value(node.inputsAt(1), "IntType") + _end_ctype, end_dim = self.get_constant_value(node.inputsAt(2), "IntType") + + # channels last with channels == 1 or (height & width both 1) + is_trivial_flatten = len(in_oper.shape) == 4 and ( + in_oper.shape[1] == 1 or (in_oper.shape[2] == 1 and in_oper.shape[3] == 1) + ) + if in_oper.dim_order != DimOrder.PRESUMED_CONTIGUOUS and not is_trivial_flatten: + raise Exception( # noqa: TRY002 + "Currently, flatten is not supported on NHWC tensors unless C=1 or H=W=1" + ) + + if start_dim < 0: + start_dim += len(in_oper.shape) + if end_dim < 0: + end_dim += len(in_oper.shape) + + out_shape = ( + in_oper.shape[:start_dim] + + (functools.reduce(operator.mul, in_oper.shape[start_dim : end_dim + 1]),) + + in_oper.shape[end_dim + 1 :] + ) + + if any(dim == 0 for dim in in_oper.shape[start_dim : end_dim + 1]): + raise Exception( # noqa: TRY002 + "Flattening flexible dims is not supported yet" + ) # noqa: TRY002 + non_flattened_dims = in_oper.shape[:start_dim] + in_oper.shape[end_dim + 1 :] + if non_flattened_dims.count(0) > 1: + raise Exception("Only 1 dim can be flexible") # noqa: TRY002 + + out_oper = in_oper._replace( + shape=out_shape, dim_order=DimOrder.PRESUMED_CONTIGUOUS + ) + out_id = self.add_tensor_operand(node.outputsAt(0), out_oper) + + for idx, dim in enumerate(out_shape): + if dim == 0: + self.forward_operand_shape(out_id, idx, in_id, in_oper.shape.index(0)) + + inputs_1 = tuple(dim if dim != 0 else -1 for dim in out_shape) + inputs = [None] * 2 + inputs[0] = in_id + inputs[1] = self.add_immediate_int_vector(inputs_1) + + outputs = [None] * 1 + outputs[0] = out_id + + self.add_operation(NNAPI_OperationCode.RESHAPE, inputs, outputs) + + def add_slice(self, node): + assert node.inputsSize() == 5 + assert node.outputsSize() == 1 + + in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) + _, dim_value = self.get_constant_value(node.inputsAt(1)) + _, start_value = self.get_constant_value(node.inputsAt(2)) + _, stop_value = self.get_constant_value(node.inputsAt(3)) + _, step_value = self.get_constant_value(node.inputsAt(4)) + + if start_value is None: + start_value = 0 + if stop_value is None: + stop_value = sys.maxsize + + if start_value < 0: + start_value += in_oper.shape[dim_value] + elif start_value == sys.maxsize: + start_value = 0 + + if start_value == 0 and stop_value == sys.maxsize: + self._identity(node) + return + + if in_oper.shape[dim_value] == 0: + raise Exception("Unable to slice with flexible shape") # noqa: TRY002 + + if stop_value < 0: + stop_value += in_oper.shape[dim_value] + elif stop_value == sys.maxsize: + stop_value = in_oper.shape[dim_value] + + if start_value >= stop_value: + raise Exception( # noqa: TRY002 + "Slice start value should be less than stop value" + ) # noqa: TRY002 + + out_len = (stop_value - start_value) // step_value + out_shape = tuple( + out_len if i == dim_value else dim for i, dim in enumerate(in_oper.shape) + ) + out_id = self.add_tensor_operand( + node.outputsAt(0), in_oper._replace(shape=out_shape) + ) + + # flex inputs + end_mask = 0 + for idx, dim in enumerate(out_shape): + if dim == 0: + self.forward_operand_shape(out_id, idx, in_id, idx) + end_mask |= 1 << idx + + inputs = [None] * 7 + inputs[0] = in_id + inputs[1] = self.add_immediate_int_vector( + [start_value if i == dim_value else 0 for i in range(len(in_oper.shape))] + ) + inputs[2] = self.add_immediate_int_vector( + [ + stop_value if i == dim_value else dim + for i, dim in enumerate(in_oper.shape) + ] + ) + inputs[3] = self.add_immediate_int_vector( + [step_value if i == dim_value else 1 for i in range(len(in_oper.shape))] + ) + inputs[4] = self.add_immediate_int_scalar(0) # begin mask + inputs[5] = self.add_immediate_int_scalar(end_mask) + inputs[6] = self.add_immediate_int_scalar(0) # shrink axis mas + + outputs = [None] * 1 + outputs[0] = out_id + + self.add_operation(NNAPI_OperationCode.STRIDED_SLICE, inputs, outputs) + + def add_size(self, node): + assert node.inputsSize() == 2 + assert node.outputsSize() == 1 + + _, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0)) + _, value = self.constants[node.inputsAt(1)] + res = in_oper.shape[value] + output = node.outputsAt(0) + self.add_constant_value(output, output.type(), res) + + def add_cat(self, node): + assert node.inputsSize() == 2 + assert node.outputsSize() == 1 + + tensors = self.tensor_sequences[node.inputsAt(0)] + _, dim = self.get_constant_value(node.inputsAt(1), "IntType") + + assert len(tensors) > 0 + in_ids = [] + out_oper = None + out_dim_size = 0 + for inp in tensors: + in_id, in_oper = self.get_tensor_operand_by_jitval(inp) + if out_oper is None: + out_shape = change_element(in_oper.shape, dim, -1) + out_oper = in_oper._replace(shape=out_shape) + assert in_oper.op_type == out_oper.op_type + assert in_oper.dim_order == out_oper.dim_order + assert change_element(in_oper.shape, dim, -1) == change_element( + out_oper.shape, dim, -1 + ) + # TODO: Possibly check scale and zero point. + in_ids.append(in_id) + # TODO: Possibly support variable-sized inputs. + out_dim_size += in_oper.shape[dim] + + assert out_oper is not None + out_oper = out_oper._replace( + shape=change_element(out_oper.shape, dim, out_dim_size) + ) + + if in_oper.dim_order == DimOrder.CHANNELS_LAST: # type: ignore[possibly-undefined] + assert len(out_oper.shape) == 4 + nnapi_dim = [0, 3, 1, 2][dim] + else: + nnapi_dim = dim + + out_id = self.add_tensor_operand(node.outputsAt(0), out_oper) + for idx, d in enumerate(out_oper.shape): + if d == 0: + if idx == dim: + shape = " + ".join(flex_name(ip_id, dim) for ip_id in in_ids) + self.compute_operand_shape(out_id, idx, shape) + else: + self.forward_operand_shape(out_id, idx, in_ids[0], idx) + + inputs = in_ids + [self.add_immediate_int_scalar(nnapi_dim)] + + outputs = [None] * 1 + outputs[0] = out_id + + self.add_operation(NNAPI_OperationCode.CONCATENATION, inputs, outputs) + + def add_mean(self, node): + assert node.inputsSize() == 4 + assert node.outputsSize() == 1 + + in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0)) + dim_ctype, dim = self.get_constant_value(node.inputsAt(1)) + assert dim_ctype.kind() == "ListType" + assert dim_ctype.getElementType().kind() == "IntType" + _, keep_dim = self.get_constant_value(node.inputsAt(2), "BoolType") + # Expect None for dtype + self.get_constant_value(node.inputsAt(3), "NoneType") + + if in_oper.dim_order == DimOrder.CHANNELS_LAST: + assert len(in_oper.shape) == 4 + nnapi_dim = [[0, 3, 1, 2][d] for d in dim] + else: + nnapi_dim = dim + + collapsed_dims = set() + for d in dim: + if d < 0: + d += len(in_oper.shape) + collapsed_dims.add(d) + + if in_oper.dim_order == DimOrder.CHANNELS_LAST and not keep_dim: + assert collapsed_dims.issuperset({2, 3}) + out_dim_order = DimOrder.PRESUMED_CONTIGUOUS + else: + out_dim_order = in_oper.dim_order + + out_shape = [] + for i, s in enumerate(in_oper.shape): + if i not in collapsed_dims: + out_shape.append(s) + elif keep_dim: + out_shape.append(1) + + out_oper = in_oper._replace(shape=out_shape, dim_order=out_dim_order) + + inputs = [None] * 3 + inputs[0] = in_id + inputs[1] = self.add_immediate_int_vector(nnapi_dim) + inputs[2] = self.add_immediate_int_scalar(keep_dim) + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper) + + self.add_operation(NNAPI_OperationCode.MEAN, inputs, outputs) + + def add_quantize(self, node): + assert node.inputsSize() == 4 + assert node.outputsSize() == 1 + + in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0)) + if in_oper.dim_order != DimOrder.CHANNELS_LAST: + raise Exception( # noqa: TRY002 + "Most hardware backends prefer NHWC quantized tensors. " + "Try setting `t.nnapi_nhwc = True` on your tensor inputs. " + ) + _, scale = self.get_constant_value(node.inputsAt(1), "FloatType") + _, zero_point = self.get_constant_value(node.inputsAt(2), "IntType") + _, scalar_type = self.get_constant_value(node.inputsAt(3), "IntType") + if scalar_type != TorchScalarTypes.QUINT8.value: + raise Exception( # noqa: TRY002 + "PyTorch NNAPI export only supports quantized tensors " + "with the quint8 dtype." + ) + op_type = NNAPI_OperandCode.TENSOR_QUANT8_ASYMM + + out_oper = in_oper._replace( + op_type=op_type, + scale=scale, + zero_point=zero_point, + ) + + inputs = [None] * 1 + inputs[0] = in_id + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper) + + self.add_operation(NNAPI_OperationCode.QUANTIZE, inputs, outputs) + + def add_dequantize(self, node): + assert node.inputsSize() == 1 + assert node.outputsSize() == 1 + + in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0)) + out_oper = in_oper._replace( + op_type=NNAPI_OperandCode.TENSOR_FLOAT32, + scale=0.0, + zero_point=0, + ) + + inputs = [None] * 1 + inputs[0] = in_id + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper) + + self.add_operation(NNAPI_OperationCode.DEQUANTIZE, inputs, outputs) + + def add_pointwise_simple_unary_op(self, node, opcode): + assert node.inputsSize() == 1 + assert node.outputsSize() == 1 + + in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) + + out_oper = in_oper + if opcode == NNAPI_OperationCode.LOGISTIC: + # NNAPI docs: For ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, the scale + # must be 1.f / 256 and the zeroPoint must be 0. + # https://fburl.com/h52stoog + if in_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM: + out_oper = in_oper._replace(zero_point=0, scale=1.0 / 256) + + out_id = self.add_tensor_operand(node.outputsAt(0), out_oper) + + for idx, dim in enumerate(in_oper.shape): + if dim == 0: + self.forward_operand_shape(out_id, idx, in_id, idx) + + inputs = [None] * 1 + inputs[0] = in_id + + outputs = [None] * 1 + outputs[0] = out_id + + self.add_operation(opcode, inputs, outputs) + + def _do_add_binary(self, node, opcode, fuse_code, *, qparams=None): # noqa: D401 + """Helper for pointwise binary broadcast ops with superfluous extra args.""" + assert node.outputsSize() == 1 + + assert node.inputsAt(0).type().kind() == "TensorType" + assert node.inputsAt(1).type().kind() == "TensorType" + + if self.has_operand_for_jitval(node.inputsAt(0)): + in0_id, in0_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) + in1_id, in1_oper = self.get_tensor_operand_or_constant( + node.inputsAt(1), in0_oper.dim_order + ) + elif self.has_operand_for_jitval(node.inputsAt(1)): + in1_id, in1_oper = self.get_tensor_operand_by_jitval(node.inputsAt(1)) + in0_id, in0_oper = self.get_tensor_operand_or_constant( + node.inputsAt(0), in1_oper.dim_order + ) + else: + raise Exception( # noqa: TRY002 + f"Can't do a NNAPI binary op: {opcode} on two constants" + ) # noqa: TRY002 + + assert in0_oper.op_type == in1_oper.op_type + in0_id, in0_oper, in1_id, in1_oper = self.transpose_for_broadcast( + in0_id, in0_oper, in1_id, in1_oper + ) + # NOTE: PyTorch and NNAPI have the same broadcast semantics. + out_shape = broadcast_shapes(in0_oper.shape, in1_oper.shape) + out_oper = in0_oper._replace(shape=out_shape) + if qparams is not None: + scale, zp = qparams + out_oper = out_oper._replace(scale=scale, zero_point=zp) + + out_id = self.add_tensor_operand(node.outputsAt(0), out_oper) + for idx, (d0, d1) in enumerate(zip(in0_oper.shape, in1_oper.shape)): + if d0 == 1 and d1 == 0: + self.forward_operand_shape(out_id, idx, in1_id, idx) + elif d0 == 0 and d1 == 1: + self.forward_operand_shape(out_id, idx, in0_id, idx) + elif d0 == 0 and d1 == 0: + self.flexible_shape_computation_lines.append( + f"assert {flex_name(in0_id, idx)} == {flex_name(in1_id, idx)}" + ) + self.forward_operand_shape(out_id, idx, in0_id, idx) + + inputs = [None] * 3 + inputs[0] = in0_id + inputs[1] = in1_id + inputs[2] = self.add_immediate_int_scalar(fuse_code) + + outputs = [None] * 1 + outputs[0] = out_id + + self.add_operation(opcode, inputs, outputs) + + def add_pointwise_simple_binary_broadcast_op(self, node, opcode, fuse_code): + assert node.inputsSize() == 2 + self._do_add_binary(node, opcode, fuse_code) + + def add_add_sub_op(self, node, opcode, fuse_code): + assert node.inputsSize() == 3 + + _, alpha = self.get_constant_value(node.inputsAt(2), "IntType") + if alpha != 1: + raise Exception( # noqa: TRY002 + "NNAPI does not support add/sub with alpha." + ) # noqa: TRY002 + + self._do_add_binary(node, opcode, fuse_code) + + def add_qadd(self, node, opcode, fuse_code): + assert node.inputsSize() == 4 + + _, scale = self.get_constant_value(node.inputsAt(2), "FloatType") + _, zero_point = self.get_constant_value(node.inputsAt(3), "IntType") + + self._do_add_binary(node, opcode, fuse_code, qparams=(scale, zero_point)) + + def add_softmax(self, node): + assert node.inputsSize() == 3 + in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) + + _, softmax_dim = self.get_constant_value(node.inputsAt(1), "IntType") + + out_id = self.add_tensor_operand(node.outputsAt(0), in_oper) + for dim, size in enumerate(in_oper.shape): + if size == 0: + self.forward_operand_shape(out_id, dim, in_id, dim) + + inputs = [None] * 3 + inputs[0] = in_id + inputs[1] = self.add_immediate_float_scalar( + 1.0 + ) # positive scaling factor of exponent, beta + inputs[2] = self.add_immediate_int_scalar(softmax_dim) + + outputs = [None] * 1 + outputs[0] = out_id + + self.add_operation(NNAPI_OperationCode.SOFTMAX, inputs, outputs) + + def add_hardtanh(self, node): + assert node.inputsSize() == 3 + assert node.outputsSize() == 1 + + in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0)) + _, min_val = self.get_constant_value(node.inputsAt(1), "FloatType") + _, max_val = self.get_constant_value(node.inputsAt(2), "FloatType") + + op_map = { + (-1, 1): NNAPI_OperationCode.RELU1, + (0, 6): NNAPI_OperationCode.RELU6, # noqa: E201 + } + + opcode = op_map.get((min_val, max_val)) + if opcode is None: + raise Exception( # noqa: TRY002 + "NNAPI only supports hardtanh with args (-1, 1) or (0, 6)." + ) # noqa: TRY002 + + inputs = [None] * 1 + inputs[0] = in_id + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand(node.outputsAt(0), in_oper) + + self.add_operation(opcode, inputs, outputs) + + def add_prelu_op(self, node): + assert node.inputsSize() == 2 + assert node.outputsSize() == 1 + + assert node.inputsAt(0).type().kind() == "TensorType" + assert node.inputsAt(1).type().kind() == "TensorType" + + in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) + w_id, w_oper = self.get_tensor_operand_for_weight(node.inputsAt(1)) + assert len(w_oper.shape) == 1 + assert w_oper.shape[0] > 0 + if w_oper.shape[0] > 1: + if in_oper.use_nchw(): + # TODO: Support this by adding trailing 1 dims. + raise Exception( # noqa: TRY002 + "Per-channel PReLU only supports channels_last right now." + ) + + out_id = self.add_tensor_operand(node.outputsAt(0), in_oper) + for dim, size in enumerate(in_oper.shape): + if size > 0: + pass + elif dim <= 1: + raise Exception( # noqa: TRY002 + "PReLU requires fixed size for dim 0 and dim 1." + ) # noqa: TRY002 + else: + self.forward_operand_shape(out_id, dim, in_id, dim) + + inputs = [None] * 2 + inputs[0] = in_id + inputs[1] = w_id + + outputs = [None] * 1 + outputs[0] = out_id + + self.add_operation(NNAPI_OperationCode.PRELU, inputs, outputs) + + def add_pool2d_node(self, node, opcode): + assert node.inputsSize() == 6 + assert node.outputsSize() == 1 + image, kernel, stride, padding, dilation, _ceil_mode = node.inputs() + + stride = stride or kernel + + # TODO: Validate ceil_mode semantics. + + args = self.get_conv_pool_args_2d_from_jit( + self.get_size_arg(kernel), stride, padding, dilation + ) + if args.dilation_h != 1 or args.dilation_w != 1: + raise Exception("NNAPI does not support dilated pooling.") # noqa: TRY002 + + image_id, image_oper = self.get_tensor_operand_by_jitval_fixed_size(image) + assert len(image_oper.shape) == 4 + + out_shape = get_conv_pool_shape( + image_oper.shape, args, image_oper.shape[1], False + ) + use_nchw = image_oper.use_nchw() + + inputs = [None] * 11 + inputs[0] = image_id + inputs[1] = self.add_immediate_int_scalar(args.pad_l) + inputs[2] = self.add_immediate_int_scalar(args.pad_r) + inputs[3] = self.add_immediate_int_scalar(args.pad_t) + inputs[4] = self.add_immediate_int_scalar(args.pad_b) + inputs[5] = self.add_immediate_int_scalar(args.stride_w) + inputs[6] = self.add_immediate_int_scalar(args.stride_h) + inputs[7] = self.add_immediate_int_scalar(args.kernel_w) + inputs[8] = self.add_immediate_int_scalar(args.kernel_h) + inputs[9] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE) + inputs[10] = self.add_immediate_bool_scalar(use_nchw) + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand( + node.outputsAt(0), image_oper._replace(shape=out_shape) + ) + + self.add_operation(opcode, inputs, outputs) + + def add_avg_pool2d(self, node): + assert node.inputsSize() == 7 + assert node.outputsSize() == 1 + ( + image, + kernel, + stride, + padding, + _ceil_mode, + count_include_pad, + divisor_override, + ) = node.inputs() + + _, count_include_pad_value = self.get_constant_value(count_include_pad) + _, divisor_override_value = self.get_constant_value(divisor_override) + if not count_include_pad_value or divisor_override_value: + raise Exception( # noqa: TRY002 + "NNAPI doesn't support count_include_pad=False or divisor_override" + ) + + args = self.get_conv_pool_args_2d_from_jit( + self.get_size_arg(kernel), stride, padding + ) + + image_id, image_oper = self.get_tensor_operand_by_jitval(image) + assert len(image_oper.shape) == 4 + + out_shape = get_conv_pool_shape( + image_oper.shape, args, image_oper.shape[1], False + ) + use_nchw = image_oper.use_nchw() + + inputs = [None] * 11 + inputs[0] = image_id + inputs[1] = self.add_immediate_int_scalar(args.pad_l) + inputs[2] = self.add_immediate_int_scalar(args.pad_r) + inputs[3] = self.add_immediate_int_scalar(args.pad_t) + inputs[4] = self.add_immediate_int_scalar(args.pad_b) + inputs[5] = self.add_immediate_int_scalar(args.stride_w) + inputs[6] = self.add_immediate_int_scalar(args.stride_h) + inputs[7] = self.add_immediate_int_scalar(args.kernel_w) + inputs[8] = self.add_immediate_int_scalar(args.kernel_h) + inputs[9] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE) + inputs[10] = self.add_immediate_bool_scalar(use_nchw) + + outputs = [None] * 1 + out_id = self.add_tensor_operand( + node.outputsAt(0), image_oper._replace(shape=out_shape) + ) + self._handle_conv_pool_flexible_input(out_id, image, args, False) + outputs[0] = out_id + + self.add_operation(NNAPI_OperationCode.AVERAGE_POOL_2D, inputs, outputs) + + def add_adaptive_avg_pool2d(self, node): + assert node.inputsSize() == 2 + assert node.outputsSize() == 1 + + image_id, image_oper = self.get_tensor_operand_by_jitval_fixed_size( + node.inputsAt(0) + ) + assert len(image_oper.shape) == 4 + + size_ctype, size_arg = self.get_constant_value(node.inputsAt(1)) + assert size_ctype.kind() == "ListType" + assert size_ctype.getElementType().kind() == "IntType" + if size_arg != [1, 1]: + raise Exception( # noqa: TRY002 + "NNAPI only supports adaptive_avg_pool2d with output size (1, 1)." + ) + + out_shape = image_oper.shape[0:2] + tuple(size_arg) + use_nchw = image_oper.use_nchw() + + inputs = [None] * 11 + inputs[0] = image_id + inputs[1] = self.add_immediate_int_scalar(0) + inputs[2] = self.add_immediate_int_scalar(0) + inputs[3] = self.add_immediate_int_scalar(0) + inputs[4] = self.add_immediate_int_scalar(0) + inputs[5] = self.add_immediate_int_scalar(1) + inputs[6] = self.add_immediate_int_scalar(1) + inputs[7] = self.add_immediate_int_scalar(image_oper.shape[3]) + inputs[8] = self.add_immediate_int_scalar(image_oper.shape[2]) + inputs[9] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE) + inputs[10] = self.add_immediate_bool_scalar(use_nchw) + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand( + node.outputsAt(0), image_oper._replace(shape=out_shape) + ) + + self.add_operation(NNAPI_OperationCode.AVERAGE_POOL_2D, inputs, outputs) + + def add_upsample_nearest2d(self, node): + assert node.inputsSize() == 3 or node.inputsSize() == 4 + assert node.outputsSize() == 1 + if node.inputsSize() == 3: + image, size_jit, scale_jit = node.inputs() + else: + image, size_jit, scale_h_jit, scale_w_jit = node.inputs() + size_ctype, size_arg = self.get_constant_value(size_jit) + + if node.inputsSize() == 3: + scale_ctype, scale_arg = self.get_constant_value(scale_jit) # type: ignore[possibly-undefined] + else: + scale_h_ctype, scale_h_arg = self.get_constant_value(scale_h_jit) # type: ignore[possibly-undefined] + scale_w_ctype, _scale_w_arg = self.get_constant_value(scale_w_jit) # type: ignore[possibly-undefined] + + # The only way for the 4-argument overload of upsample_nearest2d to + # have been added to the graph without error is if the scale_h and + # scale_w arguments are None + assert scale_h_ctype.kind() == "NoneType" + assert scale_w_ctype.kind() == "NoneType" + + scale_ctype = scale_h_ctype + scale_arg = scale_h_arg + + image_id, image_oper = self.get_tensor_operand_by_jitval(image) + assert len(image_oper.shape) == 4 + + if size_ctype.kind() != "NoneType" and scale_ctype.kind() != "NoneType": + raise Exception("Size and scale cannot both be non-None.") # noqa: TRY002 + elif size_ctype.kind() != "NoneType": + assert size_ctype.kind() == "ListType" + assert size_ctype.getElementType().kind() == "IntType" + assert scale_ctype.kind() == "NoneType" + assert scale_arg is None + assert isinstance(size_arg, list) + assert size_arg + assert all(isinstance(val, int) for val in size_arg) + if len(size_arg) == 1: + size_arg = size_arg * 2 + assert len(size_arg) == 2 + out_h = size_arg[0] + out_w = size_arg[1] + arg_h = self.add_immediate_int_scalar(out_h) + arg_w = self.add_immediate_int_scalar(out_w) + elif scale_ctype.kind() != "NoneType": + assert scale_ctype.kind() == "ListType" + assert scale_ctype.getElementType().kind() == "FloatType" + assert size_ctype.kind() == "NoneType" + assert size_arg is None + assert isinstance(scale_arg, list) + assert scale_arg + assert all(isinstance(val, float) for val in scale_arg) + if len(scale_arg) == 1: + scale_arg = scale_arg * 2 + assert len(scale_arg) == 2 + out_h = int(scale_arg[0] * image_oper.shape[2]) + out_w = int(scale_arg[1] * image_oper.shape[3]) + arg_h = self.add_immediate_float_scalar(scale_arg[0]) + arg_w = self.add_immediate_float_scalar(scale_arg[1]) + else: + raise Exception("Size and scale cannot both be None.") # noqa: TRY002 + + out_shape = (image_oper.shape[0], image_oper.shape[1], out_h, out_w) + use_nchw = image_oper.use_nchw() + out_id = self.add_tensor_operand( + node.outputsAt(0), image_oper._replace(shape=out_shape) + ) + + if image_oper.shape[0] == 0 or image_oper.shape[1] == 0: + raise Exception("Flexible batch or channels not supported") # noqa: TRY002 + + # Handle variable input size + for dim in (2, 3): # h, w indices + if image_oper.shape[dim] == 0: + if size_ctype.kind() != "NoneType": + self.compute_operand_shape(out_id, dim, size_arg[dim - 2]) + elif scale_ctype.kind() != "NoneType": + self.compute_operand_shape( + out_id, + dim, + f"int({scale_arg[dim - 2]} * {flex_name(image_id, dim)})", + ) + else: + raise Exception( # noqa: TRY002 + "Size and scale cannot both be None." + ) # noqa: TRY002 + + inputs = [None] * 4 + inputs[0] = image_id + inputs[1] = arg_w + inputs[2] = arg_h + inputs[3] = self.add_immediate_bool_scalar(use_nchw) + + outputs = [None] * 1 + outputs[0] = out_id + + self.add_operation(NNAPI_OperationCode.RESIZE_NEAREST_NEIGHBOR, inputs, outputs) + + def add_addmm(self, node): + assert node.inputsSize() == 5 + assert node.outputsSize() == 1 + jit_bias, jit_input, jit_weight, jit_beta, jit_alpha = node.inputs() + + for jitval in (jit_beta, jit_alpha): + scale_ctype, scale_value = self.get_constant_value(jitval) + assert scale_ctype.kind() in ("IntType", "FloatType") + if scale_value != 1: + raise Exception( # noqa: TRY002 + "NNAPI Fully-Connected does not support alpha and beta." + ) + + self.add_addmm_or_linear(node, True, jit_input, jit_weight, jit_bias) + + def add_linear(self, node): + assert node.inputsSize() == 3 + assert node.outputsSize() == 1 + jit_input, jit_weight, jit_bias = node.inputs() + + self.add_addmm_or_linear(node, False, jit_input, jit_weight, jit_bias) + + def add_addmm_or_linear( + self, node, transpose_weight, jit_input, jit_weight, jit_bias + ): + input_id, input_oper = self.get_tensor_operand_by_jitval(jit_input) + bias_id, bias_oper = self.get_tensor_operand_for_weight(jit_bias) + + assert len(input_oper.shape) == 2 + assert len(bias_oper.shape) == 1 + + # TODO: Transform at load time to share weights with CPU model. + _, weight_tensor = self.get_constant_value(jit_weight, "TensorType") + assert len(weight_tensor.shape) == 2 + if transpose_weight: + nnapi_weight_tensor = weight_tensor.t().contiguous() + else: + nnapi_weight_tensor = weight_tensor.contiguous() + weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor) + weight_oper = self.operands[weight_id] + + out_shape = (input_oper.shape[0], weight_oper.shape[0]) + out_id = self.add_tensor_operand( + node.outputsAt(0), input_oper._replace(shape=out_shape) + ) + + if input_oper.shape[0] == 0: + self.forward_operand_shape(out_id, 0, input_id, 0) + + inputs = [None] * 4 + inputs[0] = input_id + inputs[1] = weight_id + inputs[2] = bias_id + inputs[3] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE) + + outputs = [None] * 1 + outputs[0] = out_id + + self.add_operation(NNAPI_OperationCode.FULLY_CONNECTED, inputs, outputs) + + def add_qlinear(self, node): + assert node.inputsSize() == 4 + assert node.outputsSize() == 1 + ( + jit_input, + jit_packed_weight, + jit_scale, + jit_zero_point, + ) = node.inputs() + + input_id, input_oper = self.get_tensor_operand_by_jitval_fixed_size(jit_input) + # TODO: Support automatic reshape + assert len(input_oper.shape) == 2 + + _, out_scale = self.get_constant_value(jit_scale, "FloatType") + _, out_zero_point = self.get_constant_value(jit_zero_point, "IntType") + weight_ctype, packed_weight = self.get_constant_value(jit_packed_weight) + assert weight_ctype.name() == "LinearPackedParamsBase" + raw_weight, raw_bias = packed_weight.__getstate__()[0] + assert raw_bias is not None + + assert len(raw_weight.shape) == 2 + assert len(raw_bias.shape) == 1 + assert raw_bias.shape[0] == raw_weight.shape[0] + assert raw_weight.shape[1] == input_oper.shape[1] + + assert raw_weight.qscheme() == torch.per_tensor_affine + if raw_weight.dtype == torch.quint8: + unsigned_weight = raw_weight + else: + assert raw_weight.dtype == torch.qint8 + unsigned_weight = torch._make_per_tensor_quantized_tensor( + (raw_weight.int_repr().int() + 128).to(torch.uint8), + scale=raw_weight.q_scale(), + zero_point=raw_weight.q_zero_point() + 128, + ) + weight_scale = unsigned_weight.q_scale() + bias_scale = input_oper.scale * weight_scale + int_bias = torch.quantize_per_tensor(raw_bias, bias_scale, 0, torch.qint32) + bias_id = self.add_tensor_operand_for_weight(int_bias) + + multiplier = input_oper.scale * weight_scale / out_scale + assert multiplier > 0 + if multiplier >= 1: + raise Exception( # noqa: TRY002 + "Quantized convolution multiplier is greater than 1. " + "This is supported by NNAPI, but not by most hardware backends. " + "Try training a model without quantization-aware training. " + ) + + # TODO: Transform at load time to share weights with CPU model. + nnapi_weight_tensor = unsigned_weight.contiguous() + weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor) + weight_oper = self.operands[weight_id] + + out_shape = (input_oper.shape[0], weight_oper.shape[0]) + out_oper = input_oper._replace( + shape=out_shape, + scale=out_scale, + zero_point=out_zero_point, + ) + + inputs = [None] * 4 + inputs[0] = input_id + inputs[1] = weight_id + inputs[2] = bias_id + inputs[3] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE) + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper) + + self.add_operation(NNAPI_OperationCode.FULLY_CONNECTED, inputs, outputs) + + def get_optional_bias(self, jit_bias, weight_tensor, transpose=False): + ctype, _value = self.get_constant_value(jit_bias) + if ctype.kind() == "NoneType": + bias_idx = 1 if transpose else 0 + nnapi_bias_tensor = torch.zeros( + weight_tensor.size()[bias_idx], dtype=weight_tensor.dtype + ) + bias_id = self.add_tensor_operand_for_weight(nnapi_bias_tensor) + bias_oper = self.operands[bias_id] + return bias_id, bias_oper + else: + return self.get_tensor_operand_for_weight(jit_bias) + + def add_conv2d(self, node): + assert node.inputsSize() == 7 + assert node.outputsSize() == 1 + + ( + jit_image, + jit_weight, + jit_bias, + jit_stride, + jit_pad, + jit_dilation, + jit_groups, + ) = node.inputs() + + _, weight_tensor = self.get_constant_value(jit_weight, "TensorType") + bias_id, _bias_oper = self.get_optional_bias(jit_bias, weight_tensor) + args = self.get_conv_pool_args_2d_from_jit( + weight_tensor.shape[2:4], jit_stride, jit_pad, jit_dilation, jit_groups + ) + + return self.add_conv2d_common( + node.outputsAt(0), + 0.0, + 0, + jit_image, + weight_tensor, + bias_id, + args, + False, # transpose + NNAPI_FuseCode.FUSED_NONE, + ) + + def add_conv_underscore(self, node): + assert node.inputsSize() == 13 + assert node.outputsSize() == 1 + + ( + jit_image, + jit_weight, + jit_bias, + jit_stride, + jit_pad, + jit_dilation, + jit_transpose, + _, + jit_groups, + _, + _, + _, + _, + ) = node.inputs() + + _, weight_tensor = self.get_constant_value(jit_weight, "TensorType") + _, transpose = self.get_constant_value(jit_transpose) + bias_id, _bias_oper = self.get_optional_bias(jit_bias, weight_tensor, transpose) + args = self.get_conv_pool_args_2d_from_jit( + weight_tensor.shape[2:4], jit_stride, jit_pad, jit_dilation, jit_groups + ) + + return self.add_conv2d_common( + node.outputsAt(0), + 0.0, + 0, + jit_image, + weight_tensor, + bias_id, + args, + transpose, + NNAPI_FuseCode.FUSED_NONE, + ) + + def add_log_softmax(self, node): + assert node.inputsSize() == 3 + assert node.outputsSize() == 1 + + jit_input, jit_dim, _jit_half_to_float = node.inputs() + input_id, input_oper = self.get_tensor_operand_by_jitval_fixed_size(jit_input) + _, dim = self.get_constant_value(jit_dim, "IntType") + + out_shape = input_oper.shape + + inputs = [None] * 3 + inputs[0] = input_id + # specifying 1 as the scaling factor for the exponent, beta + inputs[1] = self.add_immediate_float_scalar(1) + inputs[2] = self.add_immediate_int_scalar(dim) + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand( + node.outputsAt(0), input_oper._replace(shape=out_shape) + ) + self.add_operation(NNAPI_OperationCode.LOG_SOFTMAX, inputs, outputs) + + def add_qconv2d(self, node, fuse_code, transpose=False): + assert node.inputsSize() == 4 + assert node.outputsSize() == 1 + + ( + jit_image, + jit_packed_weight, + jit_scale, + jit_zero_point, + ) = node.inputs() + + _, out_scale = self.get_constant_value(jit_scale, "FloatType") + _, out_zero_point = self.get_constant_value(jit_zero_point, "IntType") + weight_ctype, packed_weight = self.get_constant_value(jit_packed_weight) + assert weight_ctype.name() == "Conv2dPackedParamsBase" + ( + pack_version, + tensors, + opt_tensors, + ) = packed_weight.__getstate__()[0] + assert pack_version == "2" + packed_config, raw_weight = tensors + (raw_bias,) = opt_tensors + assert raw_bias is not None + args = self.get_conv_pool_args_2d_from_pack( + raw_weight.shape[2:4], packed_config + ) + + assert raw_weight.qscheme() == torch.per_tensor_affine + if raw_weight.dtype == torch.quint8: + unsigned_weight = raw_weight + else: + assert raw_weight.dtype == torch.qint8 + unsigned_weight = torch._make_per_tensor_quantized_tensor( + (raw_weight.int_repr().int() + 128).to(torch.uint8), + scale=raw_weight.q_scale(), + zero_point=raw_weight.q_zero_point() + 128, + ) + weight_scale = unsigned_weight.q_scale() + _, image_oper = self.get_tensor_operand_by_jitval(jit_image) + bias_scale = image_oper.scale * weight_scale + int_bias = torch.quantize_per_tensor(raw_bias, bias_scale, 0, torch.qint32) + bias_id = self.add_tensor_operand_for_weight(int_bias) + + multiplier = image_oper.scale * weight_scale / out_scale + assert multiplier > 0 + if multiplier >= 1: + raise Exception( # noqa: TRY002 + "Quantized convolution multiplier is greater than 1. " + "This is supported by NNAPI, but not by most hardware backends. " + "Try training a model without quantization-aware training. " + ) + + return self.add_conv2d_common( + node.outputsAt(0), + out_scale, + out_zero_point, + jit_image, + unsigned_weight, + bias_id, + args, + transpose, + fuse_code, + ) + + def add_conv2d_common( + self, + jit_out, + out_scale, + out_zero_point, + jit_image, + weight_tensor, + bias_id, + args, + transpose, + fuse_code, + ): + image_id, image_oper = self.get_tensor_operand_by_jitval(jit_image) + in_c = image_oper.shape[1] + + if args.group == 1: + # Full convolution + depthwise = False + if transpose: + weight_permutation = (1, 2, 3, 0) + else: + weight_permutation = (0, 2, 3, 1) + elif args.group == in_c: + # Depthwise convolution + depthwise = True + weight_permutation = (1, 2, 3, 0) + else: + raise Exception("Group convolution not supported yet.") # noqa: TRY002 + + # TODO: Transform at load time to share weights with CPU model. + nnapi_weight_tensor = weight_tensor.permute(*weight_permutation).contiguous() + weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor) + weight_oper = self.operands[weight_id] + + bias_oper = self.operands[bias_id] + + if image_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32: + assert weight_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32 + assert bias_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32 + elif image_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM: + assert weight_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM + assert bias_oper.op_type == NNAPI_OperandCode.TENSOR_INT32 + assert approx_equal(image_oper.scale * weight_oper.scale, bias_oper.scale) + assert bias_oper.zero_point == 0 + else: + raise Exception( # noqa: TRY002 + f"Unsupported input type for conv2d: {image_oper.op_type}" + ) # noqa: TRY002 + + assert len(image_oper.shape) == 4 + assert len(weight_oper.shape) == 4 + assert len(bias_oper.shape) == 1 + + if depthwise: + # Depthwise convolution + one, _kern_h, _kern_w, out_c = weight_oper.shape + assert one == 1 + assert out_c % in_c == 0 + channel_multiplier = out_c // in_c + assert channel_multiplier == 1 # Don't support multiplier + assert out_c == in_c + else: + # Full convolution + out_c, _kern_h, _kern_w, kern_d = weight_oper.shape + assert kern_d == in_c + + assert out_c == bias_oper.shape[0] + + use_nchw = image_oper.use_nchw() + + if depthwise: + num_args = 12 + opcode = NNAPI_OperationCode.DEPTHWISE_CONV_2D + else: + num_args = 11 + if transpose: + opcode = NNAPI_OperationCode.TRANSPOSE_CONV_2D + else: + opcode = NNAPI_OperationCode.CONV_2D + + inputs = [None] * num_args + inputs[0] = image_id + inputs[1] = weight_id + inputs[2] = bias_id + inputs[3] = self.add_immediate_int_scalar(args.pad_l) + inputs[4] = self.add_immediate_int_scalar(args.pad_r) + inputs[5] = self.add_immediate_int_scalar(args.pad_t) + inputs[6] = self.add_immediate_int_scalar(args.pad_b) + inputs[7] = self.add_immediate_int_scalar(args.stride_w) + inputs[8] = self.add_immediate_int_scalar(args.stride_h) + if depthwise: + inputs[9] = self.add_immediate_int_scalar(1) + inputs[10] = self.add_immediate_int_scalar(fuse_code) + inputs[11] = self.add_immediate_bool_scalar(use_nchw) + else: + inputs[9] = self.add_immediate_int_scalar(fuse_code) + inputs[10] = self.add_immediate_bool_scalar(use_nchw) + + outputs = [None] * 1 + out_shape = get_conv_pool_shape(image_oper.shape, args, out_c, transpose) + out_oper = image_oper._replace( + shape=out_shape, + scale=out_scale, + zero_point=out_zero_point, + ) + out_id = self.add_tensor_operand(jit_out, out_oper) + self._handle_conv_pool_flexible_input(out_id, jit_image, args, transpose) + + outputs[0] = out_id + self.add_operation(opcode, inputs, outputs) + + def _handle_conv_pool_flexible_input(self, out_id, jit_image, args, transpose): + image_id, image_oper = self.get_tensor_operand_by_jitval(jit_image) + batch, in_ch, in_h, in_w = image_oper.shape + + if batch == 0: + self.forward_operand_shape(out_id, 0, image_id, 0) + if in_ch == 0: + raise Exception("Input channels can't be flexible") # noqa: TRY002 + # H & W + if transpose: + if in_h == 0: + self.compute_operand_shape( + out_id, + 2, + f"({flex_name(image_id, 2)} - 1) * {args.stride_h} + {args.kernel_h} - {args.pad_t} - {args.pad_b}", + ) + if in_w == 0: + self.compute_operand_shape( + out_id, + 3, + f"({flex_name(image_id, 3)} - 1) * {args.stride_w} + {args.kernel_w} - {args.pad_l} - {args.pad_r}", + ) + else: + if in_h == 0: + self.compute_operand_shape( + out_id, + 2, + f"({flex_name(image_id, 2)} - {args.kernel_h} + {args.pad_t} + {args.pad_b}) // {args.stride_h} + 1", + ) + if in_w == 0: + self.compute_operand_shape( + out_id, + 3, + f"({flex_name(image_id, 3)} - {args.kernel_w} + {args.pad_l} + {args.pad_r}) // {args.stride_w} + 1", + ) + + +def serialize_model( + module, inputs, *, config=None, return_shapes=None, use_int16_for_qint16=False +): + """Convert to NNAPI and serialize torchscript module. + + Parameters: + module: Torchscript module to convert + inputs: Tensors used to specify input details for NNAPI + config (optional): Optional config to attach to module + return_shapes (optional): Specify shape of outputs if + your module uses runtime flexible shapes to set output + buffer size for NNAPI + use_int16_for_qint16 (optional): Use Pytorch int16 to represent NNAPI qint16 values + """ + return _NnapiSerializer(config, use_int16_for_qint16).serialize_model( + module, inputs, return_shapes + ) diff --git a/phivenv/Lib/site-packages/torch/backends/cpu/__init__.py b/phivenv/Lib/site-packages/torch/backends/cpu/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fedf31830caffed0b2b1e16de42598e1d3e0f3bc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/backends/cpu/__init__.py @@ -0,0 +1,21 @@ +import torch + + +__all__ = [ + "get_cpu_capability", +] + + +def get_cpu_capability() -> str: + r"""Return cpu capability as a string value. + + Possible values: + - "DEFAULT" + - "VSX" + - "Z VECTOR" + - "NO AVX" + - "AVX2" + - "AVX512" + - "SVE256" + """ + return torch._C._get_cpu_capability() diff --git a/phivenv/Lib/site-packages/torch/backends/cpu/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/backends/cpu/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01bf0ffbebe9696457e7197070729f6cc0668275 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/backends/cpu/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/backends/cuda/__init__.py b/phivenv/Lib/site-packages/torch/backends/cuda/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..499e86fa794f5ba8db75410b1c2a05d78353aa8a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/backends/cuda/__init__.py @@ -0,0 +1,536 @@ +# mypy: allow-untyped-defs +import contextlib +from typing import Union +from typing_extensions import deprecated + +import torch + + +__all__ = [ + "is_built", + "cuFFTPlanCacheAttrContextProp", + "cuFFTPlanCache", + "cuFFTPlanCacheManager", + "cuBLASModule", + "preferred_linalg_library", + "preferred_blas_library", + "preferred_rocm_fa_library", + "cufft_plan_cache", + "matmul", + "SDPAParams", + "enable_cudnn_sdp", + "cudnn_sdp_enabled", + "enable_flash_sdp", + "flash_sdp_enabled", + "enable_mem_efficient_sdp", + "mem_efficient_sdp_enabled", + "math_sdp_enabled", + "enable_math_sdp", + "allow_fp16_bf16_reduction_math_sdp", + "fp16_bf16_reduction_math_sdp_allowed", + "is_flash_attention_available", + "can_use_flash_attention", + "can_use_efficient_attention", + "can_use_cudnn_attention", + "sdp_kernel", +] + + +def is_built(): + r""" + Return whether PyTorch is built with CUDA support. + + Note that this doesn't necessarily mean CUDA is available; just that if this PyTorch + binary were run on a machine with working CUDA drivers and devices, we would be able to use it. + """ + return torch._C._has_cuda + + +class cuFFTPlanCacheAttrContextProp: + # Like regular ContextProp, but uses the `.device_index` attribute from the + # calling object as the first argument to the getter and setter. + def __init__(self, getter, setter): + self.getter = getter + self.setter = setter + + def __get__(self, obj, objtype): + return self.getter(obj.device_index) + + def __set__(self, obj, val): + if isinstance(self.setter, str): + raise RuntimeError(self.setter) + self.setter(obj.device_index, val) + + +class cuFFTPlanCache: + r""" + Represent a specific plan cache for a specific `device_index`. + + The attributes `size` and `max_size`, and method `clear`, can fetch and/ or + change properties of the C++ cuFFT plan cache. + """ + + def __init__(self, device_index): + self.device_index = device_index + + size = cuFFTPlanCacheAttrContextProp( + torch._cufft_get_plan_cache_size, + ".size is a read-only property showing the number of plans currently in the " + "cache. To change the cache capacity, set cufft_plan_cache.max_size.", + ) + + max_size = cuFFTPlanCacheAttrContextProp( + torch._cufft_get_plan_cache_max_size, torch._cufft_set_plan_cache_max_size + ) + + def clear(self): + return torch._cufft_clear_plan_cache(self.device_index) + + +class cuFFTPlanCacheManager: + r""" + Represent all cuFFT plan caches, return the cuFFTPlanCache for a given device when indexed. + + Finally, this object, when used directly as a `cuFFTPlanCache` object (e.g., + setting the `.max_size`) attribute, the current device's cuFFT plan cache is + used. + """ + + __initialized = False + + def __init__(self): + self.caches = [] + self.__initialized = True + + def __getitem__(self, device): + index = torch.cuda._utils._get_device_index(device) + if index < 0 or index >= torch.cuda.device_count(): + raise RuntimeError( + f"cufft_plan_cache: expected 0 <= device index < {torch.cuda.device_count()}, but got " + f"device with index {index}" + ) + if len(self.caches) == 0: + self.caches.extend( + cuFFTPlanCache(index) for index in range(torch.cuda.device_count()) + ) + return self.caches[index] + + def __getattr__(self, name): + return getattr(self[torch.cuda.current_device()], name) + + def __setattr__(self, name, value): + if self.__initialized: + return setattr(self[torch.cuda.current_device()], name, value) + else: + return super().__setattr__(name, value) + + +class cuBLASModule: + def __getattr__(self, name): + if name == "allow_tf32": + return torch._C._get_cublas_allow_tf32() + elif name == "allow_fp16_reduced_precision_reduction": + return torch._C._get_cublas_allow_fp16_reduced_precision_reduction() + elif name == "allow_bf16_reduced_precision_reduction": + return torch._C._get_cublas_allow_bf16_reduced_precision_reduction() + elif name == "allow_fp16_accumulation": + return torch._C._get_cublas_allow_fp16_accumulation() + raise AttributeError("Unknown attribute " + name) + + def __setattr__(self, name, value): + if name == "allow_tf32": + return torch._C._set_cublas_allow_tf32(value) + elif name == "allow_fp16_reduced_precision_reduction": + return torch._C._set_cublas_allow_fp16_reduced_precision_reduction(value) + elif name == "allow_bf16_reduced_precision_reduction": + return torch._C._set_cublas_allow_bf16_reduced_precision_reduction(value) + elif name == "allow_fp16_accumulation": + return torch._C._set_cublas_allow_fp16_accumulation(value) + raise AttributeError("Unknown attribute " + name) + + +_LinalgBackends = { + "default": torch._C._LinalgBackend.Default, + "cusolver": torch._C._LinalgBackend.Cusolver, + "magma": torch._C._LinalgBackend.Magma, +} +_LinalgBackends_str = ", ".join(_LinalgBackends.keys()) + + +def preferred_linalg_library( + backend: Union[None, str, torch._C._LinalgBackend] = None +) -> torch._C._LinalgBackend: + r""" + Override the heuristic PyTorch uses to choose between cuSOLVER and MAGMA for CUDA linear algebra operations. + + .. warning:: This flag is experimental and subject to change. + + When PyTorch runs a CUDA linear algebra operation it often uses the cuSOLVER or MAGMA libraries, + and if both are available it decides which to use with a heuristic. + This flag (a :class:`str`) allows overriding those heuristics. + + * If `"cusolver"` is set then cuSOLVER will be used wherever possible. + * If `"magma"` is set then MAGMA will be used wherever possible. + * If `"default"` (the default) is set then heuristics will be used to pick between + cuSOLVER and MAGMA if both are available. + * When no input is given, this function returns the currently preferred library. + * User may use the environment variable TORCH_LINALG_PREFER_CUSOLVER=1 to set the preferred library to cuSOLVER + globally. + This flag only sets the initial value of the preferred library and the preferred library + may still be overridden by this function call later in your script. + + Note: When a library is preferred other libraries may still be used if the preferred library + doesn't implement the operation(s) called. + This flag may achieve better performance if PyTorch's heuristic library selection is incorrect + for your application's inputs. + + Currently supported linalg operators: + + * :func:`torch.linalg.inv` + * :func:`torch.linalg.inv_ex` + * :func:`torch.linalg.cholesky` + * :func:`torch.linalg.cholesky_ex` + * :func:`torch.cholesky_solve` + * :func:`torch.cholesky_inverse` + * :func:`torch.linalg.lu_factor` + * :func:`torch.linalg.lu` + * :func:`torch.linalg.lu_solve` + * :func:`torch.linalg.qr` + * :func:`torch.linalg.eigh` + * :func:`torch.linalg.eighvals` + * :func:`torch.linalg.svd` + * :func:`torch.linalg.svdvals` + """ + if backend is None: + pass + elif isinstance(backend, str): + if backend not in _LinalgBackends: + raise RuntimeError( + "Unknown input value. " f"Choose from: {_LinalgBackends_str}." + ) + torch._C._set_linalg_preferred_backend(_LinalgBackends[backend]) + elif isinstance(backend, torch._C._LinalgBackend): + torch._C._set_linalg_preferred_backend(backend) + else: + raise RuntimeError("Unknown input value type.") + + return torch._C._get_linalg_preferred_backend() + + +_BlasBackends = { + "default": torch._C._BlasBackend.Default, + "cublas": torch._C._BlasBackend.Cublas, + "hipblas": torch._C._BlasBackend.Cublas, # alias + "cublaslt": torch._C._BlasBackend.Cublaslt, + "hipblaslt": torch._C._BlasBackend.Cublaslt, # alias + "ck": torch._C._BlasBackend.Ck, +} +_BlasBackends_str = ", ".join(_BlasBackends.keys()) + + +def preferred_blas_library( + backend: Union[None, str, torch._C._BlasBackend] = None +) -> torch._C._BlasBackend: + r""" + Override the library PyTorch uses for BLAS operations. Choose between cuBLAS, cuBLASLt, and CK [ROCm-only]. + + .. warning:: This flag is experimental and subject to change. + + When PyTorch runs a CUDA BLAS operation it defaults to cuBLAS even if both cuBLAS and cuBLASLt are available. + For PyTorch built for ROCm, hipBLAS, hipBLASLt, and CK may offer different performance. + This flag (a :class:`str`) allows overriding which BLAS library to use. + + * If `"cublas"` is set then cuBLAS will be used wherever possible. + * If `"cublaslt"` is set then cuBLASLt will be used wherever possible. + * If `"ck"` is set then CK will be used wherever possible. + * If `"default"` (the default) is set then heuristics will be used to pick between the other options. + * When no input is given, this function returns the currently preferred library. + * User may use the environment variable TORCH_BLAS_PREFER_CUBLASLT=1 to set the preferred library to cuBLASLt + globally. + This flag only sets the initial value of the preferred library and the preferred library + may still be overridden by this function call later in your script. + + Note: When a library is preferred other libraries may still be used if the preferred library + doesn't implement the operation(s) called. + This flag may achieve better performance if PyTorch's library selection is incorrect + for your application's inputs. + + """ + if backend is None: + pass + elif isinstance(backend, str): + if backend not in _BlasBackends: + raise RuntimeError( + "Unknown input value. " f"Choose from: {_BlasBackends_str}." + ) + torch._C._set_blas_preferred_backend(_BlasBackends[backend]) + elif isinstance(backend, torch._C._BlasBackend): + torch._C._set_blas_preferred_backend(backend) + else: + raise RuntimeError("Unknown input value type.") + + return torch._C._get_blas_preferred_backend() + + +_ROCmFABackends = { + "default": torch._C._ROCmFABackend.Default, + "aotriton": torch._C._ROCmFABackend.AOTriton, + "ck": torch._C._ROCmFABackend.Ck, +} +_ROCmFABackends_str = ", ".join(_ROCmFABackends.keys()) + + +from torch._C import _SDPAParams as SDPAParams, _SDPBackend as SDPBackend + + +def preferred_rocm_fa_library( + backend: Union[None, str, torch._C._ROCmFABackend] = None +) -> torch._C._ROCmFABackend: + r""" + [ROCm-only] + Override the backend PyTorch uses in ROCm environments for Flash Attention. Choose between AOTriton and CK + + .. warning:: This flag is experimeental and subject to change. + + When Flash Attention is enabled and desired, PyTorch defaults to using AOTriton as the backend. + This flag (a :class:`str`) allows users to override this backend to use composable_kernel + + * If `"default"` is set then the default backend will be used wherever possible. Currently AOTriton. + * If `"aotriton"` is set then AOTriton will be used wherever possible. + * If `"ck"` is set then CK will be used wherever possible. + * When no input is given, this function returns the currently preferred library. + * User may use the environment variable TORCH_ROCM_FA_PREFER_CK=1 to set the preferred library to CK + globally. + + Note: When a library is preferred other libraries may still be used if the preferred library + doesn't implement the operation(s) called. + This flag may achieve better performance if PyTorch's library selection is incorrect + for your application's inputs. + """ + if backend is None: + pass + elif isinstance(backend, str): + if backend not in _ROCmFABackends: + raise RuntimeError( + "Unknown input value. " f"Choose from: {_ROCmFABackends_str}." + ) + torch._C._set_rocm_fa_preferred_backend(_ROCmFABackends[backend]) + elif isinstance(backend, torch._C._ROCmFABackend): + torch._C._set_rocm_fa_preferred_backend(backend) + else: + raise ValueError("Unknown input value. " f"Choose from: {_ROCmFABackends_str}.") + + return torch._C._get_rocm_fa_preferred_backend() + + +# Set the __module__ attribute +SDPAParams.__module__ = "torch.backends.cuda" +SDPAParams.__name__ = "SDPAParams" + + +def flash_sdp_enabled(): + r""" + .. warning:: This flag is beta and subject to change. + + Returns whether flash scaled dot product attention is enabled or not. + """ + return torch._C._get_flash_sdp_enabled() + + +def enable_flash_sdp(enabled: bool): + r""" + .. warning:: This flag is beta and subject to change. + + Enables or disables flash scaled dot product attention. + """ + torch._C._set_sdp_use_flash(enabled) + + +def mem_efficient_sdp_enabled(): + r""" + .. warning:: This flag is beta and subject to change. + + Returns whether memory efficient scaled dot product attention is enabled or not. + """ + return torch._C._get_mem_efficient_sdp_enabled() + + +def enable_mem_efficient_sdp(enabled: bool): + r""" + .. warning:: This flag is beta and subject to change. + + Enables or disables memory efficient scaled dot product attention. + """ + torch._C._set_sdp_use_mem_efficient(enabled) + + +def math_sdp_enabled(): + r""" + .. warning:: This flag is beta and subject to change. + + Returns whether math scaled dot product attention is enabled or not. + """ + return torch._C._get_math_sdp_enabled() + + +def enable_math_sdp(enabled: bool): + r""" + .. warning:: This flag is beta and subject to change. + + Enables or disables math scaled dot product attention. + """ + torch._C._set_sdp_use_math(enabled) + + +def allow_fp16_bf16_reduction_math_sdp(enabled: bool): + r""" + .. warning:: This flag is beta and subject to change. + + Enables or disables fp16/bf16 reduction in math scaled dot product attention. + """ + torch._C._set_math_sdp_allow_fp16_bf16_reduction(enabled) + + +def fp16_bf16_reduction_math_sdp_allowed(): + r""" + .. warning:: This flag is beta and subject to change. + + Returns whether fp16/bf16 reduction in math scaled dot product attention is enabled or not. + """ + return torch._C._get_math_sdp_allow_fp16_bf16_reduction() + + +def is_flash_attention_available() -> bool: + r"""Check if PyTorch was built with FlashAttention for scaled_dot_product_attention. + + Returns: + True if FlashAttention is built and available; otherwise, False. + + Note: + This function is dependent on a CUDA-enabled build of PyTorch. It will return False + in non-CUDA environments. + """ + return torch._C._is_flash_attention_available() + + +def can_use_flash_attention(params: SDPAParams, debug: bool = False) -> bool: + r"""Check if FlashAttention can be utilized in scaled_dot_product_attention. + + Args: + params: An instance of SDPAParams containing the tensors for query, + key, value, an optional attention mask, dropout rate, and + a flag indicating if the attention is causal. + debug: Whether to logging.warn debug information as to why FlashAttention could not be run. + Defaults to False. + + Returns: + True if FlashAttention can be used with the given parameters; otherwise, False. + + Note: + This function is dependent on a CUDA-enabled build of PyTorch. It will return False + in non-CUDA environments. + """ + return torch._C._can_use_flash_attention(params, debug) + + +def can_use_efficient_attention(params: SDPAParams, debug: bool = False) -> bool: + r"""Check if efficient_attention can be utilized in scaled_dot_product_attention. + + Args: + params: An instance of SDPAParams containing the tensors for query, + key, value, an optional attention mask, dropout rate, and + a flag indicating if the attention is causal. + debug: Whether to logging.warn with information as to why efficient_attention could not be run. + Defaults to False. + + Returns: + True if efficient_attention can be used with the given parameters; otherwise, False. + + Note: + This function is dependent on a CUDA-enabled build of PyTorch. It will return False + in non-CUDA environments. + """ + return torch._C._can_use_mem_efficient_attention(params, debug) + + +def can_use_cudnn_attention(params: SDPAParams, debug: bool = False) -> bool: + r"""Check if cudnn_attention can be utilized in scaled_dot_product_attention. + + Args: + params: An instance of SDPAParams containing the tensors for query, + key, value, an optional attention mask, dropout rate, and + a flag indicating if the attention is causal. + debug: Whether to logging.warn with information as to why cuDNN attention could not be run. + Defaults to False. + + Returns: + True if cuDNN can be used with the given parameters; otherwise, False. + + Note: + This function is dependent on a CUDA-enabled build of PyTorch. It will return False + in non-CUDA environments. + """ + return torch._C._can_use_cudnn_attention(params, debug) + + +def cudnn_sdp_enabled(): + r""" + .. warning:: This flag is beta and subject to change. + + Returns whether cuDNN scaled dot product attention is enabled or not. + """ + return torch._C._get_cudnn_sdp_enabled() + + +def enable_cudnn_sdp(enabled: bool): + r""" + .. warning:: This flag is beta and subject to change. + + Enables or disables cuDNN scaled dot product attention. + """ + torch._C._set_sdp_use_cudnn(enabled) + + +@contextlib.contextmanager +@deprecated( + ( + "`torch.backends.cuda.sdp_kernel()` is deprecated. " + "In the future, this context manager will be removed. " + "Please see `torch.nn.attention.sdpa_kernel()` for the new context manager, " + "with updated signature." + ), + category=FutureWarning, +) +def sdp_kernel( + enable_flash: bool = True, + enable_math: bool = True, + enable_mem_efficient: bool = True, + enable_cudnn: bool = True, +): + r""" + .. warning:: This flag is beta and subject to change. + + This context manager can be used to temporarily enable or disable any of the three backends for scaled dot product attention. + Upon exiting the context manager, the previous state of the flags will be restored. + """ + from torch.nn.attention import sdpa_kernel + + backend_list = [] + if enable_flash: + backend_list.append(SDPBackend.FLASH_ATTENTION) + if enable_mem_efficient: + backend_list.append(SDPBackend.EFFICIENT_ATTENTION) + if enable_math: + backend_list.append(SDPBackend.MATH) + if enable_cudnn: + backend_list.append(SDPBackend.CUDNN_ATTENTION) + + with sdpa_kernel(backend_list) as context: + try: + yield context + finally: + pass + + +cufft_plan_cache = cuFFTPlanCacheManager() +matmul = cuBLASModule() diff --git a/phivenv/Lib/site-packages/torch/backends/cuda/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/backends/cuda/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d99e6a027a71bc3ae52c8210ee6c39c2bad597d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/backends/cuda/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/backends/cudnn/__init__.py b/phivenv/Lib/site-packages/torch/backends/cudnn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d55e2828acca5be3c4eaad1dfb2cd333975b1016 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/backends/cudnn/__init__.py @@ -0,0 +1,208 @@ +# mypy: allow-untyped-defs +import os +import sys +import warnings +from contextlib import contextmanager +from typing import Optional + +import torch +from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule + + +try: + from torch._C import _cudnn +except ImportError: + _cudnn = None # type: ignore[assignment] + +# Write: +# +# torch.backends.cudnn.enabled = False +# +# to globally disable CuDNN/MIOpen + +__cudnn_version: Optional[int] = None + +if _cudnn is not None: + + def _init(): + global __cudnn_version + if __cudnn_version is None: + __cudnn_version = _cudnn.getVersionInt() + runtime_version = _cudnn.getRuntimeVersion() + compile_version = _cudnn.getCompileVersion() + runtime_major, runtime_minor, _ = runtime_version + compile_major, compile_minor, _ = compile_version + # Different major versions are always incompatible + # Starting with cuDNN 7, minor versions are backwards-compatible + # Not sure about MIOpen (ROCm), so always do a strict check + if runtime_major != compile_major: + cudnn_compatible = False + elif runtime_major < 7 or not _cudnn.is_cuda: + cudnn_compatible = runtime_minor == compile_minor + else: + cudnn_compatible = runtime_minor >= compile_minor + if not cudnn_compatible: + if os.environ.get("PYTORCH_SKIP_CUDNN_COMPATIBILITY_CHECK", "0") == "1": + return True + base_error_msg = ( + f"cuDNN version incompatibility: " + f"PyTorch was compiled against {compile_version} " + f"but found runtime version {runtime_version}. " + f"PyTorch already comes bundled with cuDNN. " + f"One option to resolving this error is to ensure PyTorch " + f"can find the bundled cuDNN. " + ) + + if "LD_LIBRARY_PATH" in os.environ: + ld_library_path = os.environ.get("LD_LIBRARY_PATH", "") + if any( + substring in ld_library_path for substring in ["cuda", "cudnn"] + ): + raise RuntimeError( + f"{base_error_msg}" + f"Looks like your LD_LIBRARY_PATH contains incompatible version of cudnn. " + f"Please either remove it from the path or install cudnn {compile_version}" + ) + else: + raise RuntimeError( + f"{base_error_msg}" + f"one possibility is that there is a " + f"conflicting cuDNN in LD_LIBRARY_PATH." + ) + else: + raise RuntimeError(base_error_msg) + + return True + +else: + + def _init(): + return False + + +def version(): + """Return the version of cuDNN.""" + if not _init(): + return None + return __cudnn_version + + +CUDNN_TENSOR_DTYPES = { + torch.half, + torch.float, + torch.double, +} + + +def is_available(): + r"""Return a bool indicating if CUDNN is currently available.""" + return torch._C._has_cudnn + + +def is_acceptable(tensor): + if not torch._C._get_cudnn_enabled(): + return False + if tensor.device.type != "cuda" or tensor.dtype not in CUDNN_TENSOR_DTYPES: + return False + if not is_available(): + warnings.warn( + "PyTorch was compiled without cuDNN/MIOpen support. To use cuDNN/MIOpen, rebuild " + "PyTorch making sure the library is visible to the build system." + ) + return False + if not _init(): + warnings.warn( + "cuDNN/MIOpen library not found. Check your {libpath}".format( + libpath={"darwin": "DYLD_LIBRARY_PATH", "win32": "PATH"}.get( + sys.platform, "LD_LIBRARY_PATH" + ) + ) + ) + return False + return True + + +def set_flags( + _enabled=None, + _benchmark=None, + _benchmark_limit=None, + _deterministic=None, + _allow_tf32=None, +): + orig_flags = ( + torch._C._get_cudnn_enabled(), + torch._C._get_cudnn_benchmark(), + None if not is_available() else torch._C._cuda_get_cudnn_benchmark_limit(), + torch._C._get_cudnn_deterministic(), + torch._C._get_cudnn_allow_tf32(), + ) + if _enabled is not None: + torch._C._set_cudnn_enabled(_enabled) + if _benchmark is not None: + torch._C._set_cudnn_benchmark(_benchmark) + if _benchmark_limit is not None and is_available(): + torch._C._cuda_set_cudnn_benchmark_limit(_benchmark_limit) + if _deterministic is not None: + torch._C._set_cudnn_deterministic(_deterministic) + if _allow_tf32 is not None: + torch._C._set_cudnn_allow_tf32(_allow_tf32) + return orig_flags + + +@contextmanager +def flags( + enabled=False, + benchmark=False, + benchmark_limit=10, + deterministic=False, + allow_tf32=True, +): + with __allow_nonbracketed_mutation(): + orig_flags = set_flags( + enabled, benchmark, benchmark_limit, deterministic, allow_tf32 + ) + try: + yield + finally: + # recover the previous values + with __allow_nonbracketed_mutation(): + set_flags(*orig_flags) + + +# The magic here is to allow us to intercept code like this: +# +# torch.backends..enabled = True + + +class CudnnModule(PropModule): + def __init__(self, m, name): + super().__init__(m, name) + + enabled = ContextProp(torch._C._get_cudnn_enabled, torch._C._set_cudnn_enabled) + deterministic = ContextProp( + torch._C._get_cudnn_deterministic, torch._C._set_cudnn_deterministic + ) + benchmark = ContextProp( + torch._C._get_cudnn_benchmark, torch._C._set_cudnn_benchmark + ) + benchmark_limit = None + if is_available(): + benchmark_limit = ContextProp( + torch._C._cuda_get_cudnn_benchmark_limit, + torch._C._cuda_set_cudnn_benchmark_limit, + ) + allow_tf32 = ContextProp( + torch._C._get_cudnn_allow_tf32, torch._C._set_cudnn_allow_tf32 + ) + + +# This is the sys.modules replacement trick, see +# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273 +sys.modules[__name__] = CudnnModule(sys.modules[__name__], __name__) + +# Add type annotation for the replaced module +enabled: bool +deterministic: bool +benchmark: bool +allow_tf32: bool +benchmark_limit: int diff --git a/phivenv/Lib/site-packages/torch/backends/cudnn/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/backends/cudnn/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31b9c0edc479b73cfec514e55eff840b9aa154cf Binary files /dev/null and b/phivenv/Lib/site-packages/torch/backends/cudnn/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/backends/cudnn/__pycache__/rnn.cpython-39.pyc b/phivenv/Lib/site-packages/torch/backends/cudnn/__pycache__/rnn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eacc514b282ad2a0d78648a2abfc64f2d6ea1a84 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/backends/cudnn/__pycache__/rnn.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/backends/cudnn/rnn.py b/phivenv/Lib/site-packages/torch/backends/cudnn/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..d78538f0cec0d76eee4cd4180b2812193f0908dc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/backends/cudnn/rnn.py @@ -0,0 +1,64 @@ +# mypy: allow-untyped-defs +import torch.cuda + + +try: + from torch._C import _cudnn +except ImportError: + # Uses of all the functions below should be guarded by torch.backends.cudnn.is_available(), + # so it's safe to not emit any checks here. + _cudnn = None # type: ignore[assignment] + + +def get_cudnn_mode(mode): + if mode == "RNN_RELU": + return int(_cudnn.RNNMode.rnn_relu) + elif mode == "RNN_TANH": + return int(_cudnn.RNNMode.rnn_tanh) + elif mode == "LSTM": + return int(_cudnn.RNNMode.lstm) + elif mode == "GRU": + return int(_cudnn.RNNMode.gru) + else: + raise Exception(f"Unknown mode: {mode}") # noqa: TRY002 + + +# NB: We don't actually need this class anymore (in fact, we could serialize the +# dropout state for even better reproducibility), but it is kept for backwards +# compatibility for old models. +class Unserializable: + def __init__(self, inner): + self.inner = inner + + def get(self): + return self.inner + + def __getstate__(self): + # Note: can't return {}, because python2 won't call __setstate__ + # if the value evaluates to False + return "" + + def __setstate__(self, state): + self.inner = None + + +def init_dropout_state(dropout, train, dropout_seed, dropout_state): + dropout_desc_name = "desc_" + str(torch.cuda.current_device()) + dropout_p = dropout if train else 0 + if (dropout_desc_name not in dropout_state) or ( + dropout_state[dropout_desc_name].get() is None + ): + if dropout_p == 0: + dropout_state[dropout_desc_name] = Unserializable(None) + else: + dropout_state[dropout_desc_name] = Unserializable( + torch._cudnn_init_dropout_state( # type: ignore[call-arg] + dropout_p, + train, + dropout_seed, + self_ty=torch.uint8, + device=torch.device("cuda"), + ) + ) + dropout_ts = dropout_state[dropout_desc_name].get() + return dropout_ts diff --git a/phivenv/Lib/site-packages/torch/backends/cusparselt/__init__.py b/phivenv/Lib/site-packages/torch/backends/cusparselt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae4bd2d436a2c2c64dbc8b81d6e7d9bc53002931 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/backends/cusparselt/__init__.py @@ -0,0 +1,57 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch + + +__all__ = [ + "version", + "is_available", + "get_max_alg_id", +] + +try: + from torch._C import _cusparselt +except ImportError: + _cusparselt = None # type: ignore[assignment] + +__cusparselt_version: Optional[int] = None +__MAX_ALG_ID: Optional[int] = None + +if _cusparselt is not None: + + def _init(): + global __cusparselt_version + global __MAX_ALG_ID + if __cusparselt_version is None: + __cusparselt_version = _cusparselt.getVersionInt() + if __cusparselt_version == 400: + __MAX_ALG_ID = 4 + elif __cusparselt_version == 502: + __MAX_ALG_ID = 5 + elif __cusparselt_version == 602: + __MAX_ALG_ID = 37 + return True + +else: + + def _init(): + return False + + +def version() -> Optional[int]: + """Return the version of cuSPARSELt""" + if not _init(): + return None + return __cusparselt_version + + +def is_available() -> bool: + r"""Return a bool indicating if cuSPARSELt is currently available.""" + return torch._C._has_cusparselt + + +def get_max_alg_id() -> Optional[int]: + if not _init(): + return None + return __MAX_ALG_ID diff --git a/phivenv/Lib/site-packages/torch/backends/cusparselt/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/backends/cusparselt/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e73f984adc9548507b26c85a12006be8b5703d66 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/backends/cusparselt/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/backends/kleidiai/__init__.py b/phivenv/Lib/site-packages/torch/backends/kleidiai/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4ba82b8f269eafc30d8964b68e84e28cb957523f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/backends/kleidiai/__init__.py @@ -0,0 +1,7 @@ +# mypy: allow-untyped-defs +import torch + + +def is_available(): + r"""Return whether PyTorch is built with KleidiAI support.""" + return torch._C._has_kleidiai diff --git a/phivenv/Lib/site-packages/torch/backends/kleidiai/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/backends/kleidiai/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a426e1c7208ae60cad939d425db824d0bcb6a5c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/backends/kleidiai/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/backends/mha/__init__.py b/phivenv/Lib/site-packages/torch/backends/mha/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..54693e7dd54720c74a1157cc413bbcd6beab6592 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/backends/mha/__init__.py @@ -0,0 +1,25 @@ +# Config options to enable/disable C++ kernel for nn.functional.MHA +# and nn.TransformerEncoder +import torch + + +_is_fastpath_enabled: bool = True + + +def get_fastpath_enabled() -> bool: + """Returns whether fast path for TransformerEncoder and MultiHeadAttention + is enabled, or ``True`` if jit is scripting. + + .. note:: + The fastpath might not be run even if ``get_fastpath_enabled`` returns + ``True`` unless all conditions on inputs are met. + """ + if not torch.jit.is_scripting(): + return _is_fastpath_enabled + return True + + +def set_fastpath_enabled(value: bool) -> None: + """Sets whether fast path is enabled""" + global _is_fastpath_enabled + _is_fastpath_enabled = value diff --git a/phivenv/Lib/site-packages/torch/backends/mha/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/backends/mha/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d177f044743082dccfdb3d88c03bfcab063a72a7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/backends/mha/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/backends/mkl/__init__.py b/phivenv/Lib/site-packages/torch/backends/mkl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..76316ec38e3e169ea5a734c08b6d87d499f74798 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/backends/mkl/__init__.py @@ -0,0 +1,57 @@ +# mypy: allow-untyped-defs +import torch + + +def is_available(): + r"""Return whether PyTorch is built with MKL support.""" + return torch._C.has_mkl + + +VERBOSE_OFF = 0 +VERBOSE_ON = 1 + + +class verbose: + """ + On-demand oneMKL verbosing functionality. + + To make it easier to debug performance issues, oneMKL can dump verbose + messages containing execution information like duration while executing + the kernel. The verbosing functionality can be invoked via an environment + variable named `MKL_VERBOSE`. However, this methodology dumps messages in + all steps. Those are a large amount of verbose messages. Moreover, for + investigating the performance issues, generally taking verbose messages + for one single iteration is enough. This on-demand verbosing functionality + makes it possible to control scope for verbose message dumping. In the + following example, verbose messages will be dumped out for the second + inference only. + + .. highlight:: python + .. code-block:: python + + import torch + model(data) + with torch.backends.mkl.verbose(torch.backends.mkl.VERBOSE_ON): + model(data) + + Args: + level: Verbose level + - ``VERBOSE_OFF``: Disable verbosing + - ``VERBOSE_ON``: Enable verbosing + """ + + def __init__(self, enable): + self.enable = enable + + def __enter__(self): + if self.enable == VERBOSE_OFF: + return + st = torch._C._verbose.mkl_set_verbose(self.enable) + assert ( + st + ), "Failed to set MKL into verbose mode. Please consider to disable this verbose scope." + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + torch._C._verbose.mkl_set_verbose(VERBOSE_OFF) + return False diff --git a/phivenv/Lib/site-packages/torch/backends/mkl/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/backends/mkl/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34842c76cd773f16c164f53cb1d9935eec810fa6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/backends/mkl/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/backends/mkldnn/__init__.py b/phivenv/Lib/site-packages/torch/backends/mkldnn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e1092fce004af3958aea272d4aad3b6ac31bb877 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/backends/mkldnn/__init__.py @@ -0,0 +1,114 @@ +# mypy: allow-untyped-defs +import sys +from contextlib import contextmanager +from typing import TYPE_CHECKING + +import torch +from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule + + +def is_available(): + r"""Return whether PyTorch is built with MKL-DNN support.""" + return torch._C._has_mkldnn + + +VERBOSE_OFF = 0 +VERBOSE_ON = 1 +VERBOSE_ON_CREATION = 2 + + +class verbose: + """ + On-demand oneDNN (former MKL-DNN) verbosing functionality. + + To make it easier to debug performance issues, oneDNN can dump verbose + messages containing information like kernel size, input data size and + execution duration while executing the kernel. The verbosing functionality + can be invoked via an environment variable named `DNNL_VERBOSE`. However, + this methodology dumps messages in all steps. Those are a large amount of + verbose messages. Moreover, for investigating the performance issues, + generally taking verbose messages for one single iteration is enough. + This on-demand verbosing functionality makes it possible to control scope + for verbose message dumping. In the following example, verbose messages + will be dumped out for the second inference only. + + .. highlight:: python + .. code-block:: python + + import torch + model(data) + with torch.backends.mkldnn.verbose(torch.backends.mkldnn.VERBOSE_ON): + model(data) + + Args: + level: Verbose level + - ``VERBOSE_OFF``: Disable verbosing + - ``VERBOSE_ON``: Enable verbosing + - ``VERBOSE_ON_CREATION``: Enable verbosing, including oneDNN kernel creation + """ + + def __init__(self, level): + self.level = level + + def __enter__(self): + if self.level == VERBOSE_OFF: + return + st = torch._C._verbose.mkldnn_set_verbose(self.level) + assert ( + st + ), "Failed to set MKLDNN into verbose mode. Please consider to disable this verbose scope." + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + torch._C._verbose.mkldnn_set_verbose(VERBOSE_OFF) + return False + + +def set_flags(_enabled=None, _deterministic=None, _allow_tf32=None): + orig_flags = ( + torch._C._get_mkldnn_enabled(), + torch._C._get_mkldnn_deterministic(), + torch._C._get_onednn_allow_tf32(), + ) + if _enabled is not None: + torch._C._set_mkldnn_enabled(_enabled) + if _deterministic is not None: + torch._C._set_mkldnn_deterministic(_deterministic) + if _allow_tf32 is not None: + torch._C._set_onednn_allow_tf32(_allow_tf32) + return orig_flags + + +@contextmanager +def flags(enabled=False, deterministic=False, allow_tf32=True): + with __allow_nonbracketed_mutation(): + orig_flags = set_flags(enabled, deterministic, allow_tf32) + try: + yield + finally: + with __allow_nonbracketed_mutation(): + set_flags(*orig_flags) + + +class MkldnnModule(PropModule): + def __init__(self, m, name): + super().__init__(m, name) + + def is_available(self): + return is_available() + + enabled = ContextProp(torch._C._get_mkldnn_enabled, torch._C._set_mkldnn_enabled) + deterministic = ContextProp( + torch._C._get_mkldnn_deterministic, torch._C._set_mkldnn_deterministic + ) + allow_tf32 = ContextProp( + torch._C._get_onednn_allow_tf32, torch._C._set_onednn_allow_tf32 + ) + + +if TYPE_CHECKING: + enabled: ContextProp + deterministic: ContextProp + allow_tf32: ContextProp + +sys.modules[__name__] = MkldnnModule(sys.modules[__name__], __name__) diff --git a/phivenv/Lib/site-packages/torch/backends/mkldnn/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/backends/mkldnn/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8aae2a1cef7e0039b32c369f3d430732c5dcc00e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/backends/mkldnn/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/backends/mps/__init__.py b/phivenv/Lib/site-packages/torch/backends/mps/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..93d0e4525f69092209afb60ad1b511e5fbc590f9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/backends/mps/__init__.py @@ -0,0 +1,55 @@ +# mypy: allow-untyped-defs +from functools import lru_cache as _lru_cache +from typing import Optional, TYPE_CHECKING + +import torch +from torch.library import Library as _Library + + +__all__ = ["is_built", "is_available", "is_macos13_or_newer", "is_macos_or_newer"] + + +def is_built() -> bool: + r"""Return whether PyTorch is built with MPS support. + + Note that this doesn't necessarily mean MPS is available; just that + if this PyTorch binary were run a machine with working MPS drivers + and devices, we would be able to use it. + """ + return torch._C._has_mps + + +@_lru_cache +def is_available() -> bool: + r"""Return a bool indicating if MPS is currently available.""" + return torch._C._mps_is_available() + + +@_lru_cache +def is_macos_or_newer(major: int, minor: int) -> bool: + r"""Return a bool indicating whether MPS is running on given MacOS or newer.""" + return torch._C._mps_is_on_macos_or_newer(major, minor) + + +@_lru_cache +def is_macos13_or_newer(minor: int = 0) -> bool: + r"""Return a bool indicating whether MPS is running on MacOS 13 or newer.""" + return torch._C._mps_is_on_macos_or_newer(13, minor) + + +_lib: Optional[_Library] = None + + +def _init(): + r"""Register prims as implementation of var_mean and group_norm.""" + global _lib + + if _lib is not None or not is_built(): + return + + from torch._decomp.decompositions import native_group_norm_backward + from torch._refs import native_group_norm + + _lib = _Library("aten", "IMPL") # noqa: TOR901 + _lib.impl("native_group_norm", native_group_norm, "MPS") + _lib.impl("native_group_norm_backward", native_group_norm_backward, "MPS") diff --git a/phivenv/Lib/site-packages/torch/backends/mps/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/backends/mps/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ddb805da0614f76b3945c6dad59a5bd4fdf34f9e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/backends/mps/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/backends/nnpack/__init__.py b/phivenv/Lib/site-packages/torch/backends/nnpack/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cc9f01a16cd4d873100e3a8d234d9e2752305a8e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/backends/nnpack/__init__.py @@ -0,0 +1,32 @@ +# mypy: allow-untyped-defs +from contextlib import contextmanager + +import torch +from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule + + +__all__ = ["is_available", "flags", "set_flags"] + + +def is_available(): + r"""Return whether PyTorch is built with NNPACK support.""" + return torch._nnpack_available() + + +def set_flags(_enabled): + r"""Set if nnpack is enabled globally""" + orig_flags = (torch._C._get_nnpack_enabled(),) + torch._C._set_nnpack_enabled(_enabled) + return orig_flags + + +@contextmanager +def flags(enabled=False): + r"""Context manager for setting if nnpack is enabled globally""" + with __allow_nonbracketed_mutation(): + orig_flags = set_flags(enabled) + try: + yield + finally: + with __allow_nonbracketed_mutation(): + set_flags(orig_flags[0]) diff --git a/phivenv/Lib/site-packages/torch/backends/nnpack/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/backends/nnpack/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8abf92bbf3f83ca2353aa7040cbf06fe51e97726 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/backends/nnpack/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/backends/openmp/__init__.py b/phivenv/Lib/site-packages/torch/backends/openmp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..62b181dff7241084824598098f9430b243f99ad5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/backends/openmp/__init__.py @@ -0,0 +1,7 @@ +# mypy: allow-untyped-defs +import torch + + +def is_available(): + r"""Return whether PyTorch is built with OpenMP support.""" + return torch._C.has_openmp diff --git a/phivenv/Lib/site-packages/torch/backends/openmp/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/backends/openmp/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf9fb2075dabe321313ed41ebd8360465227c581 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/backends/openmp/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/backends/opt_einsum/__init__.py b/phivenv/Lib/site-packages/torch/backends/opt_einsum/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7a7a9fb9b3e2cd9bddcfe1390cfa182c9952d561 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/backends/opt_einsum/__init__.py @@ -0,0 +1,119 @@ +# mypy: allow-untyped-defs +import sys +import warnings +from contextlib import contextmanager +from functools import lru_cache as _lru_cache +from typing import Any + +from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule + + +try: + import opt_einsum as _opt_einsum # type: ignore[import] +except ImportError: + _opt_einsum = None + + +@_lru_cache +def is_available() -> bool: + r"""Return a bool indicating if opt_einsum is currently available. + + You must install opt-einsum in order for torch to automatically optimize einsum. To + make opt-einsum available, you can install it along with torch: ``pip install torch[opt-einsum]`` + or by itself: ``pip install opt-einsum``. If the package is installed, torch will import + it automatically and use it accordingly. Use this function to check whether opt-einsum + was installed and properly imported by torch. + """ + return _opt_einsum is not None + + +def get_opt_einsum() -> Any: + r"""Return the opt_einsum package if opt_einsum is currently available, else None.""" + return _opt_einsum + + +def _set_enabled(_enabled: bool) -> None: + if not is_available() and _enabled: + raise ValueError( + f"opt_einsum is not available, so setting `enabled` to {_enabled} will not reap " + "the benefits of calculating an optimal path for einsum. torch.einsum will " + "fall back to contracting from left to right. To enable this optimal path " + "calculation, please install opt-einsum." + ) + global enabled + enabled = _enabled + + +def _get_enabled() -> bool: + return enabled + + +def _set_strategy(_strategy: str) -> None: + if not is_available(): + raise ValueError( + f"opt_einsum is not available, so setting `strategy` to {_strategy} will not be meaningful. " + "torch.einsum will bypass path calculation and simply contract from left to right. " + "Please install opt_einsum or unset `strategy`." + ) + if not enabled: + raise ValueError( + f"opt_einsum is not enabled, so setting a `strategy` to {_strategy} will not be meaningful. " + "torch.einsum will bypass path calculation and simply contract from left to right. " + "Please set `enabled` to `True` as well or unset `strategy`." + ) + if _strategy not in ["auto", "greedy", "optimal"]: + raise ValueError( + f"`strategy` must be one of the following: [auto, greedy, optimal] but is {_strategy}" + ) + global strategy + strategy = _strategy + + +def _get_strategy() -> str: + return strategy + + +def set_flags(_enabled=None, _strategy=None): + orig_flags = (enabled, None if not is_available() else strategy) + if _enabled is not None: + _set_enabled(_enabled) + if _strategy is not None: + _set_strategy(_strategy) + return orig_flags + + +@contextmanager +def flags(enabled=None, strategy=None): + with __allow_nonbracketed_mutation(): + orig_flags = set_flags(enabled, strategy) + try: + yield + finally: + # recover the previous values + with __allow_nonbracketed_mutation(): + set_flags(*orig_flags) + + +# The magic here is to allow us to intercept code like this: +# +# torch.backends.opt_einsum.enabled = True + + +class OptEinsumModule(PropModule): + def __init__(self, m, name): + super().__init__(m, name) + + global enabled + enabled = ContextProp(_get_enabled, _set_enabled) + global strategy + strategy = None + if is_available(): + strategy = ContextProp(_get_strategy, _set_strategy) + + +# This is the sys.modules replacement trick, see +# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273 +sys.modules[__name__] = OptEinsumModule(sys.modules[__name__], __name__) + +enabled = True if is_available() else False +strategy = "auto" if is_available() else None diff --git a/phivenv/Lib/site-packages/torch/backends/opt_einsum/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/backends/opt_einsum/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9afe50c388e8b81959fe9340b82cd777aedb57d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/backends/opt_einsum/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/backends/quantized/__init__.py b/phivenv/Lib/site-packages/torch/backends/quantized/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..464ba861c73d2247277e5c081c87fc6b1dbcd77c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/backends/quantized/__init__.py @@ -0,0 +1,65 @@ +# mypy: allow-untyped-defs +import sys +import types + +import torch + + +# This function should correspond to the enums present in c10/core/QEngine.h +def _get_qengine_id(qengine: str) -> int: + if qengine == "none" or qengine == "" or qengine is None: + ret = 0 + elif qengine == "fbgemm": + ret = 1 + elif qengine == "qnnpack": + ret = 2 + elif qengine == "onednn": + ret = 3 + elif qengine == "x86": + ret = 4 + else: + ret = -1 + raise RuntimeError(f"{qengine} is not a valid value for quantized engine") + return ret + + +# This function should correspond to the enums present in c10/core/QEngine.h +def _get_qengine_str(qengine: int) -> str: + all_engines = {0: "none", 1: "fbgemm", 2: "qnnpack", 3: "onednn", 4: "x86"} + return all_engines.get(qengine, "*undefined") + + +class _QEngineProp: + def __get__(self, obj, objtype) -> str: + return _get_qengine_str(torch._C._get_qengine()) + + def __set__(self, obj, val: str) -> None: + torch._C._set_qengine(_get_qengine_id(val)) + + +class _SupportedQEnginesProp: + def __get__(self, obj, objtype) -> list[str]: + qengines = torch._C._supported_qengines() + return [_get_qengine_str(qe) for qe in qengines] + + def __set__(self, obj, val) -> None: + raise RuntimeError("Assignment not supported") + + +class QuantizedEngine(types.ModuleType): + def __init__(self, m, name): + super().__init__(name) + self.m = m + + def __getattr__(self, attr): + return self.m.__getattribute__(attr) + + engine = _QEngineProp() + supported_engines = _SupportedQEnginesProp() + + +# This is the sys.modules replacement trick, see +# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273 +sys.modules[__name__] = QuantizedEngine(sys.modules[__name__], __name__) +engine: str +supported_engines: list[str] diff --git a/phivenv/Lib/site-packages/torch/backends/quantized/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/backends/quantized/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9531f9b5ece764f96438eb9a9d200a17c6e1454c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/backends/quantized/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/backends/xeon/__init__.py b/phivenv/Lib/site-packages/torch/backends/xeon/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/backends/xeon/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/backends/xeon/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67314aaf2f8d84618459490c394771c1db3493bb Binary files /dev/null and b/phivenv/Lib/site-packages/torch/backends/xeon/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/backends/xeon/__pycache__/run_cpu.cpython-39.pyc b/phivenv/Lib/site-packages/torch/backends/xeon/__pycache__/run_cpu.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8cd3806364d6acff100bbf8c4291d6e25882d0a9 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/backends/xeon/__pycache__/run_cpu.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/backends/xeon/run_cpu.py b/phivenv/Lib/site-packages/torch/backends/xeon/run_cpu.py new file mode 100644 index 0000000000000000000000000000000000000000..22c7ec32d82c6379826c02b580706b03774cea35 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/backends/xeon/run_cpu.py @@ -0,0 +1,942 @@ +# mypy: allow-untyped-defs +""" +This is a script for launching PyTorch inference on Intel(R) Xeon(R) Scalable Processors with optimal configurations. + +Single instance inference, multi-instance inference are enabled. + +Note: term "instance" here doesn't refer to a cloud instance. This script is executed as a single process. It invokes +multiple "instances" which are formed from multiple threads for each. "instance" is kind of group of threads in this +context. + +Illustrated as below: + +:: + + +-----------------------------+----------------------+-------+ + | process | thread | core | + +=============================+======================+=======+ + | torch.backends.xeon.run_cpu | instance 0: thread 0 | 0 | + | | thread 1 | 1 | + | +----------------------+-------+ + | | instance 1: thread 0 | 2 | + | | thread 1 | 3 | + | +----------------------+-------+ + | | ... | ... | + | +----------------------+-------+ + | | instance N: thread 0 | M | + | | thread 1 | M+1 | + +-----------------------------+----------------------+-------+ + +To get the peak performance on Intel(R) Xeon(R) Scalable Processors, the script optimizes the configuration of thread and memory +management. For thread management, the script configures thread affinity and the preload of Intel OMP library. +For memory management, it configures NUMA binding and preload optimized memory allocation library (e.g. tcmalloc, jemalloc). + +Environment variables that will be set by this script: + ++------------------+-------------------------------------------------------------------------------------------------+ +| Environ Variable | Value | ++==================+=================================================================================================+ +| LD_PRELOAD | Depending on knobs you set, /libiomp5.so, /libjemalloc.so, /libtcmalloc.so might | +| | be appended to LD_PRELOAD. | ++------------------+-------------------------------------------------------------------------------------------------+ +| KMP_AFFINITY | If libiomp5.so is preloaded, KMP_AFFINITY could be set to "granularity=fine,compact,1,0". | ++------------------+-------------------------------------------------------------------------------------------------+ +| KMP_BLOCKTIME | If libiomp5.so is preloaded, KMP_BLOCKTIME is set to "1". | ++------------------+-------------------------------------------------------------------------------------------------+ +| OMP_NUM_THREADS | value of ncores_per_instance | ++------------------+-------------------------------------------------------------------------------------------------+ +| MALLOC_CONF | If libjemalloc.so is preloaded, MALLOC_CONF will be set to | +| | "oversize_threshold:1,background_thread:true,metadata_thp:auto". | ++------------------+-------------------------------------------------------------------------------------------------+ + +*Note*: This script respects environment variables set preliminarily. I.e. If you set the environment variables +mentioned above before running the script, the script will not overwrite the values in the script. + +How to use this module: +~~~~~~~~~~~~~~~~~~~~~~~ + +Single instance inference +------------------------- + +1. Run single-instance inference on a single node with all CPU nodes. + +:: + + python -m torch.backends.xeon.run_cpu --throughput-mode script.py args + +2. Run single-instance inference on a single CPU node. + +:: + + python -m torch.backends.xeon.run_cpu --node-id 1 script.py args + +Multi-instance inference +------------------------ + +1. Multi-instance + By default this tool runs one process per node. If you want to set the instance numbers and core per instance, + --ninstances and --ncores-per-instance should be set. + +:: + + python -m torch.backends.xeon.run_cpu -- python_script args + + eg: on an Intel(R) Xeon(R) Scalable Processor with 14 instance, 4 cores per instance + +:: + + python -m torch.backends.xeon.run_cpu --ninstances 14 --ncores-per-instance 4 python_script args + +2. Run single-instance inference among multiple instances. + By default, runs all ninstances. If you want to independently run a single instance among ninstances, specify rank. + + eg: run 0th instance on an Intel(R) Xeon(R) Scalable Processor with 2 instance (i.e., numactl -C 0-27) + +:: + + python -m torch.backends.xeon.run_cpu --ninstances 2 --rank 0 python_script args + + eg: run 1st instance on an Intel(R) Xeon(R) Scalable Processor with 2 instance (i.e., numactl -C 28-55) + +:: + + python -m torch.backends.xeon.run_cpu --ninstances 2 --rank 1 python_script args + + eg: run 0th instance on an Intel(R) Xeon(R) Scalable Processor with 2 instance, 2 cores per instance, + first four cores (i.e., numactl -C 0-1) + +:: + + python -m torch.backends.xeon.run_cpu --core-list "0, 1, 2, 3" --ninstances 2 --ncores-per-instance 2 + --rank 0 python_script args + +3. To look up what optional arguments this module offers: + +:: + + python -m torch.backends.xeon.run_cpu --help + +Memory allocator +---------------- + +"--enable-tcmalloc" and "--enable-jemalloc" can be used to enable different memory allcator. + +""" + +import glob +import logging +import os +import platform +import re +import subprocess +import sys +from argparse import ArgumentParser, RawTextHelpFormatter, REMAINDER +from os.path import expanduser + +from torch.distributed.elastic.multiprocessing import ( + DefaultLogsSpecs as _DefaultLogsSpecs, + start_processes, + Std, +) + + +format_str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" +logging.basicConfig(level=logging.INFO, format=format_str) +logger = logging.getLogger(__name__) + + +class _CPUinfo: + """Get CPU information, such as cores list and NUMA information.""" + + def __init__(self, test_input=""): + self.cpuinfo = [] + if platform.system() in ["Windows", "Darwin"]: + raise RuntimeError(f"{platform.system()} is not supported!!!") + elif platform.system() == "Linux": + # Sample output of: `lscpu --parse=CPU,Core,Socket,Node` + # + # # The following is the parsable format, which can be fed to other + # # programs. Each different item in every column has an unique ID + # # starting from zero. + # # CPU,Core,Socket,Node + # 0,0,0,0 + # 1,1,0,0 + # ... + if test_input == "": + lscpu_cmd = ["lscpu", "--parse=CPU,Core,Socket,Node"] + lscpu_info = subprocess.check_output( + lscpu_cmd, universal_newlines=True + ).split("\n") + else: + lscpu_info = test_input.split("\n") + + # Get information about cpu, core, socket and node + for line in lscpu_info: + pattern = r"^([\d]+,[\d]+,[\d]+,[\d]?)" + regex_out = re.search(pattern, line) + if regex_out: + self.cpuinfo.append(regex_out.group(1).strip().split(",")) + + # physical cores := core column in lscpu output + # logical cores := cPU column in lscpu output + self.node_nums = int(max(line[3] for line in self.cpuinfo)) + 1 + self.node_physical_cores: list[list[int]] = [] # node_id is index + self.node_logical_cores: list[list[int]] = [] # node_id is index + self.physical_core_node_map = {} # physical core to numa node id + self.logical_core_node_map = {} # logical core to numa node id + + for node_id in range(self.node_nums): + cur_node_physical_core = [] + cur_node_logical_core = [] + for cpuinfo in self.cpuinfo: + nid = cpuinfo[3] if cpuinfo[3] != "" else "0" + if node_id == int(nid): + if int(cpuinfo[1]) not in cur_node_physical_core: + cur_node_physical_core.append(int(cpuinfo[1])) + self.physical_core_node_map[int(cpuinfo[1])] = int(node_id) + cur_node_logical_core.append(int(cpuinfo[0])) + self.logical_core_node_map[int(cpuinfo[0])] = int(node_id) + self.node_physical_cores.append(cur_node_physical_core) + self.node_logical_cores.append(cur_node_logical_core) + + def _physical_core_nums(self): + return len(self.node_physical_cores) * len(self.node_physical_cores[0]) + + def _logical_core_nums(self): + return len(self.node_logical_cores) * len(self.node_logical_cores[0]) + + def get_node_physical_cores(self, node_id): + if node_id < 0 or node_id > self.node_nums - 1: + raise ValueError( + f"Invalid node id: {node_id}. Valid node ids: {list(range(len(self.node_physical_cores)))}" + ) + return self.node_physical_cores[node_id] + + def get_node_logical_cores(self, node_id): + if node_id < 0 or node_id > self.node_nums - 1: + raise ValueError( + f"Invalid node id: {node_id}. Valid node ids: {list(range(len(self.node_physical_cores)))}" + ) + return self.node_logical_cores[node_id] + + def get_all_physical_cores(self): + all_cores = [] + for cores in self.node_physical_cores: + all_cores.extend(cores) + return all_cores + + def get_all_logical_cores(self): + all_cores = [] + for cores in self.node_logical_cores: + all_cores.extend(cores) + return all_cores + + def numa_aware_check(self, core_list): + """ + Check whether all cores in core_list are in the same NUMA node. + + Cross NUMA will reduce performance. + We strongly advice to not use cores on different nodes. + """ + cores_numa_map = self.logical_core_node_map + numa_ids = [] + for core in core_list: + numa_id = cores_numa_map[core] + if numa_id not in numa_ids: + numa_ids.append(numa_id) + if len(numa_ids) > 1: + logger.warning( + "Numa Aware: cores:%s on different NUMA nodes:%s. To avoid \ +this behavior, please use --ncores-per-instance knob to make sure number of cores is divisible by --ncores-per-\ +instance. Alternatively, please use --skip-cross-node-cores knob.", + str(core_list), + str(numa_ids), + ) + if len(numa_ids) == 0: + raise RuntimeError( + "invalid number of NUMA nodes; please make sure numa_ids >= 1" + ) + return numa_ids + + +class _Launcher: + r"""Class for launcher.""" + + msg_lib_notfound = f"Unable to find the {{0}} library file lib{{1}}.so in $CONDA_PREFIX/lib or $VIRTUAL_ENV/lib \ +or /.local/lib/ or /usr/local/lib/ or /usr/local/lib64/ or /usr/lib or /usr/lib64 or \ +{expanduser('~')}/.local/lib/ so the LD_PRELOAD environment variable will not be set." + + def __init__(self) -> None: + self.cpuinfo = _CPUinfo() + + def add_lib_preload(self, lib_type): + """Enable TCMalloc/JeMalloc/intel OpenMP.""" + library_paths = [] + if "CONDA_PREFIX" in os.environ: + library_paths.append(f"{os.environ['CONDA_PREFIX']}/lib") + if "VIRTUAL_ENV" in os.environ: + library_paths.append(f"{os.environ['VIRTUAL_ENV']}/lib") + + library_paths += [ + f"{expanduser('~')}/.local/lib", + "/usr/local/lib", + "/usr/local/lib64", + "/usr/lib", + "/usr/lib64", + ] + + lib_find = False + lib_set = False + for item in os.getenv("LD_PRELOAD", "").split(":"): + if item.endswith(f"lib{lib_type}.so"): + lib_set = True + break + if not lib_set: + for lib_path in library_paths: + library_file = os.path.join(lib_path, f"lib{lib_type}.so") + matches = glob.glob(library_file) + if len(matches) > 0: + ld_preloads = [f"{matches[0]}", os.getenv("LD_PRELOAD", "")] + os.environ["LD_PRELOAD"] = os.pathsep.join( + [p.strip(os.pathsep) for p in ld_preloads if p] + ) + lib_find = True + break + return lib_set or lib_find + + def is_numactl_available(self): + numactl_available = False + try: + cmd = ["numactl", "-C", "0", "-m", "0", "hostname"] + r = subprocess.run( + cmd, + env=os.environ, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + check=False, + ) + if r.returncode == 0: + numactl_available = True + except Exception: + pass + return numactl_available + + def set_memory_allocator( + self, enable_tcmalloc=True, enable_jemalloc=False, use_default_allocator=False + ): + """ + Enable TCMalloc/JeMalloc with LD_PRELOAD and set configuration for JeMalloc. + + By default, PTMalloc will be used for PyTorch, but TCMalloc and JeMalloc can get better + memory reuse and reduce page fault to improve performance. + """ + if enable_tcmalloc and enable_jemalloc: + raise RuntimeError( + "Unable to enable TCMalloc and JEMalloc at the same time." + ) + + if enable_tcmalloc: + find_tc = self.add_lib_preload(lib_type="tcmalloc") + if not find_tc: + msg = f'{self.msg_lib_notfound} you can use "conda install -c conda-forge gperftools" to install {{0}}' + logger.warning(msg.format("TCmalloc", "tcmalloc")) # noqa: G001 + else: + logger.info("Use TCMalloc memory allocator") + + elif enable_jemalloc: + find_je = self.add_lib_preload(lib_type="jemalloc") + if not find_je: + msg = f'{self.msg_lib_notfound} you can use "conda install -c conda-forge jemalloc" to install {{0}}' + logger.warning(msg.format("Jemalloc", "jemalloc")) # noqa: G001 + else: + logger.info("Use JeMalloc memory allocator") + self.set_env( + "MALLOC_CONF", + "oversize_threshold:1,background_thread:true,metadata_thp:auto", + ) + + elif use_default_allocator: + pass + + else: + find_tc = self.add_lib_preload(lib_type="tcmalloc") + if find_tc: + logger.info("Use TCMalloc memory allocator") + return + find_je = self.add_lib_preload(lib_type="jemalloc") + if find_je: + logger.info("Use JeMalloc memory allocator") + return + logger.warning( + """Neither TCMalloc nor JeMalloc is found in $CONDA_PREFIX/lib or $VIRTUAL_ENV/lib + or /.local/lib/ or /usr/local/lib/ or /usr/local/lib64/ or /usr/lib or /usr/lib64 or + %s/.local/lib/ so the LD_PRELOAD environment variable will not be set. + This may drop the performance""", + expanduser("~"), + ) + + def log_env_var(self, env_var_name=""): + if env_var_name in os.environ: + logger.info("%s=%s", env_var_name, os.environ[env_var_name]) + + def set_env(self, env_name, env_value): + if not env_value: + logger.warning("%s is None", env_name) + if env_name not in os.environ: + os.environ[env_name] = env_value + elif os.environ[env_name] != env_value: + logger.warning( + "Overriding value with the one set in environment variable: %s. \ +Value applied: %s. Value ignored: %s", + env_name, + os.environ[env_name], + env_value, + ) + self.log_env_var(env_name) + + # set_kmp_affinity is used to control whether to set KMP_AFFINITY or not. + # In scenario that use all cores on all nodes, including logical cores, setting KMP_AFFINITY disables logical cores. + # In this case, KMP_AFFINITY should not be set. + def set_multi_thread_and_allocator( + self, + ncores_per_instance, + disable_iomp=False, + set_kmp_affinity=True, + enable_tcmalloc=True, + enable_jemalloc=False, + use_default_allocator=False, + ): + """ + Set multi-thread configuration and enable Intel openMP and TCMalloc/JeMalloc. + + By default, GNU openMP and PTMalloc are used in PyTorch. but Intel openMP and TCMalloc/JeMalloc are better alternatives + to get performance benefit. + """ + self.set_memory_allocator( + enable_tcmalloc, enable_jemalloc, use_default_allocator + ) + self.set_env("OMP_NUM_THREADS", str(ncores_per_instance)) + if not disable_iomp: + find_iomp = self.add_lib_preload(lib_type="iomp5") + if not find_iomp: + msg = f'{self.msg_lib_notfound} you can use "conda install mkl" to install {{0}}' + logger.warning(msg.format("iomp", "iomp5")) # noqa: G001 + else: + logger.info("Using Intel OpenMP") + if set_kmp_affinity: + self.set_env("KMP_AFFINITY", "granularity=fine,compact,1,0") + self.set_env("KMP_BLOCKTIME", "1") + self.log_env_var("LD_PRELOAD") + + r""" + Launcher for single instance and multi-instance + """ + + def launch(self, args): + cores = [] + set_kmp_affinity = True + enable_taskset = False + if args.core_list: # user specify what cores will be used by params + cores = [int(x) for x in args.core_list.split(",")] + if args.ncores_per_instance == -1: + raise RuntimeError( + 'please specify the "--ncores-per-instance" if you have pass the --core-list params' + ) + elif ( + args.ninstances > 1 + and args.ncores_per_instance * args.ninstances < len(cores) + ): + logger.warning( + "only first %s cores will be used, \ +but you specify %s cores in core_list", + args.ncores_per_instance * args.ninstances, + len(cores), + ) + else: + args.ninstances = len(cores) // args.ncores_per_instance + + else: + if args.use_logical_core: + if args.node_id != -1: + cores = self.cpuinfo.get_node_logical_cores(args.node_id) + else: + cores = self.cpuinfo.get_all_logical_cores() + # When using all cores on all nodes, including logical cores, + # setting KMP_AFFINITY disables logical cores. Thus, KMP_AFFINITY should not be set. + set_kmp_affinity = False + else: + if args.node_id != -1: + cores = self.cpuinfo.get_node_physical_cores(args.node_id) + else: + cores = self.cpuinfo.get_all_physical_cores() + if ( + not args.multi_instance + and args.ninstances == -1 + and args.ncores_per_instance == -1 + ): + args.ninstances = 1 + args.ncores_per_instance = len(cores) + elif ( + args.multi_instance + and args.ninstances == -1 + and args.ncores_per_instance == -1 + ): + args.throughput_mode = True + elif args.ncores_per_instance == -1 and args.ninstances != -1: + if args.ninstances > len(cores): + raise RuntimeError( + f"there are {len(cores)} total cores but you specify {args.ninstances} ninstances; \ +please make sure ninstances <= total_cores)" + ) + else: + args.ncores_per_instance = len(cores) // args.ninstances + elif args.ncores_per_instance != -1 and args.ninstances == -1: + if not args.skip_cross_node_cores: + args.ninstances = len(cores) // args.ncores_per_instance + else: + ncore_per_node = len(self.cpuinfo.node_physical_cores[0]) + num_leftover_cores = ncore_per_node % args.ncores_per_instance + if args.ncores_per_instance > ncore_per_node: + # too many ncores_per_instance to skip cross-node cores + logger.warning( + "there are %s core(s) per socket, but you specify %s ncores_per_instance and \ +skip_cross_node_cores. Please make sure --ncores-per-instance < core(s) per \ +socket", + ncore_per_node, + args.ncores_per_instance, + ) + sys.exit(-1) + elif num_leftover_cores == 0: + # aren't any cross-node cores + logger.info( + "--skip-cross-node-cores is set, but there are no cross-node cores." + ) + args.ninstances = len(cores) // args.ncores_per_instance + else: + # skip cross-node cores + if args.ninstances != -1: + logger.warning( + "--skip-cross-node-cores is exclusive to --ninstances. --ninstances \ +won't take effect even if it is set explicitly." + ) + + i = 1 + leftover_cores = set() + while ncore_per_node * i <= len(cores): + leftover_cores.update( + cores[ + ncore_per_node * i + - num_leftover_cores : ncore_per_node * i + ] + ) + i += 1 + cores = list(set(cores) - leftover_cores) + assert len(cores) % args.ncores_per_instance == 0 + args.ninstances = len(cores) // args.ncores_per_instance + else: + if args.ninstances * args.ncores_per_instance > len(cores): + raise RuntimeError( + "Please make sure ninstances * ncores_per_instance <= total_cores" + ) + if args.latency_mode: + logger.warning( + "--latency-mode is exclusive to --ninstances, --ncores-per-instance, --node-id and \ +--use-logical-core. They won't take effect even they are set explicitly." + ) + args.ncores_per_instance = 4 + cores = self.cpuinfo.get_all_physical_cores() + args.ninstances = len(cores) // args.ncores_per_instance + + if args.throughput_mode: + logger.warning( + "--throughput-mode is exclusive to --ninstances, --ncores-per-instance, --node-id and \ +--use-logical-core. They won't take effect even they are set explicitly." + ) + args.ninstances = self.cpuinfo.node_nums + cores = self.cpuinfo.get_all_physical_cores() + args.ncores_per_instance = len(cores) // args.ninstances + + if args.ninstances > 1 and args.rank != -1: + logger.info( + "assigning %s cores for instance %s", + args.ncores_per_instance, + args.rank, + ) + + if not args.disable_numactl: + numactl_available = self.is_numactl_available() + if not numactl_available: + if not args.disable_taskset: + logger.warning( + "Core binding with numactl is not available. Disabling numactl and using taskset instead. \ + This may affect performance in multi-socket system; please use numactl if memory binding is needed." + ) + args.disable_numactl = True + enable_taskset = True + else: + logger.warning( + "Core binding with numactl is not available, and --disable_taskset is set. \ + Please unset --disable_taskset to use taskset instead of numactl." + ) + sys.exit(-1) + + if not args.disable_taskset: + enable_taskset = True + + self.set_multi_thread_and_allocator( + args.ncores_per_instance, + args.disable_iomp, + set_kmp_affinity, + args.enable_tcmalloc, + args.enable_jemalloc, + args.use_default_allocator, + ) + entrypoint = "" + launch_args = {} + launch_envs: dict[int, dict] = {} + launch_tee = {} + # check whether is launched from torchrun with --nproc-per-node + local_size = int(os.environ.get("LOCAL_WORLD_SIZE", 1)) + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + for i in range(args.ninstances): + cmd = [] + cur_process_cores = "" + if not args.disable_numactl or enable_taskset: + if not args.disable_numactl: + cmd = ["numactl"] + elif enable_taskset: + cmd = ["taskset"] + cores = sorted(cores) + if ( + args.rank == -1 + ): # sequentially assign ncores_per_instance to ninstances + core_list = cores[ + i + * args.ncores_per_instance : (i + 1) + * args.ncores_per_instance + ] + else: # assign ncores_per_instance from rank + core_list = cores[ + args.rank + * args.ncores_per_instance : (args.rank + 1) + * args.ncores_per_instance + ] + + core_ranges: list[dict] = [] + if local_size > 1: + total_num_cores = len(core_list) + cores_per_rank = total_num_cores // local_size + assert ( + cores_per_rank >= 1 + ), "At least one core needs to be assigned to each rank" + core_list = core_list[ + cores_per_rank * local_rank : cores_per_rank * (local_rank + 1) + ] + for core in core_list: + if len(core_ranges) == 0: + range_elem = {"start": core, "end": core} + core_ranges.append(range_elem) + else: + if core - core_ranges[-1]["end"] == 1: + core_ranges[-1]["end"] = core + else: + range_elem = {"start": core, "end": core} + core_ranges.append(range_elem) + for r in core_ranges: + cur_process_cores = f"{cur_process_cores}{r['start']}-{r['end']}," + cur_process_cores = cur_process_cores[:-1] + if not args.disable_numactl: + numa_params = f"-C {cur_process_cores} " + numa_ids = ",".join( + [ + str(numa_id) + for numa_id in self.cpuinfo.numa_aware_check(core_list) + ] + ) + numa_params += f"-m {numa_ids}" + cmd.extend(numa_params.split()) + elif enable_taskset: + taskset_params = f"-c {cur_process_cores} " + cmd.extend(taskset_params.split()) + with_python = not args.no_python + if with_python: + cmd.append(sys.executable) + cmd.append("-u") + if args.module: + cmd.append("-m") + cmd.append(args.program) + cmd.extend(args.program_args) + cmd_s = " ".join(cmd) + logger.info(cmd_s) + if entrypoint == "": + entrypoint = cmd[0] + del cmd[0] + launch_args[i] = tuple(cmd) + launch_envs[i] = {} + launch_tee[i] = Std.ALL + + if args.rank != -1: # launches single instance, rank, only + break + + ctx = start_processes( + name=args.log_file_prefix, + entrypoint=entrypoint, + args=launch_args, + envs=launch_envs, + logs_specs=_DefaultLogsSpecs(log_dir=args.log_path, tee=launch_tee), + ) + ctx.wait() + + +def _add_memory_allocator_params(parser): + group = parser.add_argument_group("Memory Allocator Parameters") + # allocator control + group.add_argument( + "--enable-tcmalloc", + "--enable_tcmalloc", + action="store_true", + default=False, + help="Enable tcmalloc allocator", + ) + group.add_argument( + "--enable-jemalloc", + "--enable_jemalloc", + action="store_true", + default=False, + help="Enable jemalloc allocator", + ) + group.add_argument( + "--use-default-allocator", + "--use_default_allocator", + action="store_true", + default=False, + help="Use default memory allocator", + ) + + +def _add_multi_instance_params(parser): + group = parser.add_argument_group("Multi-instance Parameters") + # multi-instance control + group.add_argument( + "--ncores-per-instance", + "--ncores_per_instance", + metavar="\b", + default=-1, + type=int, + help="Cores per instance", + ) + group.add_argument( + "--ninstances", + metavar="\b", + default=-1, + type=int, + help="For multi-instance, you should give the cores number you used for per instance.", + ) + group.add_argument( + "--skip-cross-node-cores", + "--skip_cross_node_cores", + action="store_true", + default=False, + help="If specified --ncores-per-instance, skips cross-node cores.", + ) + group.add_argument( + "--rank", + metavar="\b", + default="-1", + type=int, + help="Specify instance index to assign ncores_per_instance for rank; \ +otherwise ncores_per_instance will be assigned sequentially to ninstances. Please refer to \ +https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/launch_script.md", + ) + group.add_argument( + "--latency-mode", + "--latency_mode", + action="store_true", + default=False, + help="By default 4 core per instance and use all physical cores", + ) + group.add_argument( + "--throughput-mode", + "--throughput_mode", + action="store_true", + default=False, + help="By default one instance per node and use all physical cores", + ) + group.add_argument( + "--node-id", + "--node_id", + metavar="\b", + default=-1, + type=int, + help="node id for multi-instance, by default all nodes will be used", + ) + group.add_argument( + "--use-logical-core", + "--use_logical_core", + action="store_true", + default=False, + help="Whether only use physical cores", + ) + group.add_argument( + "--disable-numactl", + "--disable_numactl", + action="store_true", + default=False, + help="Disable numactl", + ) + group.add_argument( + "--disable-taskset", + "--disable_taskset", + action="store_true", + default=False, + help="Disable taskset", + ) + group.add_argument( + "--core-list", + "--core_list", + metavar="\b", + default=None, + type=str, + help='Specify the core list as "core_id, core_id, ....", otherwise, all the cores will be used.', + ) + group.add_argument( + "--log-path", + "--log_path", + metavar="\b", + default="", + type=str, + help="The log file directory. Default path is " + ", which means disable logging to files.", + ) + group.add_argument( + "--log-file-prefix", + "--log_file_prefix", + metavar="\b", + default="run", + type=str, + help="log file prefix", + ) + + +def _add_kmp_iomp_params(parser): + group = parser.add_argument_group("IOMP Parameters") + group.add_argument( + "--disable-iomp", + "--disable_iomp", + action="store_true", + default=False, + help="By default, we use Intel OpenMP and libiomp5.so will be add to LD_PRELOAD", + ) + + +def create_args(parser=None): + """ + Parse the command line options. + + @retval ArgumentParser + """ + parser.add_argument( + "--multi-instance", + "--multi_instance", + action="store_true", + default=False, + help="Enable multi-instance, by default one instance per node", + ) + + parser.add_argument( + "-m", + "--module", + default=False, + action="store_true", + help="Changes each process to interpret the launch script " + "as a python module, executing with the same behavior as" + '"python -m".', + ) + + parser.add_argument( + "--no-python", + "--no_python", + default=False, + action="store_true", + help='Do not prepend the --program script with "python" - just exec ' + "it directly. Useful when the script is not a Python script.", + ) + + _add_memory_allocator_params(parser) + _add_kmp_iomp_params(parser) + + _add_multi_instance_params(parser) + # positional + parser.add_argument( + "program", + type=str, + help="The full path to the program/script to be launched. " + "followed by all the arguments for the script", + ) + + # rest from the training program + parser.add_argument("program_args", nargs=REMAINDER) + + +def main(args): + env_before = set(os.environ.keys()) + if platform.system() in ["Windows", "Darwin"]: + raise RuntimeError(f"{platform.system()} is not supported!!!") + + if args.log_path: + os.makedirs(args.log_path, exist_ok=True) + else: + args.log_path = os.devnull + + if args.latency_mode and args.throughput_mode: + raise RuntimeError( + "Either args.latency_mode or args.throughput_mode should be set" + ) + + if not args.no_python and not args.program.endswith(".py"): + raise RuntimeError( + 'For non Python script, you should use "--no-python" parameter.' + ) + + # Verify LD_PRELOAD + if "LD_PRELOAD" in os.environ: + lst_valid = [] + tmp_ldpreload = os.environ["LD_PRELOAD"] + for item in tmp_ldpreload.split(":"): + matches = glob.glob(item) + if len(matches) > 0: + lst_valid.append(item) + else: + logger.warning("%s doesn't exist. Removing it from LD_PRELOAD.", item) + if len(lst_valid) > 0: + os.environ["LD_PRELOAD"] = ":".join(lst_valid) + else: + os.environ["LD_PRELOAD"] = "" + + launcher = _Launcher() + launcher.launch(args) + for x in sorted(set(os.environ.keys()) - env_before): + logger.debug("%s=%s", x, os.environ[x]) + + +if __name__ == "__main__": + parser = ArgumentParser( + description="This is a script for launching PyTorch inference on Intel(R) Xeon(R) Scalable " + "Processors with optimal configurations. Single instance inference, " + "multi-instance inference are enable. To get the peak performance on Intel(R) " + "Xeon(R) Scalable Processors, the script optimizes the configuration " + "of thread and memory management. For thread management, the script configures thread " + "affinity and the preload of Intel OMP library. For memory management, it configures " + "NUMA binding and preload optimized memory allocation library (e.g. tcmalloc, jemalloc) " + "\n################################# Basic usage ############################# \n" + "\n 1. single instance\n" + "\n >>> python -m torch.backends.xeon.run_cpu python_script args \n" + "\n2. multi-instance \n" + "\n >>> python -m torch.backends.xeon.run_cpu --ninstances xxx " + "--ncores-per-instance xx python_script args\n" + "\n############################################################################# \n", + formatter_class=RawTextHelpFormatter, + ) + create_args(parser) + args = parser.parse_args() + main(args) diff --git a/phivenv/Lib/site-packages/torch/backends/xnnpack/__init__.py b/phivenv/Lib/site-packages/torch/backends/xnnpack/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a8f40f78be9d9f6683c0cf9daad5d6d3ede98271 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/backends/xnnpack/__init__.py @@ -0,0 +1,29 @@ +# mypy: allow-untyped-defs +import sys +import types + +import torch + + +class _XNNPACKEnabled: + def __get__(self, obj, objtype): + return torch._C._is_xnnpack_enabled() + + def __set__(self, obj, val): + raise RuntimeError("Assignment not supported") + + +class XNNPACKEngine(types.ModuleType): + def __init__(self, m, name): + super().__init__(name) + self.m = m + + def __getattr__(self, attr): + return self.m.__getattribute__(attr) + + enabled = _XNNPACKEnabled() + + +# This is the sys.modules replacement trick, see +# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273 +sys.modules[__name__] = XNNPACKEngine(sys.modules[__name__], __name__) diff --git a/phivenv/Lib/site-packages/torch/backends/xnnpack/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/backends/xnnpack/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..475ea2bab4a2cee697770cd650853f3b45a3fed4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/backends/xnnpack/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/compiler/__init__.py b/phivenv/Lib/site-packages/torch/compiler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f6deba227837dfcc8211d4b9a4a70040a6fce464 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/compiler/__init__.py @@ -0,0 +1,632 @@ +# mypy: allow-untyped-defs +from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar +from typing_extensions import ParamSpec + +import torch + +from . import config + + +if TYPE_CHECKING: + from ._cache import CacheInfo + + +__all__ = [ + "compile", + "config", + "assume_constant_result", + "reset", + "allow_in_graph", + "substitute_in_graph", + "list_backends", + "disable", + "set_stance", + "set_enable_guard_collectives", + "cudagraph_mark_step_begin", + "wrap_numpy", + "is_compiling", + "is_dynamo_compiling", + "is_exporting", + "save_cache_artifacts", + "load_cache_artifacts", + "skip_guard_on_inbuilt_nn_modules_unsafe", + "skip_guard_on_all_nn_modules_unsafe", + "keep_tensor_guards_unsafe", + "skip_guard_on_globals_unsafe", + "nested_compile_region", +] + + +_P = ParamSpec("_P") +_R = TypeVar("_R") + + +def compile(*args, **kwargs): + """ + See :func:`torch.compile` for details on the arguments for this function. + """ + return torch.compile(*args, **kwargs) + + +def reset() -> None: + """ + This function clears all compilation caches and restores the system to its initial state. + It is recommended to call this function, especially after using operations like `torch.compile(...)` + to ensure a clean state before another unrelated compilation + """ + import torch._dynamo + + torch._dynamo.reset() + + +def allow_in_graph(fn): + """ + Tells the compiler frontend (Dynamo) to skip symbolic introspection of the function + and instead directly write it to the graph when encountered. + + If you are using :func:`torch.compile` (with backend="inductor" (the default)), or + :func:`torch.export.export`, and trying to black-box a Python function throughout + all tracing, do not use this API. + Instead, please create a custom operator (see `PyTorch Custom Operators Landing Page + `_) + + .. warning:: + + If you're a typical torch.compile user (e.g. you're applying torch.compile to + a model to make it run faster), you probably don't want to use this function. + :func:`allow_in_graph` is a footgun because it skips the compiler frontend + (Dynamo) that is responsible for doing safety checks (graph breaks, handling + closures, etc). Incorrect usage will lead to difficult-to-debug silent + incorrectness issues. + + Given a Python function with no allow_in_graph decorator, regular execution + of torch.compile traces through the function. :func:`allow_in_graph` changes + it so that the frontend does not trace inside the function, but the compiler + backend still traces through it. Compare this to custom operators, which + treats a function as a black box throughout the torch.compile stack. The following + table compares these mechanisms. + + +------------------------+-----------------------+--------------------------------+ + | Mechanism | Frontend (Dynamo) | Backend (AOTAutograd+Inductor) | + +========================+=======================+================================+ + | no decorator | trace inside | trace inside | + +------------------------+-----------------------+--------------------------------+ + | allow_in_graph | opaque callable | trace inside | + +------------------------+-----------------------+--------------------------------+ + | custom op | opaque callable | opaque callable | + +------------------------+-----------------------+--------------------------------+ + + One common use case for :func:`allow_in_graph()` is as an escape hatch for the compiler + frontend: if you know the function works w.r.t. to the downstream components of the + compilation stack (AOTAutograd and Inductor) but there is a Dynamo bug that prevents it from + symbolically introspecting the function properly (or if your code is in C/C++ and + therefore cannot be introspected with Dynamo), then one can decorate said function + with :func:`allow_in_graph` to bypass Dynamo. + + We require that ``fn`` adhere to the following restrictions. Failure to adhere + results in undefined behavior: + + - The inputs to ``fn`` must be Proxy-able types in the FX graph. Valid types include: + Tensor/int/bool/float/None/List[Tensor?]/List[int?]/List[float?] + Tuple[Tensor?, ...]/Tuple[int?, ...]/Tuple[float?, ...]/torch.dtype/torch.device + - The outputs to ``fn`` must be Proxy-able types in the FX graph (see previous bullet) + - all Tensors used inside of ``fn`` must be passed directly as inputs to ``fn`` + (as opposed to being captured variables). + + Args: + fn: A callable representing the function to be included in the graph. + If ``fn`` is a list or tuple of callables it recursively applies + :func:`allow_in_graph()` to each function and returns a new list or + tuple containing the modified functions. + + Example:: + + torch.compiler.allow_in_graph(my_custom_function) + + @torch.compile(...) + def fn(x): + x = torch.add(x, 1) + x = my_custom_function(x) + x = torch.add(x, 1) + return x + + fn(...) + + Will capture a single graph containing ``my_custom_function()``. + + """ + import torch._dynamo + + return torch._dynamo.allow_in_graph(fn) + + +def substitute_in_graph( + original_fn: Callable[_P, _R], + *, + can_constant_fold_through: bool = False, + skip_signature_check: bool = False, +) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]: + """ + Register a polyfill handler for a function, usually a C function from the C extension, to be + used in place of the original function when inlining the original function in the graph. + + .. note:: + + The polyfill handler is only used when inlining the original function. It is not used when + the original function is called directly. In the eager mode, the decorated function calls + the performant C function rather than the polyfill handler. + + The polyfill handler is a function that will be called in place of the original function when + inlining the original function. The polyfill handler should have the same signature and the same + behavior as the original function. + + Args: + original_fn (callable): The original function, usually a C function, to register a polyfill + handler for. + can_constant_fold_through (bool, optional): Whether the polyfill handler can be constant + folded through. That is, if the polyfill handler is a pure function and its arguments + are constant, the result of the polyfill handler can be constant folded during the + compilation. Defaults to ``False``. + skip_signature_check (bool, optional): Whether to skip the signature check between the + original function and the polyfill handler. Defaults to ``False``. + + Returns: + A decorator that registers the polyfill handler for the original function. + + Example:: + + >>> import operator + >>> operator.indexOf([1, 2, 3, 4, 5], 3) + 2 + >>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3) + ... # xdoctest: +SKIP("Long tracebacks") + Traceback (most recent call last): + ... + torch._dynamo.exc.Unsupported: ... + + >>> @torch.compiler.substitute_in_graph(operator.indexOf) + ... def indexOf(a, b, /): + ... for i, item in enumerate(a): + ... if item is b or item == b: + ... return i + ... raise ValueError("sequence.index(x): x not in sequence") + >>> + >>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3) + 2 + """ + import torch._dynamo + + return torch._dynamo.substitute_in_graph( + original_fn, + can_constant_fold_through=can_constant_fold_through, + skip_signature_check=skip_signature_check, + ) + + +def list_backends(exclude_tags=("debug", "experimental")) -> list[str]: + """ + Return valid strings that can be passed to `torch.compile(..., backend="name")`. + + Args: + exclude_tags(optional): A tuple of strings representing tags to exclude. + """ + import torch._dynamo + + return torch._dynamo.list_backends(exclude_tags) + + +def assume_constant_result(fn): + """ + This function is used to mark a function `fn` as having a constant result. + This allows the compiler to optimize away your function. + Returns The same function `fn` + + Args: + fn: The function to be marked as having a constant result. + + .. warning:: + `assume_constant_result` can if invalid cause safety and soundness issues, :func:`torch.compile` + will not attempt to validate whether the constant assumption is true or not + + """ + import torch._dynamo + + return torch._dynamo.assume_constant_result(fn) + + +def disable(fn=None, recursive=True, *, reason=None): + """ + This function provides a decorator to disable compilation on a function. + It also provides the option of recursively disabling called functions. + + Args: + fn (optional): The function to disable + recursive (optional): A boolean value indicating whether the disabling should be recursive. + reason (optional): A string value indicating the reason for disabling the function. + """ + import torch._dynamo + + return torch._dynamo.disable(fn, recursive, reason=reason) + + +def set_stance( + stance: str = "default", *, skip_guard_eval_unsafe=False, force_backend=None +): + """ + Set the current stance of the compiler. + Can be used as a function, context manager, or decorator. + Do not use this function inside a `torch.compile` region - an error will be raised otherwise. + + .. code-block:: python + + @torch.compile + def foo(x): + ... + + @torch.compiler.set_stance("force_eager") + def bar(): + # will not be compiled + foo(...) + + bar() + + with torch.compiler.set_stance("force_eager"): + # will also not be compiled + foo(...) + + torch.compiler.set_stance("force_eager") + # will also not be compiled + foo(...) + torch.compiler.set_stance("default") + + # will be compiled + foo(...) + + Args: + stance: The stance to set the compiler to. Valid values are: + + - "default": The default stance, used for normal compilation. + - "force_eager": Ignore all `torch.compile` directives. + - "eager_on_recompile": Run code eagerly when a recompile is necessary. + If there is cached compiled code valid for the input, it will still be used. + - "fail_on_recompile": Raise an error when recompiling a function. + - "eager_then_compile": Run the first invocation in eager mode, then compile on + subsequent calls. This is beneficial for dynamic shapes as it allows inferring + dynamism from the first two invocations instead of wasting a static compile on + the first invocation. + - "aot_eager_then_compile": Run the first invocation with AOT eager to get memory + benefits from activation checkpointing, then compile on subsequent calls. Like + eager_then_compile, this improves handling of dynamic shapes by avoiding an + initial static compile. + + + skip_guard_eval_unsafe: A flag to run only differentiating guards. + CAUTION - This flag is unsafe and should only be used if your setup + meets the following conditions. + + torch.compile uses a guard system to support recompilations and + choose which compiled artifact to run at runtime. These guards, + though efficient, add some overhead, which may impact performance in + scenarios where you need to optimize for minimal guard processing + time. This API enables you to disable guard evaluation, assuming + that you have warmed up the compiled model with a sufficient variety + of inputs. This assumption means that, after the warmup phase, no + further recompilations will be necessary. If this assumption fails, + there is a risk of silently producing incorrect results (hence the + term "unsafe" in the API name). + + force_backend: If `stance` is "default", this argument can be used to force `torch.compile` + to use a specific backend. Otherwise, an error is raised. + """ + import torch._dynamo + + return torch._dynamo.set_stance( + stance, + skip_guard_eval_unsafe=skip_guard_eval_unsafe, + force_backend=force_backend, + ) + + +# forbid in graph +set_stance._dynamo_forbidden = True # type: ignore[attr-defined] + + +def set_enable_guard_collectives(enabled: bool): + """ + Enables use of collectives *during* guard evaluation to synchronize behavior + across ranks. This is expensive: we have to issue a collective every time + we enter a compiled code region, even if no rank actually would need to + compile. This can help prevent NCCL hangs by ensuring that we never have a + situation where one rank starts recompiling while other ranks don't compile; + it is especially useful in conjunction with enable_compiler_collectives + where such a situation would immediately cause a hang (as it is necessary + for all ranks to compile at the same time to run compiler collectives). Like + compiler collectives, you can only run this on SPMD programs; you will hang + otherwise. Note that a guard collective is only issued if there is any + compiled code to guard on; if this the first time we encounter a frame or + the frame is skipped, we don't issue collectives. + + Returns the previous setting of enabled. + """ + from torch._C._dynamo.eval_frame import set_guard_complete_hook # noqa: F401 + from torch._dynamo.eval_frame import guard_collectives_hook + + if enabled: + return set_guard_complete_hook(guard_collectives_hook) is not None + else: + return set_guard_complete_hook(None) is not None + + +set_enable_guard_collectives._dynamo_forbidden = True # type: ignore[attr-defined] + + +def cudagraph_mark_step_begin(): + """ + Indicates that a new iteration of inference or training is about to begin. + + CUDA Graphs will free tensors of a prior iteration. A new iteration is started on each invocation of + torch.compile, so long as there is not a pending backward that has not been called. + + If that heuristic is wrong, such as in the following example, manually mark it with this api. + + .. code-block:: python + + @torch.compile(mode="reduce-overhead") + def rand_foo(): + return torch.rand([4], device="cuda") + + for _ in range(5): + torch.compiler.cudagraph_mark_step_begin() + rand_foo() + rand_foo() + + For more details, see `torch.compiler_cudagraph_trees `__ + """ + from torch._inductor import cudagraph_trees + + cudagraph_trees.mark_step_begin() + + +def wrap_numpy(fn): + r"""Decorator that turns a function from ``np.ndarray``s to ``np.ndarray``s into a function + from ``torch.Tensor``s to ``torch.Tensor``s. + + It is designed to be used with :func:`torch.compile` with ``fullgraph=True``. It allows to + compile a NumPy function as if it were a PyTorch function. This allows you to run NumPy code + on CUDA or compute its gradients. + + .. note:: + + This decorator does not work without :func:`torch.compile`. + + Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> # Compile a NumPy function as a Tensor -> Tensor function + >>> @torch.compile(fullgraph=True) + >>> @torch.compiler.wrap_numpy + >>> def fn(a: np.ndarray): + >>> return np.sum(a * a) + >>> # Execute the NumPy function using Tensors on CUDA and compute the gradients + >>> x = torch.arange(6, dtype=torch.float32, device="cuda", requires_grad=True) + >>> out = fn(x) + >>> out.backward() + >>> print(x.grad) + tensor([ 0., 2., 4., 6., 8., 10.], device='cuda:0') + """ + from torch._dynamo.external_utils import wrap_numpy as wrap + + return wrap(fn) + + +_is_compiling_flag: bool = False +_is_exporting_flag: bool = False + + +def is_compiling() -> bool: + """ + Indicates whether a graph is executed/traced as part of torch.compile() or torch.export(). + + Note that there are 2 other related flags that should deprecated eventually: + * torch._dynamo.external_utils.is_compiling() + * torch._utils.is_compiling() + + Example:: + + >>> def forward(self, x): + >>> if not torch.compiler.is_compiling(): + >>> pass # ...logic that is not needed in a compiled/traced graph... + >>> + >>> # ...rest of the function... + """ + if torch.jit.is_scripting(): + return False + else: + return _is_compiling_flag + + +def is_dynamo_compiling() -> bool: + """ + Indicates whether a graph is traced via TorchDynamo. + + It's stricter than is_compiling() flag, as it would only be set to True when + TorchDynamo is used. + + Example:: + + >>> def forward(self, x): + >>> if not torch.compiler.is_dynamo_compiling(): + >>> pass # ...logic that is not needed in a TorchDynamo-traced graph... + >>> + >>> # ...rest of the function... + """ + return False + + +def is_exporting() -> bool: + """ + Indicated whether we're under exporting. + + It's stricter than is_compiling() flag, as it would only be set to True when + torch.export is used. + + Example:: + + >>> def forward(self, x): + >>> if not torch.compiler.is_exporting(): + >>> pass # ...logic that is not needed in export... + >>> + >>> # ...rest of the function... + """ + return _is_exporting_flag + + +def save_cache_artifacts() -> Optional[tuple[bytes, "CacheInfo"]]: + """ + Serializes all the cache artifacts that were created during the compilation + + Example: + + - Execute torch.compile + - Call torch.compiler.save_cache_artifacts() + """ + from ._cache import CacheArtifactManager, CacheInfo + + return CacheArtifactManager.serialize() + + +def load_cache_artifacts(serialized_artifacts: bytes) -> Optional["CacheInfo"]: + """ + Hot loads cache artifacts that were previously serialized via + save_cache_artifacts + + Example: + + # From a previous invocation + artifacts = torch.compiler.save_cache_artifacts() + + torch.compiler.load_cache_artifacts(artifacts[0]) + """ + from ._cache import CacheArtifactManager, CacheInfo + + artifacts = CacheArtifactManager.deserialize(serialized_artifacts) + if artifacts is not None: + return CacheArtifactManager.populate_caches(artifacts) + return None + + +def skip_guard_on_inbuilt_nn_modules_unsafe(guard_entries): + """ + A common function to skip guards on the inbuilt nn modules like + torch.nn.Linear. This is unsafe to use by default. But for majority of + torch.compile users, the model code does not modify the inbuilt nn module + attributes. They can benefit from reduction in guard latency overhead using + this API. + + To use this API, use guard_filter_fn argument while calling torch.compile + + >> opt_mod = torch.compile( + >> mod, + >> options={"guard_filter_fn": torch.compiler.skip_guard_on_all_nn_modules_unsafe}, + >> ) + """ + return [ + not entry.orig_guard.source.is_unspecialized_builtin_nn_module() + for entry in guard_entries + ] + + +def skip_guard_on_all_nn_modules_unsafe(guard_entries): + """ + A common function to skip guards on all nn modules, both user defined as + well inbuilt nn modules (like torch.nn.Linear). This is unsafe to use by + default. But for majority of torch.compile users, the model code does not + modify the nn module attributes. They can benefit from reduction in guard + latency overhead using this API. + + To use this API, use guard_filter_fn argument while calling torch.compile + + >> opt_mod = torch.compile( + >> mod, + >> options={"guard_filter_fn": torch.compiler.skip_guard_on_all_nn_modules_unsafe}, + >> ) + """ + + return [ + not entry.orig_guard.source.is_unspecialized_nn_module() + for entry in guard_entries + ] + + +def keep_tensor_guards_unsafe(guard_entries, keep_parameters=False): + """ + A common function to keep tensor guards on all tensors. This is unsafe to + use by default. But if you don't expect any changes in the model code, you + can just keep the tensor guards. + + + >> opt_mod = torch.compile( + >> mod, + >> options={"guard_filter_fn": torch.compiler.keep_tensor_guards}, + >> ) + """ + + keep_flags = [] + for entry in guard_entries: + if entry.guard_type == "TENSOR_MATCH": + if not isinstance(entry.value, torch.nn.Parameter): + keep_flags.append(True) + elif keep_parameters: + keep_flags.append(True) + else: + keep_flags.append(False) + else: + keep_flags.append(False) + return keep_flags + + +def skip_guard_on_globals_unsafe(guard_entries): + """ + A common function to skip guards on all globals. This is unsafe to use by + default. But if you don't expect any changes in the globals, you can just + keep the tensor guards. + + >> opt_mod = torch.compile( + >> mod, + >> options={"guard_filter_fn": torch.compiler.skip_guard_on_globals}, + >> ) + """ + + return [not entry.is_global for entry in guard_entries] + + +def nested_compile_region(fn=None): + """ + Tells **``torch.compile``** that the marked set of operations forms a nested + compile region (which is often repeated in the full model) whose code can be + compiled once and safely reused. ``nested_compile_region`` can also be used + as a decorator. + + During **``torch.compile``** tracing, the compiler applies *hierarchical + compilation* with ``nested_compile_region``: it emits optimized code for the + marked region the first time it is encountered and re-emits (or “stamps + out”) the previously compiled code on every subsequent invocation. This can + substantially reduce overall compile time for deeply-stacked, + structurally-identical components such as the transformer layers of a + large-language-model (LLM). + + Outside a ``torch.compile`` context—i.e., in standard eager execution—the + call is a no-op, so existing workflows remain unaffected. + + Note that ``nested_compile_region`` **does not** promise that a region will + be compiled exactly once. If the compiler detects that new input conditions + (shape, dtype, device, stride, globals etc.) make the cached version invalid + to reuse, it will transparently re-compile the region. Using it is + therefore *safe*: correctness is always preserved, and you pay the extra + compilation cost only when required. + """ + + from torch._higher_order_ops.invoke_subgraph import ( + mark_compile_region as _mark_compile_region, + ) + + return _mark_compile_region(fn) diff --git a/phivenv/Lib/site-packages/torch/compiler/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/compiler/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eea1124990b55157527df7becdac35709a3821a5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/compiler/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/compiler/__pycache__/_cache.cpython-39.pyc b/phivenv/Lib/site-packages/torch/compiler/__pycache__/_cache.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..efff570fa8a109e978d802d0f41339f848fe518a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/compiler/__pycache__/_cache.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/compiler/__pycache__/config.cpython-39.pyc b/phivenv/Lib/site-packages/torch/compiler/__pycache__/config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..341e79dbcafff474c0de87cbc5faac96407360c8 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/compiler/__pycache__/config.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/compiler/_cache.py b/phivenv/Lib/site-packages/torch/compiler/_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..4ceaf457b594424087bd99fe1f54bd9b15be5ea8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/compiler/_cache.py @@ -0,0 +1,322 @@ +import copy +import dataclasses +import logging +from abc import ABC, abstractmethod +from collections import defaultdict +from collections.abc import Generator +from contextlib import contextmanager +from itertools import chain +from typing import Any, Optional + +from torch.utils._appending_byte_serializer import ( + AppendingByteSerializer, + BytesReader, + BytesWriter, +) +from torch.utils._ordered_set import OrderedSet + + +log = logging.getLogger(__name__) + + +@dataclasses.dataclass(frozen=True) +class CacheArtifact(ABC): + """ + Data for each cache artifact that will be serialized and deserialized + """ + + key: str + content: bytes = dataclasses.field(repr=False) # Do not display potential binary + + @staticmethod + def serialize(writer: BytesWriter, cls: "CacheArtifact") -> None: + writer.write_str(cls.key) + writer.write_bytes(cls.content) + + @staticmethod + def deserialize(artifact_type: str, reader: BytesReader) -> "CacheArtifact": + key = reader.read_str() + content = reader.read_bytes() + return CacheArtifactFactory.create(artifact_type, key, content) + + @staticmethod + def encode(content: Any) -> bytes: + assert isinstance(content, bytes), f"Expected bytes, got {type(content)}" + return content + + @abstractmethod + def populate_cache(self) -> None: + pass + + def precompile_compatible(self) -> bool: + return False + + @staticmethod + def type() -> str: + """ + Returns the type of the artifact. Must be unique across all CacheArtifact classes. + + CacheArtifactFactory.register will add property method to CacheInfo based on this (def {type}_artifacts) + that returns all artifacts for specific cache. + """ + raise RuntimeError("CacheArtifact is an abstract class, please use a subclass") + + +class CacheArtifactFactory: + """ + Factory for creating CacheArtifact objects based on their type + """ + + _artifact_types: dict[str, type[CacheArtifact]] = {} + + @classmethod + def register(cls, artifact_cls: type[CacheArtifact]) -> type[CacheArtifact]: + artifact_type_key = artifact_cls.type() + assert ( + artifact_cls.type() not in cls._artifact_types + ), f"Artifact of type={artifact_type_key} already registered in mega-cache artifact factory" + cls._artifact_types[artifact_type_key] = artifact_cls + setattr( + CacheInfo, + f"{artifact_type_key}_artifacts", + property(lambda self: self.artifacts[artifact_type_key]), + ) + return artifact_cls + + @classmethod + def _get_artifact_type(cls, artifact_type_key: str) -> type[CacheArtifact]: + assert ( + artifact_type_key in cls._artifact_types + ), f"Artifact of type={artifact_type_key} not registered in mega-cache artifact factory" + return cls._artifact_types[artifact_type_key] + + @classmethod + def create(cls, artifact_type_key: str, key: str, content: bytes) -> CacheArtifact: + artifact_cls = cls._get_artifact_type(artifact_type_key) + return artifact_cls(key, content) + + @classmethod + def encode_create( + cls, artifact_type_key: str, key: str, content: Any + ) -> CacheArtifact: + artifact_cls = cls._get_artifact_type(artifact_type_key) + return artifact_cls(key, artifact_cls.encode(content)) + + +@dataclasses.dataclass +class CacheInfo: + """ + Return value of serialization and deserialization for the purpose of + instrumentation + """ + + artifacts: defaultdict[str, list[str]] = dataclasses.field( + default_factory=lambda: defaultdict(list) + ) + + # Methods set by CacheArtifactFactory.register based on CacheArtifact.type() + @property + def inductor_artifacts(self) -> list[str]: # type: ignore[empty-body] + ... + + @property + def autotune_artifacts(self) -> list[str]: # type: ignore[empty-body] + ... + + @property + def aot_autograd_artifacts(self) -> list[str]: # type: ignore[empty-body] + ... + + @property + def pgo_artifacts(self) -> list[str]: # type: ignore[empty-body] + ... + + @property + def precompile_aot_autograd_artifacts(self) -> list[str]: # type: ignore[empty-body] + ... + + def add(self, artifact: CacheArtifact) -> None: + self.artifacts[artifact.type()].append(artifact.key) + + def clear(self) -> None: + self.artifacts.clear() + + def empty(self) -> bool: + return not self.artifacts + + +def _serialize_single_cache( + writer: BytesWriter, cls: "tuple[str, list[CacheArtifact]]" +) -> None: + writer.write_str(cls[0]) + writer.write_uint64(len(cls[1])) + for artifact in cls[1]: + CacheArtifact.serialize(writer, artifact) + + +def _deserialize_single_cache( + reader: BytesReader, +) -> "tuple[str, list[CacheArtifact]]": + artifacts = [] + artifact_type_key = reader.read_str() + num_artifacts = reader.read_uint64() + for _ in range(num_artifacts): + artifacts.append(CacheArtifact.deserialize(artifact_type_key, reader)) + + return artifact_type_key, artifacts + + +CacheArtifactsResult = dict[str, list[CacheArtifact]] + + +class CacheArtifactManager: + """ + Lightweight manager class for collecting and processing cache artifacts for + hot loading + + Intended Lifecycle: + - Execute code via torch.compile, this will call + CacheArtifactManager.record_artifact on each cache artifact + - Call CacheArtifactManager.serialize to convert all the cache artifacts + to portable format + - Call CacheArtifactManager.deserialize to hot load the cache artifacts on + a potentially different process + + NOTE: There's no FB/FC guarentees, results of cache artifacts will not be + used unless code version matches. + """ + + # Protected by the compile_lock + _new_cache_artifacts: CacheArtifactsResult = defaultdict(list) + # Keep a seperate seen artifacts list to make avoid unnecessary duplicates + # This list will not be cleared between serialize() calls + _seen_artifacts: OrderedSet[CacheArtifact] = OrderedSet() + # When serialize() is called, artifacts are transferred from _cache_artifacts to + # internal data structure of the _serializer + # This allows us to only pay the cost of serialization if serialize() is called + _serializer: AppendingByteSerializer[ + tuple[str, list[CacheArtifact]] + ] = AppendingByteSerializer(serialize_fn=_serialize_single_cache) + _cache_info: CacheInfo = CacheInfo() + + @classmethod + def clear(cls) -> None: + cls._new_cache_artifacts.clear() + cls._seen_artifacts.clear() + cls._serializer.clear() + cls._cache_info.clear() + + @classmethod + @contextmanager + def with_fresh_cache(cls) -> Generator[None, None, None]: + original_new_cache_artifacts = cls._new_cache_artifacts + original_seen_artifacts = cls._seen_artifacts + original_serializer = cls._serializer + original_cache_info = cls._cache_info + + cls._new_cache_artifacts = defaultdict(list) + cls._seen_artifacts = OrderedSet() + cls._serializer = AppendingByteSerializer(serialize_fn=_serialize_single_cache) + cls._cache_info = cls._cache_info.__class__() + try: + yield + finally: + cls._new_cache_artifacts = original_new_cache_artifacts + cls._seen_artifacts = original_seen_artifacts + cls._serializer = original_serializer + cls._cache_info = original_cache_info + + @classmethod + def record_artifact( + cls, + artifact_type: str, + key: str, + content: Any, + ) -> None: + """ + Called from each caching operation to record the artifact in this + "mega" list + """ + artifact = CacheArtifactFactory.encode_create(artifact_type, key, content) + if artifact in cls._seen_artifacts: + return + log.debug("Recording %s", str(artifact)) + cls._new_cache_artifacts[artifact_type].append(artifact) + cls._seen_artifacts.add(artifact) + + @classmethod + def need_serialize(cls) -> bool: + """ + Have we seen new artifacts since last serialize call? + """ + return len(cls._new_cache_artifacts) != 0 + + @classmethod + def serialize(cls) -> Optional[tuple[bytes, CacheInfo]]: + """ + Converts the "mega" list into portable format + """ + for artifact in chain(*cls._new_cache_artifacts.values()): + log.debug("saving: %s", artifact) + cls._cache_info.add(artifact) + + if cls._cache_info.empty(): + # If there are not artifacts, dont just return bytes with + # version. + return None + + try: + # We deep copy cls._cache_info since later compilations + # can keep adding to cache_info + info = copy.deepcopy(cls._cache_info) + cls._serializer.extend(cls._new_cache_artifacts.items()) + artifact_bytes = cls._serializer.to_bytes() + cls._new_cache_artifacts.clear() + return artifact_bytes, info + except Exception: + log.warning("Failed to pickle cache artifacts", exc_info=True) + return None + + @staticmethod + def deserialize(serialized_artifacts: bytes) -> Optional[CacheArtifactsResult]: + """ + Converts the portable format back into CacheArtifacts + """ + try: + CacheArtifactManager._ensure_cache_artifacts_registered() + artifacts = dict( + AppendingByteSerializer.to_list( + serialized_artifacts, + deserialize_fn=_deserialize_single_cache, + ) + ) + except Exception: + log.warning("Failed to un-pickle cache artifacts", exc_info=True) + return None + + return artifacts + + @staticmethod + def populate_caches(artifacts: CacheArtifactsResult) -> CacheInfo: + info = CacheInfo() + for artifact in chain(*artifacts.values()): + log.debug("writing: %s", artifact) + info.add(artifact) + artifact.populate_cache() + + return info + + @classmethod + def _ensure_cache_artifacts_registered(cls) -> None: + """When deserializing caches in fresh process, we need to ensure that all + cache artifacts are registered in the cache registry. This is done by + simply importing all the cache artifacts already wrapped with register call. + """ + from torch._dynamo.pgo import PGOCacheArtifact # noqa: F401 + from torch._functorch._aot_autograd.autograd_cache import ( # noqa: F401 + AOTAutogradCacheArtifact, + ) + from torch._inductor.codecache import InductorCacheArtifact # noqa: F401 + from torch._inductor.runtime.autotune_cache import ( # noqa: F401 + AutotuneCacheArtifact, + ) diff --git a/phivenv/Lib/site-packages/torch/compiler/config.py b/phivenv/Lib/site-packages/torch/compiler/config.py new file mode 100644 index 0000000000000000000000000000000000000000..27afa68ae0ca3cc3b7b1a8602919fa169a6d0db2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/compiler/config.py @@ -0,0 +1,91 @@ +""" +This is the top-level configuration module for the compiler, containing +cross-cutting configuration options that affect all parts of the compiler +stack. + +You may also be interested in the per-component configuration modules, which +contain configuration options that affect only a specific part of the compiler: + +* :mod:`torch._dynamo.config` +* :mod:`torch._inductor.config` +* :mod:`torch._functorch.config` +* :mod:`torch.fx.experimental.config` +""" + +import sys +from typing import Optional + +from torch.utils._config_module import Config, install_config_module + + +__all__ = [ + "job_id", +] + + +# NB: Docblocks go UNDER variable definitions! Use spacing to make the +# grouping clear. + +# FB-internal note: you do NOT have to specify this explicitly specify this if +# you run on MAST, we will automatically default this to +# mast:MAST_JOB_NAME:MAST_JOB_VERSION. +job_id: Optional[str] = Config( + env_name_default=["TORCH_COMPILE_JOB_ID", "TORCH_COMPILE_STICKY_PGO_KEY"], + default=None, +) +""" +Semantically, this should be an identifier that uniquely identifies, e.g., a +training job. You might have multiple attempts of the same job, e.g., if it was +preempted or needed to be restarted, but each attempt should be running +substantially the same workload with the same distributed topology. You can +set this by environment variable with :envvar:`TORCH_COMPILE_JOB_ID`. + +Operationally, this controls the effect of profile-guided optimization related +persistent state. PGO state can affect how we perform compilation across +multiple invocations of PyTorch, e.g., the first time you run your program we +may compile twice as we discover what inputs are dynamic, and then PGO will +save this state so subsequent invocations only need to compile once, because +they remember it is dynamic. This profile information, however, is sensitive +to what workload you are running, so we require you to tell us that two jobs +are *related* (i.e., are the same workload) before we are willing to reuse +this information. Notably, PGO does nothing (even if explicitly enabled) +unless a valid ``job_id`` is available. In some situations, PyTorch can +configured to automatically compute a ``job_id`` based on the environment it +is running in. + +Profiles are always collected on a per rank basis, so different ranks may have +different profiles. If you know your workload is truly SPMD, you can run with +:data:`torch._dynamo.config.enable_compiler_collectives` to ensure nodes get +consistent profiles across all ranks. +""" + + +cache_key_tag: str = Config(env_name_default="TORCH_COMPILE_CACHE_KEY_TAG", default="") +""" +Tag to be included in the cache key generation for all torch compile caching. +A common use case for such a tag is to break caches. +""" + +dynamic_sources: str = Config( + env_name_default="TORCH_COMPILE_DYNAMIC_SOURCES", default="" +) +""" +Comma delimited list of sources that should be marked as dynamic. Primarily useful for large +models with graph breaks where you need intermediate tensors and ints to be marked dynamic. + +This whitelist is dominant over all other flags dynamic=False, force_nn_module_property_static_shapes +and force_parameter_static_shapes. +""" + +unbacked_sources: str = Config( + env_name_default="TORCH_COMPILE_UNBACKED_SOURCES", default="" +) +""" +Comma delimited list of sources that should be marked as unbacked. Primarily useful for large +models with graph breaks where you need intermediate tensors marked unbacked. + +This whitelist is dominant over all other flags dynamic=False, force_nn_module_property_static_shapes +and force_parameter_static_shapes. +""" + +install_config_module(sys.modules[__name__]) diff --git a/phivenv/Lib/site-packages/torch/contrib/__init__.py b/phivenv/Lib/site-packages/torch/contrib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/contrib/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/contrib/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47c3ace3573694f17a4d677ba9c48f20c9360864 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/contrib/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/contrib/__pycache__/_tensorboard_vis.cpython-39.pyc b/phivenv/Lib/site-packages/torch/contrib/__pycache__/_tensorboard_vis.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..add469b91295c5c0602a16a8e919e25c767d9bf6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/contrib/__pycache__/_tensorboard_vis.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/contrib/_tensorboard_vis.py b/phivenv/Lib/site-packages/torch/contrib/_tensorboard_vis.py new file mode 100644 index 0000000000000000000000000000000000000000..77babfd8eac47f8c41572a5b6493b96f554bc1f9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/contrib/_tensorboard_vis.py @@ -0,0 +1,152 @@ +# mypy: allow-untyped-defs +import time +from collections import defaultdict +from functools import partial + +import torch + + +# Unfortunately it doesn't seem as if there was any way to get TensorBoard to do +# anything without having TF installed, and so this file has a hard dependency on it +# as well. It really is a debugging tool, so it doesn't matter. +try: + from tensorflow.core.framework import graph_pb2 + from tensorflow.core.util import event_pb2 + from tensorflow.python.summary.writer.writer import FileWriter +except ImportError: + raise ImportError( + "TensorBoard visualization of GraphExecutors requires having " + "TensorFlow installed" + ) from None + + +def dump_tensorboard_summary(graph_executor, logdir): + with FileWriter(logdir) as w: + pb_graph = visualize(graph_executor) + evt = event_pb2.Event( + wall_time=time.time(), graph_def=pb_graph.SerializeToString() + ) + w.add_event(evt) + + +def visualize(graph, name_prefix="", pb_graph=None, executors_it=None): + """Visualizes an independent graph, or a graph executor.""" + value_map = {} + pb_graph = pb_graph or graph_pb2.GraphDef() + + if isinstance(graph, torch._C.GraphExecutorState): + visualize_graph_executor( + graph, name_prefix, pb_graph, partial(visualize, pb_graph=pb_graph) + ) + return pb_graph + + # Set up an input node + pb_graph.node.add(op="input", name=name_prefix + "input") + for i, value in enumerate(graph.param_node().outputs()): + value_map[value.unique()] = name_prefix + "input:" + str(i) + + visualize_rec(graph, value_map, name_prefix, pb_graph, executors_it) + + # Gather all outputs + return_node = pb_graph.node.add(op="output", name=name_prefix + "output") + for value in graph.return_node().inputs(): + return_node.input.append(value_map[value.unique()]) + + return pb_graph + + +def visualize_graph_executor(state, name_prefix, pb_graph, inline_graph): + """Append the state of a given GraphExecutor to the graph protobuf. + + Args: + state (GraphExecutor or GraphExecutorState): GraphExecutor to display. + name_prefix (str): Name prefix of the containing subgraph. + pb_graph (GraphDef): graph to append to. + inline_graph (Callable): a function that handles setting up a value_map, + so that some graphs in here can be inlined. This is necessary, because + this will simply be `visualize` for the top-level GraphExecutor, + or `inline_graph` for all nested ones. + + The signature should look like (Graph, name_prefix) -> (). + It will be called exactly once. + + The strategy is to embed all different configurations as independent subgraphs, + while inlining the original graph as the one that actually produces the values. + """ + if state.autograd_fallback_graph is not None: + visualize( + graph=state.autograd_fallback_graph, + name_prefix=name_prefix + "autograd_fallback/", + pb_graph=pb_graph, + executors_it=iter(state.autograd_fallback.executors()), + ) + + for i, (arg_spec, plan) in enumerate(state.execution_plans.items()): + subgraph_name = name_prefix + f"plan{i}/" + + # Create a disconnected node that will keep information regarding the input + # types of this trace. This is unfortunately a bit too verbose to be included + # in the subgraph name. + input_kinds = pb_graph.node.add(op="INPUT_KIND", name=subgraph_name) + input_kinds.attr["inputs"].s = repr(arg_spec).encode("ascii") + + visualize(plan.graph, subgraph_name, pb_graph, iter(plan.code.executors())) + + # Show gradient as an independent subgraph of this plan + if plan.grad_executor is not None: + grad_subgraph_name = subgraph_name + "grad/" + visualize(plan.grad_executor, grad_subgraph_name, pb_graph) + + return inline_graph(state.graph, name_prefix + "original/") + + +def visualize_rec(graph, value_map, name_prefix, pb_graph, executors_it=None): + """Recursive part of visualize (basically skips setting up the input and output nodes).""" + + def inline_graph(subgraph, name, node): + rec_value_map = { + inp.unique(): value_map[val.unique()] + for inp, val in zip(subgraph.inputs(), node.inputs()) + } + visualize_rec( + graph=subgraph, value_map=rec_value_map, name_prefix=name, pb_graph=pb_graph + ) + for out, val in zip(subgraph.outputs(), node.outputs()): + value_map[val.unique()] = rec_value_map[out.unique()] + + op_id_counter: defaultdict[str, int] = defaultdict(int) + + def name_for(node): + kind = node.kind()[node.kind().index("::") + 2 :] + op_id_counter[kind] += 1 + return kind, name_prefix + kind + "_" + str(op_id_counter[kind]) + + def add_fusion_group(node): + op, name = name_for(node) + inline_graph(node.g("Subgraph"), name + "/", node) + + def add_graph_executor(node): + op, name = name_for(node) + if executors_it is None: + add_node(node) + else: + ge = next(executors_it) + visualize_graph_executor( + ge, name + "/", pb_graph, partial(inline_graph, node=node) + ) + + def add_node(node): + if node.kind() == "prim::FusionGroup": + return add_fusion_group(node) + elif node.kind() == "prim::GraphExecutor": + return add_graph_executor(node) + op, name = name_for(node) + pb_node = pb_graph.node.add(op=op, name=name) + for value in node.inputs(): + pb_node.input.append(value_map[value.unique()]) + # TODO: handle attrs + for i, value in enumerate(node.outputs()): + value_map[value.unique()] = name + ":" + str(i) + + for node in graph.nodes(): + add_node(node) diff --git a/phivenv/Lib/site-packages/torch/cpu/__init__.py b/phivenv/Lib/site-packages/torch/cpu/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bfc61fcc48a52410b9cadab16fd54e58e20f761f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/cpu/__init__.py @@ -0,0 +1,204 @@ +# mypy: allow-untyped-defs +r""" +This package implements abstractions found in ``torch.cuda`` +to facilitate writing device-agnostic code. +""" + +from contextlib import AbstractContextManager +from typing import Any, Optional, Union + +import torch + +from .. import device as _device +from . import amp + + +__all__ = [ + "is_available", + "is_initialized", + "synchronize", + "current_device", + "current_stream", + "stream", + "set_device", + "device_count", + "Stream", + "StreamContext", + "Event", +] + +_device_t = Union[_device, str, int, None] + + +def _is_avx2_supported() -> bool: + r"""Returns a bool indicating if CPU supports AVX2.""" + return torch._C._cpu._is_avx2_supported() + + +def _is_avx512_supported() -> bool: + r"""Returns a bool indicating if CPU supports AVX512.""" + return torch._C._cpu._is_avx512_supported() + + +def _is_avx512_bf16_supported() -> bool: + r"""Returns a bool indicating if CPU supports AVX512_BF16.""" + return torch._C._cpu._is_avx512_bf16_supported() + + +def _is_vnni_supported() -> bool: + r"""Returns a bool indicating if CPU supports VNNI.""" + # Note: Currently, it only checks avx512_vnni, will add the support of avx2_vnni later. + return torch._C._cpu._is_avx512_vnni_supported() + + +def _is_amx_tile_supported() -> bool: + r"""Returns a bool indicating if CPU supports AMX_TILE.""" + return torch._C._cpu._is_amx_tile_supported() + + +def _is_amx_fp16_supported() -> bool: + r"""Returns a bool indicating if CPU supports AMX FP16.""" + return torch._C._cpu._is_amx_fp16_supported() + + +def _init_amx() -> bool: + r"""Initializes AMX instructions.""" + return torch._C._cpu._init_amx() + + +def is_available() -> bool: + r"""Returns a bool indicating if CPU is currently available. + + N.B. This function only exists to facilitate device-agnostic code + + """ + return True + + +def synchronize(device: _device_t = None) -> None: + r"""Waits for all kernels in all streams on the CPU device to complete. + + Args: + device (torch.device or int, optional): ignored, there's only one CPU device. + + N.B. This function only exists to facilitate device-agnostic code. + """ + + +class Stream: + """ + N.B. This class only exists to facilitate device-agnostic code + """ + + def __init__(self, priority: int = -1) -> None: + pass + + def wait_stream(self, stream) -> None: + pass + + def record_event(self) -> None: + pass + + def wait_event(self, event) -> None: + pass + + +class Event: + def query(self) -> bool: + return True + + def record(self, stream=None) -> None: + pass + + def synchronize(self) -> None: + pass + + def wait(self, stream=None) -> None: + pass + + +_default_cpu_stream = Stream() +_current_stream = _default_cpu_stream + + +def current_stream(device: _device_t = None) -> Stream: + r"""Returns the currently selected :class:`Stream` for a given device. + + Args: + device (torch.device or int, optional): Ignored. + + N.B. This function only exists to facilitate device-agnostic code + + """ + return _current_stream + + +class StreamContext(AbstractContextManager): + r"""Context-manager that selects a given stream. + + N.B. This class only exists to facilitate device-agnostic code + + """ + + cur_stream: Optional[Stream] + + def __init__(self, stream): + self.stream = stream + self.prev_stream = _default_cpu_stream + + def __enter__(self): + cur_stream = self.stream + if cur_stream is None: + return + + global _current_stream + self.prev_stream = _current_stream + _current_stream = cur_stream + + def __exit__(self, type: Any, value: Any, traceback: Any) -> None: + cur_stream = self.stream + if cur_stream is None: + return + + global _current_stream + _current_stream = self.prev_stream + + +def stream(stream: Stream) -> AbstractContextManager: + r"""Wrapper around the Context-manager StreamContext that + selects a given stream. + + N.B. This function only exists to facilitate device-agnostic code + """ + return StreamContext(stream) + + +def device_count() -> int: + r"""Returns number of CPU devices (not cores). Always 1. + + N.B. This function only exists to facilitate device-agnostic code + """ + return 1 + + +def set_device(device: _device_t) -> None: + r"""Sets the current device, in CPU we do nothing. + + N.B. This function only exists to facilitate device-agnostic code + """ + + +def current_device() -> str: + r"""Returns current device for cpu. Always 'cpu'. + + N.B. This function only exists to facilitate device-agnostic code + """ + return "cpu" + + +def is_initialized() -> bool: + r"""Returns True if the CPU is initialized. Always True. + + N.B. This function only exists to facilitate device-agnostic code + """ + return True diff --git a/phivenv/Lib/site-packages/torch/cpu/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/cpu/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e76c9c7b142a3f1583facf3a46c57c68a4b6cf4b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/cpu/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/cpu/amp/__init__.py b/phivenv/Lib/site-packages/torch/cpu/amp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..973717d653ba9c47a5b63be6aa699dfe0d25b58f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/cpu/amp/__init__.py @@ -0,0 +1,2 @@ +from .autocast_mode import autocast +from .grad_scaler import GradScaler diff --git a/phivenv/Lib/site-packages/torch/cpu/amp/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/cpu/amp/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36edd27379c2a4168a0cf267ea2ac1db7abc3fd4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/cpu/amp/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/cpu/amp/__pycache__/autocast_mode.cpython-39.pyc b/phivenv/Lib/site-packages/torch/cpu/amp/__pycache__/autocast_mode.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50bdae25611ec3a04be16c993fba32009a73290f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/cpu/amp/__pycache__/autocast_mode.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/cpu/amp/__pycache__/grad_scaler.cpython-39.pyc b/phivenv/Lib/site-packages/torch/cpu/amp/__pycache__/grad_scaler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..761f14e38bb021410f81e6912bf3a4b4564db8d8 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/cpu/amp/__pycache__/grad_scaler.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/cpu/amp/autocast_mode.py b/phivenv/Lib/site-packages/torch/cpu/amp/autocast_mode.py new file mode 100644 index 0000000000000000000000000000000000000000..26a2d9cfcd3a69dc0212addf3126506bddcaf233 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/cpu/amp/autocast_mode.py @@ -0,0 +1,51 @@ +# mypy: allow-untyped-defs +from typing import Any +from typing_extensions import deprecated + +import torch + + +__all__ = ["autocast"] + + +class autocast(torch.amp.autocast_mode.autocast): + r""" + See :class:`torch.autocast`. + ``torch.cpu.amp.autocast(args...)`` is deprecated. Please use ``torch.amp.autocast("cpu", args...)`` instead. + """ + + @deprecated( + "`torch.cpu.amp.autocast(args...)` is deprecated. " + "Please use `torch.amp.autocast('cpu', args...)` instead.", + category=FutureWarning, + ) + def __init__( + self, + enabled: bool = True, + dtype: torch.dtype = torch.bfloat16, + cache_enabled: bool = True, + ): + if torch._jit_internal.is_scripting(): + self._enabled = enabled + self.device = "cpu" + self.fast_dtype = dtype + return + super().__init__( + "cpu", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled + ) + + def __enter__(self): + if torch._jit_internal.is_scripting(): + return self + return super().__enter__() + + # TODO: discuss a unified TorchScript-friendly API for autocast + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override] + if torch._jit_internal.is_scripting(): + return + return super().__exit__(exc_type, exc_val, exc_tb) + + def __call__(self, func): + if torch._jit_internal.is_scripting(): + return func + return super().__call__(func) diff --git a/phivenv/Lib/site-packages/torch/cpu/amp/grad_scaler.py b/phivenv/Lib/site-packages/torch/cpu/amp/grad_scaler.py new file mode 100644 index 0000000000000000000000000000000000000000..3e61c0b41ec440a20f35217ecba53a5b22eb1c93 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/cpu/amp/grad_scaler.py @@ -0,0 +1,35 @@ +from typing_extensions import deprecated + +import torch + + +__all__ = ["GradScaler"] + + +class GradScaler(torch.amp.GradScaler): + r""" + See :class:`torch.amp.GradScaler`. + ``torch.cpu.amp.GradScaler(args...)`` is deprecated. Please use ``torch.amp.GradScaler("cpu", args...)`` instead. + """ + + @deprecated( + "`torch.cpu.amp.GradScaler(args...)` is deprecated. " + "Please use `torch.amp.GradScaler('cpu', args...)` instead.", + category=FutureWarning, + ) + def __init__( + self, + init_scale: float = 2.0**16, + growth_factor: float = 2.0, + backoff_factor: float = 0.5, + growth_interval: int = 2000, + enabled: bool = True, + ) -> None: + super().__init__( + "cpu", + init_scale=init_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + enabled=enabled, + ) diff --git a/phivenv/Lib/site-packages/torch/cuda/__init__.py b/phivenv/Lib/site-packages/torch/cuda/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..da5ef3302675c2a6141b333b5b50bfe9c9d0e319 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/cuda/__init__.py @@ -0,0 +1,1932 @@ +# mypy: allow-untyped-defs +r""" +This package adds support for CUDA tensor types. + +It implements the same function as CPU tensors, but they utilize +GPUs for computation. + +It is lazily initialized, so you can always import it, and use +:func:`is_available()` to determine if your system supports CUDA. + +:ref:`cuda-semantics` has more details about working with CUDA. +""" + +import importlib +import os +import sys +import threading +import traceback +import warnings +from functools import lru_cache +from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union + +import torch +import torch._C +from torch import device as _device +from torch._utils import _dummy_type, _LazySeedTracker, classproperty + +from . import gds +from ._utils import _get_device_index +from .graphs import ( + CUDAGraph, + graph, + graph_pool_handle, + is_current_stream_capturing, + make_graphed_callables, +) +from .streams import Event, ExternalStream, Stream + + +if TYPE_CHECKING: + from torch.types import Device + +try: + from torch._C import _cudart # type: ignore[attr-defined] +except ImportError: + _cudart = None + +_initialized = False +_tls = threading.local() +_initialization_lock = threading.Lock() +_queued_calls: list[ + tuple[Callable[[], None], list[str]] +] = [] # don't invoke these until initialization occurs +_is_in_bad_fork = getattr(torch._C, "_cuda_isInBadFork", lambda: False) + +_HAS_PYNVML = False +_PYNVML_ERR = None +try: + from torch import version as _version + + try: + if not _version.hip: + import pynvml # type: ignore[import] + else: + import ctypes + from pathlib import Path + + # In ROCm (at least up through 6.3.2) there're 2 copies of libamd_smi.so: + # - One at lib/libamd_smi.so + # - One at share/amd_smi/amdsmi/libamd_smi.so + # + # The amdsmi python module hardcodes loading the second one in share- + # https://github.com/ROCm/amdsmi/blob/1d305dc9708e87080f64f668402887794cd46584/py-interface/amdsmi_wrapper.py#L174 + # + # See also https://github.com/ROCm/amdsmi/issues/72. + # + # This creates an ODR violation if the copy of libamd_smi.so from lib + # is also loaded (via `ld` linking, `LD_LIBRARY_PATH` or `rpath`). + # + # In order to avoid the violation we hook CDLL and try using the + # already loaded version of amdsmi, or any version in the processes + # rpath/LD_LIBRARY_PATH first, so that we only load a single copy + # of the .so. + class _amdsmi_cdll_hook: + def __init__(self) -> None: + self.original_CDLL = ctypes.CDLL # type: ignore[misc,assignment] + paths = ["libamd_smi.so"] + if rocm_home := os.getenv("ROCM_HOME", os.getenv("ROCM_PATH")): + paths = [os.path.join(rocm_home, "lib/libamd_smi.so")] + paths + self.paths: list[str] = paths + + def hooked_CDLL( + self, name: Union[str, Path, None], *args: Any, **kwargs: Any + ) -> ctypes.CDLL: + if name and Path(name).name == "libamd_smi.so": + for path in self.paths: + try: + return self.original_CDLL(path, *args, **kwargs) + except OSError: + pass + return self.original_CDLL(name, *args, **kwargs) # type: ignore[arg-type] + + def __enter__(self) -> None: + ctypes.CDLL = self.hooked_CDLL # type: ignore[misc,assignment] + + def __exit__(self, type: Any, value: Any, traceback: Any) -> None: + ctypes.CDLL = self.original_CDLL # type: ignore[misc] + + with _amdsmi_cdll_hook(): + import amdsmi # type: ignore[import] + + _HAS_PYNVML = True + except ModuleNotFoundError: + pass + finally: + del _version +except ImportError as err: + _PYNVML_ERR = err # sometimes a lib is installed but the import fails for some other reason, so we log the error for later + +_lazy_seed_tracker = _LazySeedTracker() + +# Define dummy _CudaDeviceProperties type if PyTorch was compiled without CUDA +if hasattr(torch._C, "_CudaDeviceProperties"): + _CudaDeviceProperties = torch._C._CudaDeviceProperties +else: + _CudaDeviceProperties = _dummy_type("_CudaDeviceProperties") # type: ignore[assignment, misc] + +if hasattr(torch._C, "_cuda_exchangeDevice"): + _exchange_device = torch._C._cuda_exchangeDevice +else: + + def _exchange_device(device: int) -> int: + if device < 0: + return -1 + raise RuntimeError("PyTorch was compiled without CUDA support") + + +if hasattr(torch._C, "_cuda_maybeExchangeDevice"): + _maybe_exchange_device = torch._C._cuda_maybeExchangeDevice +else: + + def _maybe_exchange_device(device: int) -> int: + if device < 0: + return -1 + raise RuntimeError("PyTorch was compiled without CUDA support") + + +has_half: bool = True +has_magma: bool = torch._C._has_magma + +default_generators: tuple[torch._C.Generator] = () # type: ignore[assignment] + + +def _is_compiled() -> bool: + r"""Return true if compile with CUDA support.""" + return hasattr(torch._C, "_cuda_getDeviceCount") + + +def _nvml_based_avail() -> bool: + return os.getenv("PYTORCH_NVML_BASED_CUDA_CHECK") == "1" + + +def is_available() -> bool: + r""" + Return a bool indicating if CUDA is currently available. + + .. note:: This function will NOT poison fork if the environment variable + ``PYTORCH_NVML_BASED_CUDA_CHECK=1`` is set. For more details, see + :ref:`multiprocessing-poison-fork-note`. + """ + if not _is_compiled(): + return False + if _nvml_based_avail(): + # The user has set an env variable to request this availability check that attempts to avoid fork poisoning by + # using NVML at the cost of a weaker CUDA availability assessment. Note that if NVML discovery/initialization + # fails, this assessment falls back to the default CUDA Runtime API assessment (`cudaGetDeviceCount`) + return device_count() > 0 + else: + # The default availability inspection never throws and returns 0 if the driver is missing or can't + # be initialized. This uses the CUDA Runtime API `cudaGetDeviceCount` which in turn initializes the CUDA Driver + # API via `cuInit` + return torch._C._cuda_getDeviceCount() > 0 + + +def is_bf16_supported(including_emulation: bool = True): + r"""Return a bool indicating if the current CUDA/ROCm device supports dtype bfloat16.""" + # Check for ROCm, if true return true, no ROCM_VERSION check required, + # since it is supported on AMD GPU archs. + if torch.version.hip: + return True + + # If CUDA is not available, than it does not support bf16 either + if not is_available(): + return False + + device = torch.cuda.current_device() + + # Check for CUDA version and device compute capability. + # This is a fast way to check for it. + cuda_version = torch.version.cuda + if cuda_version is not None and torch.cuda.get_device_properties(device).major >= 8: + return True + + if not including_emulation: + return False + + # Finally try to create a bfloat16 device. + return _check_bf16_tensor_supported(device) + + +@lru_cache(maxsize=16) +def _check_bf16_tensor_supported(device: "Device"): + try: + torch.tensor([1.0], dtype=torch.bfloat16, device=device) + return True + except Exception: + return False + + +def is_tf32_supported() -> bool: + r"""Return a bool indicating if the current CUDA/ROCm device supports dtype tf32.""" + if torch.version.hip: + prop_name = torch.cuda.get_device_properties().gcnArchName + archs = ("gfx94", "gfx95") + for arch in archs: + if arch in prop_name: + return True + return False + + # Otherwise, tf32 is supported on CUDA platforms that natively (i.e. no emulation) + # support bfloat16. + return is_bf16_supported(including_emulation=False) + + +def _sleep(cycles): + torch._C._cuda_sleep(cycles) + + +def _extract_arch_version(arch_string: str): + """Extracts the architecture string from a CUDA version""" + base = arch_string.split("_")[1] + base = base.removesuffix("a") + return int(base) + + +def _check_capability(): + incompatible_gpu_warn = """ + Found GPU%d %s which is of cuda capability %d.%d. + Minimum and Maximum cuda capability supported by this version of PyTorch is + (%d.%d) - (%d.%d) + """ + matched_cuda_warn = """ + Please install PyTorch with a following CUDA + configurations: {} following instructions at + https://pytorch.org/get-started/locally/ + """ + + # Binary CUDA_ARCHES SUPPORTED by PyTorch + CUDA_ARCHES_SUPPORTED = { + "12.6": {"min": 50, "max": 90}, + "12.8": {"min": 70, "max": 120}, + "12.9": {"min": 70, "max": 120}, + } + + if ( + torch.version.cuda is not None and torch.cuda.get_arch_list() + ): # on ROCm we don't want this check + for d in range(device_count()): + capability = get_device_capability(d) + major = capability[0] + minor = capability[1] + name = get_device_name(d) + current_arch = major * 10 + minor + min_arch = min( + (_extract_arch_version(arch) for arch in torch.cuda.get_arch_list()), + default=50, + ) + max_arch = max( + (_extract_arch_version(arch) for arch in torch.cuda.get_arch_list()), + default=50, + ) + if current_arch < min_arch or current_arch > max_arch: + warnings.warn( + incompatible_gpu_warn + % ( + d, + name, + major, + minor, + min_arch // 10, + min_arch % 10, + max_arch // 10, + max_arch % 10, + ) + ) + matched_arches = "" + for arch, arch_info in CUDA_ARCHES_SUPPORTED.items(): + if ( + current_arch >= arch_info["min"] + and current_arch <= arch_info["max"] + ): + matched_arches += f" {arch}" + if matched_arches != "": + warnings.warn(matched_cuda_warn.format(matched_arches)) + + +def _check_cubins(): + incompatible_device_warn = """ +{} with CUDA capability sm_{} is not compatible with the current PyTorch installation. +The current PyTorch install supports CUDA capabilities {}. +If you want to use the {} GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/ +""" + if torch.version.cuda is None: # on ROCm we don't want this check + return + arch_list = get_arch_list() + if len(arch_list) == 0: + return + supported_sm = [_extract_arch_version(arch) for arch in arch_list if "sm_" in arch] + for idx in range(device_count()): + cap_major, cap_minor = get_device_capability(idx) + # NVIDIA GPU compute architectures are backward compatible within major version + supported = any(sm // 10 == cap_major for sm in supported_sm) + if not supported: + device_name = get_device_name(idx) + capability = cap_major * 10 + cap_minor + warnings.warn( + incompatible_device_warn.format( + device_name, capability, " ".join(arch_list), device_name + ) + ) + + +def is_initialized(): + r"""Return whether PyTorch's CUDA state has been initialized.""" + return _initialized and not _is_in_bad_fork() + + +def _lazy_call(callable, **kwargs): + with _initialization_lock: + if is_initialized(): + callable() + else: + # TODO(torch_deploy): this accesses linecache, which attempts to read the + # file system to get traceback info. Patch linecache or do something + # else here if this ends up being important. + global _lazy_seed_tracker + if kwargs.get("seed_all", False): + _lazy_seed_tracker.queue_seed_all(callable, traceback.format_stack()) + elif kwargs.get("seed", False): + _lazy_seed_tracker.queue_seed(callable, traceback.format_stack()) + else: + # Don't store the actual traceback to avoid memory cycle + _queued_calls.append((callable, traceback.format_stack())) + + +_lazy_call(_check_capability) +_lazy_call(_check_cubins) + + +class DeferredCudaCallError(Exception): + pass + + +AcceleratorError = torch._C.AcceleratorError +OutOfMemoryError = torch._C.OutOfMemoryError + + +def init(): + r"""Initialize PyTorch's CUDA state. + + You may need to call this explicitly if you are interacting with + PyTorch via its C API, as Python bindings for CUDA functionality + will not be available until this initialization takes place. + Ordinary users should not need this, as all of PyTorch's CUDA methods + automatically initialize CUDA state on-demand. + + Does nothing if the CUDA state is already initialized. + """ + _lazy_init() + + +def _lazy_init(): + global _initialized, _queued_calls + if is_initialized() or hasattr(_tls, "is_initializing"): + return + with _initialization_lock: + # We be double-checked locking, boys! This is OK because + # the above test was GIL protected anyway. The inner test + # is for when a thread blocked on some other thread which was + # doing the initialization; when they get the lock, they will + # find there is nothing left to do. + if is_initialized(): + return + # It is important to prevent other threads from entering _lazy_init + # immediately, while we are still guaranteed to have the GIL, because some + # of the C calls we make below will release the GIL + if _is_in_bad_fork(): + raise RuntimeError( + "Cannot re-initialize CUDA in forked subprocess. To use CUDA with " + "multiprocessing, you must use the 'spawn' start method" + ) + if not hasattr(torch._C, "_cuda_getDeviceCount"): + raise AssertionError("Torch not compiled with CUDA enabled") + if _cudart is None: + raise AssertionError( + "libcudart functions unavailable. It looks like you have a broken build?" + ) + # This function throws if there's a driver initialization error, no GPUs + # are found or any other error occurs + if "CUDA_MODULE_LOADING" not in os.environ: + os.environ["CUDA_MODULE_LOADING"] = "LAZY" + torch._C._cuda_init() + # Some of the queued calls may reentrantly call _lazy_init(); + # we need to just return without initializing in that case. + # However, we must not let any *other* threads in! + _tls.is_initializing = True + + _queued_calls.extend(calls for calls in _lazy_seed_tracker.get_calls() if calls) + + try: + for queued_call, orig_traceback in _queued_calls: + try: + queued_call() + except Exception as e: + msg = ( + f"CUDA call failed lazily at initialization with error: {str(e)}\n\n" + f"CUDA call was originally invoked at:\n\n{''.join(orig_traceback)}" + ) + raise DeferredCudaCallError(msg) from e + finally: + delattr(_tls, "is_initializing") + _initialized = True + + +def cudart(): + r"""Retrieves the CUDA runtime API module. + + + This function initializes the CUDA runtime environment if it is not already + initialized and returns the CUDA runtime API module (_cudart). The CUDA + runtime API module provides access to various CUDA runtime functions. + + Args: + ``None`` + + Returns: + module: The CUDA runtime API module (_cudart). + + Raises: + RuntimeError: If CUDA cannot be re-initialized in a forked subprocess. + AssertionError: If PyTorch is not compiled with CUDA support or if libcudart functions are unavailable. + + Example of CUDA operations with profiling: + >>> import torch + >>> from torch.cuda import cudart, check_error + >>> import os + >>> + >>> os.environ['CUDA_PROFILE'] = '1' + >>> + >>> def perform_cuda_operations_with_streams(): + >>> stream = torch.cuda.Stream() + >>> with torch.cuda.stream(stream): + >>> x = torch.randn(100, 100, device='cuda') + >>> y = torch.randn(100, 100, device='cuda') + >>> z = torch.mul(x, y) + >>> return z + >>> + >>> torch.cuda.synchronize() + >>> print("====== Start nsys profiling ======") + >>> check_error(cudart().cudaProfilerStart()) + >>> with torch.autograd.profiler.emit_nvtx(): + >>> result = perform_cuda_operations_with_streams() + >>> print("CUDA operations completed.") + >>> check_error(torch.cuda.cudart().cudaProfilerStop()) + >>> print("====== End nsys profiling ======") + + To run this example and save the profiling information, execute: + >>> $ nvprof --profile-from-start off --csv --print-summary -o trace_name.prof -f -- python cudart_test.py + + This command profiles the CUDA operations in the provided script and saves + the profiling information to a file named `trace_name.prof`. + The `--profile-from-start off` option ensures that profiling starts only + after the `cudaProfilerStart` call in the script. + The `--csv` and `--print-summary` options format the profiling output as a + CSV file and print a summary, respectively. + The `-o` option specifies the output file name, and the `-f` option forces the + overwrite of the output file if it already exists. + """ + _lazy_init() + return _cudart + + +class cudaStatus: + SUCCESS: int = 0 + ERROR_NOT_READY: int = 34 + + +class CudaError(RuntimeError): + def __init__(self, code: int) -> None: + msg = _cudart.cudaGetErrorString(_cudart.cudaError(code)) + super().__init__(f"{msg} ({code})") + + +def check_error(res: int) -> None: + if res != _cudart.cudaError.success: + raise CudaError(res) + + +class _DeviceGuard: + def __init__(self, index: int): + self.idx = index + self.prev_idx = -1 + + def __enter__(self): + self.prev_idx = torch.cuda._exchange_device(self.idx) + + def __exit__(self, type: Any, value: Any, traceback: Any): + self.idx = torch.cuda._maybe_exchange_device(self.prev_idx) + return False + + +class device: + r"""Context-manager that changes the selected device. + + Args: + device (torch.device or int): device index to select. It's a no-op if + this argument is a negative integer or ``None``. + """ + + def __init__(self, device: Any): + self.idx = _get_device_index(device, optional=True) + self.prev_idx = -1 + + def __enter__(self): + self.prev_idx = torch.cuda._exchange_device(self.idx) + + def __exit__(self, type: Any, value: Any, traceback: Any): + self.idx = torch.cuda._maybe_exchange_device(self.prev_idx) + return False + + +class device_of(device): + r"""Context-manager that changes the current device to that of given object. + + You can use both tensors and storages as arguments. If a given object is + not allocated on a GPU, this is a no-op. + + Args: + obj (Tensor or Storage): object allocated on the selected device. + """ + + def __init__(self, obj): + idx = obj.get_device() if obj.is_cuda else -1 + super().__init__(idx) + + +def set_device(device: "Device") -> None: + r"""Set the current device. + + Usage of this function is discouraged in favor of :any:`device`. In most + cases it's better to use ``CUDA_VISIBLE_DEVICES`` environmental variable. + + Args: + device (torch.device or int): selected device. This function is a no-op + if this argument is negative. + """ + device = _get_device_index(device) + if device >= 0: + torch._C._cuda_setDevice(device) + + +def get_device_name(device: "Device" = None) -> str: + r"""Get the name of a device. + + Args: + device (torch.device or int or str, optional): device for which to return the + name. This function is a no-op if this argument is a negative + integer. It uses the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + Returns: + str: the name of the device + """ + return get_device_properties(device).name + + +def get_device_capability(device: "Device" = None) -> tuple[int, int]: + r"""Get the cuda capability of a device. + + Args: + device (torch.device or int or str, optional): device for which to return the + device capability. This function is a no-op if this argument is + a negative integer. It uses the current device, given by + :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` + (default). + + Returns: + tuple(int, int): the major and minor cuda capability of the device + """ + prop = get_device_properties(device) + return prop.major, prop.minor + + +def get_device_properties(device: "Device" = None) -> _CudaDeviceProperties: + r"""Get the properties of a device. + + Args: + device (torch.device or int or str, optional): device for which to return the + properties of the device. It uses the current device, given by + :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` + (default). + + Returns: + _CudaDeviceProperties: the properties of the device + """ + _lazy_init() # will define _get_device_properties + device = _get_device_index(device, optional=True) + if device < 0 or device >= device_count(): + raise AssertionError("Invalid device id") + return _get_device_properties(device) # type: ignore[name-defined] + + +def can_device_access_peer(device: "Device", peer_device: "Device") -> bool: + r"""Check if peer access between two devices is possible.""" + _lazy_init() + device = _get_device_index(device, optional=True) + peer_device = _get_device_index(peer_device) + if device < 0 or device >= device_count(): + raise AssertionError("Invalid device id") + if peer_device < 0 or peer_device >= device_count(): + raise AssertionError("Invalid peer device id") + return torch._C._cuda_canDeviceAccessPeer(device, peer_device) + + +class StreamContext: + r"""Context-manager that selects a given stream. + + All CUDA kernels queued within its context will be enqueued on a selected + stream. + + Args: + Stream (Stream): selected stream. This manager is a no-op if it's + ``None``. + .. note:: Streams are per-device. + """ + + cur_stream: Optional["torch.cuda.Stream"] + + def __init__(self, stream: Optional["torch.cuda.Stream"]): + self.stream = stream + self.idx = _get_device_index(None, True) + if not torch.jit.is_scripting(): + if self.idx is None: + self.idx = -1 + + self.src_prev_stream = ( + None if not torch.jit.is_scripting() else torch.cuda.default_stream(None) + ) + self.dst_prev_stream = ( + None if not torch.jit.is_scripting() else torch.cuda.default_stream(None) + ) + + def __enter__(self): + # Local cur_stream variable for type refinement + cur_stream = self.stream + # Return if stream is None or CUDA device not available + if cur_stream is None or self.idx == -1: + return + self.src_prev_stream = torch.cuda.current_stream(None) + + # If the stream is not on the current device, then + # set the current stream on the device + if self.src_prev_stream.device != cur_stream.device: + with device(cur_stream.device): + self.dst_prev_stream = torch.cuda.current_stream(cur_stream.device) + torch.cuda.set_stream(cur_stream) + + def __exit__(self, type: Any, value: Any, traceback: Any): + # Local cur_stream variable for type refinement + cur_stream = self.stream + # If stream is None or no CUDA device available, return + if cur_stream is None or self.idx == -1: + return + + # Reset the stream on the original device + # and destination device + if self.src_prev_stream.device != cur_stream.device: # type: ignore[union-attr] + torch.cuda.set_stream(self.dst_prev_stream) # type: ignore[arg-type] + torch.cuda.set_stream(self.src_prev_stream) # type: ignore[arg-type] + + +def stream(stream: Optional["torch.cuda.Stream"]) -> StreamContext: + r"""Wrap around the Context-manager StreamContext that selects a given stream. + + Arguments: + stream (Stream): selected stream. This manager is a no-op if it's + ``None``. + .. note:: + In eager mode stream is of type Stream class while in JIT it is + an object of the custom class ``torch.classes.cuda.Stream``. + """ + return StreamContext(stream) + + +def _set_stream_by_id(stream_id, device_index, device_type): + r"""set stream specified by the stream id, device index and + device type + + Args: stream_id (int): stream id in stream pool + device_index (int): device index in topo + device_type (int): enum device type + """ + torch._C._cuda_setStream( + stream_id=stream_id, + device_index=device_index, + device_type=device_type, + ) + + +def set_stream(stream: Stream): + r"""Set the current stream.This is a wrapper API to set the stream. + Usage of this function is discouraged in favor of the ``stream`` + context manager. + + Args: + stream (Stream): selected stream. This function is a no-op + if this argument is ``None``. + """ + if stream is None: + return + _set_stream_by_id( + stream_id=stream.stream_id, + device_index=stream.device_index, + device_type=stream.device_type, + ) + + +def _parse_visible_devices() -> Union[list[int], list[str]]: + r"""Parse CUDA_VISIBLE_DEVICES environment variable.""" + var = os.getenv("CUDA_VISIBLE_DEVICES") + + if torch.version.hip: + hip_devices = os.getenv("HIP_VISIBLE_DEVICES") + rocr_devices = os.getenv("ROCR_VISIBLE_DEVICES") + + # You must take care if both HIP and ROCR env vars are set as they have + # different meanings. Both env vars accept either a list of ints or a + # list of UUIDs. The ROCR env var is processed first which then reduces + # the number of GPUs that HIP can select from. + if rocr_devices is not None: + rocr_count = len(rocr_devices.split(",")) + if hip_devices is not None: + # sanity check if both env vars are set + if len(hip_devices.split(",")) > rocr_count: + raise RuntimeError( + "HIP_VISIBLE_DEVICES contains more devices than ROCR_VISIBLE_DEVICES" + ) + # HIP_VISIBLE_DEVICES is preferred over ROCR_VISIBLE_DEVICES + var = hip_devices + else: + return list(range(rocr_count)) + elif hip_devices is not None: + var = hip_devices + + if var is None: + return list(range(64)) + + def _strtoul(s: str) -> int: + """Return -1 or positive integer sequence string starts with.""" + if not s: + return -1 + for idx, c in enumerate(s): + if not (c.isdigit() or (idx == 0 and c in "+-")): + break + if idx + 1 == len(s): + idx += 1 + return int(s[:idx]) if idx > 0 else -1 + + def parse_list_with_prefix(lst: str, prefix: str) -> list[str]: + rcs: list[str] = [] + for elem in lst.split(","): + # Repeated id results in empty set + if elem in rcs: + return cast(list[str], []) + # Anything other but prefix is ignored + if not elem.startswith(prefix): + break + rcs.append(elem) + return rcs + + if var.startswith("GPU-"): + return parse_list_with_prefix(var, "GPU-") + if var.startswith("MIG-"): + return parse_list_with_prefix(var, "MIG-") + # CUDA_VISIBLE_DEVICES uses something like strtoul + # which makes `1gpu2,2ampere` is equivalent to `1,2` + rc: list[int] = [] + for elem in var.split(","): + x = _strtoul(elem.strip()) + # Repeated ordinal results in empty set + if x in rc: + return cast(list[int], []) + # Negative value aborts the sequence + if x < 0: + break + rc.append(x) + return rc + + +def _raw_device_count_amdsmi() -> int: + if not _HAS_PYNVML: # If amdsmi is not available + return -1 + try: + amdsmi.amdsmi_init() + except amdsmi.AmdSmiException as e: + warnings.warn(f"Can't initialize amdsmi - Error code: {e.err_code}") + return -1 + socket_handles = amdsmi.amdsmi_get_processor_handles() + return len(socket_handles) + + +def _raw_device_count_nvml() -> int: + r"""Return number of devices as reported by NVML or negative value if NVML discovery/initialization failed.""" + from ctypes import byref, c_int, CDLL + + nvml_h = CDLL("libnvidia-ml.so.1") + rc = nvml_h.nvmlInit() + if rc != 0: + warnings.warn("Can't initialize NVML") + return -1 + dev_count = c_int(-1) + rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count)) + if rc != 0: + warnings.warn("Can't get nvml device count") + return -1 + del nvml_h + return dev_count.value + + +def _raw_device_uuid_amdsmi() -> Optional[list[str]]: + from ctypes import byref, c_int, c_void_p, CDLL, create_string_buffer + + if not _HAS_PYNVML: # If amdsmi is not available + return None + try: + amdsmi.amdsmi_init() + except amdsmi.AmdSmiException: + warnings.warn("Can't initialize amdsmi") + return None + try: + socket_handles = amdsmi.amdsmi_get_processor_handles() + dev_count = len(socket_handles) + except amdsmi.AmdSmiException: + warnings.warn("Can't get amdsmi device count") + return None + uuids: list[str] = [] + for idx in range(dev_count): + try: + handler = amdsmi.amdsmi_get_processor_handles()[idx] + except amdsmi.AmdSmiException: + warnings.warn("Cannot get amd device handler") + return None + try: + uuid = amdsmi.amdsmi_get_gpu_asic_info(handler)["asic_serial"][ + 2: + ] # Removes 0x prefix from serial + except amdsmi.AmdSmiException: + warnings.warn("Cannot get uuid for amd device") + return None + uuids.append( + str(uuid).lower() + ) # Lower-case to match expected HIP_VISIBLE_DEVICES uuid input + return uuids + + +def _raw_device_uuid_nvml() -> Optional[list[str]]: + r"""Return list of device UUID as reported by NVML or None if NVM discovery/initialization failed.""" + from ctypes import byref, c_int, c_void_p, CDLL, create_string_buffer + + nvml_h = CDLL("libnvidia-ml.so.1") + rc = nvml_h.nvmlInit() + if rc != 0: + warnings.warn("Can't initialize NVML") + return None + dev_count = c_int(-1) + rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count)) + if rc != 0: + warnings.warn("Can't get nvml device count") + return None + uuids: list[str] = [] + for idx in range(dev_count.value): + dev_id = c_void_p() + rc = nvml_h.nvmlDeviceGetHandleByIndex_v2(idx, byref(dev_id)) + if rc != 0: + warnings.warn("Can't get device handle") + return None + buf_len = 96 + buf = create_string_buffer(buf_len) + rc = nvml_h.nvmlDeviceGetUUID(dev_id, buf, buf_len) + if rc != 0: + warnings.warn("Can't get device UUID") + return None + uuids.append(buf.raw.decode("ascii").strip("\0")) + del nvml_h + return uuids + + +def _transform_uuid_to_ordinals(candidates: list[str], uuids: list[str]) -> list[int]: + r"""Given the set of partial uuids and list of known uuids builds a set of ordinals excluding ambiguous partials IDs.""" + + def uuid_to_ordinal(candidate: str, uuids: list[str]) -> int: + best_match = -1 + for idx, uuid in enumerate(uuids): + if not uuid.startswith(candidate): + continue + # Ambiguous candidate + if best_match != -1: + return -1 + best_match = idx + return best_match + + rc: list[int] = [] + for candidate in candidates: + if torch.version.hip: + candidate = candidate.replace( + "GPU-", "", 1 + ) # Remove GPU-prefix to match amdsmi asic serial + idx = uuid_to_ordinal(candidate, uuids) + # First invalid ordinal stops parsing + if idx < 0: + break + # Duplicates result in empty set + if idx in rc: + return cast(list[int], []) + rc.append(idx) + return rc + + +def _device_count_amdsmi() -> int: + visible_devices = _parse_visible_devices() + if not visible_devices: + return 0 + try: + if type(visible_devices[0]) is str: + uuids = _raw_device_uuid_amdsmi() + if uuids is None: + return -1 + # Create string version of visible devices to avoid mypy warnings + visible_device_str = cast(list[str], visible_devices) + visible_devices = _transform_uuid_to_ordinals(visible_device_str, uuids) + else: + raw_cnt = _raw_device_count_amdsmi() + if raw_cnt <= 0: + return raw_cnt + # Trim the list up to a maximum available device + for idx, val in enumerate(visible_devices): + if cast(int, val) >= raw_cnt: + return idx + except OSError: + return -1 + except AttributeError: + return -1 + return len(visible_devices) + + +def _device_count_nvml() -> int: + r"""Return number of devices as reported by NVML taking CUDA_VISIBLE_DEVICES into account. + + Negative value is returned if NVML discovery or initialization has failed. + """ + visible_devices = _parse_visible_devices() + if not visible_devices: + return 0 + try: + if type(visible_devices[0]) is str: + # Skip MIG parsing + if visible_devices[0].startswith("MIG-"): + return -1 + uuids = _raw_device_uuid_nvml() + if uuids is None: + return -1 + visible_devices = _transform_uuid_to_ordinals( + cast(list[str], visible_devices), uuids + ) + else: + raw_cnt = _raw_device_count_nvml() + if raw_cnt <= 0: + return raw_cnt + # Trim the list up to a maximum available device + for idx, val in enumerate(visible_devices): + if cast(int, val) >= raw_cnt: + return idx + except OSError: + return -1 + except AttributeError: + return -1 + return len(visible_devices) + + +def _get_nvml_device_index(device: "Device") -> int: + r"""Return the NVML index of the device, taking CUDA_VISIBLE_DEVICES into account.""" + idx = _get_device_index(device, optional=True) + visible_devices = _parse_visible_devices() + if type(visible_devices[0]) is str: + uuids = _raw_device_uuid_nvml() + if uuids is None: + raise RuntimeError("Can't get device UUIDs") + visible_devices = _transform_uuid_to_ordinals( + cast(list[str], visible_devices), uuids + ) + visible_devices = cast(list[int], visible_devices) + if idx < 0 or idx >= len(visible_devices): + raise RuntimeError( + f"device {idx} is not visible (CUDA_VISIBLE_DEVICES={visible_devices})" + ) + return visible_devices[idx] + + +_cached_device_count: Optional[int] = None + + +def device_count() -> int: + r""" + Return the number of GPUs available. + + .. note:: This API will NOT posion fork if NVML discovery succeeds. + See :ref:`multiprocessing-poison-fork-note` for more details. + """ + global _cached_device_count + if not _is_compiled(): + return 0 + if _cached_device_count is not None: + return _cached_device_count + # bypass _device_count_nvml() if rocm (not supported) + nvml_count = _device_count_amdsmi() if torch.version.hip else _device_count_nvml() + r = torch._C._cuda_getDeviceCount() if nvml_count < 0 else nvml_count + # NB: Do not cache the device count prior to CUDA initialization, because + # the number of devices can change due to changes to CUDA_VISIBLE_DEVICES + # setting prior to CUDA initialization. + if _initialized: + _cached_device_count = r + return r + + +def get_arch_list() -> list[str]: + r"""Return list CUDA architectures this library was compiled for.""" + if not is_available(): + return [] + arch_flags = torch._C._cuda_getArchFlags() + if arch_flags is None: + return [] + return arch_flags.split() + + +def get_gencode_flags() -> str: + r"""Return NVCC gencode flags this library was compiled with.""" + arch_list = get_arch_list() + if len(arch_list) == 0: + return "" + arch_list_ = [arch.split("_") for arch in arch_list] + return " ".join( + [ + f"-gencode compute=compute_{arch},code={kind}_{arch}" + for (kind, arch) in arch_list_ + ] + ) + + +def current_device() -> int: + r"""Return the index of a currently selected device.""" + _lazy_init() + return torch._C._cuda_getDevice() + + +def synchronize(device: "Device" = None) -> None: + r"""Wait for all kernels in all streams on a CUDA device to complete. + + Args: + device (torch.device or int, optional): device for which to synchronize. + It uses the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + """ + _lazy_init() + with torch.cuda.device(device): + return torch._C._cuda_synchronize() + + +def ipc_collect(): + r"""Force collects GPU memory after it has been released by CUDA IPC. + + .. note:: + Checks if any sent CUDA tensors could be cleaned from the memory. Force + closes shared memory file used for reference counting if there is no + active counters. Useful when the producer process stopped actively sending + tensors and want to release unused memory. + """ + _lazy_init() + return torch._C._cuda_ipc_collect() + + +def current_stream(device: "Device" = None) -> Stream: + r"""Return the currently selected :class:`Stream` for a given device. + + Args: + device (torch.device or int, optional): selected device. Returns + the currently selected :class:`Stream` for the current device, given + by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` + (default). + """ + _lazy_init() + streamdata = torch._C._cuda_getCurrentStream( + _get_device_index(device, optional=True) + ) + return Stream( + stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2] + ) + + +def default_stream(device: "Device" = None) -> Stream: + r"""Return the default :class:`Stream` for a given device. + + Args: + device (torch.device or int, optional): selected device. Returns + the default :class:`Stream` for the current device, given by + :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` + (default). + """ + _lazy_init() + streamdata = torch._C._cuda_getDefaultStream( + _get_device_index(device, optional=True) + ) + return Stream( + stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2] + ) + + +def get_stream_from_external(data_ptr: int, device: "Device" = None) -> Stream: + r"""Return a :class:`Stream` from an externally allocated CUDA stream. + + This function is used to wrap streams allocated in other libraries in order + to facilitate data exchange and multi-library interactions. + + .. note:: This function doesn't manage the stream life-cycle, it is the user + responsibility to keep the referenced stream alive while this returned + stream is being used. + + Args: + data_ptr(int): Integer representation of the `cudaStream_t` value that + is allocated externally. + device(torch.device or int, optional): the device where the stream + was originally allocated. If device is specified incorrectly, + subsequent launches using this stream may fail. + """ + _lazy_init() + streamdata = torch._C._cuda_getStreamFromExternal( + data_ptr, _get_device_index(device, optional=True) + ) + return Stream( + stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2] + ) + + +def current_blas_handle(): + r"""Return cublasHandle_t pointer to current cuBLAS handle""" + _lazy_init() + return torch._C._cuda_getCurrentBlasHandle() + + +def set_sync_debug_mode(debug_mode: Union[int, str]) -> None: + r"""Set the debug mode for cuda synchronizing operations. + + Args: + debug_mode(str or int): if "default" or 0, don't error or warn on synchronizing operations, + if "warn" or 1, warn on synchronizing operations, if "error" or 2, error out synchronizing operations. + + Warning: + This is an experimental feature, and not all synchronizing operations will trigger warning or error. In + particular, operations in torch.distributed and torch.sparse namespaces are not covered yet. + """ + _lazy_init() + if isinstance(debug_mode, str): + if debug_mode == "default": + debug_mode = 0 + elif debug_mode == "warn": + debug_mode = 1 + elif debug_mode == "error": + debug_mode = 2 + else: + raise RuntimeError( + "invalid value of debug_mode, expected one of `default`, `warn`, `error`" + ) + + torch._C._cuda_set_sync_debug_mode(debug_mode) + + +def get_sync_debug_mode() -> int: + r"""Return current value of debug mode for cuda synchronizing operations.""" + _lazy_init() + return torch._C._cuda_get_sync_debug_mode() + + +def _get_pynvml_handler(device: "Device" = None): + if not _HAS_PYNVML: + raise ModuleNotFoundError( + "pynvml does not seem to be installed or it can't be imported." + ) from _PYNVML_ERR + from pynvml import NVMLError_DriverNotLoaded + + try: + pynvml.nvmlInit() + except NVMLError_DriverNotLoaded as e: + raise RuntimeError("cuda driver can't be loaded, is cuda enabled?") from e + + device = _get_nvml_device_index(device) + handle = pynvml.nvmlDeviceGetHandleByIndex(device) + return handle + + +def _get_amdsmi_handler(device: "Device" = None): + if not _HAS_PYNVML: + raise ModuleNotFoundError( + "amdsmi does not seem to be installed or it can't be imported." + ) from _PYNVML_ERR + try: + amdsmi.amdsmi_init() + except amdsmi.AmdSmiException as e: + raise RuntimeError( + "amdsmi driver can't be loaded, requires >=ROCm5.6 installation" + ) from e + device = _get_amdsmi_device_index(device) + handle = amdsmi.amdsmi_get_processor_handles()[device] + return handle + + +def _get_amdsmi_device_index(device: "Device") -> int: + r"""Return the amdsmi index of the device, taking visible_devices into account.""" + idx = _get_device_index(device, optional=True) + visible_devices = _parse_visible_devices() + if type(visible_devices[0]) is str: + uuids = _raw_device_uuid_amdsmi() + if uuids is None: + raise RuntimeError("Can't get device UUIDs") + visible_devices_str = cast( + list[str], visible_devices + ) # Create str variable for mypy + visible_devices = _transform_uuid_to_ordinals(visible_devices_str, uuids) + idx_map = dict(enumerate(cast(list[int], visible_devices))) + if idx not in idx_map: + raise RuntimeError( + f"device {idx} is not visible (HIP_VISIBLE_DEVICES={visible_devices})" + ) + return idx_map[idx] + + +def _get_amdsmi_device_memory_used(device: "Device" = None) -> int: + handle = _get_amdsmi_handler(device) + # amdsmi_get_gpu_vram_usage returns mem usage in megabytes + mem_mega_bytes = amdsmi.amdsmi_get_gpu_vram_usage(handle)["vram_used"] + mem_bytes = mem_mega_bytes * 1024 * 1024 + return mem_bytes + + +def _get_amdsmi_memory_usage(device: "Device" = None) -> int: + handle = _get_amdsmi_handler(device) + return amdsmi.amdsmi_get_gpu_activity(handle)["umc_activity"] + + +def _get_amdsmi_utilization(device: "Device" = None) -> int: + handle = _get_amdsmi_handler(device) + return amdsmi.amdsmi_get_gpu_activity(handle)["gfx_activity"] + + +def _get_amdsmi_temperature(device: "Device" = None) -> int: + handle = _get_amdsmi_handler(device) + return amdsmi.amdsmi_get_temp_metric( + handle, + amdsmi.AmdSmiTemperatureType.JUNCTION, + amdsmi.AmdSmiTemperatureMetric.CURRENT, + ) + + +def _get_amdsmi_power_draw(device: "Device" = None) -> int: + handle = _get_amdsmi_handler(device) + socket_power = amdsmi.amdsmi_get_power_info(handle)["average_socket_power"] + if socket_power != "N/A": + return socket_power + else: + socket_power = amdsmi.amdsmi_get_power_info(handle)["current_socket_power"] + if socket_power != "N/A": + return socket_power + else: + return 0 + + +def _get_amdsmi_clock_rate(device: "Device" = None) -> int: + handle = _get_amdsmi_handler(device) + clock_info = amdsmi.amdsmi_get_clock_info(handle, amdsmi.AmdSmiClkType.GFX) + if "cur_clk" in clock_info: # ROCm 6.2 deprecation + clock_rate = clock_info["cur_clk"] + else: + clock_rate = clock_info["clk"] + if clock_rate != "N/A": + return clock_rate + else: + return 0 + + +def device_memory_used(device: "Device" = None) -> int: + r"""Return used global (device) memory in bytes as given by `nvidia-smi` or `amd-smi`. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + """ + if not torch.version.hip: + handle = _get_pynvml_handler() + device = _get_nvml_device_index(device) + handle = pynvml.nvmlDeviceGetHandleByIndex(device) + return pynvml.nvmlDeviceGetMemoryInfo(handle).used + else: + return _get_amdsmi_device_memory_used(device) + + +def memory_usage(device: "Device" = None) -> int: + r"""Return the percent of time over the past sample period during which global (device) + memory was being read or written as given by `nvidia-smi`. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + Warning: Each sample period may be between 1 second and 1/6 second, + depending on the product being queried. + """ + if not torch.version.hip: + handle = _get_pynvml_handler() + device = _get_nvml_device_index(device) + handle = pynvml.nvmlDeviceGetHandleByIndex(device) + return pynvml.nvmlDeviceGetUtilizationRates(handle).memory + else: + return _get_amdsmi_memory_usage(device) + + +def utilization(device: "Device" = None) -> int: + r"""Return the percent of time over the past sample period during which one or + more kernels was executing on the GPU as given by `nvidia-smi`. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + Warning: Each sample period may be between 1 second and 1/6 second, + depending on the product being queried. + """ + if not torch.version.hip: + handle = _get_pynvml_handler(device) + device = _get_nvml_device_index(device) + handle = pynvml.nvmlDeviceGetHandleByIndex(device) + return pynvml.nvmlDeviceGetUtilizationRates(handle).gpu + else: + return _get_amdsmi_utilization(device) + + +def temperature(device: "Device" = None) -> int: + r"""Return the average temperature of the GPU sensor in Degrees C (Centigrades). + + The average temperature is computed based on past sample period as given by `nvidia-smi`. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + Warning: Each sample period may be between 1 second and 1/6 second, + depending on the product being queried. + """ + if not torch.version.hip: + handle = _get_pynvml_handler(device) + # 0 refers to the temperature sensor for the GPU die. + return pynvml.nvmlDeviceGetTemperature(handle, 0) + else: + return _get_amdsmi_temperature(device) + + +def power_draw(device: "Device" = None) -> int: + r"""Return the average power draw of the GPU sensor in mW (MilliWatts) + over the past sample period as given by `nvidia-smi` for Fermi or newer fully supported devices. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + Warning: Each sample period may be between 1 second and 1/6 second, + depending on the product being queried. + """ + if not torch.version.hip: + handle = _get_pynvml_handler(device) + return pynvml.nvmlDeviceGetPowerUsage(handle) + else: + return _get_amdsmi_power_draw(device) + + +def clock_rate(device: "Device" = None) -> int: + r"""Return the clock speed of the GPU SM in MHz (megahertz) over the past sample period as given by `nvidia-smi`. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + Warning: Each sample period may be between 1 second and 1/6 second, + depending on the product being queried. + """ + if not torch.version.hip: + handle = _get_pynvml_handler(device) + return pynvml.nvmlDeviceGetClockInfo(handle, 1) + else: + return _get_amdsmi_clock_rate(device) + + +def _get_device(device: Union[int, str, torch.device]) -> torch.device: + r"""Return the torch.device type object from the passed in device. + + Args: + device (torch.device or int): selected device. + """ + if isinstance(device, str): + device = torch.device(device) + elif isinstance(device, int): + device = torch.device("cuda", device) + return device + + +def _get_generator(device: torch.device) -> torch._C.Generator: + r"""Return the CUDA Generator object for the given device. + + Args: + device (torch.device): selected device. + """ + idx = device.index + if idx is None: + idx = current_device() + return torch.cuda.default_generators[idx] + + +def _set_rng_state_offset( + offset: int, device: Union[int, str, torch.device] = "cuda" +) -> None: + r"""Set the random number generator state offset of the specified GPU. + + Args: + offset (int): The desired offset + device (torch.device or int, optional): The device to set the RNG state. + Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device). + """ + final_device = _get_device(device) + + def cb(): + default_generator = _get_generator(final_device) + default_generator.set_offset(offset) + + _lazy_call(cb) + + +def _get_rng_state_offset(device: Union[int, str, torch.device] = "cuda") -> int: + r"""Return the random number generator state offset of the specified GPU. + + Args: + device (torch.device or int, optional): The device to return the RNG state offset of. + Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device). + + .. warning:: + This function eagerly initializes CUDA. + """ + _lazy_init() + final_device = _get_device(device) + default_generator = _get_generator(final_device) + return default_generator.get_offset() + + +from .memory import * # noqa: F403 +from .random import * # noqa: F403 + + +################################################################################ +# Define Storage and Tensor classes +################################################################################ + + +@staticmethod # type: ignore[misc] +def _lazy_new(cls, *args, **kwargs): + _lazy_init() + # We may need to call lazy init again if we are a forked child + # del _CudaBase.__new__ + return super(_CudaBase, cls).__new__(cls, *args, **kwargs) + + +class _CudaBase: + is_cuda = True + is_sparse = False + + def type(self, *args, **kwargs): + # We could use a Protocol here to tell mypy that self has `get_device` method + # but it is only available in the typing module on Python >= 3.8 + # or on typing_extensions module on Python >= 3.6 + with device(self.get_device()): # type: ignore[attr-defined] + return super().type(*args, **kwargs) # type: ignore[misc] + + __new__ = _lazy_new + + +from torch.storage import _LegacyStorage, _warn_typed_storage_removal + + +class _CudaLegacyStorage(_LegacyStorage): + @classmethod + def from_buffer(cls, *args, **kwargs): + _warn_typed_storage_removal() + raise RuntimeError("from_buffer: Not available for CUDA storage") + + @classmethod + def _new_with_weak_ptr(cls, *args, **kwargs): + raise RuntimeError("_new_with_weak_ptr: Not available for CUDA storage") + + @classmethod + def _new_shared_filename(cls, manager, obj, size, *, device=None, dtype=None): + raise RuntimeError("_new_shared_filename: Not available for CUDA storage") + + +class ByteStorage(_CudaLegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): + return torch.uint8 + + +class DoubleStorage(_CudaLegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): + return torch.double + + +class FloatStorage(_CudaLegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): + return torch.float + + +class HalfStorage(_CudaLegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): + return torch.half + + +class LongStorage(_CudaLegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): + return torch.long + + +class IntStorage(_CudaLegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): + return torch.int + + +class ShortStorage(_CudaLegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): + return torch.short + + +class CharStorage(_CudaLegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): + return torch.int8 + + +class BoolStorage(_CudaLegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): + return torch.bool + + +class BFloat16Storage(_CudaLegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): + return torch.bfloat16 + + +class ComplexDoubleStorage(_CudaLegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): + return torch.cdouble + + +class ComplexFloatStorage(_CudaLegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): + return torch.cfloat + + +del _LegacyStorage +del _CudaLegacyStorage + +torch._storage_classes.add(DoubleStorage) +torch._storage_classes.add(FloatStorage) +torch._storage_classes.add(LongStorage) +torch._storage_classes.add(IntStorage) +torch._storage_classes.add(ShortStorage) +torch._storage_classes.add(CharStorage) +torch._storage_classes.add(ByteStorage) +torch._storage_classes.add(HalfStorage) +torch._storage_classes.add(BoolStorage) +torch._storage_classes.add(BFloat16Storage) +torch._storage_classes.add(ComplexDoubleStorage) +torch._storage_classes.add(ComplexFloatStorage) + + +class _WrappedTritonKernel: + """Just a simple wrapper to store some metadata for testing purposes.""" + + def __init__(self, kernel): + self.kernel = kernel + self.kernel_invoked = False + + def __call__(self, *args, **kwargs): + res = self.kernel(*args, **kwargs) + self.kernel_invoked = True + return res + + +def _register_triton_kernels(): + if torch._running_with_deploy(): + return + + @_WrappedTritonKernel + def kernel_impl(*args, **kwargs): + from torch.sparse._triton_ops import bsr_dense_mm + + return bsr_dense_mm(*args, skip_checks=True, **kwargs) + + @_WrappedTritonKernel + def addmm_kernel_impl(*args, **kwargs): + from torch.sparse._triton_ops import bsr_dense_addmm + + return bsr_dense_addmm(*args, skip_checks=True, **kwargs) + + has_triton = importlib.util.find_spec("triton") is not None + if has_triton: + torch._TritonLibrary.registerOp( + "_triton_bsr_dense_mm_out", + "_triton_bsr_dense_mm_out(Tensor bsr, Tensor dense, *, Tensor(a!) out) -> Tensor(a!)", + kernel_impl, + "SparseCsrCUDA", + ) + + torch._TritonLibrary.registerOp( + "_triton_bsr_dense_addmm_out", + ( + "_triton_bsr_dense_addmm_out(Tensor input, Tensor bsr, Tensor dense," + " *, Scalar beta, Scalar alpha, Tensor(a!) out) -> Tensor(a!)" + ), + addmm_kernel_impl, + "SparseCsrCUDA", + ) + + +_lazy_call(_register_triton_kernels) + + +def _compile_kernel( + kernel_source: str, + kernel_name: str, + compute_capability: Optional[str] = None, + header_code: str = "", + cuda_include_dirs: Optional[list] = None, + nvcc_options: Optional[list] = None, +): + """ + Compiles a CUDA kernel using NVRTC and returns a callable function. + + This function is a wrapper for NVRTC that enables runtime compilation of CUDA kernels. + Note that this returns a raw CUDA kernel that operates on raw memory pointers. + To use this kernel as a proper PyTorch operator, you should wrap it following the guide at: + pytorch.org/tutorials/advanced/python_custom_ops.html + + Args: + kernel_source (str): The CUDA kernel source code as a string + kernel_name (str): The name of the kernel function to compile + compute_capability (str, optional): The compute capability to target (e.g., "86"). + If None, will detect from current device. + header_code (str, optional): Additional header code to prepend to the kernel source + cuda_include_dirs (list, optional): List of directories containing CUDA headers + nvcc_options (list, optional): Additional options to pass to NVRTC + + Returns: + callable: A Python function that can be called with PyTorch tensor arguments to execute the kernel + + Example: + >>> # xdoctest: +SKIP + >>> kernel_code = ''' + extern "C" + __global__ void add_tensors(const float* a, const float* b, float* c, int n) { + int i = threadIdx.x + blockIdx.x * blockDim.x; + if (i < n) + c[i] = a[i] + b[i]; + } + ''' + >>> add_kernel = torch.cuda.compile_kernel(kernel_code, "add_tensors") + >>> a = torch.randn(1024, device="cuda") + >>> b = torch.randn(1024, device="cuda") + >>> c = torch.empty_like(a) + >>> add_kernel(grid=(4,1,1), block=(256,1,1), args=[a, b, c, a.numel()]) + """ + import ctypes + + from torch.cuda._utils import _cuda_load_module, _nvrtc_compile + + # Compile the kernel to PTX + ptx = _nvrtc_compile( + kernel_source, + kernel_name, + compute_capability, + header_code, + cuda_include_dirs, + nvcc_options, + ) + + # Load the module and get the kernel + result = _cuda_load_module(ptx, [kernel_name]) + + if isinstance(result, dict): + return result[kernel_name] + else: + # This branch shouldn't be executed if kernel_names is provided, + # but MyPy needs this to understand type narrowing + return getattr(result, kernel_name) + + +from . import amp, jiterator, nvtx, profiler, sparse, tunable + + +__all__ = [ + # Typed storage and tensors + "BFloat16Storage", + "BFloat16Tensor", + "BoolStorage", + "BoolTensor", + "ByteStorage", + "ByteTensor", + "CharStorage", + "CharTensor", + "ComplexDoubleStorage", + "ComplexFloatStorage", + "DoubleStorage", + "DoubleTensor", + "FloatStorage", + "FloatTensor", + "HalfStorage", + "HalfTensor", + "IntStorage", + "IntTensor", + "LongStorage", + "LongTensor", + "ShortStorage", + "ShortTensor", + "CUDAGraph", + "CudaError", + "DeferredCudaCallError", + "Event", + "ExternalStream", + "Stream", + "StreamContext", + "amp", + "caching_allocator_alloc", + "caching_allocator_delete", + "caching_allocator_enable", + "can_device_access_peer", + "check_error", + "cudaStatus", + "cudart", + "current_blas_handle", + "current_device", + "current_stream", + "default_generators", + "default_stream", + "device", + "device_count", + "device_memory_used", + "device_of", + "empty_cache", + "get_allocator_backend", + "CUDAPluggableAllocator", + "change_current_allocator", + "get_arch_list", + "get_device_capability", + "get_device_name", + "get_device_properties", + "get_gencode_flags", + "get_per_process_memory_fraction", + "get_rng_state", + "get_rng_state_all", + "get_stream_from_external", + "get_sync_debug_mode", + "graph", + "graph_pool_handle", + "graphs", + "has_half", + "has_magma", + "host_memory_stats", + "host_memory_stats_as_nested_dict", + "init", + "initial_seed", + "ipc_collect", + "is_available", + "is_bf16_supported", + "is_current_stream_capturing", + "is_initialized", + "is_tf32_supported", + "jiterator", + "list_gpu_processes", + "make_graphed_callables", + "manual_seed", + "manual_seed_all", + "max_memory_allocated", + "max_memory_cached", + "max_memory_reserved", + "mem_get_info", + "memory", + "memory_allocated", + "memory_cached", + "memory_reserved", + "memory_snapshot", + "memory_stats", + "memory_stats_as_nested_dict", + "memory_summary", + "memory_usage", + "MemPool", + "use_mem_pool", + "temperature", + "power_draw", + "clock_rate", + "nccl", + "nvtx", + "profiler", + "random", + "reset_accumulated_host_memory_stats", + "reset_accumulated_memory_stats", + "reset_max_memory_allocated", + "reset_max_memory_cached", + "reset_peak_host_memory_stats", + "reset_peak_memory_stats", + "seed", + "seed_all", + "set_device", + "set_per_process_memory_fraction", + "set_rng_state", + "set_rng_state_all", + "set_stream", + "set_sync_debug_mode", + "sparse", + "stream", + "streams", + "synchronize", + "tunable", + "utilization", +] diff --git a/phivenv/Lib/site-packages/torch/cuda/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/cuda/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c88743953901227524a1d5f9c184a5a30f65c761 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/cuda/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/cuda/__pycache__/_gpu_trace.cpython-39.pyc b/phivenv/Lib/site-packages/torch/cuda/__pycache__/_gpu_trace.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb454af229cef6ccf7c7e0d0abb970d5d4ecbf95 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/cuda/__pycache__/_gpu_trace.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/cuda/__pycache__/_memory_viz.cpython-39.pyc b/phivenv/Lib/site-packages/torch/cuda/__pycache__/_memory_viz.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..adeb6202c4bdb1fd8130cdceefc1c0c7e7397a9c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/cuda/__pycache__/_memory_viz.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/cuda/__pycache__/_pin_memory_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/cuda/__pycache__/_pin_memory_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72385fd6b0659bd010098b339f0780c57b5c9093 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/cuda/__pycache__/_pin_memory_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/cuda/__pycache__/_sanitizer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/cuda/__pycache__/_sanitizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d3991391780465e127af84da42880a04de41058 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/cuda/__pycache__/_sanitizer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/cuda/__pycache__/_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/cuda/__pycache__/_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86eb4ca6b7b2d8ed8df743bf940474886655caa3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/cuda/__pycache__/_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/cuda/__pycache__/comm.cpython-39.pyc b/phivenv/Lib/site-packages/torch/cuda/__pycache__/comm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2952fa030d5769d8242cfa2593cd2e657e618bdb Binary files /dev/null and b/phivenv/Lib/site-packages/torch/cuda/__pycache__/comm.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/cuda/__pycache__/error.cpython-39.pyc b/phivenv/Lib/site-packages/torch/cuda/__pycache__/error.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..528c400d8c641771e9088335600c4ff35bed24a0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/cuda/__pycache__/error.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/cuda/__pycache__/gds.cpython-39.pyc b/phivenv/Lib/site-packages/torch/cuda/__pycache__/gds.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21824f4c4688300f8dd2f3992b8e5cc4aaad4edd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/cuda/__pycache__/gds.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/cuda/__pycache__/graphs.cpython-39.pyc b/phivenv/Lib/site-packages/torch/cuda/__pycache__/graphs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0cd4b14912950274181b2a7861ef45c7e9d6cf65 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/cuda/__pycache__/graphs.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/cuda/__pycache__/jiterator.cpython-39.pyc b/phivenv/Lib/site-packages/torch/cuda/__pycache__/jiterator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ac73a5f90fadb76444e24670bc6e39ac17859fd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/cuda/__pycache__/jiterator.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/cuda/__pycache__/memory.cpython-39.pyc b/phivenv/Lib/site-packages/torch/cuda/__pycache__/memory.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69c7879a67e7262ba49849df34a1502032a0a90b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/cuda/__pycache__/memory.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/cuda/__pycache__/nccl.cpython-39.pyc b/phivenv/Lib/site-packages/torch/cuda/__pycache__/nccl.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..048f23f0c3fff397a9efe9d6cb95cde939ac08da Binary files /dev/null and b/phivenv/Lib/site-packages/torch/cuda/__pycache__/nccl.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/cuda/__pycache__/nvtx.cpython-39.pyc b/phivenv/Lib/site-packages/torch/cuda/__pycache__/nvtx.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c00922015c7c71b35c3f6dfb8dca7e400d59bbf4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/cuda/__pycache__/nvtx.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/cuda/__pycache__/profiler.cpython-39.pyc b/phivenv/Lib/site-packages/torch/cuda/__pycache__/profiler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..103ff3f5a40d8275f8a0ad4ab4ff4526cf63fdcf Binary files /dev/null and b/phivenv/Lib/site-packages/torch/cuda/__pycache__/profiler.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/cuda/__pycache__/random.cpython-39.pyc b/phivenv/Lib/site-packages/torch/cuda/__pycache__/random.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8ca837a27ee262ebdfa8226cbcd0adf86f975cd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/cuda/__pycache__/random.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/cuda/__pycache__/sparse.cpython-39.pyc b/phivenv/Lib/site-packages/torch/cuda/__pycache__/sparse.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..baa81e6bd01898514a63cb4ee4f6c5fcda670099 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/cuda/__pycache__/sparse.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/cuda/__pycache__/streams.cpython-39.pyc b/phivenv/Lib/site-packages/torch/cuda/__pycache__/streams.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2312d5676d5c185f8c2f9a069a5128ef2dda37e6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/cuda/__pycache__/streams.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/cuda/__pycache__/tunable.cpython-39.pyc b/phivenv/Lib/site-packages/torch/cuda/__pycache__/tunable.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1994927c7993b06f98d398f4793c6e8a49f66c1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/cuda/__pycache__/tunable.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/cuda/_gpu_trace.py b/phivenv/Lib/site-packages/torch/cuda/_gpu_trace.py new file mode 100644 index 0000000000000000000000000000000000000000..3a0a46bd569918a3b24af576db9d5447b453a2d2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/cuda/_gpu_trace.py @@ -0,0 +1,73 @@ +from typing import Callable + +from torch._utils import CallbackRegistry + + +EventCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry( + "CUDA event creation" +) +EventDeletionCallbacks: "CallbackRegistry[int]" = CallbackRegistry( + "CUDA event deletion" +) +EventRecordCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry( + "CUDA event record" +) +EventWaitCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry("CUDA event wait") +MemoryAllocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry( + "CUDA memory allocation" +) +MemoryDeallocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry( + "CUDA memory deallocation" +) +StreamCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry( + "CUDA stream creation" +) +DeviceSynchronizationCallbacks: "CallbackRegistry[[]]" = CallbackRegistry( + "CUDA device synchronization" +) +StreamSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry( + "CUDA stream synchronization" +) +EventSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry( + "CUDA event synchronization" +) + + +def register_callback_for_event_creation(cb: Callable[[int], None]) -> None: + EventCreationCallbacks.add_callback(cb) + + +def register_callback_for_event_deletion(cb: Callable[[int], None]) -> None: + EventDeletionCallbacks.add_callback(cb) + + +def register_callback_for_event_record(cb: Callable[[int, int], None]) -> None: + EventRecordCallbacks.add_callback(cb) + + +def register_callback_for_event_wait(cb: Callable[[int, int], None]) -> None: + EventWaitCallbacks.add_callback(cb) + + +def register_callback_for_memory_allocation(cb: Callable[[int], None]) -> None: + MemoryAllocationCallbacks.add_callback(cb) + + +def register_callback_for_memory_deallocation(cb: Callable[[int], None]) -> None: + MemoryDeallocationCallbacks.add_callback(cb) + + +def register_callback_for_stream_creation(cb: Callable[[int], None]) -> None: + StreamCreationCallbacks.add_callback(cb) + + +def register_callback_for_device_synchronization(cb: Callable[[], None]) -> None: + DeviceSynchronizationCallbacks.add_callback(cb) + + +def register_callback_for_stream_synchronization(cb: Callable[[int], None]) -> None: + StreamSynchronizationCallbacks.add_callback(cb) + + +def register_callback_for_event_synchronization(cb: Callable[[int], None]) -> None: + EventSynchronizationCallbacks.add_callback(cb) diff --git a/phivenv/Lib/site-packages/torch/cuda/_memory_viz.py b/phivenv/Lib/site-packages/torch/cuda/_memory_viz.py new file mode 100644 index 0000000000000000000000000000000000000000..8b4b641ec92da813374b0480153d697c3e123a0b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/cuda/_memory_viz.py @@ -0,0 +1,739 @@ +# mypy: allow-untyped-defs +import base64 +import io +import json +import operator +import os +import pickle +import subprocess +import sys +import warnings +from functools import lru_cache +from itertools import groupby +from typing import Any + + +cache = lru_cache(None) + +__all__ = ["format_flamegraph", "segments", "memory", "compare"] + + +def _frame_fmt(f, full_filename=False): + i = f["line"] + fname = f["filename"] + if not full_filename: + fname = fname.split("/")[-1] + func = f["name"] + return f"{fname}:{i}:{func}" + + +@cache +def _frame_filter(name, filename): + omit_functions = [ + "unwind::unwind", + "CapturedTraceback::gather", + "gather_with_cpp", + "_start", + "__libc_start_main", + "PyEval_", + "PyObject_", + "PyFunction_", + ] + omit_filenames = [ + "core/boxing", + "/Register", + "/Redispatch", + "pythonrun.c", + "Modules/main.c", + "Objects/call.c", + "Objects/methodobject.c", + "pycore_ceval.h", + "ceval.c", + "cpython/abstract.h", + ] + for of in omit_functions: + if of in name: + return False + for of in omit_filenames: + if of in filename: + return False + return True + + +def _frames_fmt(frames, full_filename=False, reverse=False): + if reverse: + frames = reversed(frames) + return [ + _frame_fmt(f, full_filename) + for f in frames + if _frame_filter(f["name"], f["filename"]) + ] + + +def _block_extra_legacy(b): + if "history" in b: + frames = b["history"][0].get("frames", []) + real_size = b["history"][0]["real_size"] + else: + real_size = b.get("requested_size", b["size"]) + frames = [] + return frames, real_size + + +def _block_extra(b): + if "frames" not in b: + # old snapshot format made it more complicated to get frames/allocated size + return _block_extra_legacy(b) + return b["frames"], b["requested_size"] + + +def format_flamegraph(flamegraph_lines, flamegraph_script=None): + if flamegraph_script is None: + flamegraph_script = f"/tmp/{os.getuid()}_flamegraph.pl" + if not os.path.exists(flamegraph_script): + import tempfile + import urllib.request + + print(f"Downloading flamegraph.pl to: {flamegraph_script}") + with tempfile.NamedTemporaryFile(mode="wb", suffix=".pl") as f: + urllib.request.urlretrieve( + "https://raw.githubusercontent.com/brendangregg/FlameGraph/master/flamegraph.pl", + f.name, + ) + subprocess.check_call(["chmod", "+x", f.name]) + try: + os.rename(f.name, flamegraph_script) + except OSError: # noqa: B001,E722 + # Ok to skip, the file will be removed by tempfile + pass + args = [flamegraph_script, "--countname", "bytes"] + p = subprocess.Popen( + args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, encoding="utf-8" + ) + assert p.stdin is not None + assert p.stdout is not None + p.stdin.write(flamegraph_lines) + p.stdin.close() + result = p.stdout.read() + p.stdout.close() + p.wait() + assert p.wait() == 0 + return result + + +def _write_blocks(f, prefix, blocks): + def frames_fragment(frames): + if not frames: + return "" + return ";".join(_frames_fmt(frames, reverse=True)) + + for b in blocks: + if "history" not in b: + frames, accounted_for_size = _block_extra(b) + f.write( + f'{prefix};{b["state"]};{frames_fragment(frames)} {accounted_for_size}\n' + ) + else: + accounted_for_size = 0 + for h in b["history"]: + sz = h["real_size"] + accounted_for_size += sz + if "frames" in h: + frames = h["frames"] + f.write(f'{prefix};{b["state"]};{frames_fragment(frames)} {sz}\n') + else: + f.write(f'{prefix};{b["state"]}; {sz}\n') + gaps = b["size"] - accounted_for_size + if gaps: + f.write(f'{prefix};{b["state"]}; {gaps}\n') + + +def segments(snapshot, format_flamegraph=format_flamegraph): + f = io.StringIO() + for seg in snapshot["segments"]: + prefix = f'stream_{seg["stream"]};seg_{seg["address"]}' + _write_blocks(f, prefix, seg["blocks"]) + return format_flamegraph(f.getvalue()) + + +def memory(snapshot, format_flamegraph=format_flamegraph): + f = io.StringIO() + for seg in snapshot["segments"]: + prefix = f'stream_{seg["stream"]}' + _write_blocks(f, prefix, seg["blocks"]) + return format_flamegraph(f.getvalue()) + + +def compare(before, after, format_flamegraph=format_flamegraph): + def _seg_key(seg): + return (seg["address"], seg["total_size"]) + + def _seg_info(seg): + return f'stream_{seg["stream"]};seg_{seg["address"]}' + + f = io.StringIO() + + before_segs = {_seg_key(seg) for seg in before} + after_segs = {_seg_key(seg) for seg in after} + + print(f"only_before = {[a for a, _ in (before_segs - after_segs)]}") + print(f"only_after = {[a for a, _ in (after_segs - before_segs)]}") + + for seg in before: + if _seg_key(seg) not in after_segs: + _write_blocks(f, f"only_before;{_seg_info(seg)}", seg["blocks"]) + + for seg in after: + if _seg_key(seg) not in before_segs: + _write_blocks(f, f"only_after;{_seg_info(seg)}", seg["blocks"]) + + return format_flamegraph(f.getvalue()) + + +def _format_size(num): + # https://stackoverflow.com/questions/1094841/get-human-readable-version-of-file-size + for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: + if abs(num) < 1024.0: + return f"{num:3.1f}{unit}B" + num /= 1024.0 + return f"{num:.1f}YiB" + + +class Bytes: + def __init__(self, value): + self.value = value + + def __add__(self, rhs): + return Bytes(self.value + rhs) + + def __repr__(self): + return _format_size(self.value) + + +def calc_active(seg): + return sum(b["size"] for b in seg["blocks"] if b["state"] == "active_allocated") + + +def _report_free(free_external, free_internal): + total = free_external + free_internal + suffix = "" + if total != 0: + pct = (free_internal / total) * 100 + suffix = f" ({pct:.1f}% internal)" + return f"{Bytes(total)}{suffix}" + + +PAGE_SIZE = 1024 * 1024 * 20 +legend = f"""\ + +Legend: + [a ] - a segment in the allocator + ^-- a page {Bytes(PAGE_SIZE)} of memory in the segment + a-z: pages filled with a single block's content + ' ': page is completely free + *: page if completely full with multiple blocks + 0-9: page is partially full with tensors of multiple blocks (9 == 90% full) + (X% internal) - of the free memory, X% is free because we rounded the size of the allocation. +""" + + +def segsum(data): + r"""Visually reports how the allocator has filled its segments. + + This printout can help debug fragmentation issues since free fragments + will appear as gaps in this printout. The amount of free space is reported + for each segment. + We distinguish between internal free memory which occurs because the + allocator rounds the allocation size, and external free memory, which are + the gaps between allocations in a segment. + Args: + data: snapshot dictionary created from _snapshot() + """ + out = io.StringIO() + out.write(f"Summary of segments >= {Bytes(PAGE_SIZE)} in size\n") + total_reserved = 0 + total_allocated = 0 + free_external = 0 + free_internal = 0 + for seg in sorted( + data["segments"], key=lambda x: (x["total_size"], calc_active(x)) + ): + total_reserved += seg["total_size"] + + seg_free_external = 0 + seg_free_internal = 0 + seg_allocated = 0 + all_ranges = [] + boffset = 0 + for b in seg["blocks"]: + active = b["state"] == "active_allocated" + if active: + _, allocated_size = _block_extra(b) + all_ranges.append((boffset, allocated_size, True)) + seg_allocated += allocated_size + seg_free_internal += b["size"] - allocated_size + else: + seg_free_external += b["size"] + + boffset += b["size"] + + total_allocated += seg_allocated + free_external += seg_free_external + free_internal += seg_free_internal + + nseg = (seg["total_size"] - 1) // PAGE_SIZE + 1 + occupied = [" " for _ in range(nseg)] + frac = [0.0 for _ in range(nseg)] + active_size = 0 + for i, (start_, size, active) in enumerate(all_ranges): + active_size += size + finish_ = start_ + size + start = start_ // PAGE_SIZE + finish = (finish_ - 1) // PAGE_SIZE + 1 + m = chr(ord("a" if active else "A") + (i % 26)) + for j in range(start, finish): + s = max(start_, j * PAGE_SIZE) + e = min(finish_, (j + 1) * PAGE_SIZE) + frac[j] += (e - s) / PAGE_SIZE + if occupied[j] != " ": + occupied[j] = "0123456789*"[int(frac[j] * 10)] + else: + occupied[j] = m + stream = "" if seg["stream"] == 0 else f', stream_{seg["stream"]}' + body = "".join(occupied) + assert ( + seg_free_external + seg_free_internal + seg_allocated == seg["total_size"] + ) + stream = f' stream_{seg["stream"]}' if seg["stream"] != 0 else "" + if seg["total_size"] >= PAGE_SIZE: + out.write( + f'[{body}] {Bytes(seg["total_size"])} allocated, ' + f"{_report_free(seg_free_external, seg_free_internal)} free{stream}\n" + ) + out.write(f'segments: {len(data["segments"])}\n') + out.write(f"total_reserved: {Bytes(total_reserved)}\n") + out.write(f"total_allocated: {Bytes(total_allocated)}\n") + out.write(f"total_free: {_report_free(free_external, free_internal)}\n") + out.write(legend) + assert free_internal + free_external + total_allocated == total_reserved + return out.getvalue() + + +def trace(data): + out = io.StringIO() + + def format(entries): + segment_intervals: list = [] + segment_addr_to_name = {} + allocation_addr_to_name = {} + + free_names: list = [] + next_name = 0 + + def _name(): + nonlocal next_name + if free_names: + return free_names.pop() + r, m = next_name // 26, next_name % 26 + next_name += 1 + return f'{chr(ord("a") + m)}{"" if r == 0 else r}' + + def find_segment(addr): + for name, saddr, size in segment_intervals: + if addr >= saddr and addr < saddr + size: + return name, saddr + for i, seg in enumerate(data["segments"]): + saddr = seg["address"] + size = seg["allocated_size"] + if addr >= saddr and addr < saddr + size: + return f"seg_{i}", saddr + return None, None + + count = 0 + out.write(f"{len(entries)} entries\n") + + total_reserved = 0 + for seg in data["segments"]: + total_reserved += seg["total_size"] + + for count, e in enumerate(entries): + if e["action"] == "alloc": + addr, size = e["addr"], e["size"] + n = _name() + seg_name, seg_addr = find_segment(addr) + if seg_name is None: + seg_name = "MEM" + offset = addr + else: + offset = addr - seg_addr + out.write(f"{n} = {seg_name}[{offset}:{Bytes(size)}]\n") + allocation_addr_to_name[addr] = (n, size, count) + count += size + elif e["action"] == "free_requested": + addr, size = e["addr"], e["size"] + name, _, _ = allocation_addr_to_name.get(addr, (addr, None, None)) + out.write(f"del {name} # {Bytes(size)}\n") + elif e["action"] == "free_completed": + addr, size = e["addr"], e["size"] + count -= size + name, _, _ = allocation_addr_to_name.get(addr, (addr, None, None)) + out.write(f"# free completed for {name} {Bytes(size)}\n") + if name in allocation_addr_to_name: + free_names.append(name) + del allocation_addr_to_name[name] + elif e["action"] == "segment_alloc": + addr, size = e["addr"], e["size"] + name = _name() + out.write(f"{name} = cudaMalloc({addr}, {Bytes(size)})\n") + segment_intervals.append((name, addr, size)) + segment_addr_to_name[addr] = name + elif e["action"] == "segment_free": + addr, size = e["addr"], e["size"] + name = segment_addr_to_name.get(addr, addr) + out.write(f"cudaFree({name}) # {Bytes(size)}\n") + if name in segment_addr_to_name: + free_names.append(name) + del segment_addr_to_name[name] + elif e["action"] == "oom": + size = e["size"] + free = e["device_free"] + out.write( + f"raise OutOfMemoryError # {Bytes(size)} requested, {Bytes(free)} free in CUDA\n" + ) + else: + out.write(f"{e}\n") + out.write(f"TOTAL MEM: {Bytes(count)}") + + for i, d in enumerate(data["device_traces"]): + if d: + out.write(f"Device {i} ----------------\n") + format(d) + return out.getvalue() + + +_memory_viz_template = r""" + + + + + + + +""" + + +def _format_viz(data, viz_kind, device): + if device is not None: + warnings.warn( + "device argument is deprecated, plots now contain all device", + FutureWarning, + stacklevel=3, + ) + buffer = pickle.dumps(data) + buffer += b"\x00" * (3 - len(buffer) % 3) + # Encode the buffer with base64 + encoded_buffer = base64.b64encode(buffer).decode("utf-8") + + json_format = json.dumps([{"name": "snapshot.pickle", "base64": encoded_buffer}]) + return _memory_viz_template.replace("$VIZ_KIND", repr(viz_kind)).replace( + "$SNAPSHOT", json_format + ) + + +def trace_plot(data, device=None, plot_segments=False): + """Generate a visualization over time of the memory usage recorded by the trace as an html file. + + Args: + data: Memory snapshot as generated from torch.cuda.memory._snapshot() + device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations. + plot_segments (bool, optional): Plots memory returned from cudaMalloc, rather than individual allocations. + Defaults to False. + + Returns: + str: HTML of visualization + """ + return _format_viz( + data, + "Active Memory Timeline" + if not plot_segments + else "Active Cached Memory Timeline", + device, + ) + + +def _profile_to_snapshot(profile): + import torch + from torch._C._profiler import _EventType + from torch.profiler._memory_profiler import Action, TensorKey + + memory_profile = profile._memory_profile() + + allocation_stacks = {} + for event in memory_profile._op_tree.sorted_nodes: + if event.tag == _EventType.Allocation: + parent = event.parent + python_parents = [] + while parent: + if parent.tag in (_EventType.PyCall, _EventType.PyCCall): + python_parents.append(parent) + parent = parent.parent + key = TensorKey.from_allocation(event.extra_fields) + + # Corner case: If allocation doesn't have an ID (can't prove it was used as a Tensor) + # key will be None. I should add some way to identify these, I just haven't yet. + if key and event.extra_fields.alloc_size > 0: + allocation_stacks[key] = python_parents + + device_count = torch.cuda.device_count() + snapshot: dict[str, list[Any]] = { + "device_traces": [[] for _ in range(device_count + 1)], + "segments": [ + { + "device": device, + "address": None, + "total_size": 0, + "stream": 0, + "blocks": [], + } + for device in range(device_count + 1) + ], + } + + def to_device(device): + if device.type == "cuda": + return device.index + else: + return device_count + + def allocate(size, tensor_key, version, during_trace=True): + device = to_device(tensor_key.device) + addr = tensor_key.storage.ptr + + seg = snapshot["segments"][device] # type: ignore[index] + if seg["address"] is None or seg["address"] > addr: + seg["address"] = addr + seg["total_size"] = max( + seg["total_size"], addr + size + ) # record max addr for now, we will make it the size later + category = memory_profile._categories.get(tensor_key, version) + category = category.name.lower() if category is not None else "unknown" + stack = allocation_stacks.get(tensor_key, ()) + stack = [{"filename": "none", "line": 0, "name": p.name} for p in stack] + r = { + "action": "alloc", + "addr": addr, + "size": size, + "stream": 0, + "frames": stack, + "category": category, + } + if during_trace: + snapshot["device_traces"][device].append(r) + return r + + def free(alloc, device): + for e in ("free_requested", "free_completed"): + snapshot["device_traces"][device].append( + { + "action": e, + "addr": alloc["addr"], + "size": alloc["size"], + "stream": 0, + "frames": alloc["frames"], + } + ) + + kv_to_elem = {} + + # create the device trace + for _time, action, (tensor_key, version), size in memory_profile.timeline: + if not isinstance(tensor_key, TensorKey): + continue + if action == Action.CREATE: + kv_to_elem[(tensor_key, version)] = allocate(size, tensor_key, version) + elif action == Action.DESTROY: + free(kv_to_elem.pop((tensor_key, version)), to_device(tensor_key.device)) + elif action == Action.INCREMENT_VERSION: + free(kv_to_elem.pop((tensor_key, version)), to_device(tensor_key.device)) + kv_to_elem[(tensor_key, version + 1)] = allocate( + size, tensor_key, version + 1 + ) + elif action == Action.PREEXISTING: + kv_to_elem[(tensor_key, version)] = allocate( + size, tensor_key, version, during_trace=False + ) + + # create the final snapshot state + blocks_at_end = [ + (to_device(tensor_key.device), event["addr"], event["size"], event["frames"]) + for (tensor_key, version), event in kv_to_elem.items() + ] + for device, blocks in groupby(sorted(blocks_at_end), key=operator.itemgetter(0)): + seg = snapshot["segments"][device] # type: ignore[index] + last_addr = seg["address"] + for _, addr, size, frames in blocks: + if last_addr < addr: + seg["blocks"].append({"size": addr - last_addr, "state": "inactive"}) + seg["blocks"].append( + { + "size": size, + "state": "active_allocated", + "requested_size": size, + "frames": frames, + } + ) + last_addr = addr + size + if last_addr < seg["total_size"]: + seg["blocks"].append( + {"size": seg["total_size"] - last_addr, "state": "inactive"} + ) + + snapshot["segments"] = [seg for seg in snapshot["segments"] if seg["blocks"]] # type: ignore[attr-defined] + for seg in snapshot["segments"]: # type: ignore[attr-defined, name-defined, no-redef] + seg["total_size"] -= seg["address"] + if not seg["blocks"]: + seg["blocks"].append({"size": seg["total_size"], "state": "inactive"}) + + return snapshot + + +def profile_plot(profile, device=None): + """Generate a visualization over time of the memory usage recorded by kineto memory profiling as an html file. + + Args: + profile: profile as generated by `torch.profiler.profile(profile_memory=True)` + device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations. + + Returns: + str: HTML of visualization + """ + snapshot = _profile_to_snapshot(profile) + return _format_viz(snapshot, "Active Memory Timeline", device) + + +def segment_plot(data: Any, device=None): + return _format_viz(data, "Allocator State History", device) + + +if __name__ == "__main__": + import os.path + + thedir = os.path.realpath(os.path.dirname(__file__)) + if thedir in sys.path: + # otherwise we find cuda/random.py as random... + sys.path.remove(thedir) + import argparse + + fn_name = "torch.cuda.memory._snapshot()" + pickled = f"pickled memory statistics from {fn_name}" + parser = argparse.ArgumentParser( + description=f"Visualize memory dumps produced by {fn_name}" + ) + + subparsers = parser.add_subparsers(dest="action") + + def _output(p): + p.add_argument( + "-o", + "--output", + default="output.svg", + help="flamegraph svg (default: output.svg)", + ) + + description = "Prints overall allocation statistics and a visualization of how the allocators segments are currently filled." + stats_a = subparsers.add_parser("stats", description=description) + stats_a.add_argument("input", help=pickled) + + description = "Prints buffer of the most recent allocation events embedded in the snapshot in a Pythonic style." + trace_a = subparsers.add_parser("trace", description=description) + trace_a.add_argument("input", help=pickled) + + description = "Generate a flamegraph that visualizes what memory is stored in each allocator segment (aka block)" + segments_a = subparsers.add_parser("segments", description=description) + segments_a.add_argument("input", help=pickled) + _output(segments_a) + + description = ( + "Generate a flamegraph the program locations contributing to CUDA memory usage." + ) + memory_a = subparsers.add_parser("memory", description=description) + memory_a.add_argument("input", help=pickled) + _output(memory_a) + + description = ( + "Generate a flamegraph that shows segments (aka blocks) that have been added " + "or removed between two different memorys snapshots." + ) + compare_a = subparsers.add_parser("compare", description=description) + compare_a.add_argument("before", help=pickled) + compare_a.add_argument("after", help=pickled) + _output(compare_a) + + plots = ( + ( + "trace_plot", + "Generate a visualization over time of the memory usage recorded by the trace as an html file.", + ), + ( + "segment_plot", + "Visualize how allocations are packed into allocator segments at each point in a trace as an html file.", + ), + ) + for cmd, description in plots: + trace_plot_a = subparsers.add_parser(cmd, description=description) + trace_plot_a.add_argument("input", help=pickled) + help = "visualize trace from this device (default: chooses the only device with trace info or errors)" + trace_plot_a.add_argument("-d", "--device", type=int, default=None, help=help) + help = "path to save the visualization(default: output.html)" + trace_plot_a.add_argument("-o", "--output", default="output.html", help=help) + if cmd == "trace_plot": + help = "visualize change to segments rather than individual allocations" + trace_plot_a.add_argument( + "-s", "--segments", action="store_true", help=help + ) + + args = parser.parse_args() + + def _read(name): + if name == "-": + f = sys.stdin.buffer + else: + f = open(name, "rb") + data = pickle.load(f) + if isinstance(data, list): # segments only... + data = {"segments": data, "traces": []} + return data + + def _write(name, data): + with open(name, "w") as f: + f.write(data) + + if args.action == "segments": + data = _read(args.input) + _write(args.output, segments(data)) + elif args.action == "memory": + data = _read(args.input) + _write(args.output, memory(data)) + elif args.action == "stats": + data = _read(args.input) + print(segsum(data)) + elif args.action == "trace": + data = _read(args.input) + print(trace(data)) + elif args.action == "compare": + before = _read(args.before) + after = _read(args.after) + _write(args.output, compare(before, after)) + elif args.action == "trace_plot": + data = _read(args.input) + _write( + args.output, + trace_plot(data, device=args.device, plot_segments=args.segments), + ) + elif args.action == "segment_plot": + data = _read(args.input) + _write(args.output, segment_plot(data, device=args.device)) diff --git a/phivenv/Lib/site-packages/torch/cuda/_pin_memory_utils.py b/phivenv/Lib/site-packages/torch/cuda/_pin_memory_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..faba7262af7de4c736660df78e9b0f5f9fc93b31 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/cuda/_pin_memory_utils.py @@ -0,0 +1,24 @@ +import torch + + +def pin_memory(data_ptr: int, size: int) -> None: + cudart = torch.cuda.cudart() + succ = int( + cudart.cudaHostRegister( + data_ptr, + size, + 1, # lines up with 'cudaHostRegisterPortable' + ) + ) + + if succ != 0: + raise RuntimeError( + f"Registering memory failed with cudaError: {succ}." + " It's possible that this is an asynchronous error raised from a previous cuda operation." + " Consider launching with CUDA_LAUNCH_BLOCKING=1 to debug." + ) + + +def unpin_memory(data_ptr: int) -> None: + succ = int(torch.cuda.cudart().cudaHostUnregister(data_ptr)) + assert succ == 0, f"Unpinning shared memory failed with error-code: {succ}" diff --git a/phivenv/Lib/site-packages/torch/cuda/_sanitizer.py b/phivenv/Lib/site-packages/torch/cuda/_sanitizer.py new file mode 100644 index 0000000000000000000000000000000000000000..a0ed9ab89dbd6970d3ae41f0004de6ada22bacd9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/cuda/_sanitizer.py @@ -0,0 +1,662 @@ +# mypy: allow-untyped-defs +r""" +This module introduces CUDA Sanitizer, a tool for detecting synchronization errors between kernels ran on different streams. + +It stores information on accesses to tensors to determine if they are synchronized +or not. When enabled in a python program and a possible data race is detected, a +detailed warning will be printed and the program will exit. + +It can be enabled either by importing this module and calling +:func:`enable_cuda_sanitizer()` or by exporting the ``TORCH_CUDA_SANITIZER`` +environment variable. +""" + +import enum +import functools +import inspect +import io +import logging +import re +import sys +import textwrap +import traceback +from collections.abc import Iterator +from dataclasses import dataclass, field +from typing import Any, Optional, TypeVar + +import torch +import torch.cuda._gpu_trace as gpu_trace +from torch.utils import _pytree as pytree +from torch.utils._python_dispatch import TorchDispatchMode + + +DEFAULT_STREAM_ID = 0 + +TK = TypeVar("TK") +TVa = TypeVar("TVa") +TVb = TypeVar("TVb") + +DataPtr = int +StreamId = int +EventId = int +SeqNum = int + +logger = logging.getLogger(__name__) + +# Note that this is only factories that take Tensor as input as they are +# the ones we care about. +FACTORY_FUNCTION_REGEX = re.compile("(new_.*|.*_like)") + + +class AccessType(enum.Enum): + READ = enum.auto() + WRITE = enum.auto() + + def __str__(self): + return "reading from" if self is AccessType.READ else "writing to" + + +@dataclass +class Access: + r"""Stores information about a single access to a tensor by a kernel. + + Args: + type: either AccessType.READ or AccessType.Write. + seq_num: the sequential number of the kernel performing the access. + stream: the stream id of the stream executing the kernel. + operator: the schema of the launched kernel, which lists the + arguments and return type. + aliases: the arguments in the schema this access corresponds to. + is_output: Whether the tensor was an output of the kernel. + stack_trace: the stack summary object captured during access. + """ + + type: AccessType + seq_num: SeqNum + stream: StreamId + operator: str + aliases: list[str] + is_output: bool + stack_trace: traceback.StackSummary + + +class SynchronizationError(Exception): + """Base class for errors detected by CUDA Sanitizer.""" + + +class UnsynchronizedAccessError(SynchronizationError): + """Stores information about two unsynchronized accesses to one data pointer.""" + + def __init__( + self, + data_ptr: DataPtr, + allocation_stack_trace: Optional[traceback.StackSummary], + current_access: Access, + previous_access: Access, + ): + self.data_ptr = data_ptr + self.allocation_stack_trace = allocation_stack_trace + self.current_access = current_access + self.previous_access = previous_access + + def __str__(self): + def format_access(access: Access): + message.write(f"{access.operator}\n{access.type}") + if access.aliases: + message.write(" argument(s) " + ", ".join(access.aliases)) + if access.is_output: + message.write(", and to") + if access.is_output: + message.write(" the output") + message.write( + f"\nWith stack trace:\n{''.join(access.stack_trace.format())}\n" + ) + + with io.StringIO() as message: + message.write( + textwrap.dedent( + f"""\ + ============================ + CSAN detected a possible data race on tensor with data pointer {self.data_ptr} + Access by stream {self.current_access.stream} during kernel: + """ + ) + ) + format_access(self.current_access) + + message.write( + f"Previous access by stream {self.previous_access.stream} during kernel:\n" + ) + format_access(self.previous_access) + + if self.allocation_stack_trace: + message.write( + "Tensor was allocated with stack trace:\n" + f"{''.join(self.allocation_stack_trace.format())}" + ) + else: + message.write("Trace for tensor allocation not found.") + return message.getvalue() + + +class CUDASanitizerErrors(Exception): + """Wrapper class for errors reported by CUDA Sanitizer.""" + + def __init__(self, errors: list[SynchronizationError]): + self.errors = errors + + def __str__(self): + return f"detected {len(self.errors)} errors" + + +@dataclass +class TensorInfo: + r"""Stores information about a single tensor and recent accesses to it. + + Args: + allocation_stack_trace: the stack summary object captured during tensor + allocation. Can be ``None`` if the allocation wasn't caught by CSAN. + reads: list of read accesses to the tensor that were performed since + the last write. + write: the last write access to the tensor. + """ + + allocation_stack_trace: Optional[traceback.StackSummary] + reads: list[Access] = field(default_factory=list) + write: Optional[Access] = None + + +class _TensorsAccessed: + def __init__(self) -> None: + self.accesses: dict[DataPtr, TensorInfo] = {} + + def ensure_tensor_exists(self, data_ptr: DataPtr) -> None: + if data_ptr not in self.accesses: + logger.info( + "Found tensor with pointer: %s, but no matching tensor " + "allocation in the trace. Backfilling the trace now. " + "Perhaps the sanitizer was enabled after some torch operations?", + data_ptr, + ) + self.create_tensor(data_ptr, None) + + def ensure_tensor_does_not_exist(self, data_ptr: DataPtr) -> None: + if data_ptr in self.accesses: + logger.info( + "Found duplicate tensor allocation in the trace for tensor with " + "pointer: %s. Assuming the trace for tensor deallocation " + "wasn't caught and backfilling it now. " + "Perhaps the sanitizer was enabled after some torch operations?", + data_ptr, + ) + self.delete_tensor(data_ptr) + + def create_tensor( + self, data_ptr: DataPtr, stack_trace: Optional[traceback.StackSummary] + ) -> None: + self.accesses[data_ptr] = TensorInfo(stack_trace) + + def delete_tensor(self, data_ptr: DataPtr) -> None: + del self.accesses[data_ptr] + + def were_there_reads_since_last_write(self, data_ptr: DataPtr) -> bool: + return True if self.accesses[data_ptr].reads else False + + def get_allocation_stack_trace( + self, data_ptr: DataPtr + ) -> Optional[traceback.StackSummary]: + return self.accesses[data_ptr].allocation_stack_trace + + def get_write(self, data_ptr: DataPtr) -> Optional[Access]: + return self.accesses[data_ptr].write + + def get_reads(self, data_ptr: DataPtr) -> list[Access]: + return self.accesses[data_ptr].reads + + def add_read(self, data_ptr: DataPtr, access: Access) -> None: + self.accesses[data_ptr].reads.append(access) + + def set_write(self, data_ptr: DataPtr, access: Access) -> None: + self.accesses[data_ptr].write = access + self.accesses[data_ptr].reads = [] + + +class StreamSynchronizations: + def __init__(self) -> None: + self.current_sync_states: dict[StreamId, dict[StreamId, SeqNum]] = {} + self.recorded_sync_states: dict[EventId, dict[StreamId, SeqNum]] = {} + self.host_sync_state: dict[StreamId, SeqNum] = {} + self.create_stream(DEFAULT_STREAM_ID) + + def _ensure_stream_exists(self, stream: StreamId) -> None: + if stream not in self.current_sync_states: + logger.info( + "Found Stream with id: %s, but no matching stream " + "creation in the trace. Backfilling the trace now. " + "Perhaps the sanitizer was enabled after some torch operations?", + stream, + ) + self.create_stream(stream) + + def _ensure_event_exists(self, event: EventId) -> None: + if event not in self.recorded_sync_states: + logger.info( + "Found Event with id: %s, but no matching event " + "creation in the trace. Backfilling the trace now. " + "Perhaps the sanitizer was enabled after some torch operations?", + event, + ) + self.create_event(event) + + def _ensure_event_does_not_exist(self, event: EventId) -> None: + if event in self.recorded_sync_states: + logger.info( + "Found duplicate event creation in the trace for event with " + "id: %s. Assuming the trace for event deletion wasn't caught " + "and backfilling it now. " + "Perhaps the sanitizer was enabled after some torch operations?", + event, + ) + self.delete_event(event) + + def create_stream(self, stream: StreamId) -> None: + if stream in self.current_sync_states: + logger.info( + "Found duplicate Stream creation in the trace for Stream with " + "id: %s. PyTorch Streams are only created once, so this " + "trace entry is ignored.", + stream, + ) + else: + self.host_sync_state[stream] = 0 + self.current_sync_states[stream] = self.host_sync_state.copy() + + def create_event(self, event: EventId) -> None: + self._ensure_event_does_not_exist(event) + self.recorded_sync_states[event] = {} + + def delete_event(self, event: EventId) -> None: + self._ensure_event_exists(event) + del self.recorded_sync_states[event] + + def update_seq_num(self, stream: StreamId, seq_num: SeqNum) -> None: + self._ensure_stream_exists(stream) + self.current_sync_states[stream][stream] = seq_num + + def record_state(self, event: EventId, stream: StreamId) -> None: + self._ensure_event_exists(event) + self._ensure_stream_exists(stream) + self.recorded_sync_states[event] = self.current_sync_states[stream].copy() + + def _state_wait_for_other( + self, state: dict[StreamId, SeqNum], other: dict[StreamId, SeqNum] + ) -> None: + for stream, seq_num in other.items(): + state[stream] = max(state.get(stream, -1), seq_num) + + def stream_wait_for_event(self, stream: StreamId, event: EventId) -> None: + self._ensure_stream_exists(stream) + self._ensure_event_exists(event) + self._state_wait_for_other( + self.current_sync_states[stream], self.recorded_sync_states[event] + ) + + def all_streams_wait_for_event(self, event: EventId) -> None: + self._ensure_event_exists(event) + for stream in self.current_sync_states.keys(): + self.stream_wait_for_event(stream, event) + + self._state_wait_for_other( + self.host_sync_state, self.recorded_sync_states[event] + ) + + def all_streams_wait_for_stream(self, stream: StreamId) -> None: + self._ensure_stream_exists(stream) + for state in self.current_sync_states.values(): + self._state_wait_for_other(state, self.current_sync_states[stream]) + + self._state_wait_for_other( + self.host_sync_state, self.current_sync_states[stream] + ) + + def sync_all_streams(self) -> None: + for stream, state in self.current_sync_states.items(): + self.host_sync_state[stream] = state[stream] + + for state in self.current_sync_states.values(): + self._state_wait_for_other(state, self.host_sync_state) + + def is_ordered_after( + self, current_stream: StreamId, seq_num: SeqNum, other_stream: StreamId + ) -> bool: + self._ensure_stream_exists(current_stream) + self._ensure_stream_exists(other_stream) + return seq_num <= self.current_sync_states[current_stream].get(other_stream, -1) + + +class EventHandler: + """Analyzes CSAN trace for synchronization errors. + + Stores information on each stream's synchronizations with other streams as well + as tensor accesses to determine whether a given kernel launch might cause a + data race. + """ + + def __init__(self) -> None: + self.tensors_accessed = _TensorsAccessed() + self.syncs = StreamSynchronizations() + self.seq_num: SeqNum = 0 + + def _handle_kernel_launch( + self, + stream: StreamId, + read_only: set[DataPtr], + read_write: set[DataPtr], + outputs: set[DataPtr], + operator: str, + tensor_aliases: dict[int, list[str]], + ) -> list[SynchronizationError]: + def check_conflict( + data_ptr: DataPtr, current_access: Access, previous_access: Optional[Access] + ) -> None: + if previous_access is None: + return + if not self.syncs.is_ordered_after( + current_access.stream, previous_access.seq_num, previous_access.stream + ): + error_list.append( + UnsynchronizedAccessError( + data_ptr, + self.tensors_accessed.get_allocation_stack_trace(data_ptr), + current_access, + previous_access, + ) + ) + + error_list: list[SynchronizationError] = [] + self.seq_num += 1 + self.syncs.update_seq_num(stream, self.seq_num) + stack_trace = traceback.StackSummary.extract( + traceback.walk_stack(inspect.currentframe()), lookup_lines=False + ) + # The stack trace generated in this way is in the inverse order, so it must be + # reversed. + stack_trace.reverse() + + for data_ptr in read_only: + self.tensors_accessed.ensure_tensor_exists(data_ptr) + current_access = Access( + AccessType.READ, + self.seq_num, + stream, + operator, + tensor_aliases[data_ptr], + data_ptr in outputs, + stack_trace, + ) + check_conflict( + data_ptr, current_access, self.tensors_accessed.get_write(data_ptr) + ) + self.tensors_accessed.add_read(data_ptr, current_access) + + for data_ptr in read_write: + self.tensors_accessed.ensure_tensor_exists(data_ptr) + current_access = Access( + AccessType.WRITE, + self.seq_num, + stream, + operator, + tensor_aliases[data_ptr], + data_ptr in outputs, + stack_trace, + ) + if self.tensors_accessed.were_there_reads_since_last_write(data_ptr): + for previous_access in self.tensors_accessed.get_reads(data_ptr): + check_conflict(data_ptr, current_access, previous_access) + else: + check_conflict( + data_ptr, current_access, self.tensors_accessed.get_write(data_ptr) + ) + self.tensors_accessed.set_write(data_ptr, current_access) + + return error_list + + def _handle_event_creation(self, event: EventId) -> None: + self.syncs.create_event(event) + + def _handle_event_deletion(self, event: EventId) -> None: + self.syncs.delete_event(event) + + def _handle_event_record(self, event: EventId, stream: StreamId) -> None: + self.syncs.record_state(event, stream) + + def _handle_event_wait(self, event: EventId, stream: StreamId) -> None: + self.syncs.stream_wait_for_event(stream, event) + + def _handle_memory_allocation(self, data_ptr: DataPtr) -> None: + self.tensors_accessed.ensure_tensor_does_not_exist(data_ptr) + stack_trace = traceback.StackSummary.extract( + traceback.walk_stack(inspect.currentframe()), lookup_lines=False + ) + # The stack trace generated in this way is in the inverse order, so it must be + # reversed. + stack_trace.reverse() + self.tensors_accessed.create_tensor( + data_ptr, + stack_trace, + ) + + def _handle_memory_deallocation(self, data_ptr: DataPtr) -> None: + self.tensors_accessed.ensure_tensor_exists(data_ptr) + self.tensors_accessed.delete_tensor(data_ptr) + + def _handle_stream_creation(self, stream: StreamId) -> None: + self.syncs.create_stream(stream) + + def _handle_device_synchronization(self) -> None: + self.syncs.sync_all_streams() + + def _handle_stream_synchronization(self, stream: StreamId) -> None: + self.syncs.all_streams_wait_for_stream(stream) + + def _handle_event_synchronization(self, event: EventId) -> None: + self.syncs.all_streams_wait_for_event(event) + + +def zip_by_key(a: dict[TK, TVa], b: dict[TK, TVb]) -> Iterator[tuple[TK, TVa, TVb]]: + for arg, value in a.items(): + if arg in b: + yield arg, value, b[arg] + + +def zip_arguments( + schema: torch.FunctionSchema, args: tuple[Any, ...], kwargs: dict[str, Any] +) -> Iterator[tuple[torch.Argument, Any]]: + schema_args = schema.arguments[: len(args)] + schema_kwargs = {arg.name: arg for arg in schema.arguments[len(args) :]} + + yield from zip(schema_args, args) + + for _, argument, value in zip_by_key(schema_kwargs, kwargs): + yield (argument, value) + + +class ArgumentHandler: + def __init__(self) -> None: + self.dataptrs_read: set[DataPtr] = set() + self.dataptrs_written: set[DataPtr] = set() + self.tensor_aliases: dict[DataPtr, list[str]] = {} + self.outputs: set[DataPtr] = set() + + def _handle_argument( + self, + value: Any, + is_write: bool, + metadata_only: bool, + name: Optional[str] = None, + is_output: bool = False, + ) -> None: + if isinstance(value, torch.Tensor) and value.is_cuda: + data_ptr = value.data_ptr() + if is_write: + self.dataptrs_written.add(data_ptr) + elif not metadata_only: + self.dataptrs_read.add(data_ptr) + + self.tensor_aliases.setdefault(data_ptr, []) + if name is not None: + self.tensor_aliases[data_ptr].append(name) + if is_output: + self.outputs.add(data_ptr) + + def parse_inputs( + self, + schema: torch.FunctionSchema, + args: tuple[Any, ...], + kwargs: dict[str, Any], + *, + is_factory: bool, + ) -> None: + for argument, value in zip_arguments(schema, args, kwargs): + is_write = argument.alias_info is not None and argument.alias_info.is_write + # A change is metadata only if it is a view or a factory function that + # reads only metadata + metadata_only = is_factory or ( + argument.alias_info is not None and not argument.alias_info.is_write + ) + pytree.tree_map_( + functools.partial( + self._handle_argument, + is_write=is_write, + name=argument.name, + metadata_only=metadata_only, + ), + value, + ) + + def parse_outputs( + self, schema: torch.FunctionSchema, outputs: Any, *, is_factory: bool + ) -> None: + for res, value in zip(schema.returns, (outputs,)): + metadata_only = is_factory or ( + res.alias_info is not None and not res.alias_info.is_write + ) + pytree.tree_map_( + functools.partial( + self._handle_argument, + is_write=not metadata_only, + is_output=True, + metadata_only=metadata_only, + ), + value, + ) + + +class CUDASanitizerDispatchMode(TorchDispatchMode): + def __init__(self) -> None: + self.event_handler = EventHandler() + torch._C._activate_gpu_trace() + gpu_trace.register_callback_for_event_creation( + self.event_handler._handle_event_creation + ) + gpu_trace.register_callback_for_event_deletion( + self.event_handler._handle_event_deletion + ) + gpu_trace.register_callback_for_event_record( + self.event_handler._handle_event_record + ) + gpu_trace.register_callback_for_event_wait( + self.event_handler._handle_event_wait + ) + gpu_trace.register_callback_for_memory_allocation( + self.event_handler._handle_memory_allocation + ) + gpu_trace.register_callback_for_memory_deallocation( + self.event_handler._handle_memory_deallocation + ) + gpu_trace.register_callback_for_stream_creation( + self.event_handler._handle_stream_creation + ) + gpu_trace.register_callback_for_device_synchronization( + self.event_handler._handle_device_synchronization + ) + gpu_trace.register_callback_for_stream_synchronization( + self.event_handler._handle_stream_synchronization + ) + gpu_trace.register_callback_for_event_synchronization( + self.event_handler._handle_event_synchronization + ) + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + is_factory = bool(FACTORY_FUNCTION_REGEX.match(func._schema.name)) + + argument_handler = ArgumentHandler() + argument_handler.parse_inputs(func._schema, args, kwargs, is_factory=is_factory) + + outputs = func(*args, **kwargs) + + argument_handler.parse_outputs(func._schema, outputs, is_factory=is_factory) + errors = self.event_handler._handle_kernel_launch( + torch.cuda.current_stream().cuda_stream, + argument_handler.dataptrs_read - argument_handler.dataptrs_written, + argument_handler.dataptrs_written, + argument_handler.outputs, + func._schema, + argument_handler.tensor_aliases, + ) + if errors: + for error in errors: + print(error, file=sys.stderr) + raise CUDASanitizerErrors(errors) + + return outputs + + +class CUDASanitizer: + """Manages the lifetime of a CUDASanitizer dispatch mode object. + + The CUDASanitizer class wraps the entering/exiting functions of the dispatch mode + context manager in the enable function/destructor, respectively. This is to + explicitly set the lifetime of the dispatch mode object to that of the application. + This approach was deemed more elegant than using the atexit module. + """ + + def __init__(self) -> None: + self.dispatch = CUDASanitizerDispatchMode() + self.enabled = False + + def enable(self): + self.dispatch.__enter__() + self.enabled = True + + def disable(self): + self.dispatch.__exit__(None, None, None) + self.enabled = False + + def __del__(self): + # Since this object lifetime is linked to the `torch.cuda._sanitizer` python + # module, it often gets deleted as part of the overall `torch` module cleanup + # At that time, depending on CPython version, the torch.* module might be in + # different states of being already cleaned up. + # Similarly other imports might already have been cleaned up so `sys` might + # be already gone as well. + # Skip exiting the mode if it outlived the runtime. + if (sys is not None) and (not sys.is_finalizing()) and self.enabled: + self.disable() + + +def enable_cuda_sanitizer(): + """Enable CUDA Sanitizer. + + The sanitizer will begin to analyze low-level CUDA calls invoked by torch functions + for synchronization errors. All data races found will be printed to the standard + error output along with stack traces of suspected causes. For best results, the + sanitizer should be enabled at the very beginning of the program. + """ + cuda_sanitizer.enable() + + +cuda_sanitizer = CUDASanitizer() diff --git a/phivenv/Lib/site-packages/torch/cuda/_utils.py b/phivenv/Lib/site-packages/torch/cuda/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..45ea572332444d7ddaf627d30581d08f364349ca --- /dev/null +++ b/phivenv/Lib/site-packages/torch/cuda/_utils.py @@ -0,0 +1,363 @@ +import ctypes +import sys +from typing import Any, Optional, Union + +import torch + +# The _get_device_index has been moved to torch.utils._get_device_index +from torch._utils import _get_device_index as _torch_get_device_index + + +# Load CUDA driver and NVRTC +def _get_cuda_library() -> ctypes.CDLL: + if sys.platform == "win32": + return ctypes.CDLL("nvcuda.dll") + else: # Unix-based systems + return ctypes.CDLL("libcuda.so.1") + + +# Helper: check CUDA errors +def _check_cuda(result: int) -> None: + if result == 0: + return + err_str = ctypes.c_char_p() + libcuda = _get_cuda_library() # Get reference to CUDA library + libcuda.cuGetErrorString(result, ctypes.byref(err_str)) + error_message = ( + err_str.value.decode() if err_str.value is not None else "Unknown CUDA error" + ) + raise RuntimeError(f"CUDA error: {error_message}") + + +def _get_nvrtc_library() -> ctypes.CDLL: + # Since PyTorch already loads NVRTC, we can use the system library + # which should be compatible with PyTorch's version + if sys.platform == "win32": + return ctypes.CDLL("nvrtc64_120_0.dll") + else: + return ctypes.CDLL("libnvrtc.so") + + +def _nvrtc_compile( + kernel_source: str, + kernel_name: str, + compute_capability: Optional[str] = None, + header_code: str = "", + cuda_include_dirs: Optional[list] = None, + nvcc_options: Optional[list] = None, +) -> bytes: + """ + Compiles a CUDA kernel using NVRTC and returns the PTX code. + + Args: + kernel_source (str): The CUDA kernel source code as a string + kernel_name (str): The name of the kernel function to compile + compute_capability (str, None): The compute capability to target (e.g., "86"). + If None, will detect from current device. + header_code (str, optional): Additional header code to prepend to the kernel source + cuda_include_dirs (list, None): List of directories containing CUDA headers + nvcc_options (list, None): Additional options to pass to NVRTC + + Returns: + str: The compiled PTX code + """ + # Ensure CUDA is initialized + import torch.cuda + + # Load NVRTC library + libnvrtc = _get_nvrtc_library() + + # NVRTC constants + NVRTC_SUCCESS = 0 + + # Helper: check NVRTC errors + def check_nvrtc(result: int) -> None: + if result != NVRTC_SUCCESS: + err_str = ctypes.c_char_p() + libnvrtc.nvrtcGetErrorString(result, ctypes.byref(err_str)) + error_message = ( + err_str.value.decode() + if err_str.value is not None + else "Unknown CUDA error" + ) + raise RuntimeError(f"CUDA error: {error_message}") + + # Add 'extern "C"' if not already present to ensure C linkage + if not kernel_source.strip().startswith('extern "C"'): + kernel_source = f'extern "C" {kernel_source}' + + # Combine header code and kernel source + if header_code: + full_source = header_code + "\n" + kernel_source + else: + full_source = kernel_source + + # Convert source to bytes + source_bytes = full_source.encode("utf-8") + + # Get compute capability if not provided + if compute_capability is None: + props = torch.cuda.get_device_properties(torch.cuda.current_device()) + compute_capability = f"{props.major}{props.minor}" + + # Prepare compilation options + options = [] + options.append(f"--gpu-architecture=sm_{compute_capability}".encode()) + + # Add custom include directories + if cuda_include_dirs: + for directory in cuda_include_dirs: + options.append(f"-I{directory}".encode()) + + # Add custom NVCC options + if nvcc_options: + for option in nvcc_options: + options.append(option.encode("utf-8")) + + # TODO: Should we refactor flags into a common place? + from torch.utils.cpp_extension import COMMON_NVCC_FLAGS + + # Filter out flags not supported by NVRTC + nvrtc_compatible_flags = [ + flag for flag in COMMON_NVCC_FLAGS if flag != "--expt-relaxed-constexpr" + ] + options.extend([flag.encode("utf-8") for flag in nvrtc_compatible_flags]) + + # Convert options to C array + num_options = len(options) + options_array = (ctypes.c_char_p * num_options)(*options) + + # Create program + prog = ctypes.c_void_p() + check_nvrtc( + libnvrtc.nvrtcCreateProgram( + ctypes.byref(prog), + source_bytes, + f"{kernel_name}.cu".encode(), + 0, + None, + None, + ) + ) + + # Compile program + res = libnvrtc.nvrtcCompileProgram(prog, num_options, options_array) + + # Handle compilation errors + if res != NVRTC_SUCCESS: + # Get log + log_size = ctypes.c_size_t() + libnvrtc.nvrtcGetProgramLogSize(prog, ctypes.byref(log_size)) + log = ctypes.create_string_buffer(log_size.value) + libnvrtc.nvrtcGetProgramLog(prog, log) + raise RuntimeError(f"Kernel compilation failed:\n{log.value.decode()}") + + # Get PTX + ptx_size = ctypes.c_size_t() + check_nvrtc(libnvrtc.nvrtcGetPTXSize(prog, ctypes.byref(ptx_size))) + ptx = ctypes.create_string_buffer(ptx_size.value) + check_nvrtc(libnvrtc.nvrtcGetPTX(prog, ptx)) + libnvrtc.nvrtcDestroyProgram(ctypes.byref(prog)) + + return ptx.value + + +class _CudaModule: + def __init__(self, module: ctypes.c_void_p) -> None: + self._module = module + self._kernels: dict[str, _CudaKernel] = {} + + def __getattr__(self, name: str) -> "_CudaKernel": + if name in self._kernels: + return self._kernels[name] + + # Import the CUDA library inside the method + from torch.cuda._utils import _get_cuda_library + + libcuda = _get_cuda_library() + + func = ctypes.c_void_p() + try: + _check_cuda( + libcuda.cuModuleGetFunction( + ctypes.byref(func), self._module, name.encode("utf-8") + ) + ) + kernel = _CudaKernel(func, self._module) + self._kernels[name] = kernel + return kernel + + except RuntimeError as err: + raise AttributeError(f"No kernel named '{name}' in this module") from err + + +class _CudaKernel: + """ + Represents a compiled CUDA kernel that can be called with PyTorch tensors. + """ + + def __init__(self, func: ctypes.c_void_p, module: ctypes.c_void_p) -> None: + self.func = func + self.module = module + + def __call__( + self, + grid: tuple[int, int, int] = (1, 1, 1), + block: tuple[int, int, int] = (1, 1, 1), + args: Optional[list] = None, + shared_mem: int = 0, + stream: Optional[Any] = None, + ) -> None: + """ + Call the compiled CUDA kernel + + Args: + grid (tuple): Grid dimensions (grid_x, grid_y, grid_z) + block (tuple): Block dimensions (block_x, block_y, block_z) + args (list): List of arguments to pass to the kernel. + PyTorch tensor arguments will be automatically converted to pointers. + shared_mem (int): Shared memory size in bytes + stream (torch.cuda.Stream): CUDA stream to use. If None, uses current stream. + """ + import torch + + libcuda = torch.cuda._utils._get_cuda_library() + + if not args: + args = [] + + # Process arguments and convert tensors to pointers + processed_args: list[ctypes.c_void_p] = [] + c_args = [] + + for arg in args: + if isinstance(arg, torch.Tensor): + if not arg.is_cuda and not (arg.is_cpu and arg.is_pinned()): + raise ValueError( + "All tensor arguments must be CUDA tensors or pinned CPU tensors" + ) + # Get pointer to tensor data + ptr = ctypes.c_void_p(arg.data_ptr()) + processed_args.append(ptr) + c_args.append(ctypes.byref(ptr)) + elif isinstance(arg, int): + # Convert integers to C int + c_int = ctypes.c_int(arg) + # Store the C int for reference keeping, not in processed_args + c_args.append(ctypes.byref(c_int)) + # TODO: Python floats are actually doubles + elif isinstance(arg, float): + # Convert floats to C float + c_float = ctypes.c_float(arg) + # Store the C float for reference keeping, not in processed_args + c_args.append(ctypes.byref(c_float)) + else: + raise TypeError(f"Unsupported argument type: {type(arg)}") + + # Convert to array of void pointers + c_args_array = (ctypes.c_void_p * len(c_args))() + for i, arg in enumerate(c_args): + c_args_array[i] = ctypes.cast(arg, ctypes.c_void_p) + + # Get the stream + if stream is None: + # Defer import to avoid circular imports + import torch.cuda + + stream = torch.cuda.current_stream() + + _check_cuda( + libcuda.cuLaunchKernel( + self.func, + grid[0], + grid[1], + grid[2], + block[0], + block[1], + block[2], + shared_mem, + stream._as_parameter_, + c_args_array, + None, + ) + ) + + +def _cuda_load_module( + ptx: Union[str, bytes], kernel_names: Optional[list[str]] = None +) -> Union[_CudaModule, dict[str, "_CudaKernel"]]: + """ + Loads a CUDA module from PTX code and returns a module object that can access kernels. + + Args: + ptx (bytes or str): The PTX code to load + kernel_names (list, optional): List of kernel names to extract from the module. + If None, will return a module object with __getattr__. + + Returns: + object: If kernel_names is None, returns a module object with __getattr__ to access kernels. + If kernel_names is provided, returns a dict mapping kernel names to _CudaKernel objects. + """ + # Ensure CUDA is initialized + import torch.cuda + + # Load CUDA driver library + libcuda = _get_cuda_library() + + # Convert PTX to bytes if it's a string + if isinstance(ptx, str): + ptx = ptx.encode("utf-8") + + # Load PTX module + module = ctypes.c_void_p() + # Get the current stream without directly importing torch.cuda at module level + stream = torch.cuda.current_stream() + with stream: + _check_cuda(libcuda.cuModuleLoadData(ctypes.byref(module), ptx)) + + if not kernel_names: + return _CudaModule(module) + + # Return specific kernels + kernels = {} + for name in kernel_names: + func = ctypes.c_void_p() + _check_cuda( + libcuda.cuModuleGetFunction( + ctypes.byref(func), module, name.encode("utf-8") + ) + ) + kernels[name] = _CudaKernel(func, module) + return kernels + + +def _get_device_index( + device: Any, optional: bool = False, allow_cpu: bool = False +) -> int: + r"""Get the device index from :attr:`device`, which can be a torch.device object, a Python integer, or ``None``. + + If :attr:`device` is a torch.device object, returns the device index if it + is a CUDA device. Note that for a CUDA device without a specified index, + i.e., ``torch.device('cuda')``, this will return the current default CUDA + device if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``, + CPU devices will be accepted and ``-1`` will be returned in this case. + + If :attr:`device` is a Python integer, it is returned as is. + + If :attr:`device` is ``None``, this will return the current default CUDA + device if :attr:`optional` is ``True``. + """ + if isinstance(device, int): + return device + if isinstance(device, str): + device = torch.device(device) + if isinstance(device, torch.device): + if allow_cpu: + if device.type not in ["cuda", "cpu"]: + raise ValueError(f"Expected a cuda or cpu device, but got: {device}") + elif device.type != "cuda": + raise ValueError(f"Expected a cuda device, but got: {device}") + if not torch.jit.is_scripting(): + if isinstance(device, torch.cuda.device): + return device.idx + return _torch_get_device_index(device, optional, allow_cpu) diff --git a/phivenv/Lib/site-packages/torch/cuda/amp/__init__.py b/phivenv/Lib/site-packages/torch/cuda/amp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f9f35b9d4a29804a5efd041e5daa238878158d98 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/cuda/amp/__init__.py @@ -0,0 +1,12 @@ +from .autocast_mode import autocast, custom_bwd, custom_fwd +from .common import amp_definitely_not_available +from .grad_scaler import GradScaler + + +__all__ = [ + "amp_definitely_not_available", + "autocast", + "custom_bwd", + "custom_fwd", + "GradScaler", +] diff --git a/phivenv/Lib/site-packages/torch/cuda/amp/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/cuda/amp/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..830a5a497932be36dbd29e526d53711391c7f563 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/cuda/amp/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/cuda/amp/__pycache__/autocast_mode.cpython-39.pyc b/phivenv/Lib/site-packages/torch/cuda/amp/__pycache__/autocast_mode.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c33dde4fec99a6f7649eec9fc42f727f425952fe Binary files /dev/null and b/phivenv/Lib/site-packages/torch/cuda/amp/__pycache__/autocast_mode.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/cuda/amp/__pycache__/common.cpython-39.pyc b/phivenv/Lib/site-packages/torch/cuda/amp/__pycache__/common.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75df260d250df8e597d4972d0812641e0247fd16 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/cuda/amp/__pycache__/common.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/cuda/amp/__pycache__/grad_scaler.cpython-39.pyc b/phivenv/Lib/site-packages/torch/cuda/amp/__pycache__/grad_scaler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94aadc71d58dbf271b0954c366b5ef7f6300b6ef Binary files /dev/null and b/phivenv/Lib/site-packages/torch/cuda/amp/__pycache__/grad_scaler.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/cuda/amp/autocast_mode.py b/phivenv/Lib/site-packages/torch/cuda/amp/autocast_mode.py new file mode 100644 index 0000000000000000000000000000000000000000..d2fa2bc9e6d3a40974fd6d668b829d98d4d00eff --- /dev/null +++ b/phivenv/Lib/site-packages/torch/cuda/amp/autocast_mode.py @@ -0,0 +1,90 @@ +# mypy: allow-untyped-defs +import functools +from typing import Any +from typing_extensions import deprecated + +import torch + + +__all__ = ["autocast", "custom_fwd", "custom_bwd"] + + +class autocast(torch.amp.autocast_mode.autocast): + r"""See :class:`torch.autocast`. + + ``torch.cuda.amp.autocast(args...)`` is deprecated. Please use ``torch.amp.autocast("cuda", args...)`` instead. + """ + + @deprecated( + "`torch.cuda.amp.autocast(args...)` is deprecated. " + "Please use `torch.amp.autocast('cuda', args...)` instead.", + category=FutureWarning, + ) + def __init__( + self, + enabled: bool = True, + dtype: torch.dtype = torch.float16, + cache_enabled: bool = True, + ): + if torch._jit_internal.is_scripting(): + self._enabled = enabled + self.device = "cuda" + self.fast_dtype = dtype + return + super().__init__( + "cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled + ) + + def __enter__(self): + if torch._jit_internal.is_scripting(): + return self + return super().__enter__() + + # TODO: discuss a unified TorchScript-friendly API for autocast + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override] + if torch._jit_internal.is_scripting(): + return + return super().__exit__(exc_type, exc_val, exc_tb) + + def __call__(self, func): + if torch._jit_internal.is_scripting(): + return func + return super().__call__(func) + + +# Preserved only for BC reasons +@deprecated( + "`torch.cuda.amp.autocast_mode._cast(value, dtype)` is deprecated. " + "Please use `torch.amp.autocast_mode._cast(value, 'cuda', dtype)` instead.", + category=FutureWarning, +) +def _cast(value, dtype): + return torch.amp.autocast_mode._cast(value, "cuda", dtype) + + +@deprecated( + "`torch.cuda.amp.custom_fwd(args...)` is deprecated. " + "Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.", + category=FutureWarning, +) +def custom_fwd(fwd=None, *, cast_inputs=None): + """ + ``torch.cuda.amp.custom_fwd(args...)`` is deprecated. Please use + ``torch.amp.custom_fwd(args..., device_type='cuda')`` instead. + """ + return functools.partial(torch.amp.custom_fwd, device_type="cuda")( + fwd=fwd, cast_inputs=cast_inputs + ) + + +@deprecated( + "`torch.cuda.amp.custom_bwd(args...)` is deprecated. " + "Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.", + category=FutureWarning, +) +def custom_bwd(bwd): + """ + ``torch.cuda.amp.custom_bwd(args...)`` is deprecated. Please use + ``torch.amp.custom_bwd(args..., device_type='cuda')`` instead. + """ + return functools.partial(torch.amp.custom_bwd, device_type="cuda")(bwd) diff --git a/phivenv/Lib/site-packages/torch/cuda/amp/common.py b/phivenv/Lib/site-packages/torch/cuda/amp/common.py new file mode 100644 index 0000000000000000000000000000000000000000..1f98700a14b84e10e2962855c9548c99a7ff4c69 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/cuda/amp/common.py @@ -0,0 +1,11 @@ +# mypy: allow-untyped-defs +from importlib.util import find_spec + +import torch + + +__all__ = ["amp_definitely_not_available"] + + +def amp_definitely_not_available(): + return not (torch.cuda.is_available() or find_spec("torch_xla")) diff --git a/phivenv/Lib/site-packages/torch/cuda/amp/grad_scaler.py b/phivenv/Lib/site-packages/torch/cuda/amp/grad_scaler.py new file mode 100644 index 0000000000000000000000000000000000000000..255d72f167c5a759c83c50da319133f271b52d60 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/cuda/amp/grad_scaler.py @@ -0,0 +1,38 @@ +from typing_extensions import deprecated + +import torch + +# We need to keep this unused import for BC reasons +from torch.amp.grad_scaler import OptState # noqa: F401 + + +__all__ = ["GradScaler"] + + +class GradScaler(torch.amp.GradScaler): + r""" + See :class:`torch.amp.GradScaler`. + ``torch.cuda.amp.GradScaler(args...)`` is deprecated. Please use ``torch.amp.GradScaler("cuda", args...)`` instead. + """ + + @deprecated( + "`torch.cuda.amp.GradScaler(args...)` is deprecated. " + "Please use `torch.amp.GradScaler('cuda', args...)` instead.", + category=FutureWarning, + ) + def __init__( + self, + init_scale: float = 2.0**16, + growth_factor: float = 2.0, + backoff_factor: float = 0.5, + growth_interval: int = 2000, + enabled: bool = True, + ) -> None: + super().__init__( + "cuda", + init_scale=init_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + enabled=enabled, + ) diff --git a/phivenv/Lib/site-packages/torch/cuda/comm.py b/phivenv/Lib/site-packages/torch/cuda/comm.py new file mode 100644 index 0000000000000000000000000000000000000000..94b2e67b1d68ca7290c901151ee5decc4fd23aaa --- /dev/null +++ b/phivenv/Lib/site-packages/torch/cuda/comm.py @@ -0,0 +1,19 @@ +# The functions here have been moved to torch.nn.parallel.comm +from torch.nn.parallel.comm import ( + broadcast, + broadcast_coalesced, + gather, + reduce_add, + reduce_add_coalesced, + scatter, +) + + +__all__ = [ + "broadcast", + "broadcast_coalesced", + "reduce_add", + "reduce_add_coalesced", + "scatter", + "gather", +] diff --git a/phivenv/Lib/site-packages/torch/cuda/error.py b/phivenv/Lib/site-packages/torch/cuda/error.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/cuda/gds.py b/phivenv/Lib/site-packages/torch/cuda/gds.py new file mode 100644 index 0000000000000000000000000000000000000000..38405a6c289c29d6c5116b10b8103cc22ff4fce4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/cuda/gds.py @@ -0,0 +1,166 @@ +import os +import sys +from typing import Callable, Optional + +import torch +from torch.types import Storage + + +__all__: list[str] = [ + "gds_register_buffer", + "gds_deregister_buffer", + "GdsFile", +] + + +def _dummy_fn(name: str) -> Callable: + def fn(*args, **kwargs): # type: ignore[no-untyped-def] + raise RuntimeError(f"torch._C.{name} is not supported on this platform") + + return fn + + +if not hasattr(torch._C, "_gds_register_buffer"): + assert not hasattr(torch._C, "_gds_deregister_buffer") + assert not hasattr(torch._C, "_gds_register_handle") + assert not hasattr(torch._C, "_gds_deregister_handle") + assert not hasattr(torch._C, "_gds_load_storage") + assert not hasattr(torch._C, "_gds_save_storage") + # Define functions + torch._C.__dict__["_gds_register_buffer"] = _dummy_fn("_gds_register_buffer") + torch._C.__dict__["_gds_deregister_buffer"] = _dummy_fn("_gds_deregister_buffer") + torch._C.__dict__["_gds_register_handle"] = _dummy_fn("_gds_register_handle") + torch._C.__dict__["_gds_deregister_handle"] = _dummy_fn("_gds_deregister_handle") + torch._C.__dict__["_gds_load_storage"] = _dummy_fn("_gds_load_storage") + torch._C.__dict__["_gds_save_storage"] = _dummy_fn("_gds_save_storage") + + +def gds_register_buffer(s: Storage) -> None: + """Registers a storage on a CUDA device as a cufile buffer. + + Example:: + + >>> # xdoctest: +SKIP("gds filesystem requirements") + >>> src = torch.randn(1024, device="cuda") + >>> s = src.untyped_storage() + >>> gds_register_buffer(s) + + Args: + s (Storage): Buffer to register. + """ + torch._C._gds_register_buffer(s) + + +def gds_deregister_buffer(s: Storage) -> None: + """Deregisters a previously registered storage on a CUDA device as a cufile buffer. + + Example:: + + >>> # xdoctest: +SKIP("gds filesystem requirements") + >>> src = torch.randn(1024, device="cuda") + >>> s = src.untyped_storage() + >>> gds_register_buffer(s) + >>> gds_deregister_buffer(s) + + Args: + s (Storage): Buffer to register. + """ + torch._C._gds_deregister_buffer(s) + + +class GdsFile: + r"""Wrapper around cuFile. + + cuFile is a file-like interface to the GPUDirect Storage (GDS) API. + + See the `cufile docs `_ + for more details. + + Args: + filename (str): Name of the file to open. + flags (int): Flags to pass to ``os.open`` when opening the file. ``os.O_DIRECT`` will + be added automatically. + + Example:: + + >>> # xdoctest: +SKIP("gds filesystem requirements") + >>> src1 = torch.randn(1024, device="cuda") + >>> src2 = torch.randn(2, 1024, device="cuda") + >>> file = torch.cuda.gds.GdsFile(f, os.O_CREAT | os.O_RDWR) + >>> file.save_storage(src1.untyped_storage(), offset=0) + >>> file.save_storage(src2.untyped_storage(), offset=src1.nbytes) + >>> dest1 = torch.empty(1024, device="cuda") + >>> dest2 = torch.empty(2, 1024, device="cuda") + >>> file.load_storage(dest1.untyped_storage(), offset=0) + >>> file.load_storage(dest2.untyped_storage(), offset=src1.nbytes) + >>> torch.equal(src1, dest1) + True + >>> torch.equal(src2, dest2) + True + + """ + + def __init__(self, filename: str, flags: int): + if sys.platform == "win32": + raise RuntimeError("GdsFile is not supported on this platform.") + self.filename = filename + self.flags = flags + self.fd = os.open(filename, flags | os.O_DIRECT) # type: ignore[attr-defined] + self.handle: Optional[int] = None + self.register_handle() + + def __del__(self) -> None: + if self.handle is not None: + self.deregister_handle() + os.close(self.fd) + + def register_handle(self) -> None: + """Registers file descriptor to cuFile Driver. + + This is a wrapper around ``cuFileHandleRegister``. + """ + assert ( + self.handle is None + ), "Cannot register a handle that is already registered." + self.handle = torch._C._gds_register_handle(self.fd) + + def deregister_handle(self) -> None: + """Deregisters file descriptor from cuFile Driver. + + This is a wrapper around ``cuFileHandleDeregister``. + """ + assert ( + self.handle is not None + ), "Cannot deregister a handle that is not registered." + torch._C._gds_deregister_handle(self.handle) + self.handle = None + + def load_storage(self, storage: Storage, offset: int = 0) -> None: + """Loads data from the file into the storage. + + This is a wrapper around ``cuFileRead``. ``storage.nbytes()`` of data + will be loaded from the file at ``offset`` into the storage. + + Args: + storage (Storage): Storage to load data into. + offset (int, optional): Offset into the file to start loading from. (Default: 0) + """ + assert ( + self.handle is not None + ), "Cannot load data from a file that is not registered." + torch._C._gds_load_storage(self.handle, storage, offset) + + def save_storage(self, storage: Storage, offset: int = 0) -> None: + """Saves data from the storage into the file. + + This is a wrapper around ``cuFileWrite``. All bytes of the storage + will be written to the file at ``offset``. + + Args: + storage (Storage): Storage to save data from. + offset (int, optional): Offset into the file to start saving to. (Default: 0) + """ + assert ( + self.handle is not None + ), "Cannot save data to a file that is not registered." + torch._C._gds_save_storage(self.handle, storage, offset) diff --git a/phivenv/Lib/site-packages/torch/cuda/graphs.py b/phivenv/Lib/site-packages/torch/cuda/graphs.py new file mode 100644 index 0000000000000000000000000000000000000000..d102798d4503b0142ec3611a131f1276b238ab19 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/cuda/graphs.py @@ -0,0 +1,526 @@ +# mypy: allow-untyped-defs +import gc +import typing + +import torch + +from .._utils import _dummy_type + + +if not hasattr(torch._C, "_CudaStreamBase"): + # Define dummy base classes + torch._C.__dict__["_CUDAGraph"] = _dummy_type("_CUDAGraph") + torch._C.__dict__["_graph_pool_handle"] = _dummy_type("_graph_pool_handle") + torch._C.__dict__["_cuda_isCurrentStreamCapturing"] = _dummy_type( + "_cuda_isCurrentStreamCapturing" + ) + +from torch._C import ( # noqa: F401 + _cuda_isCurrentStreamCapturing, + _CUDAGraph, + _graph_pool_handle, +) + + +def is_current_stream_capturing(): + r"""Return True if CUDA graph capture is underway on the current CUDA stream, False otherwise. + + If a CUDA context does not exist on the current device, returns False without initializing the context. + """ + return _cuda_isCurrentStreamCapturing() + + +# Python shim helps Sphinx process docstrings more reliably. +def graph_pool_handle(): + r"""Return an opaque token representing the id of a graph memory pool. + + See :ref:`Graph memory management`. + + .. warning:: + This API is in beta and may change in future releases. + """ + return _graph_pool_handle() + + +# Python shim helps Sphinx process docstrings more reliably. +class CUDAGraph(torch._C._CUDAGraph): + r"""Wrapper around a CUDA graph. + + Arguments: + keep_graph (bool, optional): If ``keep_graph=False``, the + cudaGraphExec_t will be instantiated on GPU at the end of + ``capture_end`` and the underlying cudaGraph_t will be + destroyed. Users who want to query or otherwise modify the + underlying cudaGraph_t before instantiatiation can set + ``keep_graph=True`` and access it via ``raw_cuda_graph`` after + ``capture_end``. Note that the cudaGraphExec_t will not be + instantiated at the end of ``capture_end`` in this + case. Instead, it wil be instantiated via an explicit called + to ``instantiate`` or automatically on the first call to + ``replay`` if ``instantiate`` was not already called. Calling + ``instantiate`` manually before ``replay`` is recommended to + prevent increased latency on the first call to ``replay``. It + is allowed to modify the raw cudaGraph_t after first calling + ``instantiate``, but the user must call ``instantiate`` again + manually to make sure the instantiated graph has these + changes. Pytorch has no means of tracking these changes. + + .. warning:: + This API is in beta and may change in future releases. + + """ + + def __new__(cls, keep_graph=False): + return super().__new__(cls, keep_graph) + + def capture_begin(self, pool=None, capture_error_mode="global"): + r"""Begin capturing CUDA work on the current stream. + + Typically, you shouldn't call ``capture_begin`` yourself. + Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`, + which call ``capture_begin`` internally. + + Arguments: + pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or + :meth:`other_Graph_instance.pool()`) that hints this graph may share memory + with the indicated pool. See :ref:`Graph memory management`. + capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream. + Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc, + may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for + actions in the current thread, and "relaxed" will not error on these actions. Do NOT change this setting + unless you're familiar with `cudaStreamCaptureMode `_ + """ # noqa: B950 + super().capture_begin(pool=pool, capture_error_mode=capture_error_mode) + + def capture_end(self): + r"""End CUDA graph capture on the current stream. + + After ``capture_end``, ``replay`` may be called on this instance. + + Typically, you shouldn't call ``capture_end`` yourself. + Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`, + which call ``capture_end`` internally. + """ + super().capture_end() + + def instantiate(self): + r"""Instantiate the CUDA graph. Will be called by + ``capture_end`` if ``keep_graph=False``, or by ``replay`` if + ``keep_graph=True`` and ``instantiate`` has not already been + explicitly called. Does not destroy the cudaGraph_t returned + by ``raw_cuda_graph``. + """ + super().instantiate() + + def replay(self): + r"""Replay the CUDA work captured by this graph.""" + super().replay() + + def reset(self): + r"""Delete the graph currently held by this instance.""" + super().reset() + + def pool(self): + r"""Return an opaque token representing the id of this graph's memory pool. + + This id can optionally be passed to another graph's ``capture_begin``, + which hints the other graph may share the same memory pool. + """ + return super().pool() + + def enable_debug_mode(self): + r"""Enable debugging mode for CUDAGraph.debug_dump.""" + return super().enable_debug_mode() + + def debug_dump(self, debug_path): + r""" + Arguments: + debug_path (required): Path to dump the graph to. + + Calls a debugging function to dump the graph if the debugging is + enabled via CUDAGraph.enable_debug_mode() + """ + return super().debug_dump(debug_path) + + def raw_cuda_graph(self): + r"""Returns the underlying cudaGraph_t. ``keep_graph`` must be True. + + See the following for APIs for how to manipulate this object: `Graph Managmement `_ and `cuda-python Graph Management bindings `_ + """ # noqa: B950 + return super().raw_cuda_graph() + + +class graph: + r"""Context-manager that captures CUDA work into a :class:`torch.cuda.CUDAGraph` object for later replay. + + See :ref:`CUDA Graphs ` for a general introduction, + detailed use, and constraints. + + Arguments: + cuda_graph (torch.cuda.CUDAGraph): Graph object used for capture. + pool (optional): Opaque token (returned by a call to :func:`~torch.cuda.graph_pool_handle()` or + :meth:`other_Graph_instance.pool()`) hinting this graph's capture + may share memory from the specified pool. See :ref:`Graph memory management`. + stream (torch.cuda.Stream, optional): If supplied, will be set as the current stream in the context. + If not supplied, ``graph`` sets its own internal side stream as the current stream in the context. + capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream. + Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc, + may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for + actions in the current thread, and "relaxed" will not error on actions. Do NOT change this setting + unless you're familiar with `cudaStreamCaptureMode `_ + + .. note:: + For effective memory sharing, if you pass a ``pool`` used by a previous capture and the previous capture + used an explicit ``stream`` argument, you should pass the same ``stream`` argument to this capture. + + .. warning:: + This API is in beta and may change in future releases. + + .. _cudaStreamCaptureMode: + https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85 + """ # noqa: B950 + + default_capture_stream: typing.Optional["torch.cuda.Stream"] = None + + def __init__( + self, + cuda_graph, + pool=None, + stream=None, + capture_error_mode: str = "global", + ): + # Lazy-init of default_capture_stream helps avoid circular-import errors. + # Not thread safe, but graphs already have the general (explicitly documented) + # restriction that only one capture may be underway at a time in the process. + if self.__class__.default_capture_stream is None: + self.__class__.default_capture_stream = torch.cuda.Stream() + + self.pool = () if pool is None else (pool,) + self.capture_stream = ( + stream if stream is not None else self.__class__.default_capture_stream + ) + assert self.capture_stream is not None + self.stream_ctx = torch.cuda.stream(self.capture_stream) + self.cuda_graph = cuda_graph + self.capture_error_mode = capture_error_mode + + def __enter__(self): + # Free as much memory as we can for the graph + torch.cuda.synchronize() + gc.collect() + torch.cuda.empty_cache() + + # Stackoverflow seems comfortable with this pattern + # https://stackoverflow.com/questions/26635684/calling-enter-and-exit-manually#39172487 + self.stream_ctx.__enter__() + + self.cuda_graph.capture_begin( + *self.pool, capture_error_mode=self.capture_error_mode + ) + + def __exit__(self, exc_type, exc_value, traceback): + self.cuda_graph.capture_end() + self.stream_ctx.__exit__(exc_type, exc_value, traceback) + # returning None should propagate exceptions from either capture_end or stream_ctx.__exit__() + + +def make_graphed_callables( + callables, sample_args, num_warmup_iters=3, allow_unused_input=False, pool=None +): + r"""Accept callables (functions or :class:`nn.Module`\ s) and returns graphed versions. + + Each graphed callable's forward pass runs its source callable's + forward CUDA work as a CUDA graph inside a single autograd node. + + The graphed callable's forward pass also appends + a backward node to the autograd graph. During backward, this node runs the + callable's backward work as a CUDA graph. + + Therefore, each graphed callable should be a drop-in replacement for its source callable + in an autograd-enabled training loop. + + See :ref:`Partial-network capture` for detailed use and constraints. + + If you pass a tuple of several callables, their captures will use the same memory pool. + See :ref:`Graph memory management` for when this is appropriate. + + Arguments: + callables (torch.nn.Module or Python function, or tuple of these): Callable or callables to graph. + See :ref:`Graph memory management` for when passing a tuple of callables + is appropriate. If you pass a tuple of callables, their order in the tuple must be the same order + they'll run in the live workload. + sample_args (tuple of Tensors, or tuple of tuples of Tensors): Samples args for each callable. + If a single callable was passed, ``sample_args`` must be a single tuple of argument Tensors. + If a tuple of callables was passed, ``sample_args`` must be tuple of tuples of argument Tensors. + num_warmup_iters (int): The number of warmup iterations. Currently, ``DataDistributedParallel`` needs + 11 iterations for warm up. Default: ``3``. + allow_unused_input (bool): If False, specifying inputs that were not used when computing outputs + (and therefore their grad is always zero) is an error. Defaults to False. + pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or + :meth:`other_Graph_instance.pool()`) that hints this graph may share memory + with the indicated pool. See :ref:`Graph memory management`. + .. note:: + The ``requires_grad`` state of each Tensor in ``sample_args`` must match the state + that's expected for the corresponding real input in the training loop. + + .. warning:: + This API is in beta and may change in future releases. + + .. warning:: + ``sample_args`` for each callable must contain only Tensors. Other types are not allowed. + + .. warning:: + Returned callables do not support higher order differentiation (e.g., double backward). + + .. warning:: + In any :class:`~torch.nn.Module` passed to :func:`~make_graphed_callables`, only parameters + may be trainable. Buffers must have ``requires_grad=False``. + + .. warning:: + After you pass a :class:`torch.nn.Module` through :func:`~make_graphed_callables`, + you may not add or remove any of that Module's parameters or buffers. + + .. warning:: + :class:`torch.nn.Module`\s passed to :func:`~torch.cuda.make_graphed_callables` must not have module hooks + registered on them at the time they are passed. However, registering hooks on modules *after* passing them + through :func:`~torch.cuda.make_graphed_callables` is allowed. + + .. warning:: + When running a graphed callable, you must pass its arguments in the same order and format + they appeared in that callable's ``sample_args``. + + .. warning:: + The automatic mixed precision is supported in :func:`~torch.cuda.make_graphed_callables` only with disabled + caching. The context manager `torch.cuda.amp.autocast()` must have `cache_enabled=False`. + """ + if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled(): + raise RuntimeError( + "make_graphed_callables does not support the autocast caching. Please set `cache_enabled=False`." + ) + + just_one_callable = False + + if not isinstance(callables, tuple): + just_one_callable = True + callables = (callables,) + sample_args = (sample_args,) + + flatten_sample_args = [] + + for c, args in zip(callables, sample_args): + if isinstance(c, torch.nn.Module): + assert ( + len(c._backward_hooks) == 0 + and len(c._forward_hooks) == 0 + and len(c._forward_pre_hooks) == 0 + ), ( + "Modules must not have hooks registered at the time they are passed. However, registering hooks " + + "on modules after passing them through make_graphed_callables is allowed." + ) + assert all(b.requires_grad is False for b in c.buffers()), ( + "In any :class:`~torch.nn.Module` passed to " + + ":func:`~make_graphed_callables`, only parameters may be trainable. All buffers must have " + + "``requires_grad=False``." + ) + flatten_arg = torch.utils._pytree.arg_tree_leaves(*args) + flatten_sample_args.append(tuple(flatten_arg)) + assert all(isinstance(arg, torch.Tensor) for arg in flatten_arg), ( + "In the beta API, sample_args " + + "for each callable must contain only Tensors. Other types are not allowed." + ) + + # If a callable is an nn.Module, its graph's full input surface is the args the user explicitly + # passes to forward (ie, its sample_args) AND the module's parameter attributes. + per_callable_len_user_args = [len(args) for args in flatten_sample_args] + per_callable_module_params = [ + tuple(c.parameters()) if isinstance(c, torch.nn.Module) else () + for c in callables + ] + per_callable_static_input_surfaces = [ + flatten_sample_args[i] + per_callable_module_params[i] + for i in range(len(callables)) + ] + + fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))] + bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))] + + mempool = graph_pool_handle() if pool is None else pool + + # Warmup + # Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work + # from ending up in any captures. + torch.cuda.synchronize() + with torch.cuda.stream(torch.cuda.Stream()): + for func, args, static_input_surface in zip( + callables, sample_args, per_callable_static_input_surfaces + ): + grad_inputs, outputs, outputs_grad = None, None, None + for _ in range(num_warmup_iters): + outputs = torch.utils._pytree.tree_leaves(func(*args)) + outputs_grad = tuple(o for o in outputs if o.requires_grad) + if len(outputs_grad) > 0: + grad_inputs = torch.autograd.grad( + outputs=outputs_grad, + inputs=tuple( + i for i in static_input_surface if i.requires_grad + ), + grad_outputs=tuple( + torch.empty_like(o) for o in outputs if o.requires_grad + ), + only_inputs=True, + allow_unused=allow_unused_input, + ) + for v in [outputs, outputs_grad, grad_inputs]: + del v + + torch.cuda.synchronize() + + # All captures here share a mempool. To avoid replays corrupting each other's memory, + # the safest approach is to capture all passes in the same order they'll run: + # fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1. + + # Capture forward graphs + per_callable_static_outputs = [] + per_callable_output_unflatten_spec = [] + for func, args, fwd_graph in zip(callables, sample_args, fwd_graphs): + with torch.cuda.graph(fwd_graph, pool=mempool): + outputs = func(*args) + + flatten_outputs, spec = torch.utils._pytree.tree_flatten(outputs) + per_callable_static_outputs.append(tuple(flatten_outputs)) + per_callable_output_unflatten_spec.append(spec) + + # Capture backward graphs in reverse order + per_callable_static_grad_outputs = [] + per_callable_static_grad_inputs = [] + for static_input_surface, static_outputs, bwd_graph in zip( + reversed(per_callable_static_input_surfaces), + reversed(per_callable_static_outputs), + reversed(bwd_graphs), + ): + # For now, assumes all static_outputs require grad + # assert all(o.requires_grad for o in static_outputs), "Outputs of graphed callables must require grad." + static_grad_outputs = tuple( + torch.empty_like(o) if o.requires_grad else None for o in static_outputs + ) + + outputs_grad = tuple(o for o in static_outputs if o.requires_grad) + grad_inputs = None + if len(outputs_grad) > 0: + with torch.cuda.graph(bwd_graph, pool=mempool): + grad_inputs = torch.autograd.grad( + outputs=outputs_grad, + inputs=tuple(i for i in static_input_surface if i.requires_grad), + grad_outputs=tuple(o for o in static_grad_outputs if o is not None), + only_inputs=True, + allow_unused=allow_unused_input, + ) + + # Constructs a tuple suitable for returning from Graphed.backward: + # Pads out the actually-needed grads with Nones in gradient slots for inputs that don't require grad. + # I couldn't think of a slick one-liner for this pattern. + static_grad_inputs = [] + grad_idx = 0 + for arg in static_input_surface: + if arg.requires_grad and grad_inputs is not None: + static_grad_inputs.append(grad_inputs[grad_idx]) + grad_idx += 1 + else: + static_grad_inputs.append(None) # type: ignore[arg-type] + static_grad_inputs = tuple(static_grad_inputs) # type: ignore[assignment] + + per_callable_static_grad_outputs.append(static_grad_outputs) + per_callable_static_grad_inputs.append(static_grad_inputs) + + # Reverses the most recent two lists + per_callable_static_grad_outputs.reverse() + per_callable_static_grad_inputs.reverse() + # Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable. + + def make_graphed_autograd_function( + fwd_graph, + bwd_graph, + module_params, + len_user_args, + output_unflatten_spec, + static_input_surface, + static_outputs, + static_grad_outputs, + static_grad_inputs, + ): + class Graphed(torch.autograd.Function): + @staticmethod + def forward(ctx, *inputs): + # At this stage, only the user args may (potentially) be new tensors. + for i in range(len_user_args): + if static_input_surface[i].data_ptr() != inputs[i].data_ptr(): + static_input_surface[i].copy_(inputs[i]) + fwd_graph.replay() + assert isinstance(static_outputs, tuple) + return tuple(o.detach() for o in static_outputs) + + @staticmethod + @torch.autograd.function.once_differentiable + def backward(ctx, *grads): + assert len(grads) == len(static_grad_outputs) + for g, grad in zip(static_grad_outputs, grads): + if g is not None: + # don't copy if autograd gods have been kind and the + # incoming grad is already in the right place + if g.data_ptr() != grad.data_ptr(): + g.copy_(grad) + bwd_graph.replay() + + # Input args that didn't require grad expect a None gradient. + assert isinstance(static_grad_inputs, tuple) + return tuple( + b.detach() if b is not None else b for b in static_grad_inputs + ) + + def functionalized(*user_args): + # Runs the autograd function with inputs == all inputs to the graph that might require grad + # (explicit user args + module parameters) + # Assumes module params didn't change since capture. + flatten_user_args = torch.utils._pytree.arg_tree_leaves(*user_args) + out = Graphed.apply(*(tuple(flatten_user_args) + module_params)) + return torch.utils._pytree.tree_unflatten(out, output_unflatten_spec) + + return functionalized + + # Put together the final graphed callables + ret = [] + for i, func in enumerate(callables): + graphed = make_graphed_autograd_function( + fwd_graphs[i], + bwd_graphs[i], + per_callable_module_params[i], + per_callable_len_user_args[i], + per_callable_output_unflatten_spec[i], + per_callable_static_input_surfaces[i], + per_callable_static_outputs[i], + per_callable_static_grad_outputs[i], + per_callable_static_grad_inputs[i], + ) + + if isinstance(func, torch.nn.Module): + + def make_graphed_forward(func, graph_training_state, graphed, orig_fwd): + def new_fwd(*user_args): + # If the module's training-or-eval state matches what we graphed, + # run the graph, otherwise run the original forward method + if func.training == graph_training_state: + return graphed(*user_args) + else: + return orig_fwd(*user_args) + + return new_fwd + + func.forward = make_graphed_forward(func, func.training, graphed, func.forward) # type: ignore[assignment] + ret.append(func) + else: + ret.append(graphed) + + if just_one_callable: + return ret[0] + + return tuple(ret) diff --git a/phivenv/Lib/site-packages/torch/cuda/jiterator.py b/phivenv/Lib/site-packages/torch/cuda/jiterator.py new file mode 100644 index 0000000000000000000000000000000000000000..df7e71a6ee52714e43b66eec90d64b52396ef786 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/cuda/jiterator.py @@ -0,0 +1,187 @@ +# mypy: allow-untyped-defs +import re +from typing import Callable + +import torch +from torch import Tensor + + +__all__: list[str] = [] + + +class _CodeParser: + def __init__(self, code_string: str): + optional_ws = r"\s*" + required_ws = r"\s+" + template_params = r"(?P\<.+\>)" + return_type = r"(?P\w+)" + function_name = r"(?P\w+)" + function_params = r"(?P\(.+\))" + function_body = r"(?P\{.+\})" + + pattern = ( + optional_ws + + "template" + + optional_ws + + template_params + + optional_ws + + return_type + + required_ws + + function_name + + optional_ws + + function_params + + optional_ws + + function_body + + optional_ws + ) + + result = re.match( + pattern, code_string, re.DOTALL + ) # DOTALL for matching multiline + + if result is None: + raise Exception( # noqa: TRY002 + f"Couldn't parse code, please check correctness:\n {code_string}" + ) + + self.template_params = result["template_params"] + self.return_type = result["return_type"] + self.function_name = result["function_name"] + self.function_params = result["function_params"] + self.function_body = result["function_body"] + + +class _JittedFunction: + def __init__( + self, code_string: str, return_by_ref: bool, num_outputs: int, **kwargs + ): + self.code_string = code_string + + assert ( + return_by_ref or num_outputs == 1 + ), "Return by value only works for single output. " + self.return_by_ref = return_by_ref + self.num_outputs = num_outputs + + parsed_code = _CodeParser(code_string) + self.kernel_name = parsed_code.function_name + + self.kwargs_dict = kwargs + self.is_cuda_available = torch.cuda.is_available() + + def __call__(self, *tensors: Tensor, **kwargs): + # Jiterator follow torch.cuda's lazy initialization behavior + # Defer checking cuda's availability at the function invocation time + assert ( + self.is_cuda_available + ), "Jiterator is only supported on CUDA and ROCm GPUs, none are available." + + assert len(tensors) <= 8, "jiterator only supports up to 8 tensor inputs." + + expanded_kwargs = self.kwargs_dict.copy() + for key, value in kwargs.items(): + if key in self.kwargs_dict: + expanded_kwargs[key] = value + else: + raise KeyError(f"{key} is not declared in function definition") + + return torch._C._cuda_jiterator_compile_and_launch_kernel( + self.code_string, + self.kernel_name, + self.return_by_ref, + self.num_outputs, + tensors, + expanded_kwargs, + ) + + +def _create_jit_fn(code_string: str, **kwargs) -> Callable: + """ + Create a jiterator-generated cuda kernel for an elementwise op. + + The code string has to be a valid CUDA function that describes the computation for a single element. The code + string has to follow the c++ template pattern, as shown in the example below. This function will be inlined + into elementwise kernel template, and compiled on the fly. Compiled kernel will be cached in memory, as well as + local temp dir. + + Jiterator-generated kernels accepts noncontiguous tensors, and supports broadcasting and type promotion. + + Args: + code_string (str): CUDA code string to be compiled by jiterator. The entry functor must return by value. + kwargs (Dict, optional): Keyword arguments for generated function + + Example:: + + code_string = "template T my_kernel(T x, T y, T alpha) { return -x + alpha * y; }" + jitted_fn = create_jit_fn(code_string, alpha=1.0) + a = torch.rand(3, device='cuda') + b = torch.rand(3, device='cuda') + # invoke jitted function like a regular python function + result = jitted_fn(a, b, alpha=3.14) + + code_string also allows multiple function definitions, and the last function will be treated as the entry function. + + Example:: + + code_string = "template T util_fn(T x, T y) { return ::sin(x) + ::cos(y); }" + code_string += "template T my_kernel(T x, T y, T val) { return ::min(val, util_fn(x, y)); }" + jitted_fn = create_jit_fn(code_string, val=0.0) + a = torch.rand(3, device='cuda') + b = torch.rand(3, device='cuda') + # invoke jitted function like a regular python function + result = jitted_fn(a, b) # using default val=0.0 + + Jiterator can be used together with python registration to override an operator's cuda kernel. + Following example is overriding gelu's cuda kernel with relu. + + Example:: + + code_string = "template T my_gelu(T a) { return a > 0 ? a : 0; }" + my_gelu = create_jit_fn(code_string) + my_lib = torch.library.Library("aten", "IMPL") + my_lib.impl('aten::gelu', my_gelu, "CUDA") + # torch.nn.GELU and torch.nn.function.gelu are now overridden + a = torch.rand(3, device='cuda') + torch.allclose(torch.nn.functional.gelu(a), torch.nn.functional.relu(a)) + + .. warning:: + This API is in beta and may change in future releases. + + .. warning:: + This API only supports up to 8 inputs and 1 output + + .. warning:: + All input tensors must live in CUDA device + """ + return _JittedFunction(code_string, return_by_ref=False, num_outputs=1, **kwargs) + + +def _create_multi_output_jit_fn( + code_string: str, num_outputs: int, **kwargs +) -> Callable: + """ + Create a jiterator-generated cuda kernel for an elementwise op that supports returning one or more outputs. + + Args: + code_string (str): CUDA code string to be compiled by jiterator. The entry functor must return value by reference. + num_outputs(int): number of outputs return by the kernel + kwargs (Dict, optional): Keyword arguments for generated function + + Example:: + + code_string = "template void my_kernel(T x, T y, T alpha, T& out) { out = -x + alpha * y; }" + jitted_fn = create_jit_fn(code_string, alpha=1.0) + a = torch.rand(3, device='cuda') + b = torch.rand(3, device='cuda') + # invoke jitted function like a regular python function + result = jitted_fn(a, b, alpha=3.14) + + .. warning:: + This API is in beta and may change in future releases. + + .. warning:: + This API only supports up to 8 inputs and 8 outputs + """ + return _JittedFunction( + code_string, return_by_ref=True, num_outputs=num_outputs, **kwargs + ) diff --git a/phivenv/Lib/site-packages/torch/cuda/memory.py b/phivenv/Lib/site-packages/torch/cuda/memory.py new file mode 100644 index 0000000000000000000000000000000000000000..24455a70652eeab4d67c0d5dae6e844ead9fd5de --- /dev/null +++ b/phivenv/Lib/site-packages/torch/cuda/memory.py @@ -0,0 +1,1237 @@ +# mypy: allow-untyped-defs +r"""This package adds support for device memory management implemented in CUDA.""" + +import collections +import contextlib +import ctypes +import pickle +import sys +import warnings +from inspect import signature +from typing import Any, Literal, Optional, TYPE_CHECKING +from typing_extensions import deprecated + +import torch +from torch import _C +from torch._utils import _dummy_type + +from . import ( + _get_amdsmi_device_index, + _get_device_index, + _get_nvml_device_index, + _lazy_init, + is_initialized, +) +from ._memory_viz import memory as _memory, segments as _segments + + +if TYPE_CHECKING: + from torch.types import Device + + +__all__ = [ + "caching_allocator_alloc", + "caching_allocator_delete", + "caching_allocator_enable", + "get_per_process_memory_fraction", + "set_per_process_memory_fraction", + "empty_cache", + "memory_stats", + "memory_stats_as_nested_dict", + "reset_accumulated_memory_stats", + "reset_peak_memory_stats", + "reset_max_memory_allocated", + "reset_max_memory_cached", + "host_memory_stats", + "host_memory_stats_as_nested_dict", + "reset_accumulated_host_memory_stats", + "reset_peak_host_memory_stats", + "memory_allocated", + "max_memory_allocated", + "memory_reserved", + "max_memory_reserved", + "memory_cached", + "max_memory_cached", + "memory_snapshot", + "memory_summary", + "list_gpu_processes", + "mem_get_info", + "get_allocator_backend", + "CUDAPluggableAllocator", + "change_current_allocator", + "MemPool", + "use_mem_pool", +] + + +if not hasattr(torch._C, "_cuda_CUDAAllocator"): + # Define dummy base classes + torch._C.__dict__["_cuda_CUDAAllocator"] = _dummy_type("_cuda_CUDAAllocator") + + +if not hasattr(torch._C, "_MemPool"): + # Define dummy base classes + torch._C.__dict__["_MemPool"] = _dummy_type("_MemPool") + torch._C.__dict__["_cuda_beginAllocateToPool"] = _dummy_type( + "_cuda_beginAllocateToPool" + ) + torch._C.__dict__["_cuda_beginAllocateCurrentThreadToPool"] = _dummy_type( + "_cuda_beginAllocateCurrentThreadToPool" + ) + torch._C.__dict__["_cuda_endAllocateToPool"] = _dummy_type( + "_cuda_endAllocateToPool" + ) + torch._C.__dict__["_cuda_releasePool"] = _dummy_type("_cuda_releasePool") + +from torch._C import ( # noqa: F401 + _cuda_beginAllocateCurrentThreadToPool, + _cuda_beginAllocateToPool, + _cuda_CUDAAllocator, + _cuda_endAllocateToPool, + _cuda_releasePool, + _MemPool, +) + + +def _host_allocator(): + _lazy_init() + return torch._C._cuda_cudaHostAllocator() + + +@contextlib.contextmanager +def _free_mutex(): + torch._C._cuda_lock_mutex() + try: + yield + finally: + torch._C._cuda_unlock_mutex() + + +def caching_allocator_alloc(size, device: "Device" = None, stream=None): + r"""Perform a memory allocation using the CUDA memory allocator. + + Memory is allocated for a given device and a stream, this + function is intended to be used for interoperability with other + frameworks. Allocated memory is released through + :func:`~torch.cuda.caching_allocator_delete`. + + Args: + size (int): number of bytes to be allocated. + device (torch.device or int, optional): selected device. If it is + ``None`` the default CUDA device is used. + stream (torch.cuda.Stream or int, optional): selected stream. If is ``None`` then + the default stream for the selected device is used. + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + if device is None: + device = torch.cuda.current_device() + device = _get_device_index(device) + if stream is None: + stream = torch.cuda.current_stream(device) + if isinstance(stream, torch.cuda.streams.Stream): + stream = stream.cuda_stream + if not isinstance(stream, int): + raise TypeError( + "Invalid type for stream argument, must be " + "`torch.cuda.Stream` or `int` representing a pointer " + "to a existing stream" + ) + with torch.cuda.device(device): + return torch._C._cuda_cudaCachingAllocator_raw_alloc(size, stream) + + +def caching_allocator_delete(mem_ptr): + r"""Delete memory allocated using the CUDA memory allocator. + + Memory allocated with :func:`~torch.cuda.caching_allocator_alloc`. + is freed here. The associated device and stream are tracked inside + the allocator. + + Args: + mem_ptr (int): memory address to be freed by the allocator. + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + torch._C._cuda_cudaCachingAllocator_raw_delete(mem_ptr) + + +def caching_allocator_enable(value: bool = True) -> None: + r"""Enable or disable the CUDA memory allocator. On by default.""" + if is_initialized(): + torch._C._cuda_cudaCachingAllocator_enable(value) + + +def set_per_process_memory_fraction(fraction, device: "Device" = None) -> None: + r"""Set memory fraction for a process. + + The fraction is used to limit an caching allocator to allocated memory on a CUDA device. + The allowed value equals the total visible memory multiplied fraction. + If trying to allocate more than the allowed value in a process, will raise an out of + memory error in allocator. + + Args: + fraction(float): Range: 0~1. Allowed memory equals total_memory * fraction. + device (torch.device or int, optional): selected device. If it is + ``None`` the default CUDA device is used. + .. note:: + In general, the total available free memory is less than the total capacity. + """ + _lazy_init() + if device is None: + device = torch.cuda.current_device() + device = _get_device_index(device) + if not isinstance(fraction, float): + raise TypeError("Invalid type for fraction argument, must be `float`") + if fraction < 0 or fraction > 1: + raise ValueError(f"Invalid fraction value: {fraction}. Allowed range: 0~1") + + torch._C._cuda_setMemoryFraction(fraction, device) + + +def get_per_process_memory_fraction(device: "Device" = None) -> float: + r"""Get memory fraction for a process. + + Args: + device (torch.device or int, optional): selected device. If it is + ``None`` the default CUDA device is used. + Returns: + memory fraction, in range 0~1. Allowed memory equals total_memory * fraction. + """ + _lazy_init() + if device is None: + device = torch.cuda.current_device() + device = _get_device_index(device) + return torch._C._cuda_getMemoryFraction(device) + + +def empty_cache() -> None: + r"""Release all unoccupied cached memory currently held by the caching + allocator so that those can be used in other GPU application and visible in + `nvidia-smi`. + + .. note:: + :func:`~torch.cuda.empty_cache` doesn't increase the amount of GPU + memory available for PyTorch. However, it may help reduce fragmentation + of GPU memory in certain cases. See :ref:`cuda-memory-management` for + more details about GPU memory management. + """ + if is_initialized(): + torch._C._cuda_emptyCache() + + +def memory_stats(device: "Device" = None) -> dict[str, Any]: + r"""Return a dictionary of CUDA memory allocator statistics for a given device. + + The return value of this function is a dictionary of statistics, each of + which is a non-negative integer. + + Core statistics: + + - ``"allocated.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + number of allocation requests received by the memory allocator. + - ``"allocated_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of allocated memory. + - ``"segment.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + number of reserved segments from ``cudaMalloc()``. + - ``"reserved_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of reserved memory. + - ``"active.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + number of active memory blocks. + - ``"active_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of active memory. + - ``"inactive_split.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + number of inactive, non-releasable memory blocks. + - ``"inactive_split_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of inactive, non-releasable memory. + + For these core statistics, values are broken down as follows. + + Pool type: + + - ``all``: combined statistics across all memory pools. + - ``large_pool``: statistics for the large allocation pool + (as of October 2019, for size >= 1MB allocations). + - ``small_pool``: statistics for the small allocation pool + (as of October 2019, for size < 1MB allocations). + + Metric type: + + - ``current``: current value of this metric. + - ``peak``: maximum value of this metric. + - ``allocated``: historical total increase in this metric. + - ``freed``: historical total decrease in this metric. + + In addition to the core statistics, we also provide some simple event + counters: + + - ``"num_alloc_retries"``: number of failed ``cudaMalloc`` calls that + result in a cache flush and retry. + - ``"num_ooms"``: number of out-of-memory errors thrown. + - ``"num_sync_all_streams"``: number of ``synchronize_and_free_events`` calls. + - ``"num_device_alloc"``: number of CUDA allocation calls. This includes both + cuMemMap and cudaMalloc. + - ``"num_device_free"``: number of CUDA free calls. This includes both cuMemUnmap + and cudaFree. + + The caching allocator can be configured via ENV to not split blocks larger than a + defined size (see Memory Management section of the Cuda Semantics documentation). + This helps avoid memory fragmentation but may have a performance + penalty. Additional outputs to assist with tuning and evaluating impact: + + - ``"max_split_size"``: blocks above this size will not be split. + - ``"oversize_allocations.{current,peak,allocated,freed}"``: + number of over-size allocation requests received by the memory allocator. + - ``"oversize_segments.{current,peak,allocated,freed}"``: + number of over-size reserved segments from ``cudaMalloc()``. + + The caching allocator can be configured via ENV to round memory allocations in order + to reduce fragmentation. Sometimes the overhead from rounding can be higher than + the fragmentation it helps reduce. The following stat can be used to check if + rounding adds too much overhead: + + - ``"requested_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + memory requested by client code, compare this with allocated_bytes to check if + allocation rounding adds too much overhead. + + Args: + device (torch.device or int, optional): selected device. Returns + statistics for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + + .. note:: + With :ref:`backend:cudaMallocAsync`, some stats are not + meaningful, and are always reported as zero. + """ + result = [] + + def _recurse_add_to_result(prefix, obj): + if isinstance(obj, dict): + if len(prefix) > 0: + prefix += "." + for k, v in obj.items(): + _recurse_add_to_result(prefix + k, v) + else: + result.append((prefix, obj)) + + stats = memory_stats_as_nested_dict(device=device) + _recurse_add_to_result("", stats) + result.sort() + + return collections.OrderedDict(result) + + +def memory_stats_as_nested_dict(device: "Device" = None) -> dict[str, Any]: + r"""Return the result of :func:`~torch.cuda.memory_stats` as a nested dictionary.""" + if not is_initialized(): + return {} + device = _get_device_index(device, optional=True) + return torch._C._cuda_memoryStats(device) + + +def reset_accumulated_memory_stats(device: "Device" = None) -> None: + r"""Reset the "accumulated" (historical) stats tracked by the CUDA memory allocator. + + See :func:`~torch.cuda.memory_stats` for details. Accumulated stats correspond to + the `"allocated"` and `"freed"` keys in each individual stat dict, as well as + `"num_alloc_retries"` and `"num_ooms"`. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + device = _get_device_index(device, optional=True) + return torch._C._cuda_resetAccumulatedMemoryStats(device) + + +def reset_peak_memory_stats(device: "Device" = None) -> None: + r"""Reset the "peak" stats tracked by the CUDA memory allocator. + + See :func:`~torch.cuda.memory_stats` for details. Peak stats correspond to the + `"peak"` key in each individual stat dict. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + device = _get_device_index(device, optional=True) + return torch._C._cuda_resetPeakMemoryStats(device) + + +def host_memory_stats() -> dict[str, Any]: + r"""Return a dictionary of CUDA memory allocator statistics for a given device. + + The return value of this function is a dictionary of statistics, each of + which is a non-negative integer. + + Core statistics: + + - ``"allocated.{current,peak,allocated,freed}"``: + number of allocation requests received by the memory allocator. + - ``"allocated_bytes.{current,peak,allocated,freed}"``: + amount of allocated memory. + - ``"segment.{current,peak,allocated,freed}"``: + number of reserved segments from ``cudaMalloc()``. + - ``"reserved_bytes.{current,peak,allocated,freed}"``: + amount of reserved memory. + + For these core statistics, values are broken down as follows. + + Metric type: + + - ``current``: current value of this metric. + - ``peak``: maximum value of this metric. + - ``allocated``: historical total increase in this metric. + - ``freed``: historical total decrease in this metric. + + In addition to the core statistics, we also provide some simple event + counters: + + - ``"num_host_alloc"``: number of CUDA allocation calls. This includes both + cudaHostAlloc and cudaHostRegister. + - ``"num_host_free"``: number of CUDA free calls. This includes both cudaHostFree + and cudaHostUnregister. + + Finally, we also provide some simple timing counters: + + - ``"host_alloc_time.{total,max,min,count,avg}"``: + timing of allocation requests going through CUDA calls. + - ``"host_free_time.{total,max,min,count,avg}"``: + timing of free requests going through CUDA calls. + + For these timing statistics, values are broken down as follows. + + Metric type: + + - ``total``: total time spent. + - ``max``: maximum value per call. + - ``min``: minimum value per call. + - ``count``: number of times it was called. + - ``avg``: average time per call. + """ + result = [] + + def _recurse_add_to_result(prefix, obj): + if isinstance(obj, dict): + if len(prefix) > 0: + prefix += "." + for k, v in obj.items(): + _recurse_add_to_result(prefix + k, v) + else: + result.append((prefix, obj)) + + stats = host_memory_stats_as_nested_dict() + _recurse_add_to_result("", stats) + result.sort() + + return collections.OrderedDict(result) + + +def host_memory_stats_as_nested_dict() -> dict[str, Any]: + r"""Return the result of :func:`~torch.cuda.host_memory_stats` as a nested dictionary.""" + if not is_initialized(): + return {} + return torch._C._cuda_hostMemoryStats() + + +def reset_accumulated_host_memory_stats() -> None: + r"""Reset the "accumulated" (historical) stats tracked by the host memory allocator. + + See :func:`~torch.cuda.host_memory_stats` for details. Accumulated stats correspond to + the `"allocated"` and `"freed"` keys in each individual stat dict. + """ + return torch._C._cuda_resetAccumulatedHostMemoryStats() + + +def reset_peak_host_memory_stats() -> None: + r"""Reset the "peak" stats tracked by the host memory allocator. + + See :func:`~torch.cuda.host_memory_stats` for details. Peak stats correspond to the + `"peak"` key in each individual stat dict. + """ + return torch._C._cuda_resetPeakHostMemoryStats() + + +def reset_max_memory_allocated(device: "Device" = None) -> None: + r"""Reset the starting point in tracking maximum GPU memory occupied by tensors for a given device. + + See :func:`~torch.cuda.max_memory_allocated` for details. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + .. warning:: + This function now calls :func:`~torch.cuda.reset_peak_memory_stats`, which resets + /all/ peak memory stats. + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + warnings.warn( + "torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, " + "which resets /all/ peak memory stats.", + FutureWarning, + ) + return reset_peak_memory_stats(device=device) + + +def reset_max_memory_cached(device: "Device" = None) -> None: + r"""Reset the starting point in tracking maximum GPU memory managed by the caching allocator for a given device. + + See :func:`~torch.cuda.max_memory_cached` for details. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + .. warning:: + This function now calls :func:`~torch.cuda.reset_peak_memory_stats`, which resets + /all/ peak memory stats. + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + warnings.warn( + "torch.cuda.reset_max_memory_cached now calls torch.cuda.reset_peak_memory_stats, " + "which resets /all/ peak memory stats.", + FutureWarning, + ) + return reset_peak_memory_stats(device=device) + + +def memory_allocated(device: "Device" = None) -> int: + r"""Return the current GPU memory occupied by tensors in bytes for a given device. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + .. note:: + This is likely less than the amount shown in `nvidia-smi` since some + unused memory can be held by the caching allocator and some context + needs to be created on GPU. See :ref:`cuda-memory-management` for more + details about GPU memory management. + """ + return memory_stats(device=device).get("allocated_bytes.all.current", 0) + + +def max_memory_allocated(device: "Device" = None) -> int: + r"""Return the maximum GPU memory occupied by tensors in bytes for a given device. + + By default, this returns the peak allocated memory since the beginning of + this program. :func:`~torch.cuda.reset_peak_memory_stats` can be used to + reset the starting point in tracking this metric. For example, these two + functions can measure the peak allocated memory usage of each iteration in a + training loop. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + return memory_stats(device=device).get("allocated_bytes.all.peak", 0) + + +def memory_reserved(device: "Device" = None) -> int: + r"""Return the current GPU memory managed by the caching allocator in bytes for a given device. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + return memory_stats(device=device).get("reserved_bytes.all.current", 0) + + +def max_memory_reserved(device: "Device" = None) -> int: + r"""Return the maximum GPU memory managed by the caching allocator in bytes for a given device. + + By default, this returns the peak cached memory since the beginning of this + program. :func:`~torch.cuda.reset_peak_memory_stats` can be used to reset + the starting point in tracking this metric. For example, these two functions + can measure the peak cached memory amount of each iteration in a training + loop. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + return memory_stats(device=device).get("reserved_bytes.all.peak", 0) + + +@deprecated( + "`torch.cuda.memory_cached` has been renamed to `torch.cuda.memory_reserved`", + category=FutureWarning, +) +def memory_cached(device: "Device" = None) -> int: + r"""Deprecated; see :func:`~torch.cuda.memory_reserved`.""" + return memory_reserved(device=device) + + +@deprecated( + "`torch.cuda.max_memory_cached` has been renamed to `torch.cuda.max_memory_reserved`", + category=FutureWarning, +) +def max_memory_cached(device: "Device" = None) -> int: + r"""Deprecated; see :func:`~torch.cuda.max_memory_reserved`.""" + return max_memory_reserved(device=device) + + +def memory_snapshot(mempool_id=None): + r"""Return a snapshot of the CUDA memory allocator state across all devices. + + Interpreting the output of this function requires familiarity with the + memory allocator internals. + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + return torch._C._cuda_memorySnapshot(mempool_id)["segments"] + + +def memory_summary(device: "Device" = None, abbreviated: bool = False) -> str: + r"""Return a human-readable printout of the current memory allocator statistics for a given device. + + This can be useful to display periodically during training, or when + handling out-of-memory exceptions. + + Args: + device (torch.device or int, optional): selected device. Returns + printout for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + abbreviated (bool, optional): whether to return an abbreviated summary + (default: False). + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + device = _get_device_index(device, optional=True) + stats = memory_stats(device=device) + + def _format_size(sz, pref_sz): + prefixes = ["B ", "KiB", "MiB", "GiB", "TiB", "PiB"] + prefix = prefixes[0] + for new_prefix in prefixes[1:]: + if pref_sz < 768 * 1024: + break + prefix = new_prefix + sz //= 1024 + pref_sz /= 1024 + return f"{sz:6d} {prefix}" + + def _format_count(cnt, pref_cnt): + prefixes = [" ", "K", "M"] + prefix = prefixes[0] + for new_prefix in prefixes[1:]: + if pref_cnt < 750 * 1000: + break + prefix = new_prefix + cnt //= 1000 + pref_cnt /= 1000 + return f"{cnt:7d} {prefix} " + + metrics_to_display = [ + ("allocated_bytes", "Allocated memory", _format_size), + ("active_bytes", "Active memory", _format_size), + ("requested_bytes", "Requested memory", _format_size), + ("reserved_bytes", "GPU reserved memory", _format_size), + ("inactive_split_bytes", "Non-releasable memory", _format_size), + ("allocation", "Allocations", _format_count), + ("active", "Active allocs", _format_count), + ("segment", "GPU reserved segments", _format_count), + ("inactive_split", "Non-releasable allocs", _format_count), + ] + + lines = [] + lines.append("=" * 75) + lines.append(" {_:16} PyTorch CUDA memory summary, device ID {device:<17d} ") + lines.append("-" * 75) + lines.append( + " {_:9} CUDA OOMs: {num_ooms:<12d} | {_:6} cudaMalloc retries: {num_alloc_retries:<8d} " + ) + lines.append("=" * 75) + lines.append( + " Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed " + ) + + for metric_key, metric_name, formatter in metrics_to_display: + lines.append("-" * 75) + submetrics = [("all", metric_name)] + if not abbreviated: + submetrics.append(("large_pool", " from large pool")) + submetrics.append(("small_pool", " from small pool")) + + current_prefval, peak_prefval, allocated_prefval, freed_prefval = ( + None, + None, + None, + None, + ) + + for submetric_key, submetric_name in submetrics: + prefix = metric_key + "." + submetric_key + "." + + current = stats[prefix + "current"] + peak = stats[prefix + "peak"] + allocated = stats[prefix + "allocated"] + freed = stats[prefix + "freed"] + + if current_prefval is None: + current_prefval = current + peak_prefval = peak + allocated_prefval = allocated + freed_prefval = freed + + lines.append( + f" {submetric_name:<21} | {formatter(current, current_prefval)} | {formatter(peak, peak_prefval)} | " + f"{formatter(allocated, allocated_prefval)} | {formatter(freed, freed_prefval)} ", + ) + + metrics_to_display = [ + ("oversize_allocations", "Oversize allocations", _format_count), + ("oversize_segments", "Oversize GPU segments", _format_count), + ] + + for metric_key, metric_name, formatter in metrics_to_display: + lines.append("-" * 75) + + prefix = metric_key + "." + + current = stats[prefix + "current"] + peak = stats[prefix + "peak"] + allocated = stats[prefix + "allocated"] + freed = stats[prefix + "freed"] + + lines.append( + f" {metric_name:<21} | {formatter(current, current)} | {formatter(peak, peak)} | " + f"{formatter(allocated, allocated)} | {formatter(freed, freed)} ", + ) + + lines.append("=" * 75) + + fmt_dict = {"_": "", "device": device} + for k, v in stats.items(): + fmt_dict[k.replace(".", "-")] = v + return "|" + "|\n|".join(lines).format(**fmt_dict) + "|\n" + + +def list_gpu_processes(device: "Device" = None) -> str: + r"""Return a human-readable printout of the running processes and their GPU memory use for a given device. + + This can be useful to display periodically during training, or when + handling out-of-memory exceptions. + + Args: + device (torch.device or int, optional): selected device. Returns + printout for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + """ + if not torch.version.hip: + try: + import pynvml # type: ignore[import] + except ModuleNotFoundError: + return "pynvml module not found, please install pynvml" + from pynvml import NVMLError_DriverNotLoaded + + try: + pynvml.nvmlInit() + except NVMLError_DriverNotLoaded: + return "cuda driver can't be loaded, is cuda enabled?" + + device = _get_nvml_device_index(device) + handle = pynvml.nvmlDeviceGetHandleByIndex(device) + procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle) + else: + try: + import amdsmi # type: ignore[import] + except ModuleNotFoundError: + return "amdsmi module not found, please install amdsmi" + try: + amdsmi.amdsmi_init() # type: ignore[attr-defined] + except amdsmi.AmdSmiException: # type: ignore[attr-defined] + return "amdsmi driver can't be loaded, is ROCm installed?" + + device = _get_amdsmi_device_index(device) + + try: + handle = amdsmi.amdsmi_get_processor_handles()[device] # type: ignore[attr-defined] + procs = amdsmi.amdsmi_get_gpu_process_list(handle) # type: ignore[attr-defined] + except amdsmi.AmdSmiException: # type: ignore[attr-defined] + return "amdsmi cannot list processes from other users" + + lines = [] + lines.append(f"GPU:{device}") + if len(procs) == 0: + lines.append("no processes are running") + for p in procs: + if not torch.version.hip: + mem = p.usedGpuMemory / (1024 * 1024) + pid = p.pid + else: + try: + proc_info = amdsmi.amdsmi_get_gpu_process_info(handle, p) # type: ignore[possibly-undefined] + except AttributeError: + # https://github.com/ROCm/amdsmi/commit/c551c3caedbd903ba828e7fdffa5b56d475a15e7 + # is a BC-breaking change that removes amdsmi_get_gpu_process_info API from amdsmi + proc_info = p + mem = proc_info["memory_usage"]["vram_mem"] / (1024 * 1024) + pid = proc_info["pid"] + lines.append(f"process {pid:>10d} uses {mem:>12.3f} MB GPU memory") + return "\n".join(lines) + + +def mem_get_info(device: "Device" = None) -> tuple[int, int]: + r"""Return the global free and total GPU memory for a given device using cudaMemGetInfo. + + Args: + device (torch.device or int or str, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default) or if the device index is not specified. + + .. note:: + See :ref:`cuda-memory-management` for more + details about GPU memory management. + """ + if device is None: + device = torch.cuda.current_device() + # optional=True allows `device = torch.device('cuda')` for which device.index is None + device = _get_device_index(device, optional=True) + return torch.cuda.cudart().cudaMemGetInfo(device) + + +def _record_memory_history_legacy( + enabled: bool, + record_context=True, + trace_alloc_max_entries=1, + trace_alloc_record_context=False, + device: "Device" = None, + record_context_cpp=False, + clear_history=False, + compile_context=False, + global_record_annotations=False, +): + _C._cuda_record_memory_history_legacy( # type: ignore[call-arg] + enabled, + record_context, + trace_alloc_max_entries, + trace_alloc_record_context, + record_context_cpp, + clear_history, + compile_context, + global_record_annotations, + ) + + +def _record_memory_history( + enabled: Literal[None, "state", "all"] = "all", *args, **kwargs +) -> None: + """Enable recording of stack traces associated with memory + allocations, so you can tell what allocated any piece of memory in + :func:`torch.cuda.memory._snapshot()`. + + In addition to keeping stack traces with each current allocation and free, + this will also enable recording of a history of all alloc/free events. + + Use :func:`torch.cuda.memory._snapshot()` to retrieve this information, + and the tools in `_memory_viz.py` to visualize snapshots. + + Buffer behavior + --------------- + + This will store up to `max_entries` instances of `TraceEntry` when enabled. + Python trace collection defaults to `sys.maxsize`, meaning long-running + or indefinitely running jobs should set a reasonable limit to avoid excessive + memory use. Expect each entry to be several KB. + + Longer running workflows or those with smaller `max_entries` values will only + store the last accumulated `max_entries` entries, meaning new entries overwrite + older entries. + + C++ implementation for reference to ring buffer implemenation: + + .. code-block:: cpp + + if (record_history) { + if (alloc_trace->size() < alloc_trace_max_entries_) { + alloc_trace->emplace_back(te); + } else { + (*alloc_trace)[alloc_trace_next++] = te; + if (alloc_trace_next == alloc_trace_max_entries_) { + alloc_trace_next = 0; + } + } + } + + Latency impact + -------------- + + The Python trace collection is fast (2us per trace), so you may consider + enabling this on production jobs if you anticipate ever having to debug + memory issues. + + C++ trace collection is also fast (~50ns/frame), which for many typical programs + works out to ~2us per trace, but can vary depending on stack depth. + + Args: + enabled (Literal[None, "state", "all"], optional): + `None`, disable recording memory history. + `"state"`, keep information for currenly allocated memory. + `"all"`, additionally keep a history of all alloc/free calls. + Defaults to "all". + context (Literal[None, "state", "alloc", "all"], optional): + `None`, Do not record any tracebacks. + `"state"`, Record tracebacks for currently allocated memory. + `"alloc"`, additionally keep tracebacks for alloc calls. + `"all"`, additionally keep tracebacks for free calls. + Defaults to "all". + stacks (Literal["python", "all"], optional): + `"python"`, include Python, TorchScript, and inductor frames in tracebacks + `"all"`, additionally include C++ frames + Defaults to "all". + max_entries (int, optional): Keep a maximum of `max_entries` + alloc/free events in the recorded history recorded. + """ + if isinstance(enabled, bool): + return _record_memory_history_legacy(enabled, *args, **kwargs) + else: + return _record_memory_history_impl(enabled, *args, **kwargs) + + +def _record_memory_history_impl( + enabled: Optional[str] = "all", + context: Optional[str] = "all", + stacks: str = "all", + max_entries: int = sys.maxsize, + device: "Device" = None, + clear_history: bool = False, + compile_context: bool = False, + global_record_annotations: bool = False, +): + _C._cuda_record_memory_history( # type: ignore[call-arg] + enabled, + context, + stacks, + max_entries, + clear_history, + compile_context, + global_record_annotations, + ) + + +_record_memory_history.__signature__ = signature(_record_memory_history_impl) # type: ignore[attr-defined] + + +def _snapshot(device: "Device" = None): + """Save a snapshot of CUDA memory state at the time it was called. + + The state is represented as a dictionary with the following structure. + + .. code-block:: python + + class Snapshot(TypedDict): + segments : List[Segment] + device_traces: List[List[TraceEntry]] + + class Segment(TypedDict): + # Segments are memory returned from a cudaMalloc call. + # The size of reserved memory is the sum of all Segments. + # Segments are cached and reused for future allocations. + # If the reuse is smaller than the segment, the segment + # is split into more then one Block. + # empty_cache() frees Segments that are entirely inactive. + address: int + total_size: int # cudaMalloc'd size of segment + stream: int + segment_type: Literal['small', 'large'] # 'large' (>1MB) + allocated_size: int # size of memory in use + active_size: int # size of memory in use or in active_awaiting_free state + blocks : List[Block] + + class Block(TypedDict): + # A piece of memory returned from the allocator, or + # current cached but inactive. + size: int + requested_size: int # size requested during malloc, may be smaller than + # size due to rounding + address: int + state: Literal['active_allocated', # used by a tensor + 'active_awaiting_free', # waiting for another stream to finish using + # this, then it will become free + 'inactive',] # free for reuse + frames: List[Frame] # stack trace from where the allocation occurred + + class Frame(TypedDict): + filename: str + line: int + name: str + + class TraceEntry(TypedDict): + # When `torch.cuda.memory._record_memory_history()` is enabled, + # the snapshot will contain TraceEntry objects that record each + # action the allocator took. + action: Literal[ + 'alloc' # memory allocated + 'free_requested', # the allocated received a call to free memory + 'free_completed', # the memory that was requested to be freed is now + # able to be used in future allocation calls + 'segment_alloc', # the caching allocator ask cudaMalloc for more memory + # and added it as a segment in its cache + 'segment_free', # the caching allocator called cudaFree to return memory + # to cuda possibly trying free up memory to + # allocate more segments or because empty_caches was called + 'oom', # the allocator threw an OOM exception. 'size' is + # the requested number of bytes that did not succeed + 'snapshot' # the allocator generated a memory snapshot + # useful to coorelate a previously taken + # snapshot with this trace + ] + addr: int # not present for OOM + frames: List[Frame] + size: int + stream: int + device_free: int # only present for OOM, the amount of + # memory cuda still reports to be free + + Returns: + The Snapshot dictionary object + """ + return _C._cuda_memorySnapshot(None) + + +def _dump_snapshot(filename="dump_snapshot.pickle"): + """ + Save a pickled version of the `torch.memory._snapshot()` dictionary to a file. + + This file can be opened by the interactive snapshot viewer at pytorch.org/memory_viz + + Snapshot file sizes scale with `max_entries` and stack trace depth per entry, + with several KB per entry. These can easily be in the GB range for longer running + workflows with large `max_entries`. + + Args: + filename (str, optional): Name of the file to create. Defaults to "dump_snapshot.pickle". + """ + s = _snapshot() + with open(filename, "wb") as f: + pickle.dump(s, f) + + +def _save_segment_usage(filename="output.svg", snapshot=None): + if snapshot is None: + snapshot = _snapshot() + with open(filename, "w") as f: + f.write(_segments(snapshot)) + + +def _save_memory_usage(filename="output.svg", snapshot=None): + if snapshot is None: + snapshot = _snapshot() + with open(filename, "w") as f: + f.write(_memory(snapshot)) + + +def _set_allocator_settings(env: str): + return torch._C._cuda_cudaCachingAllocator_set_allocator_settings(env) + + +def get_allocator_backend() -> str: + r"""Return a string describing the active allocator backend as set by + ``PYTORCH_CUDA_ALLOC_CONF``. Currently available backends are + ``native`` (PyTorch's native caching allocator) and `cudaMallocAsync`` + (CUDA's built-in asynchronous allocator). + + .. note:: + See :ref:`cuda-memory-management` for details on choosing the allocator backend. + """ + return torch._C._cuda_getAllocatorBackend() + + +class _CUDAAllocator: + r"""Wrapper over internal CUDA memory allocators.""" + + def __init__(self, allocator: torch._C._cuda_CUDAAllocator): + self._allocator = allocator + + def allocator(self): + return self._allocator + + +class CUDAPluggableAllocator(_CUDAAllocator): + r"""CUDA memory allocator loaded from a so file.""" + + def __init__(self, path_to_so_file: str, alloc_fn_name: str, free_fn_name: str): + r"""Memory allocators are compiled in .so files and loaded dynamically using ctypes. + + To change the active allocator use the :func:`torch.memory.cuda.change_current_allocator` function. + + Args: + path_to_so_file(str): Path in the filesystem to the `.so` file containing + the allocator functions + alloc_fn_name(str): Name of the function to perform the memory allocation + in the so file. The signature must be: + void* alloc_fn_name(ssize_t size, int device, cudaStream_t stream); + free_fn_name(str): Name of the function to perform the memory release + in the so file. The signature must be: + void free_fn_name(void* ptr, size_t size, cudaStream_t stream); + + .. warning:: + This is currently supported only in unix OSs + + .. note:: + See :ref:`cuda-memory-management` for details on creating and using a custom allocator + """ + allocator = ctypes.CDLL(path_to_so_file) + alloc_fn = ctypes.cast(getattr(allocator, alloc_fn_name), ctypes.c_void_p).value + free_fn = ctypes.cast(getattr(allocator, free_fn_name), ctypes.c_void_p).value + assert alloc_fn is not None + assert free_fn is not None + self._allocator = torch._C._cuda_customAllocator(alloc_fn, free_fn) + + +def change_current_allocator(allocator: _CUDAAllocator) -> None: + r"""Change the currently used memory allocator to be the one provided. + + If the current allocator has already been used/initialized, this function will error. + + + Args: + allocator (torch.cuda.memory._CUDAAllocator): allocator to be set as the active one. + .. note:: + See :ref:`cuda-memory-management` for details on creating and using a custom allocator + """ + torch._C._cuda_changeCurrentAllocator(allocator.allocator()) + + +def _get_current_allocator() -> _CUDAAllocator: + r"""Return the allocator being currently used. + + .. note:: + See :ref:`cuda-memory-management` for details on creating and using a custom allocator + """ + return _CUDAAllocator(torch._C._cuda_getAllocator()) + + +class MemPool(_MemPool): + r"""MemPool represents a pool of memory in a caching allocator. Currently, + it's just the ID of the pool object maintained in the CUDACachingAllocator. + + Args: + allocator(torch._C._cuda_CUDAAllocator, optional): a + torch._C._cuda_CUDAAllocator object that can be used to + define how memory gets allocated in the pool. If :attr:`allocator` + is ``None`` (default), memory allocation follows the default/ + current configuration of the CUDACachingAllocator. + use_on_oom(bool): a bool that indicates if this pool can be used + as a last resort if a memory allocation outside of the pool fails due + to Out Of Memory. This is False by default. + symmetric(bool): a bool that indicates if this pool is symmetrical + across ranks. This is False by default. + """ + + def __init__( + self, + allocator: Optional[_cuda_CUDAAllocator] = None, + use_on_oom: bool = False, + symmetric: bool = False, + ): + super().__init__(allocator, True, use_on_oom, symmetric) + + @property + def id(self) -> tuple[int, int]: + r"""Returns the ID of this pool as a tuple of two ints.""" + return super().id + + @property + def is_symmetric(self) -> bool: + r"""Returns whether this pool is used for NCCL's symmetric memory.""" + return super().is_symmetric + + @property + def allocator(self) -> Optional[_cuda_CUDAAllocator]: + r"""Returns the allocator this MemPool routes allocations to.""" + return super().allocator + + def use_count(self) -> int: + r"""Returns the reference count of this pool.""" + return super().use_count() + + def snapshot(self): + r"""Return a snapshot of the CUDA memory allocator pool state across all + devices. + + Interpreting the output of this function requires familiarity with the + memory allocator internals. + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + snapshot = torch.cuda.memory_snapshot(self.id) + return snapshot + + +@contextlib.contextmanager +def use_mem_pool(pool: MemPool, device: "Device" = None): + r"""A context manager that routes allocations to a given pool. + + Args: + pool(torch.cuda.MemPool): a MemPool object to be made active so that + allocations route to this pool. + device (torch.device or int, optional): selected device. Uses MemPool on + the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + .. note:: + This context manager makes only current thread's allocations route to + the given pool. If a new thread is spawned inside the context manager + (e.g. by calling backward) the allocations in that thread will not + route to the given pool. + """ + device_index = ( + torch.cuda.current_device() if device is None else _get_device_index(device) + ) + _cuda_beginAllocateCurrentThreadToPool(device_index, pool.id) + try: + yield + finally: + _cuda_endAllocateToPool(device_index, pool.id) + _cuda_releasePool(device_index, pool.id) diff --git a/phivenv/Lib/site-packages/torch/cuda/nccl.py b/phivenv/Lib/site-packages/torch/cuda/nccl.py new file mode 100644 index 0000000000000000000000000000000000000000..470beca2fa67cfa1f70f35496b139a62810dc065 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/cuda/nccl.py @@ -0,0 +1,152 @@ +# mypy: allow-untyped-defs +import collections +import warnings +from collections.abc import Sequence +from typing import Optional, Union + +import torch.cuda + + +__all__ = ["all_reduce", "reduce", "broadcast", "all_gather", "reduce_scatter"] + +SUM = 0 # ncclRedOp_t + + +def is_available(tensors): + if not hasattr(torch._C, "_nccl_all_reduce"): + warnings.warn("PyTorch is not compiled with NCCL support") + return False + + devices = set() + for tensor in tensors: + if tensor.is_sparse: + return False + if not tensor.is_contiguous(): + return False + if not tensor.is_cuda: + return False + device = tensor.get_device() + if device in devices: + return False + devices.add(device) + + return True + + +def version(): + """ + Returns the version of the NCCL. + + + This function returns a tuple containing the major, minor, and patch version numbers of the NCCL. + The suffix is also included in the tuple if a version suffix exists. + Returns: + tuple: The version information of the NCCL. + """ + ver = torch._C._nccl_version() + major = ver >> 32 + minor = (ver >> 16) & 65535 + patch = ver & 65535 + suffix = torch._C._nccl_version_suffix().decode("utf-8") + if suffix == "": + return (major, minor, patch) + else: + return (major, minor, patch, suffix) + + +def unique_id(): + return torch._C._nccl_unique_id() + + +def init_rank(num_ranks, uid, rank): + return torch._C._nccl_init_rank(num_ranks, uid, rank) + + +def _check_sequence_type(inputs: Union[torch.Tensor, Sequence[torch.Tensor]]) -> None: + if not isinstance(inputs, collections.abc.Container) or isinstance( + inputs, torch.Tensor + ): + raise TypeError("Inputs should be a collection of tensors") + + +def all_reduce(inputs, outputs=None, op=SUM, streams=None, comms=None): + _check_sequence_type(inputs) + if outputs is None: + outputs = inputs + _check_sequence_type(outputs) + torch._C._nccl_all_reduce(inputs, outputs, op, streams, comms) + + +# `output` used to be `outputs`, taking in a list of tensors. So we have two +# arguments for BC reasons. +def reduce( + inputs: Sequence[torch.Tensor], + output: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]] = None, + root: int = 0, + op: int = SUM, + streams: Optional[Sequence[torch.cuda.Stream]] = None, + comms=None, + *, + outputs: Optional[Sequence[torch.Tensor]] = None, +) -> None: + _check_sequence_type(inputs) + _output: torch.Tensor + if outputs is not None: + if output is not None: + raise ValueError( + "'output' and 'outputs' can not be both specified. 'outputs' is deprecated in " + "favor of 'output', taking in a single output tensor. The signature of reduce is: " + "reduce(inputs, output=None, root=0, op=SUM, streams=None, comms=None)." + ) + else: + warnings.warn( + "`nccl.reduce` with an output tensor list is deprecated. " + "Please specify a single output tensor with argument 'output' instead instead.", + FutureWarning, + stacklevel=2, + ) + _output = outputs[root] + elif not isinstance(output, torch.Tensor) and isinstance( + output, collections.abc.Sequence + ): + # User called old API with positional arguments of list of output tensors. + warnings.warn( + "nccl.reduce with an output tensor list is deprecated. " + "Please specify a single output tensor.", + FutureWarning, + stacklevel=2, + ) + _output = output[root] + else: + _output = inputs[root] if output is None else output + torch._C._nccl_reduce(inputs, _output, root, op, streams, comms) + + +def broadcast( + inputs: Sequence[torch.Tensor], root: int = 0, streams=None, comms=None +) -> None: + _check_sequence_type(inputs) + torch._C._nccl_broadcast(inputs, root, streams, comms) + + +def all_gather( + inputs: Sequence[torch.Tensor], + outputs: Sequence[torch.Tensor], + streams=None, + comms=None, +) -> None: + _check_sequence_type(inputs) + _check_sequence_type(outputs) + torch._C._nccl_all_gather(inputs, outputs, streams, comms) + + +def reduce_scatter( + inputs: Sequence[torch.Tensor], + outputs: Sequence[torch.Tensor], + op: int = SUM, + streams=None, + comms=None, +) -> None: + _check_sequence_type(inputs) + _check_sequence_type(outputs) + torch._C._nccl_reduce_scatter(inputs, outputs, op, streams, comms) diff --git a/phivenv/Lib/site-packages/torch/cuda/nvtx.py b/phivenv/Lib/site-packages/torch/cuda/nvtx.py new file mode 100644 index 0000000000000000000000000000000000000000..6f0da812028a4406e59770d7c94f59403a26ff80 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/cuda/nvtx.py @@ -0,0 +1,125 @@ +# mypy: allow-untyped-defs +r"""This package adds support for NVIDIA Tools Extension (NVTX) used in profiling.""" + +from contextlib import contextmanager + + +try: + from torch._C import _nvtx +except ImportError: + + class _NVTXStub: + @staticmethod + def _fail(*args, **kwargs): + raise RuntimeError( + "NVTX functions not installed. Are you sure you have a CUDA build?" + ) + + rangePushA = _fail + rangePop = _fail + markA = _fail + + _nvtx = _NVTXStub() # type: ignore[assignment] + +__all__ = ["range_push", "range_pop", "range_start", "range_end", "mark", "range"] + + +def range_push(msg): + """ + Push a range onto a stack of nested range span. Returns zero-based depth of the range that is started. + + Args: + msg (str): ASCII message to associate with range + """ + return _nvtx.rangePushA(msg) + + +def range_pop(): + """Pop a range off of a stack of nested range spans. Returns the zero-based depth of the range that is ended.""" + return _nvtx.rangePop() + + +def range_start(msg) -> int: + """ + Mark the start of a range with string message. It returns an unique handle + for this range to pass to the corresponding call to rangeEnd(). + + A key difference between this and range_push/range_pop is that the + range_start/range_end version supports range across threads (start on one + thread and end on another thread). + + Returns: A range handle (uint64_t) that can be passed to range_end(). + + Args: + msg (str): ASCII message to associate with the range. + """ + return _nvtx.rangeStartA(msg) + + +def range_end(range_id) -> None: + """ + Mark the end of a range for a given range_id. + + Args: + range_id (int): an unique handle for the start range. + """ + _nvtx.rangeEnd(range_id) + + +def _device_range_start(msg: str, stream: int = 0) -> object: + """ + Marks the start of a range with string message. + It returns an opaque heap-allocated handle for this range + to pass to the corresponding call to device_range_end(). + + A key difference between this and range_start is that the + range_start marks the range right away, while _device_range_start + marks the start of the range as soon as all the tasks on the + CUDA stream are completed. + + Returns: An opaque heap-allocated handle that should be passed to _device_range_end(). + + Args: + msg (str): ASCII message to associate with the range. + stream (int): CUDA stream id. + """ + return _nvtx.deviceRangeStart(msg, stream) + + +def _device_range_end(range_handle: object, stream: int = 0) -> None: + """ + Mark the end of a range for a given range_handle as soon as all the tasks + on the CUDA stream are completed. + + Args: + range_handle: an unique handle for the start range. + stream (int): CUDA stream id. + """ + _nvtx.deviceRangeEnd(range_handle, stream) + + +def mark(msg): + """ + Describe an instantaneous event that occurred at some point. + + Args: + msg (str): ASCII message to associate with the event. + """ + return _nvtx.markA(msg) + + +@contextmanager +def range(msg, *args, **kwargs): + """ + Context manager / decorator that pushes an NVTX range at the beginning + of its scope, and pops it at the end. If extra arguments are given, + they are passed as arguments to msg.format(). + + Args: + msg (str): message to associate with the range + """ + range_push(msg.format(*args, **kwargs)) + try: + yield + finally: + range_pop() diff --git a/phivenv/Lib/site-packages/torch/cuda/profiler.py b/phivenv/Lib/site-packages/torch/cuda/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..d99ef0ff3cdc4541f124aac7461f1bdf02631a2d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/cuda/profiler.py @@ -0,0 +1,86 @@ +# mypy: allow-untyped-defs +import contextlib +import tempfile + +import torch + +from . import check_error, cudart + + +__all__ = ["init", "start", "stop", "profile"] + +DEFAULT_FLAGS = [ + "gpustarttimestamp", + "gpuendtimestamp", + "gridsize3d", + "threadblocksize", + "streamid", + "enableonstart 0", + "conckerneltrace", +] + + +def init(output_file, flags=None, output_mode="key_value"): + rt = cudart() + if not hasattr(rt, "cudaOutputMode"): + raise AssertionError("HIP does not support profiler initialization!") + if ( + hasattr(torch.version, "cuda") + and torch.version.cuda is not None + and int(torch.version.cuda.split(".")[0]) >= 12 + ): + # Check https://github.com/pytorch/pytorch/pull/91118 + # cudaProfilerInitialize is no longer needed after CUDA 12 + raise AssertionError("CUDA12+ does not need profiler initialization!") + flags = DEFAULT_FLAGS if flags is None else flags + if output_mode == "key_value": + output_mode_enum = rt.cudaOutputMode.KeyValuePair + elif output_mode == "csv": + output_mode_enum = rt.cudaOutputMode.CSV + else: + raise RuntimeError( + "supported CUDA profiler output modes are: key_value and csv" + ) + with tempfile.NamedTemporaryFile(delete=True) as f: + f.write(b"\n".join(f.encode("ascii") for f in flags)) + f.flush() + check_error(rt.cudaProfilerInitialize(f.name, output_file, output_mode_enum)) + + +def start(): + r"""Starts cuda profiler data collection. + + .. warning:: + Raises CudaError in case of it is unable to start the profiler. + """ + check_error(cudart().cudaProfilerStart()) + + +def stop(): + r"""Stops cuda profiler data collection. + + .. warning:: + Raises CudaError in case of it is unable to stop the profiler. + """ + check_error(cudart().cudaProfilerStop()) + + +@contextlib.contextmanager +def profile(): + """ + Enable profiling. + + Context Manager to enabling profile collection by the active profiling tool from CUDA backend. + Example: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> import torch + >>> model = torch.nn.Linear(20, 30).cuda() + >>> inputs = torch.randn(128, 20).cuda() + >>> with torch.cuda.profiler.profile() as prof: + ... model(inputs) + """ + try: + start() + yield + finally: + stop() diff --git a/phivenv/Lib/site-packages/torch/cuda/random.py b/phivenv/Lib/site-packages/torch/cuda/random.py new file mode 100644 index 0000000000000000000000000000000000000000..d5dee5052d284c0309d2bbd2224f5688726402d2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/cuda/random.py @@ -0,0 +1,184 @@ +# mypy: allow-untyped-defs +from collections.abc import Iterable +from typing import Union + +import torch +from torch import Tensor + +from . import _lazy_call, _lazy_init, current_device, device_count, is_initialized + + +__all__ = [ + "get_rng_state", + "get_rng_state_all", + "set_rng_state", + "set_rng_state_all", + "manual_seed", + "manual_seed_all", + "seed", + "seed_all", + "initial_seed", +] + + +def get_rng_state(device: Union[int, str, torch.device] = "cuda") -> Tensor: + r"""Return the random number generator state of the specified GPU as a ByteTensor. + + Args: + device (torch.device or int, optional): The device to return the RNG state of. + Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device). + + .. warning:: + This function eagerly initializes CUDA. + """ + _lazy_init() + if isinstance(device, str): + device = torch.device(device) + elif isinstance(device, int): + device = torch.device("cuda", device) + idx = device.index + if idx is None: + idx = current_device() + default_generator = torch.cuda.default_generators[idx] + return default_generator.get_state() + + +def get_rng_state_all() -> list[Tensor]: + r"""Return a list of ByteTensor representing the random number states of all devices.""" + results = [get_rng_state(i) for i in range(device_count())] + return results + + +def set_rng_state( + new_state: Tensor, device: Union[int, str, torch.device] = "cuda" +) -> None: + r"""Set the random number generator state of the specified GPU. + + Args: + new_state (torch.ByteTensor): The desired state + device (torch.device or int, optional): The device to set the RNG state. + Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device). + """ + if not is_initialized(): + with torch._C._DisableFuncTorch(): + # Clone the state because the callback will be triggered + # later when CUDA is lazy initialized. + new_state = new_state.clone(memory_format=torch.contiguous_format) + if isinstance(device, str): + device = torch.device(device) + elif isinstance(device, int): + device = torch.device("cuda", device) + + def cb(): + idx = device.index + if idx is None: + idx = current_device() + default_generator = torch.cuda.default_generators[idx] + default_generator.set_state(new_state) + + _lazy_call(cb) + + +def set_rng_state_all(new_states: Iterable[Tensor]) -> None: + r"""Set the random number generator state of all devices. + + Args: + new_states (Iterable of torch.ByteTensor): The desired state for each device. + """ + for i, state in enumerate(new_states): + set_rng_state(state, i) + + +def manual_seed(seed: int) -> None: + r"""Set the seed for generating random numbers for the current GPU. + + It's safe to call this function if CUDA is not available; in that + case, it is silently ignored. + + Args: + seed (int): The desired seed. + + .. warning:: + If you are working with a multi-GPU model, this function is insufficient + to get determinism. To seed all GPUs, use :func:`manual_seed_all`. + """ + seed = int(seed) + + def cb(): + idx = current_device() + default_generator = torch.cuda.default_generators[idx] + default_generator.manual_seed(seed) + + _lazy_call(cb, seed=True) + + +def manual_seed_all(seed: int) -> None: + r"""Set the seed for generating random numbers on all GPUs. + + It's safe to call this function if CUDA is not available; in that + case, it is silently ignored. + + Args: + seed (int): The desired seed. + """ + seed = int(seed) + + def cb(): + for i in range(device_count()): + default_generator = torch.cuda.default_generators[i] + default_generator.manual_seed(seed) + + _lazy_call(cb, seed_all=True) + + +def seed() -> None: + r"""Set the seed for generating random numbers to a random number for the current GPU. + + It's safe to call this function if CUDA is not available; in that + case, it is silently ignored. + + .. warning:: + If you are working with a multi-GPU model, this function will only initialize + the seed on one GPU. To initialize all GPUs, use :func:`seed_all`. + """ + + def cb(): + idx = current_device() + default_generator = torch.cuda.default_generators[idx] + default_generator.seed() + + _lazy_call(cb) + + +def seed_all() -> None: + r"""Set the seed for generating random numbers to a random number on all GPUs. + + It's safe to call this function if CUDA is not available; in that + case, it is silently ignored. + """ + + def cb(): + random_seed = 0 + seeded = False + for i in range(device_count()): + default_generator = torch.cuda.default_generators[i] + if not seeded: + default_generator.seed() + random_seed = default_generator.initial_seed() + seeded = True + else: + default_generator.manual_seed(random_seed) + + _lazy_call(cb) + + +def initial_seed() -> int: + r"""Return the current random seed of the current GPU. + + .. warning:: + This function eagerly initializes CUDA. + """ + _lazy_init() + idx = current_device() + default_generator = torch.cuda.default_generators[idx] + return default_generator.initial_seed() diff --git a/phivenv/Lib/site-packages/torch/cuda/sparse.py b/phivenv/Lib/site-packages/torch/cuda/sparse.py new file mode 100644 index 0000000000000000000000000000000000000000..702def052945ad1bd54ab221ae517a9c01361e34 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/cuda/sparse.py @@ -0,0 +1 @@ +# The Tensor classes are added to this module by python_tensor.cpp diff --git a/phivenv/Lib/site-packages/torch/cuda/streams.py b/phivenv/Lib/site-packages/torch/cuda/streams.py new file mode 100644 index 0000000000000000000000000000000000000000..058371553ea6ce6b7715405a4d2a4ce4c1d2474b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/cuda/streams.py @@ -0,0 +1,248 @@ +# mypy: allow-untyped-defs +import ctypes + +import torch +from torch._utils import _dummy_type + + +if not hasattr(torch._C, "_CudaStreamBase"): + # Define dummy base classes + torch._C.__dict__["_CudaStreamBase"] = _dummy_type("_CudaStreamBase") + torch._C.__dict__["_CudaEventBase"] = _dummy_type("_CudaEventBase") + + +class Stream(torch._C._CudaStreamBase): + r"""Wrapper around a CUDA stream. + + A CUDA stream is a linear sequence of execution that belongs to a specific + device, independent from other streams. It supports with statement as a + context manager to ensure the operators within the with block are running + on the corresponding stream. See :ref:`cuda-semantics` for details. + + Args: + device(torch.device or int, optional): a device on which to allocate + the stream. If :attr:`device` is ``None`` (default) or a negative + integer, this will use the current device. + priority(int, optional): priority of the stream, which can be positive, 0, or negative. + A lower number indicates a higher priority. By default, the priority is set to 0. + If the value falls outside of the allowed priority range, it will automatically be + mapped to the nearest valid priority (lowest for large positive numbers or + highest for large negative numbers). + + """ + + def __new__(cls, device=None, priority=0, **kwargs): + # setting device manager is expensive, so we avoid it unless necessary + if device is None or ("stream_id" in kwargs and "device_index" in kwargs): + return super().__new__(cls, priority=priority, **kwargs) + else: + with torch.cuda.device(device): + return super().__new__(cls, priority=priority, **kwargs) + + def wait_event(self, event) -> None: + r"""Make all future work submitted to the stream wait for an event. + + Args: + event (torch.cuda.Event): an event to wait for. + + .. note:: This is a wrapper around ``cudaStreamWaitEvent()``: see + `CUDA Stream documentation`_ for more info. + + This function returns without waiting for :attr:`event`: only future + operations are affected. + + .. _CUDA Stream documentation: + https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html + """ + event.wait(self) + + def wait_stream(self, stream) -> None: + r"""Synchronize with another stream. + + All future work submitted to this stream will wait until all kernels + submitted to a given stream at the time of call complete. + + Args: + stream (Stream): a stream to synchronize. + + .. note:: This function returns without waiting for currently enqueued + kernels in :attr:`stream`: only future operations are affected. + """ + self.wait_event(stream.record_event()) + + def record_event(self, event=None): + r"""Record an event. + + Args: + event (torch.cuda.Event, optional): event to record. If not given, a new one + will be allocated. + + Returns: + Recorded event. + """ + if event is None: + event = Event() + event.record(self) + return event + + def query(self) -> bool: + r"""Check if all the work submitted has been completed. + + Returns: + A boolean indicating if all kernels in this stream are completed. + """ + return super().query() + + def synchronize(self) -> None: + r"""Wait for all the kernels in this stream to complete. + + .. note:: This is a wrapper around ``cudaStreamSynchronize()``: see + `CUDA Stream documentation`_ for more info. + """ + super().synchronize() + + @property + def _as_parameter_(self): + return ctypes.c_void_p(self.cuda_stream) + + def __eq__(self, o) -> bool: + if isinstance(o, Stream): + return super().__eq__(o) + return False + + def __hash__(self): + return hash((self.cuda_stream, self.device)) + + def __repr__(self): + return f"" + + +class ExternalStream(Stream): + r"""Wrapper around an externally allocated CUDA stream. + + This class is used to wrap streams allocated in other libraries in order + to facilitate data exchange and multi-library interactions. + + .. note:: This class doesn't manage the stream life-cycle, it is the user + responsibility to keep the referenced stream alive while this class is + being used. + + Args: + stream_ptr(int): Integer representation of the `cudaStream_t` value. + allocated externally. + device(torch.device or int, optional): the device where the stream + was originally allocated. If device is specified incorrectly, + subsequent launches using this stream may fail. + """ + + def __new__(cls, stream_ptr, device=None, **kwargs): + with torch.cuda.device(device): + return super().__new__(cls, stream_ptr=stream_ptr, **kwargs) + + +class Event(torch._C._CudaEventBase): + r"""Wrapper around a CUDA event. + + CUDA events are synchronization markers that can be used to monitor the + device's progress, to accurately measure timing, and to synchronize CUDA + streams. + + The underlying CUDA events are lazily initialized when the event is first + recorded or exported to another process. After creation, only streams on the + same device may record the event. However, streams on any device can wait on + the event. + + Args: + enable_timing (bool, optional): indicates if the event should measure time + (default: ``False``) + blocking (bool, optional): if ``True``, :meth:`wait` will be blocking (default: ``False``) + interprocess (bool): if ``True``, the event can be shared between processes + (default: ``False``) + external (bool, optional): indicates whether this event should create event record and event wait nodes, or create an internal cross-stream dependency, when captured in a cuda graph. See `cross-stream dependencies `_, `cudaEventRecordExternal `_, and `cudaEventWaitExternal `_ for more information about internal vs. external events. (default: ``False``) + + .. _CUDA Event Documentation: + https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__EVENT.html + """ # noqa: B950 + + def __new__( + cls, enable_timing=False, blocking=False, interprocess=False, external=False + ): + return super().__new__( + cls, + enable_timing=enable_timing, + blocking=blocking, + interprocess=interprocess, + external=external, + ) + + @classmethod + def from_ipc_handle(cls, device, handle): + r"""Reconstruct an event from an IPC handle on the given device.""" + return super().from_ipc_handle(device, handle) + + def record(self, stream=None): + r"""Record the event in a given stream. + + Uses ``torch.cuda.current_stream()`` if no stream is specified. The + stream's device must match the event's device. + """ + if stream is None: + stream = torch.cuda.current_stream() + super().record(stream) + + def wait(self, stream=None) -> None: + r"""Make all future work submitted to the given stream wait for this event. + + Use ``torch.cuda.current_stream()`` if no stream is specified. + + .. note:: This is a wrapper around ``cudaStreamWaitEvent()``: see + `CUDA Event documentation`_ for more info. + """ + if stream is None: + stream = torch.cuda.current_stream() + super().wait(stream) + + def query(self): + r"""Check if all work currently captured by event has completed. + + Returns: + A boolean indicating if all work currently captured by event has + completed. + """ + return super().query() + + def elapsed_time(self, end_event): + r"""Return the time elapsed. + + Time reported in milliseconds after the event was recorded and + before the end_event was recorded. + """ + return super().elapsed_time(end_event) + + def synchronize(self) -> None: + r"""Wait for the event to complete. + + Waits until the completion of all work currently captured in this event. + This prevents the CPU thread from proceeding until the event completes. + + .. note:: This is a wrapper around ``cudaEventSynchronize()``: see + `CUDA Event documentation`_ for more info. + """ + super().synchronize() + + def ipc_handle(self): + r"""Return an IPC handle of this event. + + If not recorded yet, the event will use the current device. + """ + return super().ipc_handle() + + @property + def _as_parameter_(self): + return ctypes.c_void_p(self.cuda_event) + + def __repr__(self) -> str: + if self.cuda_event: + return f"" + else: + return "" diff --git a/phivenv/Lib/site-packages/torch/cuda/tunable.py b/phivenv/Lib/site-packages/torch/cuda/tunable.py new file mode 100644 index 0000000000000000000000000000000000000000..3b4f6459dfd9480f6d7f7932c3ea284c041f2a88 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/cuda/tunable.py @@ -0,0 +1,821 @@ +r""" +This module exposes a TunableOp interface. + +Some operations, such as GEMMs, could be implemented using more than one library +or more than one technique. For example, a GEMM could be implemented for CUDA or +ROCm using either the blas or blasLt libraries. Further, ROCm's rocblas and +hipblaslt libraries allow the user to query for all possible algorithms and then +choose one. How does one know which implementation is the fastest and should be +chosen? That's what TunableOp provides. + +Enabling TunableOp and Tuning Separately +======================================== + +The TunableOp feature is enabled separately from enabling the tuning phase +itself. Enabling TunableOp means that PyTorch will replace any standard +operators with their Tunable implementations. Any call to a TunableOp first +checks whether it has already been tuned for the given operator inputs. If so, +it will immediately call the tuned operation; no further tuning will take place +even when the tuning setting is enabled. Instead if no tuning result is found, +and tuning is enabled, the TunableOp will benchmark every registered +implementation of that operator for the given set of inputs and select the +fastest. + +File Input and Output +===================== + +The first time any TunableOp is invoked, the internal database of tuned +operations will be prepared by attempting to read the results from the given +file. The default filename is 'tunableop_results.csv'. To support tuning when +multiple GPUs are used across multiple processes, the GPU device ordinal is +automatically inserted into the filename to avoid multiple processes overwriting +the same file. + +If tuning is enabled and new tunings are discovered during the course of your +workload, it will also write out to this same filename with all tunings, both +the ones it read in at startup as well as the new ones found at runtime. This +can be used, for example, to build up a tunings file across many workloads by +reusing the same file. The output file is automatically created when the +application terminates. This behavior can be controlled by the C++ and Python +APIs but not the environment variables. + +Assuming you specified a filename, you'll end up with a CSV file with contents +like so:: + + Validator,PT_VERSION,2.2.0 + Validator,ROCM_VERSION,6.0.0.0-12969-1544e39 + Validator,HIPBLASLT_VERSION,0.6.0-a9c5cc7 + Validator,ROCBLAS_VERSION,4.0.0-72e57364-dirty + GemmTunableOp_float_NT,nt_25088_4096_64,Gemm_Hipblaslt_1219,1.262 + GemmTunableOp_float_NT,nt_4096_4096_64,Gemm_Rocblas_1216,0.033 + +Note the "Validator" lines. If you change a library version, or ROCm version, or +PyTorch version, TunableOp will detect this and reject the tunings file because +the prior tunings are likely affected by other software changes. + +The remaining lines are the tuned solutions for each TunableOp encountered +during your execution. Each line consists of 4 comma-separated fields: operator +name, operator parameters, solution name, and average execution time. The +execution time is an optional field. The CSV file can be edited, but with +caution. For example, the solution name (field 3) can be changed to "Default" +and it will fall back to the original PyTorch untuned implementation. Or, in the +case of ROCm's hipBLAS or hipBLASLt libraries, if you know the specific solution +index you can override the solution that TunableOp selected by replacing the +value. The operator name and parameters (fields 1 and 2) are internally named +and should not be modified. In the case of GemmTunableOp, field 1 indicates the +datatype and whether the inputs are transposed (T) or not (N) and field 2 +indicates the M, N, K input shapes. + +There is an option to enable verbose output but it is only recommended for +debugging purposes. This will produce a lot of diagnostic messages but may be +useful to see if TunableOp is being used at all. Otherwise, TunableOp is +completely silent, besides file output, unless there is a warning or error +during its use. The verbose option is only available by setting the environment +variable PYTORCH_TUNABLEOP_VEROBSE=1. + +A Note on Tuning Behavior, Warmup, and Cache Effects +==================================================== + +Tuning an operator consists of iterating through the list or registered +implementations and profiling each one. The profile is established by running a +single implementation in a loop multiple times and taking the average execution +time. There is also an optional warmup phase prior to tuning that can help with +reaching stable power states by the hardware. During tuning of a workload the +various hardware caches will more likely produce hits than when not tuning. +There are options for flushing the instruction cache and rotate the input tensors +which might help produce a more faithful profile of the tuned operator as if the +operator were run within a larger workload instead of in a tight, repetitive loop. + +By default, each possible solution for a given operator will be run for either +100 iterations or as many iterations that can be run within 30ms, whichever is +smaller, and its average execution will be calculated. The fastest solution +among all that were successfully profiled will be chosen. A profile might fail +if the given solution doesn't achieve the same accuracy as the default +implementation or if the solution returns an error code. + +Current Tunable Operators +========================= + +TunableGemm for ROCm +-------------------- + +Currently only a TunableGemm for ROCm is implemented. Note that CUDA builds of +PyTorch will function correctly when using TunableOp but the only solution +available to CUDA builds is the 'Default' implementation i.e. the original +cuBLAS default, now called through TunableOp. Any call to at::cuda::blas::gemm() +or ::bgemm() will be routed through TunableOp when enabled. Calling gemm() for a +given set of input arguments (transa, transb, m, n, k) will attempt to use the +fastest available implementation across both rocblas and hipblaslt. + +Offline Tuning +============== + +Motivation +---------- +There are several use cases for offline tuning. + +One use case involves a workload with a high-memory utilization, where regular tuning might lead to running out of memory. + +Another use case is for compute-intensive workloads. In such cases, it is more resource-efficient to collect +the GEMMs for the workload once and then tune repeatedly with different tuning parameters or libraries. + +Workflow +-------- +There are basically two steps: +1) Set the environment variables to collect the untuned GEMM and this will generate ``tunableop_untuned0.csv``: + +.. code-block:: python + + PYTORCH_TUNABLEOP_ENABLED=1 + PYTORCH_TUNABLEOP_TUNING=0 + PYTORCH_TUNABLEOP_RECORD_UNTUNED=1 + ... + +2) Run a Python script that reads the ``tunableop_untuned0.csv`` and generates the ``tunableop_results0.csv``, like this: + +.. code-block:: python + + import torch.cuda.tunable as tunable + import os + + os.putenv('PYTORCH_TUNABLEOP_ENABLED', '1') + os.putenv('PYTORCH_TUNABLEOP_TUNING', '1') + os.putenv('PYTORCH_TUNABLEOP_RECORD_UNTUNED', '0') + tunable.tune_gemm_in_file("tunableop_untuned0.csv") + + +It is also possible to take multiple untuned files and distribute the GEMMs for tuning to multiple GPUs +within a single node. In the first step, the GEMMs are first gathered and duplicate GEMMs are eliminated. +Next, the GEMMs are distributed to different GPUs for tuning. After all GEMMs are tuned, the results from +all the GPUs are then gathered into a single file whose base filename has ``_full0`` appended to it +(for example ``tunableop_results_full0.csv``). Finally, this new file, containing the gathered results, will be +duplicated N times, once for each GPU as convenience to the user will run the workload with the tuned +configuration on N GPUs. + +.. code-block:: python + + if __name__ == "__main__": + num_gpus = 8 # number of GPUs that will be used during the tuning process + tunable.mgpu_tune_gemm_in_file("tunableop_untuned?.csv", num_gpus) + +Note that the usage of the ``mgpu_tune_gemm_in_file`` API is different from its single GPU counterpart +(``tune_gemm_in_file``). The body of the Python script that calls the API must be wrapped in ``main()`` as shown +due to the use of concurrent futures module. The argument to ``mgpu_tune_gemm_in_file`` must contain a wild card +expression (``?`` or ``*``) to generate the list of untuned files containing the GEMMs to be processed. The ``num_gpus`` +must between 1 and the total number of GPUs available. + +Tuning Context +============== + +The behavior of TunableOp is currently manipulated through environment +variables, the C++ interface of at::cuda::tunable::getTuningContext(), or the +torch.cuda.tunable python interfaces. The environment variables take precedence +over any setting you manipulate using the C++ or Python APIs. + +Environment Variable Interface +------------------------------ +Environment variables are cached the first time they are read. You cannot use the +environment variable interface programmatically since the settings become fixed. +Use the C++ or Python APIs instead. + +""" +import concurrent.futures +import glob +import multiprocessing as mp +import os +import shutil +import warnings +from typing import Optional + +import torch + + +__all__ = [ + "enable", + "is_enabled", + "tuning_enable", + "tuning_is_enabled", + "record_untuned_enable", + "record_untuned_is_enabled", + "set_max_tuning_duration", + "get_max_tuning_duration", + "set_max_tuning_iterations", + "get_max_tuning_iterations", + "set_filename", + "get_filename", + "get_results", + "get_validators", + "write_file_on_exit", + "write_file", + "read_file", + "tune_gemm_in_file", + "mgpu_tune_gemm_in_file", + "set_rotating_buffer_size", + "get_rotating_buffer_size", +] + + +def enable(val: bool = True) -> None: + r"""This is the big on/off switch for all TunableOp implementations.""" + torch._C._cuda_tunableop_enable(val) # type: ignore[attr-defined] + + +def is_enabled() -> bool: + r"""Returns whether the TunableOp feature is enabled.""" + return torch._C._cuda_tunableop_is_enabled() # type: ignore[attr-defined] + + +def tuning_enable(val: bool = True) -> None: + r"""Enable tuning of TunableOp implementations. + + When enabled, if a tuned entry isn't found, run the tuning step and record + the entry. + """ + torch._C._cuda_tunableop_tuning_enable(val) # type: ignore[attr-defined] + + +def tuning_is_enabled() -> bool: + r"""Returns whether TunableOp implementations can be tuned.""" + return torch._C._cuda_tunableop_tuning_is_enabled() # type: ignore[attr-defined] + + +def record_untuned_enable(val: bool = True) -> None: + r"""Enable recording untuned of TunableOp perations for offline tuning. + + When enabled, if a tuned entry isn't found, write it to the untuned file. + """ + torch._C._cuda_record_untuned_enable(val) # type: ignore[attr-defined] + + +def record_untuned_is_enabled() -> bool: + r"""Returns whether TunableOp operations are recorded for offline tuning.""" + return torch._C._cuda_record_untuned_is_enabled() # type: ignore[attr-defined] + + +def set_max_tuning_duration(duration: int) -> None: + r"""Set max time in milliseconds to spend tuning a given solution. + + If both max tuning duration and iterations are set, the smaller of the two + will be honored. At minimum 1 tuning iteration will always be run. + """ + torch._C._cuda_tunableop_set_max_tuning_duration(duration) # type: ignore[attr-defined] + + +def get_max_tuning_duration() -> int: + r"""Get max time to spend tuning a given solution.""" + return torch._C._cuda_tunableop_get_max_tuning_duration() # type: ignore[attr-defined] + + +def set_max_tuning_iterations(iterations: int) -> None: + r"""Set max number of iterations to spend tuning a given solution. + + If both max tuning duration and iterations are set, the smaller of the two + will be honored. At minimum 1 tuning iteration will always be run. + """ + torch._C._cuda_tunableop_set_max_tuning_iterations(iterations) # type: ignore[attr-defined] + + +def get_max_tuning_iterations() -> int: + r"""Get max iterations to spend tuning a given solution.""" + return torch._C._cuda_tunableop_get_max_tuning_iterations() # type: ignore[attr-defined] + + +def set_filename(filename: str, insert_device_ordinal: bool = False) -> None: + r"""Set the filename to use for input/output of tuning results. + + If :attr:`insert_device_ordinal` is ``True`` then the current device ordinal + will be added to the given filename automatically. This can be used in a + 1-process-per-gpu cenario to ensure all processes write to a separate file. + """ + torch._C._cuda_tunableop_set_filename(filename, insert_device_ordinal) # type: ignore[attr-defined] + + +def get_filename() -> str: + r"""Get the results filename.""" + return torch._C._cuda_tunableop_get_filename() # type: ignore[attr-defined] + + +def get_results() -> tuple[str, str, str, float]: + r"""Return all TunableOp results.""" + return torch._C._cuda_tunableop_get_results() # type: ignore[attr-defined] + + +def get_validators() -> tuple[str, str]: + r"""Return the TunableOp validators.""" + return torch._C._cuda_tunableop_get_validators() # type: ignore[attr-defined] + + +def write_file_on_exit(val: bool) -> None: + r"""During Tuning Context destruction, write file to disk. + + This is useful as a final flush of your results to disk if your application + terminates as result of normal operation or an error. Manual flushing of + your results can be achieved by manually calling ``write_file()``.""" + torch._C._cuda_tunableop_write_file_on_exit(val) # type: ignore[attr-defined] + + +def write_file(filename: Optional[str] = None) -> bool: + r"""Write results to a CSV file. + + If :attr:`filename` is not given, ``get_filename()`` is called. + """ + if filename is None: + filename = get_filename() + return torch._C._cuda_tunableop_write_file(filename) # type: ignore[attr-defined] + + +def read_file(filename: Optional[str] = None) -> bool: + r"""Read results from a TunableOp CSV file. + + If :attr:`filename` is not given, ``get_filename()`` is called. + """ + if filename is None: + filename = get_filename() + return torch._C._cuda_tunableop_read_file(filename) # type: ignore[attr-defined] + + +def set_rotating_buffer_size(buffer_size: int) -> None: + r"""Set rotating buffer size to this value in MB, if the buffer size is greater than zero. + + If less than zero, query L2 cache size. If equal to zero, means deactivate rotating buffer. + """ + return torch._C._cuda_tunableop_set_rotating_buffer_size(buffer_size) # type: ignore[attr-defined] + + +def get_rotating_buffer_size() -> int: + r"""Get the rotating buffer size in kilobytes.""" + return torch._C._cuda_tunableop_get_rotating_buffer_size() # type: ignore[attr-defined] + + +def tune_gemm_in_file(filename: str) -> None: + r"""tune GEMM in file.""" + + assert is_enabled() + assert tuning_is_enabled() + + deviceid = torch.cuda.current_device() + + with open(filename) as file: + for line in file: + if line.startswith(("Gemm", "ScaledGemm")): + _process_single_offline_gemm(line, deviceid) + + +def _gather_unique_untuned_gemm_from_files(filename_pattern: str) -> set[str]: + r"""Process multiple untuned results file and return a set with duplicates removed.""" + unique_gemm_entries = set() # set will avoid duplicates + + for file_path in glob.glob(filename_pattern): + with open(file_path) as file: + for line in file: + if line.startswith(("Gemm", "ScaledGemm")): + unique_gemm_entries.add(line) + + return unique_gemm_entries + + +def _gather_tunableop_results() -> None: + r"""Gather results from multiple tunableop results file and create a single file.""" + gemm_lines = set() + validator_lines = [] + + # Need to allow for the possibility that results filename was + # set with the Python API instead of with environment variable. + # Also possible that results filename was not set at all. + # There are several test cases to check, but ultimately we + # need a glob-able expression + results_filename = get_filename() # Note empty string could be returned here + + if ( + results_filename is not None and results_filename != "" + ): # Case were the Python API was used to set the filename + dot_pos = results_filename.find(".") + if dot_pos != -1 and dot_pos > 0: + # Replace the character just to the left of the dot + filename_pattern = ( + results_filename[: dot_pos - 1] + "?" + results_filename[dot_pos:] + ) + else: + filename_pattern = "" # Needed to make linter happy + else: # Case where the environment variable was used to set the filename. + results_filename_env = os.getenv("PYTORCH_TUNABLEOP_FILENAME") + if results_filename_env is None or results_filename_env == "": + filename_pattern = "tunableop_results?.csv" + elif "%d" in results_filename_env: + filename_pattern = results_filename_env.replace("%d", "?") + else: + filename_pattern = results_filename_env.replace(".", "?.") + + assert "?" in filename_pattern + + FirstFile = False + matching_files = glob.glob(filename_pattern) + num_matching_files = len(matching_files) + for file_path in matching_files: + with open(file_path) as file: + for line in file: + if line.startswith("Validator"): + if not (FirstFile): + # Only read Validator from first file + validator_lines.append(line) + else: + gemm_lines.add(line) + + FirstFile = True + + output_file = filename_pattern.replace("?", "_full0") + + with open(output_file, "w") as out_file: + for line in validator_lines: + out_file.write(line) + for line in gemm_lines: + out_file.write(line) + + # Create num_matching_copies of the results file + for i in range(1, num_matching_files): + duplicate_file = output_file.replace("0", str(i)) + shutil.copy(output_file, duplicate_file) + + +def _create_matrices( + m: int, + n: int, + k: int, + lda: int, + ldb: int, + ldc: int, + transA: bool, + transB: bool, + dtypeA: torch.dtype, + deviceid: str, + dtypeB: Optional[torch.dtype] = None, + randn: bool = True, + subMatrix: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + r"""Helper function for _process_single_offline_gemm. + Creates matrices that are then consumed by one of the Torch GEMM APIs. + """ + # Fill parameters set for use with ScaledGEMM + fillA = 0.25 + fillB = 0.75 + + if dtypeB is None: + dtypeB = dtypeA + + if subMatrix: + # User reference for understanding leading dimension: + # https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/dgemm.f + # TO DO: According to lines 108 - 133, there is no lower bound on rowsA, + # but there is a restriction on rowsB. Using this formula for now as it + # seems to work for all UTs. + rowsA = rowsB = max(ldc, k) + + if randn: + matA = torch.randn(rowsA, lda, dtype=dtypeA, device=deviceid) + matB = torch.randn(rowsB, ldb, dtype=dtypeA, device=deviceid) + else: + matA = torch.full((rowsA, lda), fillA, dtype=dtypeB, device=deviceid) + matB = torch.full((rowsB, ldb), fillB, dtype=dtypeB, device=deviceid) + + subA = matA[:k, :m].t() if transA else matA[:m, :k] + subB = matB[:n, :k].t() if transB else matB[:k, :n] + return subA, subB + else: + if randn: + matA = ( + torch.rand(k, m, dtype=dtypeA, device=deviceid).t() + if transA + else torch.rand(m, k, dtype=dtypeA, device=deviceid) + ) + matB = ( + torch.rand(n, k, dtype=dtypeB, device=deviceid).t() + if transB + else torch.rand(k, n, dtype=dtypeB, device=deviceid) + ) + else: + matA = ( + torch.full((k, m), fillA, dtype=dtypeA, device=deviceid).t() + if transA + else torch.full((m, k), fillA, dtype=dtypeA, device=deviceid) + ) + matB = ( + torch.full((n, k), fillB, dtype=dtypeB, device=deviceid).t() + if transB + else torch.full((k, n), fillB, dtype=dtypeB, device=deviceid) + ) + return matA, matB + + +def _create_batch_matrices( + m: int, + n: int, + k: int, + b: int, + lda: int, + ldb: int, + ldc: int, + transA: bool, + transB: bool, + dtype: torch.dtype, + deviceid: str, + subMatrix: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + r"""Helper function for _process_single_offline_gemm. + Creates batch matrices that are then consumed by one of the Torch GEMM APIs. + Similar to _create_matrices but for 3D batch matrices. + """ + if subMatrix: + # User reference for understanding leading dimension: + # https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/dgemm.f + # TO DO: According to lines 108 - 133, there is no lower bound on rowsA, + # but there is a restriction on rowsB. Using this formula for now as it + # seems to work for all UTs. + rowsA = rowsB = max(ldc, k) + + matA = torch.randn(b, rowsA, lda, dtype=dtype, device=deviceid) + matB = torch.randn(b, rowsB, ldb, dtype=dtype, device=deviceid) + + subA = matA[:b, :k, :m].transpose(1, 2) if transA else matA[:b, :m, :k] + subB = matB[:b, :n, :k].transpose(1, 2) if transB else matB[:b, :k, :n] + return subA, subB + else: + matA = ( + torch.rand(b, k, m, dtype=dtype, device=deviceid) + if transA + else torch.rand(b, m, k, dtype=dtype, device=deviceid) + ) + matB = ( + torch.rand(b, n, k, dtype=dtype, device=deviceid) + if transB + else torch.rand(b, k, n, dtype=dtype, device=deviceid) + ) + matA = matA.transpose(1, 2) if transA else matA + matB = matB.transpose(1, 2) if transB else matB + return matA, matB + + +def _process_single_offline_gemm(untuned_gemm_line: str, gpu_id: int) -> None: + r"""Process a single untuned GEMM.""" + + deviceid = "cuda:" + str(gpu_id) + + dtype_dict = { + "float": torch.float32, + "tf32": torch.float32, + "double": torch.float64, + "BFloat16": torch.bfloat16, + "Half": torch.half, + "c10::complex": torch.complex128, + "c10::complex": torch.complex64, + "Float8_e4m3fn": torch.float8_e4m3fn, + "Float8_e5m2": torch.float8_e5m2, + "Float8_e4m3fnuz": torch.float8_e4m3fnuz, + "Float8_e5m2fnuz": torch.float8_e5m2fnuz, + } + + untuned_gemm = untuned_gemm_line.strip().split(",")[:] + + underscore_count = untuned_gemm[0].count("_") + + # Initialize dtype to make linter happy + dtype = None + dtypeA = None + dtypeB = None + dtypeC = None + + # Extract BLAS parameters + if underscore_count == 2: + [op_sig, data_type, layout] = untuned_gemm[0].split("_") + transB = layout[0] == "T" + transA = layout[1] == "T" + dtype = dtype_dict.get(data_type) + if data_type == "tf32": + # User must still set HIPBLASLT_ALLOW_TF32=1 + torch.backends.cuda.matmul.allow_tf32 = True + else: + torch.backends.cuda.matmul.allow_tf32 = False + + else: # ScaledGEMM + count = untuned_gemm[0].count("_") + assert count in [6, 7] + untuned_gemm_temp = untuned_gemm[0].split("_") + # dtypeC = might not be FP8 type, keep track + # of the the number of underscores + op_sig = untuned_gemm_temp[0] + data_typeA = untuned_gemm_temp[1] + "_" + untuned_gemm_temp[2] + data_typeB = untuned_gemm_temp[3] + "_" + untuned_gemm_temp[4] + if count == 7: + data_typeC = untuned_gemm_temp[5] + "_" + untuned_gemm_temp[6] + else: + data_typeC = untuned_gemm_temp[5] + transB = untuned_gemm_temp[count][0] == "T" + transA = untuned_gemm_temp[count][1] == "T" + dtypeA = dtype_dict.get(data_typeA) + dtypeB = dtype_dict.get(data_typeB) + dtypeC = dtype_dict.get(data_typeC) + + untuned_gemm_temp = untuned_gemm[1].split("_") + [n, m, k] = [int(g) for g in untuned_gemm_temp[1:4]] + if op_sig == "GemmStridedBatchedTunableOp": + assert untuned_gemm_temp[6] == "ld" + [ldb, lda, ldc] = [int(g) for g in untuned_gemm_temp[7:10]] + else: + assert untuned_gemm_temp[4] == "ld" + [ldb, lda, ldc] = [int(g) for g in untuned_gemm_temp[5:8]] + + # Detect subMatrix case + if all(item in [n, m, k] for item in [lda, ldb, ldc]): + subMatrix = False + else: + subMatrix = True + + if op_sig == "GemmTunableOp": + # Warnings for unsupported cases: + if m == 1 or n == 1 or k == 1: + if (not transA) and (not transB): + pass # case is supported + elif transA and n == 1: + pass # case is supported + else: + warnings.warn( + "Offline tuning is not supported for this GEMM. Use online tuning instead. " + + f"Skipped tuning for: {untuned_gemm[1]}" + ) + return + + # Resolve linter issue + if dtype is None or not isinstance(dtype, torch.dtype): + raise TypeError(f"dtype must be a torch.dtype, but got {dtype}") + + matA, matB = _create_matrices( + m, n, k, lda, ldb, ldc, transA, transB, dtype, deviceid, subMatrix=subMatrix + ) + torch.mm(matA, matB) + + elif op_sig == "GemmStridedBatchedTunableOp": + # Warnings for unsupported cases: + if m == 1 or n == 1 or k == 1: + warnings.warn( + "Offline tuning is not support for this GEMM. Use online tuning instead. " + + f"Skipped tuning for: {untuned_gemm[1]}" + ) + return + + [b] = [int(g) for g in untuned_gemm_temp[5:6]] + + # Resolve linter issue + if dtype is None or not isinstance(dtype, torch.dtype): + raise TypeError(f"dtype must be a torch.dtype, but got {dtype}") + + matA, matB = _create_batch_matrices( + m, + n, + k, + b, + lda, + ldb, + ldc, + transA, + transB, + dtype, + deviceid, + subMatrix=subMatrix, + ) + torch.bmm(matA, matB) + elif op_sig == "ScaledGemmTunableOp": + # Only combination supported by PyTorch + assert transB is True + assert transA is False + + # Resolve linter issue + if dtypeA is None or not isinstance(dtypeA, torch.dtype): + raise TypeError(f"dtype must be a torch.dtype, but got {dtypeA}") + + matA, matB = _create_matrices( + m, + n, + k, + lda, + ldb, + ldc, + transA, + transB, + dtypeA, + deviceid, + dtypeB=dtypeB, + randn=False, + subMatrix=subMatrix, + ) + + assert untuned_gemm_temp[8] == "rw" + if untuned_gemm_temp[9] == "1": + rowwise = True + else: + rowwise = False + if rowwise: + scaleA = ( + torch.ones((1, m), device=deviceid) + if transA + else torch.ones((m, 1), device=deviceid) + ) + scaleB = ( + torch.ones((1, n), device=deviceid) + if transB + else torch.ones((n, 1), device=deviceid) + ) + else: + scaleA = torch.tensor(0.8, device=deviceid) + scaleB = torch.tensor(0.9, device=deviceid) + + assert untuned_gemm_temp[10] == "bias" + if untuned_gemm_temp[11] == "None": # no bias vector + torch._scaled_mm( + matA, matB, scale_a=scaleA, scale_b=scaleB, out_dtype=dtypeC + ) + else: # bias vector present + fillbias = 0.10 + bias_dtype = dtype_dict.get(untuned_gemm_temp[11]) + bias = ( + torch.full((n,), fillbias, dtype=bias_dtype, device=deviceid) + if transB + else torch.full((m,), fillbias, dtype=bias_dtype, device=deviceid) + ) + torch._scaled_mm( + matA, matB, scale_a=scaleA, scale_b=scaleB, out_dtype=dtypeC, bias=bias + ) + + elif op_sig == "GemmAndBiasTunableOp": + # y = x*A^T + b + assert transA != transB + + # Resolve linter issue + if dtype is None or not isinstance(dtype, torch.dtype): + raise TypeError(f"dtype must be a torch.dtype, but got {dtype}") + + bias = torch.rand(n, dtype=dtype, device=deviceid) + + X, matA = _create_matrices( + m, n, k, lda, ldb, ldc, transA, transB, dtype, deviceid, subMatrix=subMatrix + ) + matA = matA.t() + torch.nn.functional.linear(X, matA, bias) + else: + warnings.warn(f"error: unknown op {op_sig}") + + +def _check_tuning_assertions() -> None: + r"""Helper function for multi-GPU tuning case. Need to check that TunableOp feature + is enabled and that tuning is enabled. + """ + + if is_enabled() is False: + warnings.warn("TunableOp was disabled. Trying to enable now.") + enable(True) + assert is_enabled() is True + assert tuning_is_enabled() is True + assert record_untuned_is_enabled() is False + + +def mgpu_tune_gemm_in_file(filename_pattern: str, num_gpus: int) -> None: + r"""Process one or more files and distribute work over one or more GPUs.""" + unique_gemm_entries = _gather_unique_untuned_gemm_from_files(filename_pattern) + + total_gpus = torch.cuda.device_count() + + assert 1 <= num_gpus <= total_gpus + + mp_context = mp.get_context("spawn") + + futures = [] # empty list to hold futures + flush_results = [] # empty list to hold futures + + # GEMM are assigned to GPUs in a round robin manner + h = 0 + with concurrent.futures.ProcessPoolExecutor( + max_workers=num_gpus, + mp_context=mp_context, + initializer=_check_tuning_assertions, + ) as executor: + # The workers are a separate process. TunableOp will be + # enabled in the child processes if PYTORCH_TUNABLEOP_ENABLED=1 + # In the initializer, we also try to enable TunableOP if th + # environment variable was NOT set. + + for line in unique_gemm_entries: + future = executor.submit(_process_single_offline_gemm, line, h) + futures.append(future) + h = (h + 1) % num_gpus + + for future in concurrent.futures.as_completed(futures): + future.result() + + for g in range(num_gpus): + flush_result = executor.submit(write_file) + flush_results.append(flush_result) + + for flush_result in concurrent.futures.as_completed(flush_results): + flush_result.result() + + torch.cuda.synchronize() + + _gather_tunableop_results() diff --git a/phivenv/Lib/site-packages/torch/distributed/__init__.py b/phivenv/Lib/site-packages/torch/distributed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e95b14e7373bb752d4dbd4d72bb0ffce6a3170cb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/__init__.py @@ -0,0 +1,159 @@ +# mypy: allow-untyped-defs +import logging +import pdb +import sys +import traceback +import typing + +import torch + + +log = logging.getLogger(__name__) + + +def is_available() -> bool: + """ + Return ``True`` if the distributed package is available. + + Otherwise, + ``torch.distributed`` does not expose any other APIs. Currently, + ``torch.distributed`` is available on Linux, MacOS and Windows. Set + ``USE_DISTRIBUTED=1`` to enable it when building PyTorch from source. + Currently, the default value is ``USE_DISTRIBUTED=1`` for Linux and Windows, + ``USE_DISTRIBUTED=0`` for MacOS. + """ + return hasattr(torch._C, "_c10d_init") + + +if is_available() and not torch._C._c10d_init(): + raise RuntimeError("Failed to initialize torch.distributed") + +# Custom Runtime Errors thrown from the distributed package +DistError = torch._C._DistError +DistBackendError = torch._C._DistBackendError +DistNetworkError = torch._C._DistNetworkError +DistStoreError = torch._C._DistStoreError +QueueEmptyError = torch._C._DistQueueEmptyError + +if is_available(): + from torch._C._distributed_c10d import ( + _broadcast_coalesced, + _compute_bucket_assignment_by_size, + _ControlCollectives, + _DEFAULT_FIRST_BUCKET_BYTES, + _make_nccl_premul_sum, + _register_builtin_comm_hook, + _register_comm_hook, + _StoreCollectives, + _test_python_store, + _verify_params_across_processes, + Backend as _Backend, + BuiltinCommHookType, + DebugLevel, + FileStore, + get_debug_level, + GradBucket, + Logger, + PrefixStore, + ProcessGroup as ProcessGroup, + Reducer, + set_debug_level, + set_debug_level_from_env, + Store, + TCPStore, + Work as _Work, + ) + + class _DistributedPdb(pdb.Pdb): + """ + Supports using PDB from inside a multiprocessing child process. + + Usage: + _DistributedPdb().set_trace() + """ + + def interaction(self, *args, **kwargs): + _stdin = sys.stdin + try: + sys.stdin = open("/dev/stdin") + pdb.Pdb.interaction(self, *args, **kwargs) + finally: + sys.stdin = _stdin + + _breakpoint_cache: dict[int, typing.Any] = {} + + def breakpoint(rank: int = 0, skip: int = 0): + """ + Set a breakpoint, but only on a single rank. All other ranks will wait for you to be + done with the breakpoint before continuing. + + Args: + rank (int): Which rank to break on. Default: ``0`` + skip (int): Skip the first ``skip`` calls to this breakpoint. Default: ``0``. + """ + if skip > 0: + key = hash(str(traceback.format_exc())) + counter = _breakpoint_cache.get(key, 0) + 1 + _breakpoint_cache[key] = counter + if counter <= skip: + log.warning("Skip the breakpoint, counter=%d", counter) + return + + if get_rank() == rank: + pdb = _DistributedPdb() + pdb.message( + "\n!!! ATTENTION !!!\n\n" + f"Type 'up' to get to the frame that called dist.breakpoint(rank={rank})\n" + ) + pdb.set_trace() + # If Meta/Python keys are in the TLS, we want to make sure that we ignore them + # and hit the (default) CPU/CUDA implementation of barrier. + meta_in_tls = torch._C._meta_in_tls_dispatch_include() + guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined] + torch._C._set_meta_in_tls_dispatch_include(False) + try: + barrier() + finally: + torch._C._set_meta_in_tls_dispatch_include(meta_in_tls) + del guard + + if sys.platform != "win32": + from torch._C._distributed_c10d import HashStore + + from .device_mesh import DeviceMesh, init_device_mesh + + # Variables prefixed with underscore are not auto imported + # See the comment in `distributed_c10d.py` above `_backend` on why we expose + # this. + from .distributed_c10d import * # noqa: F403 + from .distributed_c10d import ( + _all_gather_base, + _coalescing_manager, + _CoalescingManager, + _create_process_group_wrapper, + _get_process_group_name, + _rank_not_in_group, + _reduce_scatter_base, + _time_estimator, + get_node_local_rank, + ) + from .remote_device import _remote_device + from .rendezvous import ( + _create_store_from_options, + register_rendezvous_handler, + rendezvous, + ) + + set_debug_level_from_env() + +else: + # This stub is sufficient to get + # python test/test_public_bindings.py -k test_correct_module_names + # working even when USE_DISTRIBUTED=0. Feel free to add more + # stubs as necessary. + # We cannot define stubs directly because they confuse pyre + + class _ProcessGroupStub: + pass + + sys.modules["torch.distributed"].ProcessGroup = _ProcessGroupStub # type: ignore[attr-defined] diff --git a/phivenv/Lib/site-packages/torch/distributed/_checkpointable.py b/phivenv/Lib/site-packages/torch/distributed/_checkpointable.py new file mode 100644 index 0000000000000000000000000000000000000000..01afad0b0746b09694d987873457910140acff75 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_checkpointable.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from typing_extensions import Protocol, runtime_checkable + +import torch + + +@runtime_checkable +class _Checkpointable(Protocol): # noqa: PYI046 + """ + Interface for checkpointable objects. + Implemented as a protocol, implicit subtyping is supported so subclasses do not need to inherit this explicitly. + This is to allow arbitrary objects/tensor subclasses to hook into DCP seamlessly through implementing the interface. + """ + + def __create_write_items__(self, fqn: str, object: object) -> list[object]: + """ + Return a list of WriteItems based on object's contents. + """ + raise NotImplementedError( + "_Checkpointable._create_write_items is not implemented" + ) + + def __create_chunk_list__(self) -> list[object]: + """ + Return a list of `ChunkStorageMetadata` based on object's contents. + """ + raise NotImplementedError( + "_Checkpointable._create_chunk_list is not implemented" + ) + + def __get_tensor_shard__(self, index: int) -> torch.Tensor: + """ + Return a 'torch.Tensor' shard based on 'MetadataIndex'. + """ + raise NotImplementedError( + "_Checkpointable._get_tensor_shard is not implemented" + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/_composable_state.py b/phivenv/Lib/site-packages/torch/distributed/_composable_state.py new file mode 100644 index 0000000000000000000000000000000000000000..cd2fd104e8c7a2d54f0737ee64ff58e932cb9527 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_composable_state.py @@ -0,0 +1,44 @@ +import weakref +from typing import cast, Optional + +import torch.nn as nn + + +class _State: + pass + + +_module_state_mapping: weakref.WeakKeyDictionary[ + nn.Module, weakref.ReferenceType[_State] +] = weakref.WeakKeyDictionary() + + +def _insert_module_state(module: nn.Module, state: _State) -> None: + global _module_state_mapping + assert module not in _module_state_mapping, f"Inserting {module} more than once." + _module_state_mapping[module] = weakref.ref(state) + + +def _get_module_state(module: nn.Module) -> Optional[_State]: + """ + Return the ``_State`` in ``model``. + + Given a ``module``, this API finds out if the module is also a ``_State`` + instance or if the module is managed by a composable API. If the module + is also a ``_State``, ``module`` will be casted to ``_State` and returned. + If it is managed by a composable API, the corresponding ``_State`` will + be returned. + """ + global _module_state_mapping + if isinstance(module, _State): + return cast(_State, module) + else: + # https://github.com/pytorch/pytorch/issues/107054 + if module in _module_state_mapping: + state_ref = _module_state_mapping[module] + state = state_ref() + if state is None: + raise AssertionError("State has already been garbage collected") + return state + else: + return None diff --git a/phivenv/Lib/site-packages/torch/distributed/_functional_collectives.py b/phivenv/Lib/site-packages/torch/distributed/_functional_collectives.py new file mode 100644 index 0000000000000000000000000000000000000000..b4f4368de806890e179579eafafb9df5845e57cb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_functional_collectives.py @@ -0,0 +1,1204 @@ +# mypy: allow-untyped-defs +import contextlib +import sys +import warnings +from typing import Any, cast, Optional, TYPE_CHECKING, Union + +import torch +import torch.distributed as dist +import torch.distributed.distributed_c10d as c10d +from torch.distributed.device_mesh import DeviceMesh +from torch.fx.experimental.proxy_tensor import get_proxy_mode + +from . import _functional_collectives_impl as fun_col_impl + + +try: + from torch.utils._cxx_pytree import tree_map_only +except ImportError: + from torch.utils._pytree import tree_map_only # type: ignore[no-redef] + + +if torch._running_with_deploy(): + + def is_torchdynamo_compiling(): + """Can't import torchdynamo in torchdeploy builds currently.""" + return False + +else: + try: + from torch.compiler import is_dynamo_compiling as is_torchdynamo_compiling + except Exception: + warnings.warn( + "Unable to import torchdynamo util `is_torchdynamo_compiling`, so won't support torchdynamo correctly" + ) + + def is_torchdynamo_compiling(): + return False + + +""" +New traceable, functional collectives. +RFC: https://github.com/pytorch/pytorch/issues/93173 + + compiler: trace these ops with plain-old-data schemas, then choose how to lower them. + eager: execute these 'functional' ops which in eager return AsyncCollectiveTensor subclasses, + automatically calling .wait() on underlying/hidden async 'work' obj only when fed to + a downstream op. + +Issues: +* Where should these ops live? Couldn't `import torch` if putting these ops in existing torch.distributed files +* Proper support for eager requires inplace ops. We should explore having it as an option for the API. +""" + +""" +Functional collectives are asynchronous only and we perform implicit stream synchronization +on behalf of the user. + +We use AsyncCollectiveTensor to wrap the result tensor of a collective and it lets us witness +first usage of the tensor and insert cross stream sync at the right place. + +The above are the easy bits, the hard one is how we match the Work object returned by +c10d and the tensor AsyncCollectiveTensor wraps. We alloc the tensor inside the collective +op implementation (see ``clone()`` call in ``_all_reduce``) and then it's handled by the +dispatcher which might call other implementations that are allowed to change the returned +tensor - even return a tensor with a different shape (see ``torch.vmap``). + +This means the caller of our ops receives a Tensor that is not guaranteed to be the same +allocated by our implementations and that makes pairing The AsyncTensor to the original +tensor a lot harder. This pairing is needed so we can lookup the Work object to use. + +Originally, we tried WeakKeyDictionary to map from Tensor to Work, but because Tensor's +identity is not stable across dispatch, the op caller would end up with a different Tensor +instance that would not match any in the dictionary. + +With Tensor identity out of the question, we decided use the tensor data pointer, which +should be stable across all the Tensor changes done during dispatch. + +We have a dictionary of tensor::data_ptr -> Work that we insert right after we call into c10d. + +We use this dictionary when AsyncCollectiveTensor is used to invoke Work::wait() + +Finally, we setup a finalizer against the tensor wrapper to observe it getting collected so we +can clean up stale entries in the dictionary. + +To eliminate the possibility of races we have a global version counter that is used by the finalizer. + +As a wise man said once: Don't cross the streams (https://www.youtube.com/watch?v=wyKQe_i9yyo) + +""" + +""" +Functional collectives can accept any of these types to describe the ranks participating in collectives. + +The different types will be desugared to a canonical format +""" +RANK_TYPES = Union[ + list[int], + list[list[int]], + dist.ProcessGroup, + DeviceMesh, + tuple["dist.tensor.DeviceMesh", int], + str, +] + + +""" +User facing APIs for functional collectives +------------------------------------------- + +These apis are called by user code and expected to work both in eager execution and compilation, +but there are significant differences to how the two modes are implemented underneath. + +Eager execution is 'optimized' using a tensor subclass that schedules the synchronization (via wait_tensor() op) +just before the tensor is first used. Compiled tracing currently relies on the compiler to perform this optimization, +and cannot yet correctly trace the AsyncTensor wrapper class. In the future, these paths may be unified +if sufficient subclass support is added in dynamo. + +Example: all_reduce is an entrypoint API, and other collectives follow a similar pattern. + +Here's how it works under torch.compile/dynamo: +all_reduce(...) + |--> _expand_group(...) - desugars processgroup into canonical/traceable format + |--> c10d_functional.all_reduce(...) - dynamo captures this op call, doesn't trace deeper + |--> _maybe_wrap_tensor(...) - wait_tensor() op is immediately called, no AsyncTensor subclass needed + +And under eager execution: +all_reduce(...) + |--> _expand_group(...) - same as above, but less critical for eager + |--> c10d_functional.all_reduce(...) - dispatches to real kernel OR records op in trace + |--> _maybe_wrap_tensor(...) - AsyncTensor wrapper applied to returned tensor, + which issues wait_tensor() at the time of first use +""" + + +def wait_tensor(tensor): + """ + Wait on a tensor returned by the collectives ops. + + Waiting follows device semantics, which means blocking on CPU and synchronizing streams on CUDA. + """ + return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] + + +def broadcast(self: torch.Tensor, src: int, group: RANK_TYPES, tag: str = ""): + """ + Broadcasts the tensor to all processes in the given process group. + + Args: + src (int): Source rank + group (ProcessGroup or List[int]): The process group to work on. + tag (str, optional): A unique identifier for the collective. Default: empty string + """ + group_name = _resolve_group_name(group, tag) + tensor = torch.ops._c10d_functional.broadcast(self, src, group_name) + return _maybe_wrap_tensor(tensor) + + +def all_reduce(self: torch.Tensor, reduceOp: str, group: RANK_TYPES, tag: str = ""): + """ + Reduces the tensor data across all machines in such a way that all get + the final result. + + The input tensor is left unmodified. + + Group can be one of: + List[int]: ranks participating in the collective. + List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. + ProcessGroup: Will perform a collective using the ranks and tag of the PG. + DeviceMesh: Do a SPMD collective over all ranks of the mesh + (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh + + :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover + that information and perform collective algebraic optimization. Use other forms of input for that. + """ + group_name = _resolve_group_name(group, tag) + tensor = torch.ops._c10d_functional.all_reduce(self, reduceOp.lower(), group_name) + return _maybe_wrap_tensor(tensor) + + +def all_gather_tensor( + self: torch.Tensor, + gather_dim: int, + group: RANK_TYPES, + tag: str = "", +) -> torch.Tensor: + """ + Gather tensor data across from all machines and concatenate over ``gather_dim``. + + Note that it currently only supports gather_dim = 0. + + The input tensor is left unmodified. + Group can be one of: + List[int]: ranks participating in the collective. + List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. + ProcessGroup: Will perform a collective using the ranks and tag of the PG. + DeviceMesh: Do a SPMD collective over all ranks of the mesh + (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh + + :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover + that information and perform collective algebraic optimization. Use other forms of input for that. + """ + assert self.is_contiguous() + group_name = _resolve_group_name(group, tag) + group_size = c10d._get_group_size_by_name(group_name) + tensor = torch.ops._c10d_functional.all_gather_into_tensor( + self, group_size, group_name + ) + res = _maybe_wrap_tensor(tensor) + # TODO this should be done inside AsyncCollectiveTensor to delay the wait() call + if gather_dim != 0: + # torch.cat access the data so we already need to wait here, first do wait + # and then chunk + cat avoid us going through ACT dispatching logic again + if isinstance(res, AsyncCollectiveTensor): + res = res.wait() # type: ignore[attr-defined] + res = torch.cat(torch.chunk(res, group_size, dim=0), dim=gather_dim) + return res + + +def all_gather_tensor_autograd( + self: torch.Tensor, + gather_dim: int, + group: RANK_TYPES, + tag: str = "", +): + """ + Gather tensor data across from all machines and concatenate over ``gather_dim``. + + Note that it currently only supports gather_dim = 0. + + This function is the same as all_gather_tensor but will propagate the + backwards gradient across workers. + + See all_gather_tensor for more details on usage. + """ + group_name = _resolve_group_name(group, tag) + group_size = c10d._get_group_size_by_name(group_name) + + tensor = torch.ops._c10d_functional_autograd.all_gather_into_tensor( + self, group_size, group_name + ) + res = _FromTorchTensor.apply(tensor) + # TODO this should be done inside AsyncCollectiveTensor to delay the wait() call + if gather_dim != 0: + # torch.cat access the data so we already need to wait here, first do wait + # and then chunk + cat avoid us going through ACT dispatching logic again + if isinstance(res, AsyncCollectiveTensor): + res = res.wait() # type: ignore[attr-defined] + res = torch.cat(torch.chunk(res, group_size, dim=0), dim=gather_dim) + return res + + +def reduce_scatter_tensor( + self: torch.Tensor, + reduceOp: str, + scatter_dim: int, + group: RANK_TYPES, + tag: str = "", +): + """ + Reduces the tensor data across all machines in such a way that all get + the final result, then scatter the results to corresponding ranks. + + + The input tensor is left unmodified. + Group can be one of: + List[int]: ranks participating in the collective. + List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. + ProcessGroup: Will perform a collective using the ranks and tag of the PG. + DeviceMesh: Do a SPMD collective over all ranks of the mesh + (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh + :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover + that information and perform collective algebraic optimization. Use other forms of input for that. + """ + group_name = _resolve_group_name(group, tag) + group_size = c10d._get_group_size_by_name(group_name) + + assert self.size(scatter_dim) % group_size == 0, ( + f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size})" + ) + if scatter_dim != 0: + tensor_list = torch.chunk(self, group_size, dim=scatter_dim) + self = torch.cat(tensor_list) + + tensor = torch.ops._c10d_functional.reduce_scatter_tensor( + self, + reduceOp.lower(), + group_size, + group_name, # type: ignore[possibly-undefined] + ) + res = _maybe_wrap_tensor(tensor) + return res + + +def reduce_scatter_tensor_autograd( + self: torch.Tensor, + reduceOp: str, + scatter_dim: int, + group: RANK_TYPES, + tag: str = "", +): + """ + Reduces the tensor data across all machines in such a way that all get + the final result, then scatter the results to corresponding ranks. + + This function is the same as reduce_scatter_tensor but will propagate the + backwards gradient across workers. + + Currently only the "sum" reduceOp is supported. + + See reduce_scatter_tensor for more details on usage. + """ + + group_name = _resolve_group_name(group, tag) + group_size = c10d._get_group_size_by_name(group_name) + + assert self.size(scatter_dim) % group_size == 0, ( + f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}" + ) + if scatter_dim != 0: + tensor_list = torch.chunk(self, group_size, dim=scatter_dim) + self = torch.cat(tensor_list) + + tensor = torch.ops._c10d_functional_autograd.reduce_scatter_tensor( + self, + reduceOp.lower(), + group_size, + group_name, # type: ignore[possibly-undefined] + ) + res = _FromTorchTensor.apply(tensor) + return res + + +def all_reduce_coalesced( + self: list[torch.Tensor], reduceOp: str, group: RANK_TYPES, tag: str = "" +) -> list[torch.Tensor]: + """ + Reduces a list of tensors across all machines in such a way that all get + the final result. + + The all tensors in the input list are left unmodified. + + Group can be one of: + List[int]: ranks participating in the collective. + List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. + ProcessGroup: Will perform a collective using the ranks and tag of the PG. + DeviceMesh: Do a SPMD collective over all ranks of the mesh + (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh + + :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover + that information and perform collective algebraic optimization. Use other forms of input for that. + """ + group_name = _resolve_group_name(group, tag) + tensor_list = torch.ops._c10d_functional.all_reduce_coalesced( # type: ignore[attr-defined] + self, + reduceOp.lower(), + group_name, + ) + return list(map(_maybe_wrap_tensor, tensor_list)) + + +def all_gather_into_tensor_coalesced( + self: list[torch.Tensor], group: RANK_TYPES, tag: str = "" +) -> list[torch.Tensor]: + """ + Gather a list of tensors across from all machines. + + Note that it currently only supports gather_dim = 0. + + The input tensor is left unmodified. + Group can be one of: + List[int]: ranks participating in the collective. + List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. + ProcessGroup: Will perform a collective using the ranks and tag of the PG. + DeviceMesh: Do a SPMD collective over all ranks of the mesh + (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh + + :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover + that information and perform collective algebraic optimization. Use other forms of input for that. + """ + group_name = _resolve_group_name(group, tag) + group_size = c10d._get_group_size_by_name(group_name) + tensor_list = torch.ops._c10d_functional.all_gather_into_tensor_coalesced( # type: ignore[attr-defined] + self, + group_size, + group_name, + ) + return list(map(_maybe_wrap_tensor, tensor_list)) + + +def reduce_scatter_tensor_coalesced( + inputs: list[torch.Tensor], + reduceOp: str, + scatter_dim: list[int], + group: RANK_TYPES, + tag: str = "", +) -> list[torch.Tensor]: + """ + Reduces a list of tensors across all machines in such a way that all get + the final result, then scatter the results to corresponding ranks. + + The input tensors are left unmodified. + Group can be one of: + List[int]: ranks participating in the collective. + List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. + ProcessGroup: Will perform a collective using the ranks and tag of the PG. + DeviceMesh: Do a SPMD collective over all ranks of the mesh + (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh + + :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover + that information and perform collective algebraic optimization. Use other forms of input for that. + """ + group_name = _resolve_group_name(group, tag) + group_size = c10d._get_group_size_by_name(group_name) + + assert len(scatter_dim) == len(inputs) + for idx, (dim, tensor) in enumerate(zip(scatter_dim, inputs)): + assert tensor.size(dim) % group_size == 0, ( + f"input dimension {dim} ({tensor.size(dim)} must be a multiple of group_size {group_size} for tensor at index {idx}" + ) + if dim != 0: + tensor_list = torch.chunk(tensor, group_size, dim=dim) + inputs[idx] = torch.cat(tensor_list) + + tensor_list = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced( # type: ignore[attr-defined] + inputs, + reduceOp.lower(), + group_size, + group_name, # type: ignore[possibly-undefined] + ) + + return list(map(_maybe_wrap_tensor, tensor_list)) + + +# This is a bit unsafe: it checks if the first argument in the schema reports as a non-mutable alias. +# Today, this maps 1:1 with "aten ops that are views". +def _is_view_op(tgt): + assert isinstance(tgt, torch._ops.OpOverload) + # Don't apply the view optimization to any `CompositeImplicitAutograd` ops. + # See issue: https://github.com/pytorch/pytorch/issues/133421 + if torch._C._dispatch_has_kernel_for_dispatch_key( + tgt.name(), torch.DispatchKey.CompositeImplicitAutograd + ): + return False + schema = tgt._schema + if len(schema.arguments) > 0: + first_arg = schema.arguments[0] + # check if op is a view + return first_arg.alias_info is not None and not first_arg.alias_info.is_write + + +def all_to_all_single( + self: torch.Tensor, + output_split_sizes: Optional[list[int]], + input_split_sizes: Optional[list[int]], + group: RANK_TYPES, + tag: str = "", +) -> torch.Tensor: + """ + Each process splits input tensor and then scatters the split list + to all processes in a group. Then concatenate the received tensors from all + the processes in the group and return single output tensor. + + Group can be one of: + List[int]: ranks participating in the collective. + List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. + ProcessGroup: Will perform a collective using the ranks and tag of the PG. + DeviceMesh: Do a SPMD collective over all ranks of the mesh + (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh + + :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover + that information and perform collective algebraic optimization. Use other forms of input for that. + """ + if output_split_sizes is not None: + assert all( + isinstance(size, (int, torch.SymInt)) for size in output_split_sizes + ), output_split_sizes + if input_split_sizes is not None: + assert all( + isinstance(size, (int, torch.SymInt)) for size in input_split_sizes + ), input_split_sizes + group_name = _resolve_group_name(group, tag) + group_size = c10d._get_group_size_by_name(group_name) + if output_split_sizes is None or input_split_sizes is None: + assert output_split_sizes is None and input_split_sizes is None, ( + "output_split_sizes and input_split_sizes must either be " + "specified together or both set to None" + ) + output_split_sizes = [self.shape[0] // group_size] * group_size + input_split_sizes = output_split_sizes + tensor = torch.ops._c10d_functional.all_to_all_single( # type: ignore[attr-defined] + self, + output_split_sizes, + input_split_sizes, + group_name, + ) + return _maybe_wrap_tensor(tensor) + + +def all_to_all_single_autograd( + self: torch.Tensor, + output_split_sizes: Optional[list[int]], + input_split_sizes: Optional[list[int]], + group: RANK_TYPES, + tag: str = "", +) -> torch.Tensor: + """ + Same as all_to_all_single but supports autograd. + """ + if output_split_sizes is not None: + assert all( + isinstance(size, (int, torch.SymInt)) for size in output_split_sizes + ), output_split_sizes + if input_split_sizes is not None: + assert all( + isinstance(size, (int, torch.SymInt)) for size in input_split_sizes + ), input_split_sizes + + group_name = _resolve_group_name(group, tag) + group_size = c10d._get_group_size_by_name(group_name) + if output_split_sizes is None or input_split_sizes is None: + assert output_split_sizes is None and input_split_sizes is None, ( + "output_split_sizes and input_split_sizes must either be " + "specified together or both set to None" + ) + output_split_sizes = [self.shape[0] // group_size] * group_size + input_split_sizes = output_split_sizes + tensor = torch.ops._c10d_functional_autograd.all_to_all_single( # type: ignore[attr-defined] + self, + output_split_sizes, + input_split_sizes, + group_name, + ) + return _FromTorchTensor.apply(tensor) + + +def permute_tensor( + self: torch.Tensor, + src_dst: list[int], + group: RANK_TYPES, + tag: str = "", +) -> torch.Tensor: + """ + Permutes the elements of the tensor according to the given source/destination pairs. `src_dst` should + be defined such that src_dst[m] == n means m sends to n. + + Group can be one of: + List[int]: ranks participating in the collective. + List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. + ProcessGroup: Will perform a collective using the ranks and tag of the PG. + DeviceMesh: Do a SPMD collective over all ranks of the mesh + (DeviceMesh, int): Do a MPMD collective over one + """ + t, rankset, group_size = _expand_group(group, tag) + local_pg = c10d._find_or_create_pg_by_ranks_and_tag(t, rankset, group_size) + + output_split_sizes = [0] * group_size + input_split_sizes = [0] * group_size + for src, dst in enumerate(src_dst): + if src == dist.get_rank(local_pg): + input_split_sizes[dst] = self.numel() + if dst == dist.get_rank(local_pg): + output_split_sizes[src] = self.numel() + + return all_to_all_single(self, output_split_sizes, input_split_sizes, group, tag) + + +class AsyncCollectiveTensor(torch.Tensor): + r""" + A Tensor wrapper subclass that is used to trigger a call to wait + prior to first use of the underlying tensor. + Use it inside functional collective pytorch wrappers like the following: + def functional_collective(self, group, tag): + tag, rankset, group_size = _expand_group(group, tag) + tensor = torch.ops.c10d_functional.{collective}(self, tag, rankset, group_size) + return _maybe_wrap_tensor(tensor) + """ + + elem: torch.Tensor + completed: bool + + __slots__ = ["elem", "completed"] + + @staticmethod + def __new__(cls, elem: torch.Tensor): + r = torch.Tensor._make_wrapper_subclass( + cls, + elem.size(), + strides=elem.stride(), + storage_offset=elem.storage_offset(), + dtype=elem.dtype, + layout=elem.layout, + device=elem.device, + requires_grad=elem.requires_grad, + ) + r.elem = elem + r.completed = False + return r + + def __tensor_flatten__(self): + return ["elem"], None + + def tolist(self): + return self.trigger_wait().tolist() + + @staticmethod + def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): + assert meta is None + elem = inner_tensors["elem"] + return AsyncCollectiveTensor(elem) + + def __coerce_same_metadata_as_tangent__( + self, expected_metadata: Any, expected_type: Optional[type] = None + ): + if expected_type is not torch.Tensor: + return None + + return self.trigger_wait() + + def __repr__(self) -> str: # type: ignore[override] + return f"AsyncCollectiveTensor({self.trigger_wait()})" + + def trigger_wait(self): + if not self.completed: + out = wait_tensor(self.elem) + self.completed = True + return out + else: + return self.elem + + def wait(self) -> torch.Tensor: + return wait_tensor(self.elem) + + def _get_acs_underlying_tensor(self): + """This method enables _functional_collectives_impl to test if a tensor is an ACS""" + return self.elem + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override] + if func == torch.ops.aten.view.default: + # Fast handle aten.view as a lot of view related op goes to aten.view + # eventually, this avoids pytree slowdown + res = func(args[0].elem, args[1]) + wrapper_res = AsyncCollectiveTensor(res) + return wrapper_res + + is_view_op = _is_view_op(func) + + def unwrap(e: AsyncCollectiveTensor): + # wait_tensor is idepotent and will do stream sync only once + if not is_view_op: + return e.trigger_wait() + return e.elem + + def wrap(e: torch.Tensor): + # wait_tensor is idepotent and will do stream sync only once + assert not isinstance(e, AsyncCollectiveTensor) + res = AsyncCollectiveTensor(e) + return res + + unwrapped_args = tree_map_only(AsyncCollectiveTensor, unwrap, args) + unwrapped_kwargs = tree_map_only(AsyncCollectiveTensor, unwrap, kwargs) + + # we don't wrap the result as it doesn't need to be waited on. + out = func(*unwrapped_args, **unwrapped_kwargs) + + # View ops dont require a sync, so we should re-wrap the outputs. + if is_view_op: + out = tree_map_only(torch.Tensor, wrap, out) + + return out + + def numpy(self): # type: ignore[override] + return self.wait().numpy() + + +""" +Utils and infrastructure for tracing support +""" + + +def _expand_group(group: RANK_TYPES, tag: str = "") -> tuple[str, list[int], int]: + """ + _expand_group desugars the different RANK_TYPES types into a canonical format that is traceable. + + By having this be part of the explicit eager codepath, we avoid having to specialize behavior inside + torchdynamo and can still interoperate with processgroup objects or other untraceable forms. + """ + # had to define this hack _inside_ expand_group to avoid + # graph_break [('torch.* op returned non-Tensor int + # caused by 'cast_*` functions being treated as 'torch.*' ops (iiuc) + if TYPE_CHECKING: + + def cast_listlistint(x): + return cast(list[list[int]], x) + + def cast_listint(x): + return cast(list[int], x) + + else: + # fake cast op for use at runtime since dynamo doesn't support real cast + # also, dynamo didn't like encountering 'typing' objects () + # NotImplementedError: argument of type: + def cast_listlistint(x): + return x + + def cast_listint(x): + return x + + rankset: list[int] + if isinstance(group, list): + if isinstance(group[0], list): + nested_list = cast_listlistint(group) + rankset = [] + group_size = -1 + for rs in nested_list: + rankset.extend(rs) + if group_size != -1 and group_size != len(rs): + raise ValueError( + f"group sizes must be identical found {group_size} and {len(rs)}" + ) + group_size = len(rs) + else: + rankset = cast_listint(group) + group_size = len(rankset) + elif isinstance(group, dist.ProcessGroup): + rankset = dist.get_process_group_ranks(group) + group_size = len(rankset) + tag = tag or c10d._get_group_tag(group) + elif isinstance(group, DeviceMesh): + assert group.ndim == 1, ( + "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D" + ) + # TODO: it should run collective in the whole mesh instead of dim 0 + pg = group.get_group() + rankset = dist.get_process_group_ranks(pg) + group_size = len(rankset) + tag = tag or c10d._get_group_tag(pg) + elif isinstance(group, tuple): + if ( + len(group) == 2 + and isinstance(group[0], DeviceMesh) + and isinstance(group[1], int) + ): + dmesh = group[0] + dim = group[1] + pg = dmesh.get_group(dim) + rankset = dist.get_process_group_ranks(pg) + group_size = len(rankset) + tag = tag or c10d._get_group_tag(pg) + else: + raise ValueError("Invalid tuple for group must be (DeviceMesh, int)") + else: + raise ValueError( + "Invalid type for group, must be one of List, Processgroup, DeviceMesh or (DeviceMesh, int)." + ) + + return (tag, rankset, group_size) + + +def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> str: + """ + Given group in RANK_TYPES, return the group name. + """ + # `tag` will be deprecated. See details in: + # https://github.com/pytorch/pytorch/issues/93173#issuecomment-1907095208 + if isinstance(group, dist.ProcessGroup): + return group.group_name + elif isinstance(group, str): + return group + elif isinstance(group, DeviceMesh): + assert group.ndim == 1, ( + "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D" + ) + return group._dim_group_names[0] + elif isinstance(group, tuple): + if ( + len(group) == 2 + and isinstance(group[0], DeviceMesh) + and isinstance(group[1], int) + ): + dmesh = group[0] + dim = group[1] + return dmesh._dim_group_names[dim] + else: + raise ValueError("Invalid tuple for group must be (DeviceMesh, int)") + elif isinstance(group, list): + if not is_torchdynamo_compiling(): + warnings.warn( + "The combination of ranks + tag as process group " + "identifier has been deprecated. Please switch to " + "using ProcessGroup, DeviceMesh, or group name instead.", + FutureWarning, + stacklevel=3, + ) + return c10d._resolve_group_name_by_ranks_and_tag(cast(list[int], group), tag) + else: + raise ValueError(f"Unsupported group type: {type(group)}, {group}") + + +class _FromTorchTensor(torch.autograd.Function): + """ + _FromTorchTensor allows autograd to propagate from a normal Tensor to an + AsyncCollectiveTensor. + """ + + @staticmethod + def forward( # type: ignore[override] + ctx, # pyre-ignore[2]: Parameter must be annotated. + input: torch.Tensor, + ) -> torch.Tensor: + return _maybe_wrap_tensor(input) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: # type: ignore[override] + return grad_output + + +def _are_we_tracing() -> bool: + if is_torchdynamo_compiling(): + return True + # If fake mode is turned on, we are almost definitely compiling/tracing. + if torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) is not None: + return True + return get_proxy_mode() is not None + + +def _maybe_wrap_tensor(self) -> torch.Tensor: + if _are_we_tracing(): + return wait_tensor(self) + res = AsyncCollectiveTensor(self) + return cast(torch.Tensor, res) + + +@contextlib.contextmanager +def allow_inflight_collective_as_graph_input_ctx(value: bool = True): + """ + Context manager to temporarily set whether inflight collectives are allowed as torch.compile graph inputs. + Common use case is when the collective is issued in eager (with `async_op=True`) but waited in compiled region: + ``` + def all_reduce_eager(x): + y = x * x + req = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True) + return y + + + @torch.compile(fullgraph=True) + def all_reduce_wait_compiled(y): + torch.ops.c10d_functional.wait_tensor(y) + return y * y + + + x = torch.ones(1280, 1280, device="cuda") + self.rank + # the context manager ensures that `wait_tensor(y)` will wait on the correct work object + with allow_inflight_collective_as_graph_input_ctx(): + y = all_reduce_eager(x) + z = all_reduce_wait_compiled(y) + ``` + With this context manager, when a collective is called, under the hood the work object of the collective + will be registered in the work registry, and the wait_tensor() in compiled region called on + the output tensor of the collective will wait on the correct work object. + """ + previous = torch._C._distributed_c10d._allow_inflight_collective_as_graph_input() + + try: + torch._C._distributed_c10d._set_allow_inflight_collective_as_graph_input(value) + yield + finally: + torch._C._distributed_c10d._set_allow_inflight_collective_as_graph_input( + previous + ) + + +def _all_gather_into_tensor_coalesced_meta(self, tag, rankset, group_size): + def mk_out_tensor(shard): + out_size = list(shard.size()) + out_size[0] *= group_size + out_tensor = shard.new_empty(out_size) + return out_tensor + + return [mk_out_tensor(t) for t in self] + + +# We now register meta kernels to deal with tracing +def _broadcast_meta(self, *args): + return torch.empty_like(self) + + +def _all_reduce_meta(self, *args): + return torch.empty_like(self) + + +def _wait_tensor_meta(self, *args): + return torch.empty_like(self) + + +def _all_gather_into_tensor_meta(shard, tag, rankset, group_size): + out_size = list(shard.size()) + out_size[0] *= group_size + return shard.new_empty(out_size) + + +def _reduce_scatter_tensor_meta(input, reduce_op, tag, rankset, group_size): + out_size = list(input.size()) + out_size[0] //= group_size + return input.new_empty(out_size) + + +def _all_reduce_coalesced_meta(self, *args): + return [torch.empty_like(t) for t in self] + + +def _all_reduce__meta(inp, *args): + return inp + + +def _broadcast__meta(inp, *args): + return inp + + +def _all_reduce_coalesced__meta(inputs, *args): + return inputs + + +def _reduce_scatter_tensor_coalesced_meta(inputs, reduceOp, tag, rankset, group_size): + def mk_out_tensor(input): + out_size = list(input.size()) + out_size[0] //= group_size + out_tensor = input.new_empty(out_size) + return out_tensor + + return [mk_out_tensor(t) for t in inputs] + + +# NB: We often say all_to_all has dynamic output size, but this is not +# technically true: instead, what typically happens is you manually +# communicate the output_split_sizes ahead of time (which is dynamic), +# but then you pass those sizes explicitly, and the all to all itself +# isn't dynamic, it just follows the specified output splits +def _all_to_all_single_meta( + input, output_split_sizes, input_split_sizes, *args, **kwargs +): + if output_split_sizes is None: + return input.new_empty(input.size()) + else: + for s in output_split_sizes: + torch._check_is_size(s) + out_size = list(input.size()) + out_size[0] = sum(output_split_sizes) + return input.new_empty(out_size) + + +def _all_gather_into_tensor_out_native_meta(input, group_size, group_name, *, out): + shape = list(input.size()) + shape[0] *= group_size + return input.new_empty(shape) + + +def _all_gather_into_tensor_native_meta(input, group_size, group_name): + shape = list(input.size()) + shape[0] *= group_size + return input.new_empty(shape) + + +def _all_gather_into_tensor_coalesced_native_meta(inputs, group_size, group_name): + return [ + _all_gather_into_tensor_native_meta(input, group_size, group_name) + for input in inputs + ] + + +def _reduce_scatter_tensor_native_meta(inp, reduce_op, group_size, group_name): + shape = list(inp.size()) + shape[0] //= group_size + return inp.new_empty(shape) + + +def _reduce_scatter_tensor_coalesced_native_meta( + inputs, reduce_op, group_size, group_name +): + return [ + _reduce_scatter_tensor_native_meta(inp, reduce_op, group_size, group_name) + for inp in inputs + ] + + +if not torch._running_with_deploy(): + # Library MUST be defined at module scope or it doesn't work + # Creating a "DEF" Library always crashes torch::deploy so we create our + # Library instances here guarded against running inside it + lib_impl = torch.library.Library("_c10d_functional", "IMPL") + lib_impl.impl("all_reduce", _all_reduce_meta, "Meta") + lib_impl.impl("all_reduce_", _all_reduce__meta, "Meta") + lib_impl.impl("all_reduce_coalesced", _all_reduce_coalesced_meta, "Meta") + lib_impl.impl("all_reduce_coalesced_", _all_reduce_coalesced__meta, "Meta") + lib_impl.impl("wait_tensor", _wait_tensor_meta, "Meta") + lib_impl.impl( + "all_gather_into_tensor_out", _all_gather_into_tensor_out_native_meta, "Meta" + ) + lib_impl.impl("all_gather_into_tensor", _all_gather_into_tensor_native_meta, "Meta") + lib_impl.impl( + "all_gather_into_tensor_coalesced", + _all_gather_into_tensor_coalesced_native_meta, + "Meta", + ) + lib_impl.impl("reduce_scatter_tensor", _reduce_scatter_tensor_native_meta, "Meta") + lib_impl.impl( + "reduce_scatter_tensor_coalesced", + _reduce_scatter_tensor_coalesced_native_meta, + "Meta", + ) + lib_impl.impl("all_to_all_single", _all_to_all_single_meta, "Meta") + lib_impl.impl("broadcast", _broadcast_meta, "Meta") + lib_impl.impl("broadcast_", _broadcast__meta, "Meta") + + # mark these ops has side effect so that they won't be removed by DCE + torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor.default) + torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor) + + # Register legacy ops for backward compatibility + # TODO(yifu): remove these in functional collective beta release + legacy_lib = torch.library.Library("c10d_functional", "DEF") + legacy_lib_impl = torch.library.Library("c10d_functional", "IMPL") + ops_defs = [ + "broadcast(Tensor self, int src, str tag, int[] ranks, int group_size) -> Tensor", + "all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor", + "all_reduce_coalesced(Tensor[] self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]", + "wait_tensor(Tensor self) -> Tensor", + "all_gather_into_tensor(Tensor shard, str tag, int[] ranks, int group_size) -> Tensor", + "all_gather_into_tensor_coalesced(Tensor[] input, str tag, int[] ranks, int group_size) -> Tensor[]", + "reduce_scatter_tensor(Tensor input, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor", + "reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]", + "all_to_all_single(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor", # noqa: B950 + ] + + my_module = sys.modules[__name__] + for op_def in ops_defs: + op_name = op_def[0 : op_def.index("(")] + backend_impl = getattr(fun_col_impl, f"_{op_name}") + legacy_lib.define(op_def, tags=torch.Tag.pt2_compliant_tag) + legacy_lib_impl.impl(op_name, backend_impl, "CompositeImplicitAutograd") + +else: + warnings.warn( + "PyTorch Distributed functional collectives do not work with torch::deploy." + ) + + +""" +Dynamo Remappings allow seamless translation from non-functional collectives of supportable form into +functional collective calls followed by inplace copy ops, allowing them to be traced into a functional graph. + +We implement this by writing a decomposition and teaching dynamo how to associate it to a corresponding op via +the mapping dict below. + +These schemas intentionally match torch.distributed.distributed_c10d.* ops that we are trying to remap from +""" + + +def all_gather_tensor_inplace( + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + group=None, # TODO add a type, + async_op: bool = False, + tag: str = "", + gather_dim: int = 0, +): + assert not async_op, ( + "Can't remap async version of inplace op to functional collective" + ) + + group = group or dist.group.WORLD + assert group is not None + + return output_tensor.copy_(all_gather_tensor(input_tensor, gather_dim, group, tag)) + + +def reduce_scatter_tensor_inplace( + output: torch.Tensor, + input: torch.Tensor, + op: str = "sum", # TODO type is actually c10d ReduceOp. is this ok? + group=None, # TODO add a type + async_op: bool = False, + scatter_dim: int = 0, + tag: str = "", +): + assert not async_op, ( + "Can't remap async version of inplace op to functional collective" + ) + + group = group or dist.group.WORLD + assert group is not None + + return output.copy_(reduce_scatter_tensor(input, op, scatter_dim, group, tag)) + + +REDUCE_OP_TO_STR = { + dist.ReduceOp.SUM: "sum", + dist.ReduceOp.AVG: "avg", + dist.ReduceOp.PRODUCT: "product", + dist.ReduceOp.MIN: "min", + dist.ReduceOp.MAX: "max", + dist.ReduceOp.BAND: "band", + dist.ReduceOp.BOR: "bor", + dist.ReduceOp.BXOR: "bxor", +} + + +def all_reduce_inplace( + tensor: torch.Tensor, + op: str = "sum", + group=None, + async_op: bool = False, + tag: str = "", +): + assert not async_op, ( + "Can't remap async version of inplace op to functional collective" + ) + + group = group or dist.group.WORLD + assert group is not None + + return tensor.copy_(all_reduce(tensor, op, group, tag)) + + +def all_to_all_inplace( + output: torch.Tensor, + input: torch.Tensor, + output_split_sizes=None, + input_split_sizes=None, + group=None, + async_op=False, + tag: str = "", +): + assert not async_op, ( + "Can't remap async version of inplace op to functional collective" + ) + + group = group or dist.group.WORLD + assert group is not None + + return output.copy_( + all_to_all_single( + input, + output_split_sizes, + input_split_sizes, + group, + tag, + ) + ) + + +def all_gather_inplace( + tensor_list: list[torch.Tensor], + tensor: torch.Tensor, + group=None, + async_op=False, + tag: str = "", +): + assert not async_op, ( + "Can't remap async version of inplace op to functional collective" + ) + assert all(t.size(0) == tensor.size(0) for t in tensor_list), ( + "Remapping variable size all_gather is not yet supported" + ) + + group = group or dist.group.WORLD + assert group is not None + + output = all_gather_tensor(tensor, 0, group, tag) + + # Use aten.slice instead of aten.split because the latter causes + # tensor.shape(0) to be unnecessarily baked in when it's a SymInt. + output_splits = [] + offset = 0 + for t in tensor_list: + output_splits.append(output[offset : offset + t.size(0)]) + offset += t.size(0) + for dst, src in zip(tensor_list, output_splits): + dst.copy_(src) + return tensor_list + + +from torch.distributed.distributed_c10d import ( + _all_gather_base as legacy_all_gather_base, + _reduce_scatter_base as legacy_reduce_scatter_base, + all_gather as legacy_all_gather, + all_gather_into_tensor as legacy_allgather, + all_reduce as legacy_allreduce, + all_to_all_single as legacy_all_to_all_single, + reduce_scatter_tensor as legacy_reducescatter, +) + + +# This dict should contain sets of functions that dynamo is allowed to remap. +# Functions in this set should accept the same args/kwargs 1:1 as their mapping. +traceable_collective_remaps = { + legacy_allgather: all_gather_tensor_inplace, + legacy_reducescatter: reduce_scatter_tensor_inplace, + legacy_allreduce: all_reduce_inplace, + legacy_all_to_all_single: all_to_all_inplace, + legacy_all_gather: all_gather_inplace, + legacy_reduce_scatter_base: reduce_scatter_tensor_inplace, + legacy_all_gather_base: all_gather_tensor_inplace, +} diff --git a/phivenv/Lib/site-packages/torch/distributed/_functional_collectives_impl.py b/phivenv/Lib/site-packages/torch/distributed/_functional_collectives_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..35f3d1d684eff75603b7fea39409a5258cb0771a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_functional_collectives_impl.py @@ -0,0 +1,117 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch +import torch.distributed.distributed_c10d as c10d + + +""" +This file contains the op impls for the legacy (c10d_functional) functional collectives. +These impls simply call into the native (_c10d_functional) functional collectives. +""" + + +def _broadcast(input, src, tag, ranks, group_size): + group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) + return torch.ops._c10d_functional.broadcast( + input, + src, + group_name, + ) + + +def _all_reduce(input, reduce_op, tag, ranks, group_size): + group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) + return torch.ops._c10d_functional.all_reduce( + input, + reduce_op, + group_name, + ) + + +def _all_reduce_coalesced(inputs, reduce_op, tag, ranks, group_size): + group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) + return torch.ops._c10d_functional.all_reduce_coalesced( + inputs, + reduce_op, + group_name, + ) + + +def _all_gather_into_tensor(input, tag, ranks, group_size): + group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) + return torch.ops._c10d_functional.all_gather_into_tensor( + input, + group_size, + group_name, + ) + + +def _all_gather_into_tensor_coalesced(input, tag, ranks, group_size): + group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) + return torch.ops._c10d_functional.all_gather_into_tensor_coalesced( + input, + group_size, + group_name, + ) + + +def _reduce_scatter_tensor( + input: torch.Tensor, + reduce_op: str, + tag: str, + ranks: list[int], + group_size: int, +): + group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) + return torch.ops._c10d_functional.reduce_scatter_tensor( + input, + reduce_op, + group_size, + group_name, + ) + + +def _reduce_scatter_tensor_coalesced( + inputs: list[torch.Tensor], + reduce_op: str, + tag: str, + ranks: list[int], + group_size: int, +): + group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) + return torch.ops._c10d_functional.reduce_scatter_tensor_coalesced( + inputs, + reduce_op, + group_size, + group_name, + ) + + +def _all_to_all_single( + input: torch.Tensor, + output_split_sizes: Optional[list[int]], + input_split_sizes: Optional[list[int]], + tag: str, + ranks: list[int], + group_size: int, +): + if output_split_sizes is None or input_split_sizes is None: + assert output_split_sizes is None and input_split_sizes is None, ( + "output_split_sizes and input_split_sizes must either be " + "specified together or both set to None" + ) + output_split_sizes = [input.shape[0] // group_size] * group_size + input_split_sizes = output_split_sizes + + group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) + return torch.ops._c10d_functional.all_to_all_single( + input, + output_split_sizes, + input_split_sizes, + group_name, + ) + + +def _wait_tensor(tensor: torch.Tensor) -> torch.Tensor: + return torch.ops._c10d_functional.wait_tensor(tensor) diff --git a/phivenv/Lib/site-packages/torch/distributed/_serialization.py b/phivenv/Lib/site-packages/torch/distributed/_serialization.py new file mode 100644 index 0000000000000000000000000000000000000000..425b80b5460c54ecf5ab2e3751673f9a5ff5344b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_serialization.py @@ -0,0 +1,155 @@ +import pickle +from dataclasses import dataclass +from io import BufferedIOBase +from typing import Any + +import torch +import torch._weights_only_unpickler as _weights_only_unpickler +from torch.serialization import _load, _save, DEFAULT_PROTOCOL, MAP_LOCATION + + +__all__: list[str] = [] + + +@dataclass +class _Entry: + key: str + is_storage: bool + length: int + + +_weights_only_unpickler._add_safe_globals([_Entry]) + + +class _PseudoZipFile: + def __init__(self) -> None: + self.records: dict[str, tuple[object, int]] = {} + + def write_record(self, key: str, data: object, length: int) -> None: + self.records[key] = (data, length) + + def write_to(self, f: BufferedIOBase) -> None: + entries = [] + for key, (data, length) in self.records.items(): + entries.append( + _Entry( + key=key, + is_storage=isinstance(data, torch.UntypedStorage), + length=length, + ) + ) + + pickle.dump(entries, f, protocol=DEFAULT_PROTOCOL) + + for key, (data, length) in self.records.items(): + if isinstance(data, bytes): + f.write(data) + elif isinstance(data, str): + f.write(data.encode("utf-8")) + elif isinstance(data, torch.UntypedStorage): + data._write_file(f, False, False, 1) + else: + raise TypeError(f"unknown type: {type(data)}") + + def read_from(self, f: BufferedIOBase) -> None: + entries = _weights_only_unpickler.load(f) + + for entry in entries: + data = f.read(entry.length) + if entry.is_storage: + storage = torch.frombuffer( + data, + dtype=torch.uint8, + ).untyped_storage() + + self.records[entry.key] = ( + storage, + entry.length, + ) + else: + self.records[entry.key] = (data, entry.length) + + def has_record(self, key: str) -> bool: + return key in self.records + + def get_record(self, key: str) -> object: + return self.records[key][0] + + def get_storage_from_record( + self, key: str, _length: int, _type: int + ) -> torch.Tensor: + return torch.tensor(self.records[key][0], dtype=torch.uint8) + + def serialization_id(self) -> str: + return "torchft" + + +def _streaming_save( + obj: object, + f: BufferedIOBase, + pickle_module: Any = pickle, + pickle_protocol: int = DEFAULT_PROTOCOL, +) -> None: + """ + Save the object to a file-like object in a streaming fashion compatible with + network sockets. + + This behaves similarly to :func:`torch.save` with a few notable differences: + + * A non-seekable file like object can be used when loading. + * No forwards/backwards compatibility is provided for the serialization + format. This is only intended to be used with a single version of PyTorch + with transient storage (i.e. sockets or temp files). + * mmap is not supported + + See :func:`torch.save` for more details on specific arguments. + """ + + zip_file = _PseudoZipFile() + _save( + obj, + zip_file=zip_file, + pickle_module=pickle_module, + pickle_protocol=pickle_protocol, + _disable_byteorder_record=False, + ) + zip_file.write_to(f) + + +def _streaming_load( + f: BufferedIOBase, + map_location: MAP_LOCATION = None, + pickle_module: Any = None, + *, + weights_only: bool = True, + **pickle_load_args: Any, +) -> object: + """ + Load the object from a file-like object in a streaming fashion compatible with + network sockets. + + See :func:`_streaming_save` for more details about the streaming behavior. + + See :func:`torch.load` for more details on specific arguments. + """ + if weights_only: + if pickle_module is not None: + raise RuntimeError( + "Can not safely load weights when explicit pickle_module is specified" + ) + pickle_module = _weights_only_unpickler + else: + if pickle_module is None: + pickle_module = pickle + + if "encoding" not in pickle_load_args.keys(): + pickle_load_args["encoding"] = "utf-8" + + zip_file = _PseudoZipFile() + zip_file.read_from(f) + return _load( + zip_file=zip_file, + map_location=map_location, + pickle_module=pickle_module, + **pickle_load_args, + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/_state_dict_utils.py b/phivenv/Lib/site-packages/torch/distributed/_state_dict_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2e88528f49cccfe14e25a8a050705d88ea6b5b8e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_state_dict_utils.py @@ -0,0 +1,819 @@ +# mypy: allow-untyped-defs +import copy +import io +import math +import weakref +from collections.abc import Mapping, MutableMapping +from typing import Any, Callable, cast, NamedTuple, Optional, TYPE_CHECKING, Union + +import torch +import torch.cuda._pin_memory_utils as pin_memory_utils +import torch.distributed as dist +import torch.nn.functional as F +from torch.distributed._functional_collectives import AsyncCollectiveTensor + + +if dist.is_available() or TYPE_CHECKING: + from torch.distributed import distributed_c10d + from torch.distributed._shard.sharded_tensor import ShardedTensor + from torch.distributed.tensor import distribute_tensor, DTensor, Replicate + from torch.distributed.tensor._utils import compute_local_shape_and_global_offset + + +def _identity_func( + obj: torch.Tensor, + pg: Optional[dist.ProcessGroup], + device: Optional[torch.device], + companion_obj: Any, +) -> torch.Tensor: + return obj + + +def _all_gather_sharded_tensor( + sharded_tensor: "ShardedTensor", + pg: Optional[dist.ProcessGroup] = None, + device: Optional[torch.device] = None, +) -> torch.Tensor: + if pg is None: + pg = distributed_c10d._get_default_group() + world_size = dist.get_world_size(pg) + shards = sharded_tensor.local_shards() + dim_0_size = sharded_tensor.size()[0] # type: ignore[index] + tensor_numel = sharded_tensor.size().numel() # type: ignore[union-attr] + chunk_size = math.ceil(dim_0_size / world_size) * tensor_numel // dim_0_size + pg_device = ( + distributed_c10d._get_pg_default_device(pg) if device is None else device + ) + if shards: + local_tensor = shards[0].tensor.flatten() + if local_tensor.device.type != pg_device.type: + local_tensor = local_tensor.to(pg_device) + num_padding = chunk_size - local_tensor.numel() + if num_padding > 0: + local_tensor = F.pad(local_tensor, [0, num_padding]) + else: + local_tensor = torch.zeros( + chunk_size, dtype=sharded_tensor.dtype, device=pg_device + ) + + tensor = torch.empty( + chunk_size * world_size, + dtype=local_tensor.dtype, + device=pg_device, + ) + dist.all_gather_into_tensor(tensor, local_tensor, group=pg) + + tensor = tensor.narrow(0, 0, tensor_numel).reshape(sharded_tensor.size()) + return tensor + + +class CompanionMismatch(Exception): + pass + + +def _iterate_state_dict( + iter_object: Any, + sharded_tensor_func: Callable, + dtensor_func: Callable, + tensor_func: Callable, + *, + pg: Optional[dist.ProcessGroup] = None, + device: Optional[torch.device] = None, + cpu_offload: bool = False, + companion_obj: Any = None, + ranks_only: tuple[int, ...] = (), + type_check: bool = True, + non_blocking: bool = True, +) -> dict[str, Any]: + """Iterate through the state dict, applying the given functions to each tensor type. + + Args: + iter_object (Any): the target state_dict. + sharded_tensor_func (Callable): the function to apply to ShardedTensor + dtensor_func (Callable): the function to apply to DTensor + tensor_func (Callable): the function to apply to Tensor + pg (Optional[dist.ProcessGroup]): process group passed to tensor functions + device (Optional[torch.device]): device passed to tensor functions + cpu_offload (bool): whether to offload the tensors to CPU memory. This option is ignored + if a companion_obj is supplied. + companion_obj (Any): A companion object to the state dict. If this object + is supplied, we attempt to copy the tensor to the companion object. + ranks_only (Tuple[int, ...]): if this tuple is empty, all ranks will + have the same state_dicts. Otherwise only ranks that in ``ranks_only`` + have the same state_dicts. Other ranks will get empty state_dicts. + type_check (bool): check if the instance data type is a supported type + that can be saved by DCP. The current supported data types are + torch.Tensor, DTensor, int, float, str, list, dict, None. + non_blocking (bool): whether to use non-blocking copy when copying to the companion object. + """ + # TODO: should we use pytree? + cpu_device = torch.device("cpu") + if isinstance(iter_object, ShardedTensor): + ret = sharded_tensor_func(iter_object, pg, device, companion_obj) + elif isinstance(iter_object, DTensor): + ret = dtensor_func(iter_object, pg, device, companion_obj) + elif isinstance(iter_object, torch.Tensor): + ret = tensor_func(iter_object, pg, device, companion_obj) + elif ( + isinstance(iter_object, (int, float, str, bytes, io.BytesIO)) + or iter_object is None + ): + ret = iter_object + elif isinstance(iter_object, dict): + if companion_obj is not None and ( + not isinstance(companion_obj, dict) + or set(companion_obj.keys()) != set(iter_object.keys()) + ): + msg = ( + "" + if isinstance(companion_obj, dict) + else f"{set(companion_obj.keys())=} {set(iter_object.keys())=}" + ) + raise CompanionMismatch(msg) + + ret = { + key: _iterate_state_dict( + value, + sharded_tensor_func, + dtensor_func, + tensor_func, + pg=pg, + device=device, + cpu_offload=cpu_offload, + companion_obj=companion_obj[key] if companion_obj is not None else None, + ranks_only=ranks_only, + type_check=type_check, + non_blocking=non_blocking, + ) + for key, value in iter_object.items() + } + elif isinstance(iter_object, (list, tuple)): + if companion_obj is not None and ( + not isinstance(companion_obj, (list, tuple)) + or len(companion_obj) != len(iter_object) + ): + raise CompanionMismatch + + ret = [ + _iterate_state_dict( + v, + sharded_tensor_func, + dtensor_func, + tensor_func, + pg=pg, + device=device, + cpu_offload=cpu_offload, + companion_obj=companion_obj[idx] if companion_obj is not None else None, + ranks_only=ranks_only, + type_check=type_check, + non_blocking=non_blocking, + ) + for idx, v in enumerate(iter_object) + ] + if isinstance(iter_object, tuple): + ret = tuple(ret) + elif not type_check: + ret = copy.deepcopy(iter_object) + else: + raise ValueError(f"Unexpected value type {type(iter_object)}") + + if not ranks_only or dist.get_rank(pg) in ranks_only: + if isinstance(ret, torch.Tensor): + if cpu_offload and companion_obj is None: + ret = ret.to(cpu_device) + + if companion_obj is not None: + if isinstance(companion_obj, DTensor): + assert isinstance(ret, DTensor) + companion_obj._local_tensor.copy_( + ret._local_tensor, non_blocking=non_blocking + ) + elif isinstance(companion_obj, ShardedTensor): + assert isinstance(ret, ShardedTensor) + for idx, shard in enumerate(companion_obj.local_shards()): + shard.tensor.copy_( + ret.local_shards()[idx].tensor, non_blocking=non_blocking + ) + else: + companion_obj.copy_(ret, non_blocking=non_blocking) + ret = companion_obj + else: + ret = {} if isinstance(ret, dict) else None + + return ret + + +def _gather_state_dict( + state_dict: dict[str, Any], + *, + pg: Optional[dist.ProcessGroup] = None, + device: Optional[torch.device] = None, + cpu_offload: bool = False, + ranks_only: tuple[int, ...] = (), + type_check: bool = True, +) -> dict[str, Any]: + """ + Given a state_dict, this API gathers all the ShardedTensors or DTensors in + the state_dict. + + + Args: + state_dict (Dict[str, Any]): the target sharded state_dict. + pg (Optional[dist.ProcessGroup]): the process group that is used to + gather ShardedTensor. Note that gathering a DTensor will use + the DeviceMesh. So this argument will be ignored when gathering a + DTensor. + device: (Optional[torch.device]): the device that is used to + perform allgather for ShardedTensor. Note that gathering a DTensor + will use the DeviceMesh. So this argument will be ignored when + gathering a DTensor. + cpu_offload (bool): whether to offload the tensors to CPU memory. The + default value is False. + ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will + have the same state_dicts. Otherwise only ranks that in ``ranks_only`` + have the same state_dicts. Other ranks will get empty state_dicts. + type_check: (bool): check if the instance data type is a supported type + that can be saved by DCP. The current supported data types are + torch.Tensor, DTensor, int, float, str, list, dict, None. + + Returns: + The gathered state dictionary. + """ + + def sharded_tensor_func(value, pg, device, companion_obj): + # ShardedTensor does not seem to record the original device type. + # So if the tensor is moved to CPU, we won't know the original type. + # As a result, we have to rely on the user to tell us the correct one. + cpu_device = torch.device("cpu") + output_tensor = _all_gather_sharded_tensor(value, pg, device) + local_shard_device = ( + value.local_shards()[0].tensor.device + if value.local_shards() + else cpu_device + ) + if output_tensor.device != local_shard_device: + value = output_tensor.to(local_shard_device) + else: + value = output_tensor + return value + + def dtensor_func(value, pg, device, companion_obj): + if value.device != value.device_mesh.device_type: + value = value.to(value.device_mesh.device_type) + # FSDP all_gather: [Shard(0)] -> [Replicate()] + # HSDP all_gather: [Replicate(), Shard(0)] -> [Replicate(), Replicate()] + # 2D FSDP + TP all_gather: + # - [Shard(0), Shard(n)] -> [Replicate(), Replicate()] + # - [Shard(0), Replicate()] -> [Replicate(), Replicate()] + placements = [Replicate() for _ in value.placements] + value = value.redistribute( + device_mesh=value.device_mesh, + placements=placements, + ) + # Call `wait()` to force the tensor to be synchronous with respect + # to the main stream. + # See the discussion in https://github.com/pytorch/pytorch/pull/117799. + value = value.to_local() + if isinstance(value, AsyncCollectiveTensor): + value = value.wait() + return value + + return _iterate_state_dict( + state_dict, + sharded_tensor_func, + dtensor_func, + _identity_func, + pg=pg, + device=device, + cpu_offload=cpu_offload, + ranks_only=ranks_only, + type_check=type_check, + ) + + +def _offload_state_dict_to_cpu( + state_dict: dict[str, Any], + *, + ranks_only: tuple[int, ...] = (), + type_check: bool = True, +) -> dict[str, Any]: + """ + Given a state_dict, this API offload all the tensors to CPU memory. + + Args: + state_dict (Dict[str, Any]): the target state_dict. + pg (Optional[dist.ProcessGroup]): the process group that is used to + gather ShardedTensor. Note that gathering a DTensor will use + the DeviceMesh. So this argument will be ignored when gathering a + DTensor. + ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will + have the same state_dicts. Otherwise only ranks that in ``ranks_only`` + have the same state_dicts. Other ranks will get empty state_dicts. + type_check: (bool): check if the instance data type is a supported type + that can be saved by DCP. The current supported data types are + torch.Tensor, DTensor, int, float, str, list, dict, None. + + Returns: + The gathered state dictionary. + """ + + ret = _iterate_state_dict( + state_dict, + _identity_func, + _identity_func, + _identity_func, + pg=None, + device=None, + cpu_offload=True, + ranks_only=ranks_only, + type_check=type_check, + ) + return ret + + +@torch.no_grad() +def _copy_state_dict( + state_dict: dict[str, Any], + copy_state_dict: dict[str, Any], + non_blocking: bool = False, + type_check: bool = True, +) -> dict[str, Any]: + """ + Copies all tensors in a given state dict into a different state_dict with the + same structure. Additionally, a copied state dict with the same value references + is returned. Editing the keys on this state dict will not affect the + passed in copy_state_dict (but the value references are the same). + + .. warning:: + It is expected by this function that state_dict and copy_state_dict share + the same structure and data types. + + .. warning:: + The current supported data types are + torch.Tensor, DTensor, int, float, str, list, dict, None. + + Args: + state_dict (Dict[str, Any]): the target state_dict. + copy_state_dict (Dict[str, Any]): + The state dict we are copying into. This state_dict must have exactly + the same structure as the source `state_dict`. + non_blocking: (bool): Whether copy ops should be performed asynchronously + type_check (bool): check if the instance data type is a supported type + that can be saved by DCP. The current supported data types are + torch.Tensor, DTensor, int, float, str, list, dict, None. + + Returns: + State Dict copy + """ + + return _iterate_state_dict( + state_dict, + _identity_func, + _identity_func, + _identity_func, + pg=None, + device=None, + cpu_offload=False, + ranks_only=(), + companion_obj=copy_state_dict, + type_check=type_check, + non_blocking=non_blocking, + ) + + +@torch.no_grad() +def _create_cpu_state_dict( + state_dict: dict[str, Any], pin_memory: bool = False, share_memory: bool = False +) -> dict[str, Any]: + """ + Given a state_dict, create another state_dict with the same structure and elements. + However, all tensors in the returned state_dict are new tensors on CPU. These + tensors can be placed on pin_memory or share_memory based on the provided arguments. + + .. warning:: + Setting both `pin_memory` and `share_memory` to True significantly increases the + latency of this method because of the nuances which require us to register memory + as pinned directly as opposed to relying on the pin_memory cache allocator. This + option should only be used for long lived tensors which are required to be shared. + This is not the case as long as at least one of `pin_memory` or `share_memory` is + set to False. + + """ + + def tensor_func( + obj: torch.Tensor, + pg: Optional[dist.ProcessGroup], + device: Optional[torch.device], + _: Any, + ) -> torch.Tensor: + if len(obj.size()) == 0: + return torch.tensor(0, dtype=obj.dtype) + + # sometimes, a tensor might have non-zero size and 0 numel. In this case, pinning memory will fail + # so we take a best guess at how to replicate the tensor below to maintain symmetry in the returned + # state dict. + if obj.numel() == 0 or obj.data_ptr() == 0: + t = torch.zeros_like(obj, device="cpu") + if share_memory: + t = t.share_memory_() + return t + + if share_memory: + t = torch.empty(*tuple(obj.size()), dtype=obj.dtype) + t = t.share_memory_() + if pin_memory: + pin_memory_utils.pin_memory(t.data_ptr(), t.numel() * t.element_size()) + weakref.finalize(t, pin_memory_utils.unpin_memory, t) + + return t + elif pin_memory: + return torch.empty(*tuple(obj.size()), dtype=obj.dtype).pin_memory() + else: + return torch.empty(*tuple(obj.size()), dtype=obj.dtype) + + def dtensor_func( + obj: DTensor, + pg: Optional[dist.ProcessGroup], + device: Optional[torch.device], + _: Any, + ) -> DTensor: + if len(obj.size()) == 0: + return obj + + if obj.device != torch.device("cpu"): + ret = cast(DTensor, obj.to(device="cpu")) + else: + ret = copy.deepcopy(obj) + ret._local_tensor = tensor_func(ret._local_tensor, pg, device, None) + return ret + + def sharded_tensor_func( + obj: ShardedTensor, + pg: Optional[dist.ProcessGroup], + device: Optional[torch.device], + _: Any, + ) -> ShardedTensor: + if not obj.local_shards(): + return obj + + if obj.device != torch.device("cpu"): + ret = obj.to(device="cpu") + else: + ret = copy.deepcopy(obj) + + for shards in ret.local_shards(): + shards.tensor = tensor_func(shards.tensor, pg, device, None) + + return ret + + ret = _iterate_state_dict( + state_dict, + sharded_tensor_func, + dtensor_func, + tensor_func, + pg=None, + device=None, + cpu_offload=False, + ranks_only=(), + type_check=False, + ) + return ret + + +def _check_state_dict_similarity( + state_dict: dict[str, Any], + compared_state_dict: dict[str, Any], +) -> bool: + """ + Given two state_dicts, check if the structures are the same. And + if a [key, tensor] pair exist in one state_dict there must be + the a corresponding pait, [key, other_tensor], in the other state_dict, + where tensor and other_tensor have the same size and dtype. + + Return the check result. + """ + + def tensor_func( + obj: torch.Tensor, + pg: Optional[dist.ProcessGroup], + device: Optional[torch.device], + companion_obj: Any, + ) -> torch.Tensor: + if companion_obj.dtype != obj.dtype or companion_obj.size() != obj.size(): + raise CompanionMismatch + return obj + + try: + _iterate_state_dict( + state_dict, + _identity_func, + _identity_func, + tensor_func, + pg=None, + device=None, + cpu_offload=False, + ranks_only=(), + companion_obj=compared_state_dict, + type_check=False, + ) + except CompanionMismatch: + return False + + return True + + +class _TensorInfo(NamedTuple): + size: torch.Size + dtype: torch.dtype + + +def _broadcast_tensors( + full_state_dict: dict[str, Any], + local_state_dict: dict[str, Any], + keys: list[str], + device: torch.device, + pg: Optional[dist.ProcessGroup] = None, +) -> None: + if pg is None: + pg = dist.distributed_c10d._get_default_group() + pg_device = ( + device + if device.type in {pg_device.type for pg_device in pg._device_types} + else pg._device_types[0] + ) + + tensors: list[torch.Tensor] = [] + for key in keys: + if dist.get_rank() == 0: + full_state = full_state_dict[key] + assert isinstance(full_state, torch.Tensor) + full_tensor = full_state.detach().to(pg_device) + else: + tensor_info = full_state_dict[key] + full_tensor = torch.empty( + size=tensor_info.size, + device=pg_device, + dtype=tensor_info.dtype, + ) + + tensors.append(full_tensor) + + if (local_state := local_state_dict.get(key)) is None: + continue + + local_state_dict[key] = ( + (local_state, full_tensor) + if isinstance(local_state, DTensor) + else full_tensor + ) + + if len(tensors) > 1: + dist._broadcast_coalesced(pg, tensors, 500, 0) + else: + dist.broadcast(tensors[0], src=0, group=pg) + + if pg_device != device: + for key, full_tensor in zip(keys, tensors): + if (local_state := local_state_dict.get(key)) is not None: + local_state_dict[key] = ( + (local_state[0], full_tensor.to(device)) + if ( + isinstance(local_state, tuple) + and isinstance(local_state[0], DTensor) + ) + else full_tensor.to(device) + ) + + _distribute_tensors(local_state_dict, keys, device, pg) + + +def _distribute_tensors( + local_state_dict: dict[str, Any], + keys: list[str], + device: torch.device, + pg: Optional[dist.ProcessGroup] = None, +) -> None: + if pg is None: + pg = dist.distributed_c10d._get_default_group() + for key in keys: + _local_state = local_state_dict.get(key, None) + if _local_state is None or torch.is_tensor(_local_state): + continue + + local_state = _local_state[0] + full_tensor = _local_state[1] + + shape, offset = compute_local_shape_and_global_offset( + full_tensor.shape, local_state.device_mesh, local_state.placements + ) + slices = [ + slice(cur_offset, cur_offset + cur_shape) + for cur_shape, cur_offset in zip(shape, offset) + ] + if local_state.is_meta: + # Use .clone() here rather than view to clone and return only the sliced portion, minimizing memory access and cost. + local_tensor = full_tensor[tuple(slices)].detach().clone() + # TODO: currently, we cannot handle strided sharding if the dp dimension is not even. For example, + # one of the case that is not yet supported is when placements = (Shard(0), _StridedShard(0, sf=2)). + ret = DTensor.from_local( + local_tensor, + local_state.device_mesh, + local_state.placements, + shape=local_state.shape, + stride=local_state.stride(), + ) + else: + ret = local_state + # Copy full_tensor[slices] into local_state.to_local() to reduce memory footprint. + ret.to_local().copy_(full_tensor[tuple(slices)]) + local_state_dict[key] = ret + + +def _broadcast_state_dict( + full_state_dict: dict[str, Any], + local_state_dict: dict[str, Any], + device: torch.device, + pg: Optional[dist.ProcessGroup] = None, + strict: bool = False, + cpu_offload: bool = False, +) -> None: + # Broadcast from rank0's `full_state_dict` to all ranks' `local_state_dict`. + # If strict is True, any keys in `local_state_dict` but not in `full_state_dict` + # will be removed from `local_state_dict`. + ret = {} + if dist.get_rank() == 0: + for key, value in full_state_dict.items(): + if not torch.is_tensor(value): + ret[key] = value + elif value.dim() == 0: + ret[key] = value.cpu() + else: + ret[key] = _TensorInfo(value.size(), value.dtype) + + broadcast_list = [ret] + dist.broadcast_object_list(broadcast_list, src=0, group=pg) + ret = broadcast_list[0] + # Gather values + keys = [] + local_state_dict_keys = set(local_state_dict.keys()) + global_keys = set() + for key, value in ret.items(): + global_keys.add(key) + if not isinstance(value, _TensorInfo): + if key in local_state_dict: + local_state_dict[key] = value + continue + + if dist.get_rank() == 0: + ret[key] = full_state_dict[key] + + keys.append(key) + # Broadcast every tensor to avoid OOM for now. + if len(keys) >= 1: + _broadcast_tensors(ret, local_state_dict, keys, device, pg) + if cpu_offload: + for key in keys: + local_state_dict[key] = local_state_dict[key].cpu() + keys.clear() + + if strict: + if missing_keys := (local_state_dict_keys - global_keys): + for key in missing_keys: + local_state_dict.pop(key) + + if keys: + _broadcast_tensors(ret, local_state_dict, keys, device, pg) + if cpu_offload: + for key in keys: + local_state_dict[key] = local_state_dict[key].cpu() + + +def _distribute_state_dict( + full_state_dict: dict[str, Any], + local_state_dict: dict[str, Any], + device: torch.device, + pg: Optional[dist.ProcessGroup] = None, +) -> None: + # Full_state_dict = True, broadcast_from_rank0 = False here. Each rank has + # full_state_dict. Skip the broadcast in ``_broadcast_state_dict`` and + # distribute tensors in each rank + for key, value in full_state_dict.items(): + if key not in full_state_dict: + continue + if not torch.is_tensor(value): + local_state_dict[key] = value + elif value.dim() == 0: + local_state_dict[key] = value.cpu() + else: + assert isinstance(value, torch.Tensor) + local_state = local_state_dict.get(key, None) + if local_state is None: + continue + elif isinstance(local_state, DTensor): + local_state_dict[key] = distribute_tensor( + value.detach().to(device), + local_state.device_mesh, + local_state.placements, + ) + else: + local_state_dict[key] = value.detach().to(device) + + +# These APIs are from torch.distributed.checkpoint. +# TODO: We should consolidate the code here as some not all modules can depend on +# DCP. +PATH_ITEM = Union[str, int] +OBJ_PATH = tuple[PATH_ITEM, ...] +FLATTEN_MAPPING = dict[str, OBJ_PATH] +STATE_DICT_TYPE = dict[str, Any] +CONTAINER_TYPE = MutableMapping[PATH_ITEM, Any] + + +def _traverse_state_dict( + state_dict: STATE_DICT_TYPE, + visitor: Callable[[OBJ_PATH, Any], None], +) -> None: + """ + Invoke ``visitor`` for each value recursively in ``state_dict``. + Mapping, list, and tuple will be flattened and other value types are treated + as the terminal values and will invoke ``visitor``. + """ + + def _traverse_obj(path: OBJ_PATH, value: Any) -> None: + if isinstance(value, Mapping): + for k, v in value.items(): + _traverse_obj(path + (str(k),), v) + elif isinstance(value, (list, tuple)): + for i, v in enumerate(value): + _traverse_obj(path + (i,), v) + else: + visitor(path, value) + + for key, value in state_dict.items(): + _traverse_obj((str(key),), value) + + +def _flatten_state_dict( + state_dict: STATE_DICT_TYPE, +) -> tuple[STATE_DICT_TYPE, FLATTEN_MAPPING]: + """ + Flatten ``state_dict`` made of nested dicts and lists into a top level dictionary. + + Use ``unflatten_state_dict`` to revert this process. + Returns: + A tuple with the flatten state_dict and a mapping from original to new state_dict. + N.B. The new keys are derived from the object paths, joined by dot. + For example: ``{ 'a': {'b':...}}`` results in the key `a.b`. + """ + flattened: STATE_DICT_TYPE = {} + mappings: FLATTEN_MAPPING = {} + + def flat_copy(path: OBJ_PATH, value: Any) -> None: + new_fqn = ".".join(map(str, path)) + if new_fqn in flattened: + raise ValueError(f"duplicated flatten key {new_fqn}") + flattened[new_fqn] = value + mappings[new_fqn] = path + + _traverse_state_dict(state_dict, flat_copy) + return flattened, mappings + + +def _set_element(root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: Any) -> None: + """Set ``value`` in ``root_dict`` along the ``path`` object path.""" + cur_container = cast(CONTAINER_TYPE, root_dict) + + def extend_list(lst: list[Any], idx: int) -> None: + while len(lst) <= idx: + lst.append(None) + + for i in range(1, len(path)): + prev_key = path[i - 1] + key = path[i] + def_val: Union[CONTAINER_TYPE, list[Any]] = {} if type(key) == str else [] + + if isinstance(cur_container, Mapping): + cur_container = cast( + CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val) + ) + else: + extend_list(cur_container, prev_key) + if cur_container[prev_key] is None: + cur_container[prev_key] = def_val + cur_container = cur_container[prev_key] + + key = path[-1] + if type(key) == int: + extend_list(cast(list[Any], cur_container), key) + + cur_container[key] = value + + +def _unflatten_state_dict( + state_dict: STATE_DICT_TYPE, mapping: FLATTEN_MAPPING +) -> STATE_DICT_TYPE: + """Restore the original nested state_dict according to ``mapping`` and the flattened ``state_dict``.""" + nested: STATE_DICT_TYPE = {} + for key, value in state_dict.items(): + _set_element(nested, mapping[key], value) + return nested diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/__init__.py b/phivenv/Lib/site-packages/torch/distributed/algorithms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b6650ed30d66c4cecc94e56f20b910feff582f5d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/algorithms/__init__.py @@ -0,0 +1 @@ +from .join import Join, Joinable, JoinHook diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/algorithms/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8763dc628bf412a97e388765c4f22fdb95a4068e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/algorithms/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/__pycache__/join.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/algorithms/__pycache__/join.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffa65568bd5263ca29750ca694c1ab402220c096 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/algorithms/__pycache__/join.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/_checkpoint/__init__.py b/phivenv/Lib/site-packages/torch/distributed/algorithms/_checkpoint/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/_checkpoint/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/algorithms/_checkpoint/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d889c876aa9c5c71ebcf44aca71230b71a97eedb Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/algorithms/_checkpoint/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/_checkpoint/__pycache__/checkpoint_wrapper.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/algorithms/_checkpoint/__pycache__/checkpoint_wrapper.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f88b1a47e8d501ef0b5a30a2951fc0f8bf9cb12 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/algorithms/_checkpoint/__pycache__/checkpoint_wrapper.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py b/phivenv/Lib/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..9997f8df731582f3b31f96714be2f13ced322cdb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py @@ -0,0 +1,324 @@ +# mypy: allow-untyped-defs +import warnings +from abc import ABC, abstractmethod +from collections.abc import Iterator +from enum import auto, Enum +from functools import partial +from typing import Any, Callable, Optional + +import torch +import torch.nn as nn +from torch.autograd.graph import save_on_cpu +from torch.distributed.utils import _pack_kwargs, _replace_by_prefix, _unpack_kwargs +from torch.utils.checkpoint import checkpoint as torch_utils_checkpoint + + +_CHECKPOINT_WRAPPED_MODULE = "_checkpoint_wrapped_module" +_CHECKPOINT_PREFIX = _CHECKPOINT_WRAPPED_MODULE + "." + + +class CheckpointImpl(Enum): + REENTRANT = auto() + NO_REENTRANT = auto() + + +class ActivationWrapper(torch.nn.Module, ABC): + """ + Base class for Activation Checkpoint and Activation Offload. + + Not meant to be instantiated directly. + """ + + def __init__(self, mod): + super().__init__() + self._checkpoint_wrapped_module = mod + # state_dict post hook to remove prefix to allow loading into a + # non-checkpoint wrapped module. + self._register_state_dict_hook(self._post_state_dict_hook) + # load_state_dict pre-hook to allow loading back into + # checkpoint-wrapped module. + self.register_load_state_dict_pre_hook(self._pre_load_state_dict_hook) + + @abstractmethod + def forward(self, *args, **kwargs): + raise ValueError("Subclasses should implement forward().") + + def __getattr__(self, name: str) -> Any: + """Forward missing attributes to wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + return getattr(self._checkpoint_wrapped_module, name) + + def __getitem__(self, key: int) -> Any: + """Forward indexing calls in case the module is a nn.Sequential.""" + return self._checkpoint_wrapped_module.__getitem__(key) # type: ignore[operator] + + def named_parameters( + self, + *args, + **kwargs, + ) -> Iterator[tuple[str, torch.nn.Parameter]]: + """ + Override :meth:`named_parameters()` to intercept parameter names. + + remove all occurrences of ``_CHECKPOINT_PREFIX``. + """ + for param_name, param in super().named_parameters(*args, **kwargs): + yield param_name.replace(_CHECKPOINT_PREFIX, ""), param + + @staticmethod + def _post_state_dict_hook( + module: nn.Module, + state_dict: dict[str, Any], + prefix: str, + *args: Any, + ) -> dict[str, Any]: + """ + _post_state_dict_hook() is called after the state_dict() of this FSDP module is executed. + + For ``checkpoint_wrapper``, it will strip checkpoint-wrapped module prefix, + so that this module can be loaded into non-checkpointed modules. + It would still be able to be loaded into checkpoint-wrapped modules as this class, + adds the prefix back before loading the state_dict. + """ + _replace_by_prefix(state_dict, f"{prefix}{_CHECKPOINT_PREFIX}", prefix) + return state_dict + + @staticmethod + def _pre_load_state_dict_hook( + module: nn.Module, + state_dict: dict[str, Any], + prefix: str, + *args: Any, + ) -> None: + """ + ``_pre_state_dict_hook` is called before ``self._load_from_state_dict()`` is called. + + For ``checkpoint_wrapper``, it will add back the module + prefix so that non-checkpointed modules can be loaded into + checkpoint_wrapper modules properly. + """ + _replace_by_prefix(state_dict, prefix, prefix + f"{_CHECKPOINT_PREFIX}") + + +class OffloadWrapper(ActivationWrapper): + def __init__(self, mod): + super().__init__(mod) + + def forward(self, *args, **kwargs): + with save_on_cpu(pin_memory=True): + return self._checkpoint_wrapped_module(*args, **kwargs) + + +class CheckpointWrapper(ActivationWrapper): + """ + An ``nn.Module`` that wraps another ``nn.Module`` with checkpointing. + + Note that this module is not meant to be used directly but instead, + it is to be used through the ``checkpoint_wrapper`` function. + """ + + def __init__( + self, + mod: torch.nn.Module, + checkpoint_impl: CheckpointImpl = CheckpointImpl.NO_REENTRANT, + checkpoint_fn=None, + **checkpoint_fn_kwargs, + ): + super().__init__(mod) + self.checkpoint_impl = checkpoint_impl + if checkpoint_fn is None: + # use torch.utils.checkpoint + self.checkpoint_fn = partial( + torch_utils_checkpoint, + use_reentrant=(self.checkpoint_impl == CheckpointImpl.REENTRANT), + **checkpoint_fn_kwargs, + ) + else: + # Construct user-specified checkpoint function. + self.checkpoint_fn = partial( + checkpoint_fn, + **checkpoint_fn_kwargs, + ) + + def forward(self, *args, **kwargs): + # Support keyword arguments for reentrant checkpoint. Note that this + # only works if user has specified self.checkpoint_impl and is not + # using their own custom checkpoint_fn. + if self.checkpoint_impl == CheckpointImpl.REENTRANT and kwargs != {}: + # Pack the args and kwargs + flat_args, kwarg_keys = _pack_kwargs(*args, **kwargs) + + # Function that only takes (packed) args, but can unpack them + # into the original args and kwargs for the checkpointed + # function, and runs that function. + def my_function(*inputs): + # unpack back into args and kwargs + unpacked_args, unpacked_kwargs = _unpack_kwargs(inputs, kwarg_keys) + # run original module + return self._checkpoint_wrapped_module( + *unpacked_args, **unpacked_kwargs + ) + + # Pass the function that only takes packed args into reentrant + # checkpoint API. + return self.checkpoint_fn( # type: ignore[misc] + my_function, + *flat_args, + ) + else: + return self.checkpoint_fn( # type: ignore[misc] + self._checkpoint_wrapped_module, *args, **kwargs + ) + + +def offload_wrapper(module: torch.nn.Module) -> torch.nn.Module: + """ + Wrap a module for activation offloading to CPU. + + Offloads intermediate activations to the CPU for modules wrapped with this function. + Wrappers with activation offload can be composed with ones that do recomputation-based + checkpoint to trade off increased compute versus increased CPU + memory usage and additional H2D transfers. + + Usage:: + offloaded_module = offload_wrapper(module) + outputs = checkpointed_module(inputs) + Args: + module (nn.Module): + The module to be wrapped + Returns: + (nn.Module): + Wrapped module + """ + return OffloadWrapper(module) + + +def checkpoint_wrapper( + module: torch.nn.Module, + checkpoint_impl: CheckpointImpl = CheckpointImpl.NO_REENTRANT, + checkpoint_fn=None, + **checkpoint_fn_kwargs, +) -> torch.nn.Module: + """ + Wrap a module for activation checkpointing. + + If the module is wrapped with this function, all subsequent calls to the module will, + automatically perform checkpointing without the user having to explicitly call ``checkpoint`` function. + + Usage:: + checkpointed_module = checkpoint_wrapper(module) + outputs = checkpointed_module(inputs) + Args: + module (nn.Module): + The module to be wrapped + checkpoint_impl (Optional[CheckpointImpl]): + The checkpointing implementation to use. Note that this will only + be passed into the ``torch.utils.checkpoint.checkpoint`` + implementation, and is ignored if a custom ``checkpoint_fn`` is + specified. Note that for implementations using reentrant checkpoint + from ``torch.utils.checkpoint``, keyword arguments will only be + supported if ``checkpoint_impl`` is passed as ``CheckpointImpl.REENTRANT`. + checkpoint_fn (Optional[Callable]): + Functional checkpoint implementation to use. If this is specified, + it will be used over the default ``torch.utils.checkpoint.checkpoint`` + implementation and the `checkpoint_impl` argument will be ignored. + **checkpoint_fn_kwargs: (Dict[str, Any]): Keyword arguments to pass into `checkpoint_fn`. + + Returns: + (nn.Module): + Wrapped module + """ + + if checkpoint_impl == CheckpointImpl.REENTRANT: + warnings.warn( + f"Please specify {CheckpointImpl.NO_REENTRANT} as " + f"{CheckpointImpl.REENTRANT} will soon be removed as " + "the default and eventually deprecated.", + FutureWarning, + stacklevel=2, + ) + return CheckpointWrapper( + module, + checkpoint_impl, + checkpoint_fn, + **checkpoint_fn_kwargs, + ) + + +def apply_activation_checkpointing( + model, + checkpoint_wrapper_fn=checkpoint_wrapper, + check_fn=lambda _: True, + auto_wrap_policy: Optional[Callable[[nn.Module, bool, int], bool]] = None, +): + """ + Apply :func:`checkpoint_wrapper` to modules within `model` based on a user-defined configuration. + + For each module within `model`, the `check_fn` is used to decide + whether `module` should be wrapped with :func:`checkpoint_wrapper` or not. + + Note:: + This function modifies `model` in place and replaces appropriate layers with + their checkpoint-wrapped modules. + Note:: + This function will not wrap the overall root module. If this is needed, please directly use + :func:`checkpoint_wrapper` or :func:`offload_wrapper`. + Usage:: + model = nn.Sequential( + nn.Linear(10, 10), nn.Linear(10, 10), nn.Linear(10, 10) + ) + check_fn = lambda l: isinstance(l, nn.Linear) + # checkpoint activations + apply_activation_checkpointing(model, checkpoint_wrapper_fn=checkpoint_wrapper, check_fn=check_fn) + # Or offload activations to CPU + apply_activation_checkpointing(model, checkpoint_wrapper_fn=offload_wrapper, check_fn=check_fn) + Args: + model (nn.Module): + The model whose submodules should be wrapped with activation checkpointing. + checkpoint_wrapper_fn (Optional[Callable[nn.Module]]) + A ``Callable`` which will wrap modules + check_fn (Optional[Callable[nn.Module, nn.Module]]) + A lambda function which will be passed each child submodule of ``model`` and returns + ``True`` or ``False`` depending on whether the submodule should be wrapped. + auto_wrap_policy (Optional[Callable[[nn.Module, bool, int], bool]]): A policy to wrap model's + submodules with AC. Note that if this is specified, it takes precedence over ``check_fn``. + Returns: None (`model` is modified inplace) + """ + # TODO: Importing inside function to avoid circular import issue between FSDP and + # checkpoint_wrapper. This can be resolved once wrap() APIs are decoupled from FSDP code. + from torch.distributed.fsdp._wrap_utils import _construct_wrap_fn, _post_order_apply + from torch.distributed.fsdp.wrap import ( + _Policy, + _recursive_wrap, + lambda_auto_wrap_policy, + ) + + policy = ( + auto_wrap_policy + if auto_wrap_policy is not None + else partial(lambda_auto_wrap_policy, lambda_fn=check_fn) + ) + if not callable(policy): + if not isinstance(policy, _Policy): + raise ValueError( + f"Expected {policy} to be callable or be a pre-defined wrap policy" + ) + target_module_to_kwargs = policy._run_policy( + model, ignored_modules=set(), root_kwargs={} + ) + wrap_fn = _construct_wrap_fn( + model, target_module_to_kwargs, checkpoint_wrapper_fn + ) + _post_order_apply(model, wrap_fn) + return + + _recursive_wrap( + module=model, + auto_wrap_policy=policy, # type: ignore[arg-type] + wrapper_cls=checkpoint_wrapper_fn, + ignored_modules=set(), + ignored_params=set(), + only_wrap_children=True, + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/_comm_hooks/__init__.py b/phivenv/Lib/site-packages/torch/distributed/algorithms/_comm_hooks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eca9f2168b524aa1ff39cfd4cf66eb276d6a35d5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/algorithms/_comm_hooks/__init__.py @@ -0,0 +1,7 @@ +from . import default_hooks as default + + +LOW_PRECISION_HOOKS = [ + default.fp16_compress_hook, + default.bf16_compress_hook, +] diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/_comm_hooks/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/algorithms/_comm_hooks/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..794d6d3d81127bb1fab8e93d91ea7b9e3f207f76 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/algorithms/_comm_hooks/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/_comm_hooks/__pycache__/default_hooks.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/algorithms/_comm_hooks/__pycache__/default_hooks.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1e57a1291c51fa5cd3d71ed75d5faf888f1889a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/algorithms/_comm_hooks/__pycache__/default_hooks.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/_comm_hooks/default_hooks.py b/phivenv/Lib/site-packages/torch/distributed/algorithms/_comm_hooks/default_hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..addefc2de74f33791c91d925f5201793245337ad --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/algorithms/_comm_hooks/default_hooks.py @@ -0,0 +1,192 @@ +# mypy: allow-untyped-defs +import functools +from typing import Optional + +import torch +import torch.distributed as dist + + +class DefaultState: + r""" + Stores state needed to perform the default communication algorithm within a communication hook. + + Args: + process_group (ProcessGroup): The process group to be used. + """ + + __slots__ = [ + "process_group", + "world_size", + "gradient_predivide_factor", + "gradient_postdivide_factor", + ] + + def __init__(self, process_group: dist.ProcessGroup): + if process_group is None: + raise ValueError(f"Expected to pass in an explicit ProcessGroup to {self}.") + self.process_group = process_group + self.world_size = dist.get_world_size(process_group) + # Setting two factors `self.gradient_predivide_factor` + # and `self.gradient_postdivide_factor` to avoid underflow and overflow + self.gradient_predivide_factor = self._get_gradient_predivide_factor( + self.world_size + ) + self.gradient_postdivide_factor = ( + self.world_size / self.gradient_predivide_factor + ) + + @staticmethod + def _get_gradient_predivide_factor(world_size: int) -> float: + factor: int = 1 + while world_size % factor == 0 and world_size / factor > factor: + factor *= 2 + return float(factor) + + +class LowPrecisionState(DefaultState): + r""" + Stores state needed to perform gradient communication in a lower precision within a communication hook. + + Communication hook will cast gradients back to the original + parameter precision specified by ``parameter_type`` (default: torch.float32). + Builds on top of the :class:`DefaultState`. + + Args: + parameter_type (torch.dtype): The precision of model's parameters. + Required for a hook to cast gradients back to a parameter's precision. + """ + + __slots__ = [ + "parameter_type", + ] + + def __init__( + self, + process_group, + parameter_type=torch.float32, + ): + super().__init__(process_group) + self.parameter_type = parameter_type + + +def _decompress(state: LowPrecisionState, grad: torch.Tensor): + """ + Casts gradients back to full parameter precision so that further computation happens in full precision. + """ + orig_grad_data = grad.data + grad.data = grad.data.to(state.parameter_type) + device_type = "" + try: + if grad.device.type == "privateuse1": + device_type = torch._C._get_privateuse1_backend_name() + else: + device_type = grad.device.type + backend = getattr(torch, device_type) + except AttributeError as e: + raise AttributeError( + f"Device {grad.device} does not have a \ + corresponding backend registered as 'torch.device_type'." + ) from e + + # Don't let this memory get reused until after the transfer. + orig_grad_data.record_stream(backend.current_stream()) # type: ignore[arg-type] + + +def allreduce_hook(state: DefaultState, grad: torch.Tensor): + r""" + Implement the FSDP communication hook for ``all_reduce`` algorithm and a necessary pre- and post-division of gradients. + + Args: + state (DefaultState): State information, configures pre- and post-division factors. + grad (torch.Tensor): A gradient for the local batch that needs to be communicated across ranks. + """ + # Average grad by pre-division factor. Together pre- and post-division factors + # lead to an overall averaging by world_size, required for consistency with PyTorch DDP. + # This is a two-step process to avoid potential underflow and overflow. + if state.gradient_predivide_factor > 1: + grad.div_(state.gradient_predivide_factor) + dist.all_reduce(grad, group=state.process_group) + # Average grad by post-division factor. + if state.gradient_postdivide_factor > 1: + grad.div_(state.gradient_postdivide_factor) + + +def reduce_scatter_hook(state: DefaultState, grad: torch.Tensor, output: torch.Tensor): + r""" + Implement the FSDP communication hook for ``reduce_scatter`` algorithm. + + For sharded FSDP strategies and a necessary pre- and post-division of gradients. + + Args: + state (DefaultState): State information, configures pre- and post-division factors. + grad (torch.Tensor): An unsharded gradient for the local batch that needs to be + communicated across ranks. + output (torch.Tensor): Stores a single shard of the gradient after ``reduce_scatter``. + """ + # Average grad by pre-division factor. + if state.gradient_predivide_factor > 1: + grad.div_(state.gradient_predivide_factor) + dist.reduce_scatter_tensor(output, grad, group=state.process_group) + # Average grad's shard by post-division factor. + if state.gradient_postdivide_factor > 1: + output.div_(state.gradient_postdivide_factor) + + +def _low_precision_hook( + prec: torch.dtype, + state: LowPrecisionState, + grad: torch.Tensor, + output: Optional[torch.Tensor], +): + if grad.dtype != prec: + grad.data = grad.data.to(prec) + if output is not None: + if output.dtype != prec: + output.data = output.data.to(prec) + reduce_scatter_hook(state, grad, output) + _decompress(state, output) + else: + allreduce_hook(state, grad) + _decompress(state, grad) + + +def fp16_compress_hook( + state: LowPrecisionState, grad: torch.Tensor, output: Optional[torch.Tensor] = None +): + r""" + Implement FSDP communication hook for a simple gradient compression approach. + Casts ``grad`` to half-precision floating-point format (``torch.float16``). + + It also averages gradients by ``world_size`` in two steps: first it pre-divides gradients by a + ``state.gradient_predivide_factor``, and after a communication step (``all_reduce`` or ``reduce_scatter``) + gradients are averaged by a ``state.gradient_postdivide_factor``. + Once post-division is done, compressed gradients are casted back to parameters' precision. + + Args: + state (LowPrecisionState): State information, configures pre- and post-division factors, parameters' precision. + grad (torch.Tensor): A gradient for the local batch that needs to be communicated across ranks in a lower precision. + output (torch.Tensor): Stores a single shard of the gradient after ``reduce_scatter``. + """ + fp16_hook = functools.partial(_low_precision_hook, torch.float16) + return fp16_hook(state, grad, output) + + +def bf16_compress_hook( + state: LowPrecisionState, grad: torch.Tensor, output: Optional[torch.Tensor] = None +): + r""" + Implement FSDP communication hook for a simple gradient compression approach . + Casts ``grad`` to half-precision floating-point format. + + It also averages gradients by ``world_size`` in two steps: first it pre-divides gradients by a + ``state.gradient_predivide_factor``, and after a communication step (``all_reduce`` or ``reduce_scatter``) + gradients are averaged by a ``state.gradient_postdivide_factor``. + Once post-division is done, compressed gradients are casted back to parameters' precision. + + Args: + state (LowPrecisionState): State information, configures pre- and post-division factors, parameters' precision. + grad (torch.Tensor): A gradient for the local batch that needs to be communicated across ranks in a lower precision. + output (torch.Tensor): Stores a single shard of the gradient after ``reduce_scatter``. + """ + bf16_hook = functools.partial(_low_precision_hook, torch.bfloat16) + return bf16_hook(state, grad, output) diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/_optimizer_overlap/__init__.py b/phivenv/Lib/site-packages/torch/distributed/algorithms/_optimizer_overlap/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9460c12ce8abc076e6d22570e65a773039aa145f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/algorithms/_optimizer_overlap/__init__.py @@ -0,0 +1 @@ +from .optimizer_overlap import _as_overlapped_optim diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/_optimizer_overlap/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/algorithms/_optimizer_overlap/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8f05b252199bf2c6d9acb7ba7999a8e797ab7fa Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/algorithms/_optimizer_overlap/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/_optimizer_overlap/__pycache__/optimizer_overlap.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/algorithms/_optimizer_overlap/__pycache__/optimizer_overlap.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c413efb3606e7c3660ad6a1036667cc5f524346 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/algorithms/_optimizer_overlap/__pycache__/optimizer_overlap.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py b/phivenv/Lib/site-packages/torch/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py new file mode 100644 index 0000000000000000000000000000000000000000..eb29f7f0b599a37df18b61544f0e7d1977912a70 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py @@ -0,0 +1,96 @@ +# mypy: allow-untyped-defs +import inspect +from abc import ABC, abstractmethod + +from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import allreduce_hook +from torch.distributed.algorithms.ddp_comm_hooks.optimizer_overlap_hooks import ( + _hook_then_optimizer, + _OptimizerHookState, +) +from torch.distributed.fsdp import FullyShardedDataParallel +from torch.distributed.optim import as_functional_optim +from torch.nn.parallel import DistributedDataParallel +from torch.optim import Optimizer + + +# Contains the mappings between the regular and overlapped optimizer types. +_registered_overlapped_optims: dict[type, type] = {} + + +def register_overlapped(optim_cls): + def decorator(target_overlapped_optim_cls): + if target_overlapped_optim_cls in _registered_overlapped_optims: + raise ValueError( + f"{target_overlapped_optim_cls} already registered with optim_cls " + f"{_registered_overlapped_optims[optim_cls]} {optim_cls}, trying to" + f"re-register it for {optim_cls} is not supported." + ) + _registered_overlapped_optims[optim_cls] = target_overlapped_optim_cls + return target_overlapped_optim_cls + + return decorator + + +class OverlappedOptimizer(ABC): + def __init__(self, optim_cls: type) -> None: + """ + Initialize the OverlappedOptimizer. + + Overlappedoptimizer is a base class that child classes can implement to + specify how different optimizers will register themselves with DDP. + """ + self.optim_cls = optim_cls + + @abstractmethod + def register_ddp(self, ddp: DistributedDataParallel) -> None: + """Registers the overlapped optimizer with DDP.""" + raise NotImplementedError( + f"{self.__class__.__name__} does not support overlapped DDP." + ) + + @abstractmethod + def register_fsdp(self, fsdp: FullyShardedDataParallel) -> None: + """Registers the overlapped optimizer with FSDP.""" + raise NotImplementedError( + f"{self.__class__.__name__} does not support overlapped FSDP." + ) + + +@register_overlapped(Optimizer) +class _OverlappedStandardOptimizer(OverlappedOptimizer): + """Overlaps a regular ``Optimizer``.""" + + def __init__(self, optim_cls: type, params, *optim_args, **optim_kwargs) -> None: + super().__init__(optim_cls) + f_optim = as_functional_optim(self.optim_cls, *optim_args, **optim_kwargs) + self._opt_hook_state = _OptimizerHookState(f_optim, params) + + def register_ddp(self, ddp_inst: DistributedDataParallel): + # NOTE: using a custom communication hook and fused optimizer is not + # yet supported. + ddp_inst.register_comm_hook( # type: ignore[operator] + None, # wrapped hook state + _hook_then_optimizer(allreduce_hook, self._opt_hook_state), + ) + + # TODO: register_fsdp once FSDP supports communication hook. + def register_fsdp(self, fsdp: FullyShardedDataParallel) -> None: + """Register the overlapped optimizer with FSDP.""" + raise NotImplementedError( + f"{self.__class__.__name__} does not support overlapped FSDP." + ) + + +def _as_overlapped_optim(optim_cls: type, params, *args, **kwargs): + """Return a new ``OverlappedOptimizer`` instance that supports ``optim_cls``.""" + for clz in inspect.getmro(optim_cls): + try: + return _registered_overlapped_optims[clz]( + optim_cls, params, *args, **kwargs + ) + except KeyError: + pass + + # Fallback to standard overlapped optimizer, which will raise errors if user + # is attempting to use an unsupported optimizer. + return _OverlappedStandardOptimizer(optim_cls, params, *args, **kwargs) diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/_quantization/__init__.py b/phivenv/Lib/site-packages/torch/distributed/algorithms/_quantization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/_quantization/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/algorithms/_quantization/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..229505e64e35234fa5667ebb49085d44e385a92b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/algorithms/_quantization/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/_quantization/__pycache__/quantization.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/algorithms/_quantization/__pycache__/quantization.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c55457e0165d715d9fad2a2c3d69f1944a7e18b2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/algorithms/_quantization/__pycache__/quantization.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/_quantization/quantization.py b/phivenv/Lib/site-packages/torch/distributed/algorithms/_quantization/quantization.py new file mode 100644 index 0000000000000000000000000000000000000000..d87e5729adf71516fb366e877df90403901b8fd0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/algorithms/_quantization/quantization.py @@ -0,0 +1,150 @@ +# mypy: allow-untyped-defs +import functools +from enum import Enum + +import torch +import torch.distributed as dist + + +TORCH_HALF_MIN = torch.finfo(torch.float16).min +TORCH_HALF_MAX = torch.finfo(torch.float16).max + + +class DQuantType(Enum): + """ + Different quantization methods for auto_quantize API are identified here. + + auto_quantize API currently supports fp16 and bfp16 methods. + """ + + FP16 = ("fp16",) + BFP16 = "bfp16" + + def __str__(self) -> str: + return self.value + + +def _fp32_to_fp16_with_clamp(tensor: torch.Tensor) -> torch.Tensor: + return torch.clamp(tensor, TORCH_HALF_MIN, TORCH_HALF_MAX).half() + + +def _quantize_tensor(tensor, qtype): + if not isinstance(tensor, torch.Tensor): + raise RuntimeError( + f"_quantize_tensor expecting torch.Tensor as input but found {type(tensor)}" + ) + if qtype == DQuantType.FP16: + return _fp32_to_fp16_with_clamp(tensor) + elif qtype == DQuantType.BFP16: + return torch.ops.quantization._FloatToBfloat16Quantized(tensor) + else: + raise RuntimeError(f"Quantization type {qtype} is not supported") + + +def _quantize_tensor_list(tensor_list, qtype): + if not isinstance(tensor_list, list) or not all( + isinstance(p, torch.Tensor) for p in tensor_list + ): + raise RuntimeError( + f"_quantize_tensor_list expecting list of torch.Tensor as input but found {type(tensor_list)}" + ) + quantized_tensor_list = [_quantize_tensor(t, qtype) for t in tensor_list] + return quantized_tensor_list + + +def _dequantize_tensor(tensor, qtype, quant_loss=None): + if not isinstance(tensor, torch.Tensor): + raise RuntimeError( + f"_dequantize_tensor expecting torch.Tensor as input but found {type(tensor)}" + ) + if qtype == DQuantType.FP16: + if tensor.dtype != torch.float16: + raise RuntimeError( + f"tensor dtype is {tensor.dtype} while expected to be FP16." + ) + elif tensor.dtype == torch.float16 and quant_loss is None: + return tensor.float() + else: + return tensor.float() / quant_loss + elif qtype == DQuantType.BFP16: + if tensor.dtype != torch.float16: + raise RuntimeError( + f"tensor dtype is {tensor.dtype} while expected to be FP16." + ) + else: + return torch.ops.quantization._Bfloat16QuantizedToFloat(tensor) + else: + raise RuntimeError(f"Quantization type {qtype} is not supported") + + +def _dequantize_tensor_list(tensor_list, qtype, quant_loss=None): + if not isinstance(tensor_list, list) or not all( + isinstance(p, torch.Tensor) for p in tensor_list + ): + raise RuntimeError( + f"_dequantize_tensor_list expecting list of torch.Tensor as input but found {type(tensor_list)}" + ) + dequantized_tensor_list = [_dequantize_tensor(t, qtype) for t in tensor_list] + return dequantized_tensor_list + + +def auto_quantize(func, qtype, quant_loss=None): + """ + Quantize the input tensors, choose the precision types, and pass other necessary arguments and then dequantizes the output. + + Currently it only supports: + . FP16 and BFP16 quantization method supported for gloo and nccl backends + . all_gather, all_to_all collective ops + Note: BFP16 only supports 2D tensors. + Args: + func (Callable): A function representing collective operations. + qtype (QuantType): Quantization method + quant_loss (float, optional): This can be used to improve accuracy in the dequantization. + Returns: + (Callable): the same collective as func but enables automatic quantization/dequantization. + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + group = kwargs.get("group", None) + async_op = kwargs.get("async_op", False) + if async_op is True: + raise RuntimeError("The async_op=True mode is not supported yet.") + if func == dist.all_gather: + tensors = args[0] + input_tensors = _quantize_tensor(args[1], qtype) + out_tensors = _quantize_tensor_list(tensors, qtype) + dist.all_gather(out_tensors, input_tensors, group=group, async_op=async_op) + for i, t in enumerate( + _dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss) + ): + tensors[i] = t + + elif func == dist.all_to_all: + tensors = args[0] + input_tensors = _quantize_tensor_list(args[1], qtype) + out_tensors = _quantize_tensor_list(tensors, qtype) + dist.all_to_all(out_tensors, input_tensors, group=group, async_op=async_op) + for i, t in enumerate( + _dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss) + ): + tensors[i] = t + + elif func == dist.all_to_all_single: + tensors = args[0] + out_splits = kwargs.get("out_splits", None) + in_splits = kwargs.get("in_splits", None) + # Quantizing the input/output tensor + input_tensors = _quantize_tensor(args[1], qtype) + out_tensors = _quantize_tensor(tensors, qtype) + dist.all_to_all_single( + out_tensors, input_tensors, out_splits, in_splits, group=group + ) + for i, t in enumerate( + _dequantize_tensor(out_tensors, qtype, quant_loss=quant_loss) + ): + tensors[i] = t + else: + raise RuntimeError(f"The collective op {func} is not supported yet") + + return wrapper diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__init__.py b/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..41a434828ea8d156c3586bd7d886a3a295b2ccdb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__init__.py @@ -0,0 +1,110 @@ +# mypy: allow-untyped-defs +from enum import Enum +from functools import partial + +import torch.distributed as dist + +from . import ( + debugging_hooks as debugging, + default_hooks as default, + optimizer_overlap_hooks as optimizer_overlap, + powerSGD_hook as powerSGD, + quantization_hooks as quantization, +) + + +__all__ = ["DDPCommHookType", "register_ddp_comm_hook"] + + +def _ddp_comm_hook_wrapper(comm_hook, model, state): + model.register_comm_hook(state, comm_hook) + + +def _powerSGD_comm_hook_wrapper( + comm_hook, + model, + state, + matrix_approximation_rank, + start_powerSGD_iter=1_000, +): + """ + Wrap PowerSGD communication hook. + + To be consistent with the wrappers of other DDP comm hooks, the input state only needs to be a process group, + which will be wrapped up with other state info. + """ + powerSGD_state = powerSGD.PowerSGDState( + process_group=state, + matrix_approximation_rank=matrix_approximation_rank, + start_powerSGD_iter=start_powerSGD_iter, + ) + model.register_comm_hook(powerSGD_state, comm_hook) + + +class DDPCommHookType(Enum): + """ + Enumerate ``ddp_comm_hooks`` and ``ddp_comm_hook_wrapper`` communucation hook types. + + DDPCommHookType enumerates the hooks of ``torch.distributed.algorithms.ddp_comm_hooks`` + as names and ``ddp_comm_hook_wrapper`` partials with hook specified. As an example, + you can register allreduce hook by + ``DDPCommHookType.ALLREDUCE.value(model=model, state=process_group)``. + """ + + ALLREDUCE = partial(_ddp_comm_hook_wrapper, comm_hook=default.allreduce_hook) + FP16_COMPRESS = partial( + _ddp_comm_hook_wrapper, comm_hook=default.fp16_compress_hook + ) + BF16_COMPRESS = partial( + _ddp_comm_hook_wrapper, comm_hook=default.bf16_compress_hook + ) + QUANTIZE_PER_TENSOR = partial( + _ddp_comm_hook_wrapper, comm_hook=quantization.quantization_pertensor_hook + ) + QUANTIZE_PER_CHANNEL = partial( + _ddp_comm_hook_wrapper, comm_hook=quantization.quantization_perchannel_hook + ) + POWER_SGD = partial( + _powerSGD_comm_hook_wrapper, + comm_hook=powerSGD.powerSGD_hook, + matrix_approximation_rank=1, + ) + # Rank-2 PowerSGD can give a higher accuracy than the default rank-1 version, + # but it runs slower and consumes more memory. + POWER_SGD_RANK2 = partial( + _powerSGD_comm_hook_wrapper, + comm_hook=powerSGD.powerSGD_hook, + matrix_approximation_rank=2, + ) + # Batching can lead to a faster training at the cost of accuracy. + BATCHED_POWER_SGD = partial( + _powerSGD_comm_hook_wrapper, + comm_hook=powerSGD.batched_powerSGD_hook, + matrix_approximation_rank=1, + ) + BATCHED_POWER_SGD_RANK2 = partial( + _powerSGD_comm_hook_wrapper, + comm_hook=powerSGD.batched_powerSGD_hook, + matrix_approximation_rank=2, + ) + NOOP = partial( + _ddp_comm_hook_wrapper, + comm_hook=debugging.noop_hook, + ) + + +def register_ddp_comm_hook(comm_hook_type: DDPCommHookType, model, state=None): + """ + Register ``ddp_comm_hooks`` to DDP model. + + Registers the hooks of ``torch.distributed.algorithms.ddp_comm_hooks`` + to the DDP model. User can specify the type of hook as an enum + ``DDPCommHookType`` type using ``comm_hook_type`` input. State input will + be passed to the model. + Uses Python comm hook implementations. + + Example:: + >>> # xdoctest: +SKIP + >>> register_ddp_comm_hook(DDPCommHookType.FP16_COMPRESS, model, state) + """ + comm_hook_type.value(model=model, state=state) diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..155ca2564c698f2d8abfee1671fa464b5056e480 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/ddp_zero_hook.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/ddp_zero_hook.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c0463a620a0858873116d6fa7156ca5aa15bcdc Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/ddp_zero_hook.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/debugging_hooks.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/debugging_hooks.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6f17876dd7d6cc03aa3feaffce03d3e523873a4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/debugging_hooks.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/default_hooks.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/default_hooks.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a43e633720bc79f1edcc781531ab6b360bf9d9d6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/default_hooks.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/mixed_precision_hooks.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/mixed_precision_hooks.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0befeb1ef59bb91aecad37d39db4ac4d82247351 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/mixed_precision_hooks.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/optimizer_overlap_hooks.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/optimizer_overlap_hooks.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..874c01e7c513994beae1a9fc6f96a370b08b6b3c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/optimizer_overlap_hooks.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/post_localSGD_hook.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/post_localSGD_hook.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ae8277bcb6eb3e0691af2262c5ef6864d436555 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/post_localSGD_hook.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/powerSGD_hook.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/powerSGD_hook.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea5906025a49bea17c277e476f41a397c9c800c5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/powerSGD_hook.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/quantization_hooks.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/quantization_hooks.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69b884ec433a4f5ae41aee58d8e8e80029f40cbd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/quantization_hooks.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py b/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..18ce616e5d609b739220f2a901cfc252ed242200 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py @@ -0,0 +1,460 @@ +# mypy: allow-untyped-defs +import weakref +from typing import Any, Callable, Optional + +import torch +import torch.distributed as dist +from torch.distributed.optim import ZeroRedundancyOptimizer +from torch.distributed.optim.zero_redundancy_optimizer import _OverlapStatus +from torch.nn.parallel.distributed import DistributedDataParallel + + +__all__ = ["hook_with_zero_step", "hook_with_zero_step_interleaved"] + +# Functional optimizers require passing a list of gradients to their `step()` +# method, and ZeRO requires a functional optimizer to overlap with DDP +# Passing a `None` instead of an actual gradient indicates to the optimizer +# to not update the corresponding parameter +_NO_PARAM_UPDATE: None = None + + +def _perform_local_step( + bucket: dist.GradBucket, + zero: ZeroRedundancyOptimizer, + rank: int, +): + r""" + Perform a local optimizer step using the gradients provided by ``bucket``. + + Arguments: + bucket (dist.GradBucket): the bucket providing the gradients. + zero (ZeroRedundancyOptimizer): the :class:`ZeroRedundancyOptimizer` + instance to perform the :meth:`_local_step`. + rank (int): the calling process's rank. + + .. warning:: + This function assumes that appropriate synchronization has taken place + so that the bucket's gradients can be used. + """ + overlap_info = zero._overlap_info + bucket_index = bucket.index() + assert len(zero.optim.param_groups) == 1, ( + "Overlapping DDP with ZeRO only supports a single parameter group" + ) + + # Construct the `gradients` input for the local optimizer step, which + # expects `None` in a list position to indicate that the corresponding + # parameter should not be updated + num_local_optim_params = len(zero.optim.param_groups[0]["params"]) + gradients: list[Optional[torch.Tensor]] = [ + _NO_PARAM_UPDATE for _ in range(num_local_optim_params) + ] + assert bucket_index in overlap_info.offsets, ( + f"Bucket index {bucket_index} was not assigned to rank {rank}" + ) + gradients_offset = overlap_info.offsets[bucket_index] + bucket_assignment = zero._bucket_assignments_per_rank[rank][bucket_index] + bucket_offset = bucket_assignment.offset + length = len(bucket_assignment.parameters) + bucket_gradients = bucket.gradients()[bucket_offset : bucket_offset + length] + for i, grad in enumerate(bucket_gradients): + gradients[gradients_offset + i] = grad + + zero._local_step(gradients) + + +def _broadcast_bucket( + bucket_index: int, + zero: ZeroRedundancyOptimizer, +): + r""" + Broadcasts a bucket's parameters. + + Arguments: + bucket_index (int): the index of the bucket corresponding to the + parameters to broadcast. + zero (ZeroRedundancyOptimizer): the calling process's + :class:`ZeroRedundancyOptimizer` instance. + """ + overlap_info = zero._overlap_info + assert len(overlap_info.assigned_ranks_per_bucket) > bucket_index, ( + "`assigned_ranks_per_bucket` is not fully constructed" + ) + # Sort to ensure the same ordering across ranks + assigned_ranks = sorted(overlap_info.assigned_ranks_per_bucket[bucket_index]) + assert len(assigned_ranks) > 0, ( + f"Bucket {bucket_index} should be assigned to at least one rank" + ) + for assigned_rank in assigned_ranks: + bucket_assignments = zero._bucket_assignments_per_rank[assigned_rank] + if bucket_index in bucket_assignments: + send_tensor = bucket_assignments[bucket_index].tensor + assert send_tensor is not None + overlap_info.broadcast_handles.append( + dist.broadcast( + send_tensor, + src=dist.get_global_rank(zero.process_group, assigned_rank), + group=zero.process_group, + async_op=True, + ) + ) + + +def _save_ddp_bucket_info( + bucket: dist.GradBucket, + zero: ZeroRedundancyOptimizer, +): + r""" + Save :class:`DistributedDataParallel` gradient bucket information for :class:`ZeroRedundancyOptimizer` instance ``zero``. + + In particular, this function is meant to be called upon seeing each + gradient bucket to use when overlapping, meaning it does not save or compute any global + information. + + Arguments: + bucket (dist.GradBucket): the current gradient bucket. + zero (ZeroRedundancyOptimizer): the calling process's + :class:`ZeroRedundancyOptimizer` instance. + """ + overlap_info = zero._overlap_info + bucket_params = bucket.parameters() + assert len(bucket_params) > 0, "Empty bucket" + + # Save the parameters in the bucket + overlap_info.params_per_bucket.append(bucket_params) + if overlap_info.shard_buckets: + # Additionally save the bucket size for the assignment heuristic to use + bucket_size = 0 + for param in bucket_params: + bucket_size += param.numel() + assert overlap_info.total_size is not None + overlap_info.total_size += bucket_size + + +def _hook_with_zero_step_setup( + ddp_ref: weakref.ReferenceType, + zero: ZeroRedundancyOptimizer, + bucket: dist.GradBucket, +): + r""" + Encapsulate the setup logic for :func:`hook_with_zero_step` and :func:`hook_with_zero_step_interleaved`. + + This means the logic to run in the + hook before the backward pass and optimizer step can actually be + overlapped. This is factored out since it is common to both + :func:`hook_with_zero_step` and :func:`hook_with_zero_step_interleaved`. + + Arguments: + ddp_ref (weakref.ReferenceType): weak reference to the process's + :class:`DistributedDataParallel` instance. + zero (ZeroRedundancyOptimizer): the calling process's + :class:`ZeroRedundancyOptimizer` instance. + bucket (dist.GradBucket): the current gradient bucket. + """ + # Proceed as normal until the DDP buckets have been rebuilt + if not ddp_ref()._has_rebuilt_buckets: # type: ignore[union-attr] + assert zero._overlap_info.status == _OverlapStatus.UNINITIALIZED + return + + bucket_index = bucket.index() + overlap_info = zero._overlap_info + if overlap_info.status == _OverlapStatus.UNINITIALIZED: + overlap_info.status = _OverlapStatus.DDP_HAS_REBUILT_BUCKETS + + if overlap_info.status == _OverlapStatus.DDP_HAS_REBUILT_BUCKETS: + if bucket_index == 0 and len(overlap_info.params_per_bucket) > 0: + # This corresponds to the first bucket of the backward pass + # immediately after all information has been saved, so we + # can perform the delayed ZeRO initialization + zero._init_zero_for_overlap() + else: + # Once DDP buckets have been rebuilt but ZeRO has not been + # properly initialized yet, save the information needed + _save_ddp_bucket_info(bucket, zero) + + +def hook_with_zero_step( + hook: Callable[[Any, dist.GradBucket], torch.futures.Future], + ddp: DistributedDataParallel, + zero: ZeroRedundancyOptimizer, + shard_buckets: bool = False, +) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]: + r""" + Modify ``hook`` to overlap :class:`ZeroRedundancyOptimizer` optimizer step with :class:`DistributedDataParallel` backward pass. + + This approach overlaps the optimizer computation and communication with the + backward communication. In particular, the backward computation proceeds + contiguously, and the optimizer computation follows, overlapping with + outstanding backward communication (i.e. all-reduces) and possibly other + optimizer communication (i.e. broadcasts). + The optimizer step computation begins after the last gradient bucket computation has finished. + + This approach may be preferred over :meth:`hook_with_zero_step_interleaved` + if communication is relatively slow compared to computation. + + Arguments: + hook (Callable[[Any, dist.GradBucket], torch.futures.Future]): the hook + to modify. + ddp (DistributedDataParallel): the :class:`DistributedDataParallel` + instance to use. + zero (ZeroRedundancyOptimizer): the :class:`ZeroRedundancyOptimizer` + instance to use. + shard_buckets (bool): if ``True``, then the assignment of each + :class:`DistributedDataParallel` bucket is partitioned across + possibly multiple :class:`ZeroRedundancyOptimizer` instances (i.e. + across possibly multiple ranks) to approximate uniformity; if + ``False``, then each bucket is wholly assigned to a single + :class:`ZeroRedundancyOptimizer` instance (i.e. to a single rank). + + Returns: + The modified hook. + + Raises: + ValueError: if ``zero`` was constructed with ``overlap_with_ddp=False``. + RuntimeError: if using any backend other than NCCL/HCCL since currently + Gloo may hang. + + .. warning:: + Given the way that overlapping :class:`DistributedDataParallel` with + :class:`ZeroRedundancyOptimizer` is currently implemented, the first + two or three training iterations do not perform parameter updates in + the optimizer step, depending on if ``static_graph=False`` or + ``static_graph=True``, respectively. This is because it needs + information about the gradient bucketing strategy used by + :class:`DistributedDataParallel`, which is not finalized until the + second forward pass if ``static_graph=False`` or until the third + forward pass if ``static_graph=True``. + """ + if not zero._overlap_with_ddp: + raise ValueError( + "ZeroRedundancyOptimizer must be constructed with " + "`overlap_with_ddp=True` to use this hook properly" + ) + ddp_ref = weakref.ref(ddp) + + # NOTE: Gloo may hang with this overlapping approach, so we require + # NCCL/HCCL backend for now; see https://github.com/pytorch/pytorch/issues/62300 + pg = dist.get_backend(ddp_ref().process_group) # type: ignore[union-attr] + if (pg != dist.Backend.NCCL) and (pg != "hccl"): + raise RuntimeError( + "Overlapping DDP with ZeRO using this approach currently requires " + "NCCL/HCCL backend to avoid hangs" + ) + + if shard_buckets: + zero._overlap_info.shard_buckets = True + zero._overlap_info.total_size = 0 + + def hook_with_zero_fn( + state: Any, + bucket: dist.GradBucket, + ) -> torch.futures.Future[torch.Tensor]: + r""" + Return :class:`Future` that runs the optimizer step if this corresponds to the last gradient bucket. + + Perform equivalent of :class:`ZeroRedundancyOptimizer` :meth:`step` if ``bucket`` is last gradient bucket. + The function gives a gradient bucket tensor and + performs additional computation on the iteration that + the :class:`DistributedDataParallel` buckets are rebuilt to collect + information used to implement the modified hook. + + Arguments: + state (Any): any state for the hook. + bucket (dist.GradBucket): the :class:`DistributedDataParallel` + gradient bucket. + """ + fut = hook(state, bucket) + _hook_with_zero_step_setup(ddp_ref, zero, bucket) + if zero._overlap_info.status != _OverlapStatus.INITIALIZED: + return fut + + overlap_info = zero._overlap_info + bucket_index = bucket.index() + rank = zero.global_rank + + assert overlap_info.status == _OverlapStatus.INITIALIZED + assert len(overlap_info.assigned_ranks_per_bucket) > bucket_index, ( + "`assigned_ranks_per_bucket` is not fully constructed" + ) + assigned_to_bucket = ( + rank in overlap_info.assigned_ranks_per_bucket[bucket_index] + ) + + # Save the bucket reference and all-reduce future for the final bucket + if assigned_to_bucket: + overlap_info.bucket_index_to_bucket[bucket_index] = bucket + overlap_info.bucket_index_to_future[bucket_index] = fut + + # Check that buckets are indexed incrementally starting from 0 in the + # order of their autograd hooks firing + if len(overlap_info.bucket_indices_seen) > 0: + assert overlap_info.bucket_indices_seen[-1] == bucket_index - 1, ( + "Bucket indices are not in incremental order" + ) + else: + assert bucket_index == 0, "Bucket indices do not start from 0" + overlap_info.bucket_indices_seen.append(bucket_index) + + # Directly return the future without any optimizer computation if this + # is not the last bucket + num_buckets = len(overlap_info.params_per_bucket) + is_last_bucket = bucket_index == num_buckets - 1 + if not is_last_bucket: + return fut + + # Perform partial optimizer step on all buckets after the final + # bucket has been computed + # NOTE: This should not be chained as a callback to the last bucket's + # all-reduce future since that would add synchronization that delays + # all optimizer computation to wait for that last all-reduce + for bucket_index in range(num_buckets): + assigned_ranks = overlap_info.assigned_ranks_per_bucket[bucket_index] + if rank in assigned_ranks: + # Wait on the bucket's all-reduce future to ensure correct + # gradients + assert bucket_index in overlap_info.bucket_index_to_future, ( + f"All-reduce future for bucket {bucket_index} not saved " + f"on rank {rank}" + ) + allreduce_future = overlap_info.bucket_index_to_future[bucket_index] + allreduce_future.wait() + + # Perform the partial optimizer step + curr_bucket = overlap_info.bucket_index_to_bucket[bucket_index] + _perform_local_step(curr_bucket, zero, rank) + + _broadcast_bucket(bucket_index, zero) + + # Ensure that all parameter updates are finished before the + # next forward pass + overlap_info.wait_for_broadcasts() + overlap_info.clear_per_iter_info() + + return fut + + return hook_with_zero_fn + + +def hook_with_zero_step_interleaved( + hook: Callable[[Any, dist.GradBucket], torch.futures.Future], + ddp: DistributedDataParallel, + zero: ZeroRedundancyOptimizer, + shard_buckets: bool = False, +) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]: + r""" + Modify ``hook`` to overlap :class:`ZeroRedundancyOptimizer` optimizer step with :class:`DistributedDataParallel` backward pass + + This approach overlaps the optimizer computation and communication with the + backward computation and communication. In particular, once a bucket's + gradients have been computed, the optimizer computation using those + gradients is launched (though the actual computation must wait for the + bucket's all-reduce to complete). This yields an interleaving of all- + reduces and broadcasts in the communication stream. + + This approach may be preferred over :meth:`hook_with_zero_step` if + communication is relatively fast compared to computation. + + Arguments: + hook (Any * dist.GradBucket -> torch.futures.Future): the hook to + modify. + ddp (DistributedDataParallel): the :class:`DistributedDataParallel` + instance to use. + zero (ZeroRedundancyOptimizer): the :class:`ZeroRedundancyOptimizer` + instance to use. + shard_buckets (bool): if ``True``, then the assignment of each + :class:`DistributedDataParallel` bucket is partitioned across + possibly multiple :class:`ZeroRedundancyOptimizer` instances (i.e. + across possibly multiple ranks) to approximate uniformity; if + ``False``, then each bucket is wholly assigned to a single + :class:`ZeroRedundancyOptimizer` instance (i.e. to a single rank). + + Returns: + The modified hook. + + Raises: + ValueError: if ``zero`` was constructed with ``overlap_with_ddp=False``. + RuntimeError: if using any backend other than NCCL since currently + Gloo may hang. + + .. warning:: + Given the way that overlapping :class:`DistributedDataParallel` with + :class:`ZeroRedundancyOptimizer` is currently implemented, the first + two or three training iterations do not perform parameter updates in + the optimizer step, depending on if ``static_graph=False`` or + ``static_graph=True``, respectively. This is because it needs + information about the gradient bucketing strategy used by + :class:`DistributedDataParallel`, which is not finalized until the + second forward pass if ``static_graph=False`` or until the third + forward pass if ``static_graph=True``. + """ + if not zero._overlap_with_ddp: + raise ValueError( + "ZeroRedundancyOptimizer must be constructed with " + "`overlap_with_ddp=True` to use this hook properly" + ) + ddp_ref = weakref.ref(ddp) + + # NOTE: Gloo may hang with this overlapping approach, so we require + # NCCL/HCCL backend for now; see https://github.com/pytorch/pytorch/issues/62300 + pg = dist.get_backend(ddp_ref().process_group) # type: ignore[union-attr] + if (pg != dist.Backend.NCCL) and (pg != "hccl"): + raise RuntimeError( + "Overlapping DDP with ZeRO using this approach currently requires " + "NCCL/HCCL backend to avoid hangs" + ) + + if shard_buckets: + zero._overlap_info.shard_buckets = True + zero._overlap_info.total_size = 0 + + def hook_with_zero_interleaved_fn( + state, + bucket: dist.GradBucket, + ) -> torch.futures.Future[torch.Tensor]: + r""" + Return :class:`Future` that gives gradient bucket tensor and performs partial :class:`ZeroRedundancyOptimizer` :meth:`step`. + + This function uses the gradients in gradient in given bucket to perform a partial + :class:`ZeroRedundancyOptimizer` :meth:`step` + + Arguments: + state: any state for the hook. + bucket (dist.GradBucket): the :class:`DistributedDataParallel` + gradient bucket. + """ + fut = hook(state, bucket) + _hook_with_zero_step_setup(ddp_ref, zero, bucket) + if zero._overlap_info.status != _OverlapStatus.INITIALIZED: + return fut + + def zero_step(fut: torch.futures.Future) -> torch.Tensor: + r""" + Perform partial :class:`ZeroRedundancyOptimizer` :meth:`step` using gradients in the :class:`DistributedDataParallel`. + + Returns: + A :class:`torch.Tensor` representing the contents of the + gradient bucket. + """ + overlap_info = zero._overlap_info + bucket_index = bucket.index() + rank = zero.global_rank + + assigned_ranks = overlap_info.assigned_ranks_per_bucket[bucket_index] + overlap_info.bucket_indices_seen.append(bucket_index) + if rank in assigned_ranks: + _perform_local_step(bucket, zero, rank) + + _broadcast_bucket(bucket_index, zero) + + num_buckets = len(overlap_info.params_per_bucket) + if len(overlap_info.bucket_indices_seen) == num_buckets: + # Ensure that all parameter updates are finished before the + # next forward pass + overlap_info.wait_for_broadcasts() + overlap_info.clear_per_iter_info() + + return bucket.buffer() + + return fut.then(zero_step) + + return hook_with_zero_interleaved_fn diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py b/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..cdd2762d5bab6c2e548723b2192aed759feef26b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py @@ -0,0 +1,29 @@ +from typing import Any + +import torch +from torch.distributed import GradBucket + + +__all__ = ["noop_hook"] + + +def noop_hook(_: Any, bucket: GradBucket) -> torch.futures.Future[torch.Tensor]: + """ + Return a future that wraps the input, so it is a no-op that does not incur any communication overheads. + + This hook should **only** be used for headroom analysis of allreduce optimization, + instead of the normal gradient synchronization. + For example, if only less than 10% speedup of training time can be observed after this hook is registered, + it usually implies that allreduce is not a performance bottleneck for this case. + Such instrumentation can be particularly useful + if GPU traces cannot be easily retrieved or the trace analysis is complicated + some factors such as the overlap between allreduce and computation or the desynchronization across ranks. + + Example:: + >>> # xdoctest: +SKIP + >>> ddp_model.register_comm_hook(None, noop_hook) + """ + fut: torch.futures.Future[torch.Tensor] = torch.futures.Future() + fut.set_result(bucket.buffer()) + + return fut diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py b/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..97bb643884258b77af853168fa7d8552d58def7b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py @@ -0,0 +1,205 @@ +# mypy: allow-untyped-defs +from typing import Any, Callable, cast + +import torch +import torch.distributed as dist + + +__all__ = [ + "allreduce_hook", + "fp16_compress_hook", + "bf16_compress_hook", + "fp16_compress_wrapper", + "bf16_compress_wrapper", +] + + +def _allreduce_fut( + process_group: dist.ProcessGroup, tensor: torch.Tensor +) -> torch.futures.Future[torch.Tensor]: + """Average the input gradient tensor by allreduce and returns a future.""" + group_to_use = process_group if process_group is not None else dist.group.WORLD + + # Apply the division first to avoid overflow, especially for FP16. + tensor.div_(group_to_use.size()) + + return ( + dist.all_reduce(tensor, group=group_to_use, async_op=True) + .get_future() + .then(lambda fut: fut.value()[0]) + ) + + +def allreduce_hook( + process_group: dist.ProcessGroup, bucket: dist.GradBucket +) -> torch.futures.Future[torch.Tensor]: + """ + Call ``allreduce`` using ``GradBucket`` tensors. + + Once gradient tensors are aggregated across all workers, its ``then`` + callback takes the mean and returns the result. + + If user registers this DDP communication hook, + DDP results is expected to be same as the case where no hook was registered. + Hence, this won't change behavior of DDP and user can use this as a reference + or modify this hook to log useful information or any other purposes while + unaffecting DDP behavior. + + Example:: + >>> # xdoctest: +SKIP + >>> ddp_model.register_comm_hook(process_group, allreduce_hook) + """ + return _allreduce_fut(process_group, bucket.buffer()) + + +def _compress_hook( + dtype: torch.dtype, + process_group: dist.ProcessGroup, + bucket: dist.GradBucket, +) -> torch.futures.Future[torch.Tensor]: + group_to_use = process_group if process_group is not None else dist.group.WORLD + world_size = group_to_use.size() + + buffer = ( + cast(tuple[torch.Tensor, ...], bucket)[0] + if isinstance(bucket, tuple) + else bucket.buffer() + ) + compressed_tensor = buffer.to(dtype).div_(world_size) + + def decompress(fut): + decompressed_tensor = buffer + # Decompress in place to reduce the peak memory. + # See: https://github.com/pytorch/pytorch/issues/45968 + value = fut if isinstance(fut, torch.Tensor) else fut.value()[0] + decompressed_tensor.copy_(value) + return decompressed_tensor + + if torch.compiler.is_compiling(): + grad = dist._functional_collectives.all_reduce( + compressed_tensor, "sum", group_to_use + ) + return decompress(grad) + else: + fut = dist.all_reduce( + compressed_tensor, group=group_to_use, async_op=True + ).get_future() + return fut.then(decompress) + + +def fp16_compress_hook( + process_group: dist.ProcessGroup, + bucket: dist.GradBucket, +) -> torch.futures.Future[torch.Tensor]: + """ + Compress by casting ``GradBucket`` to ``torch.float16`` divided by process group size. + + This DDP communication hook implements a simple gradient compression + approach that casts ``GradBucket`` tensor to half-precision floating-point format (``torch.float16``) + and then divides it by the process group size. + It allreduces those ``float16`` gradient tensors. Once compressed gradient + tensors are allreduced, the chained callback ``decompress`` casts it back to the input data type (such as ``float32``). + + Example:: + >>> # xdoctest: +SKIP + >>> ddp_model.register_comm_hook(process_group, fp16_compress_hook) + """ + return _compress_hook(torch.float16, process_group, bucket) + + +def bf16_compress_hook( + process_group: dist.ProcessGroup, + bucket: dist.GradBucket, +) -> torch.futures.Future[torch.Tensor]: + """ + Warning: This API is experimental, and it requires NCCL version later than 2.9.6. + + This DDP communication hook implements a simple gradient compression + approach that casts ``GradBucket`` tensor to half-precision + `Brain floating point format `_ (``torch.bfloat16``) + and then divides it by the process group size. + It allreduces those ``bfloat16`` gradient tensors. Once compressed gradient + tensors are allreduced, the chained callback ``decompress`` casts it back to the input data type (such as ``float32``). + + Example:: + >>> # xdoctest: +SKIP + >>> ddp_model.register_comm_hook(process_group, bf16_compress_hook) + """ + return _compress_hook(torch.bfloat16, process_group, bucket) + + +def fp16_compress_wrapper( + hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]], +) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]: + """ + Cast input tensor to ``torch.float16``, cast result of hook back to input dtype. + + This wrapper casts the input gradient tensor of a given DDP communication hook to half-precision + floating point format (``torch.float16``), and casts the resulting tensor of the given hook back to + the input data type, such as ``float32``. + Therefore, ``fp16_compress_hook`` is equivalent to ``fp16_compress_wrapper(allreduce_hook)``. + + Example:: + >>> # xdoctest: +SKIP + >>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10) + >>> ddp_model.register_comm_hook(state, fp16_compress_wrapper(powerSGD_hook)) + """ + + def fp16_compress_wrapper_hook( + hook_state, bucket: dist.GradBucket + ) -> torch.futures.Future[torch.Tensor]: + # Cast bucket tensor to FP16. + bucket.set_buffer(bucket.buffer().to(torch.float16)) + + fut = hook(hook_state, bucket) + + def decompress(fut): + decompressed_tensor = bucket.buffer() + # Decompress in place to reduce the peak memory. + # See: https://github.com/pytorch/pytorch/issues/45968 + decompressed_tensor.copy_(fut.value()) + return decompressed_tensor + + # Decompress after hook has run. + return fut.then(decompress) + + return fp16_compress_wrapper_hook + + +def bf16_compress_wrapper( + hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]], +) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]: + """ + Warning: This API is experimental, and it requires NCCL version later than 2.9.6. + + This wrapper casts the input gradient tensor of a given DDP communication hook to half-precision + `Brain floating point format `_ (``torch.bfloat16``), + and casts the resulting tensor of the given hook back to the input data type, such as ``float32``. + + Therefore, ``bf16_compress_hook`` is equivalent to ``bf16_compress_wrapper(allreduce_hook)``. + + Example:: + >>> # xdoctest: +SKIP + >>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10) + >>> ddp_model.register_comm_hook(state, bf16_compress_wrapper(powerSGD_hook)) + """ + + def bf16_compress_wrapper_hook( + hook_state, bucket: dist.GradBucket + ) -> torch.futures.Future[torch.Tensor]: + # Cast bucket tensor to BF16. + bucket.set_buffer(bucket.buffer().to(torch.bfloat16)) + + fut = hook(hook_state, bucket) + + def decompress(fut): + decompressed_tensor = bucket.buffer() + # Decompress in place to reduce the peak memory. + # See: https://github.com/pytorch/pytorch/issues/45968 + decompressed_tensor.copy_(fut.value()) + return decompressed_tensor + + # Decompress after hook has run. + return fut.then(decompress) + + return bf16_compress_wrapper_hook diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py b/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..abe476e92f46dd3499198883d37096ceb035a0df --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py @@ -0,0 +1,86 @@ +from dataclasses import dataclass +from typing import Any, no_type_check + +import torch +import torch.distributed as dist +from torch.autograd import Variable +from torch.distributed.utils import _free_storage + + +@dataclass +class _AllreduceUpcastHookState: + """ + State to manage DDP mixed precision in backward / gradient communication. + + This contains a weakref to the DDP module for access to reducer and process + group, and a stream to run parameter and gradient upcasts. + """ + + ddp_weakref: Any + upcast_stream: torch.Stream + wait_for_stream_enqueued: bool = False + + +@no_type_check +def _reducer_allreduce_and_upcast_hook( + hook_state: _AllreduceUpcastHookState, bucket: dist.GradBucket +) -> torch.futures.Future[torch.Tensor]: + """ + Perform allreduce in precision ``reduce_dtype``, upcast to prepare for optimizer. + + Performs allreduce in the reduced precision given by DDP's mixed precision + reduce_dtype, and upcasts parameters and gradients to fp32 in preparation + to run the optimizer. + """ + ddp_weakref = hook_state.ddp_weakref + reducer, process_group = ddp_weakref().reducer, ddp_weakref().process_group + # Cast bucket if different than param_dtype. + if ( + ddp_weakref().mixed_precision.param_dtype + != ddp_weakref().mixed_precision.reduce_dtype + ): + # Cast bucket tensor to reduce_dtype + bucket.set_buffer( + bucket.buffer().to(ddp_weakref().mixed_precision.reduce_dtype) + ) + fut = reducer._run_allreduce_hook(bucket) + ret_fut = torch.futures.Future() + stream = hook_state.upcast_stream + with stream: + fut.wait() + bucket.buffer().div_(process_group.size()) + ret_fut.set_result(bucket.buffer()) + + # Upcast parameters and gradients so optimizer step can run in fp32. + for p in bucket.parameters(): + p.data = p._fp_param + # free storage for mp param as it will be allocated again in next + # forward pass. + _free_storage(p._mp_param) + p.grad.data = p.grad.to(p.data.dtype) + + # enqueue a callback to wait for this stream at end of backward + def wait_for_stream_cb(): + torch.accelerator.current_stream().wait_stream(stream) + # Remove post-backward hooks since they are re-installed in next + # iteration, similar to FSDP. + # Parameters that don't require grad still needed to be casted since + # they may participate in computation. However, they would not be recast + # by hook above as they don't have a grad hook installed, so cast them + # back here. + for _, p in ddp_weakref().module.named_parameters(): + if hasattr(p, "_ddp_mp_hook_state"): + p._ddp_mp_hook_state[1].remove() + delattr(p, "_ddp_mp_hook_state") + if not p.requires_grad and not hasattr(p, "_ddp_ignored"): + p.data = p._fp_param + + # reset for next backward pass + hook_state.wait_for_stream_enqueued = False + + if not hook_state.wait_for_stream_enqueued: + Variable._execution_engine.queue_callback(wait_for_stream_cb) + # mark that the callback is enqueued + hook_state.wait_for_stream_enqueued = True + + return ret_fut diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py b/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..235ac2cb073ccfa2196fea8c4ace403017f66370 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py @@ -0,0 +1,162 @@ +# mypy: allow-untyped-defs +from dataclasses import dataclass +from functools import partial +from typing import Any, Callable, no_type_check + +import torch +import torch.distributed as dist +from torch.autograd import Variable + + +__all__: list[str] = [] + +_FUNCTIONAL_OPTIM_STEP_METHOD_NAME = "step_param" + + +class _OptimizerHookState: + """ + Holds state for running optimizer in-line after DDP communication hook. + + Currently contains only optimizer class which must have a method `step_param`. + """ + + __slots__ = ["functional_optimizer", "params_to_optimize"] + + def __init__(self, functional_optim, params=None): + self.functional_optimizer = functional_optim + self._check_valid_functional_optim() + self._set_params_to_optimize(params) + + def _set_params_to_optimize(self, params): + if params is not None: + self.params_to_optimize = set(params) + + def _check_valid_functional_optim(self): + if not hasattr(self.functional_optimizer, _FUNCTIONAL_OPTIM_STEP_METHOD_NAME): + raise ValueError( + f"Class {type(self.functional_optimizer)} must implement method " + f"{_FUNCTIONAL_OPTIM_STEP_METHOD_NAME}." + ) + + +@dataclass +class _OptimInBackwardHookState: + optim_stream: torch.Stream + wait_for_optim_stream_enqueued: bool + + +@no_type_check +def _apply_optim_in_backward_hook( + gradient_is_bucket_view: bool, +) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]: + r""" + Register hook to apply the optimizer in backward. + + If torch.distributed.optim._apply_optimizer_in_backward is used to overlap + optimizer with backward pass, DDP will run the below hook to run optimizer + step for parameters after gradient communication has taken place. + """ + optim_in_bwd_state = _OptimInBackwardHookState( + optim_stream=torch.Stream(), + wait_for_optim_stream_enqueued=False, + ) + + def apply_optim_in_backward_hook( + hook_state: Any, + bucket: dist.GradBucket, + optim_stream_state, + ) -> torch.futures.Future[torch.Tensor]: + # Run original hook + ddp_weakref = hook_state + ddp_inst = ddp_weakref() + reducer, process_group = ddp_inst.reducer, ddp_inst.process_group + fut = reducer._run_allreduce_hook(bucket) + optimizer_stream = optim_stream_state.optim_stream + with optimizer_stream: + fut.wait() + # Apply gradient division since C++ side only allreduces and does + # not average. TODO: (rohan-varma) the div factor may be different + # when running with join hook + bucket.buffer().div_(process_group.size()) + model_params = bucket.parameters() + grads = bucket.gradients() + # TODO (rohan-varma): upcast as needed for DDP mixed precision, + # once optimizer in backward + DDP mixed precision is supported. + for p, g in zip(model_params, grads): + if hasattr(p, "_in_backward_optimizers"): + # Note: need to set grad to the bucket's grad, because + # running allreduce results in the bucket's grad being + # reduced, but not grad field. + if not gradient_is_bucket_view: + p.grad = g + for optim in p._in_backward_optimizers: + optim.step() + + # Need to return a Future[Tensor] to obey comm hook API contract. + ret_fut = torch.futures.Future() + ret_fut.set_result(bucket.buffer()) + + # enqueue a callback to wait for this optimizer stream at the end of + # backward and set all DDP managed grads to None. + def wait_for_optim_stream_callback(): + torch.accelerator.current_stream().wait_stream( + optim_stream_state.optim_stream + ) + # Set DDP managed grads to None + for param in ddp_inst._get_data_parallel_params(ddp_inst.module): + if hasattr(param, "_in_backward_optimizers"): + param.grad = None + + # reset for the next backwards pass + optim_stream_state.wait_for_optim_stream_enqueued = False + + if not optim_stream_state.wait_for_optim_stream_enqueued: + Variable._execution_engine.queue_callback(wait_for_optim_stream_callback) + # mark that the callback is enqueued + optim_stream_state.wait_for_optim_stream_enqueued = True + + return ret_fut + + comm_hook = partial( + apply_optim_in_backward_hook, optim_stream_state=optim_in_bwd_state + ) + # These are needed for DDP's logging of comm hooks + comm_hook.__name__ = apply_optim_in_backward_hook.__name__ + comm_hook.__qualname__ = apply_optim_in_backward_hook.__qualname__ + + return comm_hook + + +def _hook_then_optimizer( + hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]], + optimizer_state: _OptimizerHookState, +) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]: + r"""Run optimizer in a functional fashion after DDP communication hook.""" + has_set_params = ( + hasattr(optimizer_state, "params_to_optimize") + and optimizer_state.params_to_optimize is not None + ) + + def hook_then_optimizer_wrapper( + hook_state, bucket: dist.GradBucket + ) -> torch.futures.Future[torch.Tensor]: + # Run original hook + fut = hook(hook_state, bucket) + + def optimizer_step(fut): + gradient_tensors = bucket.gradients() + model_params = bucket.parameters() + for grad_tensor, model_param in zip(gradient_tensors, model_params): + if ( + not has_set_params + or model_param in optimizer_state.params_to_optimize + ): + optimizer_state.functional_optimizer.step_param( + model_param, + grad_tensor, + ) + return bucket.buffer() + + return fut.then(optimizer_step) + + return hook_then_optimizer_wrapper diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py b/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..f193ca51364062ce6894e4ee3d113b2e69686796 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py @@ -0,0 +1,124 @@ +# mypy: allow-untyped-defs +import logging + +import torch +import torch.distributed as dist + +from . import default_hooks as default + + +logger = logging.getLogger(__name__) + + +class PostLocalSGDState: + r""" + Store state for all-reducing gradients globally until given step, then locally after. + + Stores the state for all-reducing gradients globally using ``process_group`` until step ``start_localSGD_iter``, + and all-reducing gradients locally using ``subgroup`` afterwards. + + If ``process_group`` is ``None``, the global process group will be used. + If ``subgroup`` is ``None``, the intra-node process group on each machine will be used. + + Additionally, ``post_local_gradient_allreduce`` may be worth tuning, + because both true and false may give a faster convergence. + """ + + __slots__ = [ + "process_group", + "subgroup", + "start_localSGD_iter", + "post_local_gradient_allreduce", + "iter", + ] + + def __init__( + self, + process_group, + subgroup, + start_localSGD_iter, + post_local_gradient_allreduce=True, + ): + """Initialize state object with given parameters and log when localSGD start.""" + logger.info( + "Local SGD will be started after %s iterations", start_localSGD_iter + ) + + # The group used for all-reducing gradients globally. + self.process_group = process_group + # The group used for all-reducing gradients locally. + self.subgroup = subgroup + self.start_localSGD_iter = start_localSGD_iter + # Allreduce gradients locally since iteration `start_localSGD_iter`. + # This may help with the convergence efficiency at the cost of relatively cheap intra-subgroup communication. + self.post_local_gradient_allreduce = post_local_gradient_allreduce + # Iteration/step in the training loop. + self.iter = 0 + + def maybe_increase_iter(self, bucket): + """Track iterations and trigger log message at start of local SGD.""" + # Since bucket 0 is the last bucket to allreduce in an iteration. + # Only increase `iter` when bucket 0 is processed. + if bucket.is_last(): + self.iter += 1 + + if self.iter == self.start_localSGD_iter: + logger.info("Start to apply local SGD after %s iterations.", self.iter) + + +def post_localSGD_hook( + state: PostLocalSGDState, bucket: dist.GradBucket +) -> torch.futures.Future[torch.Tensor]: + """ + Run post-localSGD algorithm. + + This DDP communication hook is used for running post-localSGD algorithm, + by combining with a model averaging component (e.g., + :class:`~torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager`) + that runs after the optimizer step. + + Args: + state (PostLocalSGDState): State information to run post-localSGD. + Users mainly need to tune ``start_localSGD_iter`` to determine when to start local SGD. + bucket (dist.GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors. + Note that since DDP comm hook only supports single process single device mode, + only exactly one tensor is stored in this bucket. + + Returns: + Future handler of the communication, which updates the gradients in place. + + Example:: + >>> # xdoctest: +SKIP + >>> state = PostLocalSGDState(process_group=process_group, subgroup=subgroup, + start_localSGD_iter=10) + >>> ddp_model.register_comm_hook(state, post_localSGD_hook) + >>> # Also need to establish a model averaging module and run model averaging after ``optimizer.step()``. + >>> # Please refer to the examples in ``torch.distributed.algorithms.model_averaging.averagers`` module. + """ + global_group_to_use = ( + state.process_group if state.process_group is not None else dist.group.WORLD + ) + + # The input tensor is a flattened 1D tensor. + input_tensor = bucket.buffer() + + # Run allreduce using `global_group_to_use` in the first `start_localSGD_iter` iterations. + if state.iter < state.start_localSGD_iter: + state.maybe_increase_iter(bucket) + return default._allreduce_fut(global_group_to_use, input_tensor) # type: ignore[arg-type] + + # If `post_local_gradient_allreduce` is not set, + # then no gradient synchronization after the first `start_localSGD_iter` iterations. + if not state.post_local_gradient_allreduce: + fut: torch.futures.Future[torch.Tensor] = torch.futures.Future() + fut.set_result(input_tensor) + return fut + + # Run allreduce using `subgroup` after the first `start_localSGD_iter` iterations. + # Note that by default, a separate subgroup for each node is created which + # causes an intra-node allreduce to be done at each training step. + # From this moment, model averaging should run after the optimizer step, + # to globally allreduce all the parameters. + if state.subgroup is None: + state.subgroup, _ = dist.new_subgroups() + return default._allreduce_fut(state.subgroup, input_tensor) diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py b/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..6c9e1b4abb6dd0d635110a7b9bdef7ecaffc1d6c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py @@ -0,0 +1,860 @@ +# mypy: allow-untyped-defs +import logging +import math +from collections import defaultdict + +import torch +import torch.distributed as dist +from torch.distributed import distributed_c10d +from torch.utils._typing_utils import not_none + +from . import default_hooks as default + + +__all__ = ["PowerSGDState", "powerSGD_hook", "batched_powerSGD_hook"] + +logger = logging.getLogger(__name__) + + +def _orthogonalize(matrices, epsilon=0): + """ + Decide between Gram-Schmidt or QR factorization to orthogonalize a batch of matrices. + + QR factorization doesn't work with half-precision, but it is usually faster with a rank > 2. + """ + assert len(matrices.shape) == 3 and matrices.shape[2] <= matrices.shape[1] + + num_matrices = matrices.shape[0] + rank = matrices.shape[2] + dtype = matrices.dtype + if rank <= 2 or dtype in [torch.float16, torch.bfloat16]: + _orthogonalize_gram_schmidt(matrices, epsilon=epsilon) + else: + torch.linalg.qr( + matrices, + out=( + matrices, + torch.empty( + num_matrices, rank, rank, device=matrices.device, dtype=dtype + ), + ), + ) + + +def _orthogonalize_gram_schmidt(matrices, epsilon=0): + """ + Apply Gram-Schmidt procedure to orthogonalize a batch of matrices. + + If epsilon is 0, this is equivalent to `torch.qr(matrices, out=(matrices, _))`, + """ + num_cols = matrices.shape[2] + for i in range(num_cols): + # Normalize the i'th column. + col = matrices[:, :, i : i + 1] + # If no epsilon is added here, division by zero may be caused by vanishing gradients. + # This epsilon is not needed if the input batch of matrices covers the gradients of at least one entire layer + # in the neural network. + if epsilon == 0: + # Note that col ** 2 can underflow/overflow if we use FP16. + # May need to consider multiplying a scaling factor and dividing it later, or using bfloat16 instead. + try: + col /= torch.norm(col, dim=1, keepdim=True) + except ZeroDivisionError: + logger.error( + "The matrices to be orthogonalized has at least a column of all 0s. Please set a small value such as 1e-8 " + "as `orthogonalization_epsilon` in PowerSGD state." + ) + # Recover the values from NaNs to 0s. + col.fill_(0.0) + else: + col /= torch.norm(col, dim=1, keepdim=True) + epsilon + # Project it on the rest and remove it. + if i + 1 < num_cols: + rest = matrices[:, :, i + 1 :] + rest -= torch.sum(col * rest, dim=1, keepdim=True) * col + + +def _should_compress( + num_rows, num_cols, matrix_approximation_rank, min_compression_rate +): + """ + Recommend if tensor given is worth compressing. + + Returns a recommendation as to whether the 2D tensor described by the arguments is worth compressing, + including statistics describing the expected savings from compression. We consider a tensor worth + compressing when ``min_compression_rate`` < uncompressed size / compressed size, where + uncompressed size = ``num_rows`` * ``num_cols``, + and compressed size = (``num_rows`` + ``num_cols``) * ``matrix_approximation_rank``. + + The result of this function is a tuple of the form (compression_recommendation, uncompressed_el_count, compressed_el_count), where: + + compression_recommendation is true if the tensor is worth compressing, and false otherwise (see above); + + uncompressed_el_count is the uncompressed element count, i.e. ``num_rows`` * ``num_cols``; and, + + compress_el_count is the element count after compression, i.e. (``num_rows`` + ``num_cols``) * ``matrix_approximation_rank``. + """ # noqa: B950 + uncompressed_size = num_rows * num_cols + compressed_size = (num_rows + num_cols) * matrix_approximation_rank + return ( + compressed_size * min_compression_rate < uncompressed_size, + uncompressed_size, + compressed_size, + ) + + +def _report_compression_stats(bucket, state): + """Report compression stats at frequency of ``compression_stats_logging_frequency`` specified in PowerSGD state.""" + if bucket.is_last() and state.iter >= state.next_stats_report: + stats = state.compression_stats() + logger.info( + "Compression stats: iter %s, total before compression %s, total after compression %s, " + "rate %s", + state.iter, + stats[1], + stats[2], + stats[0], + ) + state.next_stats_report = state.iter + state.compression_stats_logging_frequency + + +class PowerSGDState: + r""" + Store both the algorithm's hyperparameters and internal state for all gradients during training. + + Particularly, ``matrix_approximation_rank`` and ``start_powerSGD_iter`` are the main hyperparameters that should be tuned by the user. + For performance, we suggest to keep binary hyperparameters ``use_error_feedback`` and ``warm_start`` on. + + 1. ``matrix_approximation_rank`` controls the size of compressed low-rank tensors, which determines the compression rate. The lower the rank, the stronger the compression. + + 1.1. If ``matrix_approximation_rank`` is too low, the full model quality will need more training steps to reach or will never reach and yield loss in accuracy. + + 1.2. The increase of ``matrix_approximation_rank`` can substantially increase the computation costs of the compression, and the accuracy may not be further improved beyond a certain ``matrix_approximation_rank`` threshold. + + To tune ``matrix_approximation_rank``, we suggest to start from 1 and increase by factors of 2 (like an exponential grid search, 1, 2, 4, ...), until a satisfactory accuracy is reached. Typically only a small value 1-4 is used. For some NLP tasks (as shown in Appendix D of the original paper), this value has been increased to 32. + + 2. ``start_powerSGD_iter`` defers PowerSGD compression until step ``start_powerSGD_iter``, and vanilla allreduce runs prior to step ``start_powerSGD_iter``. This hybrid scheme of **vanilla allreduce + PowerSGD** can effectively improve the accuracy, even a relatively small ``matrix_approximation_rank`` is used. This is because that, the beginning of training phase is usually very sensitive to inaccurate gradients, and compressing gradients too early may make the training quickly take a suboptimal trajectory, which can result in an irrecoverable impact on the accuracy. + + To tune ``start_powerSGD_iter``, we suggest to start with 10% of total training steps, and increase it until a satisfactory accuracy is reached. If there is a warm-up stage in the training, ``start_powerSGD_iter`` typically should be no less than the number of warm-up steps. + + 3. ``min_compression_rate`` is the minimum compression rate required when a layer is compressed. Due to the computation overheads incurred by the compression, a tensor is worth compressing only if there can be sufficient saving in bandwidth, where ``(num_rows + num_cols) * matrix_approximation_rank * min_compression_rate < num_rows * num_cols``. If the specified compression rate threshold cannot be satisfied, the tensor will be directly allreduced without compression. + + Compression statistics are logged every ``compression_stats_logging_frequency`` iterations once PowerSGD compression starts. + + 4. ``orthogonalization_epsilon`` can be a very small value (e.g., 1e-8) added to every normalized matrix column in orthogonalization step, to prevent div-by-zero error if any column has all 0s. If this can already be prevented (e.g., by batch normalization), an epsilon of 0 is recommended for accuracy. + + 5. ``batch_tensors_with_same_shape`` controls whether to compress and decompress tensors with same shape in a batched operation to achieve higher parallelism. Note that you should also increase the bucket size (i.e., ``bucket_cap_mb`` arg in DDP constructor) to make more same-shaped tensors appear in the same bucket, however this may reduce the overlap between computation and communication, and increase the memory footprint due to stacking the tensors of the same shape. Set to ``True`` if the compression / decompression computation is a bottleneck. + + .. warning :: + If error feedback or warm-up is enabled, the minimum value of ``start_powerSGD_iter`` allowed in DDP is 2. + This is because there is another internal optimization that rebuilds buckets at iteration 1 in DDP, + and this can conflict with any tensor memorized before the rebuild process. + """ # noqa: B950 + + __slots__ = [ + "process_group", + # The fields below are the hyperparameters that often need to be tuned by the user. + "matrix_approximation_rank", + "start_powerSGD_iter", + # The fields below are the hyperparameters that seldom need be tuned by the user. + "min_compression_rate", + "orthogonalization_epsilon", + # The fields below are the binary hyperparameters recommended to be turned on for performance and accuracy. + "use_error_feedback", + "warm_start", + "batch_tensors_with_same_shape", + # The fields below are internal state. + "rng", + "error_dict", + "p_memory_dict", + "q_memory_dict", + "iter", + # The fields below are for recording compression stats. + "total_numel_before_compression", + "total_numel_after_compression", + "compression_stats_logging_frequency", + "next_stats_report", + ] + + def __init__( + self, + process_group, + matrix_approximation_rank=1, + start_powerSGD_iter=1_000, + min_compression_rate=2, + use_error_feedback=True, + warm_start=True, + orthogonalization_epsilon=0, + random_seed=0, + compression_stats_logging_frequency=10_000, + batch_tensors_with_same_shape: bool = False, + ): + logger.info( + "PowerSGD config: matrix_approximation_rank = %s; start_powerSGD_iter = %s; " + "min_compression_rate = %s; orthogonalization_epsilon = %s; use_error_feedback = %s; warm_start = %s; " + "random_seed = %s; compression_stats_logging_frequency = %s; batch_tensors_with_same_shape = %s", + matrix_approximation_rank, + start_powerSGD_iter, + min_compression_rate, + orthogonalization_epsilon, + use_error_feedback, + warm_start, + random_seed, + compression_stats_logging_frequency, + batch_tensors_with_same_shape, + ) + + self.process_group = process_group + self.matrix_approximation_rank = matrix_approximation_rank + # Deferring PowerSGD compression util step 'start_powerSGD_iter' can have two advantages: + # 1) It turns out that PowerSGD may lead to a non-trivial accuracy loss, + # even if the matrix approximation rank is increased to a large value. + # To mitigate the accuracy loss, a simple yet effective way is mixing vanilla allreduce + # (or a more conservative compression such as FP16 compression) with PowerSGD. + # 2) There is an internal optimization of rebuilding buckets process in DDP, + # in order to save the memory space. + # This step takes place after the first iteration. + # However, this means that the shape of input bucketized tensors is subject to change, + # which will complicate the implementations of error feedback and warm-up. + # Running vanilla allreduce in the first few iterations can avoid this complexity. + if (use_error_feedback or warm_start) and start_powerSGD_iter <= 1: + raise ValueError( + "Expect `start_powerSGD_iter` > 1 if `use_error_feedback` or `warm_start` is enabled, " + "because PowerSGD can only be applied after the first two iterations in DDP." + ) + self.start_powerSGD_iter = start_powerSGD_iter + self.min_compression_rate = min_compression_rate + # Error feedback is usually crucial for both for convergence and generalization, + # because PowerSGD is a biased compressor, + # i.e., compressing and decompressing a random gradient does not yield the original in expectation. + # This mechanism requires a temporary copy of the input gradients, + # so it increases the peak memory consumption by the size of the gradient tensor. + # However, if the target matrices are known to be exactly low-ranked (instead of just low stable rank), + # sometimes it is possible to converge to the optima without error feedback. + # See: http://proceedings.mlr.press/v54/yurtsever17a/yurtsever17a.pdf + self.use_error_feedback = use_error_feedback + # Warm-start reuses P(s) and Q(s) from the previous iteration. + # This can improve the approximation quality and hence improve the accuracy. + # Additionally, by avoiding the initialization of these low-rank tensors at every step, + # this can also accelerate training. + # However, this is at the cost of extra memory. + self.warm_start = warm_start + # Can use a very small value to prevent div-by-zero error caused by orthogonalization of vanishing gradients. + self.orthogonalization_epsilon = orthogonalization_epsilon + # The purpose of this RNG is to generate different random seeds for initializing Q across iterations, + # but in the same order for all the DDP replicas. + # Different random seeds across iterations indicate different 'projections' of the gradients at different SGD steps. + # If the same random projection is used, + # there will be differences between the gradients that are never synchronized. + import numpy as np + + self.rng = np.random.RandomState(random_seed) + # Since there is only a single state instance for all the input buckets, + # need to maintain a dictionary that maps each bucket index to the local error. + self.error_dict: dict[int, torch.Tensor] = {} + self.p_memory_dict: dict[int, torch.Tensor] = {} + self.q_memory_dict: dict[int, torch.Tensor] = {} + # Iteration/step in the training loop. + self.iter = 0 + # Compression stats accumulators + self.total_numel_before_compression = 0 + self.total_numel_after_compression = 0 + # We'll report compression stats every 'compression_stats_logging_frequency' iterations + # Note that we always report compression stats at least once. + self.compression_stats_logging_frequency = max( + 1, compression_stats_logging_frequency + ) + self.next_stats_report = 0 + # Batching tensors with same shape can increase parallelism in compression / decompression computation. + # This requires a larger bucket size to make more same-shaped tensor to appear in one bucket, however + # this may reduce the overlap between computation and communication, and increase the memory footprint + # due to stacking tensors. + # Turn on if compression / decompression computation is a bottleneck. + self.batch_tensors_with_same_shape = batch_tensors_with_same_shape + + def __getstate__(self): + r""" + Return a ``Dict[str, Any]`` which will be pickled and saved. + + ``process_group`` is not serializable and excluded from + a returned state. + """ + logger.warning( + "NOTE: Process group is not serializable and excluded from a saved state." + ) + return { + slot: getattr(self, slot) + for slot in self.__slots__ + if slot != "process_group" + } + + def __setstate__(self, state): + r""" + Take a provided ``state`` and set to this ``PowerSGDState`` instance. + + ``process_group`` is set to default. + """ + self.process_group = distributed_c10d._get_default_group() + logger.warning( + "NOTE: Process group will be set to a default group (i.e. the world size).\ + If a different group is desired, please set `self.process_group` after PowerSGD state is loaded." + ) + for slot, value in state.items(): + setattr(self, slot, value) + + def maybe_increase_iter(self, bucket): + """Track iterations and trigger log message at start of local SGD.""" + # Since bucket 0 is the last bucket to allreduce in an iteration. + # Only increase `iter` when bucket 0 is processed. + if bucket.is_last(): + self.iter += 1 + + if self.iter == self.start_powerSGD_iter: + logger.info("Start to apply PowerSGD after %s iterations.", self.iter) + + def compression_stats(self): + r""" + Return latest compression statistics as tuple. + + Returns tuple of form (compress_rate, numel_before_compression, numel_after_compression) where: + + compress_rate is the effective compression rate i.e. (number of elements before compression) / (number of elements after compression); + + numel_before_compression is the total number of elements before compression was applied; and, + + numel_after_compression is the total number of elements after compression was applied. + """ # noqa: B950 + compress_rate = ( + self.total_numel_before_compression / self.total_numel_after_compression + if self.total_numel_after_compression > 0 + else 0 + ) + return ( + compress_rate, + self.total_numel_before_compression, + self.total_numel_after_compression, + ) + + +def powerSGD_hook( + state: PowerSGDState, bucket: dist.GradBucket +) -> torch.futures.Future[torch.Tensor]: + r""" + Implement PowerSGD algorithm. + + This DDP communication hook implements PowerSGD gradient compression + algorithm described in the `paper `_. + Once gradient tensors are aggregated across all workers, this hook applies + compression as follows: + + 1. Views the input flattened 1D gradient tensor as a list of per-parameter tensors, and divides all the tensors into two groups: + + 1.1 The tensors that should be compressed before allreduce, because the compression can give enough saving in bandwidth. + + 1.2 Rest of the tensors will be directly allreduced without compression, including all the vector tensors (for biases). + + 2. Handles uncompressed tensors: + + 2.1. Allocate contiguous memory for those uncompressed tensors, and allreduces all the uncompressed tensors as a batch, without compression; + + 2.2. Copies the individual uncompressed tensors from the contiguous memory back to the input tensor. + + 3. Handles the tensors that should be compressed by PowerSGD compression: + + 3.1. For each tensor M, creates two low-rank tensors P and Q for decomposing M, + such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized; + + 3.2. Computes each P in Ps, which is equal to MQ; + + 3.3. Allreduces Ps as a batch; + + 3.4. Orthogonalizes each P in Ps; + + 3.5. Computes each Q in Qs, which is approximately equal to M^TP; + + 3.6. Allreduces Qs as a batch; + + 3.7. Computes each M among all the compressed tensors, which is approximately equal to PQ^T. + + Note that this communication hook enforces vanilla allreduce for the first ``state.start_powerSGD_iter`` iterations. + This not only gives the user more control over the tradeoff between speedup and accuracy, + but also helps abstract away some complexity of the internal optimization of DDP for future communication hook developers. + + Args: + state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc. + To tune the compression configs, mainly need to tune ``matrix_approximation_rank``, ``start_powerSGD_iter`` + and ``min_compression_rate``. + bucket (dist.GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors. + Note that since DDP comm hook only supports single process single device mode, + only exactly one tensor is stored in this bucket. + + Returns: + Future handler of the communication, which updates the gradients in place. + + Example:: + >>> # xdoctest: +SKIP + >>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, + start_powerSGD_iter=10, min_compression_rate=0.5) + >>> ddp_model.register_comm_hook(state, powerSGD_hook) + """ # noqa: B950 + process_group = state.process_group + group_to_use = ( + process_group if process_group is not None else not_none(dist.group.WORLD) + ) + world_size = group_to_use.size() + + # The input tensor is a flattened 1D tensor. + input_tensor = bucket.buffer() + + # Run vanilla allreduce in the first `start_powerSGD_iter` iterations. + if state.iter < state.start_powerSGD_iter: + state.maybe_increase_iter(bucket) + return default._allreduce_fut(group_to_use, input_tensor) + + # Apply PowerSGD after `start_powerSGD_iter` iterations. + device = input_tensor.device + dtype = input_tensor.dtype + + # Incorporate the error from the previous state into the gradients. + bucket_index = bucket.index() + input_tensor_cp = None + total_length = input_tensor.shape[0] + if state.use_error_feedback: + if bucket_index in state.error_dict: + input_tensor.add_(state.error_dict[bucket_index]) + else: + logger.info( + "A zero tensor of length %s that represents local error is created.", + total_length, + ) + state.error_dict[bucket_index] = torch.zeros( + total_length, device=device, dtype=dtype + ) + + # Keep a copy of the input tensor, + # so that we can compute the local error caused by compression later, + # by comparing this copy and the input tensor updated after decompression. + input_tensor_cp = torch.clone(input_tensor).detach() + + # Unflatten the input tensor into per-parameter tensors, for layer-wise compression. + tensors = bucket.gradients() + + # Step I: Divide all the tensors into two groups, + # one will be compressed before allreduce and the other will be directly allreduced without compression. + tensors_to_compress, uncompressed_tensors = [], [] + total_Ps_size = 0 + total_Qs_size = 0 + for tensor in tensors: + matrix = tensor.view(tensor.shape[0], -1) + n, m = matrix.shape + matrix_approximation_rank = min(n, m, state.matrix_approximation_rank) + compress_test = _should_compress( + n, m, matrix_approximation_rank, state.min_compression_rate + ) + state.total_numel_before_compression += compress_test[1] + if compress_test[0]: + tensors_to_compress.append(matrix) + total_Ps_size += n * matrix_approximation_rank + total_Qs_size += m * matrix_approximation_rank + state.total_numel_after_compression += compress_test[2] + else: + uncompressed_tensors.append(tensor) + state.total_numel_after_compression += compress_test[1] + + _report_compression_stats(bucket, state) + + # Step II: Handle uncompressed tensors. + # Allocate contiguous memory for these tensors to allreduce efficiently. + uncompressed_tensors_memory = ( + torch.cat([tensor.view(-1) for tensor in uncompressed_tensors]) + if uncompressed_tensors + else torch.tensor([], device=device, dtype=dtype) + ) + + # Step III: Handle the tensors that should be compressed. + # Allocate contiguous memory for Ps and Qs to allreduce efficiently. + # If warm-start is enabled, reuse Ps and Qs from the previous iteration if possible. + # The memory spaces of Ps and Qs need to be allocated in the first iteration when PowerSGD is applied. + need_randomize_qs = False + if not state.warm_start or bucket_index not in state.p_memory_dict: + need_randomize_qs = True + # If warm-start is disabled, low-rank tensors will be initialized at every step. + # Only log this if warm-start to avoid spamming. + if state.warm_start: + logger.info( + "Allocating contiguous memory of length %s for Ps, and of length %s for Qs, respectively.", + total_Ps_size, + total_Qs_size, + ) + state.p_memory_dict[bucket_index] = torch.empty( + total_Ps_size, device=device, dtype=dtype + ) + state.q_memory_dict[bucket_index] = torch.empty( + total_Qs_size, device=device, dtype=dtype + ) + + # Batch tensors to compress by shape. + shape_to_tensors = defaultdict(list) + for tensor in tensors_to_compress: + shape_to_tensors[tensor.shape].append(tensor) + + # This function decides whether to batch tensors with same shape or not according to the argument, + # so the following process could share the same code. + def maybe_batched_tensors_to_compress(): + for tensors in shape_to_tensors.values(): + if state.batch_tensors_with_same_shape: + batch_size = len(tensors) + if batch_size == 1: + # Use the original tensor to avoid copy. + yield tensors[0].unsqueeze(0) + else: + yield torch.stack(tensors) + else: + for tensor in tensors: + yield tensor.unsqueeze(0) + + # Create Ps and Qs that point to the allocated memory. + tensors_to_compress = [] + ps = [] + qs = [] + p_idx = 0 + q_idx = 0 + for tensor in maybe_batched_tensors_to_compress(): + batch_size, n, m = tensor.shape + matrix_approximation_rank = min(n, m, state.matrix_approximation_rank) + tensors_to_compress.append(tensor) + ps.append( + state.p_memory_dict[bucket_index][ + p_idx : p_idx + batch_size * n * matrix_approximation_rank + ].view(batch_size, n, matrix_approximation_rank) + ) + qs.append( + state.q_memory_dict[bucket_index][ + q_idx : q_idx + batch_size * m * matrix_approximation_rank + ].view(batch_size, m, matrix_approximation_rank) + ) + p_idx += batch_size * n * matrix_approximation_rank + q_idx += batch_size * m * matrix_approximation_rank + + # If warm-start is enabled, reuse Qs from the previous iteration if possible and skip filling random values. + # The exception is the first iteration when PowerSGD is applied. + if not need_randomize_qs: + for q in qs: + _orthogonalize(q, state.orthogonalization_epsilon) + else: + with torch.random.fork_rng(devices=[]): + # Fork this RNG to avoid changing the seed globally and affecting the random sampling anywhere else in the training. + # The seed makes sure that the initial random values are the same across all the DDP replicas. + # This seed should differ at every step. + # Since it is very slow to fork RNG state across all the CUDA devices, + # only fork on CPU and then move the generated tensor to the CUDA device (by overwriting q). + torch.manual_seed(state.rng.randint(1_000_000_000)) + for q in qs: + q.copy_( + torch.randn( + *q.shape, + device="cpu", + dtype=dtype, + ) + ) + _orthogonalize(q, state.orthogonalization_epsilon) + + # Compute Ps. + for tensor, q, p in zip(tensors_to_compress, qs, ps): + torch.bmm(tensor, q, out=p) + + # This allreduce is only applied to uncompressed tensors, + # so it should have been kicked off before the above computation on the compressed tensors to hide more communication costs. + # However, this somehow requires a separate future chain at this time. + allreduce_contiguous_uncompressed_tensors_fut = dist.all_reduce( + uncompressed_tensors_memory, group=group_to_use, async_op=True + ).get_future() + + def unpack_uncompressed_tensors_and_allreduce_ps(fut): + uncompressed_tensors_memory = fut.value()[0].div_(world_size) + idx = 0 + for tensor in uncompressed_tensors: + tensor.copy_( + uncompressed_tensors_memory[idx : idx + tensor.numel()].view_as(tensor) + ) + idx += tensor.numel() + + # Since these Ps will be orthogonalized later, no need to divide them by world size. + return ( + dist.all_reduce( + state.p_memory_dict[bucket_index], group=group_to_use, async_op=True + ) + .get_future() + .wait()[0] + ) + + def compute_qs(fut): + state.p_memory_dict[bucket_index] = fut.value() + for p in ps: + _orthogonalize(p, state.orthogonalization_epsilon) + + # Compute Qs. + for tensor, p, q in zip(tensors_to_compress, ps, qs): + torch.bmm(tensor.transpose(1, 2), p, out=q) + + # TODO: The above procedure does two matmul+allreduce steps per iteration -- + # one left multiplication and one right multiplication. + # For warm-start, can take one such step at a time, and alternate between them. + + # Allreduce Qs. + return ( + dist.all_reduce( + state.q_memory_dict[bucket_index], group=group_to_use, async_op=True + ) + .get_future() + .wait()[0] + ) + + def decompress(fut): + state.q_memory_dict[bucket_index] = fut.value().div_(world_size) + + for p, q, tensor in zip(ps, qs, tensors_to_compress): + torch.bmm(p, q.transpose(1, 2), out=tensor) + + # Copy batched tensors back to original buffer. + if state.batch_tensors_with_same_shape: + for tensor in tensors_to_compress: + if tensor.shape[0] == 1: + # Skip tensor with batch_size == 1 since itself is the original tensor. + continue + original_tensors = shape_to_tensors[tensor.shape[1:]] + for i, original_tensor in enumerate(original_tensors): + original_tensor.copy_(tensor[i]) + + if torch.cuda.is_available(): + torch.cuda.synchronize(device) + + if state.use_error_feedback: + # Memorize the local errors. + state.error_dict[bucket_index] = input_tensor_cp - input_tensor + if not state.warm_start: + state.p_memory_dict.clear() + state.q_memory_dict.clear() + + state.maybe_increase_iter(bucket) + + return input_tensor + + return ( + allreduce_contiguous_uncompressed_tensors_fut.then( + unpack_uncompressed_tensors_and_allreduce_ps + ) + .then(compute_qs) + .then(decompress) + ) + + +def batched_powerSGD_hook( + state: PowerSGDState, bucket: dist.GradBucket +) -> torch.futures.Future[torch.Tensor]: + r""" + Implement simplified PowerSGD algorithm. + + This DDP communication hook implements a simplified PowerSGD gradient compression + algorithm described in the `paper `_. + This variant does not compress the gradients layer by layer, + but instead compresses the flattened input tensor that batches all the gradients. + Therefore, it is **faster** than :meth:`powerSGD_hook`, + but usually results in a **much lower accuracy**, unless ``matrix_approximation_rank`` is 1. + + .. warning :: + Increasing ``matrix_approximation_rank`` here may not necessarily increase the accuracy, + because batching per-parameter tensors without column/row alignment can destroy low-rank structure. + Therefore, the user should always consider :meth:`powerSGD_hook` first, + and only consider this variant when a satisfactory accuracy can be achieved when ``matrix_approximation_rank`` is 1. + + Once gradient tensors are aggregated across all workers, this hook applies + compression as follows: + + 1. Views the input flattened 1D gradient tensor as a square-shaped tensor M with 0 paddings; + + 2. Creates two low-rank tensors P and Q for decomposing M, such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized; + + 3. Computes P, which is equal to MQ; + + 4. Allreduces P; + + 5. Orthogonalizes P; + + 6. Computes Q, which is approximately equal to M^TP; + + 7. Allreduces Q; + + 8. Computes M, which is approximately equal to PQ^T. + + 9. Truncates the input tensor to the original length. + + Note that this communication hook enforces vanilla allreduce for the first ``state.start_powerSGD_iter`` iterations. + This not only gives the user more control over the tradeoff between speedup and accuracy, + but also helps abstract away some complexity of the internal optimization of DDP for future communication hook developers. + + Args: + state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc. + To tune the compression configs, mainly need to tune ``matrix_approximation_rank`` and ``start_powerSGD_iter``. + bucket (dist.GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors. + Note that since DDP comm hook only supports single process single device mode, + only exactly one tensor is stored in this bucket. + + Returns: + Future handler of the communication, which updates the gradients in place. + + Example:: + >>> # xdoctest: +SKIP + >>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1) + >>> ddp_model.register_comm_hook(state, batched_powerSGD_hook) + """ # noqa: B950 + process_group = state.process_group + group_to_use = ( + process_group if process_group is not None else not_none(dist.group.WORLD) + ) + world_size = group_to_use.size() + + # The input tensor is a flattened 1D tensor. + input_tensor = bucket.buffer() + + # Run vanilla allreduce in the first `start_powerSGD_iter` iterations. + if state.iter < state.start_powerSGD_iter: + state.maybe_increase_iter(bucket) + return default._allreduce_fut(group_to_use, input_tensor) + + # Apply PowerSGD after `start_powerSGD_iter` iterations. + device = input_tensor.device + total_length = input_tensor.shape[0] + state.total_numel_before_compression += total_length + + # View the input tensor as a 2D square-shape tensor, and pad 0s if necessary. + square_side_length = math.ceil(math.sqrt(total_length)) + state.total_numel_after_compression += ( + square_side_length * state.matrix_approximation_rank * 2 + ) + padded_total_length = square_side_length**2 + input_tensor.resize_(padded_total_length) + input_tensor[total_length:padded_total_length].fill_(0) + + _report_compression_stats(bucket, state) + + # Incorporate the error from the previous state into the gradients. + bucket_index = bucket.index() + input_tensor_cp = None + if state.use_error_feedback: + if bucket_index in state.error_dict: + input_tensor.add_(state.error_dict[bucket_index]) + else: + logger.info( + "A zero tensor of length %s that represents local error is created.", + padded_total_length, + ) + state.error_dict[bucket_index] = torch.zeros( + padded_total_length, device=device, dtype=input_tensor.dtype + ) + + # Keep a copy of the input tensor, + # so that we can compute the local error caused by compression later, + # by comparing this copy and the input tensor updated after decompression. + input_tensor_cp = torch.clone(input_tensor).detach() + matrix = input_tensor.view(square_side_length, square_side_length) + + # Reuse P and Q from the previous iteration if possible. + # The memory spaces of P and Q need to be allocated in the first iteration when PowerSGD is applied. + if not state.warm_start or bucket_index not in state.p_memory_dict: + # If warm-start is disabled, low-rank tensors will be initialized at every step. + # Only log this if warm-start to avoid spamming. + if state.warm_start: + logger.info( + "Initializing low-rank tensors P and Q, each of which has a shape of %s x %s.", + square_side_length, + state.matrix_approximation_rank, + ) + + def create_low_rank_tensor(fill_random_values, rng): + """Return a low-rank 2D tensor of square_side_length * matrix_approximation_rank.""" + if fill_random_values: + with torch.random.fork_rng(devices=[]): + # Fork this RNG to avoid changing the seed globally and affecting the random sampling + # anywhere else in the training. + # The seed makes sure that the initial random values are the same across all the DDP replicas. + # This seed should differ at every step. + # Since it is very slow to fork RNG state across all the CUDA devices, + # only fork on CPU and then move the generated tensor to the CUDA device. + torch.manual_seed(rng.randint(1_000_000_000)) + return torch.randn( + square_side_length, + state.matrix_approximation_rank, + device="cpu", + dtype=input_tensor.dtype, + ).to(device) + else: + return torch.empty( + square_side_length, + state.matrix_approximation_rank, + device=device, + dtype=input_tensor.dtype, + ) + + state.p_memory_dict[bucket_index] = create_low_rank_tensor( + fill_random_values=False, rng=state.rng + ) + state.q_memory_dict[bucket_index] = create_low_rank_tensor( + fill_random_values=True, rng=state.rng + ) + _orthogonalize(state.q_memory_dict[bucket_index]) + + torch.matmul( + matrix, state.q_memory_dict[bucket_index], out=state.p_memory_dict[bucket_index] + ) + allreduce_p_fut = dist.all_reduce( + state.p_memory_dict[bucket_index], group=group_to_use, async_op=True + ).get_future() + + def compute_q(fut): + state.p_memory_dict[bucket_index] = fut.value()[0] + _orthogonalize(state.p_memory_dict[bucket_index]) + + torch.matmul( + matrix.t(), + state.p_memory_dict[bucket_index], + out=state.q_memory_dict[bucket_index], + ) + + # TODO: The above procedure does two matmul+allreduce steps per iteration -- + # one left multiplication and one right multiplication. + # For warm-start, can take one such step at a time, and alternate between them. + + return ( + dist.all_reduce( + state.q_memory_dict[bucket_index], group=group_to_use, async_op=True + ) + .get_future() + .wait()[0] + ) + + def decompress(fut): + state.q_memory_dict[bucket_index] = fut.value().div_(world_size) + torch.matmul( + state.p_memory_dict[bucket_index], + state.q_memory_dict[bucket_index].t(), + out=matrix, + ) + + if state.use_error_feedback: + # Memorize the local errors. + state.error_dict[bucket_index] = input_tensor_cp - input_tensor + # Removing this seemingly unnecessary sync somehow may cause failures. + # See: https://github.com/pytorch/pytorch/pull/54838 + if torch.cuda.is_available(): + torch.cuda.synchronize(device) + if not state.warm_start: + state.p_memory_dict.clear() + state.q_memory_dict.clear() + ret = input_tensor.resize_(total_length) + + state.maybe_increase_iter(bucket) + + return ret + + return allreduce_p_fut.then(compute_q).then(decompress) diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py b/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..cd88471effa8fc30598e7474bb7a2a2fed1e9a17 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py @@ -0,0 +1,218 @@ +# mypy: allow-untyped-defs +import torch +import torch.distributed as dist +from torch import nn + + +def _quantize_per_tensor_backend(x, scale, zero_point): + y = torch.round(x / scale) + zero_point + y = torch.clamp(y, 0, 255).to(torch.uint8) + return y + + +def _dequantize_per_tensor_backend(y, scale, zero_point): + x = scale * (y.to(torch.float32) - zero_point) + return x + + +def _quantize_per_channel_backend(x, scale, zero_point): + y = torch.zeros(x.size(), device=x.device) + for i in range(x.size()[0]): + y[i, :] = torch.round(x[i, :] / scale[i]) + zero_point[i] + y = torch.clamp(y, 0, 255).to(torch.uint8) + return y + + +def _dequantize_per_channel_backend(y, scale, zero_point): + y = y.to(torch.float32).to(y.device) + x = torch.zeros_like(y, device=y.device) + for i in range(x.size()[0]): + x[i, :] = scale[i] * (y[i, :] - zero_point[i]) + return x + + +def _get_allgather_out_list(all_gather_in_list, world_size): + out_list = [ + torch.zeros_like( + all_gather_in_list, + device=all_gather_in_list.device, + dtype=all_gather_in_list.dtype, + ) + for _ in range(world_size) + ] + return out_list + + +def quantization_pertensor_hook( + process_group: dist.ProcessGroup, bucket: dist.GradBucket +) -> torch.futures.Future[torch.Tensor]: + """ + Apply ``torch.quantize_per_tensor`` logic to DDP using ``allgather`` protocol. + + Workers first allgather the scale and zero point of their own + ``GradBucket`` prior to the quantization. After all workers have that information, + the first ``then`` callback called ``quantize_and_allgather`` quantizes worker's + own gradient tensor, and uses ``allgather`` to communicate these across all workers. + The final ``then`` callback called ``dequantize_and_aggregate``, dequantizes and + aggregates each quantized gradient tensor locally and returns the mean. + + .. warning :: + This is experimental, and uses ``allgather`` protocol which is considerably slower than + ``allreduce`` protocol. It works only with flattened grads. + + Example:: + >>> # xdoctest: +SKIP + >>> ddp_model.register_comm_hook(process_group, quantization_pertensor_hook) + """ + group_to_use = process_group if process_group is not None else dist.group.WORLD + rank = process_group.rank() if process_group is not None else dist.get_rank() + world_size = group_to_use.size() + + tensor = bucket.buffer() + + myObserver = torch.ao.quantization.MinMaxObserver().to(tensor.device) + myObserver(tensor) + + s, z = myObserver.calculate_qparams() + s_and_z = torch.FloatTensor([s, z]).to(tensor.device) + + all_ranks_s_and_z = _get_allgather_out_list(s_and_z, world_size) + + # First, allgather scale and zeros. + fut = dist.all_gather( + all_ranks_s_and_z, s_and_z, group=group_to_use, async_op=True + ).get_future() + + def quantize_and_allgather(fut): + # Store scale and zeros across all workers. + all_ranks_s_and_z = fut.wait()[0] + # All workers quantize their own ``GradBucket`` tensors. + quantized_tensor = _quantize_per_tensor_backend( + tensor, all_ranks_s_and_z[rank][0], all_ranks_s_and_z[rank][1] + ) + # Allgather quantized tensors. + fut = dist.all_gather( + _get_allgather_out_list(quantized_tensor, world_size), + quantized_tensor, + group=group_to_use, + async_op=True, + ).get_future() + + return fut.wait() + + def dequantize_and_aggregate(fut): + all_ranks_quantized_tensor = fut.wait()[0] + + aggregated_dequantized_tensor = torch.zeros_like( + all_ranks_quantized_tensor[0], device=tensor.device, dtype=torch.float32 + ) + # Using previously allgathered scales and zeros, dequantize gradient tensors + # locally and then aggregate them. + for r, quantized_tensor in enumerate(all_ranks_quantized_tensor): + aggregated_dequantized_tensor += _dequantize_per_tensor_backend( + quantized_tensor, all_ranks_s_and_z[r][0], all_ranks_s_and_z[r][1] + ) + + return aggregated_dequantized_tensor / world_size + + return fut.then(quantize_and_allgather).then(dequantize_and_aggregate) + + +def quantization_perchannel_hook( + process_group: dist.ProcessGroup, bucket: dist.GradBucket, bucket_size=512 +) -> torch.futures.Future[torch.Tensor]: + """ + Apply``torch.quantize_per_channel`` logic to DDP using ``allgather`` protocol. + + Compared to per-tensor, the main motivation of per-channel is + for considerably large tensors such as a tensor that contains 6 million + elements quantizing per a bucket size of 512 (or 128) elements may significantly + increase the resolution. + + It first splits ``GradBucket`` tensor into multiple chunks (channels) of ``bucket_size`` + elements. Then, workers allgather the scales and zero points of their own + ``GradBucket`` prior to the quantization. After all workers have that information, + the first ``then`` callback called ``quantize_and_allgather`` quantizes worker's + own gradient tensor, and uses ``allgather`` to communicate these across all workers. + The final ``then`` callback called ``dequantize_and_aggregate``, dequantizes, flattens, and + aggregates each quantized gradient tensor locally and returns the mean. + + .. warning :: + This is experimental, and uses ``allgather`` protocol which is considerably slower than + ``allreduce`` protocol. It works only with flattened grads. + + Example:: + >>> # xdoctest: +SKIP + >>> ddp_model.register_comm_hook(process_group, quantization_perchannel_hook) + """ + group_to_use = process_group if process_group is not None else dist.group.WORLD + rank = process_group.rank() if process_group is not None else dist.get_rank() + world_size = group_to_use.size() + + tensor = bucket.buffer() + + tensor_in_channels = ( + nn.functional.pad( + input=tensor, + pad=(0, bucket_size - len(tensor) % bucket_size), + mode="constant", + value=0, + ) + .view(-1, bucket_size) + .to(tensor.device) + ) + + myPerChannelObserver = torch.ao.quantization.PerChannelMinMaxObserver().to( + tensor.device + ) + myPerChannelObserver(tensor_in_channels) + + s_ch, z_ch = myPerChannelObserver.calculate_qparams() + s_and_z = torch.stack((s_ch, z_ch)).to(tensor.device) + + all_ranks_s_and_z = _get_allgather_out_list(s_and_z, world_size) + # First, allgather scale and zeros. + fut = dist.all_gather( + all_ranks_s_and_z, s_and_z, group=group_to_use, async_op=True + ).get_future() + + def quantize_and_allgather(fut): + # Store scale and zeros across all workers. + all_ranks_s_and_z = fut.wait()[0] + # All workers quantize their corresponding ``GradBucket`` tensors. + quantized_tensor = _quantize_per_channel_backend( + tensor_in_channels, + all_ranks_s_and_z[rank, 0, :], + all_ranks_s_and_z[rank, 1, :], + ) + # Allgather quantized tensors. + fut = dist.all_gather( + _get_allgather_out_list(quantized_tensor, world_size), + quantized_tensor, + group=group_to_use, + async_op=True, + ).get_future() + + return fut.wait() + + def dequantize_and_aggregate(fut): + all_ranks_quantized_tensor = fut.wait()[0] + + aggregated_dequantized_tensor = torch.zeros_like( + all_ranks_quantized_tensor[0], device=tensor.device, dtype=torch.float32 + ) + # Using previously allgathered scales and zeros, dequantize gradient tensors + # locally and then aggregate them. + for r, quantized_tensor in enumerate(all_ranks_quantized_tensor): + aggregated_dequantized_tensor += _dequantize_per_channel_backend( + quantized_tensor, all_ranks_s_and_z[r][0], all_ranks_s_and_z[r][1] + ) + + return ( + torch.flatten(aggregated_dequantized_tensor).to(tensor.device)[ + : tensor.size()[0] + ] + / world_size + ) + + return fut.then(quantize_and_allgather).then(dequantize_and_aggregate) diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/join.py b/phivenv/Lib/site-packages/torch/distributed/algorithms/join.py new file mode 100644 index 0000000000000000000000000000000000000000..b04349d06efa625f6362158243f16746e3b482b4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/algorithms/join.py @@ -0,0 +1,348 @@ +# mypy: allow-untyped-defs +import warnings +from abc import ABC, abstractmethod +from types import TracebackType +from typing import Any, NamedTuple, Optional + +import torch +import torch.distributed as dist + + +__all__ = ["JoinHook", "Joinable", "Join"] + + +class JoinHook: + r""" + This defines a join hook, which provides two entry points in the join context manager. + + Entry points : a main hook, which is called repeatedly while there exists a non-joined + process, and a post-hook, which is called once all processes have joined. + + To implement a join hook for the generic join context manager, define a + class that inherits from :class:`JoinHook` and override ``main_hook()`` and + ``post_hook()`` as appropriate. + """ + + def main_hook(self) -> None: + r"""Call this hook while there exists a non-joined process to shadow collective communications in a training iteration. + + Training iteration i.e., in one forward pass, backward pass, and optimizer step. + """ + + def post_hook(self, is_last_joiner: bool) -> None: + r""" + Call hook after all processes have joined. + + It is passed an additional ``bool`` argument ``is_last_joiner``, which indicates if the rank is one of the last to join. + + Arguments: + is_last_joiner (bool): ``True`` if the rank is one of the last to + join; ``False`` otherwise. + """ + + +class Joinable(ABC): + r""" + This defines an abstract base class for joinable classes. + + A joinable class + (inheriting from :class:`Joinable`) should implement :meth:`join_hook`, + which returns a :class:`JoinHook` instance, in addition to + :meth:`join_device` and :meth:`join_process_group` that return device and + process group information, respectively. + """ + + @abstractmethod + def __init__(self) -> None: + super().__init__() + self._join_config = _JoinConfig.construct_disabled_join_config() + + @abstractmethod + def join_hook(self, **kwargs) -> JoinHook: + r""" + Return a :class:`JoinHook` instance for the given :class:`Joinable`. + + Arguments: + kwargs (dict): a :class:`dict` containing any keyword arguments + to modify the behavior of the join hook at run time; all + :class:`Joinable` instances sharing the same join context + manager are forwarded the same value for ``kwargs``. + """ + ... + + @property + @abstractmethod + def join_device(self) -> torch.device: + r"""Return the device from which to perform collective communications needed by the join context manager.""" + ... + + @property + @abstractmethod + def join_process_group(self) -> Any: + r"""Returns the process group for the collective communications needed by the join context manager itself.""" + ... + + +class _JoinConfig(NamedTuple): + r"""This includes all fields needed from a :class:`Joinable` instance for the join context manager side.""" + + enable: bool + throw_on_early_termination: bool + is_first_joinable: bool + + @staticmethod + def construct_disabled_join_config(): + r"""Return a :class:`_JoinConfig` instance indicating that join-related logic should be disabled. + + e.g. if the caller is not in a join context manager. + """ + return _JoinConfig( + enable=False, throw_on_early_termination=False, is_first_joinable=False + ) + + +class Join: + r""" + This class defines the generic join context manager, which allows custom hooks to be called after a process joins. + + These hooks should shadow the + collective communications of non-joined processes to prevent hanging and + erroring and to ensure algorithmic correctness. Refer to :class:`JoinHook` + for details about the hook definition. + + .. warning:: + The context manager requires each participating :class:`Joinable` to + call the method :meth:`notify_join_context()` before its own per- + iteration collective communications to ensure correctness. + + .. warning:: + The context manager requires that all ``process_group`` attributes in + the :class:`JoinHook` objects are the same. If there are multiple + :class:`JoinHook` objects, then the ``device`` of the first is used. + The process group and device information is used for checking for non- + joined processes and for notifying processes to throw an exception if + ``throw_on_early_termination`` is enabled, both of which using an all- + reduce. + + Arguments: + joinables (List[Joinable]): a list of the participating + :class:`Joinable` s; their hooks are iterated over in the given + order. + + enable (bool): a flag enabling uneven input detection; setting to + ``False`` disables the context manager's functionality and should + only be set when the user knows the inputs will not be uneven + (default: ``True``). + + throw_on_early_termination (bool): a flag controlling whether to throw an + exception upon detecting uneven inputs (default: ``False``). + + Example:: + + >>> import os + >>> import torch + >>> import torch.distributed as dist + >>> import torch.multiprocessing as mp + >>> # xdoctest: +SKIP + >>> import torch.nn.parallel.DistributedDataParallel as DDP + >>> import torch.distributed.optim.ZeroRedundancyOptimizer as ZeRO + >>> from torch.distributed.algorithms.join import Join + >>> + >>> # On each spawned worker + >>> def worker(rank): + >>> dist.init_process_group("nccl", rank=rank, world_size=2) + >>> model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank]) + >>> optim = ZeRO(model.parameters(), torch.optim.Adam, lr=0.01) + >>> # Rank 1 gets one more input than rank 0 + >>> inputs = [torch.tensor([1.]).to(rank) for _ in range(10 + rank)] + >>> with Join([model, optim]): + >>> for input in inputs: + >>> loss = model(input).sum() + >>> loss.backward() + >>> optim.step() + >>> # All ranks reach here without hanging/erroring + """ + + def __init__( + self, + joinables: list[Joinable], + enable: bool = True, + throw_on_early_termination: bool = False, + **kwargs, + ): + if len(joinables) == 0: + raise ValueError("The join context manager requires at least one joinable") + self._joinables = joinables + self._join_hooks = [ + joinable.join_hook(**kwargs) for joinable in self._joinables + ] + self._enable = enable + self._throw_on_early_termination = throw_on_early_termination + self._set_joinable_configs() + self._extract_dist_info() + + def _set_joinable_configs(self) -> None: + r"""Set the :class:`_JoinConfig` of each participating :class:`Joinable`.""" + assert len(self._joinables) > 0 + is_first_joinable = True + for joinable in self._joinables: + joinable._join_config = _JoinConfig( + enable=self._enable, + throw_on_early_termination=self._throw_on_early_termination, + is_first_joinable=is_first_joinable, + ) + is_first_joinable = False + + def _extract_dist_info(self) -> None: + r""" + Extract the process group and device information from the joinables. + + If there are multiple joinables, then the context manager uses the + first specified device. + + Preconditions: + ``self._joinables`` is not ``None`` and is non-empty. + + Raises: + ValueError + If there are multiple conflicting ``process_group`` attributes + among the ``Joinable`` objects. + """ + process_group = None + device = None + for joinable in self._joinables: + if process_group is None: + process_group = joinable.join_process_group + elif process_group != joinable.join_process_group: + raise ValueError( + "Using join context manager with multiple process groups" + ) + if device is None: + device = joinable.join_device + self._process_group = process_group + self._rank = dist.get_rank(self._process_group) + self._device = device + + def __enter__(self): ... + + def __exit__( + self, + type: Optional[type[BaseException]], + value: Optional[BaseException], + traceback: Optional[TracebackType], + ): + r""" + Repeatedly runs the main hooks until all processes join; then, runs the post-hooks. + + Raises: + RuntimeError + If ``throw_on_early_termination=True``. + """ + if not self._enable or type: + return # propagate the exception directly if one was raised + + all_procs_joined = False + is_last_joiner = True + + i = 0 + WARN_THRESHOLD = 1000 + warnings.simplefilter("once") + + while not all_procs_joined: + if i > WARN_THRESHOLD: + warnings.warn( + "Detected uneven input skew of greater than " + f"{WARN_THRESHOLD}. This means that rank " + f"{self._rank} has at least {WARN_THRESHOLD} " + f"fewer inputs than other currently-active ranks. " + "This level of skew could lead to performance " + "degradation during training." + ) + # Shadow the all-reduce in non-joined processes + num_nonjoined_procs = self._get_num_nonjoined_procs() + if num_nonjoined_procs == 0: + all_procs_joined = True + else: + if self._throw_on_early_termination: + self._notify_procs_to_terminate() + + # Run main hooks + for join_hook in self._join_hooks: + join_hook.main_hook() + + is_last_joiner = False + i += 1 + + # Run post-hooks + for join_hook in self._join_hooks: + join_hook.post_hook(is_last_joiner) + + def _get_num_nonjoined_procs(self): + r"""Return the number of non-joined processes by shadowing an all-reduce in the non-joined processes.""" + num_nonjoined_procs = torch.zeros(1, device=self._device) + dist.all_reduce(num_nonjoined_procs, group=self._process_group) + return num_nonjoined_procs.item() + + def _notify_procs_to_terminate(self): + r"""Schedule an all-reduce to notify non-joined processes to terminate. + + Also raise a ``RuntimeError`` indicating that the current process has exhausted its inputs. + """ + ones = torch.ones(1, device=self._device) + dist.all_reduce(ones, group=self._process_group) + raise RuntimeError(f"Rank {self._rank} exhausted all inputs.") + + @staticmethod + def notify_join_context(joinable: Joinable): + r""" + Notifies the join context manager that the calling process has not yet joined. + + Then, if ``throw_on_early_termination=True``, checks if uneven inputs have been detected + (i.e. if one process has already joined) and throws an exception if so. + + This method should be called from a :class:`Joinable` object before + its per-iteration collective communications. For example, this should + be called at the beginning of the forward pass in + :class:`DistributedDataParallel`. + + Only the first :class:`Joinable` object passed into the context + manager performs the collective communications in this method, and + for the others, this method is vacuous. + + Arguments: + joinable (Joinable): the :class:`Joinable` object calling this + method. + + Returns: + An async work handle for the all-reduce meant to notify the context + manager that the process has not yet joined if ``joinable`` is the + first one passed into the context manager; ``None`` otherwise. + """ + assert hasattr(joinable, "_join_config"), ( + f"Check that the {type(joinable)} constructor calls the " + "``Joinable`` constructor" + ) + + join_config = joinable._join_config + # First joinable is responsible for the collective communications + if not join_config.is_first_joinable or not join_config.enable: + return None + + device = joinable.join_device + process_group = joinable.join_process_group + + # Schedule an all-reduce to indicate that the caller has not yet joined + ones = torch.ones(1, device=device) + work = dist.all_reduce(ones, group=process_group, async_op=True) + + if join_config.throw_on_early_termination: + # Check if uneven inputs have been detected + zeros = torch.zeros(1, device=device) + dist.all_reduce(zeros, group=process_group) + should_throw = zeros.item() + if should_throw: + raise RuntimeError( + "Detected at least one rank that exhausted inputs. " + "Throwing across all ranks." + ) + return work diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/model_averaging/__init__.py b/phivenv/Lib/site-packages/torch/distributed/algorithms/model_averaging/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95fd33734ae5d791aafcc8e440e0b32564711102 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/averagers.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/averagers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cfe26c25a477c9987a8417767c75abb1d3c849e0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/averagers.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/hierarchical_model_averager.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/hierarchical_model_averager.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..799cba32edc87a935e95c883306fe42410a768b7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/hierarchical_model_averager.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc6938e57240452af0368115c0cf73ffdd0770d2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/model_averaging/averagers.py b/phivenv/Lib/site-packages/torch/distributed/algorithms/model_averaging/averagers.py new file mode 100644 index 0000000000000000000000000000000000000000..9841ac592f3ef185dcdce17963970b1577fbe242 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/algorithms/model_averaging/averagers.py @@ -0,0 +1,130 @@ +# mypy: allow-untyped-defs +import warnings +from abc import ABC, abstractmethod +from collections.abc import Iterable +from typing import Optional, Union + +import torch +import torch.distributed as dist +import torch.distributed.algorithms.model_averaging.utils as utils +from torch.utils._typing_utils import not_none as _not_none + + +__all__ = ["ModelAverager", "PeriodicModelAverager"] + + +class ModelAverager(ABC): + r"""Base class for all model averagers. + + Args: + process_group: The process group to be used for all-reduce. + If ``None``, the default process group, which + is created by :func:`torch.distributed.init_process_group`, + will be used. (default: ``None``) + """ + + def __init__(self, process_group: Optional[dist.ProcessGroup] = None): + self.process_group = ( + process_group if process_group is not None else _not_none(dist.group.WORLD) + ) + self.step = 0 + + @abstractmethod + def average_parameters(self, params): + raise NotImplementedError + + +class PeriodicModelAverager(ModelAverager): + r""" + Averages parameters periodically after the warm-up stage. + + This can be used for running `post-local SGD `_, + by running :class:`~torch.nn.DistributedDataParallel` (DDP) + using the subgroups created by :meth:`~torch.distributed.new_subgroups`. + + Args: + period (int): The number of steps per model averaging. + Usually the period should be greater than ``1`` to reduce the communication cost. + Otherwise, only DDP needs to be used. + warmup_steps (int): The number of warm-up steps. During this stage, + model averaging is skipped. + process_group: The process group to be used for all-reduce. + If ``None``, the default process group, which + is created by :func:`torch.distributed.init_process_group`, + will be used. (default: ``None``) + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> import torch + >>> import torch.distributed as dist + >>> import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD + >>> import torch.distributed.algorithms.model_averaging.averagers as averagers + >>> import torch.nn as nn + >>> + >>> dist.init_process_group("nccl", rank=rank, world_size=16) + >>> torch.cuda.set_device(rank) + >>> module = nn.Linear(1, 1, bias=False).cuda() + >>> model = nn.parallel.DistributedDataParallel( + >>> module, device_ids=[rank], output_device=rank + >>> ) + >>> # Register a post-localSGD communication hook. + >>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100) + >>> model.register_comm_hook(state, post_localSGD_hook) + >>> + >>> # In the first 100 steps, run global gradient averaging like normal DDP at every step. + >>> # After 100 steps, run model averaging every 4 steps. + >>> # Note that ``warmup_steps`` must be the same as ``start_localSGD_iter`` used in ``PostLocalSGDState``. + >>> averager = averagers.PeriodicModelAverager(period=4, warmup_steps=100) + >>> for step in range(0, 200): + >>> optimizer.zero_grad() + >>> loss = loss_fn(output, labels) + >>> loss.backward() + >>> optimizer.step() + >>> # Will average model parameters globally every 4 steps. Thus, + >>> # inter-node communication only occurs every 4 iterations after + >>> # the initial ``warmup_steps`` period. + >>> averager.average_parameters(model.parameters()) + """ + + def __init__( + self, period, warmup_steps=0, process_group: Optional[dist.ProcessGroup] = None + ): + super().__init__(process_group) + if warmup_steps < 0: + raise ValueError("Arg ``warmup_steps`` must be a non-negative number.") + self.warmup_steps = warmup_steps + if period < 1: + raise ValueError("Arg ``period`` must be a positive value.") + elif period == 1: + warnings.warn( + "When period is 1, no need to use model averaging because the communication cost " + "of all-reducing parameters will be no less than the cost of all-reducing gradients " + "by DistributedDataParallel in the backward pass. Therefore, only " + "DistributedDataParallel should be used for this case." + ) + self.period = period + + def average_parameters( + self, + params: Union[ + Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]] + ], + ): + """ + Averages parameters or parameter groups of an optimizer if ``step`` is no less than ``warmup_steps``. + + Can be divided by ``period``, where ``step`` is increased by 1 + at each iteration in the training loop. + Args: + params: The parameters of a model or parameter groups of an optimizer. + + """ + if ( + self.step >= self.warmup_steps + and (self.step - self.warmup_steps) % self.period == 0 + ): + utils.average_parameters_or_parameter_groups( + params, _not_none(self.process_group) + ) + self.step += 1 diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py b/phivenv/Lib/site-packages/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py new file mode 100644 index 0000000000000000000000000000000000000000..e9fcdb8eaf70a9e7d839234fde46ee27efb0417a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py @@ -0,0 +1,181 @@ +# mypy: allow-untyped-defs +# Copyright 2022 Cruise LLC +import logging +import warnings +from collections import OrderedDict +from collections.abc import Iterable +from typing import Union + +import torch +import torch.distributed as dist +import torch.distributed.algorithms.model_averaging.averagers as averagers +import torch.distributed.algorithms.model_averaging.utils as utils + + +logger = logging.getLogger(__name__) + + +class HierarchicalModelAverager(averagers.ModelAverager): + r""" + Runs hierarchical model averaging (`hierarchical SGD `_). + + Process groups of different sizes are organized in a hierarchy, and they average parameters + by using different periods concurrently after the warm-up stage. + This is an extension of :class:`~torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager` + that supports `post-local SGD `_, which essentially only supports + a two-level hierarchy: the intra-machine level and the global level, where the intra-machine + level is usually embedded in :meth:`~torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook`. + Similarly, the process groups within this class do not have such an intra-machine process + subgroup, which should be embedded by the post-local SGD communication hook instead. + + Args: + period_group_size_dict: An ordered dict mapping keys of model averaging period to + process group size, used for initializing process groups of + different sizes in a hierarchy to average parameters concurrently. + Particularly, at each iteration, there will be at most a single + process group that runs averaging -- the period of such group should + have the largest period which the current step can be divided by. + For example, if the dict has three keys: 2, 4, and 8, + then this means totally three process groups will be created to + average parameters every 2, 4, and 8 iterations, respectively. + At the 4th iteration, only the second process group will run + averaging, because the first process group should be a + subset of the second process group, and no need to execute the first + process group redundantly. + On the other hand, the third process group can only be triggered + every 8 iterations, so it will not be triggered at the 4th iteration. + warmup_steps (int): The number of warm-up steps. During this stage, model averaging is skipped. + process_group (ProcessGroup, optional): The overall process group containing all the processes that runs model averaging. + If ``None``, the default process group, which is created + by :func:`torch.distributed.init_process_group`, will be used. + (default: ``None``) + + Example:: + >>> # xdoctest: +SKIP('undefined rank') + >>> from collections import OrderedDict + >>> import torch + >>> import torch.distributed as dist + >>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import ( + >>> PostLocalSGDState, + >>> post_localSGD_hook, + >>> ) + >>> import torch.distributed.algorithms.model_averaging.hierarchical_model_averager as hierarchicalSGD + >>> import torch.nn as nn + >>> + >>> dist.init_process_group("nccl", rank=rank, world_size=16) + >>> torch.cuda.set_device(rank) + >>> module = nn.Linear(1, 1, bias=False).to(rank) + >>> model = nn.parallel.DistributedDataParallel( + >>> module, device_ids=[rank], output_device=rank + >>> ) + >>> # Register a post-localSGD communication hook. + >>> # Assume that each machine has 4 GPUs, then each intra-machine subgroup has a size of 4. + >>> subgroup, _ = dist.new_subgroups() + >>> state = PostLocalSGDState(process_group=None, subgroup=subgroup, start_localSGD_iter=100) + >>> model.register_comm_hook(state, post_localSGD_hook) + >>> + >>> # Average parameters among each group of 8 processes every 4 iterations, and among all + >>> # the 16 processes every 16 iterations. + >>> averager = hierarchicalSGD.HierarchicalModelAverager( + >>> period_group_size_dict=OrderedDict([(4, 8), (16, 16)]), warmup_steps=100) + >>> # Note that ``warmup_steps`` must be the same as ``start_localSGD_iter`` used in ``PostLocalSGDState``. + >>> # In the first 100 steps, run global gradient averaging like normal DDP at every step. + >>> # After 100 steps, run model averaging at two levels. + >>> for step in range(0, 200): + >>> optimizer.zero_grad() + >>> loss = loss_fn(output, labels) + >>> loss.backward() + >>> optimizer.step() + >>> # Average parameters after ``optimizer.step()``. + >>> # Thus, the inter-node communication only occurs periodically after ``warmup_steps``. + >>> averager.average_parameters(model.parameters()) + + .. warning :: + The last group size in the dict must be the size of the provided ``process_group``, + which indicates model averaging at the highest level of the hierarchy. + If ``process_group`` is not provided, then the last group size should be equal to the world size. + + .. warning :: + `HierarchicalModelAverager` is experimental and subject to change. + """ + + def __init__(self, period_group_size_dict=None, warmup_steps=0, process_group=None): + super().__init__(process_group) + if not period_group_size_dict: + raise ValueError("Arg ``period_group_size_dict`` must not be empty.") + self._periods = list(period_group_size_dict.keys()) + if self._periods[0] <= 0: + raise ValueError( + "The minimum period in arg ``period_group_size_dict`` must be a positive value." + ) + elif self._periods[-1] == 1: + warnings.warn( + "When the maximum period in arg ``period_group_size_dict`` is 1, " + "no need to use model averaging because the communication cost " + "of all-reducing parameters will be no less than the cost of all-reducing gradients " + "by DistributedDataParallel in the backward pass. Therefore, only " + "DistributedDataParallel should be used for this case." + ) + overall_group_size = dist.get_world_size(group=self.process_group) + if list(period_group_size_dict.values())[-1] != overall_group_size: + raise ValueError( + f"The last value in arg ``period_process_group_dict`` {list(period_group_size_dict.values())[-1]} " + f"must be equal to the size of arg ``process_group`` {overall_group_size}." + ) + + self.period_process_group_dict = OrderedDict() + logger.info("Model averaging hierarchy:") + for period, group_size in period_group_size_dict.items(): + logger.info( + "\tEach group that has %s processes average parameters every %s iterations, " + "if no higher-level averaging.", + group_size, + period, + ) + if group_size != overall_group_size: + self.period_process_group_dict[period], _ = dist.new_subgroups( + group_size=group_size, group=self.process_group + ) + else: + self.period_process_group_dict[period] = self.process_group + + if warmup_steps < 0: + raise ValueError("Arg ``warmup_steps`` must be a non-negative number.") + self.warmup_steps = warmup_steps + + def _find_process_group(self): + """ + Return a process group as the value of an ``period_process_group_dict`` entry. + + If ``step`` can be divided by multiple periods in the keys of ``period_process_group_dict``, + then the returned process group is the one corresponding to the largest period, + since this process group will be used for averaging parameters at this ``step``. + Returns ``None`` if not found. + """ + for period in reversed(self._periods): + if self.step % period == 0: + return self.period_process_group_dict[period] + return None + + def average_parameters( + self, + params: Union[ + Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]] + ], + ): + """ + Averages parameters or parameter groups of an optimizer. + + Averaging only occurs if ``step`` is no less than ``warmup_steps`` + and it can be divided by a period in the keys of ``period_process_group_dict``, + where ``step`` is increased by 1 at each iteration in the training loop. + If ``step`` can be divided by multiple periods in the keys of ``period_process_group_dict``, + only the largest period is used, and the corresponding process group is used for averaging parameters. + Args: + params: The parameters of a model or parameter groups of an optimizer. + """ + if self.step >= self.warmup_steps: + group = self._find_process_group() + if group is not None: + utils.average_parameters_or_parameter_groups(params, group) + self.step += 1 diff --git a/phivenv/Lib/site-packages/torch/distributed/algorithms/model_averaging/utils.py b/phivenv/Lib/site-packages/torch/distributed/algorithms/model_averaging/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f7f955bfa9f535654ce6f855c6248e556f4cd33b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/algorithms/model_averaging/utils.py @@ -0,0 +1,92 @@ +# mypy: allow-untyped-defs +import itertools +from collections.abc import Iterable, Iterator +from typing import Union + +import torch +import torch.distributed as dist + +# The two imports below are not always available depending on the +# USE_DISTRIBUTED compile flag. Make sure they raise import error +# if we're trying to use them. +from torch.distributed import group, ProcessGroup + + +__all__ = [ + "average_parameters", + "get_params_to_average", + "average_parameters_or_parameter_groups", +] + + +def average_parameters( + params: Iterator[torch.nn.Parameter], process_group: ProcessGroup +): + """ + Averages all the given parameters. + + For allreduce efficiency, all the parameters are flattened into a contiguous buffer. + Thus, it requires extra memory of the same size as the given parameters. + """ + group_to_use = process_group if process_group is not None else group.WORLD + # Do not update any parameter if not in the process group. + if dist._rank_not_in_group(group_to_use): + return + + params_it1, params_it2 = itertools.tee(params) + # If the input parameters have different data types, + # packing these parameters will trigger an implicit type up-casting. + # The original parameter data types will be restored during the subsequent unpacking. + flat_params = torch.cat([p.data.reshape(-1) for p in params_it1]) + flat_params /= dist.get_world_size(group_to_use) + # Make sure the allreduce will not conflict with any other ongoing process group. + if torch.accelerator.is_available(): + torch.accelerator.synchronize() + dist.all_reduce(flat_params, group=group_to_use) + + offset = 0 + for p in params_it2: + p.data = flat_params[offset : offset + p.numel()].view_as(p).type_as(p) + offset += p.numel() + + +def get_params_to_average( + params: Union[ + Iterable[torch.nn.Parameter], + Iterable[dict[str, torch.nn.Parameter]], + ], +): + """ + Return a list of parameters that need to average. + + This filters out the parameters that do not contain any gradients. + Args: + params: The parameters of a model or parameter groups of an optimizer. + """ + filtered_params = [] + for param in params: + if isinstance(param, torch.nn.Parameter): + # model.parameters() input + param_data = param + if param_data.grad is not None: + filtered_params.append(param_data) + elif isinstance(param, dict): + # optimizer.param_groups input + for param_data in param["params"]: + if param_data.grad is not None: + filtered_params.append(param_data) + else: + raise NotImplementedError( + f"Parameter input of type {type(param)} is not supported" + ) + return filtered_params + + +def average_parameters_or_parameter_groups( + params: Union[ + Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]] + ], + process_group: ProcessGroup, +): + """Averages parameters of a model or parameter groups of an optimizer.""" + average_parameters(iter(get_params_to_average(params)), process_group) diff --git a/phivenv/Lib/site-packages/torch/distributed/argparse_util.py b/phivenv/Lib/site-packages/torch/distributed/argparse_util.py new file mode 100644 index 0000000000000000000000000000000000000000..5a4030479d301537998f30a7ca67a8dc69ccd2e3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/argparse_util.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import os +from argparse import Action + + +class env(Action): + """ + Get argument values from ``PET_{dest}`` before defaulting to the given ``default`` value. + + For flags (e.g. ``--standalone``) + use ``check_env`` instead. + + .. note:: when multiple option strings are specified, ``dest`` is + the longest option string (e.g. for ``"-f", "--foo"`` + the env var to set is ``PET_FOO`` not ``PET_F``) + + Example: + :: + + parser.add_argument("-f", "--foo", action=env, default="bar") + + ./program -> args.foo="bar" + ./program -f baz -> args.foo="baz" + ./program --foo baz -> args.foo="baz" + PET_FOO="env_bar" ./program -f baz -> args.foo="baz" + PET_FOO="env_bar" ./program --foo baz -> args.foo="baz" + PET_FOO="env_bar" ./program -> args.foo="env_bar" + + parser.add_argument("-f", "--foo", action=env, required=True) + + ./program -> fails + ./program -f baz -> args.foo="baz" + PET_FOO="env_bar" ./program -> args.foo="env_bar" + PET_FOO="env_bar" ./program -f baz -> args.foo="baz" + """ + + def __init__(self, dest, default=None, required=False, **kwargs) -> None: + env_name = f"PET_{dest.upper()}" + default = os.environ.get(env_name, default) + + # ``required`` means that it NEEDS to be present in the command-line args + # rather than "this option requires a value (either set explicitly or default" + # so if we found default then we don't "require" it to be in the command-line + # so set it to False + if default: + required = False + + super().__init__(dest=dest, default=default, required=required, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, values) + + +class check_env(Action): + """ + Check whether the env var ``PET_{dest}`` exists before defaulting to the given ``default`` value. + + Equivalent to + ``store_true`` argparse built-in action except that the argument can + be omitted from the commandline if the env var is present and has a + non-zero value. + + .. note:: it is redundant to pass ``default=True`` for arguments + that use this action because a flag should be ``True`` + when present and ``False`` otherwise. + + Example: + :: + + parser.add_argument("--verbose", action=check_env) + + ./program -> args.verbose=False + ./program --verbose -> args.verbose=True + PET_VERBOSE=1 ./program -> args.verbose=True + PET_VERBOSE=0 ./program -> args.verbose=False + PET_VERBOSE=0 ./program --verbose -> args.verbose=True + + Anti-pattern (don't do this): + + :: + + parser.add_argument("--verbose", action=check_env, default=True) + + ./program -> args.verbose=True + ./program --verbose -> args.verbose=True + PET_VERBOSE=1 ./program -> args.verbose=True + PET_VERBOSE=0 ./program -> args.verbose=False + + """ + + def __init__(self, dest, default=False, **kwargs) -> None: + env_name = f"PET_{dest.upper()}" + default = bool(int(os.environ.get(env_name, "1" if default else "0"))) + super().__init__(dest=dest, const=True, default=default, nargs=0, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, self.const) diff --git a/phivenv/Lib/site-packages/torch/distributed/autograd/__init__.py b/phivenv/Lib/site-packages/torch/distributed/autograd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7eb5109662b42a15720c013fa9ed178b939da433 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/autograd/__init__.py @@ -0,0 +1,53 @@ +# mypy: allow-untyped-defs + +import torch + + +def is_available(): + return hasattr(torch._C, "_dist_autograd_init") + + +if is_available() and not torch._C._dist_autograd_init(): + raise RuntimeError("Failed to initialize torch.distributed.autograd") + +if is_available(): + from torch._C._distributed_autograd import ( + _current_context, + _get_debug_info, + _get_max_id, + _init, + _is_valid_context, + _new_context, + _release_context, + _retrieve_context, + backward, + DistAutogradContext, + get_gradients, + ) + + +class context: + """ + Context object to wrap forward and backward passes when using + distributed autograd. The ``context_id`` generated in the ``with`` + statement is required to uniquely identify a distributed backward pass + on all workers. Each worker stores metadata associated with this + ``context_id``, which is required to correctly execute a distributed + autograd pass. + + Example:: + >>> # xdoctest: +SKIP + >>> import torch.distributed.autograd as dist_autograd + >>> with dist_autograd.context() as context_id: + >>> t1 = torch.rand((3, 3), requires_grad=True) + >>> t2 = torch.rand((3, 3), requires_grad=True) + >>> loss = rpc.rpc_sync("worker1", torch.add, args=(t1, t2)).sum() + >>> dist_autograd.backward(context_id, [loss]) + """ + + def __enter__(self): + self.autograd_context = _new_context() + return self.autograd_context._context_id() + + def __exit__(self, type, value, traceback): + _release_context(self.autograd_context._context_id()) diff --git a/phivenv/Lib/site-packages/torch/distributed/autograd/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/autograd/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b033a1ead6a663b6e616ecb661dbd842d5c29e53 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/autograd/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/c10d_logger.py b/phivenv/Lib/site-packages/torch/distributed/c10d_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..b8d75031abd957c203e5610e04995af028e114aa --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/c10d_logger.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import functools +import logging +from typing import Any, Callable, TypeVar +from typing_extensions import ParamSpec + +import torch +import torch.distributed as dist +from torch.distributed.logging_handlers import _log_handlers +from torch.monitor import _WaitCounter + + +__all__: list[str] = [] + +_DEFAULT_DESTINATION = "default" + + +def _get_or_create_logger(destination: str = _DEFAULT_DESTINATION) -> logging.Logger: + logging_handler, log_handler_name = _get_logging_handler(destination) + logger = logging.getLogger(f"c10d-{log_handler_name}") + logger.setLevel(logging.DEBUG) + formatter = logging.Formatter( + "%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s" + ) + logging_handler.setFormatter(formatter) + logger.propagate = False + logger.addHandler(logging_handler) + return logger + + +def _get_logging_handler( + destination: str = _DEFAULT_DESTINATION, +) -> tuple[logging.Handler, str]: + log_handler = _log_handlers[destination] + log_handler_name = f"{type(log_handler).__name__}-{destination}" + return (log_handler, log_handler_name) + + +global _c10d_logger +_c10d_logger = _get_or_create_logger() + + +def _get_msg_dict(func_name, *args, **kwargs) -> dict[str, Any]: + if dist.is_initialized(): + group = kwargs.get("group") or kwargs.get("process_group") + msg_dict = { + "func_name": f"{func_name}", + "pg_name": f"{dist._get_process_group_name(kwargs.get('pg'))}", # type: ignore[arg-type] + "backend": f"{dist.get_backend(group)}", + "world_size": f"{dist.get_world_size()}", + "group_size": f"{dist.get_world_size(group)}", + "global_rank": f"{dist.get_rank()}", + "local_rank": f"{dist.get_rank(group)}", + } + if msg_dict["backend"] == "nccl": + nccl_version = torch.cuda.nccl.version() + msg_dict["nccl_version"] = ".".join(str(v) for v in nccl_version) + else: + msg_dict = { + "func_name": f"{func_name}", + } + return msg_dict + + +_T = TypeVar("_T") +_P = ParamSpec("_P") + + +def _exception_logger(func: Callable[_P, _T]) -> Callable[_P, _T]: + @functools.wraps(func) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: + try: + return func(*args, **kwargs) + except Exception as error: + msg_dict = _get_msg_dict(func.__name__, *args, **kwargs) + msg_dict["error"] = f"{error}" + _c10d_logger.debug(msg_dict) + raise + + return wrapper + + +def _time_logger(func: Callable[_P, _T]) -> Callable[_P, _T]: + @functools.wraps(func) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: + with _WaitCounter(f"pytorch.wait_counter.c10d.{func.__name__}").guard(): + func_return = func(*args, **kwargs) + return func_return + + return wrapper diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/__init__.py b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..935fb5ecfed894dcd0a849a5592e23dafd7903c6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__init__.py @@ -0,0 +1,16 @@ +from . import _extension +from .api import CheckpointException +from .default_planner import DefaultLoadPlanner, DefaultSavePlanner +from .filesystem import FileSystemReader, FileSystemWriter +from .hf_storage import HuggingFaceStorageReader, HuggingFaceStorageWriter +from .metadata import ( + BytesStorageMetadata, + ChunkStorageMetadata, + Metadata, + TensorStorageMetadata, +) +from .optimizer import load_sharded_optimizer_state_dict +from .planner import LoadPlan, LoadPlanner, ReadItem, SavePlan, SavePlanner, WriteItem +from .state_dict_loader import load, load_state_dict +from .state_dict_saver import async_save, save, save_state_dict +from .storage import StorageReader, StorageWriter diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce0c593eaa2143b3613f5c8a72941b3e9fc4b7cd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_async_executor.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_async_executor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2d32ff07061f2e9666bcc3eb686c789c1703fcd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_async_executor.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_async_process_executor.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_async_process_executor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04d98fa0682e8e20ebc90b6e5d8dadb0fe32ba9b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_async_process_executor.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_async_thread_executor.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_async_thread_executor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5c58a61d17cdf78736539a6185340a7c7d874cf Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_async_thread_executor.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_checkpointer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_checkpointer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff86e98109b95da1ef8fefb31093869e3a486972 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_checkpointer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_dedup_save_plans.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_dedup_save_plans.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55b41c840eb6e58faacf547e2fa8eeb0169ed9b7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_dedup_save_plans.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_dedup_tensors.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_dedup_tensors.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0df115d09c0bb76e64095de854dc36620c6cbcff Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_dedup_tensors.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_extension.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_extension.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c0942ad4fd4481335daac8c6a1d63f1f0e38a4e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_extension.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_fsspec_filesystem.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_fsspec_filesystem.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05d7ed190dc7a8e0481bc6e85804e5a0095d3e4b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_fsspec_filesystem.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_hf_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_hf_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b58bb9e8750d4ca5fd4d59a9e0903d393d53dac6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_hf_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_nested_dict.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_nested_dict.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f63a300425e9bf9ec34e17bc52037877c2c3ad43 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_nested_dict.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_sharded_tensor_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_sharded_tensor_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a94264c0b768eeb1ed8cb697561f1dcc35f9c29 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_sharded_tensor_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_state_dict_stager.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_state_dict_stager.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd92503ca9f7395194d8f5a4ffeaf74d5432951c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_state_dict_stager.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_storage_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_storage_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ac68ee1b4921da526d0babe8c4f1a21d032c93f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_storage_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_traverse.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_traverse.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb1b31fbbc3ba6b278319c6d44b5c158e55f4b40 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_traverse.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_version.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_version.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b1914abfea1177ceb9e6bac6ffffe89f1a2a2af Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_version.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/api.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/api.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0283aac445a0559d3e03a003d05e8fd52c03b5b2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/api.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/default_planner.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/default_planner.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df2d6ab796457651d9219797f569782b39a9462e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/default_planner.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/filesystem.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/filesystem.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7ca5fbc5af3890219ae8c54aa6f9ffb683b3e69 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/filesystem.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/format_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/format_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88bc58b893b4b1623e9de384927069ccc14b96df Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/format_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/hf_storage.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/hf_storage.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7504290267c6cb3719f85ba51c0f051e755c9fc0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/hf_storage.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/logger.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/logger.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd68655b7587b9e63a9e535788e3af724c989e8d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/logger.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/logging_handlers.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/logging_handlers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c2df1843b76adeb908d6818d42fd5cb1eaecbaa Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/logging_handlers.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/metadata.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/metadata.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20b862b2ee01257cd3e742e2d418489eef1c4a79 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/metadata.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/optimizer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/optimizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d5e0b71bb2f433874f58c0b31214b45859cad98 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/optimizer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/planner.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/planner.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2de446ec37a01ad06527dd6a508faae1217a7ef Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/planner.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/planner_helpers.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/planner_helpers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee27553e9b179bf008a18cf93aeae1ac38a25946 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/planner_helpers.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/resharding.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/resharding.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4a9cd393f6f6f76053dc821f18838e7a1679e6c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/resharding.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/staging.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/staging.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..835ec5c6c01a8222907b534edf3c828a8f089eb7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/staging.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/state_dict.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/state_dict.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..083e677f96b5333b385beb59f93bba39b6881925 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/state_dict.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/state_dict_loader.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/state_dict_loader.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e1999069f005b161632af9de458fd1e8dade1ad Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/state_dict_loader.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/state_dict_saver.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/state_dict_saver.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..376e1d11784c7e0f52392ff53e858b452a108e35 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/state_dict_saver.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/stateful.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/stateful.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fdde54173834a55fc1a97f2764c9691c07d22590 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/stateful.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/storage.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/storage.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d234428b5fcf2e227519ed2925406cabd98a247 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/storage.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6311b5ebc1125b204b46339b4eba1486b926e318 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/checkpoint/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/_async_executor.py b/phivenv/Lib/site-packages/torch/distributed/checkpoint/_async_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..b30395cea8db8adfcf776a415690b59e393b8077 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/checkpoint/_async_executor.py @@ -0,0 +1,32 @@ +# pyre-strict +# mypy: allow-untyped-defs +import abc +import os +from concurrent.futures import Future +from typing import Optional, Union + +import torch.distributed as dist +from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE +from torch.distributed.checkpoint.planner import SavePlanner +from torch.distributed.checkpoint.storage import StorageWriter + + +class _AsyncCheckpointExecutor(abc.ABC): + @abc.abstractmethod + def execute_save( + self, + staged_state_dict: STATE_DICT_TYPE, + *, + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_writer: Optional[StorageWriter] = None, + planner: Optional[SavePlanner] = None, + process_group: Optional[dist.ProcessGroup] = None, + ) -> Future: + """ + Execute the checkpoint save request asynchronously. + + This method is intended to be used as an abstraction for + implementing async checkpointing. The actual checkpoint save + operation is executed in a separate thread or process depending + on the implementation of this interface. + """ diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/_async_process_executor.py b/phivenv/Lib/site-packages/torch/distributed/checkpoint/_async_process_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..10432caa495075561605348ff780b72d6bdbc87a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/checkpoint/_async_process_executor.py @@ -0,0 +1,330 @@ +# pyre-strict +# mypy: allow-untyped-defs +import logging +import os +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass +from enum import Enum +from typing import Any, Optional, Union +from uuid import uuid4 + +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.distributed.checkpoint._async_executor import _AsyncCheckpointExecutor +from torch.distributed.checkpoint.logger import _dcp_method_logger, _init_logger +from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE +from torch.distributed.checkpoint.planner import SavePlanner +from torch.distributed.checkpoint.storage import StorageWriter +from torch.distributed.checkpoint.utils import _DistWrapper +from torch.distributed.elastic.agent.server.api import _get_fq_hostname +from torch.distributed.elastic.utils.distributed import get_free_port + + +logger = logging.getLogger() + + +class _CheckpointSaveProcessControlOpts(Enum): + INIT_COMPLETE = "init_complete" + TERMINATE = "terminate" + + +@dataclass(init=False, unsafe_hash=True) +class _CheckpointRequestIdentifier: + checkpoint_id: Union[str, os.PathLike, None] + uuid: str + + def __init__(self, checkpoint_id: Union[str, os.PathLike, None]): + self.checkpoint_id = checkpoint_id + self.uuid = str(uuid4()) + + +@dataclass +class _AsyncCheckpointRequest: + staged_state_dict: STATE_DICT_TYPE + checkpoint_request_id: _CheckpointRequestIdentifier + storage_writer: Optional[StorageWriter] = None + planner: Optional[SavePlanner] = None + + +@dataclass(init=False) +class _ProcessGroupInitInfo: + local_rank: int + global_rank: int + world_size: int + tcp_store_master_addr: str + tcp_store_master_port: int + + def __init__(self, process_group: Optional[dist.ProcessGroup] = None): + self.local_rank = dist.get_node_local_rank(fallback_rank=0) + self.global_rank = dist.get_rank(process_group) + self.world_size = dist.get_world_size(process_group) + + # Let coordinator rank find a free port on the localhost. + # Broadcast the (master_addr, free_port) to all ranks; each rank in the + # checkpoint daemon process will use TCPStore (master_addr, master_port) + # for collective communication. + dist_wrapper: _DistWrapper = _DistWrapper( + group=process_group, + use_dist=True, + coordinator_rank=0, + ) + + def get_master_addr_and_port() -> tuple[str, int]: + master_addr = os.environ.get("MASTER_ADDR") + if master_addr is None: + master_addr = _get_fq_hostname() + return master_addr, get_free_port() + + self.tcp_store_master_addr, self.tcp_store_master_port = dist_wrapper.broadcast( + step="get_master_addr_and_port", + map_fun=get_master_addr_and_port, + ) + + +class _AsyncCheckpointProcess: + def __init__( + self, + pg_init_info: _ProcessGroupInitInfo, + ): + self.ctx = mp.get_context("spawn") + self._process_pipe, child_end = self.ctx.Pipe() + + self._save_process = self.ctx.Process( + target=self._checkpointing_subprocess, + args=( + pg_init_info, + child_end, + ), + daemon=True, + ) + + self._save_process.start() + + # Close the parent's copy of child end after we pass it into the child, + # so the recv()s on it will fail-fast if the child process dies. + child_end.close() + + # Wait for the checkpoint background process to initialize. + # Using default GLOO init timeout. + response = self._wait_for_response(timeout=1800) + assert response == _CheckpointSaveProcessControlOpts.INIT_COMPLETE + + def __del__(self) -> None: + if self._save_process.is_alive(): + logger.info("Terminating the checkpoint background process...") + self._send(_CheckpointSaveProcessControlOpts.TERMINATE) + self._save_process.join() + + def _send(self, data: Any) -> None: + self._process_pipe.send(data) + + def _wait_for_response(self, timeout: Optional[float] = None) -> Any: + if not self._save_process.is_alive(): + logger.info("Checkpoint background process is dead calling join()...") + self._save_process.join() + raise RuntimeError( + f"Checkpoint background process is dead. Exit code: {self._save_process.exitcode}" + ) + + if timeout is not None and not self._process_pipe.poll(timeout=timeout): + raise RuntimeError( + f"Timed out after {timeout}s while waiting for response from checkpointer process pid: {self._save_process.pid}" + ) + + try: + response = self._process_pipe.recv() + except EOFError: + raise RuntimeError( # noqa: B904 + f"Checkpoint background process is dead. Exit code: {self._save_process.exitcode}" + ) + + if isinstance(response, BaseException): + raise response + + return response + + def save( + self, + staged_state_dict: STATE_DICT_TYPE, + *, + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_writer: Optional[StorageWriter] = None, + planner: Optional[SavePlanner] = None, + ) -> Metadata: + # Create a unique identifier to locate requests/responses + # from the checkpoint daemon process. + checkpoint_request_id = _CheckpointRequestIdentifier(checkpoint_id) + async_cp_request = _AsyncCheckpointRequest( + staged_state_dict=staged_state_dict, + checkpoint_request_id=checkpoint_request_id, + storage_writer=storage_writer, + planner=planner, + ) + self._send(async_cp_request) + result = self._wait_for_response() + assert isinstance(result, Metadata) + return result + + @staticmethod + def _execute_save( + state_dict: STATE_DICT_TYPE, + *, + checkpoint_request_id: _CheckpointRequestIdentifier, + storage_writer: Optional[StorageWriter] = None, + planner: Optional[SavePlanner] = None, + ) -> Metadata: + from torch.distributed.checkpoint.state_dict_saver import save + + metadata = save( + state_dict, + checkpoint_id=checkpoint_request_id.checkpoint_id, + storage_writer=storage_writer, + planner=planner, + ) + return metadata + + @staticmethod + def _checkpointing_subprocess( + pg_init_info: _ProcessGroupInitInfo, + parent_conn, + ) -> None: + try: + _init_logger(pg_init_info.global_rank) + + # Setup environment variables for process group initialization. + os.environ["TORCHELASTIC_USE_AGENT_STORE"] = "False" + os.environ["MASTER_ADDR"] = pg_init_info.tcp_store_master_addr + os.environ["MASTER_PORT"] = str(pg_init_info.tcp_store_master_port) + os.environ["LOCAL_RANK"] = str(pg_init_info.local_rank) + os.environ["RANK"] = str(pg_init_info.global_rank) + os.environ["WORLD_SIZE"] = str(pg_init_info.world_size) + + logger.info( + "Initializing dist.ProcessGroup in checkpoint background process" + ) + # NOTE: GLOO backend is enforced here. + dist.init_process_group(backend=dist.Backend.GLOO) + dist.barrier() + + logger.info("Checkpoint background process is running...") + parent_conn.send(_CheckpointSaveProcessControlOpts.INIT_COMPLETE) + + # Serving loop. + while True: + logger.info("Waiting for checkpoint save request...") + obj = parent_conn.recv() + if ( + isinstance(obj, _CheckpointSaveProcessControlOpts) + and obj == _CheckpointSaveProcessControlOpts.TERMINATE + ): + logger.info("Terminating the checkpoint background process.") + return + assert isinstance(obj, _AsyncCheckpointRequest) + logger.info( + f"Received async checkpoint request with id={obj.checkpoint_request_id.checkpoint_id}" # noqa: G004 + ) + + response = _AsyncCheckpointProcess._execute_save( + obj.staged_state_dict, + checkpoint_request_id=obj.checkpoint_request_id, + storage_writer=obj.storage_writer, + planner=obj.planner, + ) + parent_conn.send(response) + logger.info( + f"Submitted checkpoint save request for checkpoint_id={obj.checkpoint_request_id}" # noqa: G004 + ) + except BaseException as e: + logger.error( + f"Checkpoint background process encountered an exception: {e}" # noqa: G004 + ) + parent_conn.send(e) + raise + finally: + logger.info("Checkpoint background process is shutting down...") + dist.destroy_process_group() + parent_conn.close() + + +_CHECKPOINT_PROCESS: Optional[_AsyncCheckpointProcess] = None + + +class _ProcessBasedAsyncCheckpointExecutor(_AsyncCheckpointExecutor): + def __init__(self) -> None: + self._executor = ThreadPoolExecutor(max_workers=1) + + @staticmethod + def _execute_save_impl( + *, + pg_init_info: Optional[_ProcessGroupInitInfo], + staged_state_dict: STATE_DICT_TYPE, + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_writer: Optional[StorageWriter] = None, + planner: Optional[SavePlanner] = None, + process_group: Optional[dist.ProcessGroup] = None, + ) -> Metadata: + global _CHECKPOINT_PROCESS + if _CHECKPOINT_PROCESS is None: + assert pg_init_info is not None + ckpt_kwargs = {} + if (ckpt_id := getattr(storage_writer, "checkpoint_id", None)) is not None: + ckpt_kwargs["checkpoint_id"] = ckpt_id + ckpt_kwargs["process_group"] = process_group + + @_dcp_method_logger(**ckpt_kwargs) + def create_checkpoint_daemon_process() -> None: + global _CHECKPOINT_PROCESS + _CHECKPOINT_PROCESS = _AsyncCheckpointProcess(pg_init_info=pg_init_info) + + create_checkpoint_daemon_process() + + assert _CHECKPOINT_PROCESS is not None + return _CHECKPOINT_PROCESS.save( + staged_state_dict=staged_state_dict, + checkpoint_id=checkpoint_id, + storage_writer=storage_writer, + planner=planner, + ) + + def execute_save( + self, + staged_state_dict: STATE_DICT_TYPE, + *, + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_writer: Optional[StorageWriter] = None, + planner: Optional[SavePlanner] = None, + process_group: Optional[dist.ProcessGroup] = None, + ) -> Future: + """ + NOTE: + + - Checkpoint process is implemented as a daemon process. + The AsyncCheckpointProcess' lifetime is tied to the lifetime of the + main process (e.g. trainer process). + + - The first call to execute_save_in_process() will initialize the checkpoint + daemon process. Subsequent async checkpoint requests will not need process + initialization. Therefore, the first async checkpoint request will take longer to complete. + + - Process initialization can have significant overhead, dominated by latency for all ranks to spawn + a background process + process group initialization in the background process. + """ + + global _CHECKPOINT_PROCESS + pg_init_info: Optional[_ProcessGroupInitInfo] = None + if _CHECKPOINT_PROCESS is None: + # Find a free port on coordinator rank and broadcast + # to all ranks. + pg_init_info = _ProcessGroupInitInfo(process_group) + + f: Future = self._executor.submit( + self._execute_save_impl, + pg_init_info=pg_init_info, + staged_state_dict=staged_state_dict, + checkpoint_id=checkpoint_id, + storage_writer=storage_writer, + planner=planner, + ) + f.add_done_callback(lambda f: self._executor.shutdown(wait=False)) + + return f diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/_async_thread_executor.py b/phivenv/Lib/site-packages/torch/distributed/checkpoint/_async_thread_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..84d3526f341033779bf014662ea5623db28de456 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/checkpoint/_async_thread_executor.py @@ -0,0 +1,39 @@ +# pyre-strict +# mypy: allow-untyped-defs +import os +from concurrent.futures import Future, ThreadPoolExecutor +from typing import Optional, Union + +import torch.distributed as dist +from torch.distributed.checkpoint._async_executor import _AsyncCheckpointExecutor +from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE +from torch.distributed.checkpoint.planner import SavePlanner +from torch.distributed.checkpoint.storage import StorageWriter + + +class _ThreadBasedAsyncCheckpointExecutor(_AsyncCheckpointExecutor): + def __init__(self) -> None: + self._executor = ThreadPoolExecutor(max_workers=1) + + def execute_save( + self, + staged_state_dict: STATE_DICT_TYPE, + *, + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_writer: Optional[StorageWriter] = None, + planner: Optional[SavePlanner] = None, + process_group: Optional[dist.ProcessGroup] = None, + ) -> Future: + from torch.distributed.checkpoint.state_dict_saver import save + + f: Future = self._executor.submit( + save, + staged_state_dict, + checkpoint_id=checkpoint_id, + storage_writer=storage_writer, + planner=planner, + process_group=process_group, + ) + f.add_done_callback(lambda f: self._executor.shutdown(wait=False)) + + return f diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/_checkpointer.py b/phivenv/Lib/site-packages/torch/distributed/checkpoint/_checkpointer.py new file mode 100644 index 0000000000000000000000000000000000000000..66fbf84374655680d73ee6b276620554d602e14d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/checkpoint/_checkpointer.py @@ -0,0 +1,100 @@ +from concurrent.futures import Future +from typing import Any, Optional + +import torch.distributed as dist +import torch.distributed.checkpoint.state_dict_loader as loader +import torch.distributed.checkpoint.state_dict_saver as saver +from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE +from torch.distributed.checkpoint.storage import ( + LoadPlanner, + SavePlanner, + StorageReader, + StorageWriter, +) + + +__all__: list[str] = [] + + +class _Checkpointer: + """This base class specefies a high level API for saving and loading + distributed `state_dict` 's. It provides an abstraction over the low-level APIs + provided by :py:mod:`torch.distributed.checkpoint.storage`, essentially calling + :py:meth: `torch.distributed.state_dict_saver.save` and + :py:meth: `torch.distributed.state_dict_loader.load` with the provided storage + readers and writers. + + .. warning:: + This feature is experimental and subject to removal/change. + + """ + + def __init__( + self, + storage_writer: StorageWriter, + storage_reader: StorageReader, + *, + process_group: Optional[dist.ProcessGroup] = None, + coordinator_rank: int = 0, + no_dist: bool = False, + load_planner: Optional[LoadPlanner] = None, + save_planner: Optional[SavePlanner] = None, + ): + """Initializes the Checkpointer instance. + + Args: + storage_writer: Instance of StorageWrite use to perform writes. + storage_reader: StorageReader used to load data from. + process_group: ProcessGroup to be used for cross-rank synchronization. + coordinator_rank: Rank to use to coordinate the checkpoint. rank0 is used by default. + no_dist: If ``True``, distributed checkpoint will not load in SPMD style. (Default: ``False``) + loader_planner: Instance of LoadPlanner to use when loading. + save_planner: Instance of SavePlanner to use when saving. + """ + self.storage_writer = storage_writer + self.storage_reader = storage_reader + self.process_group = process_group + self.coordinator_rank = coordinator_rank + self.no_dist = no_dist + self.load_planner = load_planner + self.save_planner = save_planner + + def save( + self, + state_dict: STATE_DICT_TYPE, + ) -> Metadata: + """Calls :py:meth: `torch.distributed.state_dict_saver.save`. Utilizing values passed during initialization.""" + return saver.save( + state_dict, + self.storage_writer, + process_group=self.process_group, + coordinator_rank=self.coordinator_rank, + no_dist=self.no_dist, + planner=self.save_planner, + ) + + def async_save( + self, + state_dict: STATE_DICT_TYPE, + ) -> Future: + """ + Calls :py:meth: `torch.distributed.state_dict_saver._async_save`. Utilizing values passed during initialization. + + Returns: + Future: A future holding the resultant Metadata object from `save`. + """ + return saver.async_save( + state_dict, + storage_writer=self.storage_writer, + process_group=self.process_group, + planner=self.save_planner, + ) + + def load(self, state_dict: dict[str, Any]) -> None: + """Calls :py:meth: `torch.distributed.state_dict_loader.load`. Utilizing values passed during initialization.""" + loader.load( + state_dict, + storage_reader=self.storage_reader, + process_group=self.process_group, + planner=self.load_planner, + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/_dedup_save_plans.py b/phivenv/Lib/site-packages/torch/distributed/checkpoint/_dedup_save_plans.py new file mode 100644 index 0000000000000000000000000000000000000000..b00b5f1666fe460a23e17e2646f5186b0df8f5e7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/checkpoint/_dedup_save_plans.py @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import dataclasses +from collections import defaultdict +from typing import TYPE_CHECKING + +from torch.distributed.checkpoint.planner import SavePlan, WriteItem + + +if TYPE_CHECKING: + from torch.distributed.checkpoint.metadata import MetadataIndex + +__all__ = ["dedup_save_plans"] + + +def dedup_save_plans( + all_plans: list[SavePlan], + save_to_lowest_rank: bool = False, +) -> list[SavePlan]: + """ + Removes duplicate entries from appearing on multiple SavePlans. For each duplicate across + a set of SavePlans, only the smallest SavePlan in terms of planned storage keeps the entry. + + Please note that this function does not modify the original SavePlans, but rather returns + """ + + # Map to query the plan indices that a write item is duplicated in + write_item_to_plan_indices: dict[MetadataIndex, set[int]] = defaultdict(set) + # Map to query the write item from its index + write_item_idx_to_write_item: dict[MetadataIndex, WriteItem] = {} + # Set of write item indices that are present in each plan + # After deduplication, this will be the set of write item indices that are present in the final plans + plan_to_item_indices: list[set[MetadataIndex]] = [ + {item.index for item in plan.items} for plan in all_plans + ] + + for plan_idx, plan in enumerate(all_plans): + for write_item in plan.items: + # map each write item to its plan + write_item_to_plan_indices[write_item.index].add(plan_idx) + write_item_idx_to_write_item[write_item.index] = write_item + plan_to_size = [0] * len(all_plans) + for write_item_idx, plan_indices in write_item_to_plan_indices.items(): + if save_to_lowest_rank: + select_plan_idx = min(plan_indices) + else: + select_plan_idx = min( + plan_indices, key=lambda plan_idx: plan_to_size[plan_idx] + ) + + write_item = write_item_idx_to_write_item[write_item_idx] + # Ignore the storage size of anything that is not a tensor, since + # we don't know how much storage they represent + plan_to_size[select_plan_idx] += write_item.tensor_storage_size() or 1 + for plan_idx in plan_indices - {select_plan_idx}: + plan_to_item_indices[plan_idx].discard(write_item_idx) + # Sanity check + assert len(all_plans) == len(plan_to_item_indices) + # Create new plans with the updated write items post deduplication + return [ + dataclasses.replace( + plan, items=[item for item in plan.items if item.index in item_indexes] + ) + for plan, item_indexes in zip(all_plans, plan_to_item_indices) + ] diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/_dedup_tensors.py b/phivenv/Lib/site-packages/torch/distributed/checkpoint/_dedup_tensors.py new file mode 100644 index 0000000000000000000000000000000000000000..e5590f3a85c6a342dc5c44b7aa619f6bd411a490 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/checkpoint/_dedup_tensors.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import dataclasses +import logging +from typing import TYPE_CHECKING + +from torch.distributed.checkpoint.planner import SavePlan + + +if TYPE_CHECKING: + from torch.distributed.checkpoint.metadata import MetadataIndex + +__all__ = ["dedup_tensors"] + + +def init_logger() -> logging.Logger: + logger = logging.getLogger(__name__) + level = logging.INFO + logger.setLevel(level) + console = logging.StreamHandler() + formatter = logging.Formatter( + "%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s" + ) + console.setFormatter(formatter) + console.setLevel(level) + logger.addHandler(console) + logger.propagate = False + return logger + + +logger = init_logger() + + +# TODO add docstring for dedup_tensors +def dedup_tensors(all_plans: list[SavePlan]) -> list[SavePlan]: + all_plans = list(all_plans) + key_to_plan: dict[MetadataIndex, list[int]] = {} + for plan_idx, plan in enumerate(all_plans): + for write_item in plan.items: + key_to_plan.setdefault(write_item.index, []).append(plan_idx) + + replicated_items = {k: v for k, v in key_to_plan.items() if len(v) > 1} + + # Remove duplicates by always keeping the first entry. + # Compute the per-rank remove set. + plan_to_keys: dict[int, list[MetadataIndex]] = {} + for key, plans in replicated_items.items(): + for plan_idx in plans[1:]: + plan_to_keys.setdefault(plan_idx, []).append(key) + if len(plan_to_keys) > 0: + logger.info("Duplicate keys to remove: %s", plan_to_keys) + + for plan_idx, keys in plan_to_keys.items(): + key_set = set(keys) + # rewrite items and remove elements + new_items = [ + write_item + for write_item in all_plans[plan_idx].items + if write_item.index not in key_set + ] + all_plans[plan_idx] = dataclasses.replace(all_plans[plan_idx], items=new_items) + + return all_plans diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/_extension.py b/phivenv/Lib/site-packages/torch/distributed/checkpoint/_extension.py new file mode 100644 index 0000000000000000000000000000000000000000..5d112272d90130823f0a0f49d7c3b258bb464560 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/checkpoint/_extension.py @@ -0,0 +1,221 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +import abc +import io +from collections.abc import Sequence +from typing import cast, IO, Optional + +# introduced as collections.abc.Buffer in Python 3.12 +from typing_extensions import Buffer + +from torch._utils import try_import + + +# NOTE: everything in this file is experimental, and subject to +# change. Feedback and bug fixes are always welcome. + +pyzstd_module_name = "pyzstd" +pyzstd = try_import(pyzstd_module_name) +zstandard_module_name = "zstandard" +zstandard = try_import(zstandard_module_name) + + +__all__ = [ + "Extension", + "StreamTransformExtension", + "ZStandard", + "ExtensionRegistry", +] + + +class Extension(abc.ABC): + """ + Extensions provide modular additions to functionality within distributed checkpointing, + which affect the layout or format of the written artifacts. Extensions may be + built into pytorch, or provided externally. + + When writing, the caller provides a list of extension instances of the appropriate + type. Each extension can output a descriptor which is used to reconstitute the + extension at read-time. + """ + + @staticmethod + @abc.abstractmethod + def registry_name() -> str: + """ + See ExtensionRegistry.from_descriptor_list + """ + + @staticmethod + @abc.abstractmethod + def from_descriptor(version: str) -> "Extension": + """ + See ExtensionRegistry.from_descriptor_list + """ + + @abc.abstractmethod + def get_descriptor(self) -> str: + """ + Return descriptor name to be included in metadata. The form should be + "extension_name[@local-domain][/version]". + """ + + +class StreamTransformExtension(Extension): + """ + An extension which performs transformation on a byte stream, such as compression + or encryption. + + Implementations should try to be memory friendly and performant. For example, don't + read the whole input, then transform it, and write it back. If at all possible, do it in + chunks. But, don't read/transform/write one byte at a time, either. + """ + + @abc.abstractmethod + def transform_to(self, output: IO[bytes]) -> IO[bytes]: + """ + Takes a writeable output stream, and generates a new stream which implements the + output transform. Input data written to the returned stream will be transformed + and written to the `output` argument stream. + """ + + @abc.abstractmethod + def transform_from(self, input: IO[bytes]) -> IO[bytes]: + """ + Takes a readable input stream, and generates a new stream which implements the + input transform. When the returned stream is read, data will be read from the + 'input' stream, transformed, and returned. + """ + + +class ZStandard(StreamTransformExtension): + @staticmethod + def is_available() -> bool: + return zstandard is not None or pyzstd is not None + + @staticmethod + def from_descriptor(version: str) -> "ZStandard": + if version.partition(".")[0] != "1": + raise ValueError(f"Unknown extension {version=}") + if not ZStandard.is_available(): + raise ValueError( + f"Stream with ZStandard compression cannot be processed because " + f"no module named '{zstandard_module_name}' or '{pyzstd_module_name}'" + ) + return ZStandard() + + @staticmethod + def registry_name() -> str: + return "stream.zstd" + + def __init__(self) -> None: + super().__init__() + if not ZStandard.is_available(): + raise ValueError( + f"ZStandard extension is unavailable because no module named '{zstandard_module_name}' or '{pyzstd_module_name}'" + ) + + def get_descriptor(self) -> str: + return f"{self.registry_name()}/1" + + def transform_to(self, output: IO[bytes]) -> IO[bytes]: + if zstandard is not None: + compressor = zstandard.ZstdCompressor() # type: ignore[union-attr] + return compressor.stream_writer(output) + + class Writer(io.RawIOBase): + def __init__(self, output: IO[bytes]) -> None: + self.output = output + self.compressor = pyzstd.ZstdCompressor() # type: ignore[union-attr] + + def writeable(self) -> bool: + return True + + def write(self, b: Buffer) -> Optional[int]: + outdata = self.compressor.compress(b) + if outdata: + self.output.write(outdata) + return len(memoryview(b)) + + def flush(self) -> None: + outdata = self.compressor.flush() + if outdata: + self.output.write(outdata) + self.output.flush() + + return cast(IO[bytes], Writer(output)) + + def transform_from(self, input: IO[bytes]) -> IO[bytes]: + if zstandard is not None: + decompressor = zstandard.ZstdDecompressor() # type: ignore[union-attr] + return decompressor.stream_reader(input) + + class Reader(io.RawIOBase): + def __init__(self, input: IO[bytes]) -> None: + self.input = input + self.decompressor = pyzstd.EndlessZstdDecompressor() # type: ignore[union-attr] + + def readable(self) -> bool: + return True + + def readinto(self, b: Buffer) -> Optional[int]: + # This needs to read enough so it can decompress + # something so the output doesn't look like EOF. This + # means reading at least one block. The max block + # size is 128KB, so we read that plus some + # overhead to be sure. + + if self.decompressor.needs_input: + indata = self.input.read((128 + 6) * 1024) + else: + indata = b"" + + bview = memoryview(b) + blen = len(bview) + outdata = self.decompressor.decompress(indata, blen) + if outdata is None: + return None + + count = len(outdata) + bview[:count] = outdata + return count + + def seekable(self) -> bool: + return False + + return cast(IO[bytes], Reader(input)) + + +class ExtensionRegistry: + def __init__(self) -> None: + # Populate default registry contents + self.extensions: dict[str, type[Extension]] = { + cls.registry_name(): cls for cls in (ZStandard,) + } + + def register(self, cls: type[Extension]) -> None: + self.extensions[cls.registry_name()] = cls + + def from_descriptor_list(self, descriptors: Sequence[str]) -> Sequence[Extension]: + """ + Given a seuquence of descriptor strings as returned by + Extension.get_descriptor at save time, creates a sequence of + Extension instances. The name[@local-domain] preceding the + version number is used to look up an implementation class in + the registry, and the version is passed to the class's + from_descriptor static method. If the registry contains no + match, this will throw ValueError. If the from_descriptor + method raises an exception, that will pass through to the + caller. + """ + + def from_descriptor(desc: str) -> Extension: + name, _, version = desc.partition("/") + if version is None: + version = 0 + ext = self.extensions.get(name) + if not ext: + raise ValueError(f"Unknown extension {name=}") + return ext.from_descriptor(version) + + return [from_descriptor(desc) for desc in descriptors] diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/_fsspec_filesystem.py b/phivenv/Lib/site-packages/torch/distributed/checkpoint/_fsspec_filesystem.py new file mode 100644 index 0000000000000000000000000000000000000000..98fe84a1146a7f0c45588415bc930528b1ef96cd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/checkpoint/_fsspec_filesystem.py @@ -0,0 +1,167 @@ +# Mypy will not try inferring the types of any 3rd party libraries installed. +# mypy: ignore-errors + +import io +import os +from collections.abc import Generator, Sequence +from contextlib import contextmanager +from pathlib import Path +from typing import Optional, TYPE_CHECKING, Union + +from fsspec.core import url_to_fs + +from torch.distributed.checkpoint._extension import StreamTransformExtension +from torch.distributed.checkpoint.filesystem import ( + FileSystemBase, + FileSystemReader, + FileSystemWriter, + SerializationFormat, +) + + +if TYPE_CHECKING: + from fsspec import AbstractFileSystem + + +__all__ = [ + "FsspecWriter", + "FsspecReader", +] + + +class FileSystem(FileSystemBase): + def __init__(self) -> None: + self.fs: Optional[AbstractFileSystem] = None + + @contextmanager + def create_stream( + self, path: Union[str, os.PathLike], mode: str + ) -> Generator[io.IOBase, None, None]: + assert self.fs is not None + path = os.fspath(path) + + # fsspec does not support concurrent transactions, and not all + # AbstractFileSystem have working rollback implementations, so + # just manually delete the file if necessary on errors. + with self.fs.open(path, mode) as stream: + try: + yield stream + except: # noqa: B001,E722 + if any(ch in mode for ch in "w+a"): # cleanup file if not read-only + try: + self.rm_file(path) + except: # noqa: B001,E722 + pass + raise + + def concat_path( + self, path: Union[str, os.PathLike], suffix: str + ) -> Union[str, os.PathLike]: + return os.path.join(path, suffix) + + def init_path( + self, path: Union[str, os.PathLike], **kwargs + ) -> Union[str, os.PathLike]: + self.fs, _ = url_to_fs(path, **kwargs) + return path + + def rename( + self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike] + ) -> None: + self.fs.rename(path, new_path) + + def mkdir(self, path: Union[str, os.PathLike]) -> None: + self.fs.makedirs(path, exist_ok=True) + + @classmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + if isinstance(checkpoint_id, Path): + return False + + try: + url_to_fs(checkpoint_id) + except ValueError: + return False + + return True + + def exists(self, path: Union[str, os.PathLike]) -> bool: + return self.fs.exists(path) + + def rm_file(self, path: Union[str, os.PathLike]) -> None: + self.fs.rm(path) + + def ls(self, path: Union[str, os.PathLike]) -> list[str]: + # setting detail to False explicitly to keep the list[str] return type, + # instead of the list[Dict] return type when detail=True + return self.fs.ls(path, detail=False) + + +# TODO: add the dcp.async_save mixin +class FsspecWriter(FileSystemWriter): + """ + Basic implementation of StorageWriter using FFspec. + + This implementation makes the following assumptions and simplifications: + + * The checkpoint path is an empty or non-existing directory. + * File creation is atomic + + The checkpoint consist of one file per write request plus + a `.metadata` file with the serialized metadata. + + """ + + def __init__( + self, + path: Union[str, os.PathLike], + single_file_per_rank: bool = True, + sync_files: bool = True, + thread_count: int = 1, + per_thread_copy_ahead: int = 10_000_000, + overwrite: bool = True, + _extensions: Optional[Sequence[StreamTransformExtension]] = None, + serialization_format: SerializationFormat = SerializationFormat.TORCH_SAVE, + **kwargs, + ) -> None: + """ + Initialize the writer pointing to `path`. + + Args: + path: directory where the checkpoint will be written to. + single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True. + sync_files : force files to be synced to permanent storage. Default to True. + thread_count: Number of IO threads to use to write. Default to 1. + per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb. + overwrite: Whether to allow overwriting existing checkpoints. Defaults to True. + _extensions: Extensions to apply to output streams (EXPERIMENTAL) + + N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure. + """ + super().__init__( + path, + single_file_per_rank, + sync_files, + thread_count, + per_thread_copy_ahead, + overwrite=overwrite, + _extensions=_extensions, + serialization_format=serialization_format, + ) + self.fs = FileSystem() + self.path = self.fs.init_path(path, **kwargs) + + @classmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + return FileSystem.validate_checkpoint_id(checkpoint_id) + + +class FsspecReader(FileSystemReader): + def __init__(self, path: Union[str, os.PathLike], **kwargs) -> None: + super().__init__(path) + self.fs = FileSystem() + self.path = self.fs.init_path(path, **kwargs) + + @classmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + return FileSystem.validate_checkpoint_id(checkpoint_id) diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/_hf_utils.py b/phivenv/Lib/site-packages/torch/distributed/checkpoint/_hf_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9f5275deeb1c7bd1b5df217ea770a78c4b6f7d5e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/checkpoint/_hf_utils.py @@ -0,0 +1,105 @@ +import io +import json +import struct +from dataclasses import dataclass +from typing import Any, Optional + +import torch + + +_metadata_fn: str = "model.safetensors.index.json" + +FILE_NAME = "model-{cpt_idx}-of-{num_files}" +SHARDED_FILE_NAME = "shard-{shard_idx}-model-{cpt_idx}-of-{num_files}" +SUFFIX = ".safetensors" + +# metadata keys +CUSTOM_METADATA_KEY = "DCP_SHARDING_INFO" +DEFAULT_EXTRA_METADATA_KEY = "__metadata__" +SAVED_OFFSETS_KEY = "saved_offsets" +SHAPE_KEY = "shape" +DATA_KEY = "data" +DTYPE_KEY = "dtype" +DATA_OFFSETS_KEY = "data_offsets" + +DTYPE_MAP = { + "F16": torch.float16, + "F32": torch.float32, + "F64": torch.float64, + "I8": torch.int8, + "U8": torch.uint8, + "I16": torch.int16, + "I32": torch.int32, + "I64": torch.int64, + "BF16": torch.bfloat16, +} + +HF_DCP_VERSION: float = 1.0 +DCP_VERSION_KEY = "DCP_VERSION" +DCP_SHARDING_INFO_KEY = "DCP_SHARDING_INFO" + +FORMAT_KEY = "format" +FORMAT_VALUE = "pt" + + +@dataclass +class _HFStorageInfo: + """This is the per entry storage info.""" + + relative_path: str + offset: int + length: int + shape: torch.Size + dtype: torch.dtype + + +def _gen_file_name( + index: int, largest_index: int, shard_index: Optional[int] = None +) -> str: + if shard_index is not None: + return ( + SHARDED_FILE_NAME.format( + shard_idx=f"{shard_index}".zfill(5), + cpt_idx=f"{index}".zfill(5), + num_files=f"{largest_index}".zfill(5), + ) + + SUFFIX + ) + else: + return ( + FILE_NAME.format( + cpt_idx=f"{index}".zfill(5), num_files=f"{largest_index}".zfill(5) + ) + + SUFFIX + ) + + +def _get_safetensors_file_metadata(file_bytes: io.IOBase) -> tuple[Any, int]: + # this uses the same logic that's done in HF code base + # https://github.com/2404589803/huggingface_hub/blob/main/src/huggingface_hub/hf_api.py#L5308 + # and follows their documentation on how their files are serialized + # https://huggingface.co/docs/safetensors/index#format + + num_bytes_for_header_len = 8 + header_len_bytes = file_bytes.read(num_bytes_for_header_len) + header_len = struct.unpack(" torch.dtype: + try: + dtype = DTYPE_MAP[dtype_str] + except KeyError: + dtype = torch.get_default_dtype() + + return dtype + + +def _get_dcp_custom_metadata(metadata: Any) -> Optional[Any]: + if DEFAULT_EXTRA_METADATA_KEY in metadata: + custom_metadata = metadata[DEFAULT_EXTRA_METADATA_KEY] + if CUSTOM_METADATA_KEY in custom_metadata: + return json.loads(custom_metadata[CUSTOM_METADATA_KEY]) + return None diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/_nested_dict.py b/phivenv/Lib/site-packages/torch/distributed/checkpoint/_nested_dict.py new file mode 100644 index 0000000000000000000000000000000000000000..63c02929906c9fef011bbcee5467c9acfead2ea5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/checkpoint/_nested_dict.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE + +from . import _version +from ._traverse import ( + OBJ_PATH, + set_element, + STATE_DICT_ITEM, + traverse_state_dict, + traverse_state_dict_v_2_3, +) + + +""" +TODO: +Need to add ability to handle tuple, OrderedDict, NamedTuple. +Update mappings from dict to a class. +Change set_element to recreate the right type for tuple, OrderedDict, and NamedTuple. +""" + + +FLATTEN_MAPPING = dict[str, OBJ_PATH] + + +# TODO: Update Docstring for nested_dict.py +def flatten_state_dict( + state_dict: STATE_DICT_TYPE, +) -> tuple[STATE_DICT_TYPE, FLATTEN_MAPPING]: + """ + Flatten ``state_dict`` made of nested dicts and lists into a top level dictionary. + + Use ``unflatten_state_dict`` to revert this process. + Returns: + A tuple with the flatten state_dict and a mapping from original to new state_dict. + N.B. The new keys are derived from the object paths, joined by dot. + For example: ``{ 'a': {'b':...}}`` results in the key `a.b`. + """ + flattened: STATE_DICT_TYPE = {} + mappings: FLATTEN_MAPPING = {} + + def flat_copy(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None: + new_fqn = ".".join(map(str, path)) + if new_fqn in flattened: + raise ValueError(f"duplicated flatten key {new_fqn}") + flattened[new_fqn] = value + mappings[new_fqn] = path + + # We started to flatten dictionary since v2.4. But in order to not break + # the checkpoints that were saved before v2.4, we need to keep the old + # traversal so that we can reconstruct those checkpoints. + use_v_2_3 = ( + _version._derived_version is not None and _version._derived_version == "2_3" + ) + if use_v_2_3: + traverse_state_dict_v_2_3(state_dict, flat_copy) + else: + traverse_state_dict(state_dict, flat_copy) + return flattened, mappings + + +def unflatten_state_dict( + state_dict: STATE_DICT_TYPE, mapping: FLATTEN_MAPPING +) -> STATE_DICT_TYPE: + """Restore the original nested state_dict according to ``mapping`` and the flattened ``state_dict``.""" + nested: STATE_DICT_TYPE = {} + for key, value in state_dict.items(): + set_element(nested, mapping[key], value) + return nested diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/_sharded_tensor_utils.py b/phivenv/Lib/site-packages/torch/distributed/checkpoint/_sharded_tensor_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a36d1f8f6a8474abf802d6ab3c994ffc0c9b5d8f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/checkpoint/_sharded_tensor_utils.py @@ -0,0 +1,107 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +import copy +from typing import TYPE_CHECKING + +import torch.distributed as dist +from torch.distributed._shard.sharded_tensor import Shard, ShardedTensor, ShardMetadata +from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE +from torch.distributed.remote_device import _remote_device + +from ._traverse import OBJ_PATH, set_element, STATE_DICT_ITEM, traverse_state_dict +from .utils import _element_wise_add, _normalize_device_info + + +if TYPE_CHECKING: + from torch.distributed._shard.sharded_tensor.metadata import ShardedTensorMetadata + + +# TODO: We need to refactor this code. +def _flatten_sharded_tensors(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE: + r""" + Transform ``state_dict`` by flattening all nested ShardedTensor instances found. + + The resulting ShardedTensor instances are only correct regarding the local shard and + MUST not be used for any other purpose but checkpointing, as no operator will work with them. + + This function should be used in conjunction with a state_dict produced by FSDP's + StateDictType.SHARDED_STATE_DICT methods. + """ + new_state_dict: STATE_DICT_TYPE = {} + + def rewrite_dict(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None: + if not isinstance(value, ShardedTensor): + set_element(new_state_dict, path, value) + return + shards = value.local_shards() + + if len(shards) == 0: + return + if len(shards) != 1: + set_element(new_state_dict, path, value) + return + + outer_shard = shards[0] + + inner_st = outer_shard.tensor + if not isinstance(inner_st, ShardedTensor): + set_element(new_state_dict, path, value) + return + + if len(inner_st.local_shards()) != 1: + raise ValueError("Cannot handle inner tensor with more than 1 shard") + inner_shard = inner_st.local_shards()[0] + + local_shards = [ + Shard( + tensor=inner_shard.tensor, + metadata=ShardMetadata( + shard_offsets=_element_wise_add( + outer_shard.metadata.shard_offsets, + inner_shard.metadata.shard_offsets, + ), + shard_sizes=inner_shard.metadata.shard_sizes, + placement=f"rank:{dist.get_rank()}/{inner_shard.tensor.device}", + ), + ) + ] + + st_meta: ShardedTensorMetadata = copy.deepcopy(value.metadata()) + other_rank = 0 if dist.get_rank() > 0 else 1 + device_info = _normalize_device_info(inner_shard.tensor.device.type, 0) + + # Remove the outer ST shard the inner ST covers + for i, shard_md in enumerate(st_meta.shards_metadata): + if shard_md.shard_offsets == outer_shard.metadata.shard_offsets: + st_meta.shards_metadata.pop(i) + break + + # Attribute other rank for the other shards + for shard_md in st_meta.shards_metadata: + shard_md.placement = _remote_device(f"rank:{other_rank}/{device_info}") + + # Add other inner shards from the inner tensor + for inner_md in inner_st.metadata().shards_metadata: + if inner_md.shard_offsets != inner_shard.metadata.shard_offsets: + st_meta.shards_metadata.append( + ShardMetadata( + shard_offsets=_element_wise_add( + outer_shard.metadata.shard_offsets, + inner_md.shard_offsets, + ), + shard_sizes=inner_md.shard_sizes, + placement=f"rank:{other_rank}/{device_info}", + ) + ) + + # Finally add this shard + st_meta.shards_metadata.append(local_shards[0].metadata) + + st = ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards=local_shards, + sharded_tensor_metadata=st_meta, + ) + set_element(new_state_dict, path, st) + + traverse_state_dict(state_dict, rewrite_dict) + return new_state_dict diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/_state_dict_stager.py b/phivenv/Lib/site-packages/torch/distributed/checkpoint/_state_dict_stager.py new file mode 100644 index 0000000000000000000000000000000000000000..711fd73bb768c6a4295dba0096783cd8f93e3057 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/checkpoint/_state_dict_stager.py @@ -0,0 +1,354 @@ +# mypy: allow-untyped-defs +import logging +import types +import weakref +from copyreg import dispatch_table +from logging import getLogger +from typing import Any + +import torch +import torch.cuda._pin_memory_utils as pin_memory_utils +from torch.storage import UntypedStorage +from torch.utils.weak import WeakIdKeyDictionary + + +logger = getLogger() +logger.setLevel(logging.INFO) + + +class StateDictStager: + """ + A class for optimizing storage objects during staging for async checkpointing. + + StateDictStager stages the state_dict to CPU DRAM while applying optimizations + like memory sharing and pinning to improve performance. It caches storage objects + to avoid redundant copies and can be configured to automatically share memory + (for multi-process usage) and pin memory (for faster CPU-GPU transfers). + + Attributes: + pin_memory (bool): Whether to pin CPU memory for faster CPU-GPU transfers + share_memory (bool): Whether to share memory across processes + _cached_storage_mapping (WeakIdKeyDictionary): Maps storage objects to optimized CPU storages using weak references + """ + + def __init__(self, pin_memory: bool = False, share_memory: bool = False): + if pin_memory and not torch.cuda.is_available(): + logger.warning( + "Ignoring pin_memory flag for checkpoint staging as pinning memory" + "requires CUDA, but CUDA is not available. " + ) + self.pin_memory = False + else: + self.pin_memory = pin_memory + self.share_memory = share_memory + # Mapping from original storage objects to CPU storages using weak references + self._cached_storage_mapping = WeakIdKeyDictionary() + + def _deepcopy_atomic(x, _): + return x + + def _deepcopy_list(x, memo): + y: list = [] + memo[id(x)] = y + append = y.append + for a in x: + append(self.deepcopy_with_tensor_offload(a, memo)) + return y + + def _deepcopy_tuple(x, memo): + y = [self.deepcopy_with_tensor_offload(a, memo) for a in x] + # We're not going to put the tuple in the memo, but it's still important we + # check for it, in case the tuple contains recursive mutable structures. + try: + return memo[id(x)] + except KeyError: + pass + + # Check if any elements changed during deepcopy + for k, j in zip(x, y): + if k is not j: + # At least one element changed, create new tuple + return tuple(y) + + # No elements changed, return original tuple + return x + + def _deepcopy_dict(x, memo): + y: dict = {} + memo[id(x)] = y + for key, value in x.items(): + y[self.deepcopy_with_tensor_offload(key, memo)] = ( + self.deepcopy_with_tensor_offload(value, memo) + ) + return y + + def _deepcopy_method(x, memo): # Copy instance methods + return type(x)( + x.__func__, self.deepcopy_with_tensor_offload(x.__self__, memo) + ) + + d: dict[Any, Any] = {} + self._deepcopy_dispatch = d + d[type(None)] = _deepcopy_atomic + d[int] = _deepcopy_atomic + d[float] = _deepcopy_atomic + d[bool] = _deepcopy_atomic + d[complex] = _deepcopy_atomic + d[bytes] = _deepcopy_atomic + d[str] = _deepcopy_atomic + d[types.CodeType] = _deepcopy_atomic + d[type] = _deepcopy_atomic + d[range] = _deepcopy_atomic + d[types.BuiltinFunctionType] = _deepcopy_atomic + d[types.FunctionType] = _deepcopy_atomic + d[weakref.ref] = _deepcopy_atomic + d[property] = _deepcopy_atomic + d[types.MethodType] = _deepcopy_method + d[dict] = _deepcopy_dict + d[tuple] = _deepcopy_tuple + d[list] = _deepcopy_list + + def _stage_untyped_storage( + self, storage: UntypedStorage, non_blocking: bool = False + ): + """ + Called from the hooked storage_deepcopy function in torch.Tensor.__deepcopy__. + + This method handles the storage optimization logic for the StagingStateDict class. + It checks if the storage has already been cached, and if so, reuses it. + Otherwise, it creates a new CPU storage and applies memory optimizations. + + Args: + storage: The storage to optimize + + Returns: + The optimized storage + """ + # Check if we've already cached this storage + if storage in self._cached_storage_mapping: + cached_storage = self._cached_storage_mapping[storage] + assert cached_storage.size() == storage.size(), ( + "For async checkpointing, We cache storages in DRAM and reuse them." + "Cached storage size does not match original storage size." + "This should never happen as we track the original storage weakref " + "and clean up the cache storage. Please report this to PyTorch Distributed Checkpointing." + ) + # Reuse cached storage but update with new data + cached_storage.copy_(storage, non_blocking=non_blocking) + return cached_storage + + # Create new CPU storage + if self.share_memory: + new_storage = type(storage)._new_shared(storage.size(), device="cpu") + else: + new_storage = type(storage)(storage.size(), device="cpu") + + if self.pin_memory and new_storage.nbytes() > 0: + pin_memory_utils.pin_memory(new_storage.data_ptr(), new_storage.nbytes()) + # Set up a weak reference to unpin when cpu storage is garbage collected + f = weakref.finalize( + new_storage, pin_memory_utils.unpin_memory, new_storage.data_ptr() + ) + # This makes sure that the finalizer is not called after + # cuda context is destroyed. + f.atexit = False + + new_storage.copy_(storage, non_blocking=non_blocking) + + # Cache the storage - WeakIdKeyDictionary will automatically clean up when storage is garbage collected + self._cached_storage_mapping[storage] = new_storage + return new_storage + + @torch.no_grad() + def stage( + self, + state_dict: dict[str, Any], + non_blocking: bool = False, + ) -> dict[str, Any]: + return self.deepcopy_with_tensor_offload(state_dict, non_blocking=non_blocking) + + def _offload_tensor(self, x, memo, non_blocking=False): + """ + Deep copy a PyTorch tensor with optimized storage handling. + + This method creates a CPU copy of a tensor while applying memory optimizations + like sharing and pinning based on the StateDictStager configuration. + + Args: + x: The tensor to copy + memo: Memo dictionary for tracking already copied objects + non_blocking: Whether to perform non-blocking copies where possible + + Returns: + A CPU copy of the tensor with optimized storage + """ + # Create a new empty tensor on CPU + y = x.new_empty([], device="cpu") + + # Store in memo dict early to handle recursive references + d = id(x) + memo[d] = y + + if type(x) is torch.Tensor or x.data_ptr() != 0: + # Try to get the untyped storage and optimize it + untyped_storage = x.untyped_storage() + copied_storage = self._stage_untyped_storage( + untyped_storage, non_blocking=non_blocking + ) + # Set the tensor data using the optimized storage + y.set_(copied_storage, x.storage_offset(), x.size(), x.stride()) + + # Copy any attributes the tensor might have + if hasattr(x, "__dict__"): + for attr_name, attr_value in x.__dict__.items(): + setattr( + y, + attr_name, + self.deepcopy_with_tensor_offload( + attr_value, memo, non_blocking=non_blocking + ), + ) + + if hasattr(x, "__slots__"): + for slot in x.__slots__: + if hasattr(x, slot): + setattr( + y, + slot, + self.deepcopy_with_tensor_offload( + getattr(x, slot), memo, non_blocking=non_blocking + ), + ) + + return y + + @torch.no_grad() + def deepcopy_with_tensor_offload(self, x, memo=None, _nil=[], non_blocking=False): # noqa: B006 + """Deep copy operation on arbitrary Python objects with special handling for PyTorch tensors. + + This implementation extends the standard deepcopy functionality to handle PyTorch tensors + and their storages in a way that optimizes memory usage and performance, similar to the + stage method. It applies memory sharing and pinning optimizations based on the StateDictStager + configuration. + + Args: + x: The object to deep copy + memo: Memo dictionary for tracking already copied objects + _nil: Sentinel value for memo dictionary + non_blocking: Whether to perform non-blocking copies where possible + + Returns: + A deep copy of the input object with optimized tensor storage handling + """ + if memo is None: + memo = {} + + d = id(x) + y = memo.get(d, _nil) + if y is not _nil: + return y + + cls = type(x) + + # tensors and subclasses of tensors are handled separately + if isinstance(x, torch.Tensor): + y = self._offload_tensor(x, memo, non_blocking=non_blocking) + + # Use the dispatch table for standard types + copier = self._deepcopy_dispatch.get(cls) + if copier is not None: + y = copier(x, memo) + else: + if issubclass(cls, type): + y = self._deepcopy_dispatch[type](x, memo) + else: + copier = getattr(x, "__deepcopy__", None) + if copier is not None: + y = copier(memo) + else: + reductor = dispatch_table.get(cls) + if reductor: + rv = reductor(x) + else: + reductor = getattr(x, "__reduce_ex__", None) + if reductor is not None: + rv = reductor(4) + else: + reductor = getattr(x, "__reduce__", None) + if reductor: + rv = reductor() + else: + raise RuntimeError( + f"un(deep)copyable object of type {cls}" + ) + if isinstance(rv, str): + y = x + else: + y = self._reconstruct(x, memo, *rv) + + # If is its own copy, don't memoize. + if y is not x: + memo[d] = y + self._keep_alive(x, memo) # Make sure x lives at least as long as d + return y + + def _keep_alive(self, x, memo): + """Keeps a reference to the object x in the memo. + + Because we remember objects by their id, we have + to assure that possibly temporary objects are kept + alive by referencing them. + We store a reference at the id of the memo, which should + normally not be used unless someone tries to deepcopy + the memo itself... + """ + try: + memo[id(memo)].append(x) + except KeyError: + # aha, this is the first one :-) + memo[id(memo)] = [x] + + def _reconstruct( + self, x, memo, func, args, state=None, listiter=None, dictiter=None + ): + deep = memo is not None + if deep and args: + args = (self.deepcopy_with_tensor_offload(arg, memo) for arg in args) + y = func(*args) + if deep: + memo[id(x)] = y + + if state is not None: + if deep: + state = self.deepcopy_with_tensor_offload(state, memo) + if hasattr(y, "__setstate__"): + y.__setstate__(state) + else: + if isinstance(state, tuple) and len(state) == 2: + state, slotstate = state + else: + slotstate = None + if state is not None: + y.__dict__.update(state) + if slotstate is not None: + for key, value in slotstate.items(): + setattr(y, key, value) + + if listiter is not None: + if deep: + for item in listiter: + item = self.deepcopy_with_tensor_offload(item, memo) + y.append(item) + else: + for item in listiter: + y.append(item) + if dictiter is not None: + if deep: + for key, value in dictiter: + key = self.deepcopy_with_tensor_offload(key, memo) + value = self.deepcopy_with_tensor_offload(value, memo) + y[key] = value + else: + for key, value in dictiter: + y[key] = value + return y diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/_storage_utils.py b/phivenv/Lib/site-packages/torch/distributed/checkpoint/_storage_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..23d9dca34c0d033598d7c390992682faf54500b0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/checkpoint/_storage_utils.py @@ -0,0 +1,49 @@ +import os +from typing import Union + +from .filesystem import FileSystemReader, FileSystemWriter +from .storage import StorageReader, StorageWriter + + +def _storage_setup( + storage: Union[StorageReader, StorageWriter, None], + checkpoint_id: Union[str, os.PathLike, None], + reader: bool = False, +) -> Union[None, StorageReader, StorageWriter]: + if storage: + if checkpoint_id is not None: + storage.reset(checkpoint_id) + return storage + + if not checkpoint_id: + raise RuntimeError( + "`checkpoint_id` must be specified if " + "storage_reader/storage_writer is None." + ) + + targets: list[type[Union[StorageReader, StorageWriter]]] = [] + if reader: + targets = [ + FileSystemReader, + ] + else: + targets = [ + FileSystemWriter, + ] + try: + from ._fsspec_filesystem import FsspecReader, FsspecWriter + + targets.append(FsspecReader if reader else FsspecWriter) + except Exception: + pass + + for target in targets: + if target.validate_checkpoint_id(checkpoint_id): + storage = target(checkpoint_id) # type: ignore[call-arg] + storage.reset(checkpoint_id) + return storage + + raise RuntimeError( + "Cannot detect which StorageReader or StorageWriter to use. " + "Please specify the storage_reader/storage_writer." + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/_traverse.py b/phivenv/Lib/site-packages/torch/distributed/checkpoint/_traverse.py new file mode 100644 index 0000000000000000000000000000000000000000..14a9b13a17c7ebd9dec97af9ca23cc4dbd32fbc5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/checkpoint/_traverse.py @@ -0,0 +1,198 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from collections.abc import Collection, Mapping, MutableMapping +from typing import Callable, cast, Optional, TypeVar, Union + +import torch +from torch.distributed._shard.sharded_tensor.api import ShardedTensor +from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE +from torch.distributed.tensor import DTensor + + +PATH_ITEM = Union[str, int] +OBJ_PATH = tuple[PATH_ITEM, ...] +T = TypeVar("T") + +STATE_DICT_ITEM = object +CONTAINER_TYPE = MutableMapping[PATH_ITEM, STATE_DICT_ITEM] + +__all__ = ["traverse_state_dict", "set_element", "get_element", "print_tensor"] + + +def _keep_visiting_tensors(value: STATE_DICT_ITEM) -> bool: + return isinstance(value, torch.Tensor) + + +# TODO: update docstring for traverse.py +def traverse_state_dict( + state_dict: STATE_DICT_TYPE, + visitor: Callable[[OBJ_PATH, STATE_DICT_ITEM], None], + keep_traversing: Callable[[STATE_DICT_ITEM], bool] = _keep_visiting_tensors, +) -> None: + """ + Invoke ``visitor`` for each value recursively in ``state_dict``. + Mapping will be traversed and ``visitor`` will be applied to the leaf elements. + ``visitor`` will only be applied to elements in a list or a tuple, if the + container contains tensors or mappings. + """ + + def _is_terminal(value: STATE_DICT_ITEM) -> bool: + values: Collection[STATE_DICT_ITEM] + if isinstance(value, Mapping): + return False + elif isinstance(value, list): + values = value + else: + return True + + for entry in values: + if isinstance(entry, (Mapping, list)) and not _is_terminal(entry): + return False + if keep_traversing is not None and keep_traversing(entry): + return False + return True + + def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None: + if isinstance(value, Mapping): + for k, v in value.items(): + _traverse_obj(path + (str(k),), v) + elif _is_terminal(value): + visitor(path, value) + elif isinstance(value, (list, tuple)): + for i, v in enumerate(value): + _traverse_obj(path + (i,), v) + + for key, value in state_dict.items(): + _traverse_obj((str(key),), value) + + +def traverse_state_dict_v_2_3( + state_dict: STATE_DICT_TYPE, + visitor: Callable[[OBJ_PATH, STATE_DICT_ITEM], None], + keep_traversing: Callable[[STATE_DICT_ITEM], bool] = _keep_visiting_tensors, +) -> None: + """ + Traversal is short-circuited when if finds a collection for which ``keep_visiting_tensors`` evaluates + to false for all elements. + By default, all collections with at least one ``torch.Tensor`` element are traversed. + Visitor takes a path argument that is a tuple of the keys used to reach it. + """ + + # a value is terminal if it has no other containers values inside it + def _is_terminal(value: STATE_DICT_ITEM) -> bool: + values: Collection[STATE_DICT_ITEM] + if isinstance(value, Mapping): + values = value.values() + elif isinstance(value, list): + values = value + else: + return True + + for entry in values: + if isinstance(entry, (Mapping, list)) and not _is_terminal(entry): + return False + if keep_traversing is not None and keep_traversing(entry): + return False + return True + + def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None: + if _is_terminal(value): + visitor(path, value) + elif isinstance(value, Mapping): + for k, v in value.items(): + _traverse_obj(path + (str(k),), v) + elif isinstance(value, list): + for i, v in enumerate(value): + _traverse_obj(path + (i,), v) + + for key, value in state_dict.items(): + _traverse_obj((str(key),), value) + + +def set_element( + root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: STATE_DICT_ITEM +) -> None: + """Set ``value`` in ``root_dict`` along the ``path`` object path.""" + cur_container = cast(CONTAINER_TYPE, root_dict) + + def extend_list(lst: list[STATE_DICT_ITEM], idx: int) -> None: + while len(lst) <= idx: + lst.append(None) + + for i in range(1, len(path)): + prev_key = path[i - 1] + key = path[i] + def_val = cast(STATE_DICT_ITEM, {} if type(key) == str else []) + + if isinstance(cur_container, Mapping): + cur_container = cast( + CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val) + ) + else: + extend_list(cur_container, prev_key) + if cur_container[prev_key] is None: + cur_container[prev_key] = def_val + cur_container = cur_container[prev_key] + + key = path[-1] + if type(key) == int: + extend_list(cast(list[STATE_DICT_ITEM], cur_container), key) + + cur_container[key] = value + + +def get_element( + root_dict: STATE_DICT_TYPE, + path: OBJ_PATH, + default_value: Optional[T] = None, +) -> Optional[T]: + """Retrieve the value at ``path``from ``root_dict``, returning ``default_value`` if not found.""" + cur_value = cast(CONTAINER_TYPE, root_dict) + for part in path: + if type(part) is int: + if not isinstance(cur_value, list) or len(cur_value) < part: + return default_value + elif not isinstance(cur_value, Mapping) or part not in cur_value: + return default_value + + cur_value = cast(CONTAINER_TYPE, cur_value[part]) + return cast(Optional[T], cur_value) + + +def _print_nested( + value: STATE_DICT_ITEM, + prefix: str = "", + print_fun: Callable[[str], None] = print, +) -> None: + if type(value) is ShardedTensor: + print_fun(f"{prefix} ShardedTensor size: {value.size()}") + for shard in value.local_shards(): + _print_nested( + shard.tensor, + f"{shard.metadata.shard_offsets} ", + print_fun=print_fun, + ) + elif type(value) is (DTensor): + print_fun(f"{prefix} DistributedTensor size: {value.size()}") + # TODO: add local offset for _local_tensor in print_nested. + _print_nested( + value._local_tensor, + print_fun=print_fun, + ) + elif isinstance(value, torch.Tensor): + print_fun(f"{prefix} Tensor size: {value.size()}") + else: + print_fun(f"{prefix} Type: {type(value)}") + + +def print_tensor( + path: OBJ_PATH, + value: STATE_DICT_ITEM, + print_fun: Callable[[str], None] = print, +) -> None: + """ + Use this callback with traverse_state_dict to print its content. + + By default the content is printed using the builtin ``print`` but this can + be change by passing a different ``print_fun` callable. + """ + _print_nested(value, prefix=str(path), print_fun=print_fun) diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/_version.py b/phivenv/Lib/site-packages/torch/distributed/checkpoint/_version.py new file mode 100644 index 0000000000000000000000000000000000000000..a1bca29496f3d04cd4b959a10c26bbe85f0bb8a2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/checkpoint/_version.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +from typing import Optional + + +_derived_version: Optional[str] = None diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/api.py b/phivenv/Lib/site-packages/torch/distributed/checkpoint/api.py new file mode 100644 index 0000000000000000000000000000000000000000..d99feca2767f83c5913d686c449b77a6a9a6704f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/checkpoint/api.py @@ -0,0 +1,42 @@ +import traceback as tb +from typing import Any + + +WRAPPED_EXCEPTION = tuple[BaseException, tb.StackSummary] + +__all__ = ["CheckpointException"] + + +def _wrap_exception(exc: BaseException) -> WRAPPED_EXCEPTION: + return (exc, tb.extract_tb(exc.__traceback__)) + + +def _is_wrapped_exception(obj: Any) -> bool: + if not isinstance(obj, tuple): + return False + if len(obj) != 2: + return False + return isinstance(obj[0], BaseException) and isinstance(obj[1], tb.StackSummary) + + +class CheckpointException(BaseException): + """Exception raised if failure was detected as part of a checkpoint load or save.""" + + def __init__(self, msg: str, failures: dict[int, WRAPPED_EXCEPTION]): + super().__init__(msg, failures) + self._failures = failures + + @property + def failures(self) -> dict[int, WRAPPED_EXCEPTION]: + """Return a dictionary mapping node ranks to their associated exceptions in case of failure.""" + return self._failures + + def __str__(self) -> str: + str = f"CheckpointException ranks:{self._failures.keys()}\n" + for rank, exc_pair in self._failures.items(): + exc, trace = exc_pair + str += f"Traceback (most recent call last): (RANK {rank})\n" + if trace is not None: + str += "".join(tb.format_list(trace)) + str += "".join(tb.format_exception_only(type(exc), value=exc)) + return str diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/default_planner.py b/phivenv/Lib/site-packages/torch/distributed/checkpoint/default_planner.py new file mode 100644 index 0000000000000000000000000000000000000000..1f94bc5c6dbd31614c7797fe4ef5149ddb4ab669 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/checkpoint/default_planner.py @@ -0,0 +1,669 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates + +import dataclasses +import io +import logging +import operator +from collections import ChainMap +from functools import reduce +from typing import Any, cast, Optional, Union + +import torch +from torch.distributed._shard._utils import narrow_tensor_by_index +from torch.distributed.checkpoint._dedup_save_plans import dedup_save_plans +from torch.distributed.checkpoint._nested_dict import ( + FLATTEN_MAPPING, + flatten_state_dict, +) +from torch.distributed.checkpoint._sharded_tensor_utils import _flatten_sharded_tensors +from torch.distributed.checkpoint._traverse import set_element +from torch.distributed.checkpoint.metadata import ( + BytesStorageMetadata, + ChunkStorageMetadata, + Metadata, + MetadataIndex, + STATE_DICT_TYPE, + STORAGE_TYPES, + StorageMeta, + TensorStorageMetadata, +) +from torch.distributed.checkpoint.planner import ( + LoadPlan, + LoadPlanner, + ReadItem, + SavePlan, + SavePlanner, + WriteItem, + WriteItemType, +) +from torch.distributed.checkpoint.planner_helpers import ( + _compare_save_plans, + _contains_usable_plan, + _create_default_metadata_only_plan, + _create_read_items, + _create_write_items, + _init_state_dict, + _merge_delta_local_plans, +) +from torch.distributed.checkpoint.utils import find_state_dict_object +from torch.distributed.tensor import DTensor + +from . import _version + + +logger: logging.Logger = logging.getLogger(__name__) + + +__all__ = [ + "DefaultSavePlanner", + "DefaultLoadPlanner", + "create_default_local_load_plan", + "create_default_global_load_plan", + "create_default_local_save_plan", + "create_default_global_save_plan", +] + + +# TODO: Update docstrings for default_planner.py +class DefaultSavePlanner(SavePlanner): + mappings: FLATTEN_MAPPING + + def __init__( + self, + flatten_state_dict: bool = True, + flatten_sharded_tensors: bool = True, + dedup_replicated_tensors: Optional[bool] = None, + dedup_save_to_lowest_rank: bool = False, + enable_plan_caching: bool = False, + ) -> None: + self.flatten_state_dict = flatten_state_dict + self.flatten_sharded_tensors = flatten_sharded_tensors + self.mappings = {} + self.dedup_save_to_lowest_rank = dedup_save_to_lowest_rank + if dedup_replicated_tensors is not None: + logger.warning( + "DefaultSavePlanner's `dedup_replicated_tensors` argument is being " + "deprecated, and no longer has any effect. Please remove this argument " + "from your call." + ) + self._cached_plans_key: str = self.__class__.__name__ + self._enable_plan_caching = enable_plan_caching + + def set_up_planner( + self, + state_dict: STATE_DICT_TYPE, + storage_meta: Optional[StorageMeta] = None, + is_coordinator: bool = False, + ) -> None: + if self.flatten_state_dict: + state_dict, self.mappings = flatten_state_dict(state_dict) + if self.flatten_sharded_tensors: + state_dict = _flatten_sharded_tensors(state_dict) + self.state_dict = state_dict + self.is_coordinator = is_coordinator + + def create_local_plan(self) -> SavePlan: + plan = create_default_local_save_plan(self.state_dict, self.is_coordinator) + if self.flatten_state_dict: + plan = dataclasses.replace(plan, planner_data=self.mappings) + self.plan = plan + + if self._enable_plan_caching: + # If plans are equal, we can skip sending the plan to the coordinator. + if ( + self._cached_plans_key in SavePlanner._cached_save_plan + and _compare_save_plans( + plan, SavePlanner._cached_save_plan[self._cached_plans_key] + ) + ): + logger.info( + "No change in the local plan. Skipping sending the plan to the coordinator" + ) + return SavePlan([], usable=False) + else: + SavePlanner._cached_save_plan[self._cached_plans_key] = plan + + return self.plan + + def _dedup_save_plans(self, all_plans: list[SavePlan]) -> list[SavePlan]: + return dedup_save_plans(all_plans, self.dedup_save_to_lowest_rank) + + def _create_global_plan( + self, all_plans: list[SavePlan] + ) -> tuple[list[SavePlan], Metadata]: + deduped_plans = self._dedup_save_plans(all_plans) + + global_plan, metadata = create_default_global_save_plan(deduped_plans) + + if self.flatten_state_dict: + # | does not work for Python 3.8 or older version. + # merged_mappings = reduce( + # lambda x, y: x | y, (p.planner_data for p in global_plan) + # ) + planner_data_dict = [p.planner_data for p in global_plan] + merged_mappings = dict(ChainMap(*planner_data_dict)) + metadata = dataclasses.replace(metadata, planner_data=merged_mappings) + + if not _validate_global_plan(global_plan, metadata): + raise ValueError("Failed to validate global plan") + + return global_plan, metadata + + def _create_global_plan_with_caching( + self, all_plans: list[SavePlan] + ) -> tuple[list[SavePlan], list[SavePlan], Metadata]: + """ + Create global plan with caching. + Returns a tuple of global_plan_delta, global_plan, metadata. + """ + global_plan_delta: list[SavePlan] = [] + + if self._cached_plans_key not in SavePlanner._cached_all_plans: + # Case 1: If the plans are not cached, the cache will be hydrated with the + # all_plans, global_plans (Deduped), and metadata. + + # Cache the original all_plans + SavePlanner._cached_all_plans[self._cached_plans_key] = all_plans + global_plan, metadata = self._create_global_plan(all_plans) + # Cache the deduped and validated global_plan + SavePlanner._cached_global_plan[self._cached_plans_key] = global_plan + # Cache the metadata + SavePlanner._cached_metadata[self._cached_plans_key] = metadata + # If plans are not cached, global_plan delta will be the same as global plan. + return global_plan, global_plan, metadata + + # Case 2: Plans are cached + if not _contains_usable_plan(all_plans): + # Case 2.1: Plans are cached and the local plans have NOT changed (No usable plans). + # Global plan delta will be empty plans to avoid the collective overhead. + # We can reuse the deduped global plan and metadata from the cache directly. + global_plan_delta = [SavePlan([], usable=False)] * len(all_plans) + global_plan = SavePlanner._cached_global_plan[self._cached_plans_key] + metadata = SavePlanner._cached_metadata[self._cached_plans_key] + else: + # Case 2.2: Plans are cached but the local plans have changed. + # We will merge the changed local plans with the cached local plans. + # Updated plans will overwrite the cached plans. New global plan and metadata will be created and cached. + # Global plan delta will be created by comparing the new global plan with the cached global plan. + # Only the global plan delta (updated ones) will be sent to the coordinator to avoid the collective overhead. + merged_plans = _merge_delta_local_plans( + SavePlanner._cached_all_plans[self._cached_plans_key], all_plans + ) + # Cache the updated local plans + SavePlanner._cached_all_plans[self._cached_plans_key] = merged_plans + global_plan, metadata = self._create_global_plan(merged_plans) + + if self._cached_plans_key in self._cached_global_plan: + for cached_plan, new_plan in zip( + SavePlanner._cached_global_plan[self._cached_plans_key], global_plan + ): + if _compare_save_plans(cached_plan, new_plan): + global_plan_delta.append(SavePlan([], usable=False)) + else: + global_plan_delta.append(new_plan) + + # Cache the new global plan and the metadata + SavePlanner._cached_global_plan[self._cached_plans_key] = global_plan + SavePlanner._cached_metadata[self._cached_plans_key] = metadata + + return global_plan_delta, global_plan, metadata + + def create_global_plan( + self, all_plans: list[SavePlan] + ) -> tuple[list[SavePlan], Metadata]: + global_plan_delta: list[SavePlan] = [] + if self._enable_plan_caching: + # If the plans are cached, we only need to send the global plan delta to be scattered + # across ranks. Ranks will use the cached final plans instead. + ( + global_plan_delta, + global_plan, + metadata, + ) = self._create_global_plan_with_caching(all_plans) + else: + global_plan, metadata = self._create_global_plan(all_plans) + # If the caching is not enabled, global delta plan will always be same as the new global plan. + global_plan_delta = global_plan + + self.global_plan = global_plan + self.metadata = metadata + + return global_plan_delta, self.metadata + + def _finish_plan_with_caching(self, new_plan: SavePlan) -> SavePlan: + finished_plan: SavePlan = new_plan + + if not new_plan.usable: + finished_plan = SavePlanner._cached_final_save_plan[self._cached_plans_key] + else: + finished_plan = new_plan + SavePlanner._cached_final_save_plan[self._cached_plans_key] = new_plan + return finished_plan + + def finish_plan(self, new_plan: SavePlan) -> SavePlan: + finished_plan: SavePlan = new_plan + + if self._enable_plan_caching: + finished_plan = self._finish_plan_with_caching(new_plan) + + self.plan = finished_plan + return self.plan + + def resolve_data(self, write_item: WriteItem) -> Union[torch.Tensor, io.BytesIO]: + object = self.lookup_object(write_item.index) + return self.transform_object(write_item, object) + + def lookup_object(self, index: MetadataIndex) -> Any: + """Extension from the planner interface to make it easy to extend the default planner.""" + return find_state_dict_object(self.state_dict, index) + + def transform_object(self, write_item: WriteItem, object: Any): + """Extension from the planner interface to make it easy to extend the default planner.""" + if write_item.type == WriteItemType.BYTE_IO: + bytes = io.BytesIO() + torch.save(object, bytes) + object = bytes + return object + + +class DefaultLoadPlanner(LoadPlanner): + """ + DefaultLoadPlanner that adds multiple features on top of LoadPlanner. + + In particular it adds the following: + + flatten_state_dict: Handle state_dict with nested dicts + flatten_sharded_tensors: For FSDP in 2D parallel mode + allow_partial_load: If False, will raise a runtime error if a key is present in state_dict, but not in the checkpoint. + """ + + original_state_dict: STATE_DICT_TYPE + mappings: FLATTEN_MAPPING + + def __init__( + self, + flatten_state_dict: bool = True, + flatten_sharded_tensors: bool = True, + allow_partial_load: bool = False, + ) -> None: + self.flatten_state_dict = flatten_state_dict + self.flatten_sharded_tensors = flatten_sharded_tensors + self.original_state_dict = {} + self.mappings = {} + self.allow_partial_load = allow_partial_load + + def set_up_planner( + self, + state_dict: STATE_DICT_TYPE, + metadata: Optional[Metadata] = None, + is_coordinator: bool = False, + ) -> None: + _init_state_dict(state_dict) + self.original_state_dict = state_dict + + if self.flatten_sharded_tensors: + state_dict = _flatten_sharded_tensors(state_dict) + + if self.flatten_state_dict: + state_dict, self.mappings = flatten_state_dict(state_dict) + + self.state_dict = state_dict + self.metadata = metadata + self.is_coordinator = is_coordinator + + def create_local_plan(self) -> LoadPlan: + assert self.metadata is not None + if self.flatten_state_dict: + # To support checkpoints that are saved before v2.4, we have to + # differentiate if the missing keys are due to old checkpoints. + # The contracts are: + # 1. There are 3 cases when we found a missing key. + # 1.1 Actual missing key, but allow_partial_load is False + # 1.2 Actual missing key, but allow_partial load is True + # 1.3 Old checkpoint, but allow_partial_load is False + # 1.4 Old checkpoint, but allow_partial_load is True + # 2. If we found a missing key, we first convert the keys back to + # the key format of v2.3 + # 3. If the previous missing keys are in the v2.3 keys, we assume + # this is a old checkpoint. + # 4. Pass the state_dict to `create_default_local_load_plan()`, + # which has the logic to check missing for allow_partial_load. + # So for 1.2 and 1.4 cases, we delegate allow_partial_load check to + # `create_default_local_load_plan()`. The logic here is to determine + # whether the checkpoint belong to 2.3 (or before) or 2.4 (or after). + current_keys = set(self.state_dict.keys()) + load_keys = set(self.metadata.state_dict_metadata.keys()) + missing_keys = load_keys - current_keys + if missing_keys: + _version._derived_version = "2_3" + old_state_dict, old_mappings = flatten_state_dict( + self.original_state_dict + ) + old_keys = set(old_state_dict.keys()) + if old_keys & missing_keys: + self.state_dict, self.mappings = old_state_dict, old_mappings + # _derived_version is only used by flatten_state_dict now. + # Set it back to None so that later we can save to a new version. + _version._derived_version = None + + return create_default_local_load_plan( + self.state_dict, self.metadata, not self.allow_partial_load + ) + + def create_global_plan(self, global_plan: list[LoadPlan]) -> list[LoadPlan]: + return create_default_global_load_plan(global_plan) + + def finish_plan(self, new_plan: LoadPlan) -> LoadPlan: + return new_plan + + def load_bytes(self, read_item: ReadItem, value: io.BytesIO) -> None: + if self.flatten_state_dict: + set_element( + self.original_state_dict, + self.mappings[read_item.dest_index.fqn], + torch.load(value, weights_only=False), + ) + else: + self.state_dict[read_item.dest_index.fqn] = torch.load( + value, weights_only=False + ) + + def resolve_tensor(self, read_item: ReadItem): + tensor = self.lookup_tensor(read_item.dest_index) + return self.transform_tensor(read_item, tensor) + + def commit_tensor(self, read_item: ReadItem, tensor: torch.Tensor) -> None: + pass + + def lookup_tensor(self, index: MetadataIndex) -> torch.Tensor: + """Extension from the planner interface to make it easy to extend the default planner.""" + return find_state_dict_object(self.state_dict, index) + + def transform_tensor(self, read_item: ReadItem, tensor: torch.Tensor): + """Extension from the planner interface to make it easy to extend the default planner.""" + return narrow_tensor_by_index(tensor, read_item.dest_offsets, read_item.lengths) + + +class _EmptyStateDictLoadPlanner(DefaultLoadPlanner): + """ + Extension of DefaultLoadPlanner, which rebuilds state_dict from the saved metadata. + Useful for loading in state_dict without first initializing a model, such as + when converting a DCP checkpoint into a Torch save file. + + . N.B. `state_dict` must be an empty dictionary when used with this LoadPlanner + + .. warning:: + Because the entire state dict is initialized, It's recommended to only utilize + this LoadPlanner on a single rank or process to avoid OOM. + + """ + + def __init__(self, keys=None, *args, **kwargs): + self.keys = keys + super().__init__(*args, **kwargs) + + def _should_include_key(self, key: str, metadata: Metadata) -> bool: + if self.keys is None: + return True + + if key in self.keys: + True + + unflattened_keys: list[str] = [] + planner_data = metadata.planner_data.get(key) + for unflattened_key in planner_data: + if unflattened_keys: + unflattened_keys.append( + ".".join([unflattened_keys[-1], str(unflattened_key)]) + ) + + else: + unflattened_keys.append(unflattened_key) + + if any(unflattened_key in self.keys for unflattened_key in unflattened_keys): + return True + + return False + + def set_up_planner( + self, + state_dict: STATE_DICT_TYPE, + metadata: Optional[Metadata] = None, + is_coordinator: bool = False, + ) -> None: + assert not state_dict + assert metadata is not None + + # rebuild the state dict from the metadata + for k, v in metadata.state_dict_metadata.items(): + if not self._should_include_key(k, metadata): + continue + + if isinstance(v, TensorStorageMetadata): + v = torch.empty(v.size, dtype=v.properties.dtype) # type: ignore[assignment] + if metadata.planner_data is not None and k in metadata.planner_data: + set_element(state_dict, metadata.planner_data[k], v) + else: + state_dict[k] = v + + super().set_up_planner(state_dict, metadata, is_coordinator) + + +def create_default_local_load_plan( + state_dict: dict[str, Any], metadata: Metadata, strict: bool = True +) -> LoadPlan: + requests = [] + """ + Create the ``LoadPlan`` used by DefaultLoadPlanner. + + It produces one read item per value in ``state_dict`` using the metadata in ``metadata``. + + The default behavior is to match key exactly between state_dict and metadata. + It handles resharding by issuing multiple read requests against storage in order to match + load requirements. + """ + + for fqn, obj in state_dict.items(): + # ignore state_dict keys which do not exist in `state_dict` if strict=False + if fqn not in metadata.state_dict_metadata: + if strict: + raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.") + else: + continue + + md = metadata.state_dict_metadata[fqn] + if ( + isinstance(md, TensorStorageMetadata) + and getattr(obj, "size", None) is not None + and md.size != obj.size() + ): + raise ValueError( + f"Size mismatch between saved {md.size} and current: {obj.size()} for {fqn}", + ) + # Since DTensor supports submesh, adding extra check to ensure _create_read_items() + # gets called only when the current rank is part of the mesh for the corresponding DTensor. + if isinstance(obj, DTensor): + if obj.device_mesh.get_coordinate() is not None: + requests += _create_read_items(fqn, md, obj) + else: + requests += _create_read_items(fqn, md, obj) + + return LoadPlan(requests) + + +def create_default_global_load_plan( + all_plans: list[LoadPlan], +) -> list[LoadPlan]: + """ + Create global load plan used by DefaultLoadPlanner. + + The default load behavior involved no global coordination and this function + currently doesn't change the local plans. + """ + return all_plans + + +def create_default_local_save_plan( + state_dict: dict[str, Any], is_coordinator: bool +) -> SavePlan: + """ + Create the ``SavePlan`` used by DefaultSavePlanner. + + On non-coordinator ranks, this function ignores tensors and non-tensor objects, + only producing writes for ShardedTensor objects. + + On the coordinator rank, produce writes for all values. + """ + requests = [] + for fqn, obj in state_dict.items(): + # Since DTensor supports submesh, adding extra check to ensure _create_write_items() + # gets called only when the current rank is part of the mesh for the corresponding DTensor. + if isinstance(obj, DTensor): + if obj.device_mesh.get_coordinate() is not None: + requests += _create_write_items(fqn, obj) + else: + # For the plain tensor and non-tensor values, add the request for all + # the ranks. Coordinator will decides whether to deduplicate the + # values based on the keys. + requests += _create_write_items(fqn, obj) + + return SavePlan(requests) + + +def create_default_global_save_plan( + all_plans: list[SavePlan], + rewrite_index_hints: bool = True, +) -> tuple[list[SavePlan], Metadata]: + """ + Create the global plan and metadata used by DefaultSavePlanner. + + Metadata is produced by concatenating the metadata of all ``WriteItem`` from the supplied plans. + + The only global planning change is to update index hints in all ``MetadataIndex`` objects if + ``rewrite_index_hints`` is True. + """ + md: dict[str, STORAGE_TYPES] = {} + new_plans = [] + for plan in all_plans: + new_items = [] + for item in plan.items: + if not item.type == WriteItemType.SHARD: + assert item.index.fqn not in md + + if item.type == WriteItemType.BYTE_IO: + md[item.index.fqn] = BytesStorageMetadata() + new_items.append(item) + else: + assert item.tensor_data is not None + tensor_md = cast( + TensorStorageMetadata, + md.setdefault( + item.index.fqn, + TensorStorageMetadata( + properties=item.tensor_data.properties, + size=item.tensor_data.size, + chunks=[], + ), + ), + ) + new_item = item + if rewrite_index_hints: + new_index = dataclasses.replace( + item.index, index=len(tensor_md.chunks) + ) + new_item = dataclasses.replace(item, index=new_index) + new_items.append(new_item) + + assert item.tensor_data.chunk is not None, f""" + Cannot create MD for tensor without bounds. + FQN: {item.index.fqn} + """ + tensor_md.chunks.append(item.tensor_data.chunk) + new_plans.append(dataclasses.replace(plan, items=new_items)) + return (new_plans, Metadata(md)) + + +def _create_default_local_metadata(state_dict: STATE_DICT_TYPE) -> Metadata: + """Return the ``Metadata`` if DefaultSavePlanner was used to checkpoint ``state_dict``.""" + plan = _create_default_metadata_only_plan(state_dict) + _, md = create_default_global_save_plan([plan]) + return md + + +def _check_box_overlap(box0: ChunkStorageMetadata, box1: ChunkStorageMetadata) -> bool: + """Check if two boxes overlap. Tuples are (offset, lengths).""" + # For each dim of each shard, check if one shard resides on the other + # end of second shard with respect to that dim. As an example for a 2D + # shard, we would check if one shard is above or on the left of the + # other shard. + ndims = len(box0.offsets) + for i in range(ndims): + if box0.offsets[i] >= box1.offsets[i] + box1.sizes[i]: + return False + if box1.offsets[i] >= box0.offsets[i] + box0.sizes[i]: + return False + + return True + + +def _check_box_bounds( + outer_box_size: torch.Size, inner_box: ChunkStorageMetadata +) -> bool: + for i in range(len(outer_box_size)): + if inner_box.offsets[i] < 0: + return False + if inner_box.sizes[i] < 0: + return False + if inner_box.offsets[i] + inner_box.sizes[i] > outer_box_size[i]: + return False + + return True + + +def _validate_global_plan(global_plan: list[SavePlan], metadata: Metadata) -> bool: + all_good = True + for key, value in metadata.state_dict_metadata.items(): + if isinstance(value, BytesStorageMetadata): + continue + if len(value.size) == 0: + continue + chunks_volume = 0 + for chunk_idx, chunk0 in enumerate(value.chunks): + # Compute the volume + if not _check_box_bounds(value.size, chunk0): + logger.warning( + """ + key:%s has out of bounds chunk: + tensor-size:%s chunk: %s + """, + key, + value.size, + chunk0, + ) + all_good = False + chunks_volume += reduce(operator.mul, chunk0.sizes, 1) + + # Check for overlap + for chunk1 in value.chunks[chunk_idx + 1 :]: + if _check_box_overlap(chunk0, chunk1): + logger.warning( + "key:%s has overlapping chunks: %s %s", key, chunk0, chunk1 + ) + all_good = False + + # Check whether combined chunk cover the whole tensor + tensor_volume = reduce(operator.mul, value.size, 1) + if chunks_volume != tensor_volume: + logger.warning( + """ + key:%s invalid fill tensor-volume: + %s chunks-volume: %s + """, + key, + tensor_volume, + chunks_volume, + ) + all_good = False + + return all_good diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/filesystem.py b/phivenv/Lib/site-packages/torch/distributed/checkpoint/filesystem.py new file mode 100644 index 0000000000000000000000000000000000000000..9ba706ab19b615e21f495e2b44cc5b639b56edf0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/checkpoint/filesystem.py @@ -0,0 +1,976 @@ +# mypy: allow-untyped-defs +import collections +import dataclasses +import io +import json +import operator +import os +import pickle +import queue +import threading +import uuid +import warnings +from abc import ABC, abstractmethod +from collections.abc import Generator, Iterable, Iterator, Sequence +from contextlib import contextmanager +from dataclasses import dataclass +from enum import Enum +from io import UnsupportedOperation +from pathlib import Path +from typing import Any, Callable, cast, IO, Optional, Union + +# introduced as collections.abc.Buffer in Python 3.12 +from typing_extensions import Buffer + +import torch +from torch import Tensor +from torch._utils import _get_available_device_type, _get_device_module +from torch.distributed._shard._utils import narrow_tensor_by_index +from torch.distributed.checkpoint._extension import ( + ExtensionRegistry, + StreamTransformExtension, +) +from torch.distributed.checkpoint._hf_utils import ( + CUSTOM_METADATA_KEY, + DCP_VERSION_KEY, + FORMAT_KEY, + FORMAT_VALUE, + HF_DCP_VERSION, +) +from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE, StorageMeta +from torch.distributed.checkpoint.planner import ( + LoadItemType, + LoadPlan, + LoadPlanner, + ReadItem, + SavePlan, + SavePlanner, + WriteItem, + WriteItemType, +) +from torch.distributed.checkpoint.staging import BlockingAsyncStager +from torch.distributed.checkpoint.storage import ( + StorageReader, + StorageWriter, + WriteResult, +) +from torch.distributed.checkpoint.utils import _create_file_view +from torch.futures import Future + + +__all__ = [ + "FileSystemWriter", + "FileSystemReader", + "FileSystem", + "FileSystemBase", + "SerializationFormat", +] + +_metadata_fn: str = ".metadata" + + +@dataclass +class _StorageInfo: + """This is the per entry storage info.""" + + relative_path: str + offset: int + length: int + transform_descriptors: Optional[Sequence[str]] = None + + def __getstate__(self): + return {k: v for k, v in self.__dict__.items() if v is not None} + + +@dataclass +class _StoragePrefix: + prefix: str + + +class SerializationFormat(Enum): + TORCH_SAVE = "torch_save" + SAFETENSORS = "safetensors" + + +DEFAULT_SUFFIX = ".distcp" + + +def _generate_uuid() -> str: + return str(uuid.uuid4()) + + +class _TensorLoader(ABC): + @abstractmethod + def add(self, size: int, obj: object) -> None: + pass + + @abstractmethod + def start_loading(self) -> None: + pass + + @abstractmethod + def values(self) -> Iterator[tuple[torch.Tensor, object]]: + pass + + +class _SerialCpuLoader(_TensorLoader): + def __init__(self, resolve_fun: Callable) -> None: + self.resolve_fun = resolve_fun + self.items: list[tuple[int, object]] = [] + + def add(self, size: int, obj: object) -> None: + self.items.append((size, obj)) + + def start_loading(self) -> None: + pass + + def values(self) -> Iterator[tuple[torch.Tensor, object]]: + for _, obj in self.items: + tensor = self.resolve_fun(obj).detach() + tensor = tensor.cpu() + if tensor.storage().size() != tensor.numel(): + tensor = tensor.clone() + yield ( + tensor, + obj, + ) + + +class _OverlappingCpuLoader(_TensorLoader): + def __init__( + self, + resolve_fun: Callable, + stream: Optional[torch.Stream] = None, + inflight_threshhold: int = 1_000_000, + ) -> None: + self.resolve_fun = resolve_fun + self.items: list[tuple[int, object]] = [] + self.inflight_threshhold = inflight_threshhold + self.in_flight_data = 0 + self.current_items: collections.deque = collections.deque() + self.idx = 0 + self.started = False + self.device_type = ( + stream.device_type if stream else _get_available_device_type() + ) + self.device_module = _get_device_module(self.device_type) + self.stream = cast( + torch.cuda.Stream, stream or self.device_module.current_stream() + ) + if self.stream != self.device_module.current_stream(): + self.stream.wait_stream(self.device_module.current_stream()) + + @property + def _done(self) -> bool: + return self.idx >= len(self.items) + + def _drain(self) -> list[tuple[torch.Tensor, object]]: + drained = [] + if self.in_flight_data >= self.inflight_threshhold: + self.stream.synchronize() + while self.in_flight_data >= self.inflight_threshhold: + val = self.current_items.popleft() + self.in_flight_data -= val[0].numel() * val[0].element_size() + drained.append(val) + return drained + + def _refill(self) -> None: + with self.device_module.stream(self.stream): + while not self._done and self.in_flight_data < self.inflight_threshhold: + _, obj = self.items[self.idx] + self.idx += 1 + tensor = self.resolve_fun(obj).detach() + if tensor.device.type == self.device_type: + tensor = tensor.to(device="cpu", non_blocking=True) + elif tensor.device == torch.device("cpu"): + if ( + tensor.untyped_storage().size() + != tensor.numel() * tensor.itemsize + ): + # this forces the tensor to be both contiguous and with minimal storage + tensor = tensor.clone() + + self.current_items.append( + ( + tensor, + obj, + ) + ) + self.in_flight_data += tensor.numel() * tensor.element_size() + + def _finish(self) -> Iterable[tuple[torch.Tensor, object]]: + assert self._done + if len(self.current_items) > 0: + self.stream.synchronize() + return self.current_items + + def add(self, size: int, obj: object) -> None: + if self.started: + raise RuntimeError("cannot add items after loading started") + self.items.append((size, obj)) + + def start_loading(self) -> None: + if self.started: + return + self.started = True + self.items.sort(key=operator.itemgetter(0)) + self._refill() + + def values(self) -> Iterator[tuple[torch.Tensor, object]]: + self.start_loading() + while not self._done: + drained = self._drain() + self._refill() + yield from drained + + yield from self._finish() + + +class _StorageWriterTransforms: + """ + This is experimental, and will likely move elsewhere in the + future. It lives here to minimize changes while we are still + learning and gathering feedback. + """ + + def __init__( + self, extensions: Optional[Sequence[StreamTransformExtension]] = None + ) -> None: + """ + If the extensions arg is None, this means the implementation + should provide whatever defaults it chooses. An empty + sequence indicates no extensions should be used. At this + time, the default extensions sequence is empty. + """ + self.extensions = () if extensions is None else extensions + + def transform_save_stream( + self, write_item: WriteItem, raw_stream: io.IOBase + ) -> tuple[IO[bytes], list[str]]: + # In order to avoid leaking fds, transformers' close must + # cascade to wrapped streams, but since this function can + # append to the raw stream, we can't close the actual stream. + # So, we use this to put a wrapper around the raw stream's + # close() to make it a noop, and it gets closed once all files + # are appended. + + class NoCloseWriter(io.IOBase): + def __init__(self, raw: io.IOBase): + self.raw = raw + + def writeable(self) -> bool: + return True + + def write(self, b: Buffer) -> int: + return self.raw.write(b) + + def close(self): + self.flush() + self.raw.flush() + # but not close. + + transform_to = cast(IO[bytes], NoCloseWriter(raw_stream)) + + for ex in self.extensions: + transform_to = ex.transform_to(transform_to) + + return (transform_to, [ex.get_descriptor() for ex in reversed(self.extensions)]) + + +def _item_size(item: WriteItem) -> int: + size = 1 + assert item.tensor_data is not None + # can't use math.prod as PT needs to support older python + for s in item.tensor_data.size: + size *= s + + dtype = item.tensor_data.properties.dtype + return size * torch._utils._element_size(dtype) + + +def _split_by_size_and_type(bins: int, items: list[WriteItem]) -> list[list[WriteItem]]: + if bins == 1: + return [items] + + bytes_w = [wi for wi in items if wi.type == WriteItemType.BYTE_IO] + tensor_w = [wi for wi in items if wi.type != WriteItemType.BYTE_IO] + + buckets: list[list[WriteItem]] = [[] for _ in range(bins)] + bucket_sizes = [0 for _ in range(bins)] + + tensor_w.sort(key=_item_size, reverse=True) + + for i, wi in enumerate(bytes_w): + buckets[i % bins].append(wi) + + for wi in tensor_w: + # TODO replace with headq + idx = min(enumerate(bucket_sizes), key=operator.itemgetter(1))[0] + buckets[idx].append(wi) + bucket_sizes[idx] += _item_size(wi) + + return buckets + + +def _write_item( + transforms: _StorageWriterTransforms, + stream: io.IOBase, + data: Union[io.BytesIO, torch.Tensor], + write_item: WriteItem, + storage_key: str, + serialization_format: SerializationFormat, +) -> WriteResult: + offset = stream.tell() + + (transform_to, transform_descriptors) = transforms.transform_save_stream( + write_item, stream + ) + + if write_item.type == WriteItemType.BYTE_IO: + assert isinstance(data, io.BytesIO) + transform_to.write(data.getbuffer()) + else: + assert isinstance(data, torch.Tensor) + assert data.device == torch.device("cpu") + if serialization_format == SerializationFormat.TORCH_SAVE: + torch.save(data, transform_to) + + transform_to.close() + + if serialization_format == SerializationFormat.TORCH_SAVE or isinstance( + data, io.BytesIO + ): + length = stream.tell() - offset + else: + length = data.numel() * data.element_size() + + # For consistency with earlier versions, leave this field out of the + # metadata if there are no extensions. + info_transform_descriptors = ( + None if len(transform_descriptors) == 0 else transform_descriptors + ) + + return WriteResult( + index=write_item.index, + size_in_bytes=length, + storage_data=_StorageInfo( + storage_key, + offset, + length, + transform_descriptors=info_transform_descriptors, + ), + ) + + +def _write_files_from_queue( + create_stream: Callable, + file_queue: queue.Queue, + result_queue: queue.Queue, + planner: SavePlanner, + transforms: _StorageWriterTransforms, + inflight_threshhold: int, + use_fsync: bool, + thread_count: int, + serialization_format: SerializationFormat, +) -> None: + try: + while True: + file_name, storage_key, write_items = file_queue.get_nowait() + loader: _TensorLoader + + custom_backend_name = torch._C._get_privateuse1_backend_name() + custom_device_mod = getattr(torch, custom_backend_name, None) + + # TODO: Using the OverlappingCpuLoader with multiple threads creates significant + # performance degradation, observed as being related to cuda stream syncs. We + # should try to fix this and use _OverlappingCpuLoader for all threaded cases + if ( + thread_count == 1 + and ( + torch.cuda.is_available() + or (custom_device_mod and custom_device_mod.is_available()) + ) + and inflight_threshhold > 0 + ): + loader = _OverlappingCpuLoader( + planner.resolve_data, + inflight_threshhold=inflight_threshhold, + ) + else: + loader = _SerialCpuLoader( + planner.resolve_data, + ) + + tensor_w = [wi for wi in write_items if wi.type != WriteItemType.BYTE_IO] + for write_item in tensor_w: + loader.add(_item_size(write_item), write_item) + loader.start_loading() + + bytes_w = [wi for wi in write_items if wi.type == WriteItemType.BYTE_IO] + write_results = [] + + with create_stream(file_name, "wb") as stream: + for write_item in bytes_w: + data = planner.resolve_data(write_item) + write_results.append( + _write_item( + transforms, + stream, + data, + write_item, + storage_key, + serialization_format, + ) + ) + + tensor_dict = {} + metadata_dict = {} + for tensor, write_item in loader.values(): + assert tensor.is_cpu + write_results.append( + _write_item( + transforms, + stream, + tensor, + write_item, # type: ignore[arg-type] + storage_key, + serialization_format, + ) + ) + tensor_dict[write_item.index.fqn] = tensor # type: ignore[attr-defined] + metadata_dict[write_item.index.fqn] = { # type: ignore[attr-defined] + "saved_offsets": write_item.tensor_data.chunk.offsets # type: ignore[attr-defined] + } + + if serialization_format == SerializationFormat.SAFETENSORS: + from safetensors.torch import save # type: ignore[import-not-found] + + stream.write( + save( + tensor_dict, + metadata={ + CUSTOM_METADATA_KEY: json.dumps(metadata_dict), + DCP_VERSION_KEY: str(HF_DCP_VERSION), + FORMAT_KEY: FORMAT_VALUE, + }, + ) + ) + + if use_fsync: + try: + os.fsync(stream.fileno()) + except (AttributeError, UnsupportedOperation): + os.sync() + stream.close() + result_queue.put(write_results) + except queue.Empty: + pass + + +class FileSystemBase(ABC): + @contextmanager + @abstractmethod + def create_stream( + self, path: Union[str, os.PathLike], mode: str + ) -> Generator[io.IOBase, None, None]: ... + + @abstractmethod + def concat_path( + self, path: Union[str, os.PathLike], suffix: str + ) -> Union[str, os.PathLike]: ... + + @abstractmethod + def rename( + self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike] + ) -> None: ... + + @abstractmethod + def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]: ... + + @abstractmethod + def mkdir(self, path: Union[str, os.PathLike]) -> None: ... + + @classmethod + @abstractmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: ... + + @abstractmethod + def exists(self, path: Union[str, os.PathLike]) -> bool: ... + + @abstractmethod + def rm_file(self, path: Union[str, os.PathLike]) -> None: ... + + +class FileSystem(FileSystemBase): + @contextmanager + def create_stream( + self, path: Union[str, os.PathLike], mode: str + ) -> Generator[io.IOBase, None, None]: + if not isinstance(path, Path): + path = Path(path) + with path.open(mode) as stream: + yield cast(io.IOBase, stream) + + def concat_path( + self, path: Union[str, os.PathLike], suffix: str + ) -> Union[str, os.PathLike]: + if not isinstance(path, Path): + path = Path(path) + return path / suffix + + def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]: + if not isinstance(path, Path): + path = Path(path) + return path + + def rename( + self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike] + ) -> None: + if not isinstance(path, Path): + path = Path(path) + + path.rename(cast(Path, new_path)) + + def mkdir(self, path: Union[str, os.PathLike]) -> None: + if not isinstance(path, Path): + path = Path(path) + path.mkdir(parents=True, exist_ok=True) + + @classmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + if isinstance(checkpoint_id, Path): + return True + + if "://" in str(checkpoint_id): + return False + + for p in Path(checkpoint_id).parents: + if p.exists() and os.access(str(p), os.W_OK): + return True + + return False + + def exists(self, path: Union[str, os.PathLike]) -> bool: + if not isinstance(path, Path): + path = Path(path) + return path.exists() + + def rm_file(self, path: Union[str, os.PathLike]) -> None: + if not isinstance(path, Path): + path = Path(path) + path.unlink() + + def ls(self, path: Union[str, os.PathLike]) -> list[str]: + if not isinstance(path, Path): + path = Path(path) + return [str(p) for p in path.iterdir()] + + +class _FileSystemWriter(StorageWriter): + """ + Basic implementation of StorageWriter using file IO. + + This implementation makes the following assumptions and simplifications: + + * The checkpoint path is an empty or non-existing directory. + * File creation is atomic + + The checkpoint consist of one file per write request plus + a `.metadata` file with the serialized metadata. + + """ + + def __init__( + self, + path: Union[str, os.PathLike], + single_file_per_rank: bool = True, + sync_files: bool = True, + thread_count: int = 1, + per_thread_copy_ahead: int = 10_000_000, + overwrite: bool = True, + _extensions: Optional[Sequence[StreamTransformExtension]] = None, + serialization_format: SerializationFormat = SerializationFormat.TORCH_SAVE, + *args: Any, + **kwargs: Any, + ) -> None: + """ + Initialize the writer pointing to `path`. + + Args: + path: directory where the checkpoint will be written to. + single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True. + sync_files : force files to be synced to permanent storage. Default to True. + thread_count: Number of IO threads to use to write. Default to 1. + per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb. + overwrite: Whether to allow overwriting existing checkpoints. Defaults to True. + _extensions: Extensions to apply to output streams (EXPERIMENTAL) + + N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure. + """ + super().__init__() + self.fs = FileSystem() + self.path = self.fs.init_path(path) + self.single_file_per_rank = single_file_per_rank + self.sync_files = sync_files + self.thread_count = thread_count + self.per_thread_copy_ahead = per_thread_copy_ahead + self.save_id = _generate_uuid() + self.overwrite = overwrite + self.transforms = _StorageWriterTransforms(_extensions) + self.serialization_format = serialization_format + + def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None: + if checkpoint_id: + self.path = self.fs.init_path(checkpoint_id) + self.save_id = _generate_uuid() + + def set_up_storage_writer(self, is_coordinator: bool) -> None: + pass + + def prepare_local_plan(self, plan: SavePlan) -> SavePlan: + self.fs.mkdir(self.path) + if self.fs.exists(self.metadata_path): + if self.overwrite: + warnings.warn( + f"Detected an existing checkpoint in {self.metadata_path}, overwriting since {self.overwrite=}." + " Past version 2.5 of PyTorch, `overwrite` will default to False. Set this variable to True to" + " maintain this functionality or False to raise when an existing checkpoint is found." + ) + else: + raise RuntimeError(f"Checkpoint already exists and {self.overwrite=}.") + + return plan + + def prepare_global_plan(self, plans: list[SavePlan]) -> list[SavePlan]: + new_plans = [ + dataclasses.replace(plan, storage_data=_StoragePrefix(f"__{i}_")) + for i, plan in enumerate(plans) + ] + return new_plans + + def write_data( + self, + plan: SavePlan, + planner: SavePlanner, + ) -> Future[list[WriteResult]]: + storage_plan: _StoragePrefix = plan.storage_data + file_count = 0 + + def gen_file(): + nonlocal file_count + file_name = f"{storage_plan.prefix}{file_count}{DEFAULT_SUFFIX}" + file_count += 1 + return file_name + + file_queue: queue.Queue = queue.Queue() + if self.single_file_per_rank: + for bucket in _split_by_size_and_type(self.thread_count, plan.items): + file_name = gen_file() + path = self.fs.concat_path(self.path, file_name) + file_queue.put((path, file_name, bucket)) + else: + for item in plan.items: + file_name = gen_file() + path = self.fs.concat_path(self.path, file_name) + file_queue.put((path, file_name, [item])) + + return self._write_data(planner, file_queue) + + def _write_data( + self, + planner: SavePlanner, + file_queue: queue.Queue, + ) -> Future[list[WriteResult]]: + result_queue: queue.Queue = queue.Queue() + + threads = [] + for _ in range(1, self.thread_count): + t = threading.Thread( + target=_write_files_from_queue, + args=( + self.fs.create_stream, + file_queue, + result_queue, + planner, + self.transforms, + self.per_thread_copy_ahead, + self.sync_files, + self.thread_count, + self.serialization_format, + ), + ) + t.start() + threads.append(t) + + _write_files_from_queue( + create_stream=self.fs.create_stream, + file_queue=file_queue, + result_queue=result_queue, + planner=planner, + transforms=self.transforms, + inflight_threshhold=self.per_thread_copy_ahead, + use_fsync=self.sync_files, + thread_count=self.thread_count, + serialization_format=self.serialization_format, + ) + + for t in threads: + t.join() + + res = [] + try: + while True: + res += result_queue.get_nowait() + except queue.Empty: + fut: Future[list[WriteResult]] = Future() + fut.set_result(res) + return fut + + def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None: + storage_md = {} + for wr_list in results: + storage_md.update({wr.index: wr.storage_data for wr in wr_list}) + metadata.storage_data = storage_md + + metadata.storage_meta = self.storage_meta() + + tmp_path = cast(Path, self.fs.concat_path(self.path, f"{_metadata_fn}.tmp")) + with self.fs.create_stream(tmp_path, "wb") as metadata_file: + pickle.dump(metadata, metadata_file) + if self.sync_files: + try: + os.fsync(metadata_file.fileno()) + except (AttributeError, UnsupportedOperation): + os.sync() + + # delete in-case other checkpoints were present. + if self.fs.exists(self.metadata_path): + self.fs.rm_file(self.metadata_path) + + self.fs.rename(tmp_path, self.metadata_path) + + def storage_meta(self) -> Optional[StorageMeta]: + return StorageMeta(checkpoint_id=self.checkpoint_id, save_id=self.save_id) + + @property + def metadata_path(self) -> Union[str, os.PathLike]: + return cast(Path, self.fs.concat_path(self.path, _metadata_fn)) + + @property + def checkpoint_id(self) -> Union[str, os.PathLike]: + """ + return the checkpoint_id that will be used to save the checkpoint. + """ + return self.path + + @classmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + return FileSystem.validate_checkpoint_id(checkpoint_id) + + +class _StorageReaderTransforms: + """ + This is experimental, and will likely move elsewhere in the + future. It lives here to minimize changes while we are still + learning and gathering feedback. + """ + + def __init__(self, extension_registry: Optional[ExtensionRegistry] = None) -> None: + self.extension_registry = ( + ExtensionRegistry() if extension_registry is None else extension_registry + ) + + def transform_load_stream( + self, + read_item: ReadItem, + transform_descriptors: Sequence[str], + raw_stream: IO[bytes], + ) -> IO[bytes]: + extensions = self.extension_registry.from_descriptor_list(transform_descriptors) + transform_from = raw_stream + for ex in extensions: + if isinstance(ex, StreamTransformExtension): + transform_from = ex.transform_from(transform_from) + return transform_from + + +class FileSystemReader(StorageReader): + def __init__( + self, + path: Union[str, os.PathLike], + _extension_registry: Optional[ExtensionRegistry] = None, # EXPERIMENTAL + ) -> None: + super().__init__() + self.fs = FileSystem() + self.path = self.fs.init_path(path) + self.storage_data: dict[Any, Any] = {} + self.load_id = _generate_uuid() + self.transforms = _StorageReaderTransforms(_extension_registry) + + def _slice_file(self, file, sinfo: _StorageInfo) -> IO[bytes]: + return cast(IO[bytes], _create_file_view(file, sinfo.offset, sinfo.length)) + + def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None: + self.storage_data = {} + if checkpoint_id: + self.path = self.fs.init_path(checkpoint_id) + self.load_id = _generate_uuid() + + def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: + # group requests by file + per_file: dict[str, list[ReadItem]] = {} + for read_item in plan.items: + item_md: _StorageInfo = self.storage_data[read_item.storage_index] + path = item_md.relative_path + per_file.setdefault(path, []).append(read_item) + + for relative_path, reqs in per_file.items(): + new_path = self.fs.concat_path(self.path, relative_path) + with self.fs.create_stream(new_path, "rb") as stream: + # TODO sort by offset and cache the reading + for req in reqs: + item_md = self.storage_data[req.storage_index] + file_slice = self._slice_file(stream, item_md) + transform_from = self.transforms.transform_load_stream( + req, + # This field wasn't present in older + # implementations so provide a fallback. + item_md.transform_descriptors or (), + file_slice, + ) + + if req.type == LoadItemType.BYTE_IO: + read_bytes = io.BytesIO(transform_from.read(-1)) + read_bytes.seek(0) + planner.load_bytes(req, read_bytes) + else: + if transform_from.seekable(): + seekable = transform_from + else: + # torch.load requires a seekable input, so read the transform + # stream now and store the output if needed + seekable = io.BytesIO(transform_from.read(-1)) + seekable.seek(0) + + tensor = cast( + Tensor, + torch.load( + seekable, + map_location="cpu", + weights_only=True, + ), + ) + tensor = narrow_tensor_by_index( + tensor, req.storage_offsets, req.lengths + ) + target_tensor = planner.resolve_tensor(req).detach() + + assert target_tensor.size() == tensor.size(), ( + f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}" + ) + target_tensor.copy_(tensor) + planner.commit_tensor(req, target_tensor) + + fut: Future = Future() + fut.set_result(None) + return fut + + # Implementing the abstract function in StorageReader + def read_metadata(self) -> Metadata: + path = self.fs.concat_path(self.path, ".metadata") + with self.fs.create_stream(path, "rb") as metadata_file: + metadata = pickle.load(metadata_file) + + if getattr(metadata, "storage_meta", None) is None: + metadata.storage_meta = StorageMeta() + metadata.storage_meta.load_id = self.load_id + + return metadata + + def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None: + self.storage_data = metadata.storage_data + assert self.storage_data is not None + + def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan: + return plan + + def prepare_global_plan(self, plans: list[LoadPlan]) -> list[LoadPlan]: + return plans + + @property + def checkpoint_id(self) -> Union[str, os.PathLike]: + """ + return the checkpoint_id that will be used to load the checkpoint. + """ + return self.path + + @classmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + return FileSystem.validate_checkpoint_id(checkpoint_id) + + +class FileSystemWriter(_FileSystemWriter, BlockingAsyncStager): + """ + Basic implementation of StorageWriter using file IO. + + This implementation makes the following assumptions and simplifications: + + * The checkpoint path is an empty or non-existing directory. + * File creation is atomic + + The checkpoint consist of one file per write request plus + a `.metadata` file with the serialized metadata. + + """ + + def __init__( + self, + path: Union[str, os.PathLike], + single_file_per_rank: bool = True, + sync_files: bool = True, + thread_count: int = 1, + per_thread_copy_ahead: int = 10_000_000, + cache_staged_state_dict: bool = False, + overwrite: bool = True, + _extensions: Optional[Sequence[StreamTransformExtension]] = None, + serialization_format: SerializationFormat = SerializationFormat.TORCH_SAVE, + ) -> None: + """ + Initialize the writer pointing to `path`. + + Args: + path: directory where the checkpoint will be written to. + single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True. + sync_files : force files to be synced to permanent storage. Default to True. + thread_count: Number of IO threads to use to write. Default to 1. + per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb. + cache_staged_state_dict: Whether to cache the staged state_dict. This option decreases staging latency + at the cost of increases memory usage. Additionally, if this parameter is set to True, it's the expectation + that the stager is maintained and reused for multiple dcp.async_save calls. Default to False. + overwrite: Whether to allow overwriting existing checkpoints. Defaults to True. + _extensions: Extensions to apply to output streams (EXPERIMENTAL) + + N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure. + """ + _FileSystemWriter.__init__( + self, + path=path, + single_file_per_rank=single_file_per_rank, + sync_files=sync_files, + thread_count=thread_count, + per_thread_copy_ahead=per_thread_copy_ahead, + overwrite=overwrite, + _extensions=_extensions, + serialization_format=serialization_format, + ) + BlockingAsyncStager.__init__( + self, + cache_staged_state_dict=cache_staged_state_dict, + ) + + def stage(self, state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE: + """Override of AsyncStager.stage""" + # in the async case, the state dict is already on CPU, so maintaining this + # buffer makes no sense + self.per_thread_copy_ahead = 0 + return super().stage(state_dict) diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/format_utils.py b/phivenv/Lib/site-packages/torch/distributed/checkpoint/format_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..601c8d882178358a6bbac8c00952cc72132c8ab9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/checkpoint/format_utils.py @@ -0,0 +1,280 @@ +# mypy: allow-untyped-defs +import argparse +import os +from enum import Enum +from typing import cast, Optional, Union + +import torch +import torch.distributed as dist +from torch.distributed._shard._utils import narrow_tensor_by_index +from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter +from torch.distributed.checkpoint._nested_dict import flatten_state_dict +from torch.distributed.checkpoint.default_planner import ( + _EmptyStateDictLoadPlanner, + DefaultLoadPlanner, +) +from torch.distributed.checkpoint.metadata import ( + Metadata, + STATE_DICT_TYPE, + STORAGE_TYPES, + TensorProperties, + TensorStorageMetadata, +) +from torch.distributed.checkpoint.planner import LoadItemType, LoadPlan, LoadPlanner +from torch.distributed.checkpoint.planner_helpers import _create_chunk_list +from torch.distributed.checkpoint.state_dict_loader import _load_state_dict +from torch.distributed.checkpoint.state_dict_saver import _save_state_dict +from torch.distributed.checkpoint.storage import StorageReader +from torch.futures import Future + + +__all__ = [ + "dcp_to_torch_save", + "torch_save_to_dcp", + "BroadcastingTorchSaveReader", + "DynamicMetaLoadPlanner", +] + + +class BroadcastingTorchSaveReader(StorageReader): + """ + StorageReader for reading a Torch Save file. This reader will read the entire checkpoint + on the coordinator rank, and then broadcast and shard each tensor to all ranks. + + . N.B. Intended to be used with DynamicMetaLoadPlanner + + .. warning:: + Current implementation only supports loading Tensors. + + >>> # xdoctest: +SKIP("undefined vars") + >>> sd = {"mode": model} + >>> dcp.load( + >>> sd, + >>> storage_reader=BroadcastingTorchSaveReader(), + >>> planner=DynamicMetaLoadPlanner(), + >>> checkpoint_id="path_to_model.pt" + >>> ) + """ + + def __init__( + self, + checkpoint_id: Optional[Union[str, os.PathLike]] = None, + coordinator_rank: int = 0, + ) -> None: + self.checkpoint_id = checkpoint_id + self.coordinator_rank = coordinator_rank + + def read_metadata(self) -> Metadata: + """Extends the default StorageReader to support building the metadata file""" + # Metadata is built in planner.set_up_planner, since we are not actually reading metadata from + # the disk + return Metadata(state_dict_metadata={}) + + def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: + """ + Reads torch save data on the coordinator rank, and broadcast afterwards + this incurrs a communication cost, but avoids having to load + the entire checkpoint on each rank, hopefully preventing OOM issues + """ + planner = cast(DefaultLoadPlanner, planner) + + # data is read in on the coordinator rank, and broadcast afterwards + # this incurrs a communication cost, but it avoids having to load + # the entire checkpoint on each rank, hopefully preventing OOM issues + # TODO: read on each host, instead of only the coordinator + if self.is_coordinator: + assert self.checkpoint_id is not None + torch_state_dict = torch.load( + self.checkpoint_id, map_location="cpu", weights_only=False + ) + if planner.flatten_state_dict: + torch_state_dict, _ = flatten_state_dict(torch_state_dict) + else: + torch_state_dict = None + + for req in plan.items: + if req.type == LoadItemType.BYTE_IO: + raise RuntimeError( + f"Non-tensor value identified at {req.storage_index.fqn}. " + f"At this time {type(self).__name__} only supports loading Tensors." + ) + + # Broadcast the tensor from the coordinator rank + if self.is_coordinator: + pg_device = dist.distributed_c10d._get_pg_default_device() + tensor = torch_state_dict[req.storage_index.fqn].to(pg_device) + else: + tensor = torch.empty_like(planner.state_dict[req.storage_index.fqn]) + + dist.broadcast(tensor, src=self.coordinator_rank, async_op=False) + + tensor = narrow_tensor_by_index(tensor, req.storage_offsets, req.lengths) + target_tensor = planner.resolve_tensor(req).detach() + assert target_tensor.size() == tensor.size(), ( + f"req {req.storage_index} mismatch sizes, " + f"{target_tensor.size()} vs {tensor.size()}" + ) + target_tensor.copy_(tensor) + planner.commit_tensor(req, target_tensor) + + fut: Future = Future() + fut.set_result(None) + return fut + + def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None: + """Implementation of the StorageReader method""" + self.is_coordinator = is_coordinator + if self.is_coordinator: + assert dist.get_rank() == self.coordinator_rank + + assert self.checkpoint_id is not None + + def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan: + """Implementation of the StorageReader method""" + return plan + + def prepare_global_plan(self, global_plan: list[LoadPlan]) -> list[LoadPlan]: + """Implementation of the StorageReader method""" + return global_plan + + def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None: + """Implementation of the StorageReader method""" + self.checkpoint_id = checkpoint_id + + @classmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + """Implementation of the StorageReader method""" + return os.path.isfile(checkpoint_id) + + +class DynamicMetaLoadPlanner(DefaultLoadPlanner): + """ + Extension of DefaultLoadPlanner, which creates a new Metadata object based on the passed in state dict, + avoiding the need to read metadata from disk. This is useful when reading formats which don't have a + metadata file, like Torch Save files. + + . N.B. Intended to be used with BroadcastingTorchSaveReader + + .. warning:: + Current implementation only supports loading Tensors. + + >>> # xdoctest: +SKIP("undefined vars") + >>> sd = {"mode": model} + >>> dcp.load( + >>> sd, + >>> storage_reader=BroadcastingTorchSaveReader(), + >>> planner=DynamicMetaLoadPlanner(), + >>> checkpoint_id="path_to_model.pt" + >>> ) + """ + + def set_up_planner( + self, + state_dict: STATE_DICT_TYPE, + metadata: Optional[Metadata] = None, + is_coordinator: bool = False, + ) -> None: + """Setups of the planner, extnding default behavior by creating the Metadata object from the state dict""" + super().set_up_planner(state_dict, metadata, is_coordinator) + + state_dict_metadata: dict[str, STORAGE_TYPES] = {} + for key, tensor in self.state_dict.items(): + if not torch.is_tensor(tensor): + raise RuntimeError( + f"Non-tensor value identified at {key}. " + f"At this time {type(self).__name__} only supports loading Tensors." + ) + + state_dict_metadata[key] = TensorStorageMetadata( + TensorProperties(dtype=tensor.dtype), + tensor.size(), + _create_chunk_list(tensor), + ) + self.metadata = Metadata(state_dict_metadata=state_dict_metadata) + + +def dcp_to_torch_save( + dcp_checkpoint_dir: Union[str, os.PathLike], + torch_save_path: Union[str, os.PathLike], +): + """ + Given a directory containing a DCP checkpoint, this function will convert it into a + Torch save file. + + Args: + dcp_checkpoint_dir: Directory containing the DCP checkpoint. + torch_save_path: Filename to store the converted Torch save file. + + .. warning:: + To avoid OOM, it's recommended to only run this function on a single rank. + """ + sd: STATE_DICT_TYPE = {} + _load_state_dict( + sd, + storage_reader=FileSystemReader(dcp_checkpoint_dir), + planner=_EmptyStateDictLoadPlanner(), + no_dist=True, + ) + torch.save(sd, torch_save_path) + + +def torch_save_to_dcp( + torch_save_path: Union[str, os.PathLike], + dcp_checkpoint_dir: Union[str, os.PathLike], +): + """ + Given the location of a torch save file, converts it into a DCP checkpoint. + + Args: + torch_save_path: Filename of the Torch save file. + dcp_checkpoint_dir: Directory to store the DCP checkpoint. + + .. warning:: + To avoid OOM, it's recommended to only run this function on a single rank. + """ + + state_dict = torch.load(torch_save_path, weights_only=False) + # we don't need stateful behavior here because the expectation is anything loaded by + # torch.load would not contain stateful objects. + _save_state_dict( + state_dict, storage_writer=FileSystemWriter(dcp_checkpoint_dir), no_dist=True + ) + + +if __name__ == "__main__": + + class FormatMode(Enum): + TORCH_TO_DCP = "torch_to_dcp" + DCP_TO_TORCH = "dcp_to_torch" + + # Parse command-line arguments + parser = argparse.ArgumentParser() + parser.add_argument( + "mode", + type=str, + help="Conversion mode", + choices=[m.value for m in FormatMode], + default=FormatMode.TORCH_TO_DCP, + ) + parser.add_argument("src", type=str, help="Path to the source model") + parser.add_argument("dst", type=str, help="Path to the destination model") + args = parser.parse_args() + + print( + f"Converting checkpoint from {args.src} to {args.dst} using method: '{args.mode}'" + ) + checkpoint_missing_warning = ( + f"No checkpoint found at {args.src}. Skipping conversion." + ) + if args.mode == FormatMode.TORCH_TO_DCP.value: + if os.path.isfile(args.src): + torch_save_to_dcp(args.src, args.dst) + else: + print(checkpoint_missing_warning) + elif args.mode == FormatMode.DCP_TO_TORCH.value: + if os.path.isdir(args.src): + dcp_to_torch_save(args.src, args.dst) + else: + print(checkpoint_missing_warning) + else: + raise ValueError(f"Unknown conversion mode: {args.mode}") diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/hf_storage.py b/phivenv/Lib/site-packages/torch/distributed/checkpoint/hf_storage.py new file mode 100644 index 0000000000000000000000000000000000000000..b49410fe3d8627a41aca0b570c9c9d025523aabc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/checkpoint/hf_storage.py @@ -0,0 +1,334 @@ +# mypy: allow-untyped-defs +import dataclasses +import json +import queue +from typing import Any, Optional + +import torch +from torch.distributed._shard._utils import narrow_tensor_by_index +from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter +from torch.distributed.checkpoint._hf_utils import ( + _gen_file_name, + _get_dtype, + _get_safetensors_file_metadata, + _HFStorageInfo, + _metadata_fn, + CUSTOM_METADATA_KEY, + DATA_KEY, + DATA_OFFSETS_KEY, + DEFAULT_EXTRA_METADATA_KEY, + DTYPE_KEY, + SAVED_OFFSETS_KEY, + SHAPE_KEY, + SUFFIX, +) +from torch.distributed.checkpoint.filesystem import SerializationFormat +from torch.distributed.checkpoint.metadata import ( + ChunkStorageMetadata, + Metadata, + MetadataIndex, + StorageMeta, + TensorProperties, + TensorStorageMetadata, +) +from torch.distributed.checkpoint.planner import ( + LoadPlan, + LoadPlanner, + ReadItem, + SavePlan, + SavePlanner, + WriteItem, +) +from torch.distributed.checkpoint.storage import WriteResult +from torch.futures import Future + + +__all__ = ["HuggingFaceStorageWriter", "HuggingFaceStorageReader"] + + +class HuggingFaceStorageWriter(FsspecWriter): + """ + A writer that writes to a huggingface repository in the huggingface format. + Uses Fsspec back-end to communicate with back-end storage. + Fsspec registration of the storage solution is required. + """ + + def __init__( + self, + path: str, + fqn_to_index_mapping: Optional[dict[str, int]] = None, + token: Optional[str] = None, + save_sharded: bool = False, + ) -> None: + """ + Initialize the huggingface writer pointing to path. + + Args: + path: hf directory where the checkpoint will be read from. + Needs to have .safetensors files, but can be from any fsspec supported storage, + including localFS and hf://. + fqn_to_index_mapping: A mapping from tensor FQN to the index of the file that the tensor should be written to. + Indices are from 1 to N, where N is the number of files. If not provided, + the tensors will be written to a single file. If none, then all the tensors on the + same rank will be written to the same file. + token: The token to use to authenticate with huggingface hub. + save_sharded: If True, save the checkpoint as a sharded checkpoint where every rank saves its own shard. + Default is False which assumes full tensors are being saved. + + """ + + if token is not None: + super().__init__( + path=path, + token=token, + serialization_format=SerializationFormat.SAFETENSORS, + ) + else: + super().__init__( + path=path, + serialization_format=SerializationFormat.SAFETENSORS, + ) + self._fqn_to_index_mapping: Optional[dict[str, int]] = fqn_to_index_mapping + self._save_sharded = save_sharded + + def prepare_global_plan(self, plans: list[SavePlan]) -> list[SavePlan]: + new_plans = [] + for i, plan in enumerate(plans, start=1): + storage_data: dict[str, Any] = {} + if self._fqn_to_index_mapping is not None: + storage_data["fqn_to_index_mapping"] = self._fqn_to_index_mapping + if self._save_sharded: + storage_data["shard_index"] = i + + new_plans.append(dataclasses.replace(plan, storage_data=storage_data)) + + return new_plans + + def write_data( + self, + plan: SavePlan, + planner: SavePlanner, + ) -> Future[list[WriteResult]]: + if len(plan.items) == 0: + fut: Future = Future() + fut.set_result([]) + return fut + + # storage_plan is a map from key to file index + storage_data: dict[str, Any] = plan.storage_data + storage_plan: Optional[dict[str, int]] = None + shard_index: Optional[int] = None + if "fqn_to_index_mapping" in storage_data: + storage_plan = storage_data["fqn_to_index_mapping"] + if "shard_index" in storage_data: + shard_index = storage_data["shard_index"] + + buckets = self._split_by_storage_plan(storage_plan, plan.items) + highest_index = max(storage_plan.values()) if storage_plan is not None else 1 + + file_queue: queue.Queue = queue.Queue() + for file_index, write_items in buckets.items(): + file_name = _gen_file_name(file_index, highest_index, shard_index) + file_queue.put( + (self.fs.concat_path(self.path, file_name), file_name, write_items) + ) + + return super()._write_data(planner, file_queue) + + def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None: + if self._save_sharded: + return + + metadata_to_write = {} + storage_md = {} + total_size = 0 + for wr_list in results: + storage_md.update( + {wr.index.fqn: wr.storage_data.relative_path for wr in wr_list} + ) + total_size += sum([wr.storage_data.length for wr in wr_list]) + metadata_to_write["metadata"] = {"total_size": total_size} + metadata_to_write["weight_map"] = storage_md + + metadata_path = self.fs.concat_path(self.path, f"{_metadata_fn}") + with self.fs.create_stream(metadata_path, "w") as metadata_file: + json.dump(metadata_to_write, metadata_file, indent=2) + + def _split_by_storage_plan( + self, storage_plan: Optional[dict[str, int]], items: list[WriteItem] + ) -> dict[int, list[WriteItem]]: + # storage_plan is a map from key to index + if storage_plan is None: + return {1: items} + + buckets = {} + for item in items: + key = item.index.fqn + + idx = storage_plan[key] + if idx not in buckets: + buckets[idx] = [item] + else: + buckets[idx].append(item) + + return buckets + + @property + def metadata_path(self) -> str: + return _metadata_fn + + +class HuggingFaceStorageReader(FsspecReader): + """ + A reader that reads from a huggingface repository in the huggingface format. + Uses in Fsspec back-end to communicate with storage. + Fsspec registration of the storage solution is required. + """ + + def __init__(self, path: str, token: Optional[str] = None) -> None: + """ + Initialize the huggingface reader pointing to path. + + Args: + path: hf directory where the checkpoint will be read from. + Needs to have .safetensors file, but can be from any fsspec supported storage, + including localFS and hf://. + token: The token to use to authenticate with huggingface hub. + """ + + if token is not None: + super().__init__(path=path, token=token) + else: + super().__init__(path=path) + + def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: + from safetensors import deserialize # type: ignore[import-not-found] + + per_file: dict[str, list[ReadItem]] = {} + + for read_item in plan.items: + item_md: _HFStorageInfo = self.storage_data[read_item.storage_index] + file_name = item_md.relative_path + per_file.setdefault(file_name, []).append(read_item) + + for file_name, reqs in per_file.items(): + with self.fs.create_stream(file_name, "rb") as stream: + # TODO: make this more efficient by doing offset reads instead of a + # full deserialization of the file + deserialized = deserialize(stream.read()) + deserialized_dict: dict[str, dict[str, Any]] = { + tensor_info[0]: tensor_info[1] for tensor_info in deserialized + } + + for req in reqs: + item_md = self.storage_data[req.storage_index] + + tensor_bytes = deserialized_dict[req.dest_index.fqn][DATA_KEY] + + tensor = torch.frombuffer( + tensor_bytes, + dtype=item_md.dtype, + ) + tensor = tensor.reshape(item_md.shape) + tensor = narrow_tensor_by_index( + tensor, req.storage_offsets, req.lengths + ) + target_tensor = planner.resolve_tensor(req).detach() + + assert target_tensor.size() == tensor.size(), ( + f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}" + ) + + target_tensor.copy_(tensor) + planner.commit_tensor(req, target_tensor) + + fut: Future = Future() + fut.set_result(None) + return fut + + def read_metadata(self) -> Metadata: + state_dict_metadata: dict[str, TensorStorageMetadata] = {} + storage_data: dict[MetadataIndex, _HFStorageInfo] = {} + + safetensors_files = [] + for file in self.fs.ls(self.path): + if file.endswith(SUFFIX): + safetensors_files.append(file) + + for safetensor_file in safetensors_files: + with self.fs.create_stream(safetensor_file, "rb") as f: + safetensors_metadata, _ = _get_safetensors_file_metadata(f) + custom_metadata = safetensors_metadata.get(DEFAULT_EXTRA_METADATA_KEY) + + dcp_sharding_info = None + if custom_metadata and custom_metadata.get(CUSTOM_METADATA_KEY): + dcp_sharding_info = json.loads( + custom_metadata.get(CUSTOM_METADATA_KEY) + ) + + for key, val in safetensors_metadata.items(): + if key == DEFAULT_EXTRA_METADATA_KEY: + continue + + # construct state_dict_metadata + if dcp_sharding_info is not None: + offset = dcp_sharding_info[key][SAVED_OFFSETS_KEY] + else: + offset = [0] * len(val[SHAPE_KEY]) + + if key not in state_dict_metadata: + state_dict_metadata[key] = TensorStorageMetadata( + properties=TensorProperties( + dtype=_get_dtype(val[DTYPE_KEY]) + ), + size=torch.Size( + [ + saved + offset + for saved, offset in zip(val[SHAPE_KEY], offset) + ] + ), + chunks=[ + ChunkStorageMetadata( + offsets=torch.Size(offset), + sizes=torch.Size(val[SHAPE_KEY]), + ) + ], + ) + else: + state_dict_metadata[key].chunks.append( + ChunkStorageMetadata( + torch.Size(offset), sizes=torch.Size(val[SHAPE_KEY]) + ) + ) + size = list(state_dict_metadata[key].size) + for i in range(len(size)): + size[i] = max(size[i], val[SHAPE_KEY][i] + offset[i]) + state_dict_metadata[key].size = torch.Size(size) + + # construct storage data + if dcp_sharding_info is not None: + metadata_index = MetadataIndex( + fqn=key, offset=dcp_sharding_info[key][SAVED_OFFSETS_KEY] + ) + else: + metadata_index = MetadataIndex( + fqn=key, offset=[0] * len(val[SHAPE_KEY]) + ) + storage_data[metadata_index] = _HFStorageInfo( + relative_path=safetensor_file, + offset=val[DATA_OFFSETS_KEY][0], + length=val[DATA_OFFSETS_KEY][1] - val[DATA_OFFSETS_KEY][0], + shape=torch.Size(val[SHAPE_KEY]), + dtype=_get_dtype(val[DTYPE_KEY]), + ) + + metadata = Metadata( + state_dict_metadata=state_dict_metadata, # type: ignore[arg-type] + storage_data=storage_data, + ) + + if getattr(metadata, "storage_meta", None) is None: + metadata.storage_meta = StorageMeta() + metadata.storage_meta.load_id = self.load_id # type: ignore[union-attr] + + return metadata diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/logger.py b/phivenv/Lib/site-packages/torch/distributed/checkpoint/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..848fcb83983e594f4ac256a0b1c34e919385eeca --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/checkpoint/logger.py @@ -0,0 +1,118 @@ +# mypy: allow-untyped-defs +import functools +import logging +import time +from typing import Any, Callable, TypeVar +from typing_extensions import ParamSpec +from uuid import uuid4 + +import torch.distributed.c10d_logger as c10d_logger +from torch.distributed.checkpoint.logging_handlers import DCP_LOGGER_NAME + + +logger = logging.getLogger() + + +__all__: list[str] = [] + +global _dcp_logger +_dcp_logger = c10d_logger._get_or_create_logger(DCP_LOGGER_NAME) + +_T = TypeVar("_T") +_P = ParamSpec("_P") + + +def _msg_dict_from_dcp_method_args(*args, **kwargs) -> dict[str, Any]: + """ + Extracts log data from dcp method args + """ + msg_dict = {} + + # checkpoint ID can be passed in through the serializer or through the checkpoint id directly + storage_writer = kwargs.get("storage_writer", None) + storage_reader = kwargs.get("storage_reader", None) + planner = kwargs.get("planner", None) + + checkpoint_id = kwargs.get("checkpoint_id", None) + if not checkpoint_id and (serializer := storage_writer or storage_reader): + checkpoint_id = getattr(serializer, "checkpoint_id", None) + + msg_dict["checkpoint_id"] = ( + str(checkpoint_id) if checkpoint_id is not None else checkpoint_id + ) + + # Uniquely identify a _dcp_method_logger wrapped function call. + msg_dict["uuid"] = str(uuid4().int) + + if storage_writer: + msg_dict["storage_writer"] = storage_writer.__class__.__name__ + + if storage_reader: + msg_dict["storage_reader"] = storage_reader.__class__.__name__ + + if planner: + msg_dict["planner"] = planner.__class__.__name__ + + return msg_dict + + +def _get_msg_dict(func_name, *args, **kwargs) -> dict[str, Any]: + msg_dict = _msg_dict_from_dcp_method_args(*args, **kwargs) + msg_dict.update(c10d_logger._get_msg_dict(func_name, *args, **kwargs)) + + return msg_dict + + +def _dcp_method_logger( + log_exceptions: bool = False, **wrapper_kwargs: Any +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: # pyre-ignore + """This method decorator logs the start, end, and exception of wrapped events.""" + + def decorator(func: Callable[_P, _T]): + @functools.wraps(func) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: + msg_dict = _get_msg_dict( + func.__name__, *args, **{**wrapper_kwargs, **kwargs} + ) + + # log start event + msg_dict["event"] = "start" + t0 = time.time_ns() + msg_dict["time"] = t0 + msg_dict["log_exceptions"] = log_exceptions + _dcp_logger.debug(msg_dict) + + # exceptions + try: + result = func(*args, **kwargs) + except BaseException as error: + if log_exceptions: + msg_dict["event"] = "exception" + msg_dict["error"] = f"{error}" + msg_dict["time"] = time.time_ns() + _dcp_logger.error(msg_dict) + raise + + # end event + msg_dict["event"] = "end" + t1 = time.time_ns() + msg_dict["time"] = time.time_ns() + msg_dict["times_spent"] = t1 - t0 + _dcp_logger.debug(msg_dict) + + return result + + return wrapper + + return decorator + + +def _init_logger(rank: int): + logger.setLevel(logging.INFO) + ch = logging.StreamHandler() + ch.setLevel(logging.INFO) + formatter = logging.Formatter( + f"[{rank}] %(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + ch.setFormatter(formatter) + logger.addHandler(ch) diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/logging_handlers.py b/phivenv/Lib/site-packages/torch/distributed/checkpoint/logging_handlers.py new file mode 100644 index 0000000000000000000000000000000000000000..d3833d1d5b1a760604e701d4bf24627aac12eae2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/checkpoint/logging_handlers.py @@ -0,0 +1,14 @@ +import logging + +from torch.distributed.logging_handlers import _log_handlers + + +__all__: list[str] = [] + +DCP_LOGGER_NAME = "dcp_logger" + +_log_handlers.update( + { + DCP_LOGGER_NAME: logging.NullHandler(), + } +) diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/metadata.py b/phivenv/Lib/site-packages/torch/distributed/checkpoint/metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..d6f2a19a14025bffbf4b0035fe9ffde8583770ae --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/checkpoint/metadata.py @@ -0,0 +1,184 @@ +# mypy: allow-untyped-defs +import os +from collections.abc import Sequence +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Optional, Union + +import torch +from torch.distributed.checkpoint.stateful import StatefulT + + +__all__ = [ + "ChunkStorageMetadata", + "TensorStorageMetadata", + "BytesStorageMetadata", + "Metadata", + "MetadataIndex", + "TensorProperties", + "StorageMeta", +] + + +@dataclass +class ChunkStorageMetadata: + """ + Each chunk is expected to have the same properties of the TensorStorageMetadata + that includes it. + """ + + offsets: torch.Size + sizes: torch.Size + + +class _MEM_FORMAT_ENCODING(Enum): + """Describe the memory format of a tensor.""" + + TORCH_CONTIGUOUS_FORMAT = 0 + TORCH_CHANNELS_LAST = 1 + TORCH_PRESERVE_FORMAT = 2 + + +@dataclass +class TensorProperties: + """Properties used to create :class:`Tensor`""" + + # Regular tensor fields + dtype: torch.dtype = field(default_factory=torch.get_default_dtype) + # This field is deprecated. + layout: torch.layout = field(default=torch.strided) + # This field is deprecated. + requires_grad: bool = False + # This field is deprecated. + memory_format: torch.memory_format = field(default=torch.contiguous_format) + # This field is deprecated. + pin_memory: bool = False + + def __getstate__(self): + # Since torch.memory_format cannot be pickled! + memory_format = self.memory_format + if memory_format == torch.contiguous_format: + mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT + elif memory_format == torch.channels_last: + mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST + elif memory_format == torch.preserve_format: + mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT + else: + raise RuntimeError(f"Invalid torch.memory_format: {memory_format}") + + return ( + self.dtype, + self.layout, + self.requires_grad, + mem_format_encoding, + self.pin_memory, + ) + + def __setstate__( + self, + state, + ): + ( + self.dtype, + self.layout, + self.requires_grad, + mem_format_encoding, + self.pin_memory, + ) = state + + if mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT: + memory_format = torch.contiguous_format + elif mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST: + memory_format = torch.channels_last + elif mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT: + memory_format = torch.preserve_format + else: + raise RuntimeError( + f"Invalid torch.memory_format encoding: {mem_format_encoding}" + ) + + self.memory_format = memory_format + + @staticmethod + def create_from_tensor(tensor: torch.Tensor) -> "TensorProperties": + return TensorProperties( + dtype=tensor.dtype, + layout=tensor.layout, + requires_grad=tensor.requires_grad, + memory_format=torch.contiguous_format, + pin_memory=tensor.is_pinned(), + ) + + +@dataclass +class TensorStorageMetadata: + properties: TensorProperties + size: torch.Size + chunks: list[ChunkStorageMetadata] + + +@dataclass +class BytesStorageMetadata: + pass + + +STORAGE_TYPES = Union[TensorStorageMetadata, BytesStorageMetadata] +STATE_DICT_TYPE = dict[str, Union[StatefulT, Any]] + + +@dataclass +class StorageMeta: + checkpoint_id: Union[str, os.PathLike, None] = None + save_id: Optional[str] = None + load_id: Optional[str] = None + modules: list[str] = field(default_factory=list) + + +@dataclass +class Metadata: + """This class represents the metadata of the checkpoint.""" + + # Keys are the same from the `state_dict` used. + state_dict_metadata: dict[str, STORAGE_TYPES] + # It is the responsibility of the planner and storage plugins to ensure + # backward compatibility of the planner_data and storage_data. DCP will + # also ensure the backward compatibility of the metadata in this file and + # the metadata of the built-in planner and storage plugins. + planner_data: Any = None + storage_data: Any = None + storage_meta: Optional[StorageMeta] = None + + +@dataclass(frozen=True) +class MetadataIndex: + """This class represents a lookup key for items in a state dict or Metadata.""" + + fqn: str + """Fully Qualified Name of the object""" + + offset: Optional[torch.Size] = None + """If the object is a tensor, offset into the tensor we're looking for""" + + index: Optional[int] = field(hash=False, compare=False, default=None) + """ + Index hint when searching for tensor chunk to speedup lookups (optional) + + A common representation of a sharded tensor is as a list of chunks so to + find the index in such a list you need to linear search it. + + When constructing an instance of MetadataIndex that points to that list, + one can provide the index as a hint and it will be probed first before + the linear search and thus making it significantly faster. + """ + + def __init__( + self, + fqn: str, + offset: Optional[Sequence[int]] = None, + index: Optional[int] = None, + ): + # We must use object.__setattr__ due to frozen=True + object.__setattr__(self, "fqn", fqn) + object.__setattr__(self, "index", index) + if offset is not None: + object.__setattr__(self, "offset", torch.Size(offset)) diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/optimizer.py b/phivenv/Lib/site-packages/torch/distributed/checkpoint/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..b9dc58e5dc836d27cabc09b8ce7db4355926cacd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/checkpoint/optimizer.py @@ -0,0 +1,357 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +import dataclasses +from collections.abc import Sequence +from typing import cast, Optional, Union + +import torch +import torch.distributed as dist +from torch._utils import _get_device_module +from torch.distributed._shard.sharded_tensor.api import ShardedTensor +from torch.distributed._shard.sharded_tensor.metadata import ( + TensorProperties as ShardTensorProperties, +) +from torch.distributed._shard.sharded_tensor.shard import Shard +from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec +from torch.distributed.checkpoint._nested_dict import unflatten_state_dict +from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner +from torch.distributed.checkpoint.metadata import ( + BytesStorageMetadata, + ChunkStorageMetadata, + Metadata, + MetadataIndex, + STATE_DICT_TYPE, + TensorProperties, + TensorStorageMetadata, +) +from torch.distributed.checkpoint.planner import LoadPlan, LoadPlanner +from torch.distributed.checkpoint.planner_helpers import ( + _create_read_items, + create_read_items_for_chunk_list, +) +from torch.distributed.checkpoint.state_dict_loader import load_state_dict +from torch.distributed.checkpoint.storage import StorageReader +from torch.distributed.checkpoint.utils import ( + _element_wise_add, + _element_wise_sub, + _normalize_device_info, +) +from torch.distributed.distributed_c10d import _get_default_group +from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor +from torch.distributed.remote_device import _remote_device +from torch.distributed.tensor import DTensor + + +STATE_DICT_2D_LAYOUT = dict[str, tuple[Optional[Sequence[int]], Sequence[int]]] + + +# TODO: Update docstrings for optimizer.py +__all__ = [ + "load_sharded_optimizer_state_dict", +] + + +def _gen_rank_device(global_rank: int, device_type: str = "cuda") -> str: + if device_type == "cpu": + return "cpu" + device_module = _get_device_module(device_type) + if device_module.is_available(): + return _normalize_device_info( + device_type, global_rank % device_module.device_count() + ) + return "cpu" + + +def _create_colwise_spec( + pg: Optional[dist.ProcessGroup] = None, +) -> ChunkShardingSpec: + pg_device_type = dist.distributed_c10d._get_pg_default_device(pg).type + if pg is None: + placements = [ + f"rank:{idx}/{_gen_rank_device(idx, pg_device_type)}" + for idx in range(dist.get_world_size()) + ] + else: + placements = [ + f"rank:{idx}/{_gen_rank_device(dist.get_global_rank(pg, idx), pg_device_type)}" + for idx in range(pg.size()) + ] + return ChunkShardingSpec( + dim=0, + placements=cast(list[Union[_remote_device, str]], placements), + ) + + +def _is_nested_tensor(val: torch.Tensor) -> bool: + if type(val) is ShardedTensor: + if len(val.local_shards()) == 0: + return False + if type(val.local_shards()[0].tensor) is ShardedTensor: + return True + if type(val.local_shards()[0].tensor) is DTensor: + raise ValueError("Cannot handle DTensor nested inside ShardedTensor") + elif type(val) is DTensor and ( + type(val._local_tensor) is DTensor or type(val._local_tensor) is ShardedTensor + ): + raise ValueError("Cannot handle nested DTensor") + return False + + +def _alloc_tensor( + props: TensorProperties, size: Sequence[int], device_type: str = "cuda" +) -> torch.Tensor: + if device_type == "cpu": + device = cast(torch.device, _get_device_module(device_type).current_device()) + else: + device = torch.device( + device_type, _get_device_module(device_type).current_device() + ) + + return torch.empty( + size=size, + dtype=props.dtype, + layout=props.layout, + requires_grad=props.requires_grad, + pin_memory=props.pin_memory, + device=device, + ) + + +def _get_state_dict_2d_layout( + state_dict: STATE_DICT_TYPE, +) -> tuple[STATE_DICT_2D_LAYOUT, Optional[dist.ProcessGroup]]: + """ + Load the right TP slice of the optimizer state. + + This is not easy since the per-tensor slicing can't be inferred from checkpoint metadata. + We take advantage of the model state_dict producing a sliced ST to figure out what we need to load. + This is pretty fragile and it might be easier for FSDP to compute this info for us. + Returns a dictionary where keys are the same of the state_dict and the value is a tuple of + (offset, size) for the current rank TP slice. + N.B. The state_dict *MUST* come from FSDP.sharded_state_dict. + """ + specs: STATE_DICT_2D_LAYOUT = {} + dp_pg: Optional[dist.ProcessGroup] = None + for key, value in state_dict.items(): + specs[key] = (None, value.size()) + if _is_nested_tensor(value): + assert len(value.local_shards()) == 1, ( + "Cannot handle ST with multiple shards" + ) + assert isinstance(value, ShardedTensor), ( + "Can only handle nested ShardedTensor" + ) + shard = value.local_shards()[0] + specs[key] = ( + shard.metadata.shard_offsets, + shard.metadata.shard_sizes, + ) + dp_pg = shard.tensor._process_group # type: ignore[attr-defined] + + return ( + specs, + dp_pg, + ) + + +class _ReaderWithOffset(DefaultLoadPlanner): + translation: dict[MetadataIndex, MetadataIndex] + state_dict: STATE_DICT_TYPE + metadata: Metadata + + def __init__(self, fqn_to_offset: dict[str, Sequence[int]]) -> None: + super().__init__() + self.fqn_to_offset = fqn_to_offset + self.metadata = Metadata({}) + self.state_dict = {} + self.translation = {} + + def create_local_plan(self) -> LoadPlan: + requests = [] + self.translation = {} + for fqn, obj in self.state_dict.items(): + md = self.metadata.state_dict_metadata[fqn] + if not isinstance(obj, ShardedTensor): + requests += _create_read_items(fqn, md, obj) + continue + + if fqn not in self.fqn_to_offset: + requests += _create_read_items(fqn, md, obj) + continue + + offset = self.fqn_to_offset[fqn] + + assert len(obj.local_shards()) == 1 + original_shard = obj.local_shards()[0] + local_chunks = [ + ChunkStorageMetadata( + offsets=torch.Size( + _element_wise_add(original_shard.metadata.shard_offsets, offset) + ), + sizes=torch.Size(original_shard.metadata.shard_sizes), + ) + ] + + reqs = create_read_items_for_chunk_list( + fqn, cast(TensorStorageMetadata, md), local_chunks + ) + # TODO: The ReadItems will have a displaced MetadataIndex, fix it. + # TODO: we should change _create_sharded_read_items to have more ergonomic API + for ri in reqs: + assert ri.dest_index.offset is not None + original_offset = _element_wise_sub(ri.dest_index.offset, offset) + original_index = dataclasses.replace( + ri.dest_index, offset=torch.Size(original_offset) + ) + self.translation[ri.dest_index] = original_index + + requests += reqs + return LoadPlan(requests) + + def lookup_tensor(self, index: MetadataIndex) -> torch.Tensor: + return super().lookup_tensor(self.translation.get(index, index)) + + +def load_sharded_optimizer_state_dict( + model_state_dict: STATE_DICT_TYPE, + optimizer_key: str, + storage_reader: StorageReader, + planner: Optional[LoadPlanner] = None, +) -> STATE_DICT_TYPE: + """ + Load a state_dict in conjunction with FSDP sharded optimizer state. + + This is the current recommended way to checkpoint FSDP. + >>> # xdoctest: +SKIP + >>> import torch.distributed.checkpoint as dist_cp + >>> # Save + >>> model: torch.nn.Model + >>> optim_params = model.parameters() + >>> optim = torch.optim.SGD(optim_params, lr=0.01) + >>> # Save + >>> with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): + >>> state_dict = { + >>> "optimizer": FSDP.optim_state_dict(model, optim), + >>> "model": model.state_dict() + >>> } + >>> dist_cp.save_state_dict( + >>> state_dict=optim_state, + >>> storage_writer=dist_cp.FileSystemWriter("checkpoint"), + >>> planner=dist_cp.DefaultSavePlanner(), + >>> ) + >>> + >>> # Load + >>> with FSDP.state_dict_type(model_tp, StateDictType.SHARDED_STATE_DICT): + >>> model_state_dict = model_tp.state_dict() + >>> checkpoint = { + >>> "model": model_state_dict + >>> } + >>> dist_cp.load_state_dict( + >>> state_dict=checkpoint, + >>> storage_reader=dist_cp.FileSystemReader(checkpoint_file), + >>> planner=dist_cp.DefaultLoadPlanner(), + >>> ) + >>> model.load_state_dict(checkpoint["model_state"]) + >>> + >>> optim_state = dist_cp.load_sharded_optimizer_state_dict( + >>> model_state_dict, + >>> optimizer_key="optimizer", + >>> storage_reader=dist_cp.FileSystemReader("checkpoint"), + >>> ) + >>> + >>> flattened_osd = FSDP.optim_state_dict_to_load( + >>> model, optim, optim_state["optimizer"] + >>> ) + >>> + >>> optim.load_state_dict(flattened_osd) + """ + metadata = storage_reader.read_metadata() + + layout_specs, dp_pg = _get_state_dict_2d_layout(model_state_dict) + dp_pg_device_type = dist.distributed_c10d._get_pg_default_device(dp_pg).type + device_module = _get_device_module(dp_pg_device_type) + + if dp_pg is None: + placements = [] + for i in range(dist.get_world_size()): + device_info = _normalize_device_info( + dp_pg_device_type, i % device_module.device_count() + ) + placements.append(f"rank:{i}/{device_info}") + sharding_spec = ChunkShardingSpec(dim=0, placements=placements) # type: ignore[arg-type] + else: + sharding_spec = _create_colwise_spec(dp_pg) + + # Create a state_dict for optimizer state + state_dict: STATE_DICT_TYPE = {} + + fqn_to_offset: dict[str, Sequence[int]] = {} + for key, value in metadata.state_dict_metadata.items(): + key_path = metadata.planner_data[key] + if key_path[0] != optimizer_key: + continue + + if isinstance(value, BytesStorageMetadata): + state_dict[key] = "" + continue + + # value: TensorStorageMetadata + if value.size.numel() == 1: + state_dict[key] = _alloc_tensor( + value.properties, value.size, dp_pg_device_type + ) + elif dp_pg is None: + state_dict[key] = _create_chunk_sharded_tensor( + _alloc_tensor(value.properties, value.size, dp_pg_device_type), + rank=dist.get_rank(), + world_size=dist.get_world_size(), + num_devices_per_node=device_module.device_count(), + pg=_get_default_group(), + ) + else: + spec_key = key_path[2] + alloc_size = layout_specs.get(spec_key, (None, value.size))[1] + + properties = ShardTensorProperties( + dtype=value.properties.dtype, + layout=value.properties.layout, + requires_grad=value.properties.requires_grad, + memory_format=value.properties.memory_format, + pin_memory=value.properties.pin_memory, + ) + + st_md = sharding_spec.build_metadata(torch.Size(alloc_size), properties) + local_shards = [] + current_rank = dist.get_rank(dp_pg) + for shard_md in st_md.shards_metadata: + if cast(_remote_device, shard_md.placement).rank() != current_rank: + continue + local_shards.append( + Shard( + tensor=_alloc_tensor( + value.properties, shard_md.shard_sizes, dp_pg_device_type + ), + metadata=shard_md, + ) + ) + + st = ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards, st_md, process_group=dp_pg + ) + + if spec_key in layout_specs and layout_specs[spec_key][0] is not None: + fqn_to_offset[key] = cast(Sequence[int], layout_specs[spec_key][0]) + + state_dict[key] = st + + # Whether we unflatten before or after doesn't matter + load_state_dict( + state_dict=state_dict, + storage_reader=storage_reader, + # FIXME the type of planner is wrong in load_state_dict + planner=_ReaderWithOffset(fqn_to_offset) if dp_pg is not None else planner, + ) + + state_dict = unflatten_state_dict(state_dict, metadata.planner_data) + + return state_dict diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/planner.py b/phivenv/Lib/site-packages/torch/distributed/checkpoint/planner.py new file mode 100644 index 0000000000000000000000000000000000000000..f362438f1f85d735f790d1c594a80dd25cd26444 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/checkpoint/planner.py @@ -0,0 +1,450 @@ +import abc +import io +import operator +from dataclasses import dataclass +from enum import auto, Enum +from functools import reduce +from typing import Any, Optional, Union + +import torch +from torch.distributed.checkpoint.metadata import ( + ChunkStorageMetadata, + Metadata, + MetadataIndex, + STATE_DICT_TYPE, + StorageMeta, + TensorProperties, +) + + +__all__ = [ + "WriteItemType", + "LoadItemType", + "BytesIOWriteData", + "TensorWriteData", + "WriteItem", + "ReadItem", + "SavePlan", + "LoadPlan", + "SavePlanner", + "LoadPlanner", +] + + +class WriteItemType(Enum): + TENSOR = auto() + SHARD = auto() + BYTE_IO = auto() + + +class LoadItemType(Enum): + TENSOR = auto() + BYTE_IO = auto() + + +@dataclass(frozen=True) +class BytesIOWriteData: + nbytes: int + + +@dataclass(frozen=True) +class TensorWriteData: + chunk: ChunkStorageMetadata + properties: TensorProperties + size: torch.Size + + +@dataclass(frozen=True) +class WriteItem: + """Dataclass which holds information about what needs to be written to storage.""" + + index: MetadataIndex + type: WriteItemType + + # Size of bytesIO data to be written. + bytes_io_data: Optional[BytesIOWriteData] = None + + # Value present if it's a tensor write + tensor_data: Optional[TensorWriteData] = None + + def tensor_storage_size(self) -> Optional[int]: + """ + Calculates the storage size of the underlying tensor, or None if this is not a tensor write. + + Returns: + Optional[int] storage size, in bytes of underlying tensor if any. + """ + if self.tensor_data is None: + return None + + numels = reduce(operator.mul, self.tensor_data.size, 1) + dtype_size = torch._utils._element_size(self.tensor_data.properties.dtype) + return numels * dtype_size + + +@dataclass(frozen=True) +class ReadItem: + # Read Item + type: LoadItemType + + # Index into the state_dict + dest_index: MetadataIndex + # Offsets into destination tensor + dest_offsets: torch.Size + + # Index into the checkpoint + storage_index: MetadataIndex + # Offset into the checkpoint data + storage_offsets: torch.Size + + # Size of the hypercube to copy + lengths: torch.Size + + +@dataclass(frozen=True) +class SavePlan: + items: list[WriteItem] + storage_data: Any = None + planner_data: Any = None + # This is used to indicate that the ranks should + # use the cached plans to write data instead. + usable: bool = True + + +@dataclass +class LoadPlan: + items: list[ReadItem] + storage_data: Any = None + planner_data: Any = None + + +class SavePlanner(abc.ABC): + """ + Abstract class defining the protocol used by save_state_dict to plan the save process. + + SavePlanners are stateful objects that can be used to customize the whole save process. + + SavePlanner acts as an access proxy to the state_dict, so any transformation done to it + will be visible to the whole process. + + A planner subclass can expect the following sequence of calls during save_state_dict: + + 1) set_up_planner - called on all ranks. + Signals the start of a checkpoint save. + + 2) create_local_plan - called on all ranks. + Process the state_dict and produces a `SavePlan` that will be sent for global planning. + + 3) create_global_plan - called on the coordinator rank only. + Takes the SavePlan from all ranks and make any global decision. + + 4) finish_plan - called on all ranks. + This gives each rank a chance to adjust to global planning decisions. + + 5) resolve_data - called multiple times on each rank + Lookups a value on the `state_dict` for the storage layer to write. + + Users are recommended to extend DefaultSavePlanner instead of this interface directly as + most changes can be expressed by changes in a single method. + + There are 3 usual patterns of extension: + + Rewriting state_dict. This is the simplest way to extend the save process as it + doesn't requite understanding the intrincacies of how SavePlan works: + + >>> # xdoctest: +SKIP("undefined vars") + >>> class RenamePlanner(DefaultSavePlanner): + >>> def set_up_planner( + >>> self, + >>> state_dict: STATE_DICT_TYPE, + >>> storage_meta: Optional[StorageMeta], + >>> is_coordinator: bool, + >>> ) -> None: + >>> # prefix all keys with `foo_`` + >>> super().set_up_planner({"foo_" + k: v for k, v in state_dict.items()}, storage_meta, is_coordinator) + + Modifying local plan and lookup in tandem. This is useful when fine control of how data is persisted + + >>> # xdoctest: +SKIP("undefined vars") + >>> class FP16Planner(DefaultSavePlanner): + >>> def create_local_plan(self): + >>> plan = super().create_local_plan() + >>> for p in plan: + >>> if p.tensor_data is not None: + >>> p.tensor_data.properties.dtype = torch.float16 + >>> return plan + >>> + >>> def resolve_data(self, write_item): + >>> item = super().resolve_data(write_item) + >>> return item if write_item.type == WriteItemType.BYTE_IO else item.to(torch.float16) + + Using the global planning step to make central decisions that can't be made individually by each rank + + >>> # xdoctest: +SKIP("undefined vars") + >>> from itertools import zip_longest + >>> from dataclasses import replace + >>> class DDPLoadBalancingPlanner(DefaultSavePlanner): + >>> # This uses the default local plan behavior of having all non-sharded writes in rank 0 + >>> # This sample doesn't handle ShardedTensors + >>> def create_global_plan(self, all_plans): + >>> iters = [iter(all_plans[0].items)] * len(all_plans) + >>> items_per_rank = [ + >>> [item for item in items if item is not None] + >>> for items in zip(*zip_longest(*iters), strict=True) + >>> ] + >>> all_plans = [ + >>> replace(plan, items=items) + >>> for plan, items in zip(all_plans, items_per_rank, strict=True) + >>> ] + >>> return super().create_global_plan(all_plans) + + Finally, some planners need to save additional metadata in the checkpoint, this is + accomplished by having each rank contribute their data items in the local plan and + the global planner aggregate them: + + >>> # xdoctest: +SKIP("undefined vars") + >>> class SaveExtraDataPlanner(DefaultSavePlanner): + >>> def create_local_plan(self) -> SavePlan: + >>> plan = super().create_local_plan() + >>> return replace(plan, planner_data="per-rank-data") + >>> + >>> def create_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]: + >>> global_plan, metadata = super().create_global_plan(all_plans) + >>> merged_data = [p.planner_data for p in global_plan] + >>> metadata = replace(metadata, planner_data=merged_data) + >>> return global_plan, metadata + """ + + # Save plan for the current rank as computed by `create_local_plan` API + # Cached on the local rank. + _cached_save_plan: dict[str, SavePlan] = {} + # Final save plan for the current rank. + # This is created by merging the plan created by `create_local_plan` API + # and the result of `create_global_plan` for the given rank. + # This is the final plan computed by the `finish_plan` API that gets + # sent to the `write_data`. + # Cached on the local rank. + _cached_final_save_plan: dict[str, SavePlan] = {} + # Collection of all the local plans from all the ranks. + # This is the input to the `create_global_plan` API. + # Cached on the coordinator rank. + _cached_all_plans: dict[str, list[SavePlan]] = {} + # Global checkpoint plan as computed by `create_global_plan` API. + # Cached on the coordinator rank. + _cached_global_plan: dict[str, list[SavePlan]] = {} + # Metadata for the global checkpoint plan as computed by `create_global_plan` API. + # Cached on the coordinator rank. + _cached_metadata: dict[str, Metadata] = {} + + @abc.abstractmethod + def set_up_planner( + self, + state_dict: STATE_DICT_TYPE, + storage_meta: Optional[StorageMeta] = None, + is_coordinator: bool = False, + ) -> None: + """ + Initialize this planner to save ``state_dict``. + + Implementations should save those values as they won't be provided lated in the save process. + + This is called on all ranks. + """ + + @abc.abstractmethod + def create_local_plan(self) -> SavePlan: + """ + Compute the save plan for the current rank. + + This will be aggregated and passed to create_global_plan. + Planner specific data can be passed through SavePlan::planner_data. + + This is called on all ranks. + """ + + @abc.abstractmethod + def create_global_plan( + self, all_plans: list[SavePlan] + ) -> tuple[list[SavePlan], Metadata]: + """ + Compute the global checkpoint plan and return the local plan of each rank. + + This is called on the coordinator rank only. + """ + + @abc.abstractmethod + def finish_plan(self, new_plan: SavePlan) -> SavePlan: + """ + Merge the plan created by `create_local_plan` and the result of `create_global_plan`. + + This is called on all ranks. + """ + + @abc.abstractmethod + def resolve_data(self, write_item: WriteItem) -> Union[torch.Tensor, io.BytesIO]: + """ + Transform and prepare ``write_item`` from ``state_dict`` for storage, ensuring idempotency and thread-safety. + + Lookup the object associated with ``write_item`` in ``state_dict`` and apply any + transformation (such as serialization) prior to the storage layer consuming it. + + Called on each rank multiple times, at least once per WriteItem in the final SavePlan. + + This method should be idempotent and thread-save. StorageWriter implementations + are free to call it as frequently as they need. + + Any transformation that allocates memory should be lazily done when his method + is called in order to reduce peak memory required by checkpointing. + + When returning tensors, they can be on any device or format, they can be views too. + It's the storage layer responsibility to figure out how to save them. + """ + + +class LoadPlanner: + """ + Abstract class defining the protocol used by load_state_dict to plan the load process. + + LoadPlanner are stateful objects that can be used to customize the whole load process. + + LoadPlanner acts as an access proxy to the state_dict, so any transformation done to it + will be visible to the whole process. + + A planner subclass can expect the following sequence of calls during load_state_dict: + + 1) set_up_planner - called on all ranks. + Signals the start of loading a checkpoint. + + 2) create_local_plan - called on all ranks. + Process the state_dict and produces a `LoadPlan` that will be sent for global planning. + + 3) create_global_plan - called on the coordinator rank only. + Takes the LoadPlan from all ranks and make any global decision. + + 4) load_bytes - called multiple times on each rank + This is called once per non-tensor value in state_dict. + + 5) resolve_tensor and commit_tensor - called multiple times on each rank + They are called in pair for each Tensor value in state_dict. + + Users are recommended to extend DefaultLoadPlanner instead of this interface directly as + most changes can be expressed by changes in a single method. + + There are two usual patterns of extension: + + Rewriting state_dict. This is the simplest way to extend the load process as it + doesn't requite understanding the intrincacies of how LoadPlan works. We need + to keep a reference to the original state_dict as load happens in place so + we need to be able to perform it in place + + >>> # xdoctest: +SKIP("undefined vars") + >>> class RenamePlanner(DefaultLoadPlanner): + >>> def set_up_planner( + >>> self, + >>> state_dict: STATE_DICT_TYPE, + >>> metadata: Metadata, + >>> is_coordinator: bool, + >>> ) -> None: + >>> self.original_state_dict = state_dict + >>> state_dict = {"foo_" + k: v for k, v in state_dict.items()} + >>> + >>> if self.flatten_sharded_tensors: + >>> state_dict = _flatten_sharded_tensors(state_dict) + >>> + >>> if self.flatten_state_dict: + >>> state_dict, self.mappings = flatten_state_dict(state_dict) + >>> + >>> self.state_dict = state_dict + >>> self.metadata = metadata + >>> self.is_coordinator = is_coordinator + >>> + >>> def load_bytes(self, read_item, value): + >>> # Remove the "foo_" prefix + >>> self.original_state_dict[read_item.dest_index.fqn[4:]] = torch.load(value, weights_only=False) + + + Modifying resolve_tensor and commit_tensor to handle load time transformation. + + >>> # xdoctest: +SKIP("undefined vars") + >>> class MetaModelMaterialize(DefaultSavePlanner): + >>> def resolve_tensor(self, read_item): + >>> tensor = super().resolve_tensor(read_item) + >>> return torch.empty_like(tensor, device="cpu") + >>> + >>> def commit_tensor(self, read_item, tensor): + >>> self.state_dict[read_item.dest_index.fqn] = tensor + """ + + @abc.abstractmethod + def set_up_planner( + self, + state_dict: STATE_DICT_TYPE, + metadata: Optional[Metadata] = None, + is_coordinator: bool = False, + ) -> None: + """ + Initialize this instance to load data into ``state_dict``. + + . N.B. This is called on every rank. + """ + + @abc.abstractmethod + def create_local_plan(self) -> LoadPlan: + """ + Create a LoadPlan based on state_dict and metadata provided by set_up_planner. + + . N.B. This is called on every rank. + """ + + @abc.abstractmethod + def create_global_plan(self, global_plan: list[LoadPlan]) -> list[LoadPlan]: + """ + Compute the global load plan and return plans for each rank. + + . N.B. This is called on the coordinator rank only + """ + + @abc.abstractmethod + def finish_plan(self, central_plan: LoadPlan) -> LoadPlan: + """Accept the plan from coordinator and return final LoadPlan.""" + + @abc.abstractmethod + def load_bytes(self, read_item: ReadItem, value: io.BytesIO) -> None: + """ + Load the item described by ``read_item``and ``value``. + + This method is expected to modify in-place the underlying state_dict. + + The contents of ``value`` are defined by the SavePlanner used to produce + the checkpoint being loaded. + """ + + def resolve_bytes(self, read_item: ReadItem) -> io.BytesIO: + """ + Return the BytesIO to be used by the StorageReader to load `read_item`. + + The BytesIO should alias with one on the underlying state_dict as StorageReader will replace its contents. + """ + raise NotImplementedError("LoadPlanner.resolve_bytes is not implemented") + + @abc.abstractmethod + def resolve_tensor(self, read_item: ReadItem) -> torch.Tensor: + """ + Return the tensor described by ``read_item`` to be used by the StorageReader to load `read_item`. + + The tensor should alias with one on the underlying state_dict as StorageReader will replace its contents. + If, for any reason, that's not possible, the planner can use the ``commit_tensor`` method to copy the data + back to the one in state_dict. + """ + + @abc.abstractmethod + def commit_tensor(self, read_item: ReadItem, tensor: torch.Tensor) -> None: + """ + Call once the StorageReader finished loading data into ``tensor``. + + The provided tensor is the same one returned by the call to ``resolve_tensor``. + This method is only needed if this LoadPlanner needs to post process ``tensor`` prior to + copying it back to the one in the state_dict. + + The contents of tensor will follow its device synchronization model. + """ diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/planner_helpers.py b/phivenv/Lib/site-packages/torch/distributed/checkpoint/planner_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..eef7d86b7826edac34d0ff25bc203dd9eab39ec6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/checkpoint/planner_helpers.py @@ -0,0 +1,490 @@ +# mypy: allow-untyped-defs +import io +from typing import Any, Callable, cast + +import torch +import torch.distributed as dist +from torch._utils import _get_device_module +from torch.distributed._shard.metadata import ShardMetadata +from torch.distributed._shard.sharded_tensor import ShardedTensor +from torch.distributed.tensor import DTensor +from torch.distributed.tensor._utils import compute_local_shape_and_global_offset + +from .metadata import ( + BytesStorageMetadata, + ChunkStorageMetadata, + MetadataIndex, + STATE_DICT_TYPE, + STORAGE_TYPES, + TensorProperties, + TensorStorageMetadata, +) +from .planner import ( + LoadItemType, + ReadItem, + SavePlan, + TensorWriteData, + WriteItem, + WriteItemType, +) +from .resharding import ( + _check_shard_metadata_pair_overlap, + _shards_get_overlap_region_wrt_saved_tensor, +) + + +__all__: list[str] = ["create_read_items_for_chunk_list"] + + +def _compare_save_plans(plan: SavePlan, other_plan: SavePlan) -> bool: + """ + Compare the two Save plans and return True if they are equal. + + Args: + plan (SavePlan): First SavePlan to compare. + other_plan (SavePlan): Second SavePlan to compare. + + Returns: + True if the two plans are equal, False otherwise. + """ + if plan.usable != other_plan.usable: + return False + + # Both the plans should have the same number of items + if len(plan.items) != len(other_plan.items): + return False + + # Both the plans should have the same write items. + for plan_item, other_plan_item in zip(plan.items, other_plan.items): + # Write item type should be same + if plan_item.type != other_plan_item.type: + return False + + plan_metadata_index = plan_item.index + other_plan_metadata_index = other_plan_item.index + + # Write item metadata_index should be same + if ( + plan_metadata_index.fqn != other_plan_metadata_index.fqn + or plan_metadata_index.offset != other_plan_metadata_index.offset + or plan_metadata_index.index != other_plan_metadata_index.index + ): + return False + + # Write item tensor_data should be present in both the write items plans, if it exists in either of them. + tensor_data = plan_item.tensor_data + other_tensor_data = other_plan_item.tensor_data + if (tensor_data and not other_tensor_data) or ( + not tensor_data and other_tensor_data + ): + return False + + if tensor_data and other_tensor_data: + # Write item tensor_data size should be same + if tensor_data.size != other_tensor_data.size: + return False + + # Write item tensor_data chunk should be present in both the write items, if it exists in either of them. + chunk = tensor_data.chunk + other_chunk = other_tensor_data.chunk + if (chunk and not other_chunk) or (not chunk and other_chunk): + return False + + # Write item tensor_data chunk offsets and sizes should be same + if chunk and other_chunk: + if ( + chunk.offsets != other_chunk.offsets + or chunk.sizes != other_chunk.sizes + ): + return False + + return True + + +def _contains_usable_plan(delta_plans: list[SavePlan]) -> bool: + """ + Check if any delta plan is usable, indicating the plan has changed. + + Args: + delta_plans (List[SavePlan]): A list of delta plans to check. + Returns: + True if any delta plan is usable, False otherwise. + """ + return any(delta_plan and delta_plan.usable for delta_plan in delta_plans) + + +def _merge_delta_local_plans( + cached_plans: list[SavePlan], + delta_plans: list[SavePlan], +) -> list[SavePlan]: + """ + Merge a list of delta plans into a single plan. + + Args: + cached_plans (List[SavePlan]): A list of cached plans. + delta_plans (List[SavePlan]): A list of delta plans to merge. It can contain empty plans + + Returns: + A single merged plan. If a delta plan is not usable, use the cached plan. Otherwise, use the delta plan. + """ + merged_plans = [] + + for cached_plan, delta_plan in zip(cached_plans, delta_plans): + if delta_plan and not delta_plan.usable: + merged_plans.append(cached_plan) + else: + merged_plans.append(delta_plan) + + return merged_plans + + +def _create_chunk_from_tensor(tensor: torch.Tensor) -> ChunkStorageMetadata: + return ChunkStorageMetadata( + offsets=torch.Size([0] * len(tensor.size())), sizes=tensor.size() + ) + + +def _chunk_for_shard(shard_md: ShardMetadata) -> ChunkStorageMetadata: + return ChunkStorageMetadata( + offsets=torch.Size(shard_md.shard_offsets), + sizes=torch.Size(shard_md.shard_sizes), + ) + + +def _sharded_tensor_metadata( + sharded_tensor: ShardedTensor, shard_md: ShardMetadata +) -> TensorWriteData: + shard_properties = sharded_tensor.metadata().tensor_properties + + properties = TensorProperties( + dtype=shard_properties.dtype, + layout=shard_properties.layout, + requires_grad=shard_properties.requires_grad, + memory_format=shard_properties.memory_format, + pin_memory=shard_properties.pin_memory, + ) + + return TensorWriteData( + chunk=_chunk_for_shard(shard_md), + properties=properties, + size=sharded_tensor.metadata().size, + ) + + +def _create_write_items_for_dtensor(fqn: str, tensor: DTensor) -> WriteItem: + sizes, offsets = compute_local_shape_and_global_offset( + tensor.shape, tensor.device_mesh, tensor.placements + ) + sizes, offsets = torch.Size(sizes), torch.Size(offsets) + + return WriteItem( + index=MetadataIndex(fqn, offsets), + type=WriteItemType.SHARD, + tensor_data=TensorWriteData( + chunk=ChunkStorageMetadata( + offsets=offsets, + sizes=sizes, + ), + properties=TensorProperties.create_from_tensor(tensor.to_local()), + size=tensor.size(), + ), + ) + + +def _create_write_item_for_shard( + fqn: str, sharded_tensor: ShardedTensor, shard_md: ShardMetadata +) -> WriteItem: + offsets = torch.Size(shard_md.shard_offsets) + return WriteItem( + index=MetadataIndex(fqn, offsets), + type=WriteItemType.SHARD, + tensor_data=_sharded_tensor_metadata(sharded_tensor, shard_md), + ) + + +def _create_write_item_for_tensor(fqn: str, tensor: torch.Tensor) -> WriteItem: + offsets = torch.Size([0] * len(tensor.size())) + return WriteItem( + index=MetadataIndex(fqn, offsets), + type=WriteItemType.TENSOR, + tensor_data=TensorWriteData( + chunk=ChunkStorageMetadata(offsets=offsets, sizes=tensor.size()), + properties=TensorProperties.create_from_tensor(tensor), + size=tensor.size(), + ), + ) + + +def _create_write_item_for_bytesio(fqn: str, bytes: Any): + return WriteItem( + index=MetadataIndex(fqn), + type=WriteItemType.BYTE_IO, + ) + + +def _create_read_item_for_byteio( + dest_index, dest_offset, storage_index, storage_offset, length +): + return ReadItem( + type=LoadItemType.BYTE_IO, + dest_index=dest_index, + dest_offsets=torch.Size((dest_offset,)), + storage_index=storage_index, + storage_offsets=torch.Size((storage_offset,)), + lengths=torch.Size((length,)), + ) + + +def _create_read_item_for_tensor( + dest_index, dest_offsets, storage_index, storage_offsets, lengths +): + return ReadItem( + type=LoadItemType.TENSOR, + dest_index=dest_index, + dest_offsets=torch.Size(dest_offsets), + storage_index=storage_index, + storage_offsets=torch.Size(storage_offsets), + lengths=torch.Size(lengths), + ) + + +def create_read_items_for_chunk_list( + fqn: str, + checkpoint_md: TensorStorageMetadata, + local_chunks: list[ChunkStorageMetadata], +) -> list[ReadItem]: + """ + Create a list of ``ReadItem`` based on the checkpoint and local chunks. + + This applies the resharding algorithm and computes the reads needed + to satisfy ``local_chunks`` with a checkpoint described by ``checkpoint_md``. + + Args: + fqn (str) : The state_dict FQN to pass to ``ReadItem``. + checkpoint_md (TensorStorageMetadata): metadata for a given tensor + from a checkpoint. + local_chunks (List[ChunkStorageMetadata]): Local chunks that needs to be + loaded. + + Returns: + A list of ``ReadItem`` that will satisfy all input chunks. + """ + read_items = [] + # this is a naive quadratic algo that can be optimized later + for idx, shard in enumerate(local_chunks): + for storage_idx, storage_md in enumerate(checkpoint_md.chunks): + if not _check_shard_metadata_pair_overlap(shard, storage_md): + continue + + storage_offsets = [] + dest_offsets = [] + lengths = [] + for ( + _dim, + offset_for_saved_tensor, + offset_for_current_tensor, + length, + ) in _shards_get_overlap_region_wrt_saved_tensor( + saved_shard=storage_md, current_shard=shard + ): + storage_offsets.append(offset_for_saved_tensor) + dest_offsets.append(offset_for_current_tensor) + lengths.append(length) + + read_items.append( + _create_read_item_for_tensor( + dest_index=MetadataIndex(fqn, shard.offsets, idx), + dest_offsets=dest_offsets, + storage_index=MetadataIndex(fqn, storage_md.offsets, storage_idx), + storage_offsets=storage_offsets, + lengths=lengths, + ) + ) + return read_items + + +def _create_default_metadata_only_plan(state_dict: STATE_DICT_TYPE) -> SavePlan: + requests = [] + for fqn, obj in state_dict.items(): + if isinstance(obj, DTensor): + requests.append(_create_write_items_for_dtensor(fqn, obj)) + elif isinstance(obj, ShardedTensor): + requests.extend( + _create_write_item_for_shard(fqn, obj, shard_md) + for shard_md in obj.metadata().shards_metadata + ) + elif isinstance(obj, torch.Tensor): + requests.append(_create_write_item_for_tensor(fqn, obj)) + else: + requests.append(_create_write_item_for_bytesio(fqn, obj)) + return SavePlan(requests) + + +def _create_write_items(fqn: str, object: Any) -> list[WriteItem]: + if hasattr(object, "__create_write_items__"): + # DTensor implements _Checkpointable + return object.__create_write_items__(fqn, object) + elif isinstance(object, ShardedTensor): + return [ + _create_write_item_for_shard(fqn, object, shard.metadata) + for shard in object.local_shards() + ] + elif isinstance(object, torch.Tensor): + return [_create_write_item_for_tensor(fqn, object)] + else: + return [_create_write_item_for_bytesio(fqn, object)] + + +def _create_chunk_from_dtensor(tensor: DTensor) -> ChunkStorageMetadata: + sizes, offsets = compute_local_shape_and_global_offset( + tensor.shape, tensor.device_mesh, tensor.placements + ) + sizes, offsets = torch.Size(sizes), torch.Size(offsets) + return ChunkStorageMetadata( + offsets=offsets, + sizes=sizes, + ) + + +def _create_chunk_list(tensor: torch.Tensor) -> list[ChunkStorageMetadata]: + if hasattr(tensor, "__create_chunk_list__"): + # DTensor implements _Checkpointable + local_chunks = tensor.__create_chunk_list__() # type: ignore[attr-defined] + elif isinstance(tensor, ShardedTensor): + local_chunks = [ + _chunk_for_shard(shard.metadata) for shard in tensor.local_shards() + ] + elif isinstance(tensor, torch.Tensor): + local_chunks = [_create_chunk_from_tensor(tensor)] + else: + raise ValueError( + "Unsupported Type, expecting one of [Tensor, DTensor, ShardedTensor] " + f",but got {type(tensor)}" + ) + + return local_chunks + + +def _create_read_items(fqn: str, md: STORAGE_TYPES, obj: Any) -> list[ReadItem]: + if not isinstance(md, BytesStorageMetadata): + try: + local_chunks = _create_chunk_list(obj) + except ValueError as ex: + raise ValueError( + f"Invalid checkpoint metadata for {fqn}, " + + f"expected BytesStorageMetadata but found {type(md)}", + ) from ex + + return create_read_items_for_chunk_list(fqn, md, local_chunks) + else: + return [ + _create_read_item_for_byteio( + dest_index=MetadataIndex(fqn), + dest_offset=0, + storage_index=MetadataIndex(fqn), + storage_offset=0, + length=0, + ) + ] + + +def _init_state_dict(state_dict: dict[str, Any]) -> Any: + """ + Initializes meta tensor if the meta tensor is DTensor or torch.Tensor. + """ + + def dtensor_func(value: DTensor): + device = getattr(value, "device", None) + if device == torch.device("meta"): + device_type = dist.distributed_c10d._get_pg_default_device().type + device = cast( + torch.device, _get_device_module(device_type).current_device() + ) + new_local_tensor = torch.empty_like(value.to_local(), device=device) + # We need to pass shape and stride explicitly, since DTensor might be + # sharded unevenly. + dtensor = DTensor.from_local( + new_local_tensor, + device_mesh=value.device_mesh, + placements=value.placements, + shape=value.size(), + stride=value.stride(), + ) + return dtensor + else: + return value + + def sharded_tensor_func(value: Any): + device = getattr(value, "device", None) + if device == torch.device("meta"): + raise RuntimeError( + f"Found unsupported type {type(value)} for meta device loading." + ) + else: + return value + + def tensor_func(value: torch.Tensor): + device = getattr(value, "device", None) + if device == torch.device("meta"): + device_type = dist.distributed_c10d._get_pg_default_device().type + device = cast( + torch.device, _get_device_module(device_type).current_device() + ) + tensor = torch.empty_like(value, device=device) + return tensor + else: + return value + + _iterate_state_dict( + state_dict, + dtensor_func, + sharded_tensor_func, + tensor_func, + ) + + +def _iterate_state_dict( + iter_object: Any, + dtensor_func: Callable, + sharded_tensor_func: Callable, + tensor_func: Callable, +): + """ + Iterate through the state dict, applying the given functions to each tensor type + and update the state dict in place. + + Args: + iter_object (Any): the target state_dict. + sharded_tensor_func (Callable): the function to apply to ShardedTensor + dtensor_func (Callable): the function to apply to DTensor + tensor_func (Callable): the function to apply to Tensor + + # TODO: let state_dict_util._iterate_state_dict() to support in place option + so we don't need to have two versions of _iterate_state_dict. + """ + + if isinstance(iter_object, DTensor): + return dtensor_func(iter_object) + elif isinstance(iter_object, ShardedTensor): + return sharded_tensor_func(iter_object) + elif isinstance(iter_object, torch.Tensor): + return tensor_func(iter_object) + elif ( + isinstance(iter_object, (int, float, str, bytes, io.BytesIO)) + or iter_object is None + ): + return iter_object + elif isinstance(iter_object, dict): + for key, value in iter_object.items(): + iter_object[key] = _iterate_state_dict( + value, dtensor_func, sharded_tensor_func, tensor_func + ) + return iter_object + elif isinstance(iter_object, (list, tuple)): + ret = [ + _iterate_state_dict(v, dtensor_func, sharded_tensor_func, tensor_func) + for v in iter_object + ] + if isinstance(iter_object, tuple): + ret = tuple(ret) # type: ignore[assignment] + return ret diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/resharding.py b/phivenv/Lib/site-packages/torch/distributed/checkpoint/resharding.py new file mode 100644 index 0000000000000000000000000000000000000000..d8d858e2e3ef6de01511cedec26d2b5dc1f52dbf --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/checkpoint/resharding.py @@ -0,0 +1,69 @@ +from torch.distributed.checkpoint.metadata import ChunkStorageMetadata + + +__all__: list[str] = [] + + +def _check_shard_metadata_pair_overlap( + shard1: ChunkStorageMetadata, shard2: ChunkStorageMetadata +) -> bool: + """Check if two shards overlap.""" + # For each dim of each shard, check if one shard resides on the other + # end of second shard with respect to that dim. As an example for a 2D + # shard, we would check if one shard is above or on the left of the + # other shard. + ndims = len(shard1.offsets) + for i in range(ndims): + if shard1.offsets[i] >= shard2.offsets[i] + shard2.sizes[i]: + return False + if shard2.offsets[i] >= shard1.offsets[i] + shard1.sizes[i]: + return False + + return True + + +def _shards_get_overlap_region_wrt_saved_tensor( + saved_shard: ChunkStorageMetadata, current_shard: ChunkStorageMetadata +) -> list[tuple[int, int, int, int]]: + """ + Return the overlapping region between saved_shard and current_shard. + + There returned list has the same number of elements as the tensor's dimension. + For each element, we produce a tuple with the following contents: + (dimension, `saved_shard` offset, `current_shard` offset, length) + + Offsets are relative to each shard. + """ + narrows = [] + for dim, ( + saved_shard_offset, + current_shard_offset, + saved_shard_size, + current_shard_size, + ) in enumerate( + zip( + saved_shard.offsets, + current_shard.offsets, + saved_shard.sizes, + current_shard.sizes, + ) + ): + min_range_end = min( + saved_shard_offset + saved_shard_size, + current_shard_offset + current_shard_size, + ) + + length = min_range_end - max(current_shard_offset, saved_shard_offset) + + if saved_shard_offset > current_shard_offset: + offset_for_saved_tensor = 0 + offset_for_current_tensor = saved_shard_offset - current_shard_offset + else: + offset_for_saved_tensor = current_shard_offset - saved_shard_offset + offset_for_current_tensor = 0 + + narrows.append( + (dim, offset_for_saved_tensor, offset_for_current_tensor, length) + ) + + return narrows diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/staging.py b/phivenv/Lib/site-packages/torch/distributed/checkpoint/staging.py new file mode 100644 index 0000000000000000000000000000000000000000..0f93e9688e5d4f03f7660ee48cb476d44df21abf --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/checkpoint/staging.py @@ -0,0 +1,115 @@ +from typing import Optional +from typing_extensions import Protocol, runtime_checkable + +from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict +from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE + + +__all__ = ["AsyncStager", "BlockingAsyncStager"] + + +@runtime_checkable +class AsyncStager(Protocol): + """ + This protocol is meant to provide customization and extensibility for dcp.async_save, allowing users + to customize how data is staged previous to executing the usual dcp.save path in parallel. + The expected order of operations (concretely defined in `torch.distributed.state_dict_saver.async_save`) + is the following: + + 1. AsyncStager.stage_data(state_dict): + This call gives the AsyncStager the opportunity to 'stage' + the state_dict. The expectation and purpose of staging in this context is to create a "training-safe" + representation of the state dict, meaning that any updates to module data after staging is complete + should not be reflected in the state dict returned from this method. For example, in the default + case a copy of the entire state dict is created on CPU RAM and returned here, allowing users + to continue training without risking changes to data which is being serialized. + + 2. dcp.save is called on the state_dict returned from stage in parallel. This call is responsible + for serializing the state_dict and writing it to storage. + + 3. If AsyncStager.should_synchronize_after_execute is True, this method will be called immediately after + the serialization thread starts and before returning from dcp.async_save. If this is set to False, + the assumption is the user has defined a custom synchronization point for the the purpose of further + optimizing save latency in the training loop (for example, by overlapping staging with the + forward/backward pass), and it is the respondsibility of the user to call `AsyncStager.synchronize_staging` + at the appropriate time. + + """ + + # default to True since the common case is to stage synchronously + _synchronize_after_execute: bool = True + + @property + def should_synchronize_after_execute(self) -> bool: + """ + Whether to synchronize after executing the stage. + """ + + return self._synchronize_after_execute + + def stage(self, state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE: + """ + Returns a "staged" copy of `state_dict`. The expectation of the staged copy is that it is + inoculated from any updates incurred after the stage call is complete. + """ + raise NotImplementedError( + f"{self.__class__.__name__} must implement stage method" + ) + + def synchronize_staging(self) -> None: + """ + In the case `stage` is async in some way, this method should be called to ensure staging + is complete and it is safe to begin modifying the original `state_dict` + """ + + +class BlockingAsyncStager(AsyncStager): + """ + An implementation of AsyncStager which stages the state_dict on CPU RAM and blocks until the copy is complete. + This implementation also provides an option to optimize stage latency using pinned memory. + + N.B. synchronize_staging is a no-op in this case. + + + """ + + # default to True since the common case is to stage synchronously + _synchronize_after_execute: bool = False + + def __init__( + self, + cache_staged_state_dict: bool = False, + type_check: bool = False, + ): + """ + Initializes the BlockingAsyncStager. + + Args: + cache_staged_state_dict: Whether to cache the staged state_dict. This option decreases staging latency + at the cost of increases memory usage. Additionally, if this parameter is set to True, it's the expectation + that the stager is maintained and reused for multiple dcp.async_save calls. Default to False. + type_check: Whether to perform a type check during cpu_offload. Defaults to False. + + """ + self.cache_staged_state_dict = cache_staged_state_dict + self.type_check = type_check + self.state_dict_cache: Optional[STATE_DICT_TYPE] = None + + def stage(self, state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE: + """ + Returns a copy of `state_dict` on the CPU. + """ + + if not self.cache_staged_state_dict: + staged_state_dict = _create_cpu_state_dict(state_dict) + _copy_state_dict(state_dict, staged_state_dict, type_check=self.type_check) + return staged_state_dict + + if self.state_dict_cache is None: + self.state_dict_cache = _create_cpu_state_dict(state_dict, pin_memory=True) + return _copy_state_dict(state_dict, self.state_dict_cache) + + def synchronize_staging(self) -> None: + """ + No-op function, since staging is blocking. + """ diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/state_dict.py b/phivenv/Lib/site-packages/torch/distributed/checkpoint/state_dict.py new file mode 100644 index 0000000000000000000000000000000000000000..f12f2993e99d102e695ef7b4ee5cbab4ea0882a2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/checkpoint/state_dict.py @@ -0,0 +1,1498 @@ +# mypy: allow-untyped-defs +import contextlib +import functools +import gc +import warnings +from collections.abc import Generator, Iterable +from dataclasses import asdict, dataclass, field +from itertools import chain +from typing import Any, Callable, cast, no_type_check, Optional, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed._shard.sharded_tensor import ShardedTensor +from torch.distributed._state_dict_utils import ( + _broadcast_state_dict, + _distribute_state_dict, + _flatten_state_dict, + _gather_state_dict, + _offload_state_dict_to_cpu, + _unflatten_state_dict, +) +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + _CHECKPOINT_PREFIX, +) +from torch.distributed.fsdp import ( + FullOptimStateDictConfig, + FullStateDictConfig, + FullyShardedDataParallel as FSDP, + OptimStateDictConfig, + ShardedOptimStateDictConfig, + ShardedStateDictConfig, + StateDictConfig, + StateDictType, +) +from torch.distributed.fsdp._common_utils import ( + _get_module_fsdp_state_if_fully_sharded_module, + FSDP_WRAPPED_MODULE, +) +from torch.distributed.tensor import DTensor +from torch.nn.modules.module import _IncompatibleKeys +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils._pytree import tree_map_only + + +__all__ = [ + "FQNS_T", + "PrimitiveType", + "ValueType", + "DictValueType", + "ListDictValueType", + "OptimizerStateType", + "StateDictOptions", + "get_model_state_dict", + "get_optimizer_state_dict", + "get_state_dict", + "set_model_state_dict", + "set_optimizer_state_dict", + "set_state_dict", +] + + +_FLAT_PARAM = "_flat_param" +_PG = "param_groups" +_PARAMS = "params" +_STATE = "state" + +FQNS_T = set[str] +PrimitiveType = Union[DTensor, ShardedTensor, torch.Tensor, int, float, str] +ValueType = Union[ + PrimitiveType, list[PrimitiveType], tuple[PrimitiveType], dict[str, "ValueType"] +] +DictValueType = dict[str, ValueType] +ListDictValueType = list[DictValueType] +OptimizerStateType = dict[str, Union[DictValueType, ListDictValueType]] + + +_patched_state_dict: set[Callable] = set() + + +@contextlib.contextmanager +def _gc_context(): + is_enabled = gc.isenabled() + gc.disable() + try: + yield + finally: + if is_enabled: + gc.enable() + + +@dataclass +class StateDictOptions: + """ + This dataclass specifies how get_state_dict/set_state_dict will work. + + - ``full_state_dict``: if this is set to True, all the tensors in the + returned state_dict will be gathered. No ShardedTensor and DTensor + will be in the returned state_dict. + + - ``cpu_offload``: offload all the tensors to cpu. To prevent CPU OOM, if + ``full_state_dict`` is also true, then only the rank0 will get the + state_dict and all other ranks will get empty state_dict. + + - ``ignore_frozen_params``: if the value is True, the returned state_dict + won't contain any frozen parameters -- the ``requires_grad`` is False. + The default value is False. + + - ``keep_submodule_prefixes`` (deprecated): when ``submodules`` is not None, this option + indicates whether to keep the submodule prefixes from the state_dict keys. + or example, if the submodule is ``module.pretrain`` and the full FQN of + the parameter is ``pretrain.layer1.weight`` of the param. When this option + is True, the parameter's key in the returned state_dict will be + ``pretrain.layer1.weight``. If the options is False, the key will be + ``layer1.weight``. + Note that if ``keep_submodule_prefixes`` is False, there may be conflicted + FQNs, hence there should be only one submodule in ``submodules``. + + - ``strict``: the ``strict`` option when ``set_state_dict`` calls + model.load_state_dict(). + + - ``broadcast_from_rank0``: when the option is True, rank0 should receive a + full state_dict and will broadcast the tensors in the state_dict/ + optim_state_dict one by one to other ranks. Other ranks will receive + the tensors and shard according to the local shards in the model and + optimizer. ``full_state_dict`` must be set to True when using this option. + This option currently only supports DTensor, not the legacy ShardedTensor. + """ + + full_state_dict: bool = False + cpu_offload: bool = False + ignore_frozen_params: bool = False + keep_submodule_prefixes: bool = True + strict: bool = True + broadcast_from_rank0: bool = False + flatten_optimizer_state_dict: bool = False + dsd_fqn_modifiers: str = "_fqn_modifiers" + + +@dataclass +class _StateDictInfo(StateDictOptions): + fqn_param_mapping: dict[ + Union[str, torch.Tensor], + Union[FQNS_T, torch.Tensor], + ] = field(default_factory=dict) + shared_params_mapping: dict[ + Union[str, torch.Tensor], + Union[FQNS_T, torch.Tensor], + ] = field(default_factory=dict) + submodule_prefixes: set[str] = field(default_factory=set) + handle_model: bool = True + handle_optim: bool = True + fsdp_context: Callable = contextlib.nullcontext + fsdp_modules: list[nn.Module] = field(default_factory=list) + + +def _get_fqns( + model: nn.Module, + name: str, + dsd_fqn_modifiers: str = "_fqn_modifiers", + skip_ddp_prefix: bool = True, + skip_compiler_prefix: bool = True, +) -> FQNS_T: + """ + This API is used to convert the name of a parameter to the FQNs. For FSDP + without `use_orig_params`, the name of FlatParameter can be mapped to + multiple original parameters. As a result, the return type of this function + is `set[str]`. + + Args: + module (nn.Module): the root model. + name (str): the name + skip_ddp_prefix (bool): whether to skip DDP's `module` prefix + + Returns: + The canonical FQNs based on the model traversal. + """ + + # Remove the checkpoint prefix, if it exists. + name = name.replace(_CHECKPOINT_PREFIX, "") + if "." not in name: + return {name} + + obj_names = name.split(".") + fqn_obj_names = [] + curr_obj = model + for i, curr_obj_name in enumerate(obj_names): + if isinstance(curr_obj, DDP): + assert curr_obj_name == "module" + curr_obj = curr_obj.module + if not skip_ddp_prefix: + fqn_obj_names.append(curr_obj_name) + elif isinstance(curr_obj, FSDP): + if i < len(obj_names) - 1 and obj_names[i + 1] == _FLAT_PARAM: + prefix = ".".join(fqn_obj_names) + flat_param = getattr(curr_obj, _FLAT_PARAM) + if prefix: + prefix = f"{prefix}." + return {f"{prefix}{fqn}" for fqn in flat_param._fqns} + curr_obj = getattr(curr_obj, FSDP_WRAPPED_MODULE) + if curr_obj_name != FSDP_WRAPPED_MODULE: + fqn_obj_names.append(curr_obj_name) + curr_obj = getattr(curr_obj, curr_obj_name) + elif isinstance(curr_obj, torch._dynamo.eval_frame.OptimizedModule): + assert curr_obj_name == "_orig_mod" + curr_obj = curr_obj._orig_mod + if not skip_compiler_prefix: + fqn_obj_names.append(curr_obj_name) + else: + # In some modules, _fqn_modifiers would not shown in the state_dict keys, + # skip them in the fqn to ensure load stat dict successfully for them. + if hasattr(curr_obj, dsd_fqn_modifiers): + if removed_fqn := getattr(curr_obj, dsd_fqn_modifiers)().get( + curr_obj_name + ): + if hasattr(curr_obj, removed_fqn): + curr_obj = getattr(curr_obj, removed_fqn) + fqn_obj_names.append(curr_obj_name) + if curr_obj_name == nn.modules.module._EXTRA_STATE_KEY_SUFFIX: + if i != len(obj_names) - 1: + raise RuntimeError("Expect `_extra_state` to be the last obj name") + else: + curr_obj = getattr(curr_obj, curr_obj_name) + + return {".".join(fqn_obj_names).replace(_CHECKPOINT_PREFIX, "")} + + +class _EXTRA_STATE: + pass + + +def _iterate_valid_model_state(model, dsd_fqn_modifiers="_fqn_modifiers"): + visited_modules: set[nn.Module] = set() + + def recurse(module: nn.Module, curr_fqn: str) -> Generator: + visited_modules.add(module) + + curr_fqn = f"{curr_fqn}." if curr_fqn else "" + for name, submodule in module.named_children(): + if submodule in visited_modules: + continue + # if user have state_dict_hooks in their model, they can add the state_dict key changes + # at dsd_fqn_modifiers in input to align with the function of state_dict_hook + if ( + hasattr(module, dsd_fqn_modifiers) + and name in getattr(module, dsd_fqn_modifiers)().values() + ): + # skip _fqn_modifiers here thus remove the last `.` added + new_fqn = curr_fqn[:-1] + else: + new_fqn = f"{curr_fqn}{name}" + yield from recurse(submodule, new_fqn) + + for name, obj in chain( + module.named_buffers(recurse=False), module.named_parameters(recurse=False) + ): + if name in module._non_persistent_buffers_set: + continue + new_fqn = f"{curr_fqn}{name}" + yield new_fqn, obj + + if ( + getattr(module.__class__, "get_extra_state", nn.Module.get_extra_state) + != nn.Module.get_extra_state + ): + new_fqn = f"{curr_fqn}{nn.modules.module._EXTRA_STATE_KEY_SUFFIX}" + yield new_fqn, _EXTRA_STATE() + + yield from recurse(model, "") + + +def _verify_options( + model: nn.Module, + optims: tuple[torch.optim.Optimizer, ...], + optim_only: bool, + *, + submodules: Optional[set[nn.Module]] = None, + options: Optional[StateDictOptions] = None, +) -> _StateDictInfo: + """ + Verify the model and options passed by the user and generates _StateDictInfo. + """ + if submodules: + warnings.warn( + "Getting submodules only model/optim state_dict is deprecated and " + "will be removed in 2.5. This feature can be achieved by manually " + "filtering out the state_dict returned from get_state_dict.", + FutureWarning, + ) + if optim_only and not optims: + raise RuntimeError( + "Optimizers are not passed in but optim_only is set to True." + ) + + options = options or StateDictOptions() + + fqn_param_mapping: dict[ + Union[str, torch.Tensor], Union[set[str], torch.Tensor] + ] = {} + shared_params_mapping: dict[ + Union[str, torch.Tensor], Union[set[str], torch.Tensor] + ] = {} + for name, param in _iterate_valid_model_state(model): + if isinstance(param, _EXTRA_STATE): + continue + + fqns = _get_fqns(model, name) + fqn = fqn_param_mapping.get(param, None) + if fqn is not None: + cast(set[str], fqn_param_mapping[param]).update(fqns) + shared_params_mapping[param] = fqn_param_mapping[param] + else: + # We need to do copy as _get_fqns is lru_cached + fqn_param_mapping[param] = fqns.copy() + for fqn in fqns: + if not isinstance(param, _EXTRA_STATE): + fqn_param_mapping[fqn] = param + + for param_, fqns_ in list(shared_params_mapping.items()): + for fqn in fqns_: + shared_params_mapping[fqn] = cast(torch.Tensor, param_) + + submodule_prefixes: set[str] = set() + if submodules: + submodules = set(submodules) + for name, module in model.named_modules(): + if module not in submodules: + continue + fqns = _get_fqns(model, name) + assert len(fqns) == 1, "Submodule FQN should only have 1 instance" + submodule_prefixes.update(f"{fqn}." for fqn in fqns) + + if options.broadcast_from_rank0 and not options.full_state_dict: + raise ValueError( + "full_state_dict must be True when broadcast_from_rank0 is True." + ) + fsdp_modules = FSDP.fsdp_modules(model) + state_dict_config: StateDictConfig + optim_state_dict_config: OptimStateDictConfig + fsdp_context: Callable + if fsdp_modules: + # FSDP API only work if at least one FSDP instance exists. + if options.full_state_dict: + state_dict_config = FullStateDictConfig( + offload_to_cpu=options.cpu_offload, rank0_only=options.cpu_offload + ) + optim_state_dict_config = FullOptimStateDictConfig( + offload_to_cpu=options.cpu_offload, + rank0_only=(options.cpu_offload or options.broadcast_from_rank0), + ) + state_dict_type = StateDictType.FULL_STATE_DICT + else: + state_dict_config = ShardedStateDictConfig( + offload_to_cpu=options.cpu_offload, + ) + optim_state_dict_config = ShardedOptimStateDictConfig( + offload_to_cpu=options.cpu_offload, + ) + state_dict_type = StateDictType.SHARDED_STATE_DICT + + @contextlib.contextmanager + def fsdp_state_dict_type_without_warning( + module, + state_dict_type, + state_dict_config, + optim_state_dict_config, + ): + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="FSDP.state_dict_type", category=FutureWarning + ) + with FSDP.state_dict_type( + module=module, + state_dict_type=state_dict_type, + state_dict_config=state_dict_config, + optim_state_dict_config=optim_state_dict_config, + ): + yield + + fsdp_context = functools.partial( + fsdp_state_dict_type_without_warning, + module=model, + state_dict_type=state_dict_type, + state_dict_config=state_dict_config, + optim_state_dict_config=optim_state_dict_config, + ) + else: + fsdp_context = contextlib.nullcontext + + return _StateDictInfo( + **asdict(options), + fqn_param_mapping=fqn_param_mapping, + shared_params_mapping=shared_params_mapping, + submodule_prefixes=submodule_prefixes, + fsdp_context=fsdp_context, + fsdp_modules=cast(list[nn.Module], fsdp_modules), + handle_model=not optim_only, + handle_optim=(len(optims) > 0), + ) + + +def _verify_state_dict( + model_state_dict: dict[str, ValueType], + optim_state_dict: OptimizerStateType, + info: _StateDictInfo, +) -> None: + for module in info.fsdp_modules: + fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module) + assert fsdp_state is not None, "Expected a fsdp_state with a fsdp module." + + # Verify if the model_state_dict and optim_state_dict are valid. This API + # should give the users an explicit error message to debug or report. + if ( + info.handle_model + and not model_state_dict + and not info.submodule_prefixes + and not info.ignore_frozen_params + and not (info.cpu_offload and info.full_state_dict) + and info.strict + and not info.broadcast_from_rank0 + ): + raise RuntimeError( + "The option indicates that model state_dict is required to save " + "or load, but model state_dict is empty." + f"rank = {dist.get_rank()=}." + ) + + if info.handle_optim: + if ( + not optim_state_dict + and not (info.cpu_offload and info.full_state_dict) + and (not info.broadcast_from_rank0) + ): + raise RuntimeError( + "The option indicates that model state_dict is required to save, " + f"or load but optim state_dict is empty. {optim_state_dict}" + ) + + for key in model_state_dict.keys(): + if _FLAT_PARAM in key: + raise RuntimeError( + f"{key} contains {_FLAT_PARAM}. This can happen if the model " + "is not the root module." + ) + + +def _state_dict_fn(obj: Union[nn.Module, torch.optim.Optimizer], api: str) -> Callable: + call = getattr(obj, api) + if call in _patched_state_dict: + call = functools.partial(getattr(obj.__class__, api), self=obj) + return call + + +def _maybe_full_or_cpu_state_dict( + state_dict: dict[str, Any], info: _StateDictInfo +) -> dict[str, Any]: + if info.full_state_dict: + ranks_only = ( + () + if (not info.cpu_offload or not torch.distributed.is_initialized()) + else (0,) + ) + return _gather_state_dict( + state_dict, cpu_offload=info.cpu_offload, ranks_only=ranks_only + ) + elif info.cpu_offload: + return _offload_state_dict_to_cpu(state_dict) + else: + return state_dict + + +@torch.no_grad() +def _get_model_state_dict( + model: nn.Module, info: _StateDictInfo +) -> dict[str, ValueType]: + if not info.handle_model: + return {} + + with info.fsdp_context(): + state_dict = _state_dict_fn(model, "state_dict")() + + for key in list(state_dict.keys()): + fqns = _get_fqns(model, key) + assert len(fqns) == 1, (key, fqns) + fqn = next(iter(fqns)) + if fqn != key: + # As we only support FSDP, DDP, and TP, the only cases are + # wrapper-based DDP and compiler. Verify if the assumption + # is correct. + def verify(key, fqn) -> bool: + if len(fqn) >= len(key): + return False + fqn_split = fqn.split(".") + key_split = key.split(".") + fqn_idx = 0 + for key_idx, key_name in enumerate(key_split): + if key_name == fqn_split[fqn_idx]: + fqn_idx += 1 + if fqn_idx == len(fqn_split): + return key_idx == len(key_split) - 1 + elif key_name in ("module", "_orig_mod"): + continue + else: + return False + return True + + if not verify(key, fqn): + raise RuntimeError(f"An unexpected key, {key}, exists. FQN is {fqn}") + state_dict[fqn] = state_dict.pop(key) + + if info.submodule_prefixes: + new_state_dict: dict[str, ValueType] = {} + # TODO: make this faster. + for fqn in state_dict.keys(): + for prefix in info.submodule_prefixes: + if not fqn.startswith(prefix): + continue + if info.keep_submodule_prefixes: + new_state_dict[fqn] = state_dict[fqn] + else: + new_fqn = fqn[len(prefix) :] + new_state_dict[new_fqn] = state_dict[fqn] + state_dict = new_state_dict + + if info.ignore_frozen_params: + for key, param in model.named_parameters(): + if param.requires_grad: + continue + fqns = _get_fqns(model, key) + for fqn in fqns: + state_dict.pop(fqn) + + return _maybe_full_or_cpu_state_dict(state_dict, info) + + +@torch.no_grad() +def _load_model_state_dict( + model: nn.Module, + state_dict: dict[str, ValueType], + info: _StateDictInfo, +) -> _IncompatibleKeys: + if not info.handle_model or (not state_dict and not info.broadcast_from_rank0): + return _IncompatibleKeys({}, {}) + + local_state_dict = {} + for key, value in _iterate_valid_model_state(model, info.dsd_fqn_modifiers): + fqns = _get_fqns(model, key, info.dsd_fqn_modifiers) + fqns_with_prefix = _get_fqns( + model, + key, + info.dsd_fqn_modifiers, + skip_ddp_prefix=False, + skip_compiler_prefix=False, + ) + + for fqn, fqn_with_prefix in zip(fqns, fqns_with_prefix): + if ( + not info.broadcast_from_rank0 or dist.get_rank() == 0 + ) and fqn != fqn_with_prefix: + load_value = state_dict.pop(fqn, None) + if load_value is None: + if info.strict: + raise RuntimeError(f"Missing key: {fqn}.") + else: + state_dict[fqn_with_prefix] = load_value + local_state_dict[fqn_with_prefix] = value + + assign = False + if info.broadcast_from_rank0 or info.full_state_dict: + devices = set() + for key, value in local_state_dict.items(): + if torch.is_tensor(value) and value.dim() > 0: + devices.add(value.device) + # In lora state_dict, there could be multiple devices, with meta device inside. + # Take the other device in the broadcast/distribtue, and set assign to True + if torch.device("meta") in devices: + devices.remove(torch.device("meta")) + assign = True + if len(devices) == 0: + devices.add(dist.distributed_c10d._get_pg_default_device()) + elif len(devices) > 1: + raise ValueError("Multiple devices found") + + if info.broadcast_from_rank0: + _broadcast_state_dict( + state_dict, + local_state_dict, + device=devices.pop(), + strict=info.strict, + cpu_offload=info.cpu_offload, + ) + elif info.full_state_dict: + _distribute_state_dict(state_dict, local_state_dict, device=devices.pop()) + state_dict.update(local_state_dict) + + with info.fsdp_context(): + return cast( + _IncompatibleKeys, + _state_dict_fn(model, "load_state_dict")( + state_dict=state_dict, strict=info.strict, assign=assign + ), + ) + + +def _init_optim_state(optim: torch.optim.Optimizer) -> None: + """ + Initialize optim states by calling the step() with zero grads. + """ + if optim.state: + # The optimizer state is initialized. + return + + # There are some stateless optimizers like SGD. These optimizer will + # not return in the above condition. So if gradients exist, we should also + # return. If gradients do not exist, the following initialization should + # not disturb SGD because the gradients and lr are both zero. + for param_group in optim.param_groups: + for param in param_group[_PARAMS]: + if param.grad is not None: + return + + for param_group in optim.param_groups: + for param in param_group[_PARAMS]: + if param.requires_grad: + param.grad = torch.zeros_like(param) + + # Some optimizers will update parameters regardless of grads due to lr, so + # make lr to zero when calling `step()`. + lrs = [] + for param_group in optim.param_groups: + if "lr" in param_group: + lrs.append(param_group["lr"]) + param_group["lr"] = ( + torch.tensor(0.0) + if isinstance(param_group["lr"], torch.Tensor) + else 0.0 + ) + optim.step(closure=None) + # Whether to recover the "lr" should not matter too much as we will + # restore checkpointing later. + for param_group in optim.param_groups: + if "lr" in param_group: + param_group["lr"] = lrs.pop(0) + optim.zero_grad(set_to_none=True) + + +def _flatten_optim_state_dict(state_dict: OptimizerStateType) -> dict[str, ValueType]: + """ + This API flattens the optimizer state_dict to support optimizer resharding for + MPMD, e.g., pipeline parallelism. + + Without the API, the original optimizer state_dict looks like: + { + "state": { + "layer1.weight": { + "step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor + }, + "layer2.weight": { + "step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor + }, + }, + "param_group": [ + { + "lr": 0.0, + "betas": (0.9, 0.95), ..., + "params": ["layer1.weight", "layer2.weight"] + } + ] + } + + With this API, the optimizer state_dict looks like: + { + "state.layer1.weight.step": 10, + "state.layer2.weight.step": 10, + "state.layer1.weight.exp_avg": SomeTensor, + "state.layer2.weight.exp_avg": SomeTensor, + "state.layer1.weight.exp_avg_sq": SomeTensor, + "state.layer2.weight.exp_avg_sq": SomeTensor, + "param_group.layer1.weight.lr" : 0.1, + "param_group.layer2.weight.lr" : 0.1, + "param_group.layer1.weight.betas" : (0.9, 0.95), + "param_group.layer2.weight.betas" : (0.9, 0.95), + } + + Note that if any of the value is a container, like the betas in the example, + this API won't flattent it. + """ + + def _raise_if_type_not_supported(v): + if not isinstance(v, (torch.Tensor, int, float)): + raise NotImplementedError( + "Flattening optimizer state_dict only supports " + "tensor, int, float states now. " + f"Type is {type(v)}." + ) + + ret: dict[str, ValueType] = {} + for fqn, state in cast(DictValueType, state_dict[_STATE]).items(): + for k, v in cast(DictValueType, state).items(): + _raise_if_type_not_supported(v) + ret[f"{_STATE}.{fqn}.{k}"] = v + + for param_group in cast(ListDictValueType, state_dict[_PG]): + fqns = param_group.pop(_PARAMS) + for fqn in cast(list[str], fqns): + for k, v in param_group.items(): + ret[f"{_PG}.{fqn}.{k}"] = v + return ret + + +def _unflatten_optim_state_dict( + optim: torch.optim.Optimizer, + state_dict: dict[str, ValueType], + info: _StateDictInfo, +) -> OptimizerStateType: + """ + This API unflattens the state_dict generated by _flatten_optim_state_dict(). + See the docstring of _flatten_optim_state_dict() for more detail. + """ + state: DictValueType = {} + pg_state: ListDictValueType = [] + return_osd: OptimizerStateType = {_STATE: state, _PG: pg_state} + + for param_group in optim.param_groups: + pg_state.append({_PARAMS: []}) + for param in param_group[_PARAMS]: + for fqn in info.fqn_param_mapping[param]: + # If a parameter is shared, only one of the FQN will be used. + # So we need to verify which if this fqn is actually used in + # the state_dict. + if fqn in info.shared_params_mapping: + in_params = False + for k in param_group.keys(): + if k == _PARAMS: + continue + flatten_key = f"{_PG}.{fqn}.{k}" + if flatten_key in state_dict: + in_params = True + break + else: + in_params = True + + if not in_params: + continue + + params = pg_state[-1][_PARAMS] + assert isinstance(params, list) # typing + params.append(fqn) + if not param.requires_grad: + continue + state[fqn] = {} + for state_name in optim.state[param].keys(): + cast(DictValueType, state[fqn])[state_name] = state_dict[ + f"{_STATE}.{fqn}.{state_name}" + ] + + first_param_fqn = cast(list[str], pg_state[-1][_PARAMS])[0] + for k in param_group.keys(): + if k == _PARAMS: + continue + value = state_dict[f"{_PG}.{first_param_fqn}.{k}"] + if k not in pg_state[-1]: + pg_state[-1][k] = value + elif pg_state[-1][k] != value: + raise RuntimeError( + "All the parameters in the same parameter group should have " + f"the same saved param_group value. But {first_param_fqn}.{k} " + f"is {value} while other(s) is {pg_state[-1][k]}." + ) + + return return_osd + + +@torch.no_grad() +def _get_optim_state_dict( + model: nn.Module, + optimizers: tuple[torch.optim.Optimizer, ...], + info: _StateDictInfo, +) -> OptimizerStateType: + if not info.handle_optim: + return {} + + optim_state_dict: OptimizerStateType = {_STATE: {}, _PG: []} + for optim in optimizers: + _init_optim_state(optim) + osd = _state_dict_fn(optim, "state_dict")() + if info.fsdp_modules: + with info.fsdp_context(): + osd = FSDP.optim_state_dict(model, optim, osd) + + # We need to specially handle FlatParameter FSDP as + # FlatParameter FSDP converts the FQNs. + # There are no easy ways to do this conversion systematically. + # We can only use a string replacement without correctness check. + if not osd: + continue + for k in list(osd[_STATE].keys()): + if "_orig_mod" in k: + osd[_STATE][k.replace("_orig_mod.", "")] = osd[_STATE].pop(k) + for g in osd[_PG]: + params = [k.replace("_orig_mod.", "") for k in g[_PARAMS]] + g[_PARAMS] = params + else: + params = list(chain.from_iterable(g[_PARAMS] for g in optim.param_groups)) + param_pid_mapping = dict(zip(params, range(len(params)))) + fqn_pid_mapping = {} + for key, param in model.named_parameters(): + fqns = _get_fqns(model, key) + assert len(fqns) == 1 + fqn = next(iter(fqns)) + if param not in param_pid_mapping: + continue + pid = param_pid_mapping[param] + fqn_pid_mapping[fqn] = pid + fqn_pid_mapping[pid] = fqn + + for key in list(osd[_STATE].keys()): + fqn = fqn_pid_mapping[key] + osd[_STATE][fqn] = osd[_STATE].pop(key) + + for group in osd[_PG]: + group[_PARAMS] = [fqn_pid_mapping[pid] for pid in group[_PARAMS]] + + if not osd: + continue + + cast(DictValueType, optim_state_dict[_STATE]).update(osd[_STATE]) + cast(ListDictValueType, optim_state_dict[_PG]).extend(osd[_PG]) + + if info.flatten_optimizer_state_dict: + optim_state_dict = cast( + OptimizerStateType, _flatten_optim_state_dict(optim_state_dict) + ) + + return _maybe_full_or_cpu_state_dict(optim_state_dict, info) + + +def _split_optim_state_dict( + model: nn.Module, + optim: torch.optim.Optimizer, + optim_state_dict: OptimizerStateType, + info: _StateDictInfo, +) -> OptimizerStateType: + """ + Extract the corresponding optim state_dict from ``optim_state_dict`` for + ``optim`` and return the result optim state_dict. + + Args: + model (nn.Module): the root model. + optim (torch.optim.Optimizer): the optimizer. + optim_state_dict (Dict[str, ValueType]): the superset optim state_dict that + contains the optim state_dict of ``optim``. + info (_StateDictInfo): state dict information. + + Returns: + The optim state_dict of ``optim``. + """ + + state: DictValueType = {} + pg_state: ListDictValueType = [] + return_osd: OptimizerStateType = {_STATE: state, _PG: pg_state} + pg_mapping: dict[int, int] = {} + + if all( + isinstance(k, int) for k in cast(DictValueType, optim_state_dict[_STATE]).keys() + ): + return optim_state_dict + + for param_group in optim.param_groups: + pg_state.append({_PARAMS: []}) + for param in param_group[_PARAMS]: + for fqn in info.fqn_param_mapping[param]: + if fqn in info.shared_params_mapping: + in_params = False + for loaded_param_group in cast( + ListDictValueType, optim_state_dict[_PG] + ): + if fqn in cast(list[str], loaded_param_group[_PARAMS]): + in_params = True + break + else: + in_params = True + if not in_params: + continue + + params = pg_state[-1][_PARAMS] + assert isinstance(params, list) + params.append(fqn) + if param.requires_grad: + state[fqn] = cast(DictValueType, optim_state_dict[_STATE])[fqn] + for loaded_param_group in cast( + ListDictValueType, optim_state_dict[_PG] + ): + if fqn in cast(list[str], loaded_param_group[_PARAMS]): + pg_mapping[id(loaded_param_group)] = len(return_osd[_PG]) - 1 + + if len(param_group[_PARAMS]) == 0: + # Param_group with empty params. + ret = [] + for loaded_param_group in cast(ListDictValueType, optim_state_dict[_PG]): + if len(cast(list[str], loaded_param_group[_PARAMS])) == 0: + ret.append(loaded_param_group) + if len(ret) != 1: + raise ValueError( + "There are param groups that have zero parameters. " + "In such a case, DSD only support exactly one param group " + "with zero parameters." + "But the loaded state_dict has zero or more than one param groups " + "that have zero parameters." + ) + if len(optim_state_dict[_PG]) != len(optim.param_groups): + raise ValueError( + "When there is a parameter group that has zero parameters, " + "multiple optimizers are not supported." + ) + pg_mapping[id(loaded_param_group)] = len(return_osd[_PG]) - 1 + + for param_group in cast(ListDictValueType, optim_state_dict[_PG]): + pg_idx = pg_mapping.get(id(param_group), -1) + if pg_idx == -1: + continue + + for key, value in param_group.items(): + if key == _PARAMS: + continue + # TODO: check if value is the same if exists. + pg_state[pg_idx][key] = value + + return return_osd + + +@torch.no_grad() +def _load_optim_state_dict( + model: nn.Module, + optimizers: tuple[torch.optim.Optimizer, ...], + state_dict: OptimizerStateType, + info: _StateDictInfo, +) -> None: + if not info.handle_optim: + return + + for optim in optimizers: + _init_optim_state(optim) + if state_dict: + if _STATE in state_dict: + optim_state_dict = _split_optim_state_dict( + model, optim, state_dict, info + ) + else: + optim_state_dict = _unflatten_optim_state_dict( + optim, cast(dict[str, ValueType], state_dict), info + ) + else: + optim_state_dict = {} + if info.fsdp_modules: + # We need to specially handle FlatParameter FSDP as + # FlatParameter FSDP converts the FQNs. + for original_fqn, _ in model.named_parameters(): + fqns = _get_fqns(model, original_fqn) + fqns_with_compiler = _get_fqns( + model, original_fqn, skip_compiler_prefix=False + ) + if fqns == fqns_with_compiler: + continue + + assert len(fqns) == 1 + fqn = fqns.pop() + fqn_with_compiler = fqns_with_compiler.pop() + for g in optim_state_dict[_PG]: + val = cast(dict[str, Any], g) + params = [ + key.replace(fqn, fqn_with_compiler) for key in val[_PARAMS] + ] + val[_PARAMS] = params + osd_state = cast(DictValueType, optim_state_dict[_STATE]) + for k in list(osd_state.keys()): + if fqn in k: + osd_state[k.replace(fqn, fqn_with_compiler)] = osd_state.pop(k) + + with info.fsdp_context(): + optim_state_dict = FSDP.optim_state_dict_to_load( + model, optim, optim_state_dict + ) + elif info.full_state_dict: + info.full_state_dict = False + local_state_dict = _get_optim_state_dict(model, (optim,), info) + info.full_state_dict = True + device = None + + def _device(t): + if t.dim() > 0: + nonlocal device + if device is None: + device = t.device + elif device != t.device: + raise ValueError("Device mismatch") + return t + + _ = tree_map_only(torch.Tensor, _device, local_state_dict) + assert device is not None + flatten_osd, osd_mapping = _flatten_state_dict(optim_state_dict) + flatten_local_osd, local_osd_mapping = _flatten_state_dict(local_state_dict) + if info.broadcast_from_rank0: + _broadcast_state_dict(flatten_osd, flatten_local_osd, device=device) + else: + _distribute_state_dict(flatten_osd, flatten_local_osd, device=device) + # The modifications listed seek to address the problem where optim might possess + # dissimilar parameters in comparison to optim_state_dict. This is achieved by + # incorporating differential parameters within local, which may result in optim + # having additional parameters ultimately. + for optim_key in flatten_osd.keys(): + if optim_key not in flatten_local_osd: + assert optim_key in osd_mapping + flatten_local_osd[optim_key] = flatten_osd[optim_key] + local_osd_mapping[optim_key] = osd_mapping[optim_key] + optim_state_dict = _unflatten_state_dict( + flatten_local_osd, local_osd_mapping + ) + for pg in optim_state_dict[_PG]: + if _PARAMS not in pg: + cast(dict[str, ValueType], pg)[_PARAMS] = [] + + # Note that we do not have to convert the FQN back to param id here if + # order in optim.param_groups[idx][_PARAMS] is the same as the one in + # optim_state_dict[_PG][idx][_PARAMS]. + _state_dict_fn(optim, "load_state_dict")(state_dict=optim_state_dict) + + +def get_model_state_dict( + model: nn.Module, + *, + submodules: Optional[set[nn.Module]] = None, + options: Optional[StateDictOptions] = None, +) -> dict[str, ValueType]: + """ + Return the model state_dict of ``model``. + + See ``get_state_dict`` for the detail usage. + + Args: + model (nn.Module): the nn.Module to the model. + submodules (deprecated): Optional[set[nn.Module]]: only return the model parameters + that belong to the submodules. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be returned. See + `StateDictOptions` for the details. + + Returns: + The state_dict for ``model``. + + :rtype: typing.Dict[str, ValueType] + """ + with _gc_context(): + info = _verify_options( + model, + (), + optim_only=False, + submodules=submodules, + options=options, + ) + model_state_dict = _get_model_state_dict(model, info) + _verify_state_dict(model_state_dict, {}, info) + return model_state_dict + + +def get_optimizer_state_dict( + model: nn.Module, + optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]], + *, + submodules: Optional[set[nn.Module]] = None, + options: Optional[StateDictOptions] = None, +) -> OptimizerStateType: + """ + Return the combined state_dict for optimizers. + + See ``get_state_dict`` for the detail usage. + + Args: + model (nn.Module): the nn.Module to the model. + optimizers (Union[None, Optimizer, Iterable[Optimizer]]): + The optimizers that are used to optimize ``model``. + submodules (deprecated): Optional[set[nn.Module]]: only return the model parameters + that belong to the submodules. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be returned. See + `StateDictOptions` for the details. + + Returns: + The state_dict for ``optimizers``. + + :rtype: OptimizerStateType + """ + with _gc_context(): + optimizers = ( + (optimizers,) + if isinstance(optimizers, torch.optim.Optimizer) + else tuple(optimizers) + ) + info = _verify_options( + model, + optimizers, + optim_only=True, + submodules=submodules, + options=options, + ) + optim_state_dict = _get_optim_state_dict(model, optimizers, info) + _verify_state_dict({}, optim_state_dict, info) + return optim_state_dict + + +def get_state_dict( + model: nn.Module, + optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]], + *, + submodules: Optional[set[nn.Module]] = None, + options: Optional[StateDictOptions] = None, +) -> tuple[dict[str, ValueType], OptimizerStateType]: + """ + Return the model state_dict and optimizers state_dict. + + ``get_state_dict`` can process any module that is parallelized by PyTorch + FSDP/fully_shard, DDP/replicate, tensor_parallel/parallelize_module, and any + combination of these parallelisms. The main functions of ``get_state_dict`` + are: 1.) returning a model and optimizer state_dict that can be resharded + with a different number of trainers and/or different parallelisms. + 2.) hiding the parallelism-specific state_dict APIs. Users don't have to call + these APIs. + 3.) sanity checking the result state_dict. + + The keys of the result state dictionary are the canonical FQNs (Fully + Qualified Names). A canonical FQN refers to the FQN based on a parameter's + position in an nn.Module hierarchy. More specifically, a canonical FQN to a + parameter is the FQN returned by ``module.named_parameters()`` or + ``module.named_buffers()`` when the module is not distributed by any + parallelisms. Since the optimizer internally uses parameter IDs to represent + a parameter, there will be a conversion from the parameter IDs to the + canonical FQNs when calling this API. + + ``get_state_dict`` can also process a module that is not parallelized. In + such a case, ``get_state_dict`` only performs one function -- converting the + optimizer parameter IDs to the canonical FQNs. + + Example: + >>> # xdoctest: +SKIP + >>> import torch + >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + >>> from torch.nn.parallel import DistributedDataParallel as DDP + >>> from torch.distributed.checkpoint.state_dict import get_state_dict + + >>> fsdp_model = FSDP(copy.deepcopy(model)) + >>> fsdp_optim = torch.optim.Adam(model.parameters(), lr=1e-3) + >>> ddp_model = DDP(copy.deepcopy(model)) + >>> ddp_optim = torch.optim.Adam(model.parameters(), lr=1e-3) + + + >>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim) + >>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict( + ... fsdp_model, fsdp_optim + ... ) + + >>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(), + >>> # the asserts will fail. + >>> assert ddp_state_dict == fsdp_state_dict + >>> assert ddp_optim_state == fsdp_optim_state_dict + + + Args: + model (nn.Module): the nn.Module to the model. + optimizers (Union[None, Optimizer, Iterable[Optimizer]]): + The optimizers that are used to optimize ``model``. + submodules (deprecated): Optional[set[nn.Module]]: only return the model parameters + that belong to the submodules. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be returned. See + `StateDictOptions` for the details. + + Returns: + ``Tuple`` that contain model state_dict and optimizer state_dict. + + :rtype: typing.Tuple[typing.Dict[str, ValueType], OptimizerStateType] + """ + + with _gc_context(): + optimizers = ( + (optimizers,) + if isinstance(optimizers, torch.optim.Optimizer) + else tuple(optimizers) + ) + info = _verify_options( + model, + optimizers, + optim_only=False, + submodules=submodules, + options=options, + ) + model_state_dict = _get_model_state_dict(model, info) + optim_state_dict = _get_optim_state_dict(model, optimizers, info) + _verify_state_dict(model_state_dict, optim_state_dict, info) + return model_state_dict, optim_state_dict + + +def _unflatten_model_state_dict( + model: nn.Module, + state_dict: Union[dict[nn.Module, dict[str, ValueType]], dict[str, ValueType]], +) -> dict[str, ValueType]: + if not state_dict: + return {} + + if isinstance(next(iter(state_dict.keys())), nn.Module): + warnings.warn( + "Passing model_state_dict as a ``Dict[nn.Module, Dict[str, Any]]``" + "is deprecated and will be removed in 2.5. If you need this " + "feature, please preprocessing the model_state_dict to achieve the " + "same functionality.", + FutureWarning, + ) + cast_state_dict = cast(dict[nn.Module, dict[str, ValueType]], state_dict) + new_state_dict: dict[str, ValueType] = {} + for submodule, sub_state_dict in cast_state_dict.items(): + for name, m in model.named_modules(): + if m != submodule: + continue + + fqns = _get_fqns(model, name) + assert len(fqns) == 1, "FQNs for a submodule should only have 1 element" + prefix = f"{next(iter(fqns))}." + new_state_dict.update( + {prefix + subfqn: value for subfqn, value in sub_state_dict.items()} + ) + return new_state_dict + else: + return cast(dict[str, ValueType], state_dict) + + +def set_model_state_dict( + model: nn.Module, + model_state_dict: dict[str, ValueType], + *, + options: Optional[StateDictOptions] = None, +) -> _IncompatibleKeys: + """Load the model state_dict. + + The counterpart of ``get_model_state_dict`` to set the state_dict to the + model. See ``set_state_dict`` for the detail usage. + + Args: + model (nn.Module): the nn.Module to the model. + model_state_dict: (Dict[str, ValueType]): + the model state_dict to load. If the key of the ``model_state_dict`` + is nn.Module, the key is a submodule of ``model`` and the value should + be the state_dict of the submodule. When loading the state_dict, + the prefix of the submodule will be append to the state_dict. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be loaded. See + `StateDictOptions` for the details. + + Returns: + ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: + * **missing_keys** is a list of str containing the missing keys + * **unexpected_keys** is a list of str containing the unexpected keys + + :type model_state_dict: typing.Dict[str, ValueType] + """ + model_state_dict: dict[str, ValueType] = _unflatten_model_state_dict( + model, model_state_dict + ) + with _gc_context(): + info = _verify_options(model, (), optim_only=False, options=options) + + _verify_state_dict(model_state_dict, {}, info) + return _load_model_state_dict(model, model_state_dict, info) + + +def set_optimizer_state_dict( + model: nn.Module, + optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]], + optim_state_dict: OptimizerStateType, + *, + options: Optional[StateDictOptions] = None, +) -> None: + """Load the optimizers state_dict. + + The counterpart of ``get_optimizer_state_dict`` to set the state_dict to the + optimizers. See ``set_state_dict`` for the detail usage. + + WARN: ``set_optimizer_state_dict`` can only be called before ``backward()`` or after + ``step()`` is called on the optimizers. Otherwise, the optimizer states won't be + initialized correctly. + + Args: + model (nn.Module): the nn.Module to the model. + optimizers (Union[Optimizer, Iterable[Optimizer]]): + The optimizers that are used to optimize ``model``. + optim_state_dict: OptimizerStateType: + the optimizer state_dict to load. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be loaded. See + `StateDictOptions` for the details. + + Returns: + None + + :type optim_state_dict: typing.OptimizerStateType + """ + with _gc_context(): + optimizers = ( + (optimizers,) + if isinstance(optimizers, torch.optim.Optimizer) + else tuple(optimizers) + ) + info = _verify_options(model, optimizers, optim_only=True, options=options) + + _verify_state_dict({}, optim_state_dict, info) + _load_optim_state_dict(model, optimizers, optim_state_dict, info) + + +def set_state_dict( + model: nn.Module, + optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]], + *, + model_state_dict: dict[str, ValueType], + optim_state_dict: OptimizerStateType, + options: Optional[StateDictOptions] = None, +) -> _IncompatibleKeys: + """Load the model state_dict and optimizers state_dict. + + The counterpart of ``get_state_dict`` to set the state_dict to the model and + optimizers. The given ``model_state_dict`` and ``optim_state_dict`` do not + have to be returned by ``get_state_dict`` but must meet the following + requirements: 1) all FQNs are canonical FQNs as defined in ``get_state_dict``, + 2) if a tensor is sharded, it must be either a ShardedTensor or DTensor, + 3) optimizer state_dict cannot contain the parameter IDs; the keys should be + the canonical FQNs. + + WARN: ``set_state_dict`` can only be called before ``backward()`` or after ``step()`` + is called on the optimizers. Otherwise, the optimizer states won't be initialized + correctly. + + Args: + model (nn.Module): the nn.Module to the model. + optimizers (Union[Optimizer, Iterable[Optimizer]]): + The optimizers that are used to optimize ``model``. + model_state_dict: (Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]]): + the model state_dict to load. If the key of the ``model_state_dict`` + is nn.Module, the key is a submodule of ``model`` and the value should + be the state_dict of the submodule. When loading the state_dict, + the prefix of the submodule will be append to the state_dict. + optim_state_dict: OptimizerStateType: + the optimizer state_dict to load. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be loaded. See + `StateDictOptions` for the details. + + Returns: + ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: + * **missing_keys** is a list of str containing the missing keys of the model state_dict. + * **unexpected_keys** is a list of str containing the unexpected keys of the model state_dict. + + :type model_state_dict: typing.Dict[str, ValueType] + :type optim_state_dict: typing.OptimizerStateType + """ + + model_state_dict: dict[str, ValueType] = _unflatten_model_state_dict( + model, model_state_dict + ) + with _gc_context(): + optimizers = ( + (optimizers,) + if isinstance(optimizers, torch.optim.Optimizer) + else tuple(optimizers) + ) + info = _verify_options( + model, optimizers, optim_only=not model_state_dict, options=options + ) + + _verify_state_dict(model_state_dict, optim_state_dict, info) + _load_optim_state_dict(model, optimizers, optim_state_dict, info) + return _load_model_state_dict(model, model_state_dict, info) + + +# TODO: correct the state_dict function signature. +# TODO: this API is not yet fully tested. Make it private +@no_type_check +def _patch_model_state_dict( + model: nn.Module, + *, + options: Optional[StateDictOptions] = None, +) -> None: + """Patch the ``state_dict`` and ``load_state_dict`` attributes of ``model``. + + Patch the ``state_dict`` and ``load_state_dict`` attributes of ``model`` to + be a partial function to call ``get_state_dict`` and ``set_state_dict``. + + Example: + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.checkpoint.state_dict import patch_model_state_dict + + model = fsdp(model) + patch_model_state_dict(model) + + Args: + model (nn.Module): the nn.Module to the model. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be loaded. See + `StateDictOptions` for the details. + Returns: + None + """ + + _state_dict_call = functools.partial( + get_model_state_dict, + model=model, + options=options, + ) + + def state_dict_call(): + return _state_dict_call() + + model.state_dict = state_dict_call + + _load_state_dict_call = functools.partial( + set_model_state_dict, + model=model, + options=options, + ) + + def load_state_dict_call(state_dict: dict[str, Any]): + _load_state_dict_call(model_state_dict=state_dict) + + model.load_state_dict = load_state_dict_call + + _patched_state_dict.add(state_dict_call) + _patched_state_dict.add(load_state_dict_call) + + +# TODO: correct the load_state_dict function signature. +# TODO: this API is not yet fully tested. Make it private +@no_type_check +def _patch_optimizer_state_dict( + model: nn.Module, + *, + optimizers: tuple[torch.optim.Optimizer, ...], + options: Optional[StateDictOptions] = None, +) -> None: + """Patch the ``state_dict`` and ``load_state_dict`` attributes of ``optimizers``. + + Patch the ``state_dict`` and ``load_state_dict`` attributes of ``optimizers`` to + be a partial function to call ``get_state_dict`` and ``set_state_dict``. + + Note that if there are multiple optimizers, all of the optimizers will be patched. + So users only need to call one of the state_dict() to get the full result. + + Example: + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.checkpoint.state_dict import patch_model_state_dict + + model = fsdp(model) + patch_model_state_dict(model) + + Args: + model (nn.Module): the nn.Module to the model. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be loaded. See + `StateDictOptions` for the details. + Returns: + None + """ + + _state_dict_call = functools.partial( + get_optimizer_state_dict, + model=model, + optimizers=optimizers, + options=options, + ) + + def state_dict_call(): + return _state_dict_call() + + _load_state_dict_call = functools.partial( + set_optimizer_state_dict, + model=model, + optimizers=optimizers, + options=options, + ) + + def load_state_dict_call(state_dict: dict[str, Any]): + _load_state_dict_call(optim_state_dict=state_dict) + + _patched_state_dict.add(state_dict_call) + _patched_state_dict.add(load_state_dict_call) + optimizers = ( + (optimizers,) + if isinstance(optimizers, torch.optim.Optimizer) + else tuple(optimizers) + ) + for optim in optimizers: + optim.state_dict = state_dict_call + optim.load_state_dict = load_state_dict_call diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/state_dict_loader.py b/phivenv/Lib/site-packages/torch/distributed/checkpoint/state_dict_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..e8ef1c472578e3b4d3d785441bd3b6368c89ba1f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/checkpoint/state_dict_loader.py @@ -0,0 +1,328 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import os +import warnings +from typing import Any, cast, Optional, Union +from typing_extensions import deprecated + +import torch +import torch.distributed as dist +from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner +from torch.distributed.checkpoint.logger import _dcp_method_logger +from torch.distributed.checkpoint.stateful import Stateful + +from ._storage_utils import _storage_setup +from .default_planner import DefaultLoadPlanner +from .planner import LoadPlan, LoadPlanner +from .storage import StorageReader +from .utils import _api_bc_check, _DistWrapper, _profile + + +__all__ = ["load_state_dict", "load"] + + +@deprecated( + "`load_state_dict` is deprecated and will be removed in future versions. " + "Please use `load` instead.", + category=FutureWarning, +) +def load_state_dict( + state_dict: dict[str, Any], + storage_reader: StorageReader, + process_group: Optional[dist.ProcessGroup] = None, + coordinator_rank: int = 0, + no_dist: bool = False, + planner: Optional[LoadPlanner] = None, +) -> None: + """This method is deprecated. Please switch to 'load'.""" + storage_reader.reset() + with _profile(): + # TODO: test returning `load` here instead. + return _load_state_dict( + state_dict, + storage_reader, + process_group, + coordinator_rank, + no_dist, + planner, + ) + + +@_dcp_method_logger(log_exceptions=True) +@_api_bc_check +def load( + state_dict: dict[str, Any], + *, + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_reader: Optional[StorageReader] = None, + planner: Optional[LoadPlanner] = None, + process_group: Optional[dist.ProcessGroup] = None, + no_dist: bool = False, +) -> None: + """ + Load a checkpoint into a distributed state dict in SPMD style. + + Each rank must have the same keys in their ``state_dict`` provided to this + API. Mismatched keys may result in hangs or errors. If unsure, you can use + the ``utils._assert_same_keys`` API to check (but may incur communication + costs). + + Each rank will try to read the least amount of data necessary + to fulfill the requested `state_dict`. When loading :class:`ShardedTensor` + or :class:`DTensor` instances, each rank only reads data for their local shards. + + For each ``Stateful`` object (having both a ``state_dict`` and a ``load_state_dict``), + load will first call ``state_dict`` before attempting deserialization, followed by + ``load_state_dict`` once the deserialization is complete. + For each non-``Stateful`` object, load will deserialize the object, and then replace + it in the ``state_dict`` with the deserialized object. + + .. warning:: + All tensors in ``state_dict`` must be allocated on their + destination device *prior to* calling this function. + + All non-tensor data is loaded using `torch.load()` and modified in place + on state_dict. + + .. warning:: + Users must call `load_state_dict` on the root module to ensure load + pos-processing and non-tensor data properly propagates. + + .. note: + If no process group is initialized, this function will assume the intent + is to load a checkpoint into the local process. This can be useful in the + case of local inference, and when using regular Tensors (as opposed to DTensor + or ShardedTensor) + + .. note: + Rank 0 is assumed to be the coordinator rank. + + Args: + state_dict (Dict[str, Any]): The state_dict to load the checkpoint into. + checkpoint_id (Union[str, os.PathLike, None]): + The ID of this checkpoint instance. The meaning of the checkpoint_id + depends on the storage. It can be a path to a folder or to a file. + It can also be a key if the storage is a key-value store. + (Default: ``None``) + storage_reader (Optional[StorageReader]): + Instance of StorageWriter used to perform reads. If this is not + specified, DCP will automatically infer the reader based on the + checkpoint_id. If checkpoint_id is also None, an exception will + be raised. (Default: ``None``) + planner (Optional[LoadPlanner]): + Instance of LoadPlanner. If this is not specified, the default + planner will be used. (Default: ``None``) + process_group (Optional[ProcessGroup]): + ProcessGroup to be used for cross-rank synchronization. + (Default: ``None``) + no_dist (bool): If ``True``, this function will assume the intent is to load + a checkpoint without using cross-rank synchronization. (Default: ``False``) + Returns: + None. + + Examples + >>> # xdoctest: +SKIP + >>> my_model = MyModule() + >>> optimizer = Adagrad(my_model.parameters()) + >>> model_state_dict = my_model.state_dict() + >>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader( + ... "/checkpoint/1" + ... ) + + >>> torch.distributed.checkpoint.load_state_dict( + >>> state_dict=model_state_dict, + >>> storage_reader=fs_storage_reader, + >>> ) + + >>> # module.load_state_dict() function might have customized steps + >>> # to flush the state_dict, must call it to + >>> # ensure correct behavior. + >>> my_model.load_state_dict(model_state_dict) + + .. note:: + load_state_dict uses collectives to coordinate reads across ranks. + For NCCL-based process groups, internal tensor representations of + objects must be moved to the GPU device before communication takes place. + In this case, the device used is given by ``torch.cuda.current_device()`` + and it is the user's responsibility to ensure that this is set so that each + rank has an individual GPU, via ``torch.cuda.set_device()``. + """ + + no_dist = no_dist or (not dist.is_available()) or (not dist.is_initialized()) + if no_dist: + warnings.warn( + "torch.distributed is disabled, unavailable or uninitialized, assuming the intent is to load in a single process." + ) + + with _profile(): + storage_reader = cast( + StorageReader, _storage_setup(storage_reader, checkpoint_id, reader=True) + ) + + # All ranks must have the same keys in their `state_dict` provided to + # this API. See documentation for more details. + # Here we simply sort the keys to ensure that all ranks load values in + # the same order. + keys = sorted(state_dict.keys()) + + statetful_sd = {} + for key in keys: + if key not in state_dict: + continue + elem = state_dict[key] + statetful_sd[key] = ( + elem.state_dict() if isinstance(elem, Stateful) else elem + ) + + _load_state_dict( + state_dict=statetful_sd, + storage_reader=storage_reader, + process_group=process_group, + no_dist=no_dist, + planner=planner, + ) + for key in keys: + if key not in state_dict: + continue + elem = state_dict[key] + if isinstance(elem, Stateful): + # If the state_dict is a Stateful object, + # DCP does an in-place load in the original state dict. + elem.load_state_dict(statetful_sd[key]) + else: + # Otherwise, replace the state_dict with the loaded state_dict. + state_dict[key] = statetful_sd[key] + + +def _load_state_dict( + state_dict: dict[str, Any], + storage_reader: StorageReader, + process_group: Optional[dist.ProcessGroup] = None, + coordinator_rank: int = 0, + no_dist: bool = False, + planner: Optional[LoadPlanner] = None, +) -> None: + torch._C._log_api_usage_once("torch.distributed.checkpoint.load_state_dict") + + distW = _DistWrapper(process_group, not no_dist, coordinator_rank) + if planner is None: + planner = DefaultLoadPlanner() + + ckpt_kwargs = {} + if (ckpt_id := getattr(storage_reader, "checkpoint_id", None)) is not None: + ckpt_kwargs["checkpoint_id"] = ckpt_id + ckpt_kwargs["process_group"] = distW.group + + @_dcp_method_logger(**ckpt_kwargs) + def local_step(): + assert planner is not None + metadata = storage_reader.read_metadata() + planner.set_up_planner(state_dict, metadata, distW.is_coordinator) + storage_reader.set_up_storage_reader(metadata, distW.is_coordinator) + + local_plan = planner.create_local_plan() + local_plan = storage_reader.prepare_local_plan(local_plan) + return local_plan + + @_dcp_method_logger(**ckpt_kwargs) + def global_step(all_local_plans): + assert planner is not None + all_local_plans = planner.create_global_plan(all_local_plans) + all_local_plans = storage_reader.prepare_global_plan(all_local_plans) + return all_local_plans + + central_plan: LoadPlan = distW.reduce_scatter("plan", local_step, global_step) + + @_dcp_method_logger(**ckpt_kwargs) + def read_data(): + assert planner is not None + final_local_plan = planner.finish_plan(central_plan) + all_reads = storage_reader.read_data(final_local_plan, planner) + + all_reads.wait() + return None + + _ = distW.all_gather("read", read_data) + + +def _load_state_dict_from_keys( + keys: Optional[Union[set[str], str]] = None, + *, + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_reader: Optional[StorageReader] = None, + process_group: Optional[dist.ProcessGroup] = None, +) -> dict[str, Any]: + """ + Load only the specified keys from the checkpoint, if no keys are specified, the entire + checkpoint will be loaded. Note, this method completely loads the checkpoint into the + current process and is not distributed. + + .. warning:: + + + .. warning:: + + All non-tensor data is loaded using `torch.load()` + + .. note: + As opposed to the usual pattern, this function does not take a state dict as input + and does not load inplace. Instead, a new state dict is directly initialized and read + from file. + + .. note: + If no process group is initialized, this function will assume the intent + is to load a checkpoint into the local process. This can be useful in the + case of local inference, and when using regular Tensors (as opposed to DTensor + or ShardedTensor) + + .. note: + Rank 0 is assumed to be the coordinator rank. + + Args: + keys (Optional[Union[set[str], str]]): + Loads any key specified in this set. If no keys are specified, the entire checkpoint + is loaded. + checkpoint_id (Union[str, os.PathLike, None]): + The ID of this checkpoint instance. The meaning of the checkpoint_id + depends on the storage. It can be a path to a folder or to a file. + It can also be a key if the storage is a key-value store. + (Default: ``None``) + storage_reader (Optional[StorageReader]): + Instance of StorageWriter used to perform reads. If this is not + specified, DCP will automatically infer the reader based on the + checkpoint_id. If checkpoint_id is also None, an exception will + be raised. (Default: ``None``) + process_group (Optional[ProcessGroup]): + ProcessGroup to be used for cross-rank synchronization. + (Default: ``None``) + + Returns: + State dict from specified keys + """ + torch._C._log_api_usage_once( + "torch.distributed.checkpoint._load_state_dict_from_keys" + ) + + no_dist = not (dist.is_available() and dist.is_initialized()) + if no_dist: + warnings.warn( + "torch.distributed is unavailable or uninitialized, assuming the intent is to load in a single process." + ) + + storage_reader = cast( + StorageReader, _storage_setup(storage_reader, checkpoint_id, reader=True) + ) + + if isinstance(keys, str): + keys = {keys} + + sd: dict[str, Any] = {} + _load_state_dict( + state_dict=sd, + storage_reader=storage_reader, + process_group=process_group, + no_dist=no_dist, + planner=_EmptyStateDictLoadPlanner(keys=keys), + ) + + return sd diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/state_dict_saver.py b/phivenv/Lib/site-packages/torch/distributed/checkpoint/state_dict_saver.py new file mode 100644 index 0000000000000000000000000000000000000000..d15a93727483734346363d83e53c65d956d4a9ed --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/checkpoint/state_dict_saver.py @@ -0,0 +1,379 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import inspect +import os +import warnings +from concurrent.futures import Future +from enum import Enum +from typing import cast, Optional, Union +from typing_extensions import deprecated + +import torch +import torch.distributed as dist +from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict +from torch.distributed.checkpoint._async_executor import ( # noqa: TC001 + _AsyncCheckpointExecutor, +) +from torch.distributed.checkpoint._async_process_executor import ( + _ProcessBasedAsyncCheckpointExecutor, +) +from torch.distributed.checkpoint._async_thread_executor import ( + _ThreadBasedAsyncCheckpointExecutor, +) +from torch.distributed.checkpoint._storage_utils import _storage_setup +from torch.distributed.checkpoint.default_planner import DefaultSavePlanner +from torch.distributed.checkpoint.logger import _dcp_method_logger +from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE +from torch.distributed.checkpoint.planner import SavePlan, SavePlanner +from torch.distributed.checkpoint.staging import AsyncStager +from torch.distributed.checkpoint.stateful import Stateful +from torch.distributed.checkpoint.storage import StorageWriter +from torch.distributed.distributed_c10d import _get_default_group + +from .utils import _api_bc_check, _DistWrapper, _profile + + +__all__ = ["save_state_dict", "save", "async_save", "AsyncCheckpointerType"] + + +class AsyncCheckpointerType(Enum): + """Enum for async checkpointer type.""" + + THREAD = "thread" + PROCESS = "process" + + +@deprecated( + "`save_state_dict` is deprecated and will be removed in future versions." + "Please use `save` instead.", + category=FutureWarning, +) +def save_state_dict( + state_dict: STATE_DICT_TYPE, + storage_writer: StorageWriter, + process_group: Optional[dist.ProcessGroup] = None, + coordinator_rank: int = 0, + no_dist: bool = False, + planner: Optional[SavePlanner] = None, +) -> Metadata: + """This method is deprecated. Please switch to 'save'.""" + storage_writer.reset() + + # TODO: test returning `save` here instead. + with _profile(): + return _save_state_dict( + state_dict, + storage_writer, + process_group, + coordinator_rank, + no_dist, + planner, + ) + + +@_dcp_method_logger(log_exceptions=True) # type: ignore[arg-type] +@_api_bc_check +def save( + state_dict: STATE_DICT_TYPE, + *, + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_writer: Optional[StorageWriter] = None, + planner: Optional[SavePlanner] = None, + process_group: Optional[dist.ProcessGroup] = None, + no_dist: bool = False, +) -> Metadata: + """ + Save a distributed model in SPMD style. + + This function is different from ``torch.save()`` as it handles + ``ShardedTensor`` , and ``DTensor`` by having each rank only save their local shards. + + For each ``Stateful`` object (having both a ``state_dict`` and a ``load_state_dict``), + save will call ``state_dict`` before serialization. + + .. warning:: + There is no guarantees of Backwards Compatibility across PyTorch versions + for saved state_dicts. + + .. warning:: + If using the `process_group` argument, make sure that only its ranks + call `save_state_dict` and that all data in state_dict belong to it. + + .. note:: + When saving checkpoint for FSDP's `ShardingStrategy.HYBRID_SHARD`, only one of + the shard_group should be calling `save_state_dict` and the corresponding process + group needs to be passed in. + + .. note:: + If no process group is available, this function assumes the intention is to save the + state_dict in the local process. + + .. note: + Rank 0 is assumed to be the coordinator rank. + + + Args: + state_dict (Dict[str, Any]): The state_dict to save. + checkpoint_id (Union[str, os.PathLike, None]): + The ID of this checkpoint instance. The meaning of the checkpoint_id + depends on the storage. It can be a path to a folder or to a file. + It can also be a key if the storage is a key-value store. + (Default: ``None``) + storage_writer (Optional[StorageWriter]): + Instance of StorageWriter used to perform writes. If this is not + specified, DCP will automatically infer the writer based on the + checkpoint_id. If checkpoint_id is also None, an exception will + be raised. (Default: ``None``) + planner (Optional[SavePlanner]): + Instance of SavePlanner. If this is not specified, the default + planner will be used. (Default: ``None``) + process_group (Optional[ProcessGroup]): + ProcessGroup to be used for cross-rank synchronization. + (Default: ``None``) + no_dist (bool): + If ``True``, this function will assume the intent is to load + a checkpoint without using cross-rank synchronization. + (Default: ``False``) + + Returns: + Metadata: Metadata object for the saved checkpoint. + + Example: + >>> # xdoctest: +SKIP + >>> my_model = MyModule() + + >>> state_dict = {"model": my_model} + + >>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter( + ... "/checkpoint/1" + ... ) + >>> torch.distributed.checkpoint.save( + >>> state_dict=state_dict, + >>> storage_writer=fs_storage_writer, + >>> ) + + .. note:: + save_state_dict uses collectives to coordinate writes across ranks. + For NCCL-based process groups, internal tensor representations of + objects must be moved to the GPU device before communication takes place. + In this case, the device used is given by ``torch.cuda.current_device()`` + and it is the user's responsibility to ensure that this is set so that + each rank has an individual GPU, via ``torch.cuda.set_device()``. + """ + torch._C._log_api_usage_once("torch.distributed.checkpoint.save") + + no_dist = no_dist or (not dist.is_available()) or (not dist.is_initialized()) + if no_dist: + warnings.warn( + "torch.distributed is disabled, unavailable or uninitialized, assuming the intent is to save in a single process." + ) + + with _profile(): + storage_writer = cast( + StorageWriter, _storage_setup(storage_writer, checkpoint_id, reader=False) + ) + + return _save_state_dict( + state_dict=_stateful_to_state_dict(state_dict), + storage_writer=storage_writer, + process_group=process_group, + no_dist=no_dist, + planner=planner, + ) + + +@_dcp_method_logger(log_exceptions=True) +def async_save( + state_dict: STATE_DICT_TYPE, + *, + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_writer: Optional[StorageWriter] = None, + planner: Optional[SavePlanner] = None, + process_group: Optional[dist.ProcessGroup] = None, + async_checkpointer_type: AsyncCheckpointerType = AsyncCheckpointerType.THREAD, +) -> Future: + """Asynchronous version of ``save``. This code first de-stages the state_dict on to the + staging storage (defaults to CPU memory), and then calls the `save` in a separate thread. + + .. warning:: + This feature is experimental and subject to change. + + Args: + state_dict (Dict[str, Any]): The state_dict to save. + checkpoint_id (Union[str, os.PathLike, None]): + The ID of this checkpoint instance. The meaning of the checkpoint_id + depends on the storage. It can be a path to a folder or to a file. + It can also be a key if the storage is a key-value store. + (Default: ``None``) + storage_writer (Optional[StorageWriter]): + Instance of StorageWriter used to perform 'stage' and 'save'. If + this is not specified, DCP will automatically infer the writer based on the + checkpoint_id. If checkpoint_id is also None, an exception will + be raised. (Default: ``None``) + planner (Optional[SavePlanner]): + Instance of SavePlanner. If this is not specified, the default + planner will be used. (Default: ``None``) + process_group (Optional[ProcessGroup]): + ProcessGroup to be used for cross-rank synchronization. + (Default: ``None``) + + Returns: + Future: A future holding the resultant Metadata object from `save`. + + Example: + >>> # xdoctest: +SKIP + >>> my_model = MyModule() + + >>> state_dict = {"model": my_model} + + >>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter( + ... "/checkpoint/1" + ... ) + >>> checkpoint_future = torch.distributed.checkpoint.async_save( + >>> state_dict=state_dict, + >>> storage_writer=fs_storage_writer, + >>> ) + >>> + >>> # ... do some work ... + >>> + >>> checkpoint_future.result() + + """ + torch._C._log_api_usage_once("torch.distributed.checkpoint.async_save") + + if dist.is_available() and dist.is_initialized(): + pg = process_group or _get_default_group() + assert ( + torch.device("cpu") in pg._device_types # type: ignore[attr-defined] + ), ( + "A CPU backend must be enabled for async save; try initializing process group with 'cpu:gloo,cuda:nccl'" + ) + + storage_writer = cast( + StorageWriter, _storage_setup(storage_writer, checkpoint_id, reader=False) + ) + + state_dict = _stateful_to_state_dict(state_dict) + + @_dcp_method_logger(log_exceptions=True) + def stage_state_dict(): + if isinstance(storage_writer, AsyncStager): + staged_state_dict = storage_writer.stage(state_dict) + else: # provides bwc for storage_writers not implementing AsyncStager + staged_state_dict = _create_cpu_state_dict(state_dict) + _copy_state_dict(state_dict, staged_state_dict, type_check=False) + + return staged_state_dict + + staged_state_dict = stage_state_dict() + + executor: _AsyncCheckpointExecutor = ( + _ProcessBasedAsyncCheckpointExecutor() + if async_checkpointer_type == AsyncCheckpointerType.PROCESS + else _ThreadBasedAsyncCheckpointExecutor() + ) + + f: Future = executor.execute_save( + staged_state_dict, + checkpoint_id=checkpoint_id, + storage_writer=storage_writer, + planner=planner, + process_group=process_group, + ) + + @_dcp_method_logger(log_exceptions=True) + def maybe_synchronize_staging(): + if ( + isinstance(storage_writer, AsyncStager) + and storage_writer.should_synchronize_after_execute + ): + storage_writer.synchronize_staging() + + maybe_synchronize_staging() + + return f + + +@_dcp_method_logger(log_exceptions=True) +def _stateful_to_state_dict(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE: + """Creates a shallow copy of `state_dict` where `state_dict` is called for each Stateful object.""" + stateful_state_dict = {} + for key, elem in state_dict.items(): + stateful_state_dict[key] = ( + elem.state_dict() if isinstance(elem, Stateful) else elem + ) + return stateful_state_dict + + +def _save_state_dict( + state_dict: STATE_DICT_TYPE, + storage_writer: StorageWriter, + process_group: Optional[dist.ProcessGroup] = None, + coordinator_rank: int = 0, + no_dist: bool = False, + planner: Optional[SavePlanner] = None, +) -> Metadata: + torch._C._log_api_usage_once("torch.distributed.checkpoint.save_state_dict") + + distW = _DistWrapper(process_group, not no_dist, coordinator_rank) + if planner is None: + planner = DefaultSavePlanner() + assert planner is not None + + global_metadata = None + + ckpt_kwargs = {} + if (ckpt_id := getattr(storage_writer, "checkpoint_id", None)) is not None: + ckpt_kwargs["checkpoint_id"] = ckpt_id + ckpt_kwargs["process_group"] = distW.group + + @_dcp_method_logger(**ckpt_kwargs) + def local_step(): + assert planner is not None + storage_meta = storage_writer.storage_meta() + if "storage_meta" not in inspect.signature(planner.set_up_planner).parameters: + warnings.warn( + "The function definition for SavePlanner.set_up_planner has been updated" + " to include the storage_meta argument. Please update your implementation" + " to include this parameter." + ) + planner.set_up_planner(state_dict, distW.is_coordinator) # type: ignore[call-arg, arg-type] + else: + planner.set_up_planner( + state_dict=state_dict, + storage_meta=storage_meta, + is_coordinator=distW.is_coordinator, + ) + storage_writer.set_up_storage_writer(distW.is_coordinator) + + local_plan = planner.create_local_plan() + local_plan = storage_writer.prepare_local_plan(local_plan) + return local_plan + + @_dcp_method_logger(**ckpt_kwargs) + def global_step(all_local_plans): + nonlocal global_metadata + + assert planner is not None + all_local_plans, global_metadata = planner.create_global_plan(all_local_plans) + all_local_plans = storage_writer.prepare_global_plan(all_local_plans) + return all_local_plans + + central_plan: SavePlan = distW.reduce_scatter("plan", local_step, global_step) + + @_dcp_method_logger(**ckpt_kwargs) + def write_data(): + assert planner is not None + final_local_plan = planner.finish_plan(central_plan) + all_writes = storage_writer.write_data(final_local_plan, planner) + + all_writes.wait() + return all_writes.value() + + @_dcp_method_logger(**ckpt_kwargs) + def finish_checkpoint(all_results): + assert global_metadata is not None + storage_writer.finish(metadata=global_metadata, results=all_results) + return global_metadata + + return distW.all_reduce("write", write_data, finish_checkpoint) diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/stateful.py b/phivenv/Lib/site-packages/torch/distributed/checkpoint/stateful.py new file mode 100644 index 0000000000000000000000000000000000000000..c6b045fe143193550285b8316633643834c8d2d5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/checkpoint/stateful.py @@ -0,0 +1,42 @@ +from typing import Any, TypeVar +from typing_extensions import Protocol, runtime_checkable + + +__all__ = ["Stateful", "StatefulT"] + + +@runtime_checkable +class Stateful(Protocol): + """ + Stateful protocol for objects that can be checkpointed and restored. + """ + + def state_dict(self) -> dict[str, Any]: + """ + Objects should return their state_dict representation as a dictionary. + The output of this function will be checkpointed, and later restored in + `load_state_dict()`. + + .. warning:: + Because of the inplace nature of restoring a checkpoint, this function + is also called during `torch.distributed.checkpoint.load`. + + + Returns: + Dict: The objects state dict + """ + + ... + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """ + Restore the object's state from the provided state_dict. + + Args: + state_dict: The state dict to restore from + """ + + ... + + +StatefulT = TypeVar("StatefulT", bound=Stateful) diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/storage.py b/phivenv/Lib/site-packages/torch/distributed/checkpoint/storage.py new file mode 100644 index 0000000000000000000000000000000000000000..46ce87f3f59807000e7a620867c5aa077f8f2326 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/checkpoint/storage.py @@ -0,0 +1,284 @@ +import abc +import os +from dataclasses import dataclass +from typing import Any, Optional, Union + +from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex, StorageMeta +from torch.distributed.checkpoint.planner import ( + LoadPlan, + LoadPlanner, + SavePlan, + SavePlanner, +) +from torch.futures import Future + + +__all__ = ["WriteResult", "StorageWriter", "StorageReader"] + + +@dataclass(frozen=True) +class WriteResult: + index: MetadataIndex + + size_in_bytes: int + storage_data: Any + + +class StorageWriter(abc.ABC): + """ + Interface used by ``save_state_dict`` to write to storage. + + One StorageWriter instance acts as both the coordinator and the follower + in a distributed checkpoint. As part of initialization, each instance + is told its role. + + A subclass should expect the following sequence of calls. + + 0) (all ranks) set checkpoint_id if users pass a valid checkpoint_id. + 1) (all ranks) set_up_storage_writer() + 2) (all ranks) prepare_local_plan() + 3) (coordinator) prepare_global_plan() + 4) (all ranks) write_data() + 5) (coordinator) finish() + """ + + @abc.abstractmethod + def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None: + """ + Calls to indicates a brand new checkpoint write is going to happen. + A checkpoint_id may be present if users set the checkpoint_id for + this checkpoint write. The meaning of the checkpiont_id is + storage-dependent. It can be a path to a folder/file or a key for + a key-value storage. + + Args: + checkpoint_id (Union[str, os.PathLike, None]): + The ID of this checkpoint instance. The meaning of the checkpoint_id + depends on the storage. It can be a path to a folder or to a file. + It can also be a key if the storage is a key-value store. + (Default: ``None``) + """ + ... + + @abc.abstractmethod + def set_up_storage_writer(self, is_coordinator: bool) -> None: + """ + Initialize this instance. + + Args: + is_coordinator (bool): Whether this instance is responsible for coordinating + the checkpoint. + """ + + @abc.abstractmethod + def prepare_local_plan(self, plan: SavePlan) -> SavePlan: + """ + Perform storage-specific local planning. + + While this method can produce a completely different plan, the recommended + way is to store storage specific data in SavePlan::storage_data. + + Args: + plan (SavePlan): The local plan from the ``SavePlanner`` in use. + + Returns: + A transformed ``SavePlan`` after storage local planning + """ + + @abc.abstractmethod + def prepare_global_plan(self, plans: list[SavePlan]) -> list[SavePlan]: + """ + Perform centralized planning of storage. + + This method is only called on the coordinator instance. + + While this method can produce a completely different plan, the preferred + way is to store storage specific data in SavePlan::storage_data. + + Args: + plans: A list of ``SavePlan`` instances, one for each rank. + + Returns: + A list of transformed ``SavePlan`` after storage global planning + """ + + @abc.abstractmethod + def write_data( + self, plan: SavePlan, planner: SavePlanner + ) -> Future[list[WriteResult]]: + """ + Write all items from ``plan`` using ``planner`` to resolve the data. + + A subclass should call ``SavePlanner::resolve_data`` on each item + from the plan to get access to the underlying object to write. + + Subclasses should lazily call `resolve_data` as it can allocate memory. + In case of tensors, make following assumptions: + + - They might be on any device, including not matching the one on ``WriteItem::tensor_data`` + - They might be views or not contiguous. Only the projection needs to be saved. + + Args: + plan (SavePlan): The save plan to execute. + planner (SavePlanner): Planner object to be used to resolve items to data. + + Returns: + A future that completes to a list of WriteResult + """ + + @abc.abstractmethod + def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None: + """ + Write the metadata and marks the current checkpoint as successful. + + The actual format/schema used for serializing `metadata` is an + implementation detail. The only requirement is that it's recoverable + in to the same object graph. + + Args: + metadata (Metadata): metadata for the new checkpoint + results: A list of WriteResults from all ranks. + + Returns: + None + """ + + @classmethod + @abc.abstractmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + """ + Check if the given checkpoint_id is supported by the storage. This allow + us to enable automatic storage selection. + """ + ... + + def storage_meta(self) -> Optional[StorageMeta]: + """ + Return the storage-specific metadata. This is used to store additional information + in a checkpoint that can be useful for providing request-level observability. StorageMeta + is passed to the ``SavePlanner`` during save calls. Returns None by default. + + TODO: provide an example + """ + return None + + +class StorageReader(abc.ABC): + """ + Interface used by ``load_state_dict`` to read from storage. + + One StorageReader instance acts as both the coordinator and the follower + in a distributed checkpoint. As part of initialization, each instance + is told its role. + + A subclass should expected the following sequence of calls by ``load_state_dict``: + + 0) (all ranks) set checkpoint_id if users pass a valid checkpoint_id. + 1) (all ranks) read_metadata() + 2) (all ranks) set_up_storage_reader() + 3) (all ranks) prepare_local_plan() + 4) (coordinator) prepare_global_plan() + 5) (all ranks) read_data() + """ + + @abc.abstractmethod + def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None: + """ + Calls to indicates a brand new checkpoint read is going to happen. + A checkpoint_id may be present if users set the checkpoint_id for + this checkpoint read. The meaning of the checkpiont_id is + storage-dependent. It can be a path to a folder/file or a key for + a key-value storage. + + Args: + checkpoint_id (Union[str, os.PathLike, None]): + The ID of this checkpoint instance. The meaning of the checkpoint_id + depends on the storage. It can be a path to a folder or to a file. + It can also be a key if the storage is more like a key-value store. + (Default: ``None``) + """ + ... + + @abc.abstractmethod + def read_metadata(self) -> Metadata: + """ + Read the checkpoint metadata. + + Returns: + The metadata object associated with the checkpoint being loaded. + + """ + + @abc.abstractmethod + def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None: + """ + Initialize this instance. + + Args: + metadata (Metadata): The metadata schema to use. + is_coordinator (bool): Whether this instance is responsible for coordinating + the checkpoint. + """ + + @abc.abstractmethod + def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan: + """ + Perform storage-specific local planning. + + While this method can produce a completely different plan, the recommended + way is to store storage specific data in LoadPlan::storage_data. + + Args: + plan (LoadPlan): The local plan from the ``LoadPlan`` in use. + + Returns: + A transformed ``LoadPlan`` after storage local planning + """ + + @abc.abstractmethod + def prepare_global_plan(self, plans: list[LoadPlan]) -> list[LoadPlan]: + """ + Perform centralized planning of storage loading. + + This method is only called on the coordinator instance. + + While this method can produce a completely different plan, the preferred + way is to store storage specific data in LoadPlan::storage_data. + + Args: + plans: A list of ``LoadPlan`` instances, one for each rank. + + Returns: + A list of transformed ``LoadPlan`` after storage global planning + """ + + @abc.abstractmethod + def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: + """ + Read all items from ``plan`` using ``planner`` to resolve the data. + + A subclass should call ``LoadPlanner::load_bytes`` to deserialize a BytesIO + object into the right place. + + A subclass should call ``LoadPlanner::resolve_tensor`` to get access to the + tensors that in should load data into. + + It's the StorageLayer responsibility to properly schedule any cross device copies + required. + + Args: + plan (LoadPlan): The local plan to execute on + planner (LoadPlanner): The planner object to use to resolve items. + + Returns: + A future that completes once all reads are finished. + """ + + @classmethod + @abc.abstractmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + """ + Check if the given checkpoint_id is supported by the storage. This allow + us to enable automatic storage selection. + """ + ... diff --git a/phivenv/Lib/site-packages/torch/distributed/checkpoint/utils.py b/phivenv/Lib/site-packages/torch/distributed/checkpoint/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..594b12e50082580382158e85f2e3b1bdba2c4e9e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/checkpoint/utils.py @@ -0,0 +1,477 @@ +# mypy: allow-untyped-defs +import cProfile +import inspect +import io +import itertools +import os +import warnings +from collections.abc import Sequence +from contextlib import contextmanager +from functools import wraps +from pstats import Stats +from typing import Any, Callable, cast, Optional, TypeVar, Union + +import torch +import torch.distributed as dist +from torch.distributed._shard.sharded_tensor import ShardedTensor +from torch.distributed._shard.sharded_tensor.shard import Shard + +from .api import ( + _is_wrapped_exception, + _wrap_exception, + CheckpointException, + WRAPPED_EXCEPTION, +) +from .metadata import MetadataIndex, STATE_DICT_TYPE + + +__all__ = ["find_tensor_shard", "find_state_dict_object"] + +T = TypeVar("T") +R = TypeVar("R") + + +def _get_failure_dict( + results: list[Union[T, WRAPPED_EXCEPTION]], +) -> dict[int, WRAPPED_EXCEPTION]: + return cast( + dict[int, WRAPPED_EXCEPTION], + {i: err for i, err in enumerate(results) if _is_wrapped_exception(err)}, + ) + + +def _all_gather_keys( + local_dict: dict[str, Any], group: Optional[dist.ProcessGroup] = None +) -> set[str]: + """Gathers all keys, and returns them sorted.""" + keys = list(local_dict.keys()) + gathered_keys: list[list[str]] = [None] * dist.get_world_size(group) # type: ignore[list-item] + + dist.all_gather_object(gathered_keys, keys, group=group) + return set(itertools.chain.from_iterable(gathered_keys)) + + +def _assert_same_keys( + state_dict: dict[str, Any], process_group: Optional[dist.ProcessGroup] = None +) -> None: + """ + Asserts that all ranks have the same keys in their state dict. + This is a collective call which requires all ranks in ``process_group`` to + join. It will also induce cross-rank communication and block CPU. + """ + + if dist.get_world_size(process_group) == 1: + return + + all_keys = _all_gather_keys(state_dict, process_group) + my_keys = set(state_dict.keys()) + diff = all_keys - my_keys + if len(diff) > 0: + raise AssertionError( + f"Key(s) present in other ranks but not this one, difference: {diff}" + ) + + +class _DistWrapper: + """ + This is a wrapper around PG that provides a series of features around object collectives. + + It works without distributed initialized, where most collectives turns into nops. + + All variants that take functions are exception robust, meaning that if one or more + ranks raise errors, all ranks will observe those. + """ + + def __init__( + self, + group: Optional[dist.ProcessGroup], + use_dist: bool, + coordinator_rank: int, + ): + self.group = group + self.use_dist = use_dist + self.coordinator_rank = coordinator_rank + if self.use_dist: + self.global_coordinator_rank = ( + dist.get_global_rank(group, coordinator_rank) + if group is not None + else coordinator_rank + ) + self.rank = dist.get_rank(group) + self.is_coordinator = self.rank == coordinator_rank + else: + self.global_coordinator_rank = 0 + self.rank = 0 + self.is_coordinator = True + + def get_rank(self) -> int: + return self.rank + + def get_world_size(self) -> int: + if self.use_dist: + return dist.get_world_size(self.group) + return 1 + + def broadcast_object(self, object: Optional[T]) -> T: + """Implement functionality similar to c10d::broadcast_object_list but without distributed enabled.""" + object_list = [object] + if self.use_dist: + dist.broadcast_object_list( + object_list=object_list, + group=self.group, + src=self.global_coordinator_rank, + ) + return cast(T, object_list[0]) + + def gather_object(self, object: T) -> Optional[list[T]]: + """Implement functionality similar to c10d::gather_object but without distributed enabled.""" + if self.use_dist: + gather_objs = ( + cast(list[T], [None] * dist.get_world_size(self.group)) + if self.is_coordinator + else None + ) + + dist.gather_object( + obj=object, + object_gather_list=gather_objs if self.is_coordinator else None, + dst=self.global_coordinator_rank, + group=self.group, + ) + result = gather_objs + else: + result = [object] + return result + + def all_gather_object(self, object: T) -> list[T]: + """Implement functionality similar to c10d::all_gather_object but without distributed enabled.""" + if self.use_dist: + gather_objs = cast(list[T], [None] * dist.get_world_size(self.group)) + + dist.all_gather_object( + object_list=gather_objs, obj=object, group=self.group + ) + else: + gather_objs = [object] + return gather_objs + + def scatter_object(self, object_list: Optional[list[T]]) -> T: + """Implement functionality similar to c10d::scatter_object but without distributed enabled.""" + if self.use_dist: + gather_result = cast(list[T], [None]) + dist.scatter_object_list( + scatter_object_output_list=gather_result, + scatter_object_input_list=object_list if self.is_coordinator else None, + src=self.global_coordinator_rank, + group=self.group, + ) + + local_reply = gather_result[0] + else: + assert object_list is not None + local_reply = object_list[0] + return local_reply + + def reduce_scatter( + self, + step: str, + map_fun: Callable[[], T], + reduce_fun: Callable[[list[T]], list[R]], + ) -> R: + """ + Compute a value on each rank, then do centralized reduce on a single rank, followed by a scatter. + + This method operates in the following way: + Run ``map_fun`` on all ranks + Gather results on rank 0 + Call ``reduce_fun`` on all those values + Scatter to each rank part of the result. + """ + local_data: Union[WRAPPED_EXCEPTION, T] + try: + local_data = map_fun() + except BaseException as e: + local_data = _wrap_exception(e) + + all_data = self.gather_object(local_data) + all_results: Optional[list[Union[R, CheckpointException]]] = None + if self.is_coordinator: + assert all_data is not None + node_failures = _get_failure_dict(all_data) + + if len(node_failures) == 0: + try: + # N.B. why can't mypy cast List[R] to List[Union[R, WRAPPED_EXCEPTION]]? + all_results = cast( + list[Union[R, CheckpointException]], + reduce_fun(cast(list[T], all_data)), + ) + except BaseException as e: + node_failures[self.rank] = _wrap_exception(e) + + if len(node_failures) > 0: + all_results = [ + CheckpointException(step, node_failures) + ] * self.get_world_size() + + result = self.scatter_object(all_results) + if isinstance(result, CheckpointException): + raise result + return result + + def all_reduce( + self, + step: str, + map_fun: Callable[[], T], + reduce_fun: Callable[[list[T]], R], + ) -> R: + """ + Compute a value on each rank, then do centralized reduce on a single rank, followed by a broadcast. + + This method operates in the following way: + Run ``map_fun`` on all ranks + Gather results on rank 0 + Call ``reduce_fun`` on all those values + Broadcast the reduced value to all ranks. + """ + local_data: Union[T, WRAPPED_EXCEPTION] + try: + local_data = map_fun() + except BaseException as e: + local_data = _wrap_exception(e) + + all_data = self.gather_object(local_data) + result: Optional[Union[R, CheckpointException]] = None + if self.is_coordinator: + assert all_data is not None + node_failures = _get_failure_dict(all_data) + if len(node_failures) == 0: + try: + result = reduce_fun(cast(list[T], all_data)) + except BaseException as e: + node_failures[self.rank] = _wrap_exception(e) + + if len(node_failures) > 0: + result = CheckpointException(step, node_failures) + + final_result = self.broadcast_object(result) + if isinstance(final_result, CheckpointException): + raise final_result + return cast(R, final_result) + + def all_gather( + self, + step: str, + map_fun: Callable[[], T], + ) -> list[T]: + """ + Compute a value on each rank, then all_gather them. + + This method operates in the following way: + Run ``map_cp`` on all ranks + all_gather the values to all ranks + """ + result: Union[T, WRAPPED_EXCEPTION] + try: + result = map_fun() + except BaseException as e: + result = _wrap_exception(e) + + all_results = self.all_gather_object(result) + + node_failures = _get_failure_dict(all_results) + if len(node_failures) > 0: + raise CheckpointException(step, node_failures) + return cast(list[T], all_results) + + def broadcast( + self, + step: str, + map_fun: Callable[[], T], + ) -> T: + """ + Compute a value on rank 0 and broadcast it. + + This method operates in the following way: + Run ``map_cp`` on rank 0 + broadcast the value + """ + result: Optional[Union[T, CheckpointException]] = None + if self.is_coordinator: + try: + result = map_fun() + except BaseException as e: + result = CheckpointException(step, {self.rank: _wrap_exception(e)}) + final_result = self.broadcast_object(result) + if isinstance(final_result, CheckpointException): + raise final_result + return cast(T, final_result) + + def barrier(self) -> None: + """ + Add a synchronization point across all processes when using distributed. + If torch.distributed is initialized, this function will invoke a barrier across the global process group. + If torch.distributed is not initialized, this function is a no-op. + """ + if not self.use_dist: + return + dist.barrier(group=self.group) + + +def _find_shard(tensor: ShardedTensor, index: MetadataIndex) -> Shard: + if index.offset is None: + raise ValueError( + f"Cannot lookup {index.fqn} since its a ShardedTensor and no offset was provided" + ) + + shards = tensor.local_shards() + # index fast path + if index.index is not None: + if ( + len(shards) > index.index + and torch.Size(shards[index.index].metadata.shard_offsets) == index.offset + ): + return shards[index.index] + + for shard in shards: + if torch.Size(shard.metadata.shard_offsets) == index.offset: + return shard + raise ValueError(f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'") + + +def find_tensor_shard(tensor: torch.Tensor, index: MetadataIndex) -> torch.Tensor: + if hasattr(tensor, "__get_tensor_shard__"): + # DTensor implements _Checkpointable + return tensor.__get_tensor_shard__(index) # type: ignore[attr-defined] + if isinstance(tensor, ShardedTensor): + return _find_shard(tensor, index).tensor + if index.offset is not None: + # special case looking up a tensor by origin + if index.offset == torch.Size([0] * len(tensor.size())): + return tensor + raise ValueError( + f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset: '{index.offset}'" + ) + return tensor + + +def find_state_dict_object(state_dict: STATE_DICT_TYPE, index: MetadataIndex) -> Any: + if index.fqn not in state_dict: + raise ValueError(f"Could not find FQN: '{index.fqn}'") + obj = state_dict[index.fqn] + + if isinstance(obj, torch.Tensor): + return find_tensor_shard(obj, index) + elif index.offset is not None: + raise ValueError( + f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset: '{index.offset}'" + ) + return obj + + +def _element_wise_add(a: Sequence[int], b: Sequence[int]) -> list[int]: + return [i_a + i_b for i_a, i_b in zip(a, b)] + + +def _element_wise_sub(a: Sequence[int], b: Sequence[int]) -> list[int]: + return [i_a - i_b for i_a, i_b in zip(a, b)] + + +class _ReaderView(io.IOBase): + def __init__(self, base_stream: io.IOBase, offset: int, len: int): + super().__init__() + self.offset = offset + self.len = len + self.base_stream = base_stream + self.seek(0) + + def seek(self, offset: int, whence: int = os.SEEK_SET, /) -> int: + if whence == os.SEEK_SET: + offset = self.offset + offset + elif whence == os.SEEK_END: + whence = os.SEEK_SET + offset = (self.offset + self.len) - offset + return self.base_stream.seek(offset, whence) + + def tell(self) -> int: + return self.base_stream.tell() - self.offset + + def readable(self) -> bool: + return self.base_stream.readable() + + def seekable(self) -> bool: + return self.base_stream.seekable() + + def readinto(self, b): + max_size = self.len - self.tell() + if max_size == 0: + return 0 + if len(b) > max_size: + b = memoryview(b)[:max_size] + return self.base_stream.readinto(b) # type: ignore[attr-defined] + + def read(self, size=-1): + max_size = self.len - self.tell() + if size == -1 or size > max_size: + size = max_size + return self.base_stream.read(size) + + +def _create_file_view(file: io.IOBase, offset: int, length: int) -> io.IOBase: + # FIXME (kumpera) torch.load fails if we wrap with io.BufferedReader + return _ReaderView(file, offset, length) + + +def _normalize_device_info(device_type: str, device_id: int) -> str: + """Device info normalization.""" + if device_type == "cpu": + return "cpu" + return f"{device_type}:{device_id}" + + +# TODO: integrate with distributed logging flag +ENABLE_PROFILE = False + + +@contextmanager +def _profile(): + # Only log the profiling when it is enable and is on rank0 or dist is not + # available. + if ENABLE_PROFILE and (not dist.is_available() or dist.get_rank() == 0): + profiler = cProfile.Profile() + profiler.enable() + try: + yield + finally: + profiler.disable() + stats = Stats(profiler) + stats.sort_stats("time").print_stats(10) + else: + yield + + +def _api_bc_check(func): + @wraps(func) + def inner_func(*args, **kwargs) -> Any: + if len(args) == 2: + warnings.warn( + f"The argument order of {func.__name__} has been changed. " + "Please check the document to avoid future breakages." + ) + sig = inspect.signature(func) + kwonlyargs = [ + p.name for p in sig.parameters.values() if p.kind == p.KEYWORD_ONLY + ] + if "storage_writer" in kwonlyargs: + assert "storage_writer" not in kwargs, (args, kwargs) + kwargs["storage_writer"] = args[1] + elif "storage_reader" in kwonlyargs: + assert "storage_reader" not in kwargs, (args, kwargs) + kwargs["storage_reader"] = args[1] + else: + raise RuntimeError(f"Unexpected kwonlyargs = {kwonlyargs}") + return func(args[0], **kwargs) + else: + return func(*args, **kwargs) + + return inner_func diff --git a/phivenv/Lib/site-packages/torch/distributed/collective_utils.py b/phivenv/Lib/site-packages/torch/distributed/collective_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dc7d4c5c0f722d0b57534271690809479f5edb93 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/collective_utils.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python3 + + +""" +A set of primitive functions for performing collective ops. + +Each should also handle single rank scenario. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable, cast, Generic, Optional, TypeVar, Union + +import torch.distributed as dist + + +T = TypeVar("T") + + +@dataclass +class SyncPayload(Generic[T]): + stage_name: Optional[str] + success: bool + payload: T + exception: Optional[Exception] = None + + +def broadcast( + data_or_fn: Union[T, Callable[[], T]], + *, + success: bool = True, + stage_name: Optional[str] = None, + rank: int = 0, + pg: Optional[dist.ProcessGroup] = None, +) -> T: + """ + Broadcasts the data payload from rank 0 to all other ranks. + Or if a function is passed, execute it in rank 0 and broadcast result to all other ranks. + + Can be used to broadcast a failure signal to stop all ranks. + + If the function raises an exception, all ranks will raise. + + Args: + data_or_fn: the data to broadcast or function to execute and broadcast result. + success: False to stop all ranks. + stage_name: the name of the logical stage for synchronization and debugging + rank: rank to broadcast data or execute function and broadcast results. + pg: the process group for sync + Throws: + RuntimeError from original exception trace + Returns: + the value after synchronization + + Example usage: + >> id = broadcast(data_or_fn=allocate_id, rank=0, pg=ext_pg.my_pg) + """ + + if not success and data_or_fn is not None: + raise AssertionError( + "Data or Function is expected to be None if not successful" + ) + + payload: Optional[T] = None + exception: Optional[Exception] = None + # if no pg is passed then execute if rank is 0 + if (pg is None and rank == 0) or (pg is not None and pg.rank() == rank): + # determine if it is an executable function or data payload only + if callable(data_or_fn): + try: + payload = data_or_fn() + except Exception as e: + success = False + exception = e + else: + payload = data_or_fn + + # broadcast the exception type if any to all ranks for failure categorization + sync_obj = SyncPayload( + stage_name=stage_name, + success=success, + payload=payload, + exception=exception, + ) + + if pg is not None: + broadcast_list = [sync_obj] + dist.broadcast_object_list(broadcast_list, src=rank, group=pg) + assert len(broadcast_list) == 1 + sync_obj = broadcast_list[0] + + # failure in any rank will trigger a throw in every rank. + if not sync_obj.success: + error_msg = f"Rank {rank} failed" + if stage_name is not None: + error_msg += f": stage {sync_obj.stage_name}" + if sync_obj.exception is not None: + error_msg += f": exception {sync_obj.exception}" + raise RuntimeError(error_msg) from sync_obj.exception + + return cast(T, sync_obj.payload) + + +def all_gather( + data_or_fn: Union[T, Callable[[], T]], + stage_name: Optional[str] = None, + pg: Optional[dist.ProcessGroup] = None, +) -> list[T]: + """ + A simple all_gather primitive with basic synchronization guard logic, + by checking payload from all ranks has the same stage name. + + Args: + data_or_fn: the data to be all gathered across ranks or function to be executed + stage_name: the sync stage name for out-of-sync protection + pg: the process group for sync + Throws: + RuntimeError from original exception trace + Returns: + a list of synced data from all ranks + + Example usage: + >> all_ids = all_gather(data_or_fn=allocate_id, pg=ext_pg.my_pg) + """ + payload: Optional[T] = None + exception: Optional[Exception] = None + success = True + # determine if it is an executable function or data payload only + if callable(data_or_fn): + try: + payload = data_or_fn() + except Exception as e: + success = False + exception = e + else: + payload = data_or_fn + + sync_obj = SyncPayload( + stage_name=stage_name, + success=success, + payload=payload, + exception=exception, + ) + + if pg is not None: + # List of success/failure across all ranks. + total_list = [None] * dist.get_world_size(pg) + all_gather_object_enforce_type(pg, total_list, sync_obj) + # Each rank will throw RuntimeError in case of failure on any rank. + stage_name = cast(SyncPayload[T], total_list[0]).stage_name + exception_list: list[tuple[int, Exception]] = [] + ret_list: list[T] = [] + error_msg: str = "" + + for i, sp in enumerate(cast(list[SyncPayload[T]], total_list)): + if sp.stage_name != stage_name: + error_msg += ( + f"Unexpected stage name received from rank {i}: {sp.stage_name} " + ) + continue + if not sp.success and sp.exception is not None: + exception_list.append((i, sp.exception)) + continue + ret_list.append(sp.payload) + + if len(exception_list) > 0: + raise RuntimeError( # type: ignore[misc] + error_msg, exception_list + ) from exception_list[0] + return ret_list + else: + if not sync_obj.success: + raise RuntimeError( + f"all_gather failed with exception {sync_obj.exception}", + ) from sync_obj.exception + return [sync_obj.payload] # type: ignore[list-item] + + +# Note: use Any for typing for now so users can pass in +# either a list of None or target type placeholders +# otherwise pyre would complain +def all_gather_object_enforce_type( + pg: dist.ProcessGroup, + # pyre-fixme[2]: Parameter must have a type that does not contain `Any` + object_list: list[Any], + # pyre-fixme[2]: Parameter must have a type other than `Any` + obj: Any, + # pyre-fixme[2]: Parameter must have a type that does not contain `Any` + type_checker: Callable[[Any, Any], bool] = lambda x, y: type(x) == type(y), +) -> None: + """ + Similar to plain all_gather_object but with additional type checking + AFTER gather is done to ensure basic consistency. + If check does not pass, all ranks will fail with exception. + + This is generally to prevent conditional logic leading to + unexpected messages being received. This is considered fatal code error, + but due to logic stacks this might happen implicitly in practice. + + The default check does not check sub type (considered different) + or covariance (considered same) but users can pass in custom checker + if more complicated check is needed. + """ + dist.all_gather_object(object_list, obj, group=pg) + + # conservative check + list_len = len(object_list) + if list_len == 0: + return + first_obj = object_list[0] + for i in range(1, list_len): + if not type_checker(first_obj, object_list[i]): + raise TypeError( + f"Object type at index {i} is {type(object_list[i])}, " + f"while first object type is {type(first_obj)}" + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/constants.py b/phivenv/Lib/site-packages/torch/distributed/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..7fd54e4b3fa00c240950e3ef0636456a39b460fa --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/constants.py @@ -0,0 +1,26 @@ +from datetime import timedelta +from typing import Optional + +from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT + + +__all__ = ["default_pg_timeout", "default_pg_nccl_timeout"] + +# Default process group wide timeout, if applicable. +# This only applies to the non-nccl backends +# To make an attempt at backwards compatibility with THD, we use an +# extraordinarily high default timeout, given that THD did not have timeouts. +default_pg_timeout: timedelta = _DEFAULT_PG_TIMEOUT +# Separate timeout for PGNCCL mainly because it's always been that way in the C++ layer, but until recently +# there was one default that applied across all backends in the python layer. +# Later, we could consider merging them back together at the c++ layer if we can align on a same value. +# (only if TORCH_NCCL_BLOCKING_WAIT or TORCH_NCCL_ASYNC_ERROR_HANDLING is set to 1). + +try: + from torch._C._distributed_c10d import _DEFAULT_PG_NCCL_TIMEOUT + + default_pg_nccl_timeout: Optional[timedelta] = _DEFAULT_PG_NCCL_TIMEOUT +except ImportError: + # if C++ NCCL support is not compiled, we don't have access to the default nccl value. + # if anyone is actually trying to use nccl in this state, it should error. + default_pg_nccl_timeout = None diff --git a/phivenv/Lib/site-packages/torch/distributed/device_mesh.py b/phivenv/Lib/site-packages/torch/distributed/device_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..f91799a5f2143ff66e7ad24261e3f7f6050bdc73 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/device_mesh.py @@ -0,0 +1,1052 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import logging +import math +import os +import threading +import warnings +from functools import reduce +from itertools import chain +from typing import Optional, TYPE_CHECKING, Union + +import torch +from torch.distributed import is_available +from torch.utils._typing_utils import not_none + + +__all__ = ["init_device_mesh", "DeviceMesh"] + + +if not is_available(): + import sys + + # We need to create the stubs when distributed is not available. + # Otherwise, we would fail the doc tests (```./.ci/pytorch/docs-test.sh```), + # since it would try to import ``torch.distributed.device_mesh`` or + # ``torch.distributed.init_device_mesh`` but cannot find them. + + class _DeviceMeshStub: + pass + + def _init_device_mesh_stub(): + pass + + sys.modules["torch.distributed.device_mesh"].DeviceMesh = _DeviceMeshStub # type: ignore[attr-defined] + sys.modules[ + "torch.distributed.device_mesh" + ].init_device_mesh = _init_device_mesh_stub # type: ignore[attr-defined] + + +else: + from torch._C._distributed_c10d import Backend as C10dBackend + from torch.distributed.distributed_c10d import ( + _get_default_group, + _resolve_process_group, + get_backend, + get_process_group_ranks, + get_rank, + get_world_size, + init_process_group, + is_initialized, + new_group, + ProcessGroup, + split_group, + ) + + logger = logging.getLogger(__name__) + + # only import numpy typing when type checking + if TYPE_CHECKING: + try: + from numpy.typing import ArrayLike + except ImportError: + logger.warning( + "DeviceMesh requires numpy >= 1.21 to be installed for type checking" + ) + + class _MeshEnv(threading.local): + def __init__(self) -> None: + self.mesh_stack: list[DeviceMesh] = [] + self.child_to_root_mapping: dict[DeviceMesh, DeviceMesh] = {} + self.mesh_dim_group_options: dict[ + int, tuple[str, Optional[C10dBackend.Options]] + ] = {} + self.root_to_flatten_mapping: dict[DeviceMesh, dict[str, DeviceMesh]] = {} + # Record flatten mesh name to its mesh dim index in root mesh. + self.flatten_name_to_root_dims: dict[ + DeviceMesh, dict[str, tuple[int, ...]] + ] = {} + + def get_current_mesh(self) -> "DeviceMesh": + if len(self.mesh_stack) == 0: + raise RuntimeError("No device mesh is currently active!") + return self.mesh_stack[-1] + + def create_sub_mesh( + self, + device_mesh: "DeviceMesh", + submesh_dim_names: tuple[str, ...], + submesh_dims: list[tuple[int, ...]], + ) -> "DeviceMesh": + # Get the submesh dim size from the submesh_dims. + # For example, if we have a 3D mesh with mesh_shape (2, 2, 2) mesh_dim_names ("dp", "cp", "tp") and we want + # to slice out mesh["dp_cp"], then submesh_dims = [(0, 1), (2,)] and submesh_dim_size = [2 * 2, 2] = [4, 2]. + # If we want to slice out mesh["dp", "cp"], then submesh_dims = [(0,), (1,)] and submesh_dim_size = [2, 2]. + slice_dim_size = [ + reduce( + lambda x, y: x * device_mesh.mesh.size(y), + mesh_dim, + 1, + ) + for mesh_dim in submesh_dims + ] + + mesh_tensor = device_mesh.mesh + # slice_dim_idx could be different from submesh_dims, as we may need to flatten out some dims. + slice_dim_idx = [] + slice_dim_group_name = [] + # keep track of the number of dims that have been flattened so we can get the correct slice_dim_idx in the + # flattened mesh tensor. + num_dims_flatten = 0 + for mesh_dim_indices, mesh_dim_name in zip(submesh_dims, submesh_dim_names): + # Currently, this only allows slicing out a contiguous flattened dim. + # TODO: we need to handle reconstructing a non-contiguous flattened dim. + if len(mesh_dim_indices) > 1: + # We need to move the start_dim and end_dim to the left if some dims are already flattened. + mesh_tensor = mesh_tensor.flatten( + start_dim=mesh_dim_indices[0] - num_dims_flatten, + end_dim=mesh_dim_indices[-1] - num_dims_flatten, + ) + # If some dims are already flattened, we need to adjust the slice_dim_idx accordingly. + # For example, if the submesh_dims = [(0, 1), (2,), (3, 4)] with 0-1 flattened and 3-4 flattened, + # then the final slice_dim_idx should be [0, 1, 2]. + slice_dim_idx.append(mesh_dim_indices[0] - num_dims_flatten) + num_dims_flatten += len(mesh_dim_indices) - 1 + slice_dim_group_name.append( + self.root_to_flatten_mapping[device_mesh][ + mesh_dim_name + ]._dim_group_names[0] + ) + else: + slice_dim_idx.append(mesh_dim_indices[0] - num_dims_flatten) + slice_dim_group_name.append( + device_mesh._dim_group_names[mesh_dim_indices[0]] + ) + + # mesh_tensor has already been flattened if needed. So mesh_tensor.ndim <= device_mesh.mesh.ndim now. + mesh_dims_remained_idx = list(range(mesh_tensor.ndim)) + for idx in slice_dim_idx: + mesh_dims_remained_idx.remove(idx) + + # pg_ranks_by_dim is the size of [number of local ranks of the outermost submesh dimension, *slice_dim_idx] + # This means on each local rank of the outermost slice mesh dim, we have a tensor of submesh size with + # the pg ranks of the submesh. From this, we can extract the submesh mesh tensor contains the current rank. + pg_ranks_by_dim = mesh_tensor.permute( + *mesh_dims_remained_idx, *slice_dim_idx + ).reshape(-1, *slice_dim_size) + + cur_rank = device_mesh.get_rank() + for mesh_nd in pg_ranks_by_dim: + submesh = DeviceMesh( + device_mesh.device_type, + mesh_nd, + mesh_dim_names=submesh_dim_names, + _init_backend=False, + ) + if cur_rank in mesh_nd: + res_submesh = submesh + + res_submesh._dim_group_names = slice_dim_group_name # type: ignore[possibly-undefined] + self.child_to_root_mapping[res_submesh] = device_mesh + + return res_submesh + + def create_flatten_mesh( + self, device_mesh: "DeviceMesh", mesh_dim_name: Optional[str] = None + ) -> "DeviceMesh": + root_mesh = _mesh_resources.get_root_mesh(device_mesh) + + flatten_dims_in_root = [ + not_none(root_mesh.mesh_dim_names).index(flattened_mesh_dim_name) + for flattened_mesh_dim_name in not_none(device_mesh.mesh_dim_names) + ] + + if not mesh_dim_name: + mesh_dim_name = "_".join( + [ + not_none(root_mesh.mesh_dim_names)[dim] + for dim in flatten_dims_in_root + ] + ) + + # Check whether the mesh_dim_name for flattened mesh is valid. + self.flatten_name_to_root_dims.setdefault(root_mesh, {}) + invalid_dim_names = chain( + *list(not_none(root_mesh.mesh_dim_names)), + *self.flatten_name_to_root_dims[root_mesh].keys(), + ) + if mesh_dim_name in invalid_dim_names: + raise RuntimeError( + f"{mesh_dim_name} already exists for submesh of the {root_mesh}. ", + f"The mesh_dim_names of submesh and flattened mesh are {invalid_dim_names}. " + f"Please specify another valid mesh_dim_name.", + ) + + # Quick return if the flatten mesh has been created before. + # TODO: If we decide to restrict flatten initialization once, we should remove + # this check and throw an error if the flatten mesh is already created before. + if ( + root_mesh in self.root_to_flatten_mapping + and mesh_dim_name in self.root_to_flatten_mapping[root_mesh] + ): + return self.root_to_flatten_mapping[root_mesh][mesh_dim_name] + + flattened_mesh_dim_size = math.prod(device_mesh.mesh.size()) + + remained_dims_in_root = list(range(root_mesh.mesh.ndim)) + for flatten_dim_in_root in flatten_dims_in_root: + remained_dims_in_root.remove(flatten_dim_in_root) + + pg_ranks_by_dim = root_mesh.mesh.permute( + *remained_dims_in_root, *flatten_dims_in_root + ).reshape(-1, flattened_mesh_dim_size) + + cur_rank = root_mesh.get_rank() + for mesh_nd in pg_ranks_by_dim: + # need to init backend here since the flattened pg doesn't exist in root mesh. + flattened_mesh = DeviceMesh( + root_mesh.device_type, + mesh_nd, + mesh_dim_names=(mesh_dim_name,), + ) + if cur_rank in mesh_nd: + res_flattened_mesh = flattened_mesh + self.child_to_root_mapping[res_flattened_mesh] = root_mesh # type: ignore[possibly-undefined] + self.root_to_flatten_mapping.setdefault(root_mesh, {})[mesh_dim_name] = ( + res_flattened_mesh # type: ignore[possibly-undefined] + ) + self.flatten_name_to_root_dims[root_mesh][mesh_dim_name] = tuple( + flatten_dims_in_root + ) # type: ignore[possibly-undefined] + + return res_flattened_mesh + + def get_root_mesh(self, device_mesh: "DeviceMesh") -> "DeviceMesh": + # If a mesh could not be found in the child_to_root_mapping, it is a root mesh itself. + # A root mesh is not created through slicing. + # We considers the root mesh of a root mesh is itself. + root_mesh = self.child_to_root_mapping.get(device_mesh, None) + return device_mesh if not root_mesh else root_mesh + + def get_root_mesh_dim(self, device_mesh: "DeviceMesh") -> Optional[int]: + """ + Returns the index of the mesh dim in the root mesh. + The device_mesh passed in needs to be sliced out from the root mesh + or submesh of the root mesh. + """ + root_mesh = self.get_root_mesh(device_mesh) + child_mesh_dim_names = device_mesh.mesh_dim_names + if root_mesh and child_mesh_dim_names: + assert len(child_mesh_dim_names) == 1, ( + "The submesh can only be a 1D mesh." + ) + child_mesh_dim_name = child_mesh_dim_names[0] + return self.get_mesh_dim_by_name(root_mesh, child_mesh_dim_name) + return None + + @staticmethod + def num_devices_per_host(device_type: str) -> int: + return _get_device_handle(device_type).device_count() + + @staticmethod + def num_hosts(device_type: str) -> int: + # ProcessGroup can't tell us this info so we have to infer it, assume + # homogeneous hardware for now + return get_world_size() // _MeshEnv.num_devices_per_host(device_type) + + def get_mesh_dim_by_name( + self, device_mesh: "DeviceMesh", mesh_dim_name: str + ) -> int: + if ( + device_mesh.mesh_dim_names is None + or len(device_mesh.mesh_dim_names) == 0 + ): + raise KeyError( + "No `mesh_dim_names` found.", + ) + if mesh_dim_name not in device_mesh.mesh_dim_names: + raise KeyError( + f"Mesh dimension '{mesh_dim_name}' does not exist.", + f"Available mesh dimensions are: mesh_dim_names={device_mesh.mesh_dim_names}", + ) + return not_none(device_mesh.mesh_dim_names.index(mesh_dim_name)) + + def _set_mesh_dim_group_options( + self, + dim: int, + backend: str, + pg_options: Optional[C10dBackend.Options] = None, + ) -> None: + self.mesh_dim_group_options[dim] = (backend, pg_options) + + def _get_slice_mesh_dims( + self, device_mesh, mesh_dim_names + ) -> list[tuple[int, ...]]: + """ + Validate whether the mesh_dim_names is valid for slicing the given device_mesh. + If valid, return dim indexes of the slice mesh in the device mesh. + """ + if device_mesh != self.get_root_mesh(device_mesh): + raise RuntimeError("Cannot create a submesh from a submesh.") + + # The slice mesh_dim_names should consist either the device_mesh's mesh_dim_names + # or its flattened mesh's mesh_dim_names. + self.flatten_name_to_root_dims.setdefault(device_mesh, {}) + flatten_name_to_root_dims = self.flatten_name_to_root_dims[device_mesh] + valid_mesh_dim_names = [ + *device_mesh.mesh_dim_names, + *flatten_name_to_root_dims, + ] + + if not all( + mesh_dim_name in valid_mesh_dim_names + for mesh_dim_name in mesh_dim_names + ): + raise KeyError( + f"Invalid mesh_dim_names {mesh_dim_names} specified. " + f"Valid mesh_dim_names are {valid_mesh_dim_names}." + ) + + # Validate the order of the slice mesh dim indices. + # This needs to be in ascending order. + curr_idx = -1 + slice_mesh_dims = [] + for mesh_dim_name in mesh_dim_names: + if mesh_dim_name in flatten_name_to_root_dims: + mesh_indices = flatten_name_to_root_dims[mesh_dim_name] + # TODO: this doesn't allow non-contiguous slicing with flatten dim yet. next_idx + # should be mesh_indices[0] once we support non-contiguous slicing with flatten dim. + next_idx = mesh_indices[-1] + slice_mesh_dims.append(mesh_indices) + else: + next_idx = device_mesh.mesh_dim_names.index(mesh_dim_name) + slice_mesh_dims.append((next_idx,)) + if next_idx <= curr_idx: + raise KeyError( + f"Invalid mesh_dim_names {mesh_dim_names} specified. ", + f"Found mesh dim indices to slice: {slice_mesh_dims}. ", + "Mesh dim indices should be in ascending order.", + ) + curr_idx = next_idx + + return slice_mesh_dims + + def _get_all_submeshes( + self, device_mesh: "DeviceMesh", mesh_dim_name: str + ) -> list["DeviceMesh"]: + """ + Return all the submeshes of a given mesh dimension of the device mesh. + """ + mesh_dim = self.get_mesh_dim_by_name(device_mesh, mesh_dim_name) + pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, mesh_dim).reshape( + -1, device_mesh.mesh.size(mesh_dim) + ) + + cur_rank = device_mesh.get_rank() + res_submeshes = [] + for mesh_1d in pg_ranks_by_dim: + submesh = DeviceMesh( + device_mesh.device_type, + mesh_1d, + mesh_dim_names=(mesh_dim_name,), + _init_backend=False, + ) + submesh._dim_group_names = ( + [device_mesh._dim_group_names[mesh_dim]] + if cur_rank in mesh_1d + else [] + ) + res_submeshes.append(submesh) + + return res_submeshes + + _mesh_resources: _MeshEnv = _MeshEnv() + + def _get_device_handle(device_type: str = "cuda"): + """ + Get the module corresponding to the device_type which is cuda or cuda-like device. + For example, when the device_type is cuda, the module `torch.cuda` is returned. + Return None when there is no corresponding module for device_type, otherwise + return the corresponding module. + """ + return getattr(torch, device_type, None) + + class DeviceMesh: + """ + DeviceMesh represents a mesh of devices, where layout of devices could be + represented as a n-d dimension array, and each value of the n-d dimensional + array is the global id of the default process group ranks. + + DeviceMesh could be used to setup the N dimensional device connections across the cluster, + and manage the ProcessGroups for N dimensional parallelisms. Communications could happen on + each dimension of the DeviceMesh separately. DeviceMesh respects the device that user selects + already (i.e. if user call `torch.cuda.set_device` before the DeviceMesh initialization), + and will select/set the device for the current process if user does not set the device + beforehand. Note that manual device selection should happen BEFORE the DeviceMesh initialization. + + DeviceMesh can also be used as a context manager when using together with DTensor APIs. + + .. note:: + DeviceMesh follows SPMD programming model, which means the same PyTorch Python program + is running on all processes/ranks in the cluster. Therefore, users need to make sure the + `mesh` array (which describes the layout of devices) should be identical across all ranks. + Inconsistent `mesh` will lead to silent hang. + + Args: + device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like". + mesh (ndarray): A multi-dimensional array or an integer tensor describing the layout + of devices, where the IDs are global IDs of the default process group. + + Returns: + DeviceMesh: A :class:`DeviceMesh` object representing the device layout. + + The following program runs on each process/rank in an SPMD manner. In this example, we have 2 + hosts with 4 GPUs each. + A reduction over the first dimension of mesh will reduce across + columns (0, 4), .. and (3, 7), a reduction over the second dimension + of mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7). + + Example:: + + >>> # xdoctest: +SKIP("no rank") + >>> from torch.distributed.device_mesh import DeviceMesh + >>> + >>> # Initialize device mesh as (2, 4) to represent the topology + >>> # of cross-host(dim 0), and within-host (dim 1). + >>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]]) + """ + + device_type: str + mesh: torch.Tensor + mesh_dim_names: Optional[tuple[str, ...]] + + def __init__( + self, + device_type: str, + mesh: Union[torch.Tensor, "ArrayLike"], + *, + mesh_dim_names: Optional[tuple[str, ...]] = None, + _init_backend: bool = True, + ) -> None: + self.device_type = device_type + if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu": + raise ValueError(f"`mesh` must be a CPU tensor, got {mesh}") + self.mesh = ( + mesh.detach().to(dtype=torch.int) + if isinstance(mesh, torch.Tensor) + else torch.tensor(mesh, device="cpu", dtype=torch.int) + ) + self.mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None + + # private field to pre-generate DeviceMesh's hash + self._flatten_mesh_list = tuple(self.mesh.flatten().tolist()) + self._thread_id = None + + # Skip process group initialization if xla device or init backend is False + # TODO(yeounoh) implement DeviceMesh backend and register XLA backend. + if device_type != "xla": + # always try to create default (world) pg, even if it is not initialized + # already. The world pg is used for device mesh identity (rank) on each + # process (we need to know if the current global rank is in the mesh or not). + if _init_backend: + self._setup_world_group_and_device() + self._init_process_groups() + + if is_initialized() and get_backend() == "threaded": + self._thread_id = threading.get_ident() + + # calculate the coordinates of the current global rank on the mesh + rank_coords = (self.mesh == get_rank()).nonzero() + assert rank_coords.size(0) in (0, 1) + self._coordinate_on_dim: Optional[list[int]] = ( + rank_coords[0].tolist() if rank_coords.size(0) > 0 else None + ) + + def _setup_world_group_and_device(self): + default_initialized = is_initialized() + # TODO: think about how to allow pg options to be passed to world group + # or mesh dimension groups + if not default_initialized: + init_process_group() + + world_size = get_world_size() + if self.mesh.numel() > world_size: + raise RuntimeError( + f"Mesh should not be bigger than default world size {world_size}, but found {self.mesh.numel()} ranks!" + ) + + # ONLY set the device if the current device is not initialized, if user already + # set the device before DeviceMesh init, we respect the user's choice. + device_handle = _get_device_handle(self.device_type) + if device_handle and not device_handle.is_initialized(): + # auto set the cuda/cuda-like device only if user has not set it, if there's LOCAL_RANK + # env variable from launchers, we use it to set the device. + if "LOCAL_RANK" in os.environ: + local_rank = int(os.environ["LOCAL_RANK"]) + logger.info( + "Setting default device for the current process based on LOCAL_RANK=%s", + local_rank, + ) + device_handle.set_device(local_rank) + else: + warnings.warn( + "It seems like you did not set/select the default device for the current process before the DeviceMesh " + "initialization or use a launcher (i.e. torchrun) which populates `LOCAL_RANK` environment variable. " + "It is recommended to set the current device for the process BEFORE the DeviceMesh initialization so that " + "the underlying communicator (i.e. NCCL) can be initialized properly. " + "Given that the current process has no default device selected, DeviceMesh will use a heuristic to set the " + "device_id via `global_rank % num_devices_per_host`, assuming homogeneous hardware cluster. " + ) + # heuristic to set the current cuda/cuda-like device base on num of gpu devices available in each host + # NOTE: This device selection would only work for homogeneous hardware. + num_devices_per_host = device_handle.device_count() + if ( + world_size > num_devices_per_host + and world_size % num_devices_per_host != 0 + ): + raise RuntimeError( + f"DeviceMesh only support homogeneous hardware, but found " + f"{world_size} ranks and {num_devices_per_host} {self.device_type} devices!" + ) + device_handle.set_device(get_rank() % num_devices_per_host) + + return _get_default_group() + + def _init_process_groups(self): + # group_name associated with each mesh dimension, each + # mesh dimension should have one sub-group per rank + # + dim_group_names: list[str] = [] + default_group = _get_default_group() + + if self.mesh.ndim == 1 and self.mesh.numel() == get_world_size(): + # Append the default pg to the first dim groups only if the default pg is compatible with `self.device_type`. + # Otherwise, create new pg. + ranks = list(range(get_world_size())) + dim_group = ( + new_group( + backend="cpu:gloo,cuda:nccl", + ranks=ranks, + group_desc="mesh_default", + ) + if torch.cuda.is_available() + and get_backend(default_group) == "gloo" + else default_group + ) + dim_group_names.append(dim_group.group_name) + else: + # create sub pgs base on the mesh argument specified + for dim in range(self.mesh.ndim): + # swap the current dim to the last dim + # then reshape to flatten out other dims + pg_ranks_by_dim = self.mesh.swapdims(-1, dim).reshape( + -1, self.mesh.size(dim) + ) + + # Respect dim group options specified via _MeshEnv.set_dim_group_options(). + # Inherit from the parent group if no options are specified for the group. + if dim in _mesh_resources.mesh_dim_group_options: + ( + backend, + pg_options, + ) = _mesh_resources.mesh_dim_group_options[dim] + else: + backend, pg_options = None, None + + # If we have a 2D mesh with mesh_dim_names ("dp", "tp"), the group description + # of the subgroups would be `mesh_dim_dp` and `mesh_name_tp`. + # If the mesh doesn't not have a mesh_dim_names, then the group description of the + # subgroup would be `mesh_dim_0` and `mesh_dim_1`. + group_desc = ( + f"mesh_{self.mesh_dim_names[dim]}" + if self.mesh_dim_names + else f"mesh_dim_{dim}" + ) + + # If bound_device_id exists, it means the nccl communicator has been eagerly initialized + # so that we can use `split_group` to create subgroups through `ncclCommSplit`. + # In this case, we only need to make one API call (`split_group``) for the subgroup creation + # for each mesh dimension. In a 2 * 4 mesh, we only need to make 2 API calls per ranks to create + # all the subgroups. + # Otherwise, we need to make more than one API call (`new_group`) for subgroup creations. The + # numbers of API calls are equal to the number of subgroups for each mesh dimension. In a 2 * 4 + # mesh, we need to make 2 + 4 = 6 API calls per ranks to create all the subgroups. + dim_group = None + has_split_group = False + if ( + bound_device_id := getattr( + default_group, "bound_device_id", None + ) + ) is not None and torch.cuda.is_available(): + dim_group = split_group( + parent_pg=default_group, + pg_options=pg_options, + split_ranks=pg_ranks_by_dim.tolist(), + group_desc=group_desc, + ) + has_split_group = True + + # If the subgroup has been already created through `split_group`, we simply loop over `pg_ranks_by_dim` + # and append the `group_name` to the `dim_group_names` list when the current rank is in the subgroup. + # Otherwise, we use `new_group` instead of `split_group` to create subgroups by looping over `pg_ranks_by_dim` + # along with appending information to the `dim_group_names` list whenever necessary. + for dim_mesh in pg_ranks_by_dim: + subgroup_ranks = dim_mesh.tolist() + + # We temporarily revert the reuse subgroup, since it breaks two internal tests. + # Temporarily reverting to resolve test timeout while root-causing. + # TODO: Add two tests to cover internal tests scenarios and re-enable reuse subgroup if exists. + if bound_device_id is None or not has_split_group: + dim_group = new_group( + ranks=subgroup_ranks, + backend=backend, + pg_options=pg_options, + group_desc=group_desc, + ) + + # only add to dim_groups if the current rank in the subgroup + if self.get_rank() in subgroup_ranks: + if len(dim_group_names) > dim: + raise RuntimeError( + f"Each device mesh dimension should get only one process group, but got {self.get_rank()} " + f"in {subgroup_ranks}!" + ) + dim_group_names.append(dim_group.group_name) # type: ignore[union-attr] + self._dim_group_names = dim_group_names + + def __enter__(self) -> "DeviceMesh": + # set this mesh as the current mesh in mesh env + _mesh_resources.mesh_stack.append(self) + return self + + # pyre-fixme[2]: Parameter must be annotated. + def __exit__(self, exc_type, exc_value, exc_traceback) -> None: + # pop this mesh from mesh env + _mesh_resources.mesh_stack.pop() + + def __repr__(self) -> str: + device_mesh_repr = ( + f"DeviceMesh('{self.device_type}', {self.mesh.tolist()})" + if not self.mesh_dim_names + else f"DeviceMesh('{self.device_type}', {self.mesh.tolist()}, mesh_dim_names={self.mesh_dim_names})" + ) + return device_mesh_repr + + def __hash__(self): + # lazily compute hash + self._hash = getattr(self, "_hash", None) + if not self._hash: + self._hash = hash( + ( + self._flatten_mesh_list, + self.mesh.shape, + self.device_type, + self.mesh_dim_names, + self._thread_id, + ) + ) + return self._hash + + def __eq__(self, other: object) -> bool: + if not isinstance(other, DeviceMesh): + return False + if id(self) == id(other): + return True + else: + return ( + self._flatten_mesh_list == other._flatten_mesh_list + and self.mesh.shape == other.mesh.shape + and self.device_type == other.device_type + and self.mesh_dim_names == other.mesh_dim_names + and self._thread_id == other._thread_id + ) + + def __getitem__( + self, mesh_dim_names: Union[str, tuple[str, ...]] + ) -> "DeviceMesh": + """ + Slice the current DeviceMesh based on the mesh_dim_names given to create a submesh. + The submesh created consists of the dimensions and the communicators indicated by + ``mesh_dim_names`` + + Args: + mesh_dim_names (Union[str, Tuple[str]]): the name or the tuple of names of the + mesh dimension of the DeviceMesh to create the submesh for. + Returns: + A :class:`DeviceMesh` object + + The following program runs on each process/rank in an SPMD manner in a world size of 8. + In the first example: + Calling mesh_2d["tp"] on rank 0, 1, 2, 3 returns a 1D submesh of DeviceMesh:([0, 1, 2, 3]). + Calling mesh_2d["tp"] on rank 4, 5, 6, 7 returns a 1D submesh of DeviceMesh:([4, 5, 6, 7]). + Calling mesh_2d["dp"] on rank 0, 4 returns a 1D submesh of DeviceMesh:([0, 4]). + Calling mesh_2d["dp"] on rank 1, 5 returns a 1D submesh of DeviceMesh:([1, 5]). + Calling mesh_2d["dp"] on rank 2, 6 returns a 1D submesh of DeviceMesh:([2, 6]). + Calling mesh_2d["dp"] on rank 3, 7 returns a 1D submesh of DeviceMesh:([3, 7]). + + In the second example: + Calling mesh_3d["dp", "cp"] on rank 0, 1, 4, 5 returns a 2D submesh of DeviceMesh:([[0, 1], [4, 5]]). + Calling mesh_3d["dp", "cp"] on rank 2, 3, 6, 7 returns a 2D submesh of DeviceMesh:([[2, 3], [6, 7]]). + Calling mesh_3d["cp", "dp"] on rank 0, 1, 4, 5 returns a 2D submesh of DeviceMesh:([[0, 4], [1, 5]]). + Calling mesh_3d["cp", "dp"] on rank 2, 3, 6, 7 returns a 2D submesh of DeviceMesh:([[2, 6], [3, 7]]). + + Example:: + + >>> # xdoctest: +SKIP("no rank") + >>> from torch.distributed.device_mesh import DeviceMesh + >>> + >>> # Initialize a 2D device mesh as (2, 4) to represent the topology + >>> # of cross-host(dim 0), and within-host (dim 1). + >>> mesh_2d = init_device_mesh(device_type="cuda", (2,4), mesh_dim_names=("dp", "tp")) + >>> tp_mesh = mesh_2d["tp"] + >>> dp_mesh = mesh_2d["dp"] + >>> + >>> # Initialize a 3D mesh. + >>> mesh_3d = init_device_mesh(device_type="cuda", (2,2,2), mesh_dim_names=("dp", "pp", "cp")) + >>> # The order of the mesh_dim_names provided deteremines the order of dimensions in the submesh. + >>> dp_cp_mesh = mesh_3d["dp", "cp"] + >>> cp_dp_mesh = mesh_3d["cp", "dp"] + """ + if not self.mesh_dim_names: + raise RuntimeError("Cannot slice a DeviceMesh without mesh_dim_names!") + + mesh_dim_names = ( + (mesh_dim_names,) if isinstance(mesh_dim_names, str) else mesh_dim_names + ) + + if mesh_dim_names == self.mesh_dim_names: + return self + else: + slice_mesh_dims = _mesh_resources._get_slice_mesh_dims( + self, mesh_dim_names + ) + # When using FakeTensorMode to trace the model, `create_sub_mesh()` will + # fail as it will require a real tensor to manipulate. + # `unset_fake_temporarily()` will allow us to materialize the tensors + # within `_mesh_resources`, which should not affect modling. + # + # Note that this should be orthogonal to torch.compile(). But whether + # we can compile device_mesh `slicing` (no graph break) is not verified + # yet and need a follow-up, + # TODO: compiler + device_mesh slicing. + with torch._subclasses.fake_tensor.unset_fake_temporarily(): + submesh = _mesh_resources.create_sub_mesh( + self, mesh_dim_names, slice_mesh_dims + ) + return submesh + + def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> ProcessGroup: + """ + Returns the single ProcessGroup specified by mesh_dim, or, if mesh_dim is not specified and the + DeviceMesh is 1-dimensional, returns the only ProcessGroup in the mesh. + + Args: + mesh_dim (str/int, optional): it can be the name of the mesh dimension or the index + of the mesh dimension. Default is None. + + Returns: + A :class:`ProcessGroup` object. + """ + if not hasattr(self, "_dim_group_names"): + raise RuntimeError("DeviceMesh process groups not initialized!") + + if self.mesh.ndim > 1 and mesh_dim is None: + raise RuntimeError( + f"Found the DeviceMesh have {self.mesh.ndim} dimensions", + "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.", + "If you want to get the list of all the ProcessGroups in the DeviceMesh," + "please use `get_all_groups()` instead.", + ) + + # Quick return if the current device_mesh is a 1D mesh. + if self.mesh.ndim == 1 and mesh_dim is None: + return not_none(_resolve_process_group(self._dim_group_names[0])) + + root_mesh = _mesh_resources.get_root_mesh(self) + root_to_flatten_mapping = _mesh_resources.root_to_flatten_mapping.get( + root_mesh, None + ) + if root_to_flatten_mapping and mesh_dim in root_to_flatten_mapping.keys(): + dim_group_name = root_to_flatten_mapping[ + mesh_dim # type: ignore[index] + ]._dim_group_names[0] + return not_none(_resolve_process_group(dim_group_name)) + else: + mesh_dim = ( + _mesh_resources.get_mesh_dim_by_name(self, mesh_dim) + if isinstance(mesh_dim, str) + else mesh_dim + ) + assert isinstance(mesh_dim, int) + return not_none(_resolve_process_group(self._dim_group_names[mesh_dim])) + + def get_all_groups(self) -> list[ProcessGroup]: + """ + Returns a list of ProcessGroups for all mesh dimensions. + + Returns: + A list of :class:`ProcessGroup` object. + """ + return [self.get_group(i) for i in range(self.mesh.ndim)] + + @staticmethod + def from_group( + group: Union[ProcessGroup, list[ProcessGroup]], + device_type: str, + mesh: Optional[Union[torch.Tensor, "ArrayLike"]] = None, + *, + mesh_dim_names: Optional[tuple[str, ...]] = None, + ) -> "DeviceMesh": + """ + Constructs a :class:`DeviceMesh` with ``device_type`` from an + existing :class:`ProcessGroup` or a list of existing :class:`ProcessGroup`. + + The constructed device mesh has number of dimensions equal to the + number of groups passed. For example, if a single process group is passed in, + the resulted DeviceMesh is a 1D mesh. If a list of 2 process groups is passed in, + the resulted DeviceMesh is a 2D mesh. + + If more than one group is passed, then the ``mesh`` and ``mesh_dim_names`` arguments + are required. The order of the process groups passed in determines the topology of + the mesh. For example, the first process group will be the 0th dimension of the DeviceMesh. + The `mesh` tensor passed in must have the same number of dimensions as the number of process + groups passed in, and the order of the dimensions in the `mesh` tensor must match the order + in the process groups passed in. + + Args: + group (ProcessGroup or list[ProcessGroup]): the existing ProcessGroup + or a list of existing ProcessGroups. + device_type (str): The device type of the mesh. Currently supports: "cpu", + "cuda/cuda-like". Passing in a device type with a GPU index, such as "cuda:0", + is not allowed. + mesh (torch.Tensor or ArrayLike, optional): A multi-dimensional array or an + integer tensor describing the layout of devices, where the IDs are global IDs + of the default process group. Default is None. + mesh_dim_names (tuple[str], optional): A tuple of mesh dimension names to assign + to each dimension of the multi-dimensional array describing the layout of devices. + Its length must match the length of `mesh_shape`. Each string in `mesh_dim_names` + must be unique. Default is None. + + Returns: + DeviceMesh: A :class:`DeviceMesh` object representing the device layout. + """ + + # 1D scenario + if isinstance(group, ProcessGroup): + group_ranks = get_process_group_ranks(group) + if ( + isinstance(mesh, torch.Tensor) and mesh.tolist() != group_ranks + ) or ( + mesh is not None + and not isinstance(mesh, torch.Tensor) + and mesh != group_ranks + ): + raise ValueError( + f"Invalid mesh {str(mesh)} for ProcessGroup with ranks {group_ranks}" + ) + mesh = torch.tensor(group_ranks, device="cpu", dtype=torch.int) + device_mesh = DeviceMesh( + device_type, + mesh, + mesh_dim_names=mesh_dim_names, + _init_backend=False, + ) + device_mesh._dim_group_names = [group.group_name] + return device_mesh + + # nD scenario + groups = list(group) + if len(groups) == 0: + raise ValueError("Expects at least one ProcessGroup to be passed") + if mesh is None: + raise ValueError("Must pass mesh if passing multiple ProcessGroups") + if mesh_dim_names is None: + raise ValueError( + "Must pass mesh_dim_names if passing multiple ProcessGroups" + ) + mesh = ( + mesh.detach().to(dtype=torch.int, device="cpu") + if isinstance(mesh, torch.Tensor) + else torch.tensor(mesh, device="cpu", dtype=torch.int) + ) + if mesh.ndim != len(groups): + raise ValueError( + "Expects mesh with ndim equal to number of ProcessGroups but got " + f"mesh {mesh.tolist()} and {len(groups)} ProcessGroups" + ) + device_mesh = DeviceMesh( + device_type, mesh, mesh_dim_names=mesh_dim_names, _init_backend=False + ) + device_mesh._dim_group_names = [group.group_name for group in groups] + return device_mesh + + def size(self, mesh_dim: Optional[int] = None) -> int: + return self.mesh.numel() if mesh_dim is None else self.mesh.size(mesh_dim) + + @property + def ndim(self) -> int: + return self.mesh.ndim + + @property + def shape(self) -> tuple[int, ...]: + return tuple(self.mesh.shape) + + def get_rank(self) -> int: + """ + Returns the current global rank. + """ + return get_rank() + + def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int: + """ + Returns the local rank of the given mesh_dim of the DeviceMesh. + + Args: + mesh_dim (str/int, optional): it can be the name of the mesh dimension or the index + of the mesh dimension. Default is None. + + Returns: + An integer denotes the local rank. + + The following program runs on each process/rank in an SPMD manner. In this example, we have 2 + hosts with 4 GPUs each. + Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 0, 1, 2, 3 would return 0. + Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 4, 5, 6, 7 would return 1. + Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 0, 4 would return 0. + Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 1, 5 would return 1. + Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 2, 6 would return 2. + Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 3, 7 would return 3. + + Example:: + + >>> # xdoctest: +SKIP("no rank") + >>> from torch.distributed.device_mesh import DeviceMesh + >>> + >>> # Initialize device mesh as (2, 4) to represent the topology + >>> # of cross-host(dim 0), and within-host (dim 1). + >>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]]) + """ + if self.ndim > 1 and mesh_dim is None: + raise RuntimeError( + f"Found the DeviceMesh have {self.mesh.ndim} dimensions", + "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.", + ) + elif mesh_dim is None: + mesh_dim = 0 + + mesh_dim_group = not_none(self.get_group(mesh_dim)) + assert isinstance(mesh_dim_group, ProcessGroup), ( + "We expect ProcessGroup before calling `get_rank`!" + ) + return not_none(get_rank(mesh_dim_group)) + + def get_coordinate(self) -> Optional[list[int]]: + """ + Return the relative indices of this rank relative to all + dimensions of the mesh. If this rank is not part of the mesh, return None. + """ + return self._coordinate_on_dim if self._coordinate_on_dim else None + + def _flatten(self, mesh_dim_name: Optional[str] = None) -> "DeviceMesh": + """ + Returns a 1D DeviceMesh by flattening the current DeviceMesh. + + If no mesh_dim_name is provided, the default is a string concatenating the mesh_dim_names of the + given submesh with each mesh_dim_name separated by "_". For example, if we have a 3D mesh + DeviceMesh([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], mesh_dim_names=("dp", "cp", "tp")), calling + mesh_3d["dp", "cp"]._flatten() will create a 1D submesh DeviceMesh([0, 1, 2, 3], mesh_dim_names=("dp_cp",)) + on rank 0, 1, 2, 3 and a 1D submesh DeviceMesh([4, 5, 6, 7], mesh_dim_names=("dp_cp",)) on rank 4, 5, 6, 7. + + After the flattened dimension is created, to access the flattened dimension in mesh_3d, one can use the + existing slicing method to obtain the flattened mesh through calling mesh_3d["dp_cp"]. + """ + if not self.mesh_dim_names: + raise RuntimeError( + "Cannot flatten a DeviceMesh without mesh_dim_names!" + ) + + return _mesh_resources.create_flatten_mesh(self, mesh_dim_name) + + def init_device_mesh( + device_type: str, + mesh_shape: tuple[int, ...], + *, + mesh_dim_names: Optional[tuple[str, ...]] = None, + ) -> DeviceMesh: + """ + Initializes a `DeviceMesh` based on `device_type`, `mesh_shape`, and `mesh_dim_names` parameters. + + This creates a DeviceMesh with an n-dimensional array layout, where `n` is the length of `mesh_shape`. + If `mesh_dim_names` is provided, each dimension is labeled as `mesh_dim_names[i]`. + + .. note:: + `init_device_mesh` follows SPMD programming model, meaning the same PyTorch Python program + runs on all processes/ranks in the cluster. Ensure `mesh_shape` (the dimensions of the nD array + describing device layout) is identical across all ranks. Inconsistent `mesh_shape` may lead to hanging. + + .. note:: + If no process group is found, init_device_mesh will initialize distributed process group/groups + required for distributed communications behind the scene. + + Args: + device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like". + Passing in a device type with a GPU index, such as "cuda:0", is not allowed. + mesh_shape (Tuple[int]): A tuple defining the dimensions of the multi-dimensional array + describing the layout of devices. + mesh_dim_names (Tuple[str], optional): A tuple of mesh dimension names to assign to each dimension + of the multi-dimensional array describing the layout of devices. Its length must match the length + of `mesh_shape`. Each string in `mesh_dim_names` must be unique. + + Returns: + DeviceMesh: A :class:`DeviceMesh` object representing the device layout. + + Example:: + + >>> # xdoctest: +SKIP("no rank") + >>> from torch.distributed.device_mesh import init_device_mesh + >>> + >>> mesh_1d = init_device_mesh("cuda", mesh_shape=(8,)) + >>> mesh_2d = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp")) + + """ + if mesh_dim_names is not None: + if len(set(mesh_dim_names)) != len(mesh_dim_names): + raise RuntimeError( + "Each mesh_dim_name must be unique.", + f"Found repeated mesh_dim_name in mesh_dim_names {mesh_dim_names}", + ) + + if len(mesh_shape) != len(mesh_dim_names): + raise RuntimeError( + "mesh_shape and mesh_dim_names should have same length!", + f"Found len(mesh_dim_names): {len(mesh_dim_names)} and len(mesh_shape):{len(mesh_shape)}.", + ) + + # assume valid device types are all letters + if device_type and not device_type.isalpha(): + raise RuntimeError( + f"Device type with index is not supported but got {device_type}. ", + "If you maintained a 'torch.device' object, it's recommended to pass in 'device.type'.", + ) + + # Always initialize the mesh's tensor on CPU, regardless of what the + # external device type has been set to be (e.g. meta) + with torch.device("cpu"): + mesh = torch.arange(math.prod(mesh_shape), dtype=torch.int).view(mesh_shape) + device_mesh = DeviceMesh( + device_type=device_type, + mesh=mesh, + mesh_dim_names=mesh_dim_names, + ) + + return device_mesh diff --git a/phivenv/Lib/site-packages/torch/distributed/distributed_c10d.py b/phivenv/Lib/site-packages/torch/distributed/distributed_c10d.py new file mode 100644 index 0000000000000000000000000000000000000000..a87bae5bd23ac1ce9f31429e49ae20cb85dbbc46 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/distributed_c10d.py @@ -0,0 +1,5644 @@ +# mypy: allow-untyped-defs +"""Distributed Collective Communication (c10d).""" + +import collections.abc +import contextlib +import ctypes +import hashlib +import io +import itertools +import logging +import os +import pickle +import sys +import time +import warnings +from collections import namedtuple +from datetime import timedelta +from typing import Any, Callable, Optional, TYPE_CHECKING, Union +from typing_extensions import deprecated + +import torch +from torch._C import _DistStoreError as DistStoreError +from torch._C._distributed_c10d import ( + _DistributedBackendOptions, + _register_process_group, + _resolve_process_group, + _unregister_all_process_groups, + _unregister_process_group, + AllgatherOptions, + AllreduceCoalescedOptions, + AllreduceOptions, + AllToAllOptions, + BarrierOptions, + BroadcastOptions, + DebugLevel, + GatherOptions, + get_debug_level, + PrefixStore, + ProcessGroup, + ReduceOp, + ReduceOptions, + ReduceScatterOptions, + ScatterOptions, + Store, + Work, +) +from torch._utils_internal import set_pytorch_distributed_envs_from_justknobs +from torch.monitor import _WaitCounter +from torch.overrides import handle_torch_function, has_torch_function +from torch.utils._typing_utils import not_none + +from .c10d_logger import _exception_logger, _time_logger +from .constants import default_pg_nccl_timeout, default_pg_timeout +from .rendezvous import register_rendezvous_handler, rendezvous # noqa: F401 + + +__all__ = [ + "Backend", + "BackendConfig", + "GroupMember", + "P2POp", + "all_gather", + "all_gather_coalesced", + "all_gather_object", + "all_reduce", + "all_reduce_coalesced", + "all_to_all", + "all_to_all_single", + "barrier", + "batch_isend_irecv", + "broadcast", + "send_object_list", + "recv_object_list", + "broadcast_object_list", + "destroy_process_group", + "gather", + "gather_object", + "get_backend_config", + "get_backend", + "get_default_backend_for_device", + "get_rank", + "get_world_size", + "get_pg_count", + "group", + "init_process_group", + "irecv", + "is_gloo_available", + "is_initialized", + "is_mpi_available", + "is_backend_available", + "is_nccl_available", + "is_torchelastic_launched", + "is_ucc_available", + "is_xccl_available", + "isend", + "monitored_barrier", + "new_group", + "new_subgroups", + "new_subgroups_by_enumeration", + "recv", + "reduce", + "reduce_scatter", + "scatter", + "scatter_object_list", + "send", + "supports_complex", + "AllreduceCoalescedOptions", + "AllreduceOptions", + "AllToAllOptions", + "BarrierOptions", + "BroadcastOptions", + "GatherOptions", + "PrefixStore", + "ProcessGroup", + "ReduceOp", + "ReduceOptions", + "ReduceScatterOptions", + "ScatterOptions", + "Store", + "DebugLevel", + "get_debug_level", + "Work", + "default_pg_timeout", + "get_group_rank", + "get_global_rank", + "get_process_group_ranks", + "reduce_op", + "all_gather_into_tensor", + "reduce_scatter_tensor", + "get_node_local_rank", + "split_group", +] + +_MPI_AVAILABLE = True +_NCCL_AVAILABLE = True +_GLOO_AVAILABLE = True +_UCC_AVAILABLE = True +_XCCL_AVAILABLE = True + +_pickler = pickle.Pickler +_unpickler = pickle.Unpickler + + +# Change __module__ of all imported types from torch._C._distributed_c10d that are public +def _export_c_types() -> None: + _public_types_to_change_module = [ + AllreduceCoalescedOptions, + AllreduceOptions, + AllToAllOptions, + BarrierOptions, + BroadcastOptions, + GatherOptions, + PrefixStore, + ProcessGroup, + ReduceOp, + ReduceOptions, + ReduceScatterOptions, + ScatterOptions, + Store, + DebugLevel, + get_debug_level, + Work, + ] + for type in _public_types_to_change_module: + type.__module__ = "torch.distributed.distributed_c10d" + + +_export_c_types() + +try: + from torch._C._distributed_c10d import ProcessGroupMPI + + ProcessGroupMPI.__module__ = "torch.distributed.distributed_c10d" + __all__ += ["ProcessGroupMPI"] +except ImportError: + _MPI_AVAILABLE = False + +try: + from torch._C._distributed_c10d import ProcessGroupNCCL + + ProcessGroupNCCL.__module__ = "torch.distributed.distributed_c10d" + __all__ += ["ProcessGroupNCCL"] +except ImportError: + _NCCL_AVAILABLE = False + +try: + from torch._C._distributed_c10d import _ProcessGroupWrapper, ProcessGroupGloo + + ProcessGroupGloo.__module__ = "torch.distributed.distributed_c10d" + __all__ += ["ProcessGroupGloo"] +except ImportError: + _GLOO_AVAILABLE = False + +try: + from torch._C._distributed_c10d import ProcessGroupUCC + + ProcessGroupUCC.__module__ = "torch.distributed.distributed_c10d" + __all__ += ["ProcessGroupUCC"] +except ImportError: + _UCC_AVAILABLE = False + +try: + from torch._C._distributed_c10d import ProcessGroupXCCL + + ProcessGroupXCCL.__module__ = "torch.distributed.distributed_c10d" + __all__ += ["ProcessGroupXCCL"] +except ImportError: + _XCCL_AVAILABLE = False + +logger = logging.getLogger(__name__) + +PG_WRAPPER_STORE_PREFIX = "pg_wrapper" + + +# Some reduce ops are not supported by complex numbers and will result in an error. +# We currently provide complex support to the distributed API by viewing +# complex tensors as real (torch.view_as_real), meaning that calling +# these unsupported ops will return garbage values rather than error out. +# (e.g. max(2+3i, 3+2i) = 3+3i) +# We'd like calls to unsupported ops to error out accordingly, +# rather than returning garbage values. +def supports_complex(reduceOp: ReduceOp) -> bool: + """Return true if reduce ops is supported. False otherwise.""" + denyList = [ + ReduceOp.MAX, + ReduceOp.MIN, + ReduceOp.PRODUCT, + ReduceOp.BAND, + ReduceOp.BOR, + ReduceOp.BXOR, + ] + return reduceOp not in denyList + + +# TODO refactor into enum/strenum +class Backend(str): # noqa: SLOT000 + """ + An enum-like class for backends. + + Available backends: GLOO, NCCL, UCC, MPI, XCCL, and other registered backends. + + The values of this class are lowercase strings, e.g., ``"gloo"``. They can + be accessed as attributes, e.g., ``Backend.NCCL``. + + This class can be directly called to parse the string, e.g., + ``Backend(backend_str)`` will check if ``backend_str`` is valid, and + return the parsed lowercase string if so. It also accepts uppercase strings, + e.g., ``Backend("GLOO")`` returns ``"gloo"``. + + .. note:: The entry ``Backend.UNDEFINED`` is present but only used as + initial value of some fields. Users should neither use it directly + nor assume its existence. + """ + + UNDEFINED = "undefined" + GLOO = "gloo" + NCCL = "nccl" + UCC = "ucc" + MPI = "mpi" + XCCL = "xccl" + + _BackendPlugin = namedtuple("_BackendPlugin", ["creator_fn", "extended_api"]) + + _plugins: dict[str, _BackendPlugin] = {} + + backend_list = [UNDEFINED, GLOO, NCCL, XCCL, UCC, MPI] + + # 3rd-party devices can register the default backend support here + default_device_backend_map: dict[str, str] = { + "cpu": GLOO, + "cuda": NCCL, + "xpu": XCCL, + "mps": GLOO, + } + + backend_capability: dict[str, list[str]] = { + GLOO: ["cpu", "cuda"], + NCCL: ["cuda"], + XCCL: ["xpu"], + UCC: ["cpu", "cuda"], + MPI: ["cpu", "cuda"], + } + + backend_type_map: dict[str, ProcessGroup.BackendType] = { + UNDEFINED: ProcessGroup.BackendType.UNDEFINED, + GLOO: ProcessGroup.BackendType.GLOO, + NCCL: ProcessGroup.BackendType.NCCL, + XCCL: ProcessGroup.BackendType.XCCL, + UCC: ProcessGroup.BackendType.UCC, + MPI: ProcessGroup.BackendType.MPI, + } + + def __new__(cls, name: str): + """Create and return a new instance of the class.""" + if not isinstance(name, str): + raise ValueError("Backend constructor parameter must be string-ish") + value = getattr(Backend, name.upper(), Backend.UNDEFINED) + + if value == Backend.UNDEFINED: + value = name.lower() + return value + + @classmethod + def register_backend( + cls, + name, + func, + extended_api=False, + devices: Optional[Union[str, list[str]]] = None, + ) -> None: + """ + Register a new backend with the given name and instantiating function. + + This class method is used by 3rd party ``ProcessGroup`` extension to + register new backends. + + Args: + name (str): Backend name of the ``ProcessGroup`` extension. It + should match the one in ``init_process_group()``. + func (function): Function handler that instantiates the backend. + The function should be implemented in the backend + extension and takes four arguments, including + ``store``, ``rank``, ``world_size``, and ``timeout``. + extended_api (bool, optional): Whether the backend supports extended argument structure. + Default: ``False``. If set to ``True``, the backend + will get an instance of ``c10d::DistributedBackendOptions``, and + a process group options object as defined by the backend implementation. + device (str or list of str, optional): device type this backend + supports, e.g. "cpu", "cuda", etc. If `None`, + assuming both "cpu" and "cuda" + + .. note:: This support of 3rd party backend is experimental and subject to change. + + """ + # This takes care of CUSTOM Out-of-tree backend types, update in backend_list indicates availability + if not hasattr(Backend, name.upper()): + setattr(Backend, name.upper(), name.lower()) + if name.lower() not in Backend.backend_list: + Backend.backend_list.append(name.lower()) + + if devices is not None: + for device in devices: + if device not in Backend.default_device_backend_map: + Backend.default_device_backend_map[device] = name.lower() + Backend.backend_type_map[name.lower()] = ProcessGroup.BackendType.CUSTOM + + # Update device capability matrix in Backend class + if devices is None: + # This is more of a backward support for groups like `threaded`: + # assume default devices "cpu" and "cuda", but warn + warnings.warn( + f"Device capability of {name} unspecified, assuming `cpu` and " + "`cuda`. Please specify it via the `devices` argument of " + "`register_backend`." + ) + Backend.backend_capability[name.lower()] = ["cpu", "cuda"] + elif isinstance(devices, str): + # Single device string specified. Simply convert to list. + Backend.backend_capability[name.lower()] = [devices] + else: + Backend.backend_capability[name.lower()] = devices + + Backend._plugins[name.upper()] = Backend._BackendPlugin(func, extended_api) + + +class BackendConfig: + """Backend configuration class.""" + + def __init__(self, backend: Backend): + """Init.""" + self.device_backend_map: dict[str, Backend] = {} + backend = str(backend) + + if backend == Backend.UNDEFINED: + # Detect the accelerator on the machine. If no accelerator is + # available, it returns CPU. + device_type = torch._C._get_accelerator().type + try: + backend_str = Backend.default_device_backend_map[device_type] + self.device_backend_map[device_type] = Backend(backend_str) + except KeyError: + raise ValueError( + f"We detected accelerator {device_type} on your machine. " + f"But we don't know which communication backend to use for this accelerator. " + f"Please specify the `backend` argument in the `init_process_group` call." + ) from None + elif backend.lower() in Backend.backend_list: + # Cases for when backend is a single string (without device types) + # e.g. "nccl", "gloo", "ucc", "mpi" + supported_devices = Backend.backend_capability[backend.lower()] + backend_val = Backend(backend) + self.device_backend_map = dict.fromkeys(supported_devices, backend_val) + elif ":" in backend.lower(): + # Backend specified in "device:backend" format + # make sure the backend string is in the correct format + # "{device_type1}:{backend1},{device_type2}:{backend2}" + # e.g. "cpu:gloo,cuda:nccl" + backend_str_error_message = f"""The custom backend string argument is invalid: {backend}. + Custom backend string is an experimental feature where the backend string must be in the format: + ":,:...". e.g. 'cpu:gloo,cuda:nccl'""" + + # parse the backend string and populate the device_backend_map + for device_backend_pair_str in backend.lower().split(","): + device_backend_pair = device_backend_pair_str.split(":") + if len(device_backend_pair) != 2: + raise ValueError( + f"Invalid device:backend pairing: \ + {device_backend_pair_str}. {backend_str_error_message}" + ) + device, backend = device_backend_pair + if device in self.device_backend_map: + raise ValueError( + f"Duplicate device type {device} \ + in backend string: {backend}. {backend_str_error_message}" + ) + self.device_backend_map[device] = Backend(backend) + else: + # User specified a single backend name whose device capability is + # unknown, assuming it can support the default devices of PyTorch + # (cpu and cuda) + warnings.warn( + f"Device capability of {backend} unknown, assuming `cpu` and " + "`cuda`. You can specify it in `device:backend` format in " + "`init_process_group` call." + ) + backend_val = Backend(backend) + self.device_backend_map = { + "cpu": backend_val, + "cuda": backend_val, + "xpu": backend_val, + } + + logger.info("Using backend config: %s", self.device_backend_map) + + def __repr__(self): + """Return all the device:backend pairs separated by commas.""" + return ",".join( + f"{device}:{backend}" for device, backend in self.device_backend_map.items() + ) + + def get_device_backend_map(self) -> dict[str, Backend]: + """Return backend map of the device.""" + return self.device_backend_map + + +class _reduce_op: + r""" + Deprecated enum-like class. + + For reduction operations: ``SUM``, ``PRODUCT``, ``MIN``, and ``MAX``. + + :class:`~torch.distributed.ReduceOp` is recommended to use instead. + """ + + def __init__(self) -> None: + # __members__ is a dict storing key-value pairs for enum classes + for k, v in ReduceOp.RedOpType.__members__.items(): + setattr(self, k, v) + self.__members__ = ReduceOp.RedOpType.__members__ + + @deprecated( + "`torch.distributed.reduce_op` is deprecated, " + "please use `torch.distributed.ReduceOp` instead", + category=FutureWarning, + ) + def __getattribute__(self, key): + return object.__getattribute__(self, key) + + +reduce_op = _reduce_op() + + +class P2POp: + """ + A class to build point-to-point operations for ``batch_isend_irecv``. + + This class builds the type of P2P operation, communication buffer, peer rank, + Process Group, and tag. Instances of this class will be passed to + ``batch_isend_irecv`` for point-to-point communications. + + Args: + op (Callable): A function to send data to or receive data from a peer process. + The type of ``op`` is either ``torch.distributed.isend`` or + ``torch.distributed.irecv``. + tensor (Tensor): Tensor to send or receive. + peer (int, optional): Destination or source rank. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + tag (int, optional): Tag to match send with recv. + group_peer (int, optional): Destination or source rank. + """ + + def __init__( + self, + op: Callable, + tensor: torch.Tensor, + peer: Optional[int] = None, + group: Optional[ProcessGroup] = None, + tag: int = 0, + group_peer: Optional[int] = None, + ): + """Init.""" + self.op = op + self.tensor = tensor + self.group = _group_or_default_group(group) + self.peer = _canonicalize_group_rank( + self.group, peer, group_peer, return_global=True + ) + self.tag = tag + self.group_peer = _canonicalize_group_rank(self.group, peer, group_peer) + + def __new__( + cls, + op: Callable, + tensor: torch.Tensor, + peer: Optional[int] = None, + group: Optional[ProcessGroup] = None, + tag: int = 0, + group_peer: Optional[int] = None, + ): + """Create and return a new instance of the class.""" + _check_op(op) + _check_single_tensor(tensor, "tensor") + + return object.__new__(cls) + + def __repr__(self): + my_group_rank = get_rank(self.group) + op_name = self.op.__name__ + group_name = self.group.group_name if self.group else "default_pg" + if "send" in op_name: + s = my_group_rank + d = self.group_peer + elif "recv" in op_name: + s = self.group_peer + d = my_group_rank + else: + return super().__repr__() + + return f"P2POp({op_name} pg={group_name}, group_src={s}, group_dst={d}, {self.tensor.shape}, {self.tensor.dtype})" + + +class _CollOp: + """ + A class to capture collective operations. + + Args: + op (Callable): A collective function, e.g. ``torch.distributed.all_reduce``. + tensor (Tensor): Tensor to operate on. + dst_tensor (Tensor, optional): Provided when source and destination tensors are not the same. + redop (ReduceOp, optional): reduce operation. + root (int, optional): root of broadcast or reduce. + """ + + def __init__( + self, + op: Callable, + tensor: torch.Tensor, + dst_tensor: Optional[torch.Tensor] = None, + redop: Optional[ReduceOp] = None, + root: Optional[int] = None, + ): + self.op = op + self.tensor = tensor + self.dst_tensor = dst_tensor + self.redop = redop + self.root = root + + +# DO NOT USE THESE FIELDS DIRECTLY. +# Use them through the _world object to make sure the _world override mechanism +_pg_map: dict[ProcessGroup, tuple[str, Store]] = {} +_pg_names: dict[ProcessGroup, str] = {} +_pg_group_ranks: dict[ProcessGroup, dict[int, int]] = {} +# For a pg, it is a map from ProcessGroup to BackendConfig +_pg_backend_config: dict[ProcessGroup, str] = {} +_group_count = 0 +_tags_to_pg: dict[str, list[ProcessGroup]] = {} +_pg_to_tag: dict[ProcessGroup, str] = {} +_backend: Optional[str] = None + + +class _World: + """ + Container class for c10d process group state. + + This is used during registration and lookup of PG state. + + .. warning:: This is an experimental API intended to expose the inner workings + of c10d and is subject to change.. + """ + + def __init__(self) -> None: + self._default_pg = None + self._pg_coalesce_state: dict[ProcessGroup, list[_CollOp]] = {} + + @property + def default_pg(self) -> Optional[ProcessGroup]: + """ + Process group that includes all ranks of the cluster. + + This default ProcessGroup is used by c10d APIs when a ProcessGroup is needed + but None is provided. + """ + return self._default_pg + + @default_pg.setter + def default_pg(self, value) -> None: + self._default_pg = value + + @property + def pg_map(self) -> dict[ProcessGroup, tuple[str, Store]]: + """ + Provide Mapping from ProcessGroup to backend name and store. + + For NCCL and GLOO pg, it is a map from ProcessGroup to (Backend, Store) + For MPI pg, it is a map from ProcessGroup to (Backend, None) + + TODO don't expose the map, expose fine grained ops + """ + global _pg_map + return _pg_map + + @property + def pg_names(self) -> dict[ProcessGroup, str]: + """ + Process group's names, map from ProcessGroup to str. + + TODO don't expose the map, expose fine grained ops + """ + global _pg_names + return _pg_names + + @property + def pg_group_ranks(self) -> dict[ProcessGroup, dict[int, int]]: + """ + Process group's global rank to local rank mapping. + + TODO don't expose the map, expose fine grained ops + """ + global _pg_group_ranks + return _pg_group_ranks + + @property + def pg_backend_config(self) -> dict[ProcessGroup, str]: + """ + Process group's backend config. + + TODO don't expose the map, expose fine grained ops + """ + global _pg_backend_config + return _pg_backend_config + + @property + def group_count(self) -> int: + """ + Process group count for default naming. + + TODO don't expose group_count, use something else instead + """ + global _group_count + return _group_count + + @group_count.setter + def group_count(self, value: int) -> None: + """Use to compute the name of ProcessGroups when using global synchronization.""" + global _group_count + _group_count = value + + @property + def tags_to_pg(self) -> dict[str, list[ProcessGroup]]: + global _tags_to_pg + return _tags_to_pg + + @property + def pg_to_tag(self) -> dict[ProcessGroup, str]: + global _pg_to_tag + return _pg_to_tag + + @property + def pg_coalesce_state(self) -> dict[ProcessGroup, list[_CollOp]]: + return self._pg_coalesce_state + + @property + def pg_config_info(self) -> list[dict[str, Any]]: + """ + Return a list of dict with process groups and backends. + + Along with their unique IDs and configurations (types and ranks). + """ + config_info: list[dict[str, Any]] = [] + default_pg_size = _get_group_size(None) + for pg in self.pg_map.keys(): + ranks = self.pg_group_ranks[pg] + config_info.append( + { + "pg_name": self.pg_names[pg], + "pg_desc": pg.group_desc, + "backend_config": self.pg_backend_config[pg], + "ranks": ( + list(ranks.keys()) if len(ranks) != default_pg_size else [] + ), # 'ranks' is an empty list when all ranks are involved in a pg + "group_size": len(ranks), + "group_count": self.group_count, + } + ) + return config_info + + +_world = _World() +"""Holds the singleton instance of ``_World`` used by c10. Experimental extension point to override it""" + + +class _WorldMeta(type): + """ + Meta class of ``group`` and ``GroupMember``. + + Allows them to have the class property ``WORLD``. + """ + + # Points to the default PG once initialized. + @property + def WORLD(cls) -> Optional[ProcessGroup]: + return _world.default_pg + + @WORLD.setter + def WORLD(cls, pg: Optional[ProcessGroup]): + _world.default_pg = pg + + +class group(metaclass=_WorldMeta): + """Group class. Placeholder.""" + + +class GroupMember(metaclass=_WorldMeta): + """Group member class.""" + + NON_GROUP_MEMBER = -100 + + +def _get_default_timeout(backend: Backend) -> timedelta: + # see note on nccl vs other backend timeout (constants.py) + if backend == Backend.NCCL: + if not isinstance(default_pg_nccl_timeout, timedelta): + # TODO moco benchmark on CPU initializes pgnccl backend today, triggered this assert in CI before it was + # changed to be a warning. We should fix the moco model. + warnings.warn( + "Attempted to get default timeout for nccl backend, but NCCL support is not compiled" + ) + return default_pg_timeout + return default_pg_nccl_timeout + else: + return default_pg_timeout + + +def _check_valid_timeout(timeout: Any) -> None: + if not isinstance(timeout, timedelta): + raise TypeError( + f"Expected timeout argument to be of type datetime.timedelta, got {timeout}" + ) + + +# Default process group state +_default_pg_init_method: Optional[str] = None + +STORE_BASED_BARRIER_PREFIX = "store_based_barrier_key" + + +def _get_object_coll_device(group: Optional[ProcessGroup] = None) -> str: + """ + .. note:: This is an internal helper and does not have backward + compatibility, please use with caution. + + Return the device type to use with ``group`` for object collectives or + barrier. + + There are selection rules: + 1. If user specifies exactly one backend in ``init_process_group`` call: + use that backend + 2. Else if user specifies multiple "device:backend" pairs in init_process_group: + If "cpu" is among those pairs, use "cpu" (because the object is in cpu memory); + Otherwise, use the first backend (sort of a random pick). + + Args: + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + + Returns: + str: The device type to use for object collective with ``group``. + + """ + group = group or _get_default_group() + + if not isinstance(group, ProcessGroup): + warnings.warn( + f"You are using a Backend {type(group)} as a ProcessGroup. " + "This usage is deprecated since PyTorch 2.0. Please use a public API " + "of PyTorch Distributed instead.", + ) + # Provide backward compatibility to cases where `group` passed in is + # actually a Backend (like `ProcessGroupGloo`) rather than a + # `ProcessGroup` in PT 2.0 sense + if isinstance(group, ProcessGroupGloo): + # RPC uses Gloo for object collectives + return "cpu" + else: + raise ValueError(f"Expecting a ProcessGroup, but got a {type(group)}.") + + """ + ``group._device_types`` is a property pybind that returns the devices + ("cpu", "cuda", etc) supported by ``group``. Can be multiple if the + ``group`` supports multiple devices. + """ + devices = group._device_types + + if len(devices) == 1: + # User fixed exactly one backend in `init_process_group` + return devices[0].type + elif len(devices) == 0: + # No backend has been registered with this PG (maybe because no + # collective has been run?) We pick cpu as the default and hopefully + # this would lazily init Gloo or other available cpu backend. + return "cpu" + elif torch.device("cpu") in devices: + # There are multiple backends in this PG and cpu is among them. + # cpu is preferred as the object is in cpu memory. No need for device + # copy. + return "cpu" + else: + # No cpu in the backend list. Randomly pick the first backend + return devices[0].type + + +def _get_pg_default_device(group: Optional[ProcessGroup] = None) -> torch.device: + """ + .. note:: This method will be deprecated, it only stays for + backward-compatiblity reason. Alternatives: + + - If you need to find a device for object collectives, please use + `_get_object_coll_device(group)`. + + - If you need to query the device types supported by group, please use + `_device_capability(group)`. + + Return the device type registered with ``group``. + + For example, if `init_process_group("nccl", ...)` was called, the returned + value would be `torch.device("cuda")`. + + Errors out if no device has been registered. + + Args: + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + + Returns: + torch.device: The device type registered with ``group``. + """ + + warnings.warn( + "`_get_pg_default_device` will be deprecated, it only stays for " + "backward-compatiblity reason. If you need to find a device for object " + "collectives, please use `_get_object_coll_device`. If you need to query " + "the device types supported by group, please use " + "`_device_capability(group)`. " + ) + group = group or _get_default_group() + + if not isinstance(group, ProcessGroup): + # Provide backward compatibility to cases where `group` passed in is + # actually a Backend (like `ProcessGroupGloo`) rather than a + # `ProcessGroup` in PT 2.0 sense + warnings.warn( + f"You are using a Backend {type(group)} as a ProcessGroup. " + "This usage is deprecated since PyTorch 2.0. Please use a public API " + "of PyTorch Distributed instead.", + FutureWarning, + stacklevel=3, + ) + # Most users create Gloo with private API for object collectives + return torch.device("cpu") + + """ + ``group._device_types`` is a property pybind that returns the devices + ("cpu", "cuda", etc) supported by ``group``. Can be multiple if the + ``group`` supports multiple devices. + """ + devices = group._device_types + + if len(devices) == 1: + # User fixed exactly one backend in `init_process_group` + return devices[0] + elif len(devices) == 0: + raise RuntimeError( + "Default device not found, because no backend has been registered " + "with this ProcessGroup." + ) + else: + # There are multiple backends in this PG. + if torch.device("cpu") in devices: + rv = torch.device("cpu") + else: + rv = devices[0] + warnings.warn( + "Multiple backends are registered with this ProcessGroup. We cannot " + f"determine which one is the default. Returning {rv}. " + "Please consider using other APIs." + ) + return rv + + +def _device_capability(group: Optional[ProcessGroup] = None) -> list[str]: + """ + Return the device type(s) supported by ``group``. + + Args: + group (ProcessGroup, optional): The process group to query. If None, + the default process group will be used. + + Returns: + List[str]: A list of device types supported by ``group``. + """ + group = group or _get_default_group() + return [device.type for device in group._device_types] + + +@_time_logger +def _store_based_barrier( + rank, + store, + group_name, + rendezvous_count, + timeout, + logging_interval=timedelta(seconds=10), +) -> None: + """ + Store based barrier for synchronizing processes. + + Barrier based on store which is used for synchronizing processes after + ``init_process_group`` or ``new_group``. Intended to be used only with + those two methods and is not a generic alternative to ``barrier()``. + """ + store_key = f"{STORE_BASED_BARRIER_PREFIX}:{group_name}" + store.add(store_key, 1) + logger.debug("Added key: %s to store for rank: %s", store_key, rank) + + # Now wait for all workers to check in with the store. + world_size = rendezvous_count + worker_count = store.add(store_key, 0) + + last_worker_key = f"{store_key}:last_worker" + if worker_count == world_size: + store.set(last_worker_key, "1") + + # adjust the timeout to be at least 10secs + 1sec per thousand ranks to reduce the odds of timeout + # this value was empirically found while scale testing. + logging_interval = max(logging_interval, timedelta(seconds=10 + world_size / 1000)) + + start = time.time() + while True: + try: + # This will throw an exception after the logging_interval in which we print out + # the status of the group or time out officially, throwing runtime error + store.wait([last_worker_key], logging_interval) + break + except RuntimeError as e: + worker_count = store.add(store_key, 0) + # Print status periodically to keep track. + logger.debug( + "Waiting in store based barrier to initialize process group for %s seconds" + "rank: %s, key: %s (world_size=%s, num_workers_joined=%s, timeout=%s error=%s)", + time.time() - start, + rank, + store_key, + world_size, + worker_count, + timeout, + e, + ) + + if timedelta(seconds=(time.time() - start)) > timeout: + raise DistStoreError( # noqa: B904 + "Timed out initializing process group in store based barrier on " + f"rank {rank}, for key: {store_key} (world_size={world_size}, " + f"num_workers_joined={worker_count}, timeout={timeout} error={e})" + ) + + logger.info( + "Rank %s: Completed store-based barrier for key:%s with %s nodes.", + rank, + store_key, + world_size, + ) + + +def _rank_not_in_group(group: Optional[ProcessGroup]) -> bool: + """Check if the current process's rank is not in a given group.""" + if group is None: + return False + return group == GroupMember.NON_GROUP_MEMBER + + +def _warn_not_in_group(op_name) -> None: + global_rank = -1 if GroupMember.WORLD is None else GroupMember.WORLD.rank() + warnings.warn( + f"Running {op_name} on global rank {global_rank} which does not " + "belong to the given group." + ) + + +def get_group_rank(group: ProcessGroup, global_rank: int) -> int: + """ + Translate a global rank into a group rank. + + ``global_rank`` must be part of ``group`` otherwise this raises RuntimeError. + + Args: + group (ProcessGroup): ProcessGroup to find the relative rank. + global_rank (int): Global rank to query. + + Returns: + Group rank of ``global_rank`` relative to ``group`` + + N.B. calling this function on the default process group returns identity + """ + if group is GroupMember.WORLD: + return global_rank + if group not in _world.pg_group_ranks: + raise ValueError( + f"Group {group} is not registered, please create group with torch.distributed.new_group API" + ) + group_ranks = _world.pg_group_ranks[group] + if global_rank not in group_ranks: + raise ValueError(f"Global rank {global_rank} is not part of group {group}") + + return group_ranks[global_rank] + + +def get_global_rank(group: ProcessGroup, group_rank: int) -> int: + """ + Translate a group rank into a global rank. + + ``group_rank`` must be part of `group` otherwise this raises RuntimeError. + + Args: + group (ProcessGroup): ProcessGroup to find the global rank from. + group_rank (int): Group rank to query. + + Returns: + Global rank of ``group_rank`` relative to ``group`` + + N.B. calling this function on the default process group returns identity + """ + if group is GroupMember.WORLD: + return group_rank + if group not in _world.pg_group_ranks: + raise ValueError( + f"Group {group} is not registered, please create group with torch.distributed.new_group API" + ) + for rank, grp_rank in _world.pg_group_ranks[group].items(): + if grp_rank == group_rank: + return rank + raise ValueError(f"Group rank {group_rank} is not part of group {group}") + + +# TODO: remove this once the ecosystem moves away from it. +@deprecated( + "`torch.distributed.distributed_c10d._get_global_rank` is deprecated, " + "please use `torch.distributed.distributed_c10d.get_global_rank` instead", + category=FutureWarning, +) +def _get_global_rank(group, rank) -> int: + """Use get_global_rank as this method is deprecated.""" + return get_global_rank(group, rank) + + +def get_process_group_ranks(group: Optional[ProcessGroup]) -> list[int]: + """ + Get all ranks associated with ``group``. + + Args: + group (Optional[ProcessGroup]): ProcessGroup to get all ranks from. + If None, the default process group will be used. + + Returns: + List of global ranks ordered by group rank. + """ + return list(_world.pg_group_ranks[group or _get_default_group()].keys()) + + +def _get_group_size(group) -> int: + """Get a given group's world size.""" + if group is GroupMember.WORLD or group is None: + default_pg = _get_default_group() + return default_pg.size() + return group.size() + + +def _get_group_size_by_name(group_name: str) -> int: + group = _resolve_process_group(group_name) + return group.size() + + +def _resolve_group_name_by_ranks_and_tag(ranks: list[int], tag: str) -> str: + # TODO(yifu): remove this function once ranks + tag is not a supported + # identifier for process group for functional collectives. + group = _find_pg_by_ranks_and_tag(tag, ranks) + if group is None: + raise ValueError("") + return group.group_name + + +def _check_single_tensor(param, param_name) -> None: + """Check that the parameter ``param_name`` is a single tensor.""" + if not isinstance(param, torch.Tensor): + raise TypeError( + f"""Invalid function argument. Expected parameter `{param_name}` of type torch.Tensor + but got {type(param)} instead.""" + ) + + +def _check_tensor_list(param, param_name) -> None: + """Check that the parameter ``param_name`` is a list of tensors.""" + if not isinstance(param, list): + raise TypeError( + f"""Invalid function argument. Expected parameter `{param_name}` of type List[torch.Tensor] + but got {type(param)} instead.""" + ) + elif not all(isinstance(p, torch.Tensor) for p in param): + raise TypeError( + f"""Invalid function argument. Expected parameter `{param_name}` of type List[torch.Tensor] + but got {type(param)} with elements of type {[type(p) for p in param]}.""" + ) + + +def _group_or_default_group(group: Optional[ProcessGroup] = None) -> ProcessGroup: + if group is None or group is GroupMember.WORLD: + group = _get_default_group() + return group + + +def _canonicalize_group_rank( + group: ProcessGroup, + global_rank: Optional[int] = None, + group_rank: Optional[int] = None, + return_global: bool = False, +) -> int: + """ + Helper method to take _either_ a global rank or a group rank and produce a group rank. + + If 'return_global' is true, produce a global rank instead of a group rank. + """ + + if group_rank is not None: + if global_rank is not None: + raise ValueError("Can't specify both group_rank and global_rank") + if return_global: + return get_global_rank(group, group_rank) + else: + if global_rank is None: + raise ValueError("Must specify global_rank or group_rank") + if return_global: + return global_rank + group_rank = get_group_rank(group, global_rank) + return group_rank + + +def _check_not_self_rank(group: ProcessGroup, rank: int, rank_type: str): + if group.rank() == rank: + raise ValueError( + f"Invalid {rank_type} rank: {rank_type} rank should not be the same as " + "the rank of the current process." + ) + + +def _as_iterable(obj) -> collections.abc.Iterable: + return obj if isinstance(obj, list) else (obj,) + + +def _ensure_all_tensors_same_dtype(*tensors) -> None: + last_dtype = None + for tensor in itertools.chain.from_iterable(map(_as_iterable, tensors)): + tensor_dtype = tensor.dtype + # Mixing complex and its element type is allowed + if tensor_dtype.is_complex: + tensor_dtype = ( + torch.float32 if tensor_dtype == torch.complex64 else torch.complex128 + ) + + if last_dtype is None: + last_dtype = tensor_dtype + else: + if last_dtype != tensor_dtype: + raise ValueError( + "Invalid usage of tensors with different dtypes" + f"Found {last_dtype} and {tensor.dtype}" + ) + + +def _check_op(op) -> None: + """Check that the ``op`` is either isend or irecv.""" + if op not in [isend, irecv]: + raise ValueError( + "Invalid ``op``. Expected ``op`` " + "to be of type ``torch.distributed.isend`` or " + "``torch.distributed.irecv``." + ) + + +def _check_p2p_op_list(p2p_op_list) -> None: + """ + Check that the ``p2p_op_list`` is a list of P2POp instances. + + Also, check that all ops use the same group. + """ + if not isinstance(p2p_op_list, list) or not all( + isinstance(p2p_op, P2POp) for p2p_op in p2p_op_list + ): + raise ValueError( + "Invalid ``p2p_op_list``. Each op is expected to " + "to be of type ``torch.distributed.P2POp``." + ) + + group = p2p_op_list[0].group + if not all(group == p2p_op.group for p2p_op in p2p_op_list): + raise ValueError("All ops need to use the same group.") + + +def is_mpi_available() -> bool: + """Check if the MPI backend is available.""" + return _MPI_AVAILABLE + + +def is_nccl_available() -> bool: + """Check if the NCCL backend is available.""" + return _NCCL_AVAILABLE + + +def is_gloo_available() -> bool: + """Check if the Gloo backend is available.""" + return _GLOO_AVAILABLE + + +def is_ucc_available() -> bool: + """Check if the UCC backend is available.""" + return _UCC_AVAILABLE + + +def is_xccl_available() -> bool: + """Check if the XCCL backend is available.""" + return _XCCL_AVAILABLE + + +def is_backend_available(backend: str) -> bool: + """ + Check backend availability. + + Checks if the given backend is available and supports the built-in backends or + third-party backends through function ``Backend.register_backend``. + + Args: + backend (str): Backend name. + Returns: + bool: Returns true if the backend is available otherwise false. + """ + # If the backend has an ``is_backend_available`` function, return the result of that function directly + available_func = getattr(torch.distributed, f"is_{backend.lower()}_available", None) + if available_func: + return available_func() + + return backend.lower() in Backend.backend_list + + +def is_initialized() -> bool: + """Check if the default process group has been initialized.""" + return GroupMember.WORLD is not None + + +def is_torchelastic_launched() -> bool: + """ + Check whether this process was launched with ``torch.distributed.elastic`` (aka torchelastic). + + The existence of ``TORCHELASTIC_RUN_ID`` environment + variable is used as a proxy to determine whether the current process + was launched with torchelastic. This is a reasonable proxy since + ``TORCHELASTIC_RUN_ID`` maps to the rendezvous id which is always a + non-null value indicating the job id for peer discovery purposes.. + """ + return os.getenv("TORCHELASTIC_RUN_ID") is not None + + +def _is_barrier_after_init() -> int: + # Environment variable to control whether process group should perform a + # barrier after its init. Default value is 0, i.e. no barrier. If you + # experience issue with this setting, you may set + # `TORCH_DIST_INIT_BARRIER=1` to add the barrier. + return int(os.getenv("TORCH_DIST_INIT_BARRIER", "0")) + + +def _get_default_group() -> ProcessGroup: + """Get the default process group created by init_process_group.""" + if not is_initialized(): + raise ValueError( + "Default process group has not been initialized, " + "please make sure to call init_process_group." + ) + if TYPE_CHECKING: + return not_none(GroupMember.WORLD) + else: + return GroupMember.WORLD + + +def _get_default_store() -> Store: + """Get the default store created by init_process_group.""" + if not is_initialized(): + raise ValueError( + "Default process group has not been initialized, " + "please make sure to call init_process_group." + ) + default_pg = _get_default_group() + _, default_store = _world.pg_map[default_pg] + return default_store + + +def _update_default_pg(pg) -> None: + _world.default_pg = pg + rank = pg.rank() if pg is not None and pg != GroupMember.NON_GROUP_MEMBER else -1 + torch._C._distributed_c10d._set_global_rank(rank) + + +def get_backend_config(group: Optional[ProcessGroup] = None) -> str: + """ + Return the backend configuration of the given process group. + + Args: + group (ProcessGroup, optional): The process group to work on. The + default is the general main process group. If another specific group + is specified, the calling process must be part of :attr:`group`. + + Returns: + The backend configuration of the given process group as a lower case string. + + """ + pg = group or _get_default_group() + if _rank_not_in_group(pg): + raise ValueError("Invalid process group specified") + backend_config = _world.pg_backend_config.get(pg) + return str(not_none(backend_config)) + + +def get_backend(group: Optional[ProcessGroup] = None) -> Backend: + """ + Return the backend of the given process group. + + Args: + group (ProcessGroup, optional): The process group to work on. The + default is the general main process group. If another specific group + is specified, the calling process must be part of :attr:`group`. + + Returns: + The backend of the given process group as a lower case string. + + """ + pg = group or _get_default_group() + if _rank_not_in_group(pg): + raise ValueError("Invalid process group specified") + + pg_store = _world.pg_map.get(pg, None) + if pg_store is None: + raise ValueError( + f"Process group {pg} is not initialized in the world group map. Please initialize the group first." + ) + + return Backend(not_none(pg_store)[0]) + + +def get_default_backend_for_device(device: Union[str, torch.device]) -> str: + """ + Return the default backend for the given device. + + Args: + device (Union[str, torch.device]): The device to get the default backend for. + + Returns: + The default backend for the given device as a lower case string. + + """ + if isinstance(device, torch.device): + device_str = device.type + else: + device_str = torch.device(device).type + + backend = Backend.default_device_backend_map.get(device_str) + if backend is None: + raise ValueError(f"Default backend not registered for device : {device}") + + return backend + + +def _get_process_group_uid(pg: ProcessGroup) -> int: + backend = None + try: + backend = pg._get_backend(torch.device("cuda")) + except RuntimeError: + pass + if is_nccl_available() and isinstance(backend, ProcessGroupNCCL): + return backend.uid + return -1 + + +def _get_pg_config(group: Optional[ProcessGroup] = None) -> dict[str, Any]: + """ + Return the pg configuration of the given process group. + + """ + pg = group or _get_default_group() + return { + "pg_name": _get_process_group_name(pg), + "pg_desc": pg.group_desc, + "backend_config": get_backend_config(pg), + "pg_size": _get_group_size(pg), + "ranks": get_process_group_ranks(pg), + } + + +def _get_all_pg_configs() -> list[dict[str, Any]]: + """ + Return the pg configuration of all the process groups. + + """ + config_info: list[dict[str, Any]] = [ + _get_pg_config(pg) for pg in _world.pg_map.keys() + ] + return config_info + + +def get_pg_count() -> int: + """ + Return the number of process groups. + + """ + return _world.group_count + + +def get_node_local_rank(fallback_rank: Optional[int] = None) -> int: + """ + Return the local rank of the current process relative to the node. + + Semantically, this is a useful concept for mapping processes to devices. + For example, on a node with 8 accelerator you could use the node local rank to decide + which accelerator device to bind the process to. + + In practice, the actual assignment of node local ranks is handled by the process launcher outside of pytorch, + and communicated via the `LOCAL_RANK` environment variable. + + Torchrun will automatically populate `LOCAL_RANK`, but other launchers may not. If `LOCAL_RANK` is unspecified, + this API will fall back to the provided kwarg 'fallback_rank' if specified, otherwise it will raise an error. The + intent is to allow writing an application that runs either in single or multi device contexts without error. + + """ + if "LOCAL_RANK" in os.environ: + return int(os.environ["LOCAL_RANK"]) + elif fallback_rank is not None: + return int(fallback_rank) + raise RuntimeError( + "LOCAL_RANK is not in the environment. Consider passing fallback_rank to allow `get_node_local_rank` to work, " + "assuming you are not running in a multi-device context and want the code to run locally instead." + ) + + +def _add_ephemeral_timeout_for_all_pgs(timeout: timedelta) -> None: + """ + This API adds an ephemeral timeout extension for all PGs locally + on one rank. The timeout gets reset when the first collective issued + after API called finished. + NOTE: We only support to set timeout for cuda backends for now. + NOTE: While this feature + provides flexibility in specific scenarios, it introduces statefulness + to timeout setting. Therefore, it is advisable to use this API sparingly + and consider alternative approaches, such as directly setting the timeout + or utilizing a barrier collective (one can set any timeout to the barrier), + whenever feasible. + + Args: + timeout (timedelta): The delta of timeout to extend. + + Returns: + None. + """ + for pg in _world.pg_map.keys(): + devices = pg._device_types + if torch.device("cuda") in devices: + backend = pg._get_backend(torch.device("cuda")) + if is_nccl_available() and isinstance(backend, ProcessGroupNCCL): + backend._add_ephemeral_timeout(timeout) + + +def _set_pg_timeout(timeout: timedelta, group: Optional[ProcessGroup] = None) -> None: + """ + Set the timeout for the given process group when users want to use a different timeout instead of + default values. + + Args: + timeout (timedelta): Timeout for operations executed against the process group which + users want to set. Default value is 10 minutes for NCCL and 30 minutes for other backends. + This is the duration after which collectives will be aborted asynchronously and the process will crash. + This is done since CUDA execution is async and it is no longer safe to continue executing user code since + failed async NCCL operations might result in subsequent CUDA operations running on corrupted data. + When TORCH_NCCL_BLOCKING_WAIT is set, the process will block and wait for this timeout. + + group (ProcessGroup, optional): The process group to work on. The + default is the general main process group. If another specific group + is specified, the calling process must be part of :attr:`group`. + + Returns: + None + """ + if group is None: + group = _get_default_group() + if _rank_not_in_group(group): + raise ValueError("Invalid process group specified") + assert isinstance(group, ProcessGroup) + devices = group._device_types + backends = set() + if torch.device("cpu") in devices and is_gloo_available(): + backend = group._get_backend(torch.device("cpu")) + if isinstance(backend, ProcessGroupGloo): + backends.add(backend) + if torch.device("cuda") in devices: + backend = group._get_backend(torch.device("cuda")) + if is_nccl_available() and isinstance(backend, ProcessGroupNCCL): + backends.add(backend) # type: ignore[arg-type] + elif is_gloo_available() and isinstance(backend, ProcessGroupGloo): + backends.add(backend) # type: ignore[arg-type] + if len(backends) == 0: + warnings.warn("Set timeout is now only supported for either nccl or gloo.") + for backend in backends: + backend._set_default_timeout(timeout) + + +@_exception_logger +@_time_logger +def init_process_group( + backend: Optional[str] = None, + init_method: Optional[str] = None, + timeout: Optional[timedelta] = None, + world_size: int = -1, + rank: int = -1, + store: Optional[Store] = None, + group_name: str = "", + pg_options: Optional[Any] = None, + device_id: Optional[Union[torch.device, int]] = None, +) -> None: + """ + Initialize the default distributed process group. + + This will also initialize the distributed package. + + There are 2 main ways to initialize a process group: + 1. Specify ``store``, ``rank``, and ``world_size`` explicitly. + 2. Specify ``init_method`` (a URL string) which indicates where/how + to discover peers. Optionally specify ``rank`` and ``world_size``, + or encode all required parameters in the URL and omit them. + + If neither is specified, ``init_method`` is assumed to be "env://". + + + Args: + backend (str or Backend, optional): The backend to use. Depending on + build-time configurations, valid values include ``mpi``, ``gloo``, + ``nccl``, ``ucc``, or one that is registered by a third-party + plugin. + Since 2.6, if ``backend`` is not provided, c10d will use a backend + registered for the device type indicated by the `device_id` kwarg + (if provided). The known default registrations today are: ``nccl`` + for ``cuda``, ``gloo`` for ``cpu``. + If neither ``backend`` nor ``device_id`` is provided, c10d will + detect the accelerator on the run-time machine and use a backend + registered for that detected accelerator (or ``cpu``). + This field can be given as a lowercase string (e.g., ``"gloo"``), + which can also be accessed via :class:`Backend` attributes (e.g., + ``Backend.GLOO``). + If using multiple processes per machine with ``nccl`` backend, each + process must have exclusive access to every GPU it uses, as sharing + GPUs between processes can result in deadlock or NCCL invalid usage. + ``ucc`` backend is experimental. + Default backend for the device can be queried with + :func:`get_default_backend_for_device`. + init_method (str, optional): URL specifying how to initialize the + process group. Default is "env://" if no + ``init_method`` or ``store`` is specified. + Mutually exclusive with ``store``. + world_size (int, optional): Number of processes participating in + the job. Required if ``store`` is specified. + rank (int, optional): Rank of the current process (it should be a + number between 0 and ``world_size``-1). + Required if ``store`` is specified. + store(Store, optional): Key/value store accessible to all workers, used + to exchange connection/address information. + Mutually exclusive with ``init_method``. + timeout (timedelta, optional): Timeout for operations executed against + the process group. Default value is 10 minutes for NCCL and 30 minutes for other backends. + This is the duration after which collectives will be aborted asynchronously and the process will crash. + This is done since CUDA execution is async and it is no longer safe to continue executing user code since + failed async NCCL operations might result in subsequent CUDA operations running on corrupted data. + When TORCH_NCCL_BLOCKING_WAIT is set, the process will block and wait for this timeout. + + group_name (str, optional, deprecated): Group name. This argument is ignored + pg_options (ProcessGroupOptions, optional): process group options + specifying what additional options need to be passed in during + the construction of specific process groups. As of now, the only + options we support is ``ProcessGroupNCCL.Options`` for the ``nccl`` + backend, ``is_high_priority_stream`` can be specified so that + the nccl backend can pick up high priority cuda streams when + there're compute kernels waiting. For other available options to config nccl, + See https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t + device_id (torch.device | int, optional): a single, specific device + this process will work on, allowing for backend-specific + optimizations. Currently this has two effects, only under + NCCL: the communicator is immediately formed (calling + ``ncclCommInit*`` immediately rather than the normal lazy + call) and sub-groups will use ``ncclCommSplit`` when + possible to avoid unnecessary overhead of group creation. If you + want to know NCCL initialization error early, you can also use this + field. If an `int` is provided, the API assumes that the accelerator + type at compile time will be used. + + .. note:: To enable ``backend == Backend.MPI``, PyTorch needs to be built from source + on a system that supports MPI. + + .. note:: Support for multiple backends is experimental. Currently when no backend is + specified, both ``gloo`` and ``nccl`` backends will be created. The ``gloo`` backend + will be used for collectives with CPU tensors and the ``nccl`` backend will be used + for collectives with CUDA tensors. A custom backend can be specified by passing in + a string with format ":,:", e.g. + "cpu:gloo,cuda:custom_backend". + + """ + + global _world + + global _backend + global _default_pg_init_method + + if GroupMember.WORLD is not None: + raise ValueError("trying to initialize the default process group twice!") + + set_pytorch_distributed_envs_from_justknobs() + + # Depending on the import order, some trace_rules functions may be evaluated + # during the import phase. In such a case, these functions may not correctly + # add the distributed related rules due to import circular dependency. + # We need to clear the lru_cache during the runtime to ensure the correctness + # of these trace_rules. + # + # Since this API must be called before all distributed code being compiled, + # clearing the cache here should be safe. + if "torch._dynamo" in sys.modules: + torch._dynamo.trace_rules.clear_lru_cache() + + assert (store is None) or (init_method is None), ( + "Cannot specify both init_method and store." + ) + + if store is not None: + assert world_size > 0, "world_size must be positive if using store" + assert rank >= 0, "rank must be non-negative if using store" + elif init_method is None: + init_method = "env://" + + # Get the compile-time accelerator type. + # None indicates no accelerator support. + acc = torch.accelerator.current_accelerator() + + # Auto complete device id + if isinstance(device_id, int): + if acc is None: + raise ValueError( + "device_id is an int, but no accelerator support is found from the current compilation. " + "Please use a different compiled version that supports your accelerator." + ) + device_id = torch.device(acc.type, device_id) + + # Sanity check device_id + if device_id is not None and device_id.type != "cpu": + # Type + if acc is None or device_id.type != acc.type: + raise ValueError( + f"device_id {device_id} does not match the current compilation's accelerator support: {acc}. " + "Please use a different compiled version that supports your accelerator." + ) + # Index + if device_id.index is None: + raise ValueError("Please use a device_id with index.") + # Range + if device_id.index >= torch.accelerator.device_count(): + raise ValueError( + f"device_id {device_id} is out of range. Please use a device index less than " + f"the number of accelerators available: {torch.accelerator.device_count()}." + ) + + logger.info("Using device: %s", device_id) + + # If user did not provide a backend string but provided a device id, e.g. + # >>> init_process_group(device_id=device) + # we try to figure out the backend name based on the device type. + if backend is None and device_id is not None: + # Note: 3rd-party devices can register default backend through the + # default map below. + backend = Backend.default_device_backend_map.get(device_id.type) + + # If we still cannot figure it out, e.g. + # >>> init_process_group() + # we set it to `undefined` and rely on lazy init. + if backend is None: + backend = "undefined" + + # Convert string into `Backend` type + backend = Backend(backend) + + if timeout is None: + timeout = _get_default_timeout(backend) + + _check_valid_timeout(timeout) + + """ + Group name is not visible to users unless they access + internals of c10d. This means we can ignore the value + they provide as it not exposed in a public way. + """ + group_name = _process_group_name([], use_hashed_name=False) + if backend == Backend.MPI: + if world_size != -1 or rank != -1: + warnings.warn( + f"For MPI backend, world_size ({world_size}) and rank ({rank}) " + "are ignored since they are assigned by the " + "MPI runtime." + ) + + default_pg, _ = _new_process_group_helper( + -1, + -1, + [], + backend, + Store(), # Placeholder value since store cannot be None + group_name, + timeout=timeout, + group_desc="default_pg", + ) + _update_default_pg(default_pg) + else: + # backward compatible API + if store is None: + rendezvous_iterator = rendezvous( + not_none(init_method), rank, world_size, timeout=timeout + ) + store, rank, world_size = next(rendezvous_iterator) + store.set_timeout(timeout) + + # Use a PrefixStore to avoid accidental overrides of keys used by + # different systems (e.g. RPC) in case the store is multi-tenant. + store = PrefixStore("default_pg", store) + + default_pg, _ = _new_process_group_helper( + world_size, + rank, + [], + backend, + store, + group_name, + backend_options=pg_options, + timeout=timeout, + device_id=device_id, + group_desc="default_pg", + ) + _update_default_pg(default_pg) + + _world.pg_group_ranks[GroupMember.WORLD] = { # type: ignore[index] + i: i + for i in range(GroupMember.WORLD.size()) # type: ignore[attr-defined] + } + _backend = _world.pg_map[not_none(GroupMember.WORLD)][0] + _default_pg_init_method = init_method + + old_hook = sys.excepthook + excepthook_prefix = f"[rank{get_rank()}]" + + def _distributed_excepthook(*args): + old_stderr = sys.stderr + sys.stderr = buf = io.StringIO() + try: + old_hook(*args) + finally: + sys.stderr = old_stderr + msg = buf.getvalue() + msg = "\n".join( + f"{excepthook_prefix}: {s}" if s != "" else "" for s in msg.split("\n") + ) + sys.stderr.write(msg) + sys.stderr.flush() + + sys.excepthook = _distributed_excepthook + + if _is_barrier_after_init() == 1: + # barrier at the end to ensure that once we return from this method, all + # process groups including global variables (if any) are updated + # correctly on all ranks. + # Update 04/2023: for large-scale runs, this barrier (esp. store-based + # barrier) may be costly and/or unscalable. Also, in a lot of cases, + # these barriers may be unnecessary, as proven by a green CI after + # removal. An environment variable `TORCH_DIST_INIT_BARRIER` has been + # added which enables this barrier only when set to 1. + logger.debug( + "Performing barrier after ProcessGroup initialization since " + "TORCH_DIST_INIT_BARRIER = 1" + ) + if backend == Backend.MPI: + # MPI backend doesn't use store. + barrier() + else: + # Use store based barrier here since barrier() used a bunch of + # default devices and messes up NCCL internal state. + _store_based_barrier(rank, store, group_name, world_size, timeout) + + +def _get_split_source(pg): + split_from = None + if pg.bound_device_id: + split_from = pg._get_backend(pg.bound_device_id) + elif pg is _world.default_pg: + try: + split_from = pg._get_backend(torch.device("cuda")) + except RuntimeError: + # no cuda device associated with this backend + pass + + if not split_from or not split_from.supports_splitting: + return None + + # If necessary, find a backend to split from by peeling process + # group wrappers from our potentially wrapped process group. + while _GLOO_AVAILABLE and isinstance(split_from, _ProcessGroupWrapper): + split_from = split_from.wrapped_pg + + return split_from + + +def _new_process_group_helper( + group_size, + group_rank, + global_ranks_in_group, + backend, + store, + group_name, + backend_options=None, + timeout=None, + pg_tag=None, + device_id=None, + group_desc=None, +): + """ + Create a new distributed process group. + + This function must be called by ALL processes in the global group, even if + the calling process is not part of the newly created group. In that case, + this function returns GroupMember.NON_GROUP_MEMBER. + + This function is called with ``global_ranks_in_group == []`` for the default group. + """ + global _world + + if group_name in _world.pg_names.values(): + raise ValueError( + "The specified group name has already been " + "created, please use a different group name" + ) + + if device_id is not None and (device_id.index is None or device_id.type == "cpu"): + raise ValueError( + "init_process_group device_id parameter must be an accelerator with an index" + ) + + # Note: _new_process_group_helper is only called from init_process_group, which always provides a timeout value + _check_valid_timeout(timeout) + + if pg_tag not in [None, ""]: + # creating with the same tag and rank set results in the same underlying PG + existing_group = _find_pg_by_ranks_and_tag(pg_tag, global_ranks_in_group) + if existing_group: + _, prefix_store = _world.pg_map[existing_group] + return existing_group, prefix_store + + group_desc = "undefined" if group_desc is None else group_desc + + # The list of group ranks is empty if we're creating the default group. + is_default_group = len(global_ranks_in_group) == 0 + + # nccl and potentially other backends allow creation of + # communicators based on pre-existing ones, which can save + # initialization time. Due to lazy initialization of + # communicators in some backends, we have to be careful and only + # split when we *know* the default PG has already started communicator initialization. + # We know this if we have bound a device id to the default pg (eager initialized). + if is_initialized() and _get_default_group().bound_device_id: + split_from = _get_split_source(_get_default_group()) + else: + split_from = None + + # If this is a subgroup (which means group_ranks is specified), + # we check if the current process is a member of the new group. + if not is_default_group: + global_rank = _get_default_group().rank() + if global_rank not in global_ranks_in_group: + # If we are using `ncclCommSplit` (or similar split from + # other APIs) to create the communicator, we will need to + # call `ncclCommSplit` on *all* ranks in this new group's + # parent group, even those not in the new group. This is + # a requirement of the NCCL API as otherwise we would get + # out of sync. + if split_from: + split_from.perform_nocolor_split(_get_default_group().bound_device_id) + return GroupMember.NON_GROUP_MEMBER, None + + prefix_store = PrefixStore(f"{group_name}/", store) + # The backend for PG will be set later based on what's inside BackendConfig + # and timeout are set in each backend's option. + pg: ProcessGroup = ProcessGroup( + prefix_store, + group_rank, + group_size, + ) + backend_config = BackendConfig(backend) + # Set the default backend when single backend is passed in. + if "," not in str(backend) and ":" not in str(backend): + assert backend in Backend.backend_type_map, f"Unknown backend type {backend}" + if backend == Backend.UNDEFINED: + # Currently when backend is UNDEFINED, both ``gloo`` and ``nccl`` backends + # will be created, we use nccl(if cuda is available) or gloo as default + # backend so we can correctly call getDefaultBackend which in ProcessGroup. + if Backend.NCCL in backend_config.get_device_backend_map().values(): + pg._set_default_backend(ProcessGroup.BackendType.NCCL) + else: + pg._set_default_backend(ProcessGroup.BackendType.GLOO) + else: + pg._set_default_backend(Backend.backend_type_map[backend]) + # In order to correctly call pg._has_hooks(), we should set the default backend + # when multi backend is passed in + else: + if Backend.NCCL in backend_config.device_backend_map.values(): + pg._set_default_backend(ProcessGroup.BackendType.NCCL) + elif Backend._plugins.keys(): + custom_backend = next(iter(Backend._plugins.keys())) + if custom_backend in backend_config.device_backend_map.values(): + pg._set_default_backend(ProcessGroup.BackendType.CUSTOM) + else: + pg._set_default_backend(ProcessGroup.BackendType.GLOO) + + if device_id: + pg.bound_device_id = device_id + backend_class: torch._C._distributed_c10d.Backend + for device, backend_str in backend_config.get_device_backend_map().items(): + # Use the group name as prefix in the default store, such that + # a single store can be reused by multiple groups. + backend_prefix_store = PrefixStore(f"{device}/", prefix_store) + + if backend_str == Backend.MPI: + if not is_mpi_available(): + raise RuntimeError( + "Distributed package doesn't have MPI built in." + " MPI is only included if you build PyTorch from" + " source on a host that has MPI installed." + ) + backend_class = ProcessGroupMPI.create(global_ranks_in_group) + backend_type = ProcessGroup.BackendType.MPI + if not backend_class: + return GroupMember.NON_GROUP_MEMBER, None + # create new process group with accurate rank and size + if pg.rank() == -1 and pg.size() == -1: + pg = ProcessGroup( + backend_prefix_store, + backend_class.rank(), + backend_class.size(), + ) + pg._set_default_backend(backend_type) + elif backend_str == Backend.GLOO: + # TODO: remove this check after lazy initialization is supported + # if pg_options is not None: + # raise RuntimeError("GLOO options not supported") + if not is_gloo_available(): + raise RuntimeError("Distributed package doesn't have Gloo built in") + backend_class = ProcessGroupGloo( + backend_prefix_store, group_rank, group_size, timeout=timeout + ) + backend_class.options.global_ranks_in_group = global_ranks_in_group + backend_class.options.group_name = group_name + backend_type = ProcessGroup.BackendType.GLOO + elif backend_str == Backend.NCCL: + if not is_nccl_available(): + raise RuntimeError("Distributed package doesn't have NCCL built in") + if backend_options is not None: + assert isinstance(backend_options, ProcessGroupNCCL.Options), ( + "Expected backend_options argument to be of type ProcessGroupNCCL.Options" + ) + if backend_options._timeout != timeout: + warnings.warn( + "backend_options._timeout was specified, " + "but timeout kwarg has a default value that will always override it. " + ) + else: + # default backend_options for NCCL + backend_options = ProcessGroupNCCL.Options() + backend_options.is_high_priority_stream = False + backend_options._timeout = timeout + + if split_from: + backend_options.split_from = split_from + backend_options.split_color = _process_group_color( + global_ranks_in_group + ) + backend_options.global_ranks_in_group = global_ranks_in_group + backend_options.group_name = group_name + backend_class = ProcessGroupNCCL( + backend_prefix_store, group_rank, group_size, backend_options + ) + backend_type = ProcessGroup.BackendType.NCCL + elif backend_str == Backend.UCC and is_ucc_available(): + # TODO: once UCC plugin is fully deprecated, remove + # is_ucc_available() from above elif-condition and raise + # RuntimeError if is_ucc_available() returns false. + + backend_class = ProcessGroupUCC( + backend_prefix_store, group_rank, group_size, timeout=timeout + ) + backend_type = ProcessGroup.BackendType.UCC + elif backend_str == Backend.XCCL: + if not is_xccl_available(): + raise RuntimeError("Distributed package doesn't have XCCL built in") + backend_class = ProcessGroupXCCL( + backend_prefix_store, group_rank, group_size + ) + backend_type = ProcessGroup.BackendType.XCCL + else: + assert backend_str.upper() in Backend._plugins, ( + f"Unknown c10d backend type {backend_str.upper()}" + ) + + backend_plugin = Backend._plugins[backend_str.upper()] + creator_fn = backend_plugin.creator_fn + extended_api = backend_plugin.extended_api + backend_type = ProcessGroup.BackendType.CUSTOM + + if not extended_api: + backend_class = creator_fn( + backend_prefix_store, group_rank, group_size, timeout + ) + else: + dist_backend_opts = _DistributedBackendOptions() + dist_backend_opts.store = backend_prefix_store + dist_backend_opts.group_rank = group_rank + dist_backend_opts.group_size = group_size + dist_backend_opts.timeout = timeout + dist_backend_opts.group_id = group_name + dist_backend_opts.global_ranks_in_group = global_ranks_in_group + + backend_class = creator_fn(dist_backend_opts, backend_options) + + # Set sequence numbers for gloo and nccl backends. + if backend_str == Backend.GLOO: + assert isinstance(backend_class, ProcessGroupGloo) + backend_class._set_sequence_number_for_group() + elif backend_str == Backend.NCCL: + assert isinstance(backend_class, ProcessGroupNCCL) + backend_class._set_sequence_number_for_group() + + # If the type is a subclass of ProcessGroup then return this process group immediately + # TODO: This defaults to the old behavior for PythonProcessGroups which overwrites the + # ProcessGroup instance + if issubclass(type(backend_class), ProcessGroup): + pg = backend_class # type: ignore[assignment] + break + + # Process group wrapper initialization for supported PGs when TORCH_DISTRIBUTED_DEBUG is set + if ( + backend_str in [Backend.GLOO, Backend.NCCL, Backend.UCC] + or backend_str.upper() in Backend._plugins + ): + # In debug mode and if GLOO is available, wrap in a wrapper PG that + # enables enhanced collective checking for debuggability. + if get_debug_level() == DebugLevel.DETAIL: + if not _GLOO_AVAILABLE: + logger.info( + """TORCH_DISTRIBUTED_DEBUG was set to DETAIL, but + GLOO is not available. Build with Gloo to + create a wrapper process group in debug mode + to aid collective desynchronization debugging.""" + ) + else: + backend_class = _create_process_group_wrapper( + wrapped_pg=backend_class, + store_prefix=group_name, + store=backend_prefix_store, + rank=group_rank, + world_size=group_size, + timeout=timeout, + ) + + # register only a single backend when all get_device_backend_map values are the same + if len(set(backend_config.get_device_backend_map().values())) == 1: + for device in backend_config.get_device_backend_map().keys(): + pg._register_backend(torch.device(device), backend_type, backend_class) + + # break out of outer loop to not create any more backends + break + + pg._register_backend(torch.device(device), backend_type, backend_class) + + # set group_name and group_dsec to backend + assert group_name is not None + assert group_desc is not None + pg._set_group_name(group_name) + pg._set_group_desc(group_desc) + + if device_id and pg._get_backend(device_id).supports_splitting: + eager_backend = pg._get_backend(device_id) + eager_backend.eager_connect_single_device(device_id) + + # update global state + _world.pg_map[pg] = (backend, prefix_store) + _world.pg_names[pg] = group_name + _register_process_group(group_name, pg) + + _world.pg_backend_config[pg] = str(backend_config) + # "" is the default tag for user PGs + if pg_tag in [None, ""]: + pg_tag = f"ptd:{group_name}" + _world.tags_to_pg.setdefault("", []).append(pg) + else: + pg_tag = f"user:{pg_tag}" + + _world.tags_to_pg.setdefault(pg_tag, []).append(pg) + _world.pg_to_tag[pg] = pg_tag + return pg, prefix_store + + +def destroy_process_group(group: Optional[ProcessGroup] = None): + """ + Destroy a given process group, and deinitialize the distributed package. + + Args: + group (ProcessGroup, optional): The process group to be destroyed, if + group.WORLD is given, all process + groups including the default one will + be destroyed. + """ + global _world + + if group == GroupMember.NON_GROUP_MEMBER: + return + + if group is None: + pg = GroupMember.WORLD + else: + pg = group + + assert pg is not None + if _world.pg_map.get(pg, None) is None: + raise ValueError("Invalid process group specified") + + # When users register Python onCompletion hooks, those hooks will run on a + # different thread than the main thread. Today, the ProcessGroup dtor does + # wait for that thread. However, the dtor might finish after the Python + # Interpreter exits. After that grabbing the GIL for the Python hook will crash. + # We can either revive the interpreter when running hooks or keep the main one + # alive until all works and hooks are done. The current implementation does the + # latter. Therefore, we explicitly call _wait_for_pending_works() here to wait + # for the pending hooks to finish. + if type(pg) == ProcessGroup and pg._has_hooks(): + pg._wait_for_pending_works() + + if group is None or group == GroupMember.WORLD: + # shutdown all backends in the order of pg names. shutting down in order because + # ncclCommAbort() was a 'collective' call in some versions of NCCL. + for pg_to_shutdown in sorted( + _world.pg_names, key=lambda x: _world.pg_names[x], reverse=True + ): + pg_to_shutdown.shutdown() + + _update_default_pg(None) + _world.pg_map.clear() + _world.pg_names.clear() + _world.pg_group_ranks.clear() + _world.pg_backend_config.clear() + _world.pg_to_tag.clear() + _world.tags_to_pg.clear() + _world.pg_coalesce_state.clear() + _unregister_all_process_groups() + + # when process group doesn't have an explicit name (only WORLD (default) + # process group can have an explicit name), we use global _world.group_count + # to generate the name. We need to reset the counter on destruction to + # allow consistent value to be generated when we re-create process + # groups after some trainers recover from failure + # + # We only reset this when WORLD is being destroyed because if this + # process group is in good state, we aren't dealing with failures. + _world.group_count = 0 + else: + pg.shutdown() + del _world.pg_map[pg] + del _world.pg_names[pg] + del _world.pg_group_ranks[pg] + del _world.pg_backend_config[pg] + if pg in _world.pg_coalesce_state.keys(): + warnings.warn( + "Some coalesced collectives haven't been launched when " + "ProcessGroup is destroyed. They will be cleaned." + ) + del _world.pg_coalesce_state[pg] + + tag = _world.pg_to_tag.get(pg) + del _world.pg_to_tag[pg] + if tag is not None: + try: + _world.tags_to_pg[tag].remove(pg) + if tag.startswith("ptd:"): + _world.tags_to_pg[""].remove(pg) + except Exception: + pass + _unregister_process_group(pg.group_name) + + +def _abort_process_group(group: Optional[ProcessGroup] = None): + """ + Abort a given process group. If group.WORLD (i.e. `None`) is given, all + process groups including the default one will be aborted. + + Args: + group (ProcessGroup, optional): The process group to be aborted. + + .. note:: this API is experimental and currently only works with the NCCL + backend. + + .. note:: this API should be used with `TORCH_NCCL_ASYNC_ERROR_HANDLING` + turned off (i.e. set to 0). Otherwise, ProcessGroupNCCL's watchdog may + automatically handle errors or timeouts for you including aborting the + ProcessGroup. + """ + global _world + + if group == GroupMember.NON_GROUP_MEMBER: + return + + pg = group or GroupMember.WORLD + + assert pg is not None + if _world.pg_map.get(pg, None) is None: + raise ValueError("Invalid process group specified or has been destroyed.") + + try: + backend = pg._get_backend(torch.device("cuda")) + except RuntimeError: + backend = None + + if group is None or group == GroupMember.WORLD: + # Abort all backends within a ncclGroupStart|End semantic. + # This ensures that different NCCL communicators' abort calls won't + # deadlock each other. + # For details, please see: https://github.com/pytorch/pytorch/issues/119797 + if is_nccl_available() and isinstance(backend, ProcessGroupNCCL): + backend._group_start() + for pg_to_abort in sorted( + _world.pg_names, key=lambda x: _world.pg_names[x], reverse=True + ): + pg_to_abort.abort() + if is_nccl_available() and isinstance(backend, ProcessGroupNCCL): + backend._group_end() + + _update_default_pg(None) + _world.pg_map.clear() + _world.pg_names.clear() + _world.pg_group_ranks.clear() + _world.pg_backend_config.clear() + _world.pg_to_tag.clear() + _world.tags_to_pg.clear() + _world.pg_coalesce_state.clear() + _unregister_all_process_groups() + + # when process group doesn't have an explicit name (only WORLD (default) + # process group can have an explicit name), we use global _world.group_count + # to generate the name. We need to reset the counter on destruction to + # allow consistent value to be generated when we re-create process + # groups after some trainers recover from failure + # + # We only reset this when WORLD is being destroyed because if this + # process group is in good state, we aren't dealing with failures. + _world.group_count = 0 + else: + pg.abort() + del _world.pg_map[pg] + del _world.pg_names[pg] + del _world.pg_group_ranks[pg] + del _world.pg_backend_config[pg] + if pg in _world.pg_coalesce_state.keys(): + warnings.warn( + "Some coalesced collectives haven't been launched when " + "ProcessGroup is aborted. They will be cleaned." + ) + del _world.pg_coalesce_state[pg] + + tag = _world.pg_to_tag.get(pg) + del _world.pg_to_tag[pg] + if tag is not None: + try: + _world.tags_to_pg[tag].remove(pg) + if tag.startswith("ptd:"): + _world.tags_to_pg[""].remove(pg) + except Exception: + pass + _unregister_process_group(pg.group_name) + + +def get_rank(group: Optional[ProcessGroup] = None) -> int: + """ + Return the rank of the current process in the provided ``group``, default otherwise. + + Rank is a unique identifier assigned to each process within a distributed + process group. They are always consecutive integers ranging from 0 to + ``world_size``. + + Args: + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + + Returns: + The rank of the process group + -1, if not part of the group + + """ + if _rank_not_in_group(group): + return -1 + + default_pg = _get_default_group() + if group is None or group is GroupMember.WORLD: + return default_pg.rank() + + return get_group_rank(group, default_pg.rank()) + + +def get_world_size(group: Optional[ProcessGroup] = None) -> int: + """ + Return the number of processes in the current process group. + + Args: + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + + Returns: + The world size of the process group + -1, if not part of the group + + """ + if _rank_not_in_group(group): + return -1 + + return _get_group_size(group) + + +def isend( + tensor: torch.Tensor, + dst: Optional[int] = None, + group: Optional[ProcessGroup] = None, + tag: int = 0, + group_dst: Optional[int] = None, +) -> Optional[Work]: + """ + Send a tensor asynchronously. + + .. warning:: + Modifying ``tensor`` before the request completes causes undefined + behavior. + + .. warning:: + ``tag`` is not supported with the NCCL backend. + + Unlike send, which is blocking, isend allows src == dst rank, i.e. send to self. + + Args: + tensor (Tensor): Tensor to send. + dst (int): Destination rank on global process group (regardless of ``group`` argument) + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + tag (int, optional): Tag to match send with remote recv + group_dst (int, optional): Destination rank on ``group``. Invalid to specify both ``dst`` and ``group_dst`` + + Returns: + A distributed request object. + None, if not part of the group + + """ + group = _group_or_default_group(group) + group_dst = _canonicalize_group_rank(group, dst, group_dst) + _check_single_tensor(tensor, "tensor") + if _rank_not_in_group(group): + _warn_not_in_group("isend") + return None + + if tensor.is_complex(): + tensor = torch.view_as_real(tensor) + + return group.send([tensor], group_dst, tag) + + +def irecv( + tensor: torch.Tensor, + src: Optional[int] = None, + group: Optional[ProcessGroup] = None, + tag: int = 0, + group_src: Optional[int] = None, +) -> Optional[Work]: + """ + Receives a tensor asynchronously. + + .. warning:: + ``tag`` is not supported with the NCCL backend. + + Unlike recv, which is blocking, irecv allows src == dst rank, i.e. recv from self. + + Args: + tensor (Tensor): Tensor to fill with received data. + src (int, optional): Source rank on global process group (regardless of ``group`` argument). + Will receive from any process if unspecified. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + tag (int, optional): Tag to match recv with remote send + group_src (int, optional): Destination rank on ``group``. Invalid to specify both ``src`` and ``group_src``. + + Returns: + A distributed request object. + None, if not part of the group + + """ + _check_single_tensor(tensor, "tensor") + if _rank_not_in_group(group): + _warn_not_in_group("irecv") + return None + + if tensor.is_complex(): + tensor = torch.view_as_real(tensor) + + group = _group_or_default_group(group) + if src is None and group_src is None: + return group.recv_anysource([tensor], tag) + else: + group_src = _canonicalize_group_rank(group, src, group_src) + return group.recv([tensor], group_src, tag) + + +@_exception_logger +def send( + tensor: torch.Tensor, + dst: Optional[int] = None, + group: Optional[ProcessGroup] = None, + tag: int = 0, + group_dst: Optional[int] = None, +) -> None: + """ + Send a tensor synchronously. + + .. warning:: + ``tag`` is not supported with the NCCL backend. + + Args: + tensor (Tensor): Tensor to send. + dst (int): Destination rank on global process group (regardless of ``group`` argument). + Destination rank should not be the same as the rank of the current process. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + tag (int, optional): Tag to match send with remote recv + group_dst (int, optional): Destination rank on ``group``. Invalid to specify both ``dst`` and ``group_dst``. + + """ + group = _group_or_default_group(group) + group_dst = _canonicalize_group_rank(group, dst, group_dst) + _check_not_self_rank(group, group_dst, "destination") + work = isend(tensor, group=group, tag=tag, group_dst=group_dst) + if work is not None: + work.wait() + + +@_exception_logger +def recv( + tensor: torch.Tensor, + src: Optional[int] = None, + group: Optional[ProcessGroup] = None, + tag: int = 0, + group_src: Optional[int] = None, +) -> int: + """ + Receives a tensor synchronously. + + .. warning:: + ``tag`` is not supported with the NCCL backend. + + Args: + tensor (Tensor): Tensor to fill with received data. + src (int, optional): Source rank on global process group (regardless of ``group`` argument). + Will receive from any process if unspecified. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + tag (int, optional): Tag to match recv with remote send + group_src (int, optional): Destination rank on ``group``. Invalid to specify both ``src`` and ``group_src``. + Returns: + Sender rank + -1, if not part of the group + + """ + work = irecv(tensor, src=src, group=group, tag=tag, group_src=group_src) + if work is None: + return -1 + work.wait() + if src is None: + if group_src is None: + group_src = work._source_rank() + group = _group_or_default_group(group) + _check_not_self_rank(group, group_src, "source") + src = get_global_rank(group, group_src) + return src + + +class _IllegalWork(Work): + def __getattribute__(self, name): + if name in [ + "is_success", + "exception", + "wait", + "source_rank", + "_source_rank", + "result", + "synchronize", + ]: + raise ValueError(f"Illegal to call {name} on IllegalWork object") + + +class _CoalescingManager: + def __init__(self) -> None: + self.works: list[Work] = [] + + def append(self, work: Optional[Work] = None): + if work: + self.works.append(work) + + def wait(self): + for work in self.works: + work.wait() + + +@contextlib.contextmanager +def _coalescing_manager( + group: Optional[ProcessGroup] = None, + device: Optional[torch.device] = None, + async_ops: bool = False, +): + """ + Context manager used to coalesce collectives or P2P operations when possible. + + Args: + group (`ProcessGroup`, optional): The process group to work on. If None, + the default process group will be used. + device (`torch.device`, optional): Default is None, set to a device if + there isn't a `**_coalesced` implementation by the backend. + async_ops (`bool`, optional): whether the coalesced ops are async ops. + + Examples: + >>> # xdoctest: +SKIP("no rank") + >>> # Synchronous ops + >>> with _coalescing_manager(): + >>> for i in range(num_colls): + >>> dist.all_reduce(tensors[i]) + >>> # Asynchronous ops + >>> with _coalescing_manager(async_ops=True) as cm: + >>> for i in range(num_colls): + >>> dist.all_reduce(tensors[i]) + >>> cm.wait() + + .. warning:: + :func:`_coalescing_manager` currently do not support coalescing + all-reduces with different reduce operators, e.g. `ReduceOp.SUM` mixed + with `ReduceOp.PRODUCT`. + """ + group = group or _get_default_group() + op_list = _world.pg_coalesce_state.setdefault(group, []) + if op_list: + raise ValueError( + "ProcessGroup has non-empty op list at the start of coalescing" + ) + if device: + group._start_coalescing(device) + cm = _CoalescingManager() + yield cm + work = None + op_list = _world.pg_coalesce_state.pop(group) + if op_list: + # Collectives supporting "Fast Path" coalescing are captured. + # See implementation in corresponding collective APIs. + # Currently supported: + # - coalesced `all_reduce` + # - coalesced `all_gather_into_tensor` + # - coalesced `reduce_scatter_tensor` + op0 = op_list[0].op + if op0 == all_reduce: + tensors = [op.tensor for op in op_list] + all_reduce_opts = AllreduceCoalescedOptions() + all_reduce_opts.reduceOp = not_none(op_list[0].redop) + all_reduce_opts.asyncOp = async_ops + work = group.allreduce_coalesced(tensors, all_reduce_opts) + elif op0 == all_gather_into_tensor: + inputs = [] + outputs = [] + for op in op_list: + inputs.append(op.tensor) + outputs.append(not_none(op.dst_tensor)) + all_gather_opts = AllgatherOptions() + all_gather_opts.asyncOp = async_ops + work = group.allgather_into_tensor_coalesced(outputs, inputs) + elif op0 == reduce_scatter_tensor: + inputs = [] + outputs = [] + for op in op_list: + inputs.append(op.tensor) + outputs.append(not_none(op.dst_tensor)) + reduce_opts = ReduceScatterOptions() + reduce_opts.reduceOp = not_none(op_list[0].redop) + reduce_opts.asyncOp = async_ops + work = group.reduce_scatter_tensor_coalesced(outputs, inputs, reduce_opts) + else: + raise AssertionError( + f"Coalescing manager does not support fast-path coalescing of {op0}, " + f"yet {op0} is still recorded in op list. This is an internal error of c10d." + ) + + if device: + # Old style of letting each coll inside the context manager to call into C++ counterpart via python binding + work = group._end_coalescing(device) + + if async_ops: + cm.append(work) + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level + work.wait() + # Otherwise, the backend has sync'ed at CPP level + + +class _TimeEstimator: + def __init__(self) -> None: + self.estimated_time: Optional[float] = None + + +@contextlib.contextmanager +def _time_estimator( + group: Optional[ProcessGroup] = None, + device: Optional[torch.device] = None, +): + """ + Context manager used to estimate time of collectives. + Within the context manager, nothing is actually run and the backend just simulates + the collective time only. + + Args: + group (`ProcessGroup`, optional): The process group to work on. If None, + the default process group will be used. + device (`torch.device`, optional): Default is None, set to a device if + there isn't a `**_coalesced` implementation by the backend. + + Examples: + >>> # xdoctest: +SKIP("no rank") + >>> # Synchronous ops + >>> with _time_estimator() as cm: + >>> for i in range(num_colls): + >>> dist.all_reduce(tensors[i]) + >>> # estimate time is stored in cm.estimated_time + + .. warning:: + :func:`_time_estimator` currently only support NCCL backend but it can + easily be extended to other backends. + + Also a NCCL communicator needs to be created because only with a real communicator can we do accurate estimation. + The communicator internally has knowledge about the links it runs on + (e.g. intra-node or inter-node, whether the links are NVLink or PCI-e or IB). + """ + # TODO: We need to also support torch inductor for the time estimator. + group = group or _get_default_group() + device = device or _get_pg_default_device(group) + backend = group._get_backend(device) + if not backend.supports_time_estimate: + raise NotImplementedError( + f"collective time estimator is not supported in the current version of backend {backend}" + ) + backend._start_time_estimate() # type: ignore[attr-defined] + cm = _TimeEstimator() + yield cm + cm.estimated_time = backend._end_time_estimate() # type: ignore[attr-defined] + + +def batch_isend_irecv(p2p_op_list: list[P2POp]) -> list[Work]: + """ + Send or Receive a batch of tensors asynchronously and return a list of requests. + + Process each of the operations in ``p2p_op_list`` and return the corresponding + requests. NCCL, Gloo, and UCC backend are currently supported. + + Args: + p2p_op_list: A list of point-to-point operations(type of each operator is + ``torch.distributed.P2POp``). The order of the isend/irecv in the list + matters and it needs to match with corresponding isend/irecv on the + remote end. + + Returns: + A list of distributed request objects returned by calling the corresponding + op in the op_list. + + Examples: + >>> # xdoctest: +SKIP("no rank") + >>> send_tensor = torch.arange(2, dtype=torch.float32) + 2 * rank + >>> recv_tensor = torch.randn(2, dtype=torch.float32) + >>> send_op = dist.P2POp(dist.isend, send_tensor, (rank + 1) % world_size) + >>> recv_op = dist.P2POp( + ... dist.irecv, recv_tensor, (rank - 1 + world_size) % world_size + ... ) + >>> reqs = batch_isend_irecv([send_op, recv_op]) + >>> for req in reqs: + >>> req.wait() + >>> recv_tensor + tensor([2, 3]) # Rank 0 + tensor([0, 1]) # Rank 1 + + .. note:: Note that when this API is used with the NCCL PG backend, users must set + the current GPU device with `torch.cuda.set_device`, otherwise it will + lead to unexpected hang issues. + + In addition, if this API is the first collective call in the ``group`` + passed to ``dist.P2POp``, all ranks of the ``group`` must participate in + this API call; otherwise, the behavior is undefined. If this API call is + not the first collective call in the ``group``, batched P2P operations + involving only a subset of ranks of the ``group`` are allowed. + """ + _check_p2p_op_list(p2p_op_list) + group = p2p_op_list[0].group + if group is None: + group = _get_default_group() + device = p2p_op_list[0].tensor.device + + def peer_kwarg(op: P2POp) -> dict[str, int]: + key = "group_dst" if op.op == isend else "group_src" + return {key: op.group_peer} + + if type(group) == ProcessGroup and group._get_backend(device).supports_coalescing: + # NCCL style coalescing + with _coalescing_manager(group, device, async_ops=True) as cm: + for p2p_op in p2p_op_list: + p2p_op.op( + p2p_op.tensor, + group=p2p_op.group, + tag=p2p_op.tag, + **peer_kwarg(p2p_op), + ) + + return cm.works + else: + # backend not support coalescing + reqs = [] + for p2p_op in p2p_op_list: + work = p2p_op.op( + p2p_op.tensor, + group=p2p_op.group, + tag=p2p_op.tag, + **peer_kwarg(p2p_op), + ) + if work: + reqs.append(work) + return reqs + + +@_exception_logger +def broadcast( + tensor: torch.Tensor, + src: Optional[int] = None, + group: Optional[ProcessGroup] = None, + async_op: bool = False, + group_src: Optional[int] = None, +): + """ + Broadcasts the tensor to the whole group. + + ``tensor`` must have the same number of elements in all processes + participating in the collective. + + Args: + tensor (Tensor): Data to be sent if ``src`` is the rank of current + process, and tensor to be used to save received data otherwise. + src (int): Source rank on global process group (regardless of ``group`` argument). + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op + group_src (int): Source rank on ``group``. Must specify one of ``group_src`` + and ``src`` but not both. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + """ + group = _group_or_default_group(group) + group_src = _canonicalize_group_rank(group, src, group_src, return_global=False) + _check_single_tensor(tensor, "tensor") + if _rank_not_in_group(group): + _warn_not_in_group("broadcast") + return + + opts = BroadcastOptions() + opts.rootRank = group_src + opts.rootTensor = 0 + opts.asyncOp = async_op + work = group.broadcast([tensor], opts) + if async_op: + return work + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level + work.wait() + # Otherwise, the backend has sync'ed at CPP level + + +@_exception_logger +def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False): + """ + Reduces the tensor data across all machines in a way that all get the final result. + + After the call ``tensor`` is going to be bitwise identical in all processes. + + Complex tensors are supported. + + Args: + tensor (Tensor): Input and output of the collective. The function + operates in-place. + op (optional): One of the values from + ``torch.distributed.ReduceOp`` + enum. Specifies an operation used for element-wise reductions. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + Examples: + >>> # xdoctest: +SKIP("no rank") + >>> # All tensors below are of torch.int64 type. + >>> # We have 2 process groups, 2 ranks. + >>> device = torch.device(f"cuda:{rank}") + >>> tensor = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank + >>> tensor + tensor([1, 2], device='cuda:0') # Rank 0 + tensor([3, 4], device='cuda:1') # Rank 1 + >>> dist.all_reduce(tensor, op=ReduceOp.SUM) + >>> tensor + tensor([4, 6], device='cuda:0') # Rank 0 + tensor([4, 6], device='cuda:1') # Rank 1 + + >>> # All tensors below are of torch.cfloat type. + >>> # We have 2 process groups, 2 ranks. + >>> tensor = torch.tensor( + ... [1 + 1j, 2 + 2j], dtype=torch.cfloat, device=device + ... ) + 2 * rank * (1 + 1j) + >>> tensor + tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0 + tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1 + >>> dist.all_reduce(tensor, op=ReduceOp.SUM) + >>> tensor + tensor([4.+4.j, 6.+6.j], device='cuda:0') # Rank 0 + tensor([4.+4.j, 6.+6.j], device='cuda:1') # Rank 1 + + """ + # Dynamo has built-in logic to map legacy distributed ops to functional collectives. + # Let's redirect to a torch function mode that can mimic this logic outside Dynamo + # (e.g., non-strict export implements such a torch function mode). + relevant_args = (tensor,) + if has_torch_function(relevant_args): + return handle_torch_function( + all_reduce, + relevant_args, + tensor, + op=op, + group=group, + async_op=async_op, + ) + + _check_single_tensor(tensor, "tensor") + if _rank_not_in_group(group): + _warn_not_in_group("all_reduce") + return + + if tensor.is_complex(): + if not supports_complex(op): + raise ValueError(f"all_reduce does not support {op} on complex tensors") + tensor = torch.view_as_real(tensor) + + opts = AllreduceOptions() + opts.reduceOp = op + opts.asyncOp = async_op + if group is None: + group = _get_default_group() + + if group in _world.pg_coalesce_state.keys(): + # We are in coalescing context, do not issue single operation, just append a collective representation + coll = _CollOp(all_reduce, tensor, None, op, None) + _world.pg_coalesce_state[group].append(coll) + if async_op: + return _IllegalWork() + else: + return None + + work = group.allreduce([tensor], opts) + + if async_op: + return work + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level + work.wait() + # Otherwise, the backend has sync'ed at CPP level + + +@_exception_logger +@deprecated( + "`torch.distributed.all_reduce_coalesced` will be deprecated. If you must " + "use it, please revisit our documentation later at " + "https://pytorch.org/docs/main/distributed.html#collective-functions", + category=FutureWarning, +) +def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False): + """ + WARNING: at this time individual shape checking is not implemented across nodes. + + For example, if the rank 0 node passes [torch.rand(4), torch.rand(2)] and the + rank 1 node passes [torch.rand(2), torch.rand(2), torch.rand(2)], the allreduce + operation will proceed without complaint and return erroneous outputs. This lack + of shape checking results in significant performance improvements but users of this + function should take extra care to ensure that each node passes in tensors whose + shapes match across nodes. + + Reduces each tensor in tensors (residing on the same device) across all machines + in such a way that all get the final result. + + After the call each tensor in tensors is going to bitwise identical + in all processes. + + Complex tensors are supported. + + Args: + tensors (Union[List[Tensor], Tensor]): Input and output of the collective. + The function operates in-place. + op (Optional[ReduceOp]): One of the values from + ``torch.distributed.ReduceOp`` enum. Specifies an operation used for + element-wise reductions. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (Optional[bool]): Whether this op should be an async op. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group. + + """ + if isinstance(tensors, torch.Tensor): + tensors = [tensors] + _check_tensor_list(tensors, "tensor") + _ensure_all_tensors_same_dtype(tensors) + if _rank_not_in_group(group): + _warn_not_in_group("all_reduce_coalesced") + return + + if any(t.is_complex() for t in tensors) and not supports_complex(op): + raise ValueError(f"all_reduce does not support {op} on complex tensors") + + tensors = [t if not t.is_complex() else torch.view_as_real(t) for t in tensors] + + opts = AllreduceCoalescedOptions() + opts.reduceOp = op + opts.asyncOp = async_op + group = group or _get_default_group() + work = group.allreduce_coalesced(tensors, opts) + + if async_op: + return work.get_future() + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level + work.wait() + # Otherwise, the backend has sync'ed at CPP level + + +@_exception_logger +def reduce( + tensor: torch.Tensor, + dst: Optional[int] = None, + op=ReduceOp.SUM, + group: Optional[ProcessGroup] = None, + async_op: bool = False, + group_dst: Optional[int] = None, +): + """ + Reduces the tensor data across all machines. + + Only the process with rank ``dst`` is going to receive the final result. + + Args: + tensor (Tensor): Input and output of the collective. The function + operates in-place. + dst (int): Destination rank on global process group (regardless of ``group`` argument) + op (optional): One of the values from + ``torch.distributed.ReduceOp`` + enum. Specifies an operation used for element-wise reductions. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op + group_dst (int): Destination rank on ``group``. Must specify one of ``group_dst`` + and ``dst`` but not both. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + """ + group = _group_or_default_group(group) + group_dst = _canonicalize_group_rank(group, dst, group_dst, return_global=False) + _check_single_tensor(tensor, "tensor") + if _rank_not_in_group(group): + _warn_not_in_group("reduce") + return + + opts = ReduceOptions() + opts.reduceOp = op + opts.rootRank = group_dst + opts.asyncOp = async_op + work = group.reduce([tensor], opts) + if async_op: + return work + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level + work.wait() + # Otherwise, the backend has sync'ed at CPP level + + +def _object_to_tensor(obj, device, group): + with _WaitCounter("pytorch.wait_counter.c10d._object_to_tensor").guard(): + f = io.BytesIO() + _pickler(f).dump(obj) + byte_storage = torch.ByteStorage._from_buffer(f.getvalue()) # type: ignore[attr-defined] + # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype. + # Otherwise, it will cause 100X slowdown. + # See: https://github.com/pytorch/pytorch/issues/65696 + byte_tensor = torch.ByteTensor(byte_storage).to(device) + if get_debug_level() == DebugLevel.DETAIL and is_nccl_available(): + backend = get_backend(group) + if backend == Backend.NCCL: + hash = torch._C._distributed_c10d._hash_tensors([byte_tensor]) + logger.warning( + "_object_to_tensor size: %s hash value: %s", + byte_tensor.numel(), + hash, + ) + local_size = torch.LongTensor([byte_tensor.numel()]).to(device) + return byte_tensor, local_size + + +def _tensor_to_object(tensor, tensor_size, group): + with _WaitCounter("pytorch.wait_counter.c10d._tensor_to_object").guard(): + if get_debug_level() == DebugLevel.DETAIL and is_nccl_available(): + backend = get_backend(group) + if backend == Backend.NCCL: + hash = torch._C._distributed_c10d._hash_tensors([tensor]) + logger.warning( + "_tensor_to_object size: %s hash value: %s", tensor.numel(), hash + ) + tensor = tensor.cpu() + buf = tensor.numpy().tobytes()[:tensor_size] + return _unpickler(io.BytesIO(buf)).load() + + +@_exception_logger +def all_gather_object(object_list, obj, group=None): + """ + Gathers picklable objects from the whole group into a list. + + Similar to :func:`all_gather`, but Python objects can be passed in. + Note that the object must be picklable in order to be gathered. + + Args: + object_list (list[Any]): Output list. It should be correctly sized as the + size of the group for this collective and will contain the output. + obj (Any): Pickable Python object to be broadcast from current process. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Default is ``None``. + + Returns: + None. If the calling rank is part of this group, the output of the + collective will be populated into the input ``object_list``. If the + calling rank is not part of the group, the passed in ``object_list`` will + be unmodified. + + .. note:: Note that this API differs slightly from the :func:`all_gather` + collective since it does not provide an ``async_op`` handle and thus + will be a blocking call. + + .. note:: For NCCL-based processed groups, internal tensor representations + of objects must be moved to the GPU device before communication takes + place. In this case, the device used is given by + ``torch.cuda.current_device()`` and it is the user's responsibility to + ensure that this is set so that each rank has an individual GPU, via + ``torch.cuda.set_device()``. + + .. warning:: + Object collectives have a number of serious performance and scalability + limitations. See :ref:`object_collectives` for details. + + .. warning:: + :func:`all_gather_object` uses ``pickle`` module implicitly, which is + known to be insecure. It is possible to construct malicious pickle data + which will execute arbitrary code during unpickling. Only call this + function with data you trust. + + .. warning:: + Calling :func:`all_gather_object` with GPU tensors is not well supported + and inefficient as it incurs GPU -> CPU transfer since tensors would be + pickled. Please consider using :func:`all_gather` instead. + + Example:: + >>> # xdoctest: +SKIP("need process group init") + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> # Assumes world_size of 3. + >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object + >>> output = [None for _ in gather_objects] + >>> dist.all_gather_object(output, gather_objects[dist.get_rank()]) + >>> output + ['foo', 12, {1: 2}] + """ + if _rank_not_in_group(group): + _warn_not_in_group("all_gather_object") + return + + current_device = _get_object_coll_device(group) + input_tensor, local_size = _object_to_tensor(obj, current_device, group) + + # Gather all local sizes. This is so that we can find the max size, and index + # until the correct size when deserializing the tensors. + group_size = get_world_size(group=group) + object_sizes_tensor = torch.zeros( + group_size, dtype=torch.long, device=current_device + ) + object_size_list = [ + object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) + ] + # Allgather tensor sizes + all_gather(object_size_list, local_size, group=group) + max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] + # Resize tensor to max size across all ranks. + input_tensor.resize_(max_object_size) + coalesced_output_tensor = torch.empty( + max_object_size * group_size, dtype=torch.uint8, device=current_device + ) + # Output tensors are nonoverlapping views of coalesced_output_tensor + output_tensors = [ + coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] + for i in range(group_size) + ] + all_gather(output_tensors, input_tensor, group=group) + # Deserialize outputs back to object. + for i, tensor in enumerate(output_tensors): + tensor = tensor.type(torch.uint8) + tensor_size = object_size_list[i] + object_list[i] = _tensor_to_object(tensor, tensor_size, group) + + +@_exception_logger +def gather_object( + obj: Any, + object_gather_list: Optional[list[Any]] = None, + dst: Optional[int] = None, + group: Optional[ProcessGroup] = None, + group_dst: Optional[int] = None, +): + """ + Gathers picklable objects from the whole group in a single process. + + Similar to :func:`gather`, but Python objects can be passed in. Note that the + object must be picklable in order to be gathered. + + Args: + obj (Any): Input object. Must be picklable. + object_gather_list (list[Any]): Output list. On the ``dst`` rank, it + should be correctly sized as the size of the group for this + collective and will contain the output. Must be ``None`` on non-dst + ranks. (default is ``None``) + dst (int, optional): Destination rank on global process group (regardless of ``group`` argument). + (If both ``dst`` and ``group_dst`` are None, default is global rank 0) + group: (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Default is ``None``. + group_dst (int, optional): Destination rank on ``group``. Invalid to specify both ``dst`` and ``group_dst`` + + Returns: + None. On the ``dst`` rank, ``object_gather_list`` will contain the + output of the collective. + + .. note:: Note that this API differs slightly from the gather collective + since it does not provide an async_op handle and thus will be a blocking + call. + + .. note:: For NCCL-based processed groups, internal tensor representations + of objects must be moved to the GPU device before communication takes + place. In this case, the device used is given by + ``torch.cuda.current_device()`` and it is the user's responsibility to + ensure that this is set so that each rank has an individual GPU, via + ``torch.cuda.set_device()``. + + .. warning:: + Object collectives have a number of serious performance and scalability + limitations. See :ref:`object_collectives` for details. + + .. warning:: + :func:`gather_object` uses ``pickle`` module implicitly, which is + known to be insecure. It is possible to construct malicious pickle data + which will execute arbitrary code during unpickling. Only call this + function with data you trust. + + .. warning:: + Calling :func:`gather_object` with GPU tensors is not well supported + and inefficient as it incurs GPU -> CPU transfer since tensors would be + pickled. Please consider using :func:`gather` instead. + + Example:: + >>> # xdoctest: +SKIP("need process group init") + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> # Assumes world_size of 3. + >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object + >>> output = [None for _ in gather_objects] + >>> dist.gather_object( + ... gather_objects[dist.get_rank()], + ... output if dist.get_rank() == 0 else None, + ... dst=0 + ... ) + >>> # On rank 0 + >>> output + ['foo', 12, {1: 2}] + """ + group = _group_or_default_group(group) + if dst is None and group_dst is None: + dst = 0 + group_dst = _canonicalize_group_rank(group, dst, group_dst, return_global=False) + if _rank_not_in_group(group): + _warn_not_in_group("gather_object") + return + + # Ensure object_gather_list is specified appropriately. + my_group_rank = group.rank() + _validate_output_list_for_rank(my_group_rank, group_dst, object_gather_list) + current_device = _get_object_coll_device(group) + input_tensor, local_size = _object_to_tensor(obj, current_device, group) + + # Gather all local sizes. This is so that we can find the max size, and index + # until the correct size when deserializing the tensors. + group_size = get_world_size(group=group) + object_sizes_tensor = torch.zeros( + group_size, dtype=torch.long, device=current_device + ) + object_size_list = [ + object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) + ] + # Allgather tensor sizes. An all-gather is needed here despite this being a + # gather, since each rank needs to broadcast a tensor of the same (maximal) + # size. + all_gather(object_size_list, local_size, group=group) + max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] + # Resize tensor to max size across all ranks. + input_tensor.resize_(max_object_size) + # Avoid populating output tensors if the result won't be gathered on this rank. + if my_group_rank == group_dst: + coalesced_output_tensor = torch.empty( + max_object_size * group_size, dtype=torch.uint8, device=current_device + ) + # Output tensors are nonoverlapping views of coalesced_output_tensor + output_tensors = [ + coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] + for i in range(group_size) + ] + # All ranks call gather with equal-sized tensors. + gather( + input_tensor, + gather_list=output_tensors if my_group_rank == group_dst else None, # type: ignore[possibly-undefined] + group_dst=group_dst, + group=group, + ) + if my_group_rank != group_dst: + return + + assert object_gather_list is not None, "Must provide object_gather_list on dst rank" + for i, tensor in enumerate(output_tensors): + tensor = tensor.type(torch.uint8) + tensor_size = object_size_list[i] + object_gather_list[i] = _tensor_to_object(tensor, tensor_size, group) + + +@_exception_logger +def send_object_list( + object_list: list[Any], + dst: Optional[int] = None, + group: Optional[ProcessGroup] = None, + device: Optional[torch.device] = None, + group_dst: Optional[int] = None, +): + """ + Sends picklable objects in ``object_list`` synchronously. + + Similar to :func:`send`, but Python objects can be passed in. + Note that all objects in ``object_list`` must be picklable in order to be + sent. + + Args: + object_list (List[Any]): List of input objects to sent. + Each object must be picklable. Receiver must provide lists of equal sizes. + dst (int): Destination rank to send ``object_list`` to. + Destination rank is based on global process group (regardless of ``group`` argument) + group: (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Default is ``None``. + device (``torch.device``, optional): If not None, the objects are + serialized and converted to tensors which are moved to the + ``device`` before sending. Default is ``None``. + group_dst (int, optional): Destination rank on ``group``. + Must specify one of ``dst`` and ``group_dst`` but not both + Returns: + ``None``. + + .. note:: For NCCL-based process groups, internal tensor representations + of objects must be moved to the GPU device before communication takes + place. In this case, the device used is given by + ``torch.cuda.current_device()`` and it is the user's responsibility to + ensure that this is set so that each rank has an individual GPU, via + ``torch.cuda.set_device()``. + + .. warning:: + Object collectives have a number of serious performance and scalability + limitations. See :ref:`object_collectives` for details. + + .. warning:: + :func:`send_object_list` uses ``pickle`` module implicitly, which + is known to be insecure. It is possible to construct malicious pickle + data which will execute arbitrary code during unpickling. Only call this + function with data you trust. + + .. warning:: + Calling :func:`send_object_list` with GPU tensors is not well supported + and inefficient as it incurs GPU -> CPU transfer since tensors would be + pickled. Please consider using :func:`send` instead. + + Example:: + >>> # xdoctest: +SKIP("need process group init") + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> # Assumes backend is not NCCL + >>> device = torch.device("cpu") + >>> if dist.get_rank() == 0: + >>> # Assumes world_size of 2. + >>> objects = ["foo", 12, {1: 2}] # any picklable object + >>> dist.send_object_list(objects, dst=1, device=device) + >>> else: + >>> objects = [None, None, None] + >>> dist.recv_object_list(objects, src=0, device=device) + >>> objects + ['foo', 12, {1: 2}] + """ + group = _group_or_default_group(group) + group_dst = _canonicalize_group_rank(group, dst, group_dst) + _check_not_self_rank(group, group_dst, "destination") + + if _rank_not_in_group(group): + _warn_not_in_group("send_object_list") + return + + # Current device selection. + # To preserve backwards compatibility, ``device`` is default to ``None`` + # in which case we run current logic of device selection, i.e. + # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the + # case it is not ``None`` we move the size and object tensors to be + # sent to this device. + current_device = device or _get_object_coll_device(group) + # Serialize object_list elements to tensors on src rank. + tensor_list, size_list = zip( + *[_object_to_tensor(obj, current_device, group) for obj in object_list] + ) + object_sizes_tensor = torch.cat(size_list) + + # Send object sizes + send(object_sizes_tensor, group_dst=group_dst, group=group) + + # Concatenate and send serialized object tensors + # Note: torch.cat will do an extra memory copy to the current device, if the tensor_list + # has only one element, we can skip the copy. + if len(tensor_list) == 1: # type: ignore[possibly-undefined] + object_tensor = tensor_list[0] + else: + object_tensor = torch.cat(tensor_list) + + send(object_tensor, group_dst=group_dst, group=group) + + +@_exception_logger +def recv_object_list( + object_list: list[Any], + src: Optional[int] = None, + group: Optional[ProcessGroup] = None, + device: Optional[torch.device] = None, + group_src: Optional[int] = None, +): + """ + Receives picklable objects in ``object_list`` synchronously. + + Similar to :func:`recv`, but can receive Python objects. + + Args: + object_list (List[Any]): List of objects to receive into. + Must provide a list of sizes equal to the size of the list being sent. + src (int, optional): Source rank from which to recv ``object_list``. + Source rank is based on global process group (regardless of ``group`` argument) + Will receive from any rank if set to None. Default is ``None``. + group: (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Default is ``None``. + device (``torch.device``, optional): If not None, receives on this device. + Default is ``None``. + group_src (int, optional): Destination rank on ``group``. Invalid to specify both ``src`` and ``group_src``. + + Returns: + Sender rank. -1 if rank is not part of the group. If rank is part of the group, + ``object_list`` will contain the sent objects from ``src`` rank. + + .. note:: For NCCL-based process groups, internal tensor representations + of objects must be moved to the GPU device before communication takes + place. In this case, the device used is given by + ``torch.cuda.current_device()`` and it is the user's responsibility to + ensure that this is set so that each rank has an individual GPU, via + ``torch.cuda.set_device()``. + + .. warning:: + Object collectives have a number of serious performance and scalability + limitations. See :ref:`object_collectives` for details. + + .. warning:: + :func:`recv_object_list` uses ``pickle`` module implicitly, which + is known to be insecure. It is possible to construct malicious pickle + data which will execute arbitrary code during unpickling. Only call this + function with data you trust. + + .. warning:: + Calling :func:`recv_object_list` with GPU tensors is not well supported + and inefficient as it incurs GPU -> CPU transfer since tensors would be + pickled. Please consider using :func:`recv` instead. + + Example:: + >>> # xdoctest: +SKIP("need process group init") + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> # Assumes backend is not NCCL + >>> device = torch.device("cpu") + >>> if dist.get_rank() == 0: + >>> # Assumes world_size of 2. + >>> objects = ["foo", 12, {1: 2}] # any picklable object + >>> dist.send_object_list(objects, dst=1, device=device) + >>> else: + >>> objects = [None, None, None] + >>> dist.recv_object_list(objects, src=0, device=device) + >>> objects + ['foo', 12, {1: 2}] + """ + if _rank_not_in_group(group): + _warn_not_in_group("recv_object_list") + return -1 + + # Current device selection. + # To preserve backwards compatibility, ``device`` is default to ``None`` + # in which case we run current logic of device selection, i.e. + # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the + # case it is not ``None`` we move the size and object tensors to be + # received to this device. + current_device = device or _get_object_coll_device(group) + object_sizes_tensor = torch.empty( + len(object_list), dtype=torch.long, device=current_device + ) + + # Receive object sizes + rank_sizes = recv(object_sizes_tensor, src=src, group=group, group_src=group_src) + + # Tensor to receive serialized objects into. + object_tensor = torch.empty( # type: ignore[call-overload] + torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] + dtype=torch.uint8, + device=current_device, + ) + + rank_objects = recv(object_tensor, src=src, group=group, group_src=group_src) + assert rank_sizes == rank_objects, ( + "Mismatch in return ranks for object sizes and objects." + ) + # Deserialize objects using their stored sizes. + offset = 0 + for i, obj_size in enumerate(object_sizes_tensor): + obj_view = object_tensor[offset : offset + obj_size] + obj_view = obj_view.type(torch.uint8) + offset += obj_size + object_list[i] = _tensor_to_object(obj_view, obj_size, group) + return rank_objects + + +@_exception_logger +def broadcast_object_list( + object_list: list[Any], + src: Optional[int] = None, + group: Optional[ProcessGroup] = None, + device: Optional[torch.device] = None, + group_src: Optional[int] = None, +): + """ + Broadcasts picklable objects in ``object_list`` to the whole group. + + Similar to :func:`broadcast`, but Python objects can be passed in. + Note that all objects in ``object_list`` must be picklable in order to be + broadcasted. + + Args: + object_list (List[Any]): List of input objects to broadcast. + Each object must be picklable. Only objects on the ``src`` rank will + be broadcast, but each rank must provide lists of equal sizes. + src (int): Source rank from which to broadcast ``object_list``. + Source rank is based on global process group (regardless of ``group`` argument) + group: (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Default is ``None``. + device (``torch.device``, optional): If not None, the objects are + serialized and converted to tensors which are moved to the + ``device`` before broadcasting. Default is ``None``. + group_src (int): Source rank on ``group``. Must not specify one of ``group_src`` + and ``src`` but not both. + + Returns: + ``None``. If rank is part of the group, ``object_list`` will contain the + broadcasted objects from ``src`` rank. + + .. note:: For NCCL-based process groups, internal tensor representations + of objects must be moved to the GPU device before communication takes + place. In this case, the device used is given by + ``torch.cuda.current_device()`` and it is the user's responsibility to + ensure that this is set so that each rank has an individual GPU, via + ``torch.cuda.set_device()``. + + .. note:: Note that this API differs slightly from the :func:`broadcast` + collective since it does not provide an ``async_op`` handle and thus + will be a blocking call. + + .. warning:: + Object collectives have a number of serious performance and scalability + limitations. See :ref:`object_collectives` for details. + + .. warning:: + :func:`broadcast_object_list` uses ``pickle`` module implicitly, which + is known to be insecure. It is possible to construct malicious pickle + data which will execute arbitrary code during unpickling. Only call this + function with data you trust. + + .. warning:: + Calling :func:`broadcast_object_list` with GPU tensors is not well supported + and inefficient as it incurs GPU -> CPU transfer since tensors would be + pickled. Please consider using :func:`broadcast` instead. + + Example:: + >>> # xdoctest: +SKIP("need process group init") + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> if dist.get_rank() == 0: + >>> # Assumes world_size of 3. + >>> objects = ["foo", 12, {1: 2}] # any picklable object + >>> else: + >>> objects = [None, None, None] + >>> # Assumes backend is not NCCL + >>> device = torch.device("cpu") + >>> dist.broadcast_object_list(objects, src=0, device=device) + >>> objects + ['foo', 12, {1: 2}] + """ + group = _group_or_default_group(group) + if src is None and group_src is None: + src = 0 + group_src = _canonicalize_group_rank(group, src, group_src, return_global=False) + if _rank_not_in_group(group): + _warn_not_in_group("broadcast_object_list") + return + + # Current device selection. + # To preserve backwards compatibility, ``device`` is default to ``None`` + # in which case we run current logic of device selection, i.e. + # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the + # case it is not ``None`` we move the size and object tensors to be + # broadcasted to this device. + current_device = device or _get_object_coll_device(group) + my_group_rank = group.rank() + # Serialize object_list elements to tensors on src rank. + if my_group_rank == group_src: + tensor_list, size_list = zip( + *[_object_to_tensor(obj, current_device, group) for obj in object_list] + ) + object_sizes_tensor = torch.cat(size_list) + else: + object_sizes_tensor = torch.empty( + len(object_list), dtype=torch.long, device=current_device + ) + + # Broadcast object sizes + broadcast(object_sizes_tensor, group_src=group_src, group=group) + + # Concatenate and broadcast serialized object tensors + # Note: torch.cat will do an extra memory copy to the current device, if the tensor_list + # has only one element, we can skip the copy. + if my_group_rank == group_src: + if len(tensor_list) == 1: # type: ignore[possibly-undefined] + object_tensor = tensor_list[0] + else: + object_tensor = torch.cat(tensor_list) + else: + object_tensor = torch.empty( # type: ignore[call-overload] + torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] + dtype=torch.uint8, + device=current_device, + ) + + broadcast(object_tensor, group_src=group_src, group=group) + # Deserialize objects using their stored sizes. + offset = 0 + if my_group_rank != group_src: + for i, obj_size in enumerate(object_sizes_tensor): + obj_view = object_tensor[offset : offset + obj_size] + obj_view = obj_view.type(torch.uint8) + offset += obj_size + object_list[i] = _tensor_to_object(obj_view, obj_size, group) + + +@_exception_logger +def scatter_object_list( + scatter_object_output_list: list[Any], + scatter_object_input_list: Optional[list[Any]] = None, + src: Optional[int] = None, + group: Optional[ProcessGroup] = None, + group_src: Optional[int] = None, +): + """ + Scatters picklable objects in ``scatter_object_input_list`` to the whole group. + + Similar to :func:`scatter`, but Python objects can be passed in. On + each rank, the scattered object will be stored as the first element of + ``scatter_object_output_list``. Note that all objects in + ``scatter_object_input_list`` must be picklable in order to be scattered. + + Args: + scatter_object_output_list (List[Any]): Non-empty list whose first + element will store the object scattered to this rank. + scatter_object_input_list (List[Any], optional): List of input objects to scatter. + Each object must be picklable. Only objects on the ``src`` rank will + be scattered, and the argument can be ``None`` for non-src ranks. + src (int): Source rank from which to scatter ``scatter_object_input_list``. + Source rank is based on global process group (regardless of ``group`` argument). + (If both ``src`` and ``group_src`` are None, default is global rank 0) + group: (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Default is ``None``. + group_src (int, optional): Source rank on ``group``. Invalid to specify both ``src`` and ``group_src`` + + Returns: + ``None``. If rank is part of the group, ``scatter_object_output_list`` + will have its first element set to the scattered object for this rank. + + .. note:: Note that this API differs slightly from the scatter collective + since it does not provide an ``async_op`` handle and thus will be a + blocking call. + + .. warning:: + Object collectives have a number of serious performance and scalability + limitations. See :ref:`object_collectives` for details. + + .. warning:: + :func:`scatter_object_list` uses ``pickle`` module implicitly, which + is known to be insecure. It is possible to construct malicious pickle + data which will execute arbitrary code during unpickling. Only call this + function with data you trust. + + .. warning:: + Calling :func:`scatter_object_list` with GPU tensors is not well supported + and inefficient as it incurs GPU -> CPU transfer since tensors would be + pickled. Please consider using :func:`scatter` instead. + + Example:: + >>> # xdoctest: +SKIP("need process group init") + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> if dist.get_rank() == 0: + >>> # Assumes world_size of 3. + >>> objects = ["foo", 12, {1: 2}] # any picklable object + >>> else: + >>> # Can be any list on non-src ranks, elements are not used. + >>> objects = [None, None, None] + >>> output_list = [None] + >>> dist.scatter_object_list(output_list, objects, src=0) + >>> # Rank i gets objects[i]. For example, on rank 2: + >>> output_list + [{1: 2}] + """ + group = _group_or_default_group(group) + if src is None and group_src is None: + src = 0 + group_src = _canonicalize_group_rank(group, src, group_src, return_global=False) + if _rank_not_in_group(group): + _warn_not_in_group("scatter_object_list") + return + + if ( + not isinstance(scatter_object_output_list, list) + or len(scatter_object_output_list) < 1 + ): + raise ValueError( + "Expected argument scatter_object_output_list to be a list of size at least 1." + ) + + my_group_rank = group.rank() + pg_device = _get_object_coll_device(group) + if my_group_rank == group_src: + if scatter_object_input_list is None: + raise ValueError( + "source rank must provide non-None scatter_object_input_list" + ) + tensor_list, tensor_sizes = zip( + *[ + _object_to_tensor(obj, pg_device, group) + for obj in scatter_object_input_list + ] + ) + tensor_list, tensor_sizes = list(tensor_list), list(tensor_sizes) + + # Src rank broadcasts the maximum tensor size. This is because all ranks are + # expected to call into scatter() with equal-sized tensors. + max_tensor_size = max(tensor_sizes) # type: ignore[possibly-undefined] + for tensor in tensor_list: # type: ignore[possibly-undefined] + tensor.resize_(max_tensor_size) + else: + max_tensor_size = torch.tensor([0], dtype=torch.long, device=pg_device) + broadcast(max_tensor_size, group_src=group_src, group=group) + + # Scatter actual serialized objects + output_tensor = torch.empty( + max_tensor_size.item(), dtype=torch.uint8, device=pg_device + ) + scatter( + output_tensor, + scatter_list=None if my_group_rank != group_src else tensor_list, # type: ignore[possibly-undefined] + group_src=group_src, + group=group, + ) + + # Scatter per-object sizes to trim tensors when deserializing back to object + obj_tensor_size = torch.tensor([0], dtype=torch.long, device=pg_device) + scatter( + obj_tensor_size, + scatter_list=None if my_group_rank != group_src else tensor_sizes, # type: ignore[possibly-undefined] + group_src=group_src, + group=group, + ) + + # Deserialize back to object + scatter_object_output_list[0] = _tensor_to_object( + output_tensor, obj_tensor_size, group + ) + + +@_exception_logger +def all_gather(tensor_list, tensor, group=None, async_op=False): + """ + Gathers tensors from the whole group in a list. + + Complex and uneven sized tensors are supported. + + Args: + tensor_list (list[Tensor]): Output list. It should contain + correctly-sized tensors to be used for output of the collective. + Uneven sized tensors are supported. + tensor (Tensor): Tensor to be broadcast from current process. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + Examples: + >>> # xdoctest: +SKIP("need process group init") + >>> # All tensors below are of torch.int64 dtype. + >>> # We have 2 process groups, 2 ranks. + >>> device = torch.device(f"cuda:{rank}") + >>> tensor_list = [ + ... torch.zeros(2, dtype=torch.int64, device=device) for _ in range(2) + ... ] + >>> tensor_list + [tensor([0, 0], device='cuda:0'), tensor([0, 0], device='cuda:0')] # Rank 0 + [tensor([0, 0], device='cuda:1'), tensor([0, 0], device='cuda:1')] # Rank 1 + >>> tensor = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank + >>> tensor + tensor([1, 2], device='cuda:0') # Rank 0 + tensor([3, 4], device='cuda:1') # Rank 1 + >>> dist.all_gather(tensor_list, tensor) + >>> tensor_list + [tensor([1, 2], device='cuda:0'), tensor([3, 4], device='cuda:0')] # Rank 0 + [tensor([1, 2], device='cuda:1'), tensor([3, 4], device='cuda:1')] # Rank 1 + + >>> # All tensors below are of torch.cfloat dtype. + >>> # We have 2 process groups, 2 ranks. + >>> tensor_list = [ + ... torch.zeros(2, dtype=torch.cfloat, device=device) for _ in range(2) + ... ] + >>> tensor_list + [tensor([0.+0.j, 0.+0.j], device='cuda:0'), tensor([0.+0.j, 0.+0.j], device='cuda:0')] # Rank 0 + [tensor([0.+0.j, 0.+0.j], device='cuda:1'), tensor([0.+0.j, 0.+0.j], device='cuda:1')] # Rank 1 + >>> tensor = torch.tensor( + ... [1 + 1j, 2 + 2j], dtype=torch.cfloat, device=device + ... ) + 2 * rank * (1 + 1j) + >>> tensor + tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0 + tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1 + >>> dist.all_gather(tensor_list, tensor) + >>> tensor_list + [tensor([1.+1.j, 2.+2.j], device='cuda:0'), tensor([3.+3.j, 4.+4.j], device='cuda:0')] # Rank 0 + [tensor([1.+1.j, 2.+2.j], device='cuda:1'), tensor([3.+3.j, 4.+4.j], device='cuda:1')] # Rank 1 + + """ + # Dynamo has built-in logic to map legacy distributed ops to functional collectives. + # Let's redirect to a torch function mode that can mimic this logic outside Dynamo + # (e.g., non-strict export implements such a torch function mode). + relevant_args = (tensor,) + if has_torch_function(relevant_args): + return handle_torch_function( + all_gather, + relevant_args, + tensor_list, + tensor, + group=group, + async_op=async_op, + ) + + _check_tensor_list(tensor_list, "tensor_list") + _check_single_tensor(tensor, "tensor") + _ensure_all_tensors_same_dtype(tensor_list, tensor) + if _rank_not_in_group(group): + _warn_not_in_group("all_gather") + return + + tensor_list = [ + t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list + ] + tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor) + + group = group or _get_default_group() + opts = AllgatherOptions() + opts.asyncOp = async_op + work = group.allgather([tensor_list], [tensor], opts) + + if async_op: + return work + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level + work.wait() + # Otherwise, the backend has sync'ed at CPP level + + +@_exception_logger +def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=False): + """ + Gather tensors from all ranks and put them in a single output tensor. + + This function requires all tensors to be the same size on each process. + + Args: + output_tensor (Tensor): Output tensor to accommodate tensor elements + from all ranks. It must be correctly sized to have one of the + following forms: + (i) a concatenation of all the input tensors along the primary + dimension; for definition of "concatenation", see ``torch.cat()``; + (ii) a stack of all the input tensors along the primary dimension; + for definition of "stack", see ``torch.stack()``. + Examples below may better explain the supported output forms. + input_tensor (Tensor): Tensor to be gathered from current rank. + Different from the ``all_gather`` API, the input tensors in this + API must have the same size across all ranks. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + Examples: + >>> # xdoctest: +SKIP("need process group init") + >>> # All tensors below are of torch.int64 dtype and on CUDA devices. + >>> # We have two ranks. + >>> device = torch.device(f"cuda:{rank}") + >>> tensor_in = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank + >>> tensor_in + tensor([1, 2], device='cuda:0') # Rank 0 + tensor([3, 4], device='cuda:1') # Rank 1 + >>> # Output in concatenation form + >>> tensor_out = torch.zeros(world_size * 2, dtype=torch.int64, device=device) + >>> dist.all_gather_into_tensor(tensor_out, tensor_in) + >>> tensor_out + tensor([1, 2, 3, 4], device='cuda:0') # Rank 0 + tensor([1, 2, 3, 4], device='cuda:1') # Rank 1 + >>> # Output in stack form + >>> tensor_out2 = torch.zeros(world_size, 2, dtype=torch.int64, device=device) + >>> dist.all_gather_into_tensor(tensor_out2, tensor_in) + >>> tensor_out2 + tensor([[1, 2], + [3, 4]], device='cuda:0') # Rank 0 + tensor([[1, 2], + [3, 4]], device='cuda:1') # Rank 1 + """ + # Dynamo has built-in logic to map legacy distributed ops to functional collectives. + # Let's redirect to a torch function mode that can mimic this logic outside Dynamo + # (e.g., non-strict export implements such a torch function mode). + relevant_args = (input_tensor,) + if has_torch_function(relevant_args): + return handle_torch_function( + all_gather_into_tensor, + relevant_args, + output_tensor, + input_tensor, + group=group, + async_op=async_op, + ) + + _check_single_tensor(input_tensor, "input_tensor") + _check_single_tensor(output_tensor, "output_tensor") + if _rank_not_in_group(group): + _warn_not_in_group("all_gather_into_tensor") + return + + output_tensor = ( + output_tensor + if not output_tensor.is_complex() + else torch.view_as_real(output_tensor) + ) + input_tensor = ( + input_tensor + if not input_tensor.is_complex() + else torch.view_as_real(input_tensor) + ) + + opts = AllgatherOptions() + opts.asyncOp = async_op + + group = group or _get_default_group() + + if group in _world.pg_coalesce_state.keys(): + # We are in coalescing context, do not issue single operation, just append a collective representation + coll = _CollOp(all_gather_into_tensor, input_tensor, output_tensor) + _world.pg_coalesce_state[group].append(coll) + if async_op: + return _IllegalWork() + else: + return None + + work = group._allgather_base(output_tensor, input_tensor, opts) + + if async_op: + return work + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level + work.wait() + # Otherwise, the backend has sync'ed at CPP level + + +@_exception_logger +@deprecated( + "`torch.distributed._all_gather_base` is a private function and will be deprecated. " + "Please use `torch.distributed.all_gather_into_tensor` instead.", + category=FutureWarning, +) +def _all_gather_base(output_tensor, input_tensor, group=None, async_op=False): + """ + Single tensor all gather. Gathers a single tensor from all ranks, and puts them in a single output tensor. + + Args: + output_tensor (Tensor): Output tensor. It should contain + correctly-sized tensors to be used for output of the collective. + input_tensor (Tensor): Tensor to be broadcast from current process. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + .. warning:: + `_all_gather_base` is a private function. Users should use + `all_gather_into_tensor` instead. + + """ + return all_gather_into_tensor(output_tensor, input_tensor, group, async_op) + + +@_exception_logger +@deprecated( + "`torch.distributed.all_gather_coalesced` will be deprecated. If you must use it, " + "please revisit our documentation later at " + "https://pytorch.org/docs/main/distributed.html#collective-functions", + category=FutureWarning, +) +def all_gather_coalesced( + output_tensor_lists, input_tensor_list, group=None, async_op=False +): + """ + Gathers input tensors from the whole group in a list in a coalesced manner. + + Complex tensors are supported. + + Args: + output_tensor_lists (list[list[Tensor]]): Output list. It should contain + correctly-sized tensors to be used for output of the collective. + input_tensor_list (list[Tensor]): Tensors to be broadcast from + current process. At least one tensor has to be non empty. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + Example: + we have 2 process groups, 2 ranks. + rank 0 passes: + input_tensor_list = [[[1, 1], [1, 1]], [2], [3, 3]] + output_tensor_lists = + [[[[-1, -1], [-1, -1]], [-1], [-1, -1]], + [[[-1, -1], [-1, -1]], [-1], [-1, -1]]] + rank 1 passes: + input_tensor_list = [[[3, 3], [3, 3]], [5], [1, 1]] + output_tensor_lists = + [[[[-1, -1], [-1, -1]], [-1], [-1, -1]], + [[[-1, -1], [-1, -1]], [-1], [-1, -1]]] + both rank 0 and 1 get: + output_tensor_lists = + [[[1, 1], [1, 1]], [2], [3, 3]], + [[3, 3], [3, 3]], [5], [1, 1]]]. + + WARNING: at this time individual shape checking is not implemented across nodes. + For example, if the rank 0 node passes [torch.rand(4), torch.rand(2)] and the + rank 1 node passes [torch.rand(2), torch.rand(2), torch.rand(2)], the + all_gather_coalesced operation will proceed without complaint and return + erroneous outputs. This lack of shape checking results in significant + performance improvements but users of this function should take extra care + to ensure that each node passes in tensors whose shapes match across nodes. + """ + # We only check basic compatibility with C++ params here, C++ code will + # do shape and type checking. + if _rank_not_in_group(group): + _warn_not_in_group("all_gather_coalesced") + return + _check_tensor_list(input_tensor_list, "input_tensor_list") + _ensure_all_tensors_same_dtype(input_tensor_list) + if not isinstance(output_tensor_lists, list): + raise TypeError( + "Invalid function argument: output_tensor_lists should be a list" + ) + for output_tensor_list in output_tensor_lists: + _check_tensor_list(output_tensor_list, "output_tensor_lists") + _ensure_all_tensors_same_dtype(output_tensor_list) + + output_tensor_lists = [ + [t if not t.is_complex() else torch.view_as_real(t) for t in l] + for l in output_tensor_lists + ] + input_tensor_list = [ + t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list + ] + + group = group or _get_default_group() + opts = AllgatherOptions() + opts.asyncOp = async_op + work = group.allgather_coalesced(output_tensor_lists, input_tensor_list, opts) + + if async_op: + return work.get_future() + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level + work.wait() + # Otherwise, the backend has sync'ed at CPP level + + +def _validate_output_list_for_rank(my_rank, dst, gather_list): + if dst == my_rank: + if not gather_list: + raise ValueError( + "Argument ``gather_list`` must be specified on destination rank." + ) + elif gather_list: + raise ValueError( + "Argument ``gather_list`` must NOT be specified on non-destination ranks." + ) + + +@_exception_logger +def gather( + tensor: torch.Tensor, + gather_list: Optional[list[torch.Tensor]] = None, + dst: Optional[int] = None, + group: Optional[ProcessGroup] = None, + async_op: bool = False, + group_dst: Optional[int] = None, +): + """ + Gathers a list of tensors in a single process. + + This function requires all tensors to be the same size on each process. + + Args: + tensor (Tensor): Input tensor. + gather_list (list[Tensor], optional): List of appropriately, + same-sized tensors to use for gathered data + (default is None, must be specified on the destination rank) + dst (int, optional): Destination rank on global process group (regardless of ``group`` argument). + (If both ``dst`` and ``group_dst`` are None, default is global rank 0) + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op + group_dst (int, optional): Destination rank on ``group``. Invalid to specify both ``dst`` and ``group_dst`` + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + .. note:: Note that all Tensors in gather_list must have the same size. + + Example:: + >>> # xdoctest: +SKIP("no rank") + >>> # We have 2 process groups, 2 ranks. + >>> tensor_size = 2 + >>> device = torch.device(f'cuda:{rank}') + >>> tensor = torch.ones(tensor_size, device=device) + rank + >>> if dist.get_rank() == 0: + >>> gather_list = [torch.zeros_like(tensor, device=device) for i in range(2)] + >>> else: + >>> gather_list = None + >>> dist.gather(tensor, gather_list, dst=0) + >>> # Rank 0 gets gathered data. + >>> gather_list + [tensor([1., 1.], device='cuda:0'), tensor([2., 2.], device='cuda:0')] # Rank 0 + None # Rank 1 + + """ + _check_single_tensor(tensor, "tensor") + + # Parameter ``gather_list`` may be left unspecified on non-dst ranks. + if gather_list: + _check_tensor_list(gather_list, "gather_list") + else: + gather_list = [] + _ensure_all_tensors_same_dtype(tensor, gather_list) + group = _group_or_default_group(group) + if _rank_not_in_group(group): + _warn_not_in_group("gather") + return + if dst is None and group_dst is None: + dst = 0 + group_dst = _canonicalize_group_rank(group, dst, group_dst, return_global=False) + my_group_rank = group.rank() + _validate_output_list_for_rank(my_group_rank, group_dst, gather_list) + output_tensors = [gather_list] if group_dst == my_group_rank else [] + input_tensors = [tensor] + + opts = GatherOptions() + opts.rootRank = group_dst + opts.asyncOp = async_op + work = group.gather(output_tensors, input_tensors, opts) + + if async_op: + return work + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level + work.wait() + # Otherwise, the backend has sync'ed at CPP level + + +@_exception_logger +def scatter( + tensor: torch.Tensor, + scatter_list: Optional[list[torch.Tensor]] = None, + src: Optional[int] = None, + group: Optional[ProcessGroup] = None, + async_op: bool = False, + group_src: Optional[int] = None, +): + """ + Scatters a list of tensors to all processes in a group. + + Each process will receive exactly one tensor and store its data in the + ``tensor`` argument. + + Complex tensors are supported. + + Args: + tensor (Tensor): Output tensor. + scatter_list (list[Tensor]): List of tensors to scatter (default is + None, must be specified on the source rank) + src (int): Source rank on global process group (regardless of ``group`` argument). + (If both ``src`` and ``group_src`` are None, default is global rank 0) + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op + group_src (int, optional): Source rank on ``group``. Invalid to specify both ``src`` and ``group_src`` + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + .. note:: Note that all Tensors in scatter_list must have the same size. + + Example:: + >>> # xdoctest: +SKIP("need process group init") + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> tensor_size = 2 + >>> device = torch.device(f'cuda:{rank}') + >>> output_tensor = torch.zeros(tensor_size, device=device) + >>> if dist.get_rank() == 0: + >>> # Assumes world_size of 2. + >>> # Only tensors, all of which must be the same size. + >>> t_ones = torch.ones(tensor_size, device=device) + >>> t_fives = torch.ones(tensor_size, device=device) * 5 + >>> scatter_list = [t_ones, t_fives] + >>> else: + >>> scatter_list = None + >>> dist.scatter(output_tensor, scatter_list, src=0) + >>> # Rank i gets scatter_list[i]. + >>> output_tensor + tensor([1., 1.], device='cuda:0') # Rank 0 + tensor([5., 5.], device='cuda:1') # Rank 1 + + """ + _check_single_tensor(tensor, "tensor") + # Parameter ``scatter_list`` may be left unspecified on non-src ranks. + if scatter_list: + _check_tensor_list(scatter_list, "scatter_list") + else: + scatter_list = [] + _ensure_all_tensors_same_dtype(tensor, scatter_list) + group = _group_or_default_group(group) + if src is None and group_src is None: + src = 0 + group_src = _canonicalize_group_rank(group, src, group_src, return_global=False) + if _rank_not_in_group(group): + _warn_not_in_group("scatter") + return + scatter_list = [ + t if not t.is_complex() else torch.view_as_real(t) for t in scatter_list + ] + tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor) + + my_group_rank = group.rank() + if group_src == my_group_rank: + if not scatter_list: + raise ValueError( + "Argument ``scatter_list`` must be specified on source rank." + ) + input_tensors = [scatter_list] + output_tensors = [tensor] + else: + if scatter_list: + raise ValueError( + "Argument ``scatter_list`` must NOT be specified on non-source ranks." + ) + input_tensors = [] + output_tensors = [tensor] + + opts = ScatterOptions() + opts.rootRank = group_src + opts.asyncOp = async_op + work = group.scatter(output_tensors, input_tensors, opts) + + if async_op: + return work + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level + work.wait() + # Otherwise, the backend has sync'ed at CPP level + + +@_exception_logger +def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=False): + """ + Reduces, then scatters a list of tensors to all processes in a group. + + Args: + output (Tensor): Output tensor. + input_list (list[Tensor]): List of tensors to reduce and scatter. + op (optional): One of the values from + ``torch.distributed.ReduceOp`` + enum. Specifies an operation used for element-wise reductions. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group. + + """ + _check_single_tensor(output, "output") + _check_tensor_list(input_list, "input_list") + _ensure_all_tensors_same_dtype(output, input_list) + if _rank_not_in_group(group): + _warn_not_in_group("reduce_scatter") + return + + opts = ReduceScatterOptions() + opts.reduceOp = op + opts.asyncOp = async_op + + group = group or _get_default_group() + work = group.reduce_scatter([output], [input_list], opts) + + if async_op: + return work + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level + work.wait() + # Otherwise, the backend has sync'ed at CPP level + + +@_exception_logger +def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=False): + """ + Reduces, then scatters a tensor to all ranks in a group. + + Args: + output (Tensor): Output tensor. It should have the same size across all + ranks. + input (Tensor): Input tensor to be reduced and scattered. Its size + should be output tensor size times the world size. The input tensor + can have one of the following shapes: + (i) a concatenation of the output tensors along the primary + dimension, or + (ii) a stack of the output tensors along the primary dimension. + For definition of "concatenation", see ``torch.cat()``. + For definition of "stack", see ``torch.stack()``. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group. + + Examples: + >>> # xdoctest: +SKIP("need process group init") + >>> # All tensors below are of torch.int64 dtype and on CUDA devices. + >>> # We have two ranks. + >>> device = torch.device(f"cuda:{rank}") + >>> tensor_out = torch.zeros(2, dtype=torch.int64, device=device) + >>> # Input in concatenation form + >>> tensor_in = torch.arange(world_size * 2, dtype=torch.int64, device=device) + >>> tensor_in + tensor([0, 1, 2, 3], device='cuda:0') # Rank 0 + tensor([0, 1, 2, 3], device='cuda:1') # Rank 1 + >>> dist.reduce_scatter_tensor(tensor_out, tensor_in) + >>> tensor_out + tensor([0, 2], device='cuda:0') # Rank 0 + tensor([4, 6], device='cuda:1') # Rank 1 + >>> # Input in stack form + >>> tensor_in = torch.reshape(tensor_in, (world_size, 2)) + >>> tensor_in + tensor([[0, 1], + [2, 3]], device='cuda:0') # Rank 0 + tensor([[0, 1], + [2, 3]], device='cuda:1') # Rank 1 + >>> dist.reduce_scatter_tensor(tensor_out, tensor_in) + >>> tensor_out + tensor([0, 2], device='cuda:0') # Rank 0 + tensor([4, 6], device='cuda:1') # Rank 1 + + """ + # Dynamo has built-in logic to map legacy distributed ops to functional collectives. + # Let's redirect to a torch function mode that can mimic this logic outside Dynamo + # (e.g., non-strict export implements such a torch function mode). + relevant_args = (input,) + if has_torch_function(relevant_args): + return handle_torch_function( + reduce_scatter_tensor, + relevant_args, + output, + input, + op=op, + group=group, + async_op=async_op, + ) + + _check_single_tensor(output, "output") + _check_single_tensor(input, "input") + + if _rank_not_in_group(group): + _warn_not_in_group("reduce_scatter_tensor") + return + + opts = ReduceScatterOptions() + opts.reduceOp = op + opts.asyncOp = async_op + + group = group or _get_default_group() + + # Check if we are in coalescing context + # If we are, do not issue single operation, just append a collective representation + if group in _world.pg_coalesce_state.keys(): + coll = _CollOp(reduce_scatter_tensor, input, output, op, None) + _world.pg_coalesce_state[group].append(coll) + if async_op: + return _IllegalWork() + else: + return None + + work = group._reduce_scatter_base(output, input, opts) + + if async_op: + return work + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level + work.wait() + # Otherwise, the backend has sync'ed at CPP level + + +@deprecated( + "`torch.distributed._reduce_scatter_base` is a private function and will be deprecated. " + "Please use `torch.distributed.reduce_scatter_tensor` instead.", + category=FutureWarning, +) +def _reduce_scatter_base(output, input, op=ReduceOp.SUM, group=None, async_op=False): + """ + Reduces, then scatters a flattened tensor to all processes in a group. + + Args: + output (Tensor): Output tensor. + input (Tensor): Input tensor that is of size output tensor size times world size + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group. + + .. warning:: + `_reduce_scatter_base` is a private function. Users should use + `reduce_scatter_tensor` instead. + + """ + return reduce_scatter_tensor(output, input, op, group, async_op) + + +@_exception_logger +def all_to_all_single( + output, + input, + output_split_sizes=None, + input_split_sizes=None, + group=None, + async_op=False, +): + """ + Split input tensor and then scatter the split list to all processes in a group. + + Later the received tensors are concatenated from all the processes in the group + and returned as a single output tensor. + + Complex tensors are supported. + + Args: + output (Tensor): Gathered concatenated output tensor. + input (Tensor): Input tensor to scatter. + output_split_sizes: (list[Int], optional): Output split sizes for dim 0 + if specified None or empty, dim 0 of ``output`` tensor must divide + equally by ``world_size``. + input_split_sizes: (list[Int], optional): Input split sizes for dim 0 + if specified None or empty, dim 0 of ``input`` tensor must divide + equally by ``world_size``. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group. + + .. warning:: + `all_to_all_single` is experimental and subject to change. + + Examples: + >>> # xdoctest: +SKIP("Undefined rank") + >>> input = torch.arange(4) + rank * 4 + >>> input + tensor([0, 1, 2, 3]) # Rank 0 + tensor([4, 5, 6, 7]) # Rank 1 + tensor([8, 9, 10, 11]) # Rank 2 + tensor([12, 13, 14, 15]) # Rank 3 + >>> output = torch.empty([4], dtype=torch.int64) + >>> dist.all_to_all_single(output, input) + >>> output + tensor([0, 4, 8, 12]) # Rank 0 + tensor([1, 5, 9, 13]) # Rank 1 + tensor([2, 6, 10, 14]) # Rank 2 + tensor([3, 7, 11, 15]) # Rank 3 + + >>> # Essentially, it is similar to following operation: + >>> scatter_list = list(input.chunk(world_size)) + >>> gather_list = list(output.chunk(world_size)) + >>> for i in range(world_size): + >>> dist.scatter(gather_list[i], scatter_list if i == rank else [], src = i) + + >>> # Another example with uneven split + >>> input + tensor([0, 1, 2, 3, 4, 5]) # Rank 0 + tensor([10, 11, 12, 13, 14, 15, 16, 17, 18]) # Rank 1 + tensor([20, 21, 22, 23, 24]) # Rank 2 + tensor([30, 31, 32, 33, 34, 35, 36]) # Rank 3 + >>> input_splits + [2, 2, 1, 1] # Rank 0 + [3, 2, 2, 2] # Rank 1 + [2, 1, 1, 1] # Rank 2 + [2, 2, 2, 1] # Rank 3 + >>> output_splits + [2, 3, 2, 2] # Rank 0 + [2, 2, 1, 2] # Rank 1 + [1, 2, 1, 2] # Rank 2 + [1, 2, 1, 1] # Rank 3 + >>> output = ... + >>> dist.all_to_all_single(output, input, output_splits, input_splits) + >>> output + tensor([ 0, 1, 10, 11, 12, 20, 21, 30, 31]) # Rank 0 + tensor([ 2, 3, 13, 14, 22, 32, 33]) # Rank 1 + tensor([ 4, 15, 16, 23, 34, 35]) # Rank 2 + tensor([ 5, 17, 18, 24, 36]) # Rank 3 + + + >>> # Another example with tensors of torch.cfloat type. + >>> input = torch.tensor( + ... [1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j], dtype=torch.cfloat + ... ) + 4 * rank * (1 + 1j) + >>> input + tensor([1+1j, 2+2j, 3+3j, 4+4j]) # Rank 0 + tensor([5+5j, 6+6j, 7+7j, 8+8j]) # Rank 1 + tensor([9+9j, 10+10j, 11+11j, 12+12j]) # Rank 2 + tensor([13+13j, 14+14j, 15+15j, 16+16j]) # Rank 3 + >>> output = torch.empty([4], dtype=torch.int64) + >>> dist.all_to_all_single(output, input) + >>> output + tensor([1+1j, 5+5j, 9+9j, 13+13j]) # Rank 0 + tensor([2+2j, 6+6j, 10+10j, 14+14j]) # Rank 1 + tensor([3+3j, 7+7j, 11+11j, 15+15j]) # Rank 2 + tensor([4+4j, 8+8j, 12+12j, 16+16j]) # Rank 3 + """ + # Dynamo has built-in logic to map legacy distributed ops to functional collectives. + # Let's redirect to a torch function mode that can mimic this logic outside Dynamo + # (e.g., non-strict export implements such a torch function mode). + relevant_args = (input,) + if has_torch_function(relevant_args): + return handle_torch_function( + all_to_all_single, + relevant_args, + output, + input, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=async_op, + ) + + if _rank_not_in_group(group): + _warn_not_in_group("all_to_all_single") + return + + opts = AllToAllOptions() + opts.asyncOp = async_op + _check_single_tensor(output, "output") + _check_single_tensor(input, "input") + _ensure_all_tensors_same_dtype(output, input) + + if input.is_complex(): + input = torch.view_as_real(input) + if output.is_complex(): + output = torch.view_as_real(output) + + output_split_sizes = [] if output_split_sizes is None else output_split_sizes + input_split_sizes = [] if input_split_sizes is None else input_split_sizes + + group = group or _get_default_group() + work = group.alltoall_base( + output, input, output_split_sizes, input_split_sizes, opts + ) + + if async_op: + return work + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level + work.wait() + # Otherwise, the backend has sync'ed at CPP level + + +@_exception_logger +def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False): + """ + Scatters list of input tensors to all processes in a group and return gathered list of tensors in output list. + + Complex tensors are supported. + + Args: + output_tensor_list (list[Tensor]): List of tensors to be gathered one + per rank. + input_tensor_list (list[Tensor]): List of tensors to scatter one per rank. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group. + + .. warning:: + `all_to_all` is experimental and subject to change. + + Examples: + >>> # xdoctest: +SKIP("Undefined rank") + >>> input = torch.arange(4) + rank * 4 + >>> input = list(input.chunk(4)) + >>> input + [tensor([0]), tensor([1]), tensor([2]), tensor([3])] # Rank 0 + [tensor([4]), tensor([5]), tensor([6]), tensor([7])] # Rank 1 + [tensor([8]), tensor([9]), tensor([10]), tensor([11])] # Rank 2 + [tensor([12]), tensor([13]), tensor([14]), tensor([15])] # Rank 3 + >>> output = list(torch.empty([4], dtype=torch.int64).chunk(4)) + >>> dist.all_to_all(output, input) + >>> output + [tensor([0]), tensor([4]), tensor([8]), tensor([12])] # Rank 0 + [tensor([1]), tensor([5]), tensor([9]), tensor([13])] # Rank 1 + [tensor([2]), tensor([6]), tensor([10]), tensor([14])] # Rank 2 + [tensor([3]), tensor([7]), tensor([11]), tensor([15])] # Rank 3 + + >>> # Essentially, it is similar to following operation: + >>> scatter_list = input + >>> gather_list = output + >>> for i in range(world_size): + >>> dist.scatter(gather_list[i], scatter_list if i == rank else [], src=i) + + >>> input + tensor([0, 1, 2, 3, 4, 5]) # Rank 0 + tensor([10, 11, 12, 13, 14, 15, 16, 17, 18]) # Rank 1 + tensor([20, 21, 22, 23, 24]) # Rank 2 + tensor([30, 31, 32, 33, 34, 35, 36]) # Rank 3 + >>> input_splits + [2, 2, 1, 1] # Rank 0 + [3, 2, 2, 2] # Rank 1 + [2, 1, 1, 1] # Rank 2 + [2, 2, 2, 1] # Rank 3 + >>> output_splits + [2, 3, 2, 2] # Rank 0 + [2, 2, 1, 2] # Rank 1 + [1, 2, 1, 2] # Rank 2 + [1, 2, 1, 1] # Rank 3 + >>> input = list(input.split(input_splits)) + >>> input + [tensor([0, 1]), tensor([2, 3]), tensor([4]), tensor([5])] # Rank 0 + [tensor([10, 11, 12]), tensor([13, 14]), tensor([15, 16]), tensor([17, 18])] # Rank 1 + [tensor([20, 21]), tensor([22]), tensor([23]), tensor([24])] # Rank 2 + [tensor([30, 31]), tensor([32, 33]), tensor([34, 35]), tensor([36])] # Rank 3 + >>> output = ... + >>> dist.all_to_all(output, input) + >>> output + [tensor([0, 1]), tensor([10, 11, 12]), tensor([20, 21]), tensor([30, 31])] # Rank 0 + [tensor([2, 3]), tensor([13, 14]), tensor([22]), tensor([32, 33])] # Rank 1 + [tensor([4]), tensor([15, 16]), tensor([23]), tensor([34, 35])] # Rank 2 + [tensor([5]), tensor([17, 18]), tensor([24]), tensor([36])] # Rank 3 + + >>> # Another example with tensors of torch.cfloat type. + >>> input = torch.tensor( + ... [1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j], dtype=torch.cfloat + ... ) + 4 * rank * (1 + 1j) + >>> input = list(input.chunk(4)) + >>> input + [tensor([1+1j]), tensor([2+2j]), tensor([3+3j]), tensor([4+4j])] # Rank 0 + [tensor([5+5j]), tensor([6+6j]), tensor([7+7j]), tensor([8+8j])] # Rank 1 + [tensor([9+9j]), tensor([10+10j]), tensor([11+11j]), tensor([12+12j])] # Rank 2 + [tensor([13+13j]), tensor([14+14j]), tensor([15+15j]), tensor([16+16j])] # Rank 3 + >>> output = list(torch.empty([4], dtype=torch.int64).chunk(4)) + >>> dist.all_to_all(output, input) + >>> output + [tensor([1+1j]), tensor([5+5j]), tensor([9+9j]), tensor([13+13j])] # Rank 0 + [tensor([2+2j]), tensor([6+6j]), tensor([10+10j]), tensor([14+14j])] # Rank 1 + [tensor([3+3j]), tensor([7+7j]), tensor([11+11j]), tensor([15+15j])] # Rank 2 + [tensor([4+4j]), tensor([8+8j]), tensor([12+12j]), tensor([16+16j])] # Rank 3 + + """ + if _rank_not_in_group(group): + _warn_not_in_group("all_to_all") + return + + opts = AllToAllOptions() + opts.asyncOp = async_op + _check_tensor_list(output_tensor_list, "output_tensor_list") + _check_tensor_list(input_tensor_list, "input_tensor_list") + _ensure_all_tensors_same_dtype(output_tensor_list, input_tensor_list) + + input_tensor_list = [ + t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list + ] + output_tensor_list = [ + t if not t.is_complex() else torch.view_as_real(t) for t in output_tensor_list + ] + + group = group or _get_default_group() + work = group.alltoall(output_tensor_list, input_tensor_list, opts) + + if async_op: + return work + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level + work.wait() + # Otherwise, the backend has sync'ed at CPP level + + +@_exception_logger +def barrier( + group: Optional[ProcessGroup] = GroupMember.WORLD, async_op=False, device_ids=None +): + """ + Synchronize all processes. + + This collective blocks processes until the whole group enters this function, + if async_op is False, or if async work handle is called on wait(). + + Args: + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op + device_ids ([int], optional): List of device/GPU ids. Only one id is expected. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + .. note:: `ProcessGroupNCCL` now blocks the cpu thread till the completion of the barrier collective. + """ + group = group or _get_default_group() + + if _rank_not_in_group(group): + _warn_not_in_group("barrier") + return + + opts = BarrierOptions() + opts.asyncOp = async_op + # Detect the accelerator on the machine. If no accelerator is available, it + # returns CPU. + device = torch._C._get_accelerator() + if isinstance(device_ids, list): + opts.device_ids = device_ids + # use only the first device id + opts.device = torch.device(device.type, device_ids[0]) + elif getattr(group, "bound_device_id", None) is not None: + # Use device id from `init_process_group(device_id=...)` + opts.device = group.bound_device_id # type: ignore[assignment] + elif device.type == "cpu" or _get_object_coll_device(group) == "cpu": + opts.device = torch.device("cpu") + else: + # Use the current device set by the user. If user did not set any, this + # may use default device 0, causing issues like hang or all processes + # creating context on device 0. + opts.device = device + warnings.warn( # warn only once + "No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user. " + ) + + work = group.barrier(opts=opts) + + if async_op: + return work + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level + work.wait() + # Otherwise, the backend has sync'ed at CPP level + + +def monitored_barrier( + group: Optional[ProcessGroup] = GroupMember.WORLD, + timeout=None, + wait_all_ranks=False, +): + """ + Synchronize processes similar to ``torch.distributed.barrier``, but consider a configurable timeout. + + It is able to report ranks that did not pass this barrier within the provided timeout. + Specifically, for non-zero ranks, will block until a send/recv is processed from rank 0. + Rank 0 will block until all send /recv from other ranks are processed, and will report + failures for ranks that failed to respond in time. Note that if one rank does not reach the + monitored_barrier (for example due to a hang), all other ranks would fail in monitored_barrier. + + This collective will block all processes/ranks in the group, until the + whole group exits the function successfully, making it useful for debugging + and synchronizing. However, it can have a performance impact and should only + be used for debugging or scenarios that require full synchronization points + on the host-side. For debugging purposes, this barrier can be inserted + before the application's collective calls to check if any ranks are + desynchronized. + + .. note:: Note that this collective is only supported with the GLOO backend. + + Args: + group (ProcessGroup, optional): The process group to work on. If + ``None``, the default process group will be used. + timeout (datetime.timedelta, optional): Timeout for monitored_barrier. + If ``None``, the default process group timeout will be used. + wait_all_ranks (bool, optional): Whether to collect all failed ranks or + not. By default, this is ``False`` and ``monitored_barrier`` on rank 0 + will throw on the first failed rank it encounters in order to fail + fast. By setting ``wait_all_ranks=True`` ``monitored_barrier`` will + collect all failed ranks and throw an error containing information + about all failed ranks. + + Returns: + ``None``. + + Example:: + >>> # xdoctest: +SKIP("need process group init") + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> if dist.get_rank() != 1: + >>> dist.monitored_barrier() # Raises exception indicating that + >>> # rank 1 did not call into monitored_barrier. + >>> # Example with wait_all_ranks=True + >>> if dist.get_rank() == 0: + >>> dist.monitored_barrier(wait_all_ranks=True) # Raises exception + >>> # indicating that ranks 1, 2, ... world_size - 1 did not call into + >>> # monitored_barrier. + """ + # Need to call rank not in group before using the group, otherwise + # "Invalid process group" error is raised. + if _rank_not_in_group(group): + _warn_not_in_group("monitored_barrier") + return + + if get_backend(group) != Backend.GLOO: + raise ValueError("monitored_barrier is only implemented for GLOO backend.") + + if timeout is None: + timeout = _get_default_timeout(get_backend(group)) + elif isinstance(timeout, float): + # TODO(whc) apparently some existing test case for monitored_barrier passes in a timeout in float format? + warnings.warn( + "Please specify timeout arg as a timedelta. " + f"Converting current value of {timeout} assuming it represents seconds", + ) + timeout = timedelta(seconds=timeout) + + _check_valid_timeout(timeout) + + group_to_use = _get_default_group() if group is None else group + return group_to_use.monitored_barrier( # type:ignore[attr-defined] + timeout, wait_all_ranks=wait_all_ranks + ) + + +def _create_process_group_wrapper( + wrapped_pg: torch._C._distributed_c10d.Backend, + store_prefix: str, + store: Store, + rank: int, + world_size: int, + timeout: timedelta = default_pg_timeout, +): + assert _GLOO_AVAILABLE, "ProcessGroupWrapper unsupported without GLOO backend." + + # (whc) this appears to be just for the gloo backend? if so, `default_pg_timeout` is appropriate... + + # Create a separate prefix store for the helper process group. + prefix = f"{PG_WRAPPER_STORE_PREFIX}:{store_prefix}" + store = PrefixStore(prefix, store) + helper_pg = ProcessGroupGloo(store, rank, world_size, timeout=timeout) + # Wrap the underlying pg with ProcessGroupWrapper. + wrapped_pg = _ProcessGroupWrapper(wrapped_pg, helper_pg) + return wrapped_pg + + +# helper function for deterministically hashing a list of ranks to a unique +# string +def _hash_ranks_to_str(ranks: list[int]) -> str: + rank_join: str = "_".join(map(str, ranks)) + # In case there is already a PG with the same rank composition + unique_str = "_".join([rank_join, str(len(_world.pg_names))]) + return hashlib.sha1(bytes(unique_str, "utf-8"), usedforsecurity=False).hexdigest() + + +# Takes a list of ranks and computes an integer color +def _process_group_color(ranks: list[int]) -> int: + # Convert list to tuple to make it hashable + ranks = tuple(ranks) + hash_value = hash(ranks) + # Split color must be: + # - a non-negative integer; + # - a type compatible with C's int because we are pybinding to the latter. + # Thus, we limit the hash value within c_int's max value. + max_c_int = 2 ** (ctypes.sizeof(ctypes.c_int) * 8 - 1) + color = abs(hash_value) % max_c_int + return color + + +def _process_group_name(ranks, use_hashed_name): + # Create name for a process group. + global _world + if use_hashed_name: + pg_name = _hash_ranks_to_str(ranks) + else: + pg_name = str(_world.group_count) + _world.group_count += 1 + # TODO: why is group count incremented only in the else path? + return pg_name + + +def _get_backend_from_str(backend: Optional[str] = None) -> Backend: + # Default to the same backend as the global process group + # if backend is not specified. + if not backend: + backend = get_backend(_get_default_group()) + return Backend(backend) + + +def _is_safe_to_split() -> bool: + """ + Checks if it is safe to split the any process group in the world. + This is only safe if the default pg has a bound device id, otherwise + users must be aware that a pg is only splittable after the first collective is + issued. + """ + return False if _get_default_group().bound_device_id is None else True + + +@_time_logger +def split_group( + parent_pg: Optional[ProcessGroup] = None, + split_ranks: Optional[list] = None, + timeout: Optional[timedelta] = None, + pg_options: Optional[Any] = None, + group_desc: Optional[str] = None, +) -> Optional[ProcessGroup]: + """ + Create a new process group split from the given parent process group. + + warning:: This is an experimental API. Only the ``NCCL`` and custom plugin backends + are supported. Other backends will raise an error. + Users of this API must guarantee that all ranks in the parent group enter this API call, + and the split of the sub groups is the same across all ranks in the parent group. + + Args: + parent_pg (ProcessGroup, optional): The parent process group. If None, + the default process group will be used. Users need to guarantee that + the parent group is fully initialized (e.g, communicators are initialized) + split_ranks (list[list[int]]): the split ranks, which is a list of list of ranks. + Users need to make sure the validity of the split ranks such that one + split (represented by one inner list of ints) does not overlap with any other split. + Note that the ranks in each split is the group rank (instead of global rank) + in the parent pg. For example, if the parent group has 4 ranks, and split_ranks can be + [[0, 1], [2, 3]]. Note [[0,1]] is also a valid split, in which case ranks 2, 3 would + return a non-group member. + timeout (timedelta, optional): see `init_process_group` for details and default value. + pg_options (ProcessGroupOptions, optional): Additional options need to be passed in during + the construction of specific process groups. i.e.``is_high_priority_stream`` + can be specified so that process group can pick up high priority cuda streams. + group_desc (str, optional): a string to describe the process group. + + Returns: + ProcessGroup if the current rank is within one split/subgroup given by split_ranks, + or None if the current rank is not part of any split_ranks`. + + """ + # check inputs + if split_ranks is None: + raise ValueError("split_ranks cannot be None") + + global _world + default_pg = _get_default_group() + device_id = default_pg.bound_device_id + if not device_id: + raise RuntimeError( + "No device associated with the default pg, not safe to split any process groups" + ) + _default_backend, default_store = _world.pg_map[default_pg] + global_rank = default_pg.rank() + global_world_size = default_pg.size() + + if not parent_pg: + parent_pg = default_pg + if parent_pg not in _world.pg_group_ranks: + raise ValueError(f"Group {parent_pg} is not registered") + + parent_global_to_group_ranks = _world.pg_group_ranks[parent_pg] + parent_group_to_global_ranks = { + group_rank: global_rank + for global_rank, group_rank in parent_global_to_group_ranks.items() + } + + if global_rank not in parent_global_to_group_ranks: + raise ValueError( + f"Global rank {global_rank} is not part of the parent group {parent_pg}" + ) + + parent_group_rank = parent_global_to_group_ranks[global_rank] + parent_backend = parent_pg._get_backend(torch.device("cuda")) + + # if the parent backend does not support splitting, raise error + # currently this API only support NCCL backend + if not parent_backend or not parent_backend.supports_splitting: + raise RuntimeError( + "No backend for the parent process group or its backend does not support splitting" + ) + + # set the group_desc before the color or no_cloor split + group_desc = ( + f"{parent_pg.group_desc}:split:{parent_backend.comm_split_count()}" # type: ignore[attr-defined] + if group_desc is None + else group_desc + ) + + parent_backend_str, _ = _world.pg_map[parent_pg] + # same type of backend as the parent process group + backend = Backend(parent_backend_str) + backend_config = BackendConfig(backend) + + if pg_options is None: + # default pg_options same as the parent process group + pg_options = parent_backend.options + + # this timeout defaulting/validation is used for all the new_groups/new_subgroups variants, + # which may just pass their timeout value (or None) + if timeout is None: + timeout = _get_default_timeout(backend) + _check_valid_timeout(timeout) + + # find my group of ranks and my group local rank in split_ranks + my_group = None + group_rank = -1 + + for split_group in split_ranks: + if len(split_group) == 0: + raise ValueError("the split group cannot be empty") + if len(split_group) > global_world_size: + raise ValueError( + "the split group's size should be less or equal to the world_size set by init_process_group" + ) + if len(split_group) != len(set(split_group)): + raise ValueError("the split group cannot have duplicate ranks") + split_group = sorted(split_group) + if parent_group_rank in split_group: + my_group = split_group + group_rank = split_group.index(parent_group_rank) + break + # if my rank does not belong to any sub group, + # no_color split should be called + if my_group is None or group_rank == -1: + parent_backend.perform_nocolor_split(device_id) # type: ignore[attr-defined] + return None + + group_name = _process_group_name(my_group, use_hashed_name=False) + global_ranks_in_my_group = [parent_group_to_global_ranks[rank] for rank in my_group] + + prefix_store = PrefixStore(f"{group_name}/", default_store) + # We register the backend after initializing and timeout is set in pg_options. + pg: ProcessGroup = ProcessGroup( + prefix_store, + group_rank, + len(my_group), + ) + pg.bound_device_id = device_id # type: ignore[union-attr] + pg_options._timeout = timeout # type: ignore[union-attr] + pg_options.split_from = parent_backend # type: ignore[union-attr] + pg_options.split_color = _process_group_color(my_group) # type: ignore[union-attr] + pg_options.global_ranks_in_group = global_ranks_in_my_group # type: ignore[union-attr] + pg_options.group_name = group_name # type: ignore[union-attr] + + if parent_backend_str == Backend.NCCL: + backend_type = ProcessGroup.BackendType.NCCL + if not isinstance(pg_options, ProcessGroupNCCL.Options): + raise RuntimeError( + "Expected pg_options argument to be of type ProcessGroupNCCL.Options" + ) + backend_class = ProcessGroupNCCL( + prefix_store, group_rank, len(my_group), pg_options + ) + else: + assert parent_backend_str.upper() in Backend._plugins, ( + f"Unknown c10d backend type {parent_backend_str.upper()}" + ) + backend_plugin = Backend._plugins[parent_backend_str.upper()] + creator_fn = backend_plugin.creator_fn + extended_api = backend_plugin.extended_api + backend_type = ProcessGroup.BackendType.CUSTOM + if not extended_api: + backend_class = creator_fn(prefix_store, group_rank, len(my_group), timeout) + else: + dist_backend_opts = _DistributedBackendOptions() + dist_backend_opts.store = prefix_store + dist_backend_opts.group_rank = group_rank + dist_backend_opts.group_size = len(my_group) + backend_class = creator_fn(dist_backend_opts, pg_options) + + pg._set_default_backend(backend_type) + backend_class._set_sequence_number_for_group() + + pg._register_backend(torch.device("cuda"), backend_type, backend_class) + + # set group_name and group_desc to backend + assert group_name is not None + assert group_desc is not None + pg._set_group_name(group_name) + pg._set_group_desc(group_desc) + + # always eagerly initialize the backend in split_group + eager_backend = pg._get_backend(device_id) + eager_backend.eager_connect_single_device(device_id) + + # update global state + _world.pg_map[pg] = (backend, prefix_store) + _world.pg_names[pg] = group_name + _register_process_group(group_name, pg) + _world.pg_backend_config[pg] = str(backend_config) + pg_tag = f"ptd:{group_name}" + _world.tags_to_pg.setdefault(pg_tag, []).append(pg) + _world.pg_to_tag[pg] = pg_tag + + # Create the global rank to group rank mapping + _world.pg_group_ranks[pg] = { + global_rank: group_rank + for group_rank, global_rank in enumerate(global_ranks_in_my_group) + } + + return pg + + +@_time_logger +def new_group( + ranks=None, + timeout=None, + backend=None, + pg_options=None, + use_local_synchronization=False, + group_desc=None, + device_id: Optional[torch.device] = None, +): + """ + Create a new distributed group. + + This function requires that all processes in the main group (i.e. all + processes that are part of the distributed job) enter this function, even + if they are not going to be members of the group. Additionally, groups + should be created in the same order in all processes. + + .. warning:: + Safe concurrent usage: + When using multiple process groups with the ``NCCL`` backend, the user + must ensure a globally consistent execution order of collectives across + ranks. + + If multiple threads within a process issue collectives, explicit + synchronization is necessary to ensure consistent ordering. + + When using async variants of torch.distributed communication APIs, + a work object is returned and the communication kernel is + enqueued on a separate CUDA stream, allowing overlap of communication + and computation. Once one or more async ops have been issued on one process + group, they must be synchronized with other cuda streams by calling `work.wait()` + before using another process group. + + See `Using multiple NCCL communicators concurrently + ` + for more details. + + Args: + ranks (list[int]): List of ranks of group members. If ``None``, will be + set to all ranks. Default is ``None``. + timeout (timedelta, optional): see `init_process_group` for details and default value. + backend (str or Backend, optional): The backend to use. Depending on + build-time configurations, valid values are ``gloo`` and ``nccl``. + By default uses the same backend as the global group. This field + should be given as a lowercase string (e.g., ``"gloo"``), which can + also be accessed via :class:`Backend` attributes (e.g., + ``Backend.GLOO``). If ``None`` is passed in, the backend + corresponding to the default process group will be used. Default is + ``None``. + pg_options (ProcessGroupOptions, optional): process group options + specifying what additional options need to be passed in during + the construction of specific process groups. i.e. for the ``nccl`` + backend, ``is_high_priority_stream`` can be specified so that + process group can pick up high priority cuda streams. For other available options to config nccl, + See https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-tuse_local_synchronization + (bool, optional): perform a group-local barrier at the end of the process group creation. + This is different in that non-member ranks don't need to call into API and don't + join the barrier. + group_desc (str, optional): a string to describe the process group. + device_id (torch.device, optional): a single, specific device + to "bind" this process to, The `new_group` call will try to initialize + a communication backend immediately for the device if this field is given. + + Returns: + A handle of distributed group that can be given to collective calls or + GroupMember.NON_GROUP_MEMBER if the rank is not part of ``ranks``. + + N.B. use_local_synchronization doesn't work with MPI. + + N.B. While use_local_synchronization=True can be significantly faster with larger + clusters and small process groups, care must be taken since it changes cluster behavior + as non-member ranks don't join the group barrier(). + + N.B. use_local_synchronization=True can lead to deadlocks when each rank creates + multiple overlapping process groups. To avoid that, make sure all ranks follow the + same global creation order. + """ + return _new_group_with_tag( + ranks, + timeout, + backend, + pg_options, + None, + use_local_synchronization=use_local_synchronization, + group_desc=group_desc, + device_id=device_id, + ) + + +def _new_group_with_tag( + ranks=None, + timeout=None, + backend=None, + backend_options=None, + pg_tag=None, + use_local_synchronization=False, + group_desc=None, + device_id: Optional[torch.device] = None, +): + """ + Variant of ``new_group`` that exposes tag creation. + + :: N.B. The mechanism is experimental and tied to the functional collectives effort, see + ``torch.distributed._functional_collectives`` for reference on how to use it. + """ + global _world + + default_pg = _get_default_group() + if device_id is None: + device_id = default_pg.bound_device_id + elif default_pg.bound_device_id is not None: + assert device_id == default_pg.bound_device_id, ( + "Mismatched bound device between new pg and the default pg." + ) + default_backend, default_store = _world.pg_map[default_pg] + global_rank = default_pg.rank() + global_world_size = default_pg.size() + + # Default to the same backend as the global process group + # if the backend is not specified. + if not backend: + backend = default_backend + backend = Backend(backend) + + # this timeout defaulting/validation is used for all the new_groups/new_subgroups variants, + # which may just pass their timeout value (or None) + if timeout is None: + timeout = _get_default_timeout(backend) + _check_valid_timeout(timeout) + + if use_local_synchronization: + # MPI backend doesn't have have a way for us to perform a partial sync + if backend == Backend.MPI: + raise ValueError( + "MPI backend doesn't support use_local_synchronization=True" + ) + if ranks is not None and get_rank() not in ranks: + return None + + # checks the input ranks + if ranks is not None: + ranks = sorted(ranks) + group_world_size = len(ranks) + if group_world_size > global_world_size: + raise ValueError( + "the new group's world size should be less or " + "equal to the world size set by " + "init_process_group" + ) + # check ranks' sanity + for rank in ranks: + if rank < 0 or rank >= global_world_size: + raise ValueError( + "The new group's rank should be within " + "the world_size set by init_process_group" + ) + if global_rank in ranks: + group_rank = ranks.index(global_rank) + else: + group_rank = None + else: + ranks = list(range(global_world_size)) + group_world_size = global_world_size + group_rank = global_rank + + group_name = _process_group_name(ranks, use_hashed_name=use_local_synchronization) + + pg, pg_store = _new_process_group_helper( + group_world_size, + group_rank, + ranks, + backend, + default_store, + group_name, + backend_options=backend_options, + timeout=timeout, + pg_tag=pg_tag, + device_id=device_id, + group_desc=group_desc, + ) + + # Create the global rank to group rank mapping + _world.pg_group_ranks[pg] = { + global_rank: group_rank for group_rank, global_rank in enumerate(ranks) + } + + if _is_barrier_after_init() == 1: + # barrier at the end to ensure that once we return from this method, all + # process groups including global variables (if any) are updated + # correctly on all ranks. + # Update 04/2023: for large-scale runs, this barrier (esp. store-based + # barrier) may be costly and/or unscalable. Also, in a lot of cases, + # these barriers may be unnecessary, as proven by a green CI after + # removal. An environment variable `TORCH_DIST_INIT_BARRIER` has been + # added which enables this barrier only when set to 1. + logger.info( + "Performing barrier after ProcessGroup initialization since " + "TORCH_DIST_INIT_BARRIER = 1" + ) + if backend == Backend.MPI: + # MPI doesn't have store. + barrier() + else: + barrier_store = pg_store if use_local_synchronization else default_store + world_size = len(ranks) if use_local_synchronization else get_world_size() + # Use store based barrier here since barrier() used a bunch of + # default devices and messes up NCCL internal state. + _store_based_barrier( + global_rank, barrier_store, group_name, world_size, timeout + ) + + return pg + + +def new_subgroups( + group_size=None, + group=None, + timeout=None, + backend=None, + pg_options=None, + group_desc=None, +): + """ + Create subgroups of equal size. + + By default, it creates intra-machine subgroups, + where each of which contains all the ranks of a machine, based on the assumption + that each machine has the same number of devices. + + This is a convenience API that calls ``new_group`` to generate multiple subgroups. + It requires that all processes in the main group (i.e. all + processes that are part of the distributed job) enter this function, even + if they are not going to be members of the group. + + .. warning:: + If ``group_size`` is passed in, the world size must be divisible by ``group_size``. + If no ``group_size`` is passed in, it believe that you are creating a group based + on CUDA and determining the group size by number of CUDA devices, and if not all + the machines have the same number of devices, the subgroup division will be + different across nodes and can cause unexpected behaviors. Therefore, if you are + creating a subgroup that does not depend on CUDA (such as Gloo on CPU), please + pass in ``group_size`` correctly. + + .. warning:: + See warning `Safe concurrent usage` for `new_group` API for important details about + using multiple process groups concurrently in a safe manner. + + Args: + group_size (int, optional): The size of each subgroup. If ``None``, + the default subgroup size is equal to the number of devices on each machine, + based on the assumption that each machine has exactly the same + number of devices. Default is ``None``. + group (ProcessGroup, optional): The process group to work on. If + ``None``, the default process group will be used. Default is ``None``. + timeout (timedelta, optional): see `init_process_group` for details and default value. + backend (str or Backend, optional): The backend to use. Depending on + build-time configurations, valid values are ``gloo`` and ``nccl``. + By default uses the same backend as the global group. This field + should be given as a lowercase string (e.g., ``"gloo"``), which can + also be accessed via :class:`Backend` attributes (e.g., + ``Backend.GLOO``). If ``None`` is passed in, the backend + corresponding to the default process group will be used. Default is + ``None``. + pg_options (ProcessGroupOptions, optional): process group options + specifying what additional options need to be passed in during + the construction of specific process groups. i.e. for the ``nccl`` + backend, ``is_high_priority_stream`` can be specified so that + process group can pick up high priority cuda streams. + group_desc (str, optional): A string describing the group. Each subgroup will + inherit its group_desc + + Returns: + The subgroup containing the current rank, and all the subgroups used for cleanup. + + Examples: + >>> # Create intra-machine subgroups. + >>> # xdoctest: +SKIP("need process group init") + >>> cur_subgroup, subgroups = dist.new_subgroups() + >>> # Allreduce within the machine. + >>> rank = dist.get_rank() + >>> tensor = torch.ones(1, device=rank) * rank + >>> dist.all_reduce(tensor, group=cur_subgroup) + >>> tensor + tensor([28]) # Assume 8 CUDA devices per machine. 28 is sum(range(8)). + >>> # Cleanup. + >>> for subgroup in subgroups: + >>> dist.destroy_process_group(subgroup) + """ + if group_size is None: + if not torch.cuda.is_available(): + raise ValueError( + "Default group size only takes effect when CUDA is available." + "If your subgroup using a backend that does not depend on CUDA," + "please pass in 'group_size' correctly." + ) + group_size = torch.cuda.device_count() + if group_size <= 0: + raise ValueError(f"The arg 'group_size' ({group_size}) must be positive") + + world_size = get_world_size(group=group) + if world_size < group_size: + raise ValueError( + f"The arg 'group_size' ({group_size}) must not exceed the world size ({world_size})" + ) + if world_size % group_size != 0: + raise ValueError( + f"The world size ({world_size}) must be divisible by '{group_size=}'" + ) + + # TODO: Use itertools.batched(get_process_group_ranks(group=group), group_size) instead when Python 3.12 is supported. + ranks = get_process_group_ranks(group=group) + ranks_per_subgroup_list = [ + ranks[i : i + group_size] for i in range(0, len(ranks), group_size) + ] + return new_subgroups_by_enumeration( + ranks_per_subgroup_list, + timeout=timeout, + backend=backend, + pg_options=pg_options, + group_desc=group_desc, + ) + + +def new_subgroups_by_enumeration( + ranks_per_subgroup_list, + timeout=None, + backend=None, + pg_options=None, + group_desc=None, +): + """ + Create subgroups by dividing the global world. + + The division is specified by a nested list of ranks. The subgroups cannot have + overlap, and some ranks may not have to be in any subgroup. + + This is a convenience API that calls ``new_group`` to generate multiple subgroups. + It requires that all processes in the main group (i.e. all + processes that are part of the distributed job) enter this function, even + if they are not going to be members of the group. + + .. warning:: + See warning `Safe concurrent usage` for `new_group` API for important details about + using multiple process groups concurrently in a safe manner. + + Args: + ranks_per_subgroup_list (list[list[int]]): A nested list of ranks of + group members. + timeout (timedelta, optional): see `init_process_group` for details and default value. + backend (str or Backend, optional): The backend to use. Depending on + build-time configurations, valid values are ``gloo`` and ``nccl``. + By default uses the same backend as the global group. This field + should be given as a lowercase string (e.g., ``"gloo"``), which can + also be accessed via :class:`Backend` attributes (e.g., + ``Backend.GLOO``). If ``None`` is passed in, the backend + corresponding to the default process group will be used. Default is + ``None``. + pg_options (ProcessGroupOptions, optional): process group options + specifying what additional options need to be passed in during + the construction of specific process groups. i.e. for the ``nccl`` + backend, ``is_high_priority_stream`` can be specified so that + process group can pick up high priority cuda streams. + group_desc (str, optional): A string describing the group. Each subgroup will + inherit its group_desc. + + Returns: + The subgroup containing the current rank, and all the subgroups used for cleanup. + + Examples: + >>> # Create two subgroups, where each has 2 processes. + >>> # xdoctest: +SKIP("need process group init") + >>> cur_subgroup, subgroups = dist.new_subgroups(ranks=[[0, 2], [1, 3]]) + >>> rank = dist.get_rank() + >>> tensor = torch.ones(1, device=rank) * rank + >>> dist.all_reduce(tensor, group=cur_subgroup) + >>> tensor + tensor([2]) # Subgroup 0: ranks 0 and 2 + tensor([4]) # Subgroup 1: ranks 1 and 3 + """ + if ranks_per_subgroup_list is None or len(ranks_per_subgroup_list) == 0: + raise ValueError("The arg 'ranks_per_subgroup_list' cannot be empty") + + subgroups = [] + cur_subgroup = None + # Create a mapping from rank to subgroup to check if there is any subgroup overlap. + rank_to_ranks_dict = {} # type: ignore[var-annotated] + for ranks in ranks_per_subgroup_list: + subgroup = new_group( + ranks=ranks, + timeout=timeout, + backend=backend, + pg_options=pg_options, + group_desc=group_desc, + ) + subgroups.append(subgroup) + my_rank = get_rank() + for rank in ranks: + if rank in rank_to_ranks_dict: + raise ValueError( + f"Rank {rank} has appeared in both subgroup {rank_to_ranks_dict[rank]} and {ranks}" + ) + rank_to_ranks_dict[rank] = ranks + if my_rank == rank: + cur_subgroup = subgroup + logger.info("Rank %s is assigned to subgroup %s", rank, ranks) + + return cur_subgroup, subgroups + + +def _find_pg_by_ranks_and_tag(tag: str, ranks: list[int]) -> Optional[ProcessGroup]: + if len(tag) > 0 and not tag.startswith("ptd:") and not tag.startswith("user:"): + tag = f"user:{tag}" + + for group in _world.tags_to_pg.get(tag, []): + if group.size() != len(ranks): + continue + + group_ranks = get_process_group_ranks(group) + good = all(r in group_ranks for r in ranks) + if good: + return group + return None + + +def _find_or_create_pg_by_ranks_and_tag( + tag: str, ranks: list[int], stride: int +) -> ProcessGroup: + assert len(ranks) % stride == 0, ( + f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})" + ) + + my_rank = get_rank() + my_ranks = None + + if stride == len(ranks): + my_ranks = ranks.copy() + assert my_rank in my_ranks, "rankset doesn't include the current node" + else: + for i in range(0, len(ranks), stride): + rank_set = ranks[i : i + stride] + if my_rank in rank_set: + my_ranks = rank_set + assert my_ranks is not None, "rankset doesn't include the current node" + + my_ranks = sorted(my_ranks) + + pg = _find_pg_by_ranks_and_tag(tag, my_ranks) + if pg is not None: + return pg + if tag == "": + raise ValueError("Cannot automatically create PG with empty tag") + # TODO copy settings and timeout from default PG + return _new_group_with_tag(my_ranks, pg_tag=tag) + + +def _get_group_tag(pg: ProcessGroup) -> str: + """Return the tag associated with ``pg``.""" + tag = _world.pg_to_tag[pg] + tag = tag.removeprefix("user:") + return tag + + +def _get_process_group_name(pg: ProcessGroup) -> str: + return _world.pg_names.get(pg, "None") + + +def _get_process_group_store(pg: ProcessGroup) -> Store: + return _world.pg_map[pg][1] diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/__init__.py b/phivenv/Lib/site-packages/torch/distributed/elastic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..427e1745c4a2631cd006e0c856c248d7e2968c11 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/__init__.py @@ -0,0 +1,77 @@ +#!/usr/bin/env/python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" + +Torchelastic agent and user worker failover contract: + +**TL;DR;**: + +* TE(torchelastic) expects user workers to finish with the 5 minutes drift +* It is better to design DDP app to fail for all workers, rather than a single one. +* TE does not synchronize number of restarts between agents +* TE re-rendezvous does not trigger restart decrease +* When a single agent finishes its job(successfully or not), it will close rendezvous. + If other agents still have workers in progress, they will be terminated. +* Based on above, scale down does not work if at least single agent finishes the job. +* When Scale up is detected by agents, it will not decrease ``max_restarts`` + + +In general TE(torchelastic) can launch arbitrary user code, but there is some +clarifications need to be done around what failover mechanism torchelastic +provides and what failover mechanism it expects from user workers. + +Torchelastic currently supports DDP style applications. That means that +TE expects *ALL* workers finish approximately at the same time. In practice, +it is nearly to impossible to guarantee that all workers in arbitrary +DDP application finish at the time, so TE provides a finalization barrier +that waits for TIMEOUT(5 minutes) for worker finalization. + +**Worker Failure** + +When worker fails, TE will check the number of restarts +available, if there is more than 0 restarts, TE will start a new rendezvous +round and restart the worker process. New rendezvous round will other +TE agents to terminate their workers. + +.. note:: The TE agent does not synchronize restarts between themselves. + When a single agent performs restart, it will trigger a local ``max_restarts`` + decrease, other agent will not decrease their ``max_restarts``. + the user to run the distributed application locally on a dev host. + +A single worker failure can cause the whole cluster to fail: +If a single worker is constantly failing, it will cause the TE agent +``max_restarts`` to go to zero. This will cause an agent to finish its +work and close rendezvous. If there are any other workers on different +agents, they will be terminated. + + +**Re-Rendezvous** + +Re-rendezvous occurs when TE agents detect a new node +trying to joint a cluster. TE will not decrease ``max_restarts``. TE agents +will terminate its workers and start a new rendezvous round. + +Note about DynamicRendezvous(etcd-v2, c10d-experimental): If the rendezvous +has already max_nodes, the new node won't be added to the wait list right +away since there is no need to tear down a rendezvous that is already fully +utilized. The new node will wait until its timeout (600 secs by default) +and periodically check the number of participants. If the number becomes +less than max_nodes, it will be added to the wait list; otherwise, it will time out after 600 secs. + +*Scale up event*. When scale up event happens, torchelastic rendezvous +will detect that there are new nodes trying to join. Torchelastic agent +will stop all workers and perform re-rendezvous. Note: when scale up event +happens, *``max_restarts``* will *not* decrease. + +*Scale down event*. When scale down event happens, rendezvous will not +notify the torchelastic agent about it. If TE agent launched with ``max_restarts=0`` , +it relies on the underlying scheduler to handle job restart. If the ``max_restarts>0`` , +TE agent will terminate workers and start a new rdzv round, which is a *Scale up event*. + +""" diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/agent/__init__.py b/phivenv/Lib/site-packages/torch/distributed/elastic/agent/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/agent/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/agent/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..114522f41bd97ef69dd3dc0589d1ee035f610de5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/agent/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/agent/server/__init__.py b/phivenv/Lib/site-packages/torch/distributed/elastic/agent/server/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..21d5992f195597fd71aa809fe1b2dcedb4d6a8ef --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/agent/server/__init__.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +The elastic agent is the control plane of torchelastic. + +It is a process that launches and manages underlying worker processes. +The agent is responsible for: + +1. Working with distributed torch: the workers are started with all the + necessary information to successfully and trivially call + ``torch.distributed.init_process_group()``. + +2. Fault tolerance: monitors workers and upon detecting worker failures + or unhealthiness, tears down all workers and restarts everyone. + +3. Elasticity: Reacts to membership changes and restarts workers with the new + members. + +The simplest agents are deployed per node and works with local processes. +A more advanced agent can launch and manage workers remotely. Agents can +be completely decentralized, making decisions based on the workers it manages. +Or can be coordinated, communicating to other agents (that manage workers +in the same job) to make a collective decision. +""" + +from .api import ( # noqa: F401 + ElasticAgent, + RunResult, + SimpleElasticAgent, + Worker, + WorkerGroup, + WorkerSpec, + WorkerState, +) +from .local_elastic_agent import TORCHELASTIC_ENABLE_FILE_TIMER, TORCHELASTIC_TIMER_FILE diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/agent/server/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/agent/server/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e679a1ef5c29920bcae1de38b20842ae342f4866 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/agent/server/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/agent/server/__pycache__/api.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/agent/server/__pycache__/api.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3bab66221a30a7ff37b43a51c9a542a5cee7dc9a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/agent/server/__pycache__/api.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/agent/server/__pycache__/health_check_server.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/agent/server/__pycache__/health_check_server.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4d217d9099818ef026c1a02cb1831a0f67f369d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/agent/server/__pycache__/health_check_server.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/agent/server/__pycache__/local_elastic_agent.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/agent/server/__pycache__/local_elastic_agent.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c05dbfdc9c2bbd295f9ce90e325454885daf3bda Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/agent/server/__pycache__/local_elastic_agent.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/agent/server/api.py b/phivenv/Lib/site-packages/torch/distributed/elastic/agent/server/api.py new file mode 100644 index 0000000000000000000000000000000000000000..d0f8f61ddd4d36460a7367e631c654cc3290f54e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/agent/server/api.py @@ -0,0 +1,966 @@ +# mypy: ignore-errors + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import abc +import json +import os +import signal +import socket +import time +import traceback +import warnings +from collections import defaultdict +from contextlib import contextmanager +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Callable, Optional, Union + +import torch.distributed.elastic.rendezvous as rdzv +import torch.distributed.elastic.utils.store as store_util +from torch.distributed.elastic.events import Event, EventSource, record +from torch.distributed.elastic.metrics import prof, put_metric +from torch.distributed.elastic.multiprocessing import ProcessFailure, SignalException +from torch.distributed.elastic.rendezvous import RendezvousGracefulExitError +from torch.distributed.elastic.utils.logging import get_logger + + +__all__ = [ + "WorkerSpec", + "Worker", + "WorkerState", + "WorkerGroup", + "RunResult", + "ElasticAgent", + "SimpleElasticAgent", +] +_TERMINAL_STATE_SYNC_ID = "torchelastic/agent/terminal_state" + +DEFAULT_ROLE = "default" +logger = get_logger(__name__) + + +@dataclass +class WorkerSpec: + """Blueprint information about a particular type of worker. + + For a given role, there must only exist a single worker spec. + Worker spec is expected to be homogeneous across all nodes (machine), + that is each node runs the same number of workers for a particular spec. + + Args: + role: user-defined role for the workers with this spec + local_world_size: number local workers to run + fn: (deprecated use entrypoint instead) + entrypoint: worker function or command + args: arguments to pass to ``entrypoint`` + rdzv_handler: handles rdzv for this set of workers + max_restarts: number of max retries for the workers + monitor_interval: monitor status of workers every ``n`` seconds + master_port: fixed port to run the c10d store on rank 0 + if not specified then will chose a random free port + master_addr: fixed master_addr to run the c10d store on rank 0 + if not specified then will chose hostname on agent rank 0 + redirects: redirect std streams to a file, + selectively redirect for a particular + local rank by passing a map + tee: tees the specified std stream(s) to console + file, + selectively tee for a particular local rank by passing a map, + takes precedence over ``redirects`` settings. + event_log_handler: name of the event logging handler as registered in + `elastic/events/handlers.py `_. + """ + + role: str + local_world_size: int + rdzv_handler: rdzv.RendezvousHandler + fn: Optional[Callable] = None + # TODO @kiuk - make entrypoint a required field + entrypoint: Union[Callable, str, None] = None + args: tuple = () + max_restarts: int = 3 + monitor_interval: float = 0.1 + master_port: Optional[int] = None + master_addr: Optional[str] = None + local_addr: Optional[str] = None + event_log_handler: str = "null" + + def __post_init__(self): + assert self.local_world_size > 0 + assert self.monitor_interval > 0 + + if self.fn: + warnings.warn( + "WorkerSpec.fn will be deprecated," + " please use WorkerSpec.entrypoint instead", + category=DeprecationWarning, + ) + self.entrypoint = self.fn + assert self.entrypoint + + def get_entrypoint_name(self): + """Get the entry point name. + + If the entrypoint is a function (e.g. ``Callable``) returns its ``__qualname__`` + else if the entrypoint is a binary (e.g. ``str``), returns the binary name. + """ + if isinstance(self.entrypoint, str): + return os.path.basename(self.entrypoint) + else: + assert self.entrypoint is not None + return self.entrypoint.__qualname__ + + +class Worker: + """A worker instance. + + Contrast this with ``WorkerSpec`` that represents the specifications of a + worker. A ``Worker`` is created from a ``WorkerSpec``. A ``Worker`` is to + a ``WorkerSpec`` as an object is to a class. + + The ``id`` of the worker is interpreted + by the specific implementation of ``ElasticAgent``. For a local + agent, it could be the ``pid (int)`` of the worker, for a remote + agent it could be encoded as ``host:port (string)``. + + Args: + id (Any): uniquely identifies a worker (interpreted by the agent) + local_rank (int): local rank of the worker + global_rank (int): global rank of the worker + role_rank (int): rank of the worker across all workers that have the same role + world_size (int): number of workers (globally) + role_world_size (int): number of workers that have the same role + """ + + __slots__ = [ + "id", + "local_rank", + "global_rank", + "role_rank", + "world_size", + "role_world_size", + ] + + def __init__( + self, + local_rank: int, + global_rank: int = -1, + role_rank: int = -1, + world_size: int = -1, + role_world_size: int = -1, + ): + # unique identifier for this worker + self.id: Any = None + + # rank of the worker among workers with the same role being monitored + # by the same ``agent`` instance. + self.local_rank: int = local_rank + + # rank of the worker among all the workers across all roles + # across all ``agent`` instances. + # Global rank is not stable between re-rendezvous. + self.global_rank: int = global_rank + + # rank of the worker among all the workers with the same role + # across all ``agent`` instances. + # Role rank is not stable between re-rendezvous. + self.role_rank: int = role_rank + + # total number of workers (globally). Due to elasticity + # the world size may change between re-rendezvous. + self.world_size: int = world_size + + # total number of workers that share the same role. Due to elasticity + # the role world size may change between re-rendezvous. + self.role_world_size: int = role_world_size + + def __str__(self): + return ( + f"local_rank={self.local_rank},global_rank={self.global_rank}" + f",role_rank={self.role_rank},world_size={self.world_size}" + f",role_world_size={self.role_world_size}" + ) + + def __repr__(self): + return str(self) + + +class WorkerState(str, Enum): + """A state of the ``WorkerGroup``. + + Workers in a worker group change state as a unit. If a single worker + in a worker group fails the entire set is considered failed:: + + UNKNOWN - agent lost track of worker group state, unrecoverable + INIT - worker group object created not yet started + HEALTHY - workers running and healthy + UNHEALTHY - workers running and unhealthy + STOPPED - workers stopped (interrupted) by the agent + SUCCEEDED - workers finished running (exit 0) + FAILED - workers failed to successfully finish (exit !0) + + + A worker group starts from an initial ``INIT`` state, + then progresses to ``HEALTHY`` or ``UNHEALTHY`` states, + and finally reaches a terminal ``SUCCEEDED`` or ``FAILED`` state. + + Worker groups can be interrupted and temporarily put into ``STOPPED`` state + by the agent. Workers in ``STOPPED`` state are scheduled to be restarted + in the near future by the agent. Some examples of workers being put into + ``STOPPED`` state are: + + 1. Worker group failure|unhealthy observed + 2. Membership change detected + + When actions (start, stop, rdzv, retry, etc) on worker group fails + and results in the action being partially applied to the worker group + the state will be ``UNKNOWN``. Typically this happens on uncaught/unhandled + exceptions during state change events on the agent. The agent is not + expected to recover worker groups in ``UNKNOWN`` state and is better off + self terminating and allowing the job manager to retry the node. + """ + + UNKNOWN = "UNKNOWN" + INIT = "INIT" + HEALTHY = "HEALTHY" + UNHEALTHY = "UNHEALTHY" + STOPPED = "STOPPED" + SUCCEEDED = "SUCCEEDED" + FAILED = "FAILED" + + @staticmethod + def is_running(state: "WorkerState") -> bool: + """Return the state of the Worker. + + Returns: + True if the worker state represents workers still running + (e.g. that the process exists but not necessarily healthy). + """ + return state in {WorkerState.HEALTHY, WorkerState.UNHEALTHY} + + +class WorkerGroup: + """A set of ``Worker`` instances. + + The class defines a set of ``Worker`` instances for the given ``WorkerSpec`` managed by ``ElasticAgent``. Whether the worker + group contains cross instance workers or not depends on the implementation of the agent. + """ + + __slots__ = [ + "spec", + "workers", + "store", + "group_rank", + "group_world_size", + "state", + "master_addr", + "master_port", + ] + + def __init__(self, spec: WorkerSpec): + self.spec = spec + self.workers = [Worker(local_rank=i) for i in range(self.spec.local_world_size)] + + # assigned after rdzv + self.store = None + self.group_rank = None + self.group_world_size = None + self.master_addr = None + self.master_port = None + + self.state = WorkerState.INIT + + +class _RoleInstanceInfo: + """The class is used by the agent to exchange the information with other agents. + + The information is used to determine the rank of the workers that agent + manages in heterogeneous environments, where different agents can have + different number of workers. + """ + + __slots__ = ["role", "rank", "local_world_size"] + + def __init__(self, role: str, rank: int, local_world_size: int): + r"""Initialize the agent class instance. + + Args: + role (str): user-defined role for the workers with this spec + rank (int): the rank of the agent + local_world_size (int): number of local workers to run + """ + self.role = role + self.rank = rank + self.local_world_size = local_world_size + + def serialize(self) -> bytes: + dict_data = { + "role": self.role, + "rank": self.rank, + "local_world_size": self.local_world_size, + } + return json.dumps(dict_data).encode(encoding="UTF-8") + + @staticmethod + def deserialize(data: bytes): + dict_data = json.loads(data.decode(encoding="UTF-8")) + return _RoleInstanceInfo( + dict_data["role"], dict_data["rank"], dict_data["local_world_size"] + ) + + @staticmethod + def compare(obj1, obj2) -> int: + if obj1.role == obj2.role: + return obj1.rank - obj2.rank + elif obj1.role > obj2.role: + return 1 + else: + return -1 + + @staticmethod + def find_role_boundaries(roles_infos: list, role: str) -> tuple[int, int]: + start_idx, end_idx = -1, -1 + for idx, role_info in enumerate(roles_infos): + if role_info.role == role: + if start_idx == -1: + start_idx = idx + end_idx = idx + return (start_idx, end_idx) + + +@dataclass +class RunResult: + """Return results of the worker executions. + + Run results follow an "all-or-nothing" policy where the run is successful if and + only if ALL local workers managed by this agent complete successfully. + + If the result is successful (e.g. ``is_failed() = False``) then the ``return_values`` + field contains the outputs (return values) of the workers managed by THIS agent mapped + by their GLOBAL ranks. That is ``result.return_values[0]`` is the return value of + global rank 0. + + .. note:: ``return_values`` are only meaningful for when the worker entrypoint + is a function. Workers specified as a binary entrypoint do not canonically + have a return value and the ``return_values`` field is meaningless and + may be empty. + + If ``is_failed()`` returns ``True`` then the ``failures`` field contains the + failure information, again, mapped by the GLOBAL rank of the worker that failed. + + The keys in ``return_values`` and ``failures`` are mutually exclusive, that is, + a worker's final state can only be one of: succeeded, failed. Workers intentionally + terminated by the agent according to the agent's restart policy, are not represented + in either ``return_values`` nor ``failures``. + """ + + state: WorkerState + return_values: dict[int, Any] = field(default_factory=dict) + failures: dict[int, ProcessFailure] = field(default_factory=dict) + + def is_failed(self) -> bool: + return self.state == WorkerState.FAILED + + +def _get_fq_hostname() -> str: + return socket.getfqdn(socket.gethostname()) + + +class ElasticAgent(abc.ABC): + """An agent process responsible for managing one or more worker processes. + + The worker processes are assumed to be regular distributed PyTorch scripts. + When the worker process is created by the agent, the agent provides the + necessary information for the worker processes to properly initialize + a torch process group. + + The exact deployment topology and ratio of agent-to-worker is dependent + on the specific implementation of the agent and the user's job placement + preferences. For instance, to run a distributed training job on GPU with + 8 trainers (one per GPU) one can: + + 1. Use 8 x single GPU instances, place an agent per instance, managing + 1 worker per agent. + 2. Use 4 x double GPU instances, place an agent per instance, managing + 2 workers per agent. + 3. Use 2 x quad GPU instances, place an agent per instance, managing + 4 workers per agent. + 4. Use 1 x 8 GPU instance, place an agent per instance, managing + 8 workers per agent. + + Usage + :: + + group_result = agent.run() + if group_result.is_failed(): + # workers failed + failure = group_result.failures[0] + logger.exception("worker 0 failed with exit code : %s", failure.exit_code) + else: + return group_result.return_values[0] # return rank 0's results + + """ + + @abc.abstractmethod + def run(self, role: str = DEFAULT_ROLE) -> RunResult: + """Run the agent. + + Supports retrying the worker group on failures up to ``max_restarts``. + + Returns: + The result of the execution, containing the return values or + failure details for each worker mapped by the worker's global rank. + + Raises: + Exception - any other failures NOT related to worker process + """ + raise NotImplementedError + + @abc.abstractmethod + def get_worker_group(self, role: str = DEFAULT_ROLE) -> WorkerGroup: + """Return the ``WorkerGroup`` for the given ``role``. + + Note that the worker group is a mutable object and hence in a + multi-threaded/process environment it may change state. + Implementers are encouraged (but not required) to return + a defensive read-only copy. + """ + raise NotImplementedError + + +class SimpleElasticAgent(ElasticAgent): + """An ``ElasticAgent`` that manages one particular type of worker role. + + An ``ElasticAgent`` that manages workers (``WorkerGroup``) for a single ``WorkerSpec`` + such as one particular type of worker role. + """ + + def __init__(self, spec: WorkerSpec, exit_barrier_timeout: float = 300): + self._worker_group = WorkerGroup(spec) + self._remaining_restarts = self._worker_group.spec.max_restarts + self._store = None + self._exit_barrier_timeout = exit_barrier_timeout + self._total_execution_time = 0 + + def get_worker_group(self, role: str = DEFAULT_ROLE) -> WorkerGroup: + return self._worker_group + + @abc.abstractmethod + def _start_workers(self, worker_group: WorkerGroup) -> dict[int, Any]: + r"""Start ``worker_group.spec.local_world_size`` number of workers. + + This is according to worker spec for the worker group . + Returns a map of ``local_rank`` to worker ``id``. + """ + raise NotImplementedError + + @abc.abstractmethod + def _stop_workers(self, worker_group: WorkerGroup) -> None: + r"""Stop all workers in the given worker group. + + Implementers must deal with workers in all states defined by + ``WorkerState``. That is, it must gracefully handle stopping + non-existent workers, unhealthy (stuck) workers, etc. + """ + raise NotImplementedError + + @abc.abstractmethod + def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult: + r"""Check on the workers for the ``worker_group``. + + This function also returns the new state of the worker group. + """ + raise NotImplementedError + + @abc.abstractmethod + def _shutdown(self, death_sig: signal.Signals = signal.SIGTERM) -> None: + """Clean up any resources that were allocated during the agent's work. + + Args: + death_sig: Signal to send to the child process, SIGTERM is default + """ + raise NotImplementedError + + @prof + def _rendezvous(self, worker_group: WorkerGroup) -> None: + r"""Run rendezvous for the workers specified by the worker spec. + + Assigns workers a new global rank and world size. + Updates the rendezvous store for the worker group. + """ + spec = worker_group.spec + + with self.record_duration("RENDEZVOUS"): + rdzv_info = spec.rdzv_handler.next_rendezvous() + store = rdzv_info.store + group_rank = rdzv_info.rank + group_world_size = rdzv_info.world_size + + # master_addr/master_port could be explicitly overridden + # TODO: BC - specific to static rdzv and can be simplified further + master_addr = spec.master_addr or rdzv_info.bootstrap_store_info.master_addr + master_port = spec.master_port or rdzv_info.bootstrap_store_info.master_port + + self._store = store + + with self.record_duration("ASSIGN_WORKER_RANKS"): + workers = self._assign_worker_ranks( + store, group_rank, group_world_size, spec + ) + worker_group.workers = workers + worker_group.store = store + worker_group.group_rank = group_rank + worker_group.group_world_size = group_world_size + worker_group.master_addr = master_addr + worker_group.master_port = master_port + + restart_count = spec.max_restarts - self._remaining_restarts + + logger.info( + "[%(role)s] Rendezvous complete for workers. Result:\n" + " restart_count=%(restart_count)s\n" + " master_addr=%(master_addr)s\n" + " master_port=%(master_port)s\n" + " group_rank=%(group_rank)s\n" + " group_world_size=%(group_world_size)s\n" + " local_ranks=%(local_ranks)s\n" + " role_ranks=%(role_ranks)s\n" + " global_ranks=%(global_ranks)s\n" + " role_world_sizes=%(role_world_sizes)s\n" + " global_world_sizes=%(global_world_sizes)s\n" + " event_log_handler=%(event_log_handler)s\n", + { + "role": spec.role, + "restart_count": restart_count, + "master_addr": master_addr, + "master_port": master_port, + "group_rank": group_rank, + "group_world_size": group_world_size, + "local_ranks": [worker.local_rank for worker in workers], + "role_ranks": [worker.role_rank for worker in workers], + "global_ranks": [worker.global_rank for worker in workers], + "role_world_sizes": [worker.role_world_size for worker in workers], + "global_world_sizes": [worker.world_size for worker in workers], + "event_log_handler": spec.event_log_handler, + }, + ) + + # pyre-fixme[56]: Pyre was not able to infer the type of the decorator + # `torch.distributed.elastic.metrics.prof`. + @prof + def _assign_worker_ranks( + self, store, group_rank: int, group_world_size: int, spec: WorkerSpec + ) -> list[Worker]: + """Determine proper ranks for worker processes. + + Fast Path: when all workers have the same role and world size. We calculate + the global rank to be group_rank * group_world_size + local_rank. And the + `role_world_size` is the same as `global_world_size`. No TCP store is used in + this case. This is only enabled when users set the environment variable + `TORCH_ELASTIC_WORKER_IDENTICAL` to 1. + + Time complexity: each worker O(1), overall O(1) + + Slow Path: when workers have different roles and world sizes. We use the + the following algorithm: + + 1. Each agent writes its configuration(group_rank, group_world_size + , num_workers) to the common store. + 2. The rank 0 agent reads all the role_info from the store and + determines each agents worker ranks. + 3. Determine the global rank: the global rank of the workers is computed + by cumulative sum of the local_world_size for all workers in front of it. + For efficiency reasons each worker is assigned a base global rank + such that it's workers are in the range [base_global_rank, + base_global_rank + local_world_size). + 4. Determine the role rank: The role rank is determined using the algorithms + in the point 3 with the exception that the ranks are calculated with + respect to the role name. + 5. The rank 0 agent writes the assigned ranks to the store. + 6. Each agent reads the assigned ranks from the store. + + Time complexity: each worker O(1), rank0 O(n), overall O(n) + """ + + if os.environ.get("TORCH_ELASTIC_WORKER_IDENTICAL", "0") == "1": + global_world_size = group_world_size * spec.local_world_size + base_global_rank = group_rank * spec.local_world_size + base_role_rank = base_global_rank + role_world_size = global_world_size + else: + ROLE_INFO_PREFIX = "torchelastic/role_info/" + ASSIGNED_RANKS_PREFIX = "torchelastic/assigned_ranks/" + + agent_role_info = _RoleInstanceInfo( + spec.role, group_rank, spec.local_world_size + ) + store.set(f"{ROLE_INFO_PREFIX}{group_rank}", agent_role_info.serialize()) + + # tcp store is collocated with rank 0 so we can use it to do extra compute to reduce overall # of operations. + if group_rank == 0: + role_infos_bytes = store.multi_get( + [f"torchelastic/role_info/{i}" for i in range(group_world_size)] + ) + role_infos = [ + _RoleInstanceInfo.deserialize(info_bytes) + for info_bytes in role_infos_bytes + ] + + role_sizes = defaultdict(lambda: 0) + global_size = 0 + for role_info in role_infos: + role_sizes[role_info.role] += role_info.local_world_size + global_size += role_info.local_world_size + + base_global_rank = 0 + role_ranks = defaultdict(lambda: 0) + + keys = [] + values = [] + for i, role_info in enumerate(role_infos): + keys.append(f"{ASSIGNED_RANKS_PREFIX}{i}") + values.append( + json.dumps( + [ + base_global_rank, + global_size, + role_ranks[role_info.role], + role_sizes[role_info.role], + ] + ) + ) + + base_global_rank += role_info.local_world_size + role_ranks[role_info.role] += role_info.local_world_size + + store.multi_set(keys, values) + + # get will block until the data is available in the store. + ( + base_global_rank, + global_world_size, + base_role_rank, + role_world_size, + ) = json.loads(store.get(f"{ASSIGNED_RANKS_PREFIX}{group_rank}")) + + workers = [] + for local_rank in range(spec.local_world_size): + worker = Worker( + local_rank=local_rank, + global_rank=base_global_rank + local_rank, + role_rank=base_role_rank + local_rank, + world_size=global_world_size, + role_world_size=role_world_size, + ) + workers.append(worker) + return workers + + # pyre-fixme[56]: Pyre was not able to infer the type of the decorator + # `torch.distributed.elastic.metrics.prof`. + @prof + def _initialize_workers(self, worker_group: WorkerGroup) -> None: + r"""Start a fresh set of workers for the worker_group. + + Essentially, a rendezvous followed by a ``start_workers``. + The caller should first call ``_stop_workers()`` to stop running workers + prior to calling this method. + + Optimistically sets the state of the worker group that + just started as ``HEALTHY`` and delegates the actual monitoring + of state to ``_monitor_workers()`` method + """ + role = worker_group.spec.role + logger.info("[%s] Rendezvous'ing worker group", role) + + # TODO after stopping workers, wait at least monitor_interval*2 for + # workers on different nodes to fail on a collective op before waiting + # on the rdzv barrier, this way we ensure that nodes enter rdzv + # at around the same time and reduce false positive rdzv timeout errors + self._rendezvous(worker_group) + + logger.info("[%s] Starting worker group", role) + worker_ids = self._start_workers(worker_group) + for local_rank, w_id in worker_ids.items(): + worker = worker_group.workers[local_rank] + worker.id = w_id + record( + self._construct_event("START", EventSource.WORKER, worker), + worker_group.spec.event_log_handler, + ) + + worker_group.state = WorkerState.HEALTHY + + # pyre-fixme[56]: Pyre was not able to infer the type of the decorator + # `torch.distributed.elastic.metrics.prof`. + @prof + def _restart_workers(self, worker_group: WorkerGroup) -> None: + """Restart (stops, rendezvous, starts) all local workers in the group.""" + role = worker_group.spec.role + logger.info("[%s] Stopping worker group", role) + self._stop_workers(worker_group) + worker_group.state = WorkerState.STOPPED + self._initialize_workers(worker_group) + + # pyre-fixme[56]: Pyre was not able to infer the type of the decorator + # `torch.distributed.elastic.metrics.prof`. + @prof + def run(self, role: str = DEFAULT_ROLE) -> RunResult: + start_time = time.monotonic() + shutdown_called: bool = False + try: + result = self._invoke_run(role) + self._total_execution_time = int(time.monotonic() - start_time) + self._record_metrics(result) + self._record_worker_events(result) + return result + except RendezvousGracefulExitError as e: + logger.info("Rendezvous gracefully exited: %s", e) + except SignalException as e: + logger.warning("Received %s death signal, shutting down workers", e.sigval) + self._shutdown(e.sigval) + shutdown_called = True + raise + finally: + if not shutdown_called: + self._shutdown() + # record the execution time in case there were any exceptions during run. + self._total_execution_time = int(time.monotonic() - start_time) + + def get_event_failed(self) -> Event: + return self._construct_event( + state="FAILED", + source=EventSource.AGENT, + raw_error=traceback.format_exc(), + ) + + def get_event_succeeded(self) -> Event: + return self._construct_event( + state="SUCCEEDED", + source=EventSource.AGENT, + ) + + def _record_worker_events(self, result: RunResult) -> None: + for worker in self._worker_group.workers: + failure = result.failures.get(worker.global_rank) + state: str = self._get_worker_state(worker, result) + raw_error = json.dumps(failure.error_file_data) if failure else None + record( + self._construct_event(state, EventSource.WORKER, worker, raw_error), + self._worker_group.spec.event_log_handler, + ) + + def _get_worker_state(self, worker: Worker, result: RunResult) -> str: + failure = result.failures.get(worker.global_rank) + if result.state in {WorkerState.UNHEALTHY, WorkerState.FAILED} and not failure: + # The worker got terminated by the torchelastic agent via SIGTERM signal + return "TERMINATED" + elif failure or worker.global_rank in result.return_values: + return result.state.value + else: + raise ValueError(f"Unknown worker: {worker.global_rank}") + + @contextmanager + def record_duration(self, state: str): + start_time = time.perf_counter() + try: + yield + finally: + end_time = time.perf_counter() + duration_ms = (end_time - start_time) * 1000 + record( + self._construct_event( + state=state, source=EventSource.AGENT, duration_ms=duration_ms + ), + self._worker_group.spec.event_log_handler, + ) + + def _construct_event( + self, + state: str, + source: EventSource, + worker: Optional[Worker] = None, + raw_error: Optional[str] = None, + duration_ms: Optional[float] = None, + ) -> Event: + wg = self._worker_group + spec = wg.spec + md = { + "group_world_size": wg.group_world_size, + "entry_point": spec.get_entrypoint_name(), + } + if worker: + md["local_rank"] = (worker.local_rank,) + md["role_rank"] = (worker.role_rank,) + md["role_world_size"] = (worker.role_world_size,) + global_rank = worker.global_rank + worker_id = str(worker.id) + else: + global_rank = None + worker_id = None + md_str = json.dumps(md) + metadata = { + "run_id": spec.rdzv_handler.get_run_id(), + "global_rank": global_rank, + "group_rank": wg.group_rank, + "worker_id": worker_id, + "role": spec.role, + "hostname": _get_fq_hostname(), + "state": state, + "total_run_time": self._total_execution_time, + "rdzv_backend": spec.rdzv_handler.get_backend(), + "raw_error": raw_error, + "metadata": md_str, + "agent_restarts": spec.max_restarts - self._remaining_restarts, + "duration_ms": duration_ms, + } + + return Event( + f"torchelastic.worker.status.{state}", source=source, metadata=metadata + ) + + def _record_metrics(self, group_results: RunResult): + is_failed = group_results.is_failed() + self._record_flakiness_metric(is_failed) + spec = self._worker_group.spec + restarts_happened = self._remaining_restarts != spec.max_restarts + put_metric(f"workers.{spec.role}.run_total", 1) + self._record_metric_with_condition( + "run_success_with_retries", not is_failed and restarts_happened + ) + self._record_metric_with_condition( + "run_success_no_retries", not is_failed and not restarts_happened + ) + self._record_metric_with_condition( + "run_failed_with_retries", is_failed and restarts_happened + ) + self._record_metric_with_condition( + "run_failed_no_retries", is_failed and not restarts_happened + ) + + def _record_metric_with_condition(self, metric_name, condition): + spec = self._worker_group.spec + if condition: + put_metric(f"workers.{spec.role}.{metric_name}", 1) + else: + put_metric(f"workers.{spec.role}.{metric_name}", 0) + + def _record_flakiness_metric(self, is_failed: bool = False): + if is_failed: + flakiness = 100.0 + else: + spec = self._worker_group.spec + flakiness = 100.0 - 100.0 * (self._remaining_restarts + 1) / ( + spec.max_restarts + 1 + ) + spec = self._worker_group.spec + + put_metric(f"workers.{spec.role}.flakiness", int(flakiness)) + + def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult: + # NOTE: currently only works for a single role + + spec = self._worker_group.spec + role = spec.role + + logger.info( + "[%s] starting workers for entrypoint: %s", role, spec.get_entrypoint_name() + ) + + self._initialize_workers(self._worker_group) + monitor_interval = spec.monitor_interval + rdzv_handler = spec.rdzv_handler + + while True: + assert self._worker_group.state != WorkerState.INIT + time.sleep(monitor_interval) + run_result = self._monitor_workers(self._worker_group) + state = run_result.state + self._worker_group.state = state + + put_metric(f"workers.{role}.remaining_restarts", self._remaining_restarts) + put_metric(f"workers.{role}.{state.name.lower()}", 1) + + if state == WorkerState.SUCCEEDED: + logger.info( + "[%s] worker group successfully finished." + " Waiting %s seconds for other agents to finish.", + role, + self._exit_barrier_timeout, + ) + self._exit_barrier() + return run_result + elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}: + if self._remaining_restarts > 0: + logger.info( + "[%s] Worker group %s. " + "%s/%s attempts left;" + " will restart worker group", + role, + state.name, + self._remaining_restarts, + spec.max_restarts, + ) + self._remaining_restarts -= 1 + self._restart_workers(self._worker_group) + else: + self._stop_workers(self._worker_group) + self._worker_group.state = WorkerState.FAILED + return run_result + elif state == WorkerState.HEALTHY: + # membership changes do not count as retries + num_nodes_waiting = rdzv_handler.num_nodes_waiting() + group_rank = self._worker_group.group_rank + if num_nodes_waiting > 0: + logger.info( + "[%s] Detected %s " + "new nodes from group_rank=%s; " + "will restart worker group", + role, + num_nodes_waiting, + group_rank, + ) + self._restart_workers(self._worker_group) + else: + raise Exception( # noqa: TRY002 + f"[{role}] Worker group in {state.name} state" + ) + + def _exit_barrier(self): + """ + Define a barrier that keeps the agent process alive until all workers finish. + + Wait for ``exit_barrier_timeout`` seconds for all agents to finish + executing their local workers (either successfully or not). This + acts as a safety guard against user scripts that terminate at different + times. + """ + logger.info( + "Local worker group finished (%s). " + "Waiting %s seconds for other agents to finish", + self._worker_group.state, + self._exit_barrier_timeout, + ) + start = time.time() + try: + store_util.barrier( + store=self._store, + world_size=self._worker_group.group_world_size, + key_prefix=_TERMINAL_STATE_SYNC_ID, + barrier_timeout=self._exit_barrier_timeout, + ) + logger.info( + "Done waiting for other agents. Elapsed: %s seconds", + time.time() - start, + ) + except SignalException as e: + logger.warning("Got termination signal: %s", e.sigval) + raise + except Exception: + logger.exception( + "Error waiting on exit barrier. Elapsed: %s seconds", + time.time() - start, + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/agent/server/health_check_server.py b/phivenv/Lib/site-packages/torch/distributed/elastic/agent/server/health_check_server.py new file mode 100644 index 0000000000000000000000000000000000000000..d1b522e6e383f79a31b70b9fa04fc52af99f68cb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/agent/server/health_check_server.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable + +from torch.distributed.elastic.utils.logging import get_logger + + +log = get_logger(__name__) + +__all__ = ["HealthCheckServer", "create_healthcheck_server"] + + +class HealthCheckServer: + """ + Interface for health check monitoring server, which can be extended + by starting tcp/http server on the specified port. + + Args: + + alive_callback: Callable[[], int], callback to last progress time of agent + + port: int, port number to start tcp/http server + + timeout: int, timeout seconds to decide agent is alive/dead + """ + + _alive_callback: Callable[[], int] + _port: int + _timeout: int + + def __init__( + self, alive_callback: Callable[[], int], port: int, timeout: int + ) -> None: + self._alive_callback = alive_callback + self._port = port + self._timeout = timeout + + def start(self) -> None: + """ + Unsupported functionality for Pytorch, doesn't start any health check server + """ + log.warning("No health check server started") + + def stop(self) -> None: + """ + Function to stop health check server + """ + log.info("Stopping noop health check server.") + + +def create_healthcheck_server( + alive_callback: Callable[[], int], + port: int, + timeout: int, +) -> HealthCheckServer: + """ + creates health check server object + """ + return HealthCheckServer(alive_callback, port, timeout) diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/agent/server/local_elastic_agent.py b/phivenv/Lib/site-packages/torch/distributed/elastic/agent/server/local_elastic_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..452cdbaa71d5c99eed004d9f5f077e0088056b91 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/agent/server/local_elastic_agent.py @@ -0,0 +1,411 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import json +import os +import signal +import socket +import time +import uuid +from string import Template +from typing import Any, Optional, TYPE_CHECKING + +import torch.distributed.elastic.timer as timer +from torch.distributed.elastic import events +from torch.distributed.elastic.agent.server.api import ( + RunResult, + SimpleElasticAgent, + WorkerGroup, + WorkerSpec, + WorkerState, +) +from torch.distributed.elastic.agent.server.health_check_server import ( + create_healthcheck_server, + HealthCheckServer, +) +from torch.distributed.elastic.metrics.api import prof +from torch.distributed.elastic.multiprocessing import ( + LogsSpecs, + PContext, + start_processes, +) +from torch.distributed.elastic.utils import macros +from torch.distributed.elastic.utils.logging import get_logger + + +if TYPE_CHECKING: + from torch.distributed.elastic.events.api import EventMetadataValue + +logger = get_logger(__name__) + +__all__ = [ + "LocalElasticAgent", + "TORCHELASTIC_ENABLE_FILE_TIMER", + "TORCHELASTIC_TIMER_FILE", + "TORCHELASTIC_HEALTH_CHECK_PORT", +] + +TORCHELASTIC_ENABLE_FILE_TIMER = "TORCHELASTIC_ENABLE_FILE_TIMER" +TORCHELASTIC_HEALTH_CHECK_PORT = "TORCHELASTIC_HEALTH_CHECK_PORT" +TORCHELASTIC_TIMER_FILE = "TORCHELASTIC_TIMER_FILE" + + +class LocalElasticAgent(SimpleElasticAgent): + """An implementation of :py:class:`torchelastic.agent.server.ElasticAgent` that handles host-local workers. + + This agent is deployed per host and is configured to spawn ``n`` workers. + When using GPUs, ``n`` maps to the number of GPUs available on the host. + + The local agent does not communicate to other local agents deployed on + other hosts, even if the workers may communicate inter-host. The worker id + is interpreted to be a local process. The agent starts and stops all worker + processes as a single unit. + + + The worker function and argument passed to the worker function must be + python multiprocessing compatible. To pass multiprocessing data structures + to the workers you may create the data structure in the same multiprocessing + context as the specified ``start_method`` and pass it as a function argument. + + The ``exit_barrier_timeout`` specifies the amount of time (in seconds) to wait + for other agents to finish. This acts as a safety net to handle cases where + workers finish at different times, to prevent agents from viewing workers + that finished early as a scale-down event. It is strongly advised that the + user code deal with ensuring that workers are terminated in a synchronous + manner rather than relying on the exit_barrier_timeout. + + A named pipe based watchdog can be enabled in ```LocalElasticAgent``` if an + environment variable ``TORCHELASTIC_ENABLE_FILE_TIMER`` with value 1 has + been defined in the ```LocalElasticAgent``` process. + Optionally, another environment variable ```TORCHELASTIC_TIMER_FILE``` + can be set with a unique file name for the named pipe. If the environment + variable ```TORCHELASTIC_TIMER_FILE``` is not set, ```LocalElasticAgent``` + will internally create a unique file name and set it to the environment + variable ```TORCHELASTIC_TIMER_FILE```, and this environment variable will + be propagated to the worker processes to allow them to connect to the same + named pipe that ```LocalElasticAgent``` uses. + + Logs are written to the specified log directory. Each log line will be by default + prefixed by ``[${role_name}${local_rank}]:`` (e.g. ``[trainer0]: foobar``). + Log prefixes can be customized by passing a `template string + `_ as the + ``log_line_prefix_template`` argument. + The following macros (identifiers) are substituted at runtime: + ``${role_name}, ${local_rank}, ${rank}``. For example, to prefix each log line with + global rank instead of the local rank, set ``log_line_prefix_template = "[${rank}]:``. + + + Example launching function + + :: + + def trainer(args) -> str: + return "do train" + + def main(): + start_method="spawn" + shared_queue= multiprocessing.get_context(start_method).Queue() + spec = WorkerSpec( + role="trainer", + local_world_size=nproc_per_process, + entrypoint=trainer, + args=("foobar",), + ...) + agent = LocalElasticAgent(spec, start_method) + results = agent.run() + + if results.is_failed(): + print("trainer failed") + else: + print(f"rank 0 return value: {results.return_values[0]}") + # prints -> rank 0 return value: do train + + Example launching binary + + :: + + def main(): + spec = WorkerSpec( + role="trainer", + local_world_size=nproc_per_process, + entrypoint="/usr/local/bin/trainer", + args=("--trainer-args", "foobar"), + ...) + agent = LocalElasticAgent(spec) + results = agent.run() + + if not results.is_failed(): + print("binary launches do not have return values") + + """ + + def __init__( + self, + spec: WorkerSpec, + logs_specs: LogsSpecs, + start_method="spawn", + exit_barrier_timeout: float = 300, + log_line_prefix_template: Optional[str] = None, + ): + super().__init__(spec, exit_barrier_timeout) + self._start_method = start_method + self._pcontext: Optional[PContext] = None + self._rdzv_handler = spec.rdzv_handler + self._log_line_prefix_template = log_line_prefix_template + self._worker_watchdog: Optional[timer.FileTimerServer] = None + self._logs_specs = logs_specs + self._health_check_server: Optional[HealthCheckServer] = None + + def _setup_local_watchdog(self, envs: dict[int, dict[str, str]]) -> None: + enable_watchdog_env_name = TORCHELASTIC_ENABLE_FILE_TIMER + watchdog_enabled = os.getenv(enable_watchdog_env_name) + watchdog_file_env_name = TORCHELASTIC_TIMER_FILE + watchdog_file_path = os.getenv(watchdog_file_env_name) + if watchdog_enabled is not None and str(watchdog_enabled) == "1": + if watchdog_file_path is None: + watchdog_file_path = "/tmp/watchdog_timer_" + str(uuid.uuid4()) + logger.info("Starting a FileTimerServer with %s ...", watchdog_file_path) + if not envs: + logger.warning( + "Empty envs variables, using empty run_id for FileTimerServer" + ) + run_id = "" + else: + run_id = envs[0]["TORCHELASTIC_RUN_ID"] + self._worker_watchdog = timer.FileTimerServer( + file_path=watchdog_file_path, + run_id=run_id, + max_interval=0.1, + daemon=True, + log_event=self._log_watchdog_event, + ) + self._worker_watchdog.start() + logger.info("FileTimerServer started") + else: + logger.info( + "Environment variable '%s' not found. Do not start FileTimerServer.", + enable_watchdog_env_name, + ) + # Propagate the watchdog file env to worker processes + if watchdog_file_path is not None: + for worker_env in envs.values(): + worker_env[watchdog_file_env_name] = watchdog_file_path + + @staticmethod + def _get_current_time_secs() -> int: + return int(time.time()) + + def _setup_healthcheck(self) -> None: + healthcheck_port_env_name = TORCHELASTIC_HEALTH_CHECK_PORT + healthcheck_port = os.getenv(healthcheck_port_env_name) + if healthcheck_port is not None: + logger.info( + "Found healthcheck port %s: %s", + healthcheck_port_env_name, + healthcheck_port, + ) + if self._worker_watchdog is None: + logger.info( + "FileTimerServer doesn't exist, using current time as dummy callback" + ) + alive_callback = LocalElasticAgent._get_current_time_secs + else: + alive_callback = self._worker_watchdog.get_last_progress_time + + try: + healthcheck_port_as_int = int(healthcheck_port) + self._health_check_server = create_healthcheck_server( + alive_callback=alive_callback, + port=healthcheck_port_as_int, + timeout=60, + ) + self._health_check_server.start() + except ValueError: + logger.info( + "Invalid healthcheck port value: '%s', expecting integer. Not starting healthcheck server.", + healthcheck_port, + ) + else: + logger.info( + "Environment variable '%s' not found. Do not start health check.", + healthcheck_port_env_name, + ) + + def _get_fq_hostname(self) -> str: + return socket.getfqdn(socket.gethostname()) + + def _log_watchdog_event( + self, + name: str, + request: Optional[timer.FileTimerRequest], + ) -> None: + wg = self._worker_group + spec = wg.spec + md = {"watchdog_event": name} + if request is not None: + md["worker_pid"] = str(request.worker_pid) + md["scope_id"] = request.scope_id + md["expiration_time"] = str(request.expiration_time) + md["signal"] = str(request.signal) + md_str = json.dumps(md) + state = "RUNNING" + metadata: dict[str, EventMetadataValue] = { + "run_id": spec.rdzv_handler.get_run_id(), + "global_rank": None, + "group_rank": wg.group_rank, + "worker_id": None, + "role": spec.role, + "hostname": self._get_fq_hostname(), + "state": state, + "total_run_time": self._total_execution_time, + "rdzv_backend": spec.rdzv_handler.get_backend(), + "raw_error": None, + "metadata": md_str, + "agent_restarts": spec.max_restarts - self._remaining_restarts, + } + # Note: The 'metadata' field of the Event is converted to a TorchelasticStatusLogEntry later. + # The 'name' field of the Event is NOT used in the TorchelasticStatusLogEntry. + event = events.Event( + name=name, source=events.EventSource.AGENT, metadata=metadata + ) + events.record(event, self._worker_group.spec.event_log_handler) + + # pyre-fixme[56]: Pyre was not able to infer the type of the decorator + # `torch.distributed.elastic.metrics.prof`. + @prof + def _stop_workers(self, worker_group: WorkerGroup) -> None: + self._shutdown() + + # pyre-fixme[56]: Pyre was not able to infer the type of the decorator + # `torch.distributed.elastic.metrics.prof`. + @prof + def _start_workers(self, worker_group: WorkerGroup) -> dict[int, Any]: + spec = worker_group.spec + store = worker_group.store + assert store is not None + restart_count = spec.max_restarts - self._remaining_restarts + + use_agent_store: bool = spec.rdzv_handler.use_agent_store + logger.info("use_agent_store: %s", use_agent_store) + + args: dict[int, tuple] = {} + envs: dict[int, dict[str, str]] = {} + log_line_prefixes: Optional[dict[int, str]] = ( + {} if self._log_line_prefix_template else None + ) + for worker in worker_group.workers: + local_rank = worker.local_rank + worker_env = { + "LOCAL_RANK": str(local_rank), + "RANK": str(worker.global_rank), + "GROUP_RANK": str(worker_group.group_rank), + "ROLE_RANK": str(worker.role_rank), + "ROLE_NAME": spec.role, + "LOCAL_WORLD_SIZE": str(spec.local_world_size), + "WORLD_SIZE": str(worker.world_size), + "GROUP_WORLD_SIZE": str(worker_group.group_world_size), + "ROLE_WORLD_SIZE": str(worker.role_world_size), + "MASTER_ADDR": worker_group.master_addr, + "MASTER_PORT": str(worker_group.master_port), + "TORCHELASTIC_RESTART_COUNT": str(restart_count), + "TORCHELASTIC_MAX_RESTARTS": str(spec.max_restarts), + "TORCHELASTIC_RUN_ID": spec.rdzv_handler.get_run_id(), + "TORCHELASTIC_USE_AGENT_STORE": str(use_agent_store), + "TORCH_NCCL_ASYNC_ERROR_HANDLING": os.getenv( + "TORCH_NCCL_ASYNC_ERROR_HANDLING", str(1) + ), + } + if "OMP_NUM_THREADS" in os.environ: + worker_env["OMP_NUM_THREADS"] = os.environ["OMP_NUM_THREADS"] + + if self._log_line_prefix_template: + log_line_prefix = Template( + self._log_line_prefix_template + ).safe_substitute( + role_name=spec.role, + rank=worker.global_rank, + local_rank=local_rank, + ) + log_line_prefixes[local_rank] = log_line_prefix + + envs[local_rank] = worker_env + worker_args = list(spec.args) + worker_args = macros.substitute(worker_args, str(local_rank)) + args[local_rank] = tuple(worker_args) + + self._setup_local_watchdog(envs=envs) + self._setup_healthcheck() + + assert spec.entrypoint is not None + assert self._logs_specs is not None + self._pcontext = start_processes( + name=spec.role, + entrypoint=spec.entrypoint, + args=args, + envs=envs, + logs_specs=self._logs_specs, + log_line_prefixes=log_line_prefixes, + start_method=self._start_method, + ) + + return self._pcontext.pids() + + def _shutdown(self, death_sig: signal.Signals = signal.SIGTERM) -> None: + if self._worker_watchdog is not None: + self._worker_watchdog.stop() + self._worker_watchdog = None + if self._health_check_server is not None: + self._health_check_server.stop() + self._health_check_server = None + if self._pcontext: + self._pcontext.close(death_sig) + + # pyre-fixme[56]: Pyre was not able to infer the type of the decorator + # `torch.distributed.elastic.metrics.prof`. + @prof + def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult: + role = worker_group.spec.role + worker_pids = {w.id for w in worker_group.workers} + assert self._pcontext is not None + pc_pids = set(self._pcontext.pids().values()) + if worker_pids != pc_pids: + logger.error( + "[%s] worker pids do not match process_context pids." + " Expected: %s, actual: %s", + role, + worker_pids, + pc_pids, + ) + return RunResult(state=WorkerState.UNKNOWN) + + result = self._pcontext.wait(0) + if result: + if result.is_failed(): + # map local rank failure to global rank + worker_failures = {} + for local_rank, failure in result.failures.items(): + worker = worker_group.workers[local_rank] + worker_failures[worker.global_rank] = failure + return RunResult( + state=WorkerState.FAILED, + failures=worker_failures, + ) + else: + # copy ret_val_queue into a map with a global ranks + workers_ret_vals = {} + for local_rank, ret_val in result.return_values.items(): + worker = worker_group.workers[local_rank] + workers_ret_vals[worker.global_rank] = ret_val + return RunResult( + state=WorkerState.SUCCEEDED, + return_values=workers_ret_vals, + ) + else: + return RunResult(state=WorkerState.HEALTHY) diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/control_plane.py b/phivenv/Lib/site-packages/torch/distributed/elastic/control_plane.py new file mode 100644 index 0000000000000000000000000000000000000000..9c68f381f1c8962f52644b305f72564a624fddff --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/control_plane.py @@ -0,0 +1,53 @@ +import os +from collections.abc import Generator +from contextlib import contextmanager, ExitStack + +from torch.distributed.elastic.multiprocessing.errors import record + + +__all__ = [ + "worker_main", +] + +TORCH_WORKER_SERVER_SOCKET = "TORCH_WORKER_SERVER_SOCKET" + + +@contextmanager +def _worker_server(socket_path: str) -> Generator[None, None, None]: + from torch._C._distributed_c10d import _WorkerServer + + server = _WorkerServer(socket_path) + try: + yield + finally: + server.shutdown() + + +@record +@contextmanager +def worker_main() -> Generator[None, None, None]: + """ + This is a context manager that wraps your main entry function. This combines + the existing ``errors.record`` logic as well as a new ``_WorkerServer`` that + exposes handlers via a unix socket specified by + ``Torch_WORKER_SERVER_SOCKET``. + + Example + + :: + + @worker_main() + def main(): + pass + + + if __name__ == "__main__": + main() + + """ + with ExitStack() as stack: + socket_path = os.environ.get(TORCH_WORKER_SERVER_SOCKET) + if socket_path is not None: + stack.enter_context(_worker_server(socket_path)) + + yield diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/events/__init__.py b/phivenv/Lib/site-packages/torch/distributed/elastic/events/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7e7d985c7e633d978e3fdf6deb22420aa0c5d4ac --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/events/__init__.py @@ -0,0 +1,173 @@ +#!/usr/bin/env/python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Module contains events processing mechanisms that are integrated with the standard python logging. + +Example of usage: + +:: + + from torch.distributed.elastic import events + + event = events.Event( + name="test_event", source=events.EventSource.WORKER, metadata={...} + ) + events.get_logging_handler(destination="console").info(event) + +""" + +import inspect +import logging +import os +import socket +import traceback +from typing import Optional + +from torch.distributed.elastic.events.handlers import get_logging_handler + +from .api import ( # noqa: F401 + Event, + EventMetadataValue, + EventSource, + NodeState, + RdzvEvent, +) + + +_events_loggers: dict[str, logging.Logger] = {} + + +def _get_or_create_logger(destination: str = "null") -> logging.Logger: + """ + Construct python logger based on the destination type or extends if provided. + + Available destination could be found in ``handlers.py`` file. + The constructed logger does not propagate messages to the upper level loggers, + e.g. root logger. This makes sure that a single event can be processed once. + + Args: + destination: The string representation of the event handler. + Available handlers found in ``handlers`` module + """ + global _events_loggers + + if destination not in _events_loggers: + _events_logger = logging.getLogger(f"torchelastic-events-{destination}") + _events_logger.setLevel(os.environ.get("LOGLEVEL", "INFO")) + # Do not propagate message to the root logger + _events_logger.propagate = False + + logging_handler = get_logging_handler(destination) + _events_logger.addHandler(logging_handler) + + # Add the logger to the global dictionary + _events_loggers[destination] = _events_logger + + return _events_loggers[destination] + + +def record(event: Event, destination: str = "null") -> None: + _get_or_create_logger(destination).info(event.serialize()) + + +def record_rdzv_event(event: RdzvEvent) -> None: + _get_or_create_logger("dynamic_rendezvous").info(event.serialize()) + + +def construct_and_record_rdzv_event( + run_id: str, + message: str, + node_state: NodeState, + name: str = "", + hostname: str = "", + pid: Optional[int] = None, + master_endpoint: str = "", + local_id: Optional[int] = None, + rank: Optional[int] = None, +) -> None: + """ + Initialize rendezvous event object and record its operations. + + Args: + run_id (str): The run id of the rendezvous. + message (str): The message describing the event. + node_state (NodeState): The state of the node (INIT, RUNNING, SUCCEEDED, FAILED). + name (str): Event name. (E.g. Current action being performed). + hostname (str): Hostname of the node. + pid (Optional[int]): The process id of the node. + master_endpoint (str): The master endpoint for the rendezvous store, if known. + local_id (Optional[int]): The local_id of the node, if defined in dynamic_rendezvous.py + rank (Optional[int]): The rank of the node, if known. + Returns: + None + Example: + >>> # See DynamicRendezvousHandler class + >>> def _record( + ... self, + ... message: str, + ... node_state: NodeState = NodeState.RUNNING, + ... rank: Optional[int] = None, + ... ) -> None: + ... construct_and_record_rdzv_event( + ... name=f"{self.__class__.__name__}.{get_method_name()}", + ... run_id=self._settings.run_id, + ... message=message, + ... node_state=node_state, + ... hostname=self._this_node.addr, + ... pid=self._this_node.pid, + ... local_id=self._this_node.local_id, + ... rank=rank, + ... ) + """ + # We don't want to perform an extra computation if not needed. + if isinstance(get_logging_handler("dynamic_rendezvous"), logging.NullHandler): + return + + # Set up parameters. + if not hostname: + hostname = socket.getfqdn() + if not pid: + pid = os.getpid() + + # Determines which file called this function. + callstack = inspect.stack() + filename = "no_file" + if len(callstack) > 1: + stack_depth_1 = callstack[1] + filename = os.path.basename(stack_depth_1.filename) + if not name: + name = stack_depth_1.function + + # Delete the callstack variable. If kept, this can mess with python's + # garbage collector as we are holding on to stack frame information in + # the inspect module. + del callstack + + # Set up error trace if this is an exception + if node_state == NodeState.FAILED: + error_trace = traceback.format_exc() + else: + error_trace = "" + + # Initialize event object + event = RdzvEvent( + name=f"{filename}:{name}", + run_id=run_id, + message=message, + hostname=hostname, + pid=pid, + node_state=node_state, + master_endpoint=master_endpoint, + rank=rank, + local_id=local_id, + error_trace=error_trace, + ) + + # Finally, record the event. + record_rdzv_event(event) diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/events/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/events/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3beb200596d74c3f5fa97b41b21f6304f1a161bf Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/events/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/events/__pycache__/api.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/events/__pycache__/api.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..134f6f5489cf535d4af6e4a072623d208ff0fbcf Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/events/__pycache__/api.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/events/__pycache__/handlers.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/events/__pycache__/handlers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86d379a041bc82986e162a20b5b4803e575ad6d4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/events/__pycache__/handlers.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/events/api.py b/phivenv/Lib/site-packages/torch/distributed/elastic/events/api.py new file mode 100644 index 0000000000000000000000000000000000000000..9c288039d07fcdfb723dfd921b64b6c1ec3432bd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/events/api.py @@ -0,0 +1,114 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +from dataclasses import asdict, dataclass, field +from enum import Enum +from typing import Optional, Union + + +__all__ = ["EventSource", "Event", "NodeState", "RdzvEvent"] + +EventMetadataValue = Union[str, int, float, bool, None] + + +class EventSource(str, Enum): + """Known identifiers of the event producers.""" + + AGENT = "AGENT" + WORKER = "WORKER" + + +@dataclass +class Event: + """ + The class represents the generic event that occurs during the torchelastic job execution. + + The event can be any kind of meaningful action. + + Args: + name: event name. + source: the event producer, e.g. agent or worker + timestamp: timestamp in milliseconds when event occurred. + metadata: additional data that is associated with the event. + """ + + name: str + source: EventSource + timestamp: int = 0 + metadata: dict[str, EventMetadataValue] = field(default_factory=dict) + + def __str__(self): + return self.serialize() + + @staticmethod + def deserialize(data: Union[str, "Event"]) -> "Event": + if isinstance(data, Event): + return data + if isinstance(data, str): + data_dict = json.loads(data) + data_dict["source"] = EventSource[data_dict["source"]] # type: ignore[possibly-undefined] + return Event(**data_dict) + + def serialize(self) -> str: + return json.dumps(asdict(self)) + + +class NodeState(str, Enum): + """The states that a node can be in rendezvous.""" + + INIT = "INIT" + RUNNING = "RUNNING" + SUCCEEDED = "SUCCEEDED" + FAILED = "FAILED" + + +@dataclass +class RdzvEvent: + """ + Dataclass to represent any rendezvous event. + + Args: + name: Event name. (E.g. Current action being performed) + run_id: The run id of the rendezvous + message: The message describing the event + hostname: Hostname of the node + pid: The process id of the node + node_state: The state of the node (INIT, RUNNING, SUCCEEDED, FAILED) + master_endpoint: The master endpoint for the rendezvous store, if known + rank: The rank of the node, if known + local_id: The local_id of the node, if defined in dynamic_rendezvous.py + error_trace: Error stack trace, if this is an error event. + """ + + name: str + run_id: str + message: str + hostname: str + pid: int + node_state: NodeState + master_endpoint: str = "" + rank: Optional[int] = None + local_id: Optional[int] = None + error_trace: str = "" + + def __str__(self): + return self.serialize() + + @staticmethod + def deserialize(data: Union[str, "RdzvEvent"]) -> "RdzvEvent": + if isinstance(data, RdzvEvent): + return data + if isinstance(data, str): + data_dict = json.loads(data) + data_dict["node_state"] = NodeState[data_dict["node_state"]] # type: ignore[possibly-undefined] + return RdzvEvent(**data_dict) + + def serialize(self) -> str: + return json.dumps(asdict(self)) diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/events/handlers.py b/phivenv/Lib/site-packages/torch/distributed/elastic/events/handlers.py new file mode 100644 index 0000000000000000000000000000000000000000..95a34b85095209e73a0fb4d3f5e8ed3fdaacab1c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/events/handlers.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging + + +_log_handlers: dict[str, logging.Handler] = { + "console": logging.StreamHandler(), + "dynamic_rendezvous": logging.NullHandler(), + "null": logging.NullHandler(), +} + + +def get_logging_handler(destination: str = "null") -> logging.Handler: + global _log_handlers + return _log_handlers[destination] diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/metrics/__init__.py b/phivenv/Lib/site-packages/torch/distributed/elastic/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9c3f1c8c8a2f176ca6b0fdb2a1b46cb5eba39457 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/metrics/__init__.py @@ -0,0 +1,168 @@ +#!/usr/bin/env/python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Metrics API. + +**Overview**: + +The metrics API in torchelastic is used to publish telemetry metrics. +It is designed to be used by torchelastic's internal modules to +publish metrics for the end user with the goal of increasing visibility +and helping with debugging. However you may use the same API in your +jobs to publish metrics to the same metrics ``sink``. + +A ``metric`` can be thought of as timeseries data +and is uniquely identified by the string-valued tuple +``(metric_group, metric_name)``. + +torchelastic makes no assumptions about what a ``metric_group`` is +and what relationship it has with ``metric_name``. It is totally up +to the user to use these two fields to uniquely identify a metric. + +.. note:: The metric group ``torchelastic`` is reserved by torchelastic for + platform level metrics that it produces. + For instance torchelastic may output the latency (in milliseconds) + of a re-rendezvous operation from the agent as + ``(torchelastic, agent.rendezvous.duration.ms)`` + +A sensible way to use metric groups is to map them to a stage or module +in your job. You may also encode certain high level properties +the job such as the region or stage (dev vs prod). + +**Publish Metrics**: + +Using torchelastic's metrics API is similar to using python's logging +framework. You first have to configure a metrics handler before +trying to add metric data. + +The example below measures the latency for the ``calculate()`` function. + +:: + + import time + import torch.distributed.elastic.metrics as metrics + + # makes all metrics other than the one from "my_module" to go /dev/null + metrics.configure(metrics.NullMetricsHandler()) + metrics.configure(metrics.ConsoleMetricsHandler(), "my_module") + + + def my_method(): + start = time.time() + calculate() + end = time.time() + metrics.put_metric("calculate_latency", int(end - start), "my_module") + +You may also use the torch.distributed.elastic.metrics.prof` decorator +to conveniently and succinctly profile functions + +:: + + # -- in module examples.foobar -- + + import torch.distributed.elastic.metrics as metrics + + metrics.configure(metrics.ConsoleMetricsHandler(), "foobar") + metrics.configure(metrics.ConsoleMetricsHandler(), "Bar") + + + @metrics.prof + def foo(): + pass + + + class Bar: + @metrics.prof + def baz(): + pass + +``@metrics.prof`` will publish the following metrics +:: + + .success - 1 if the function finished successfully + .failure - 1 if the function threw an exception + .duration.ms - function duration in milliseconds + +**Configuring Metrics Handler**: + +`torch.distributed.elastic.metrics.MetricHandler` is responsible for emitting +the added metric values to a particular destination. Metric groups can be +configured with different metric handlers. + +By default torchelastic emits all metrics to ``/dev/null``. +By adding the following configuration metrics, +``torchelastic`` and ``my_app`` metric groups will be printed out to +console. + +:: + + import torch.distributed.elastic.metrics as metrics + + metrics.configure(metrics.ConsoleMetricHandler(), group="torchelastic") + metrics.configure(metrics.ConsoleMetricHandler(), group="my_app") + +**Writing a Custom Metric Handler**: + +If you want your metrics to be emitted to a custom location, implement +the `torch.distributed.elastic.metrics.MetricHandler` interface +and configure your job to use your custom metric handler. + +Below is a toy example that prints the metrics to ``stdout`` + +:: + + import torch.distributed.elastic.metrics as metrics + + + class StdoutMetricHandler(metrics.MetricHandler): + def emit(self, metric_data): + ts = metric_data.timestamp + group = metric_data.group_name + name = metric_data.name + value = metric_data.value + print(f"[{ts}][{group}]: {name}={value}") + + + metrics.configure(StdoutMetricHandler(), group="my_app") + +Now all metrics in the group ``my_app`` will be printed to stdout as: + +:: + + [1574213883.4182858][my_app]: my_metric= + [1574213940.5237644][my_app]: my_metric= + +""" + +from typing import Optional + +from .api import ( # noqa: F401 + configure, + ConsoleMetricHandler, + get_elapsed_time_ms, + getStream, + MetricData, + MetricHandler, + MetricsConfig, + NullMetricHandler, + prof, + profile, + publish_metric, + put_metric, +) + + +def initialize_metrics(cfg: Optional[MetricsConfig] = None): + pass + + +try: + from torch.distributed.elastic.metrics.static_init import * # type: ignore[import] # noqa: F401 F403 +except ModuleNotFoundError: + pass diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/metrics/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/metrics/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c41762405b6e34bad74c5aaecf23b1476eaa27c6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/metrics/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/metrics/__pycache__/api.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/metrics/__pycache__/api.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c9475445548635b5de8bc452489d11eac77f84a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/metrics/__pycache__/api.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/metrics/api.py b/phivenv/Lib/site-packages/torch/distributed/elastic/metrics/api.py new file mode 100644 index 0000000000000000000000000000000000000000..ff4b68e66262a33e62415f8aa6a807ebb27d33da --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/metrics/api.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import abc +import time +from collections import namedtuple +from functools import wraps +from typing import Optional +from typing_extensions import deprecated + + +__all__ = [ + "MetricsConfig", + "MetricHandler", + "ConsoleMetricHandler", + "NullMetricHandler", + "MetricStream", + "configure", + "getStream", + "prof", + "profile", + "put_metric", + "publish_metric", + "get_elapsed_time_ms", + "MetricData", +] + +MetricData = namedtuple("MetricData", ["timestamp", "group_name", "name", "value"]) + + +class MetricsConfig: + __slots__ = ["params"] + + def __init__(self, params: Optional[dict[str, str]] = None): + self.params = params + if self.params is None: + self.params = {} + + +class MetricHandler(abc.ABC): + @abc.abstractmethod + def emit(self, metric_data: MetricData): + pass + + +class ConsoleMetricHandler(MetricHandler): + def emit(self, metric_data: MetricData): + print( + f"[{metric_data.timestamp}][{metric_data.group_name}]: {metric_data.name}={metric_data.value}" + ) + + +class NullMetricHandler(MetricHandler): + def emit(self, metric_data: MetricData): + pass + + +class MetricStream: + def __init__(self, group_name: str, handler: MetricHandler): + self.group_name = group_name + self.handler = handler + + def add_value(self, metric_name: str, metric_value: int): + self.handler.emit( + MetricData(time.time(), self.group_name, metric_name, metric_value) + ) + + +_metrics_map: dict[str, MetricHandler] = {} +_default_metrics_handler: MetricHandler = NullMetricHandler() + + +# pyre-fixme[9]: group has type `str`; used as `None`. +def configure(handler: MetricHandler, group: Optional[str] = None): + if group is None: + global _default_metrics_handler + # pyre-fixme[9]: _default_metrics_handler has type `NullMetricHandler`; used + # as `MetricHandler`. + _default_metrics_handler = handler + else: + _metrics_map[group] = handler + + +def getStream(group: str): + if group in _metrics_map: + handler = _metrics_map[group] + else: + handler = _default_metrics_handler + return MetricStream(group, handler) + + +def _get_metric_name(fn): + qualname = fn.__qualname__ + split = qualname.split(".") + if len(split) == 1: + module = fn.__module__ + if module: + return module.split(".")[-1] + "." + split[0] + else: + return split[0] + else: + return qualname + + +def prof(fn=None, group: str = "torchelastic"): + r""" + @profile decorator publishes duration.ms, count, success, failure metrics for the function that it decorates. + + The metric name defaults to the qualified name (``class_name.def_name``) of the function. + If the function does not belong to a class, it uses the leaf module name instead. + + Usage + + :: + + @metrics.prof + def x(): + pass + + + @metrics.prof(group="agent") + def y(): + pass + """ + + def wrap(f): + @wraps(f) + def wrapper(*args, **kwargs): + key = _get_metric_name(f) + try: + start = time.time() + result = f(*args, **kwargs) + put_metric(f"{key}.success", 1, group) + except Exception: + put_metric(f"{key}.failure", 1, group) + raise + finally: + put_metric(f"{key}.duration.ms", get_elapsed_time_ms(start), group) # type: ignore[possibly-undefined] + return result + + return wrapper + + if fn: + return wrap(fn) + else: + return wrap + + +@deprecated("Deprecated, use `@prof` instead", category=FutureWarning) +def profile(group=None): + """ + @profile decorator adds latency and success/failure metrics to any given function. + + Usage + + :: + + @metrics.profile("my_metric_group") + def some_function(): + """ + + def wrap(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + start_time = time.time() + result = func(*args, **kwargs) + publish_metric(group, f"{func.__name__}.success", 1) + except Exception: + publish_metric(group, f"{func.__name__}.failure", 1) + raise + finally: + publish_metric( + group, + f"{func.__name__}.duration.ms", + get_elapsed_time_ms(start_time), # type: ignore[possibly-undefined] + ) + return result + + return wrapper + + return wrap + + +def put_metric(metric_name: str, metric_value: int, metric_group: str = "torchelastic"): + """ + Publish a metric data point. + + Usage + + :: + + put_metric("metric_name", 1) + put_metric("metric_name", 1, "metric_group_name") + """ + getStream(metric_group).add_value(metric_name, metric_value) + + +@deprecated( + "Deprecated, use `put_metric(metric_group)(metric_name, metric_value)` instead", + category=FutureWarning, +) +def publish_metric(metric_group: str, metric_name: str, metric_value: int): + metric_stream = getStream(metric_group) + metric_stream.add_value(metric_name, metric_value) + + +def get_elapsed_time_ms(start_time_in_seconds: float): + """Return the elapsed time in millis from the given start time.""" + end_time = time.time() + return int((end_time - start_time_in_seconds) * 1000) diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/__init__.py b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7f614876d997180ecdacb74f09d397d9a30448bc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/__init__.py @@ -0,0 +1,234 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Library that launches and manages ``n`` copies of worker subprocesses either specified by a function or a binary. + +For functions, it uses ``torch.multiprocessing`` (and therefore python +``multiprocessing``) to spawn/fork worker processes. For binaries it uses python +``subprocessing.Popen`` to create worker processes. + + +Usage 1: Launching two trainers as a function + +:: + + from torch.distributed.elastic.multiprocessing import Std, start_processes + + + def trainer(a, b, c): + pass # train + + + # runs two trainers + # LOCAL_RANK=0 trainer(1,2,3) + # LOCAL_RANK=1 trainer(4,5,6) + ctx = start_processes( + name="trainer", + entrypoint=trainer, + args={0: (1, 2, 3), 1: (4, 5, 6)}, + envs={0: {"LOCAL_RANK": 0}, 1: {"LOCAL_RANK": 1}}, + log_dir="/tmp/foobar", + redirects=Std.ALL, # write all worker stdout/stderr to a log file + tee={0: Std.ERR}, # tee only local rank 0's stderr to console + ) + + # waits for all copies of trainer to finish + ctx.wait() + +Usage 2: Launching 2 echo workers as a binary + +:: + + # same as invoking + # echo hello + # echo world > stdout.log + ctx = start_processes( + name="echo" + entrypoint="echo", + log_dir="/tmp/foobar", + args={0: "hello", 1: "world"}, + redirects={1: Std.OUT}, + ) + +Just like ``torch.multiprocessing``, the return value of the function +:func:`start_processes` is a process context (:class:`api.PContext`). If a function +was launched, a :class:`api.MultiprocessContext` is returned and if a binary +was launched a :class:`api.SubprocessContext` is returned. Both are specific +implementations of the parent :class:`api.PContext` class. +""" + +from typing import Callable, Optional, Union + +from torch.distributed.elastic.multiprocessing.api import ( # noqa: F401 + _validate_full_rank, + DefaultLogsSpecs, + LogsDest, + LogsSpecs, + MultiprocessContext, + PContext, + ProcessFailure, + RunProcsResult, + SignalException, + Std, + SubprocessContext, + to_map, +) +from torch.distributed.elastic.utils.logging import get_logger + + +__all__ = [ + "start_processes", + "MultiprocessContext", + "PContext", + "ProcessFailure", + "RunProcsResult", + "SignalException", + "Std", + "LogsDest", + "LogsSpecs", + "DefaultLogsSpecs", + "SubprocessContext", + "to_map", +] + + +def start_processes( + name: str, + entrypoint: Union[Callable, str], + args: dict[int, tuple], + envs: dict[int, dict[str, str]], + logs_specs: LogsSpecs, + log_line_prefixes: Optional[dict[int, str]] = None, + start_method: str = "spawn", +) -> PContext: + """ + Start ``n`` copies of ``entrypoint`` processes with the provided options. + + ``entrypoint`` is either a ``Callable`` (function) or a ``str`` (binary). + The number of copies is determined by the number of entries for ``args`` and + ``envs`` arguments, which need to have the same key set. + + ``args`` and ``env`` parameters are the arguments and environment variables + to pass down to the entrypoint mapped by the replica index (local rank). + All local ranks must be accounted for. + That is, the keyset should be ``{0,1,...,(nprocs-1)}``. + + .. note:: When the ``entrypoint`` is a binary (``str``), ``args`` can only be strings. + If any other type is given, then it is casted to a string representation + (e.g. ``str(arg1)``). Furthermore, a binary failure will only write + an ``error.json`` error file if the main function is annotated with + ``torch.distributed.elastic.multiprocessing.errors.record``. For function launches, + this is done by default and there is no need to manually annotate + with the ``@record`` annotation. + + ``redirects`` and ``tee`` are bitmasks specifying which std stream(s) to redirect + to a log file in the ``log_dir``. Valid mask values are defined in ``Std``. + To redirect/tee only certain local ranks, pass ``redirects`` as a map with the key as + the local rank to specify the redirect behavior for. + Any missing local ranks will default to ``Std.NONE``. + + ``tee`` acts like the unix "tee" command in that it redirects + prints to console. + To avoid worker stdout/stderr from printing to console, use the ``redirects`` parameter. + + For each process, the ``log_dir`` will contain: + + #. ``{local_rank}/error.json``: if the process failed, a file with the error info + #. ``{local_rank}/stdout.json``: if ``redirect & STDOUT == STDOUT`` + #. ``{local_rank}/stderr.json``: if ``redirect & STDERR == STDERR`` + + .. note:: It is expected that the ``log_dir`` exists, is empty, and is a directory. + + Example: + :: + + log_dir = "/tmp/test" + + # ok; two copies of foo: foo("bar0"), foo("bar1") + start_processes( + name="trainer", + entrypoint=foo, + args:{0:("bar0",), 1:("bar1",), + envs:{0:{}, 1:{}}, + log_dir=log_dir + ) + + # invalid; envs missing for local rank 1 + start_processes( + name="trainer", + entrypoint=foo, + args:{0:("bar0",), 1:("bar1",), + envs:{0:{}}, + log_dir=log_dir + ) + + # ok; two copies of /usr/bin/touch: touch file1, touch file2 + start_processes( + name="trainer", + entrypoint="/usr/bin/touch", + args:{0:("file1",), 1:("file2",), + envs:{0:{}, 1:{}}, + log_dir=log_dir + ) + + # caution; arguments casted to string, runs: + # echo "1" "2" "3" and echo "[1, 2, 3]" + start_processes( + name="trainer", + entrypoint="/usr/bin/echo", + args:{0:(1,2,3), 1:([1,2,3],), + envs:{0:{}, 1:{}}, + log_dir=log_dir + ) + + Args: + name: a human readable short name that describes what the processes are + (used as header when tee'ing stdout/stderr outputs) + entrypoint: either a ``Callable`` (function) or ``cmd`` (binary) + args: arguments to each replica + envs: env vars to each replica + log_dir: directory used to write log files + start_method: multiprocessing start method (spawn, fork, forkserver) + ignored for binaries + redirects: which std streams to redirect to a log file + tee: which std streams to redirect + print to console + local_ranks_filter: which ranks' logs to print to console + + """ + + nprocs = len(args) + _validate_full_rank(args, nprocs, "args") + _validate_full_rank(envs, nprocs, "envs") + + context: PContext + if isinstance(entrypoint, str): + context = SubprocessContext( + name=name, + entrypoint=entrypoint, + args=args, + envs=envs, + logs_specs=logs_specs, + log_line_prefixes=log_line_prefixes, + ) + else: + context = MultiprocessContext( + name=name, + entrypoint=entrypoint, + args=args, + envs=envs, + log_line_prefixes=log_line_prefixes, + start_method=start_method, + logs_specs=logs_specs, + ) + + try: + context.start() + return context + except Exception: + context.close() + raise diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea4adf10479317d810cf30828fa04cc5a43a315d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/api.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/api.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c271039b400e5d0272b5ae25ac5b270355505f9 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/api.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/redirects.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/redirects.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75a36dfcc8b7c4aa56cba3f2e185c901e2ea49fb Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/redirects.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/tail_log.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/tail_log.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6990fb7c390b9df2d2f262edee5fe9a871acadb0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/tail_log.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/api.py b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/api.py new file mode 100644 index 0000000000000000000000000000000000000000..beb1c794ec15cbaa5d0c97cd0ebdbcf9fa93e752 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/api.py @@ -0,0 +1,926 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import abc +import logging +import os +import re +import shutil +import signal +import subprocess +import sys +import tempfile +import threading +import time +from abc import ABC, abstractmethod +from contextlib import nullcontext +from dataclasses import dataclass, field +from enum import IntFlag +from multiprocessing import synchronize +from types import FrameType +from typing import Any, Callable, Optional, Union + +import torch.multiprocessing as mp +from torch.distributed.elastic.multiprocessing.errors import ProcessFailure, record +from torch.distributed.elastic.multiprocessing.redirects import ( + redirect_stderr, + redirect_stdout, +) +from torch.distributed.elastic.multiprocessing.subprocess_handler import ( + get_subprocess_handler, + SubprocessHandler, +) +from torch.distributed.elastic.multiprocessing.tail_log import TailLog + + +IS_WINDOWS = sys.platform == "win32" +IS_MACOS = sys.platform == "darwin" + + +logger = logging.getLogger(__name__) + +__all__ = [ + "DefaultLogsSpecs", + "SignalException", + "Std", + "to_map", + "RunProcsResult", + "PContext", + "get_std_cm", + "MultiprocessContext", + "SubprocessContext", + "LogsDest", + "LogsSpecs", +] + + +class SignalException(Exception): + """ + Exception is raised inside the torchelastic agent process by the termination handler + if the death signal got received by the process. + """ + + def __init__(self, msg: str, sigval: signal.Signals) -> None: + super().__init__(msg) + self.sigval = sigval + + +def _terminate_process_handler(signum: int, frame: Optional[FrameType]) -> None: + """Termination handler that raises exceptions on the main process. + + When the process receives death signal(SIGTERM, SIGINT), this termination handler will + be invoked. It raises the ``SignalException`` exception that should be processed by the + user code. Python does not terminate process after the termination handler is finished, + so the exception should not be silently ignored, otherwise the process will never + be terminated. + """ + sigval = signal.Signals(signum) + raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval) + + +def _get_kill_signal() -> signal.Signals: + """Get the kill signal. SIGKILL for unix, CTRL_C_EVENT for windows.""" + if IS_WINDOWS: + return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821 + else: + return signal.SIGKILL + + +def _get_default_signal() -> signal.Signals: + """Get the default termination signal. SIGTERM for unix, CTRL_C_EVENT for windows.""" + if IS_WINDOWS: + return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821 + else: + return signal.SIGTERM + + +def _validate_full_rank(d: dict[int, Any], nprocs: int, what: str): + actual_keys = set(d.keys()) + expected_keys = set(range(nprocs)) + + if actual_keys != expected_keys: + raise RuntimeError( + f"{what}, local rank mapping mismatch," + f" expected: {expected_keys}, actual: {actual_keys}" + ) + + +_MAPPING_REGEX = r"^(\d:[0123],)*(\d:[0123])$" +_VALUE_REGEX = r"^[0123]$" + + +class Std(IntFlag): + NONE = 0 + OUT = 1 + ERR = 2 + ALL = OUT | ERR + + @classmethod + def from_str(cls, vm: str) -> Union["Std", dict[int, "Std"]]: + """ + Example: + :: + + from_str("0") -> Std.NONE + from_str("1") -> Std.OUT + from_str("0:3,1:0,2:1,3:2") -> {0: Std.ALL, 1: Std.NONE, 2: Std.OUT, 3: Std.ERR} + + Any other input raises an exception + """ + + def to_std(v: str) -> Std: # type: ignore[return] + s = Std(int(v)) + if s in Std: + return s + # return None -> should NEVER reach here since we regex check input + + if re.match(_VALUE_REGEX, vm): # vm is a number (e.g. 0) + return to_std(vm) + elif re.match(_MAPPING_REGEX, vm): # vm is a mapping (e.g. 0:1,1:2) + d: dict[int, Std] = {} + for m in vm.split(","): + i, v = m.split(":") + d[int(i)] = to_std(v) + return d + else: + raise ValueError( + f"{vm} does not match: <{_VALUE_REGEX}> or <{_MAPPING_REGEX}>" + ) + + +def to_map( + val_or_map: Union[Std, dict[int, Std]], local_world_size: int +) -> dict[int, Std]: + """ + Certain APIs take redirect settings either as a single value (e.g. apply to all + local ranks) or as an explicit user-provided mapping. This method is a convenience + method that converts a value or mapping into a mapping. + + Example: + :: + + to_map(Std.OUT, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT} + to_map({1: Std.OUT}, local_world_size=2) # returns: {0: Std.NONE, 1: Std.OUT} + to_map( + {0: Std.OUT, 1: Std.OUT}, local_world_size=2 + ) # returns: {0: Std.OUT, 1: Std.OUT} + """ + if isinstance(val_or_map, Std): + return dict.fromkeys(range(local_world_size), val_or_map) + else: + map = {} + for i in range(local_world_size): + map[i] = val_or_map.get(i, Std.NONE) + return map + + +@dataclass +class LogsDest: + """ + For each log type, holds mapping of local rank ids to file paths. + """ + + stdouts: dict[int, str] = field(default_factory=dict) + stderrs: dict[int, str] = field(default_factory=dict) + tee_stdouts: dict[int, str] = field(default_factory=dict) + tee_stderrs: dict[int, str] = field(default_factory=dict) + error_files: dict[int, str] = field(default_factory=dict) + + +class LogsSpecs(ABC): + """ + Defines logs processing and redirection for each worker process. + + Args: + log_dir: + Base directory where logs will be written. + redirects: + Streams to redirect to files. Pass a single ``Std`` + enum to redirect for all workers, or a mapping keyed + by local_rank to selectively redirect. + tee: + Streams to duplicate to stdout/stderr. + Pass a single ``Std`` enum to duplicate streams for all workers, + or a mapping keyed by local_rank to selectively duplicate. + """ + + def __init__( + self, + log_dir: Optional[str] = None, + redirects: Union[Std, dict[int, Std]] = Std.NONE, + tee: Union[Std, dict[int, Std]] = Std.NONE, + local_ranks_filter: Optional[set[int]] = None, + ) -> None: + self._root_log_dir = log_dir + self._redirects = redirects + self._tee = tee + self._local_ranks_filter = local_ranks_filter + + @abstractmethod + def reify( + self, + envs: dict[int, dict[str, str]], + ) -> LogsDest: + """ + Given the environment variables, builds destination of log files for each of the local ranks. + + Envs parameter contains env variables dict for each of the local ranks, where entries are defined in: + :func:`~torchelastic.distributed.elastic.agent.server.local_elastic_agent.LocalElasticAgent._start_workers`. + """ + + @property + @abstractmethod + def root_log_dir(self) -> str: + pass + + +class DefaultLogsSpecs(LogsSpecs): + """ + Default LogsSpecs implementation: + + - `log_dir` will be created if it doesn't exist + - Generates nested folders for each attempt and rank. + """ + + def __init__( + self, + log_dir: Optional[str] = None, + redirects: Union[Std, dict[int, Std]] = Std.NONE, + tee: Union[Std, dict[int, Std]] = Std.NONE, + local_ranks_filter: Optional[set[int]] = None, + ) -> None: + if log_dir != os.devnull: + if not log_dir: + log_dir = tempfile.mkdtemp(prefix="torchelastic_") + elif not os.path.exists(log_dir): + os.makedirs(log_dir, exist_ok=True) + else: + if os.path.isfile(log_dir): + raise NotADirectoryError(f"log_dir: {log_dir} is a file") + super().__init__(log_dir, redirects, tee, local_ranks_filter) + # initialized only once + self._run_log_dir = None + + @property + def root_log_dir(self) -> str: + return str(self._root_log_dir) + + def _make_log_dir(self, log_dir: Optional[str], rdzv_run_id: str): + base_log_dir = log_dir or tempfile.mkdtemp(prefix="torchelastic_") + os.makedirs(base_log_dir, exist_ok=True) + dir = tempfile.mkdtemp(prefix=f"{rdzv_run_id}_", dir=base_log_dir) + logger.info("log directory set to: %s", dir) + return dir + + def reify( + self, + envs: dict[int, dict[str, str]], + ) -> LogsDest: + """ + Uses following scheme to build log destination paths: + + - `//attempt_//stdout.log` + - `//attempt_//stderr.log` + - `//attempt_//error.json` + """ + nprocs = len(envs) + global_env = {} # use only to query properties that are not dependent on a rank + if nprocs > 0: + global_env = envs[0] + else: + logger.warning( + "Empty envs map provided when defining logging destinations." + ) + # Keys are always defined, but values can be missing in unit tests + run_id = global_env.get("TORCHELASTIC_RUN_ID", "test_run_id") + restart_count = global_env.get("TORCHELASTIC_RESTART_COUNT", "0") + + attempt_log_dir: str = "" + if self._root_log_dir != os.devnull: + if not self._run_log_dir: + self._run_log_dir = self._make_log_dir(self._root_log_dir, run_id) + + attempt_log_dir = os.path.join( + self._run_log_dir, f"attempt_{restart_count}" + ) # type: ignore[call-overload] + shutil.rmtree(attempt_log_dir, ignore_errors=True) + os.makedirs(attempt_log_dir) + + if self._root_log_dir == os.devnull: + attempt_log_dir = os.devnull + + # create subdirs for each local rank in the logs_dir + # logs_dir + # |- 0 + # |- error.json + # |- stdout.log + # |- stderr.log + # |- ... + # |- (nprocs-1) + redirs = to_map(self._redirects, nprocs) + ts = to_map(self._tee, nprocs) + + # to tee stdout/stderr we first redirect into a file + # then tail -f stdout.log/stderr.log so add tee settings to redirects + for local_rank, tee_std in ts.items(): + redirect_std = redirs[local_rank] + redirs[local_rank] = redirect_std | tee_std + + SYS_STREAM = "" # special case to indicate to output to console + stdouts = dict.fromkeys(range(nprocs), SYS_STREAM) + stderrs = dict.fromkeys(range(nprocs), SYS_STREAM) + tee_stdouts: dict[int, str] = {} + tee_stderrs: dict[int, str] = {} + error_files = {} + + for local_rank in range(nprocs): + if attempt_log_dir == os.devnull: + tee_stdouts[local_rank] = os.devnull + tee_stderrs[local_rank] = os.devnull + error_files[local_rank] = os.devnull + envs[local_rank]["TORCHELASTIC_ERROR_FILE"] = "" + else: + clogdir = os.path.join(attempt_log_dir, str(local_rank)) + os.mkdir(clogdir) + + rd = redirs[local_rank] + if (rd & Std.OUT) == Std.OUT: + stdouts[local_rank] = os.path.join(clogdir, "stdout.log") + if (rd & Std.ERR) == Std.ERR: + stderrs[local_rank] = os.path.join(clogdir, "stderr.log") + + t = ts[local_rank] + if t & Std.OUT == Std.OUT: + tee_stdouts[local_rank] = stdouts[local_rank] + if t & Std.ERR == Std.ERR: + tee_stderrs[local_rank] = stderrs[local_rank] + + if ( + self._local_ranks_filter + and local_rank not in self._local_ranks_filter + ): + # If stream is tee'd, only write to file, but don't tail + if local_rank in tee_stdouts: + tee_stdouts.pop(local_rank, None) + if local_rank in tee_stderrs: + tee_stderrs.pop(local_rank, None) + + # If stream is not redirected, don't print + if stdouts[local_rank] == SYS_STREAM: + stdouts[local_rank] = os.devnull + if stderrs[local_rank] == SYS_STREAM: + stderrs[local_rank] = os.devnull + + error_file = os.path.join(clogdir, "error.json") + error_files[local_rank] = error_file + logger.info( + "Setting worker%s reply file to: %s", local_rank, error_file + ) + envs[local_rank]["TORCHELASTIC_ERROR_FILE"] = error_file + + return LogsDest(stdouts, stderrs, tee_stdouts, tee_stderrs, error_files) + + def __repr__(self) -> str: + return ( + f"DefaultLogsSpecs(root_log_dir={self._root_log_dir}, redirects={self._redirects}, " + f"tee={self._tee}, local_ranks_filter={self._local_ranks_filter})" + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, DefaultLogsSpecs): + return False + + return ( + self._root_log_dir == other._root_log_dir + and self._redirects == other._redirects + and self._tee == other._tee + and self._local_ranks_filter == other._local_ranks_filter + ) + + +@dataclass +class RunProcsResult: + """ + Results of a completed run of processes started with ``start_processes()``. Returned by ``PContext``. + + Note the following: + + 1. All fields are mapped by local rank + 2. ``return_values`` - only populated for functions (not the binaries). + 3. ``stdouts`` - path to stdout.log (empty string if no redirect) + 4. ``stderrs`` - path to stderr.log (empty string if no redirect) + + """ + + return_values: dict[int, Any] = field(default_factory=dict) + failures: dict[int, ProcessFailure] = field(default_factory=dict) + stdouts: dict[int, str] = field(default_factory=dict) + stderrs: dict[int, str] = field(default_factory=dict) + + def is_failed(self) -> bool: + return len(self.failures) > 0 + + +class PContext(abc.ABC): + """ + The base class that standardizes operations over a set of processes that are launched via different mechanisms. + + The name ``PContext`` is intentional to disambiguate with ``torch.multiprocessing.ProcessContext``. + + .. warning:: stdouts and stderrs should ALWAYS be a superset of + tee_stdouts and tee_stderrs (respectively) this is b/c + tee is implemented as a redirect + tail -f + """ + + def __init__( + self, + name: str, + entrypoint: Union[Callable, str], + args: dict[int, tuple], + envs: dict[int, dict[str, str]], + logs_specs: LogsSpecs, + log_line_prefixes: Optional[dict[int, str]] = None, + ): + self.name = name + # validate that all mappings have the same number of keys and + # all local ranks are accounted for + nprocs = len(args) + + # TODO log_line_prefixes can be expanded too + logs_dest = logs_specs.reify(envs) + + _validate_full_rank(logs_dest.stdouts, nprocs, "stdouts") + _validate_full_rank(logs_dest.stderrs, nprocs, "stderrs") + + self.entrypoint = entrypoint + self.args = args + self.envs = envs + self.stdouts = logs_dest.stdouts + self.stderrs = logs_dest.stderrs + self.error_files = logs_dest.error_files + self.nprocs = nprocs + + self._stdout_tail = TailLog( + name, logs_dest.tee_stdouts, sys.stdout, log_line_prefixes + ) + self._stderr_tail = TailLog( + name, logs_dest.tee_stderrs, sys.stderr, log_line_prefixes + ) + + def start(self) -> None: + """Start processes using parameters defined in the constructor.""" + if threading.current_thread() is threading.main_thread(): + signal.signal(signal.SIGTERM, _terminate_process_handler) + signal.signal(signal.SIGINT, _terminate_process_handler) + if not IS_WINDOWS: + signal.signal(signal.SIGHUP, _terminate_process_handler) + signal.signal(signal.SIGQUIT, _terminate_process_handler) + else: + logger.warning( + "Failed to register signal handlers since torchelastic is running on a child thread. " + "This could lead to orphaned worker processes if the torchrun is terminated." + ) + self._start() + self._stdout_tail.start() + self._stderr_tail.start() + + @abc.abstractmethod + def _start(self) -> None: + """Start processes using strategy defined in a particular context.""" + raise NotImplementedError + + @abc.abstractmethod + def _poll(self) -> Optional[RunProcsResult]: + """ + Poll the run status of the processes running under this context. + This method follows an "all-or-nothing" policy and returns + a ``RunProcessResults`` object if either all processes complete + successfully or any process fails. Returns ``None`` if + all processes are still running. + """ + raise NotImplementedError + + def wait(self, timeout: float = -1, period: float = 1) -> Optional[RunProcsResult]: + """ + Wait for the specified ``timeout`` seconds, polling every ``period`` seconds + for the processes to be done. Returns ``None`` if the processes are still running + on timeout expiry. Negative timeout values are interpreted as "wait-forever". + A timeout value of zero simply queries the status of the processes (e.g. equivalent + to a poll). + + .. note:: + Multiprocessing library registers SIGTERM and SIGINT signal handlers that raise + ``SignalException`` when the signals received. It is up to the consumer of the code + to properly handle the exception. It is important not to swallow the exception otherwise + the process would not terminate. Example of the typical workflow can be: + + .. code-block:: python + pc = start_processes(...) + try: + pc.wait(1) + .. do some other work + except SignalException as e: + pc.shutdown(e.sigval, timeout=30) + + If SIGTERM or SIGINT occurs, the code above will try to shutdown child processes by propagating + received signal. If child processes will not terminate in the timeout time, the process will send + the SIGKILL. + """ + if timeout == 0: + return self._poll() + + if timeout < 0: + timeout = sys.maxsize + + expiry = time.time() + timeout + while time.time() < expiry: + pr = self._poll() + if pr: + return pr + time.sleep(period) + + return None + + @abc.abstractmethod + def pids(self) -> dict[int, int]: + """Return pids of processes mapped by their respective local_ranks.""" + raise NotImplementedError + + @abc.abstractmethod + def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None: + r""" + Terminates all processes managed by this context and cleans up any + meta resources (e.g. redirect, error_file files). + """ + raise NotImplementedError + + def close( + self, death_sig: Optional[signal.Signals] = None, timeout: int = 30 + ) -> None: + r""" + Terminates all processes managed by this context and cleans up any + meta resources (e.g. redirect, error_file files). + + Args: + death_sig: Death signal to terminate processes. + timeout: Time to wait for processes to finish, if process is + still alive after this time, it will be terminated via SIGKILL. + """ + if not death_sig: + death_sig = _get_default_signal() + self._close(death_sig=death_sig, timeout=timeout) + if self._stdout_tail: + self._stdout_tail.stop() + if self._stderr_tail: + self._stderr_tail.stop() + + +def get_std_cm(std_rd: str, redirect_fn): + if IS_WINDOWS or IS_MACOS or not std_rd: + return nullcontext() + else: + return redirect_fn(std_rd) + + +def _wrap( + local_rank: int, + fn: Callable, + args: dict[int, tuple], + envs: dict[int, dict[str, str]], + stdout_redirects: dict[int, str], # redirect file for stdout (to console if None) + stderr_redirects: dict[int, str], # redirect file for stderr (to console if None) + ret_vals: dict[int, mp.SimpleQueue], + queue_finished_reading_event: synchronize.Event, +) -> None: + # get the per-rank params up front so we fail fast if no mapping is found + args_ = args[local_rank] + env_ = envs[local_rank] + ret_val_ = ret_vals[local_rank] + + stdout_rd = stdout_redirects[local_rank] + stderr_rd = stderr_redirects[local_rank] + + stdout_cm = get_std_cm(stdout_rd, redirect_stdout) + stderr_cm = get_std_cm(stderr_rd, redirect_stderr) + + for k, v in env_.items(): + os.environ[k] = v + + with stdout_cm, stderr_cm: + ret = record(fn)(*args_) + ret_val_.put(ret) + queue_finished_reading_event.wait() + + +class MultiprocessContext(PContext): + """``PContext`` holding worker processes invoked as a function.""" + + def __init__( + self, + name: str, + entrypoint: Callable, + args: dict[int, tuple], + envs: dict[int, dict[str, str]], + start_method: str, + logs_specs: LogsSpecs, + log_line_prefixes: Optional[dict[int, str]] = None, + ): + super().__init__( + name, + entrypoint, + args, + envs, + logs_specs, + log_line_prefixes, + ) + + self.start_method = start_method + # each ret_val queue will always contain a single element. + self._ret_vals = { + local_rank: mp.get_context(self.start_method).SimpleQueue() + for local_rank in range(self.nprocs) + } + + # see comments in ``join()`` for what this is + self._return_values: dict[int, Any] = {} + self._pc: Optional[mp.ProcessContext] = None + # Note: set method should ONLY be invoked for the use case when all processes finished + # successfully. If any process died on event.wait() calling set() method will deadlock. + self._worker_finished_event = mp.get_context(self.start_method).Event() + + def _start(self): + if self._pc: + raise ValueError( + "The process context already initialized." + " Most likely the start method got called twice." + ) + self._pc = mp.start_processes( + fn=_wrap, + args=( + self.entrypoint, + self.args, + self.envs, + self.stdouts, + self.stderrs, + self._ret_vals, + self._worker_finished_event, + ), + nprocs=self.nprocs, + join=False, + daemon=False, + start_method=self.start_method, + ) + + def _is_done(self) -> bool: + return len(self._return_values) == self.nprocs + + def _poll(self) -> Optional[RunProcsResult]: + assert self._pc is not None # assertion for mypy type checker + + try: + # torch.mp.ProcessContext Throws an Exception if some/all of + # worker processes failed + # timeout < 0 checks worker status and return immediately + # Join will never return success since we use synchronize.Event to wait + # for all processes to finish. + self._pc.join(-1) + + # IMPORTANT: we use multiprocessing.Queue to carry worker return values + # back to the parent, the worker process will wait before terminating + # until all the buffered items are fed by the feeder thread to the underlying + # pipe. Hence to prevent deadlocks on large return values, + # we opportunistically try queue.get on each join call + # See: https://docs.python.org/2/library/multiprocessing.html#all-platforms + for local_rank in range(0, self.nprocs): + return_queue = self._ret_vals[local_rank] + if not return_queue.empty(): + # save the return values temporarily into a member var + self._return_values[local_rank] = return_queue.get() + + if self._is_done(): + # we should ALWAYS have ALL the return values when all the processes are done + self._worker_finished_event.set() + + # At this point workers finished running the user function + # But the child process might still have not exited. Wait for them. + # pc.join() blocks [forever] until "a" proc exits. Loop until all of them exits. + while not self._pc.join(): + logger.debug( + "entrypoint fn finished, waiting for all child procs to exit..." + ) + + _validate_full_rank( + self._return_values, self.nprocs, "return_value queue" + ) + self.close() + return RunProcsResult( + return_values=self._return_values, + stdouts=self.stdouts, + stderrs=self.stderrs, + ) + else: + return None + except (mp.ProcessRaisedException, mp.ProcessExitedException) as e: + failed_local_rank = e.error_index + + # entrypoint for MultiprocessContext will always be a Callable + fn_name = self.entrypoint.__qualname__ # type: ignore[union-attr] + failed_proc = self._pc.processes[failed_local_rank] + error_filepath = self.error_files[failed_local_rank] + + logger.exception( + "failed (exitcode: %s)" + " local_rank: %s (pid: %s)" + " of fn: %s (start_method: %s)", + failed_proc.exitcode, + failed_local_rank, + e.pid, + fn_name, + self.start_method, + ) + + self.close() + return RunProcsResult( + failures={ + failed_local_rank: ProcessFailure( + local_rank=failed_local_rank, + pid=e.pid, + exitcode=failed_proc.exitcode, + error_file=error_filepath, + ) + }, + stdouts=self.stdouts, + stderrs=self.stderrs, + ) + + def pids(self) -> dict[int, int]: + assert self._pc is not None # assertion for mypy type checking + return dict(enumerate(self._pc.pids())) + + def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None: + if not self._pc: + return + for proc in self._pc.processes: + if proc.is_alive(): + logger.warning( + "Closing process %s via signal %s", proc.pid, death_sig.name + ) + try: + os.kill(proc.pid, death_sig) + except ProcessLookupError: + # If the process exited because of some reason, + # `ProcessLookupError` will be raised, it is safe to ignore it. + pass + end = time.monotonic() + timeout + for proc in self._pc.processes: + time_to_wait = end - time.monotonic() + if time_to_wait <= 0: + break + proc.join(time_to_wait) + for proc in self._pc.processes: + if proc.is_alive(): + logger.warning( + "Unable to shutdown process %s via %s, forcefully exiting via %s", + proc.pid, + death_sig, + _get_kill_signal(), + ) + try: + os.kill(proc.pid, _get_kill_signal()) + except ProcessLookupError: + # If the process exited because of some reason, + # `ProcessLookupError` will be raised, it is safe to ignore it. + pass + proc.join() + + +class SubprocessContext(PContext): + """``PContext`` holding worker processes invoked as a binary.""" + + def __init__( + self, + name: str, + entrypoint: str, + args: dict[int, tuple], + envs: dict[int, dict[str, str]], + logs_specs: LogsSpecs, + log_line_prefixes: Optional[dict[int, str]] = None, + ): + super().__init__( + name, + entrypoint, + args, + envs, + logs_specs, + log_line_prefixes, + ) + + # state vector; _vdone[local_rank] -> is local_rank finished or not + self._running_local_ranks: set[int] = set(range(self.nprocs)) + self._failures: dict[int, ProcessFailure] = {} + self.subprocess_handlers: dict[int, SubprocessHandler] = {} + + def _start(self): + if self.subprocess_handlers: + raise ValueError( + "The subprocess handlers already initialized. Most likely the start method got called twice." + ) + self.subprocess_handlers = { + local_rank: get_subprocess_handler( + entrypoint=self.entrypoint, # type: ignore[arg-type] # entrypoint is always a str + args=self.args[local_rank], + env=self.envs[local_rank], + stdout=self.stdouts[local_rank], + stderr=self.stderrs[local_rank], + local_rank_id=local_rank, + ) + for local_rank in range(self.nprocs) + } + + def _poll(self) -> Optional[RunProcsResult]: + done_local_ranks = set() + for local_rank in self._running_local_ranks: + handler = self.subprocess_handlers[local_rank] + exitcode = handler.proc.poll() + if exitcode is not None: + done_local_ranks.add(local_rank) + if exitcode != 0: # failed or signaled + self._failures[local_rank] = ProcessFailure( + local_rank=local_rank, + pid=handler.proc.pid, + exitcode=exitcode, + error_file=self.error_files[local_rank], + ) + # else: --> succeeded; nothing to do + + self._running_local_ranks.difference_update(done_local_ranks) + + # if ALL procs are finished or ANY have failed + if not self._running_local_ranks or self._failures: + self.close() # terminate all running procs + result = RunProcsResult( + failures=self._failures, + stdouts=self.stdouts, + stderrs=self.stderrs, + ) + if result.is_failed(): + first_failure = min(result.failures.values(), key=lambda f: f.timestamp) + logger.error( + "failed (exitcode: %s) local_rank: %s (pid: %s) of binary: %s", + first_failure.exitcode, + first_failure.local_rank, + first_failure.pid, + self.entrypoint, + ) + else: + # Populate return with dummy values. This provides consistency with MultiprocessingHandler + result.return_values = dict.fromkeys(range(self.nprocs)) + + return result + else: # there are no failures and procs still running + return None + + def pids(self) -> dict[int, int]: + return { + local_rank: sh.proc.pid + for local_rank, sh in self.subprocess_handlers.items() + } + + def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None: + if not self.subprocess_handlers: + return + for handler in self.subprocess_handlers.values(): + if handler.proc.poll() is None: + logger.warning( + "Sending process %s closing signal %s", + handler.proc.pid, + death_sig.name, + ) + handler.close(death_sig=death_sig) + end = time.monotonic() + timeout + for handler in self.subprocess_handlers.values(): + time_to_wait = end - time.monotonic() + if time_to_wait <= 0: + break + try: + handler.proc.wait(time_to_wait) + except subprocess.TimeoutExpired: + # Ignore the timeout expired exception, since + # the child process will be forcefully terminated via SIGKILL + pass + for handler in self.subprocess_handlers.values(): + if handler.proc.poll() is None: + logger.warning( + "Unable to shutdown process %s via %s, forcefully exiting via %s", + handler.proc.pid, + death_sig, + _get_kill_signal(), + ) + handler.close(death_sig=_get_kill_signal()) + handler.proc.wait() diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..acb63ef2e64a09a1bb2fcf16389ea8c08e179ea2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py @@ -0,0 +1,385 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Each host in a distributed PyTorch job runs with a single TorchElastic agent, +and multiple workers (as children processes of the TorchElastic agent). +Since the workers are user-provided (your PyTorch script/job), TorchElastic +has a way to propagate errors on the trainers through the agent and up to the +scheduler, which ultimately informs the end-user about the state of the job +and applies any retry policies. + +TorchElastic categorizes errors into 3 categories: + ++----------------+----------------+--------------------------------------------------------------+ +| Category | Sub-Category | Description | ++================+================+==============================================================+ +| User Error | Input Error | invalid inputs to TorchElastic APIs (e.g. min > max nodes) | +| +----------------+--------------------------------------------------------------+ +| | Worker Failure | any failures on the worker child process | ++----------------+----------------+--------------------------------------------------------------+ +| Platform Error | n/a | failures caused by the agent | ++----------------+----------------+--------------------------------------------------------------+ +| Infra Error | n/a | failures outside the domain of the agent and workers | +| | | (e.g. host failures) | ++----------------+----------------+--------------------------------------------------------------+ + +All errors other than "Worker Failure" are either raised canonically from the +agent process or implicitly or explicitly crash the agent process. So the +standard language (python) provided exception handling strategies apply. + +Worker Failures are special because the exception/failure originates on a different +process from the agent so the error needs to be propagated inter-process +(e.g. the agent cannot simply ``try-catch`` an exception raised on the worker process). + +TorchElastic agents use :func:`torch.distributed.elastic.multiprocessing.start_processes` +to launch the workers which has a simple file based inter-process error propagation +built-in. + +Any function or binary entrypoint decorated with :func:`record` +will write uncaught exceptions (with the trace information) to a file specified by the +environment variable ``TORCHELASTIC_ERROR_FILE``. The parent process (e.g. agent) +sets this env var on each child it launches, then aggregates the error files for all +children, and propagates the one with the **smallest** timestamp (e.g. the **first** error). +""" + +import json +import os +import signal +import socket +import time +from dataclasses import dataclass, field +from datetime import datetime +from functools import wraps +from string import Template +from typing import Any, Callable, Optional, TypeVar, Union +from typing_extensions import ParamSpec + +from torch.distributed.elastic.utils.logging import get_logger + +from .error_handler import ErrorHandler # noqa: F401 +from .handlers import get_error_handler # noqa: F401 + + +__all__ = [ + "ProcessFailure", + "ChildFailedError", + "record", + "ErrorHandler", + "get_error_handler", +] + +logger = get_logger(__name__) + + +JSON = dict + +_EMPTY_ERROR_DATA = {"message": ""} +_NOT_AVAILABLE = "" + +_R = TypeVar("_R") +_P = ParamSpec("_P") + + +@dataclass +class ProcessFailure: + """ + Represent the failed process result. When the worker process fails, it may record failure root cause into the file. + + Tries to read the failure timestamp from the provided ``error_file``, + if the ``error_file`` does not exist, the timestamp is the current + timestamp (seconds since epoch). + + The ``message`` field is a concise explanation of the failure. If + the error file exists then the message is obtained from the error file. + Otherwise one is generated based on the failure signature. + + .. note:: It is assumed that the ``error_file`` is written by + ``torch.distributed.elastic.multiprocessing.errors.error_handler.ErrorHandler``. + Otherwise the behavior is undefined. + + """ + + local_rank: int + pid: int + exitcode: int + error_file: str + error_file_data: JSON = field(init=False) + message: str = field(init=False) + timestamp: int = field(init=False) + + def __post_init__(self): + self.error_file_data = _EMPTY_ERROR_DATA + if os.path.isfile(self.error_file): + try: + with open(self.error_file) as fp: + self.error_file_data = json.load(fp) + logger.debug( + "User process failed with error data: %s", + json.dumps(self.error_file_data, indent=2), + ) + self.message, self.timestamp = self._get_error_data( + self.error_file_data + ) + except Exception: + logger.exception("Failed to parse reply file: %s", self.error_file) + raise + else: + self._set_no_reply_file() + + # make up an informative message if not already present + if not self.message: + # signals typically do not generate an error file message + if self.exitcode < 0: + self.message = ( + f"Signal {-self.exitcode} ({self.signal_name()})" + f" received by PID {self.pid}" + ) + else: + self.message = "To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html" + + def _get_error_data(self, error_file_data: dict[str, Any]) -> tuple[str, int]: + message = error_file_data["message"] + if isinstance(message, str): + timestamp = int(error_file_data.get("timestamp", 0)) + else: + timestamp = int(message["extraInfo"]["timestamp"]) + return (message, timestamp) + + def _set_no_reply_file(self): + self.error_file = _NOT_AVAILABLE + self.error_file_data = _EMPTY_ERROR_DATA + self.message = "" + self.timestamp = int(time.time()) + + def signal_name(self) -> str: + if self.exitcode < 0: + # We don't want to kill the parent process trying to find the signal name. + # if the signal doesn't map to a known name, use not available. + try: + return signal.Signals(-self.exitcode).name + except Exception: + return _NOT_AVAILABLE + else: + return _NOT_AVAILABLE + + def timestamp_isoformat(self): + """Return timestamp in ISO format (YYYY-MM-DD_HH:MM:SS).""" + return datetime.fromtimestamp(self.timestamp).isoformat(sep="_") + + +GlobalRank = int + +_FAILURE_FORMAT_TEMPLATE = """[${idx}]: + time : ${time} + host : ${hostname} + rank : ${rank} (local_rank: ${local_rank}) + exitcode : ${exitcode} (pid: ${pid}) + error_file: ${error_file} + traceback : ${message}""" + +# extra new lines before and after are intentional +_MSG_FORMAT_TEMPLATE = """ +${boarder} +${title} +${section} +Failures: +${other_failures} +${section} +Root Cause (first observed failure): +${root_failure} +${boarder}""" + + +class ChildFailedError(Exception): + """ + Special exception type that can be raised from a function annotated with the + ``@record`` decorator to have the child process' (root exception) propagate + up the stack as-is (e.g. without being wrapped in the parent's traceback). + + Useful in cases where the parent is a simple nanny process + and the child (worker) processes are actually doing meaningful compute. + In this case, errors typically occur on the child process as the parent + is not doing anything non-trivial, and child errors should be propagated + to the scheduler for accurate root cause diagnostics. + + .. note:: The propagation relies on error files rather than exception handling to + support both function and binary launches. + + Example: + :: + + # process tree on a host (container) + 0: scheduler-init-process: + |- 1: torchelastic_agent: + |- 2: trainer_0 (ok) + |- 3: trainer_1 (fail) -> error.json + |- ... + |- n+2: trainer_n (ok) + |- n+3: other processes + |- ... + + In the example above, trainer 1's failure (written into error.json) is + the root cause and should be reported to the scheduler's init process. + The torchelastic agent raises a ``ChildFailedError("trainer", {1: "trainer_1/error.json"})`` + upon detecting trainer 1's failure which would propagate the contents + of trainer 1's error file to the scheduler's init process. + """ + + def __init__(self, name: str, failures: dict[GlobalRank, ProcessFailure]): + self.name = name + self.failures = failures + assert ( + self.failures + ) # does not make sense to create a ChildFaileError with no failures + super().__init__(self.format_msg()) + + def get_first_failure(self) -> tuple[GlobalRank, ProcessFailure]: + rank = min(self.failures.keys(), key=lambda r: self.failures[r].timestamp) + return rank, self.failures[rank] + + def format_msg(self, boarder_delim="=", section_delim="-"): + title = f"{self.name} FAILED" + root_rank, _root_failure = self.get_first_failure() + + root_failure_fmt: str = "" + other_failures_fmt: list[str] = [] + width = len(title) + for idx, (rank, failure) in enumerate(self.failures.items()): + fmt, w = self._format_failure(idx, rank, failure) + width = max(width, w) + if rank == root_rank: + root_failure_fmt = fmt + else: + other_failures_fmt.append(fmt) + + # upper boundary on width + width = min(width, 60) + + return Template(_MSG_FORMAT_TEMPLATE).substitute( + boarder=boarder_delim * width, + title=title, + section=section_delim * width, + root_failure=root_failure_fmt, + other_failures="\n".join(other_failures_fmt or [" "]), + ) + + def _format_failure( + self, idx: int, rank: int, failure: ProcessFailure + ) -> tuple[str, int]: + # failure.message is either a str (when the failure does not generate a traceback - e.g. signals) + # or a dict (json) of the form + # {"message": $ERROR_MSG, "extraInfo": {"py_callstack": $TRACEBACK, timestamp: $TS}} + # so the display logic is: + # 1. if failure.message is not a dict (it is a str) just show it as is + # 2. else try to get the traceback (py_callstack) + # 3. if the traceback is not there, use the message + # 4. if the message is not there show + msg = failure.message + if isinstance(failure.message, dict): + msg = ( + failure.message.get("extraInfo", {}) + .get("py_callstack", failure.message.get("message", "")) + .replace("\n", "\n ") # to properly indent the traceback + ) + + fmt = Template(_FAILURE_FORMAT_TEMPLATE).substitute( + idx=idx, + time=failure.timestamp_isoformat(), + hostname=socket.getfqdn(), + rank=rank, + local_rank=failure.local_rank, + exitcode=failure.exitcode, + pid=failure.pid, + error_file=failure.error_file, + message=msg, + ) + width = 0 + for line in fmt.split("\n"): + width = max(width, len(line)) + return fmt, width + + +def record( + fn: Callable[_P, _R], error_handler: Optional[ErrorHandler] = None +) -> Callable[_P, Union[_R, None]]: + """ + Syntactic sugar to record errors/exceptions that happened in the decorated + function using the provided ``error_handler``. + + Using this decorator is equivalent to: + + :: + + error_handler = get_error_handler() + error_handler.initialize() + try: + foobar() + except ChildFailedError as e: + _, failure = e.get_first_failure() + error_handler.dump_error_file(failure.error_file, failure.exitcode) + raise + except Exception as e: + error_handler.record_exception(e) + raise + + .. important:: use this decorator once per process at the top level method, + typically this is the main method. + + Example + + :: + + @record + def main(): + pass + + + if __name__ == "__main__": + main() + + """ + if not error_handler: + error_handler = get_error_handler() + + def wrap(f: Callable[_P, _R]) -> Callable[_P, Union[_R, None]]: + @wraps(f) + def wrapper(*args: _P.args, **kwargs: _P.kwargs): + assert error_handler is not None # assertion for mypy type checker + error_handler.initialize() + try: + return f(*args, **kwargs) + except SystemExit as se: + # For run_path based entrypoints, SystemExit with code = 0 will never exit. + # Handling it here by returning a value: + if se.code == 0: + return None + else: + raise + except ChildFailedError as e: + rank, failure = e.get_first_failure() + if failure.error_file != _NOT_AVAILABLE: + error_handler.dump_error_file(failure.error_file, failure.exitcode) + else: + logger.info( + ( + "local_rank %s FAILED with no error file." + " Decorate your entrypoint fn with @record for traceback info." + " See: https://pytorch.org/docs/stable/elastic/errors.html", + rank, + ) + ) + raise + except Exception as e: + error_handler.record_exception(e) + raise + + return wrapper + + return wrap(fn) diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/errors/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/errors/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36f22711bae23996547e3b74c2a58a9a08935d30 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/errors/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/errors/__pycache__/error_handler.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/errors/__pycache__/error_handler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb9cb25f829c880f4ce313396ca67984dc3c884c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/errors/__pycache__/error_handler.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/errors/__pycache__/handlers.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/errors/__pycache__/handlers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06140179417e2618b3317d40c78a737358b6db5e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/errors/__pycache__/handlers.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/errors/error_handler.py b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/errors/error_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..cd15a7e61e113bc5f3d5f7ff1b4cd8daa3c5b10d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/errors/error_handler.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import faulthandler +import json +import logging +import os +import time +import traceback +import warnings +from typing import Any, Optional + + +__all__ = ["ErrorHandler"] + +logger = logging.getLogger(__name__) + + +class ErrorHandler: + """ + Write the provided exception object along with some other metadata about + the error in a structured way in JSON format to an error file specified by the + environment variable: ``TORCHELASTIC_ERROR_FILE``. If this environment + variable is not set, then simply logs the contents of what would have been + written to the error file. + + This handler may be subclassed to customize the handling of the error. + Subclasses should override ``initialize()`` and ``record_exception()``. + """ + + def _get_error_file_path(self) -> Optional[str]: + """ + Return the error file path. + + May return ``None`` to have the structured error be logged only. + """ + return os.environ.get("TORCHELASTIC_ERROR_FILE", None) + + def initialize(self) -> None: + """ + Call prior to running code that we wish to capture errors/exceptions. + + Typically registers signal/fault handlers. Users can override this + function to add custom initialization/registrations that aid in + propagation/information of errors/signals/exceptions/faults. + """ + try: + faulthandler.enable(all_threads=True) + except Exception as e: + warnings.warn(f"Unable to enable fault handler. {type(e).__name__}: {e}") + + def _write_error_file(self, file_path: str, error_msg: str) -> None: + """Write error message to the file.""" + try: + with open(file_path, "w") as fp: + fp.write(error_msg) + except Exception as e: + warnings.warn(f"Unable to write error to file. {type(e).__name__}: {e}") + + def record_exception(self, e: BaseException) -> None: + """ + Write a structured information about the exception into an error file in JSON format. + + If the error file cannot be determined, then logs the content + that would have been written to the error file. + """ + file = self._get_error_file_path() + if file: + data = { + "message": { + "message": f"{type(e).__name__}: {e}", + "extraInfo": { + "py_callstack": traceback.format_exc(), + "timestamp": str(int(time.time())), + }, + } + } + with open(file, "w") as fp: + json.dump(data, fp) + + def override_error_code_in_rootcause_data( + self, + rootcause_error_file: str, + rootcause_error: dict[str, Any], + error_code: int = 0, + ): + """Modify the rootcause_error read from the file, to correctly set the exit code.""" + if "message" not in rootcause_error: + logger.warning( + "child error file (%s) does not have field `message`. \n" + "cannot override error code: %s", + rootcause_error_file, + error_code, + ) + elif isinstance(rootcause_error["message"], str): + logger.warning( + "child error file (%s) has a new message format. \n" + "skipping error code override", + rootcause_error_file, + ) + else: + rootcause_error["message"]["errorCode"] = error_code + + def dump_error_file(self, rootcause_error_file: str, error_code: int = 0): + """Dump parent error file from child process's root cause error and error code.""" + with open(rootcause_error_file) as fp: + rootcause_error = json.load(fp) + # Override error code since the child process cannot capture the error code if it + # is terminated by signals like SIGSEGV. + if error_code: + self.override_error_code_in_rootcause_data( + rootcause_error_file, rootcause_error, error_code + ) + logger.debug( + "child error file (%s) contents:\n%s", + rootcause_error_file, + json.dumps(rootcause_error, indent=2), + ) + + my_error_file = self._get_error_file_path() + if my_error_file: + # Guard against existing error files + # This can happen when the child is created using multiprocessing + # and the same env var (TORCHELASTIC_ERROR_FILE) is used on the + # parent and child to specify the error files (respectively) + # because the env vars on the child is set in the wrapper function + # and by default the child inherits the parent's env vars, if the child + # process receives a signal before the wrapper function kicks in + # and the signal handler writes to the error file, then the child + # will write to the parent's error file. In this case just log the + # original error file contents and overwrite the error file. + self._rm(my_error_file) + self._write_error_file(my_error_file, json.dumps(rootcause_error)) + logger.info("dumped error file to parent's %s", my_error_file) + else: + logger.error( + "no error file defined for parent, to copy child error file (%s)", + rootcause_error_file, + ) + + def _rm(self, my_error_file): + if os.path.isfile(my_error_file): + # Log the contents of the original file. + with open(my_error_file) as fp: + try: + original = json.dumps(json.load(fp), indent=2) + logger.warning( + "%s already exists" + " and will be overwritten." + " Original contents:\n%s", + my_error_file, + original, + ) + except json.decoder.JSONDecodeError: + logger.warning( + "%s already exists" + " and will be overwritten." + " Unable to load original contents:\n", + my_error_file, + ) + os.remove(my_error_file) diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/errors/handlers.py b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/errors/handlers.py new file mode 100644 index 0000000000000000000000000000000000000000..7d582ae50fd9312154f50d3f760634b600c11fbc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/errors/handlers.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# Multiprocessing error-reporting module + + +from torch.distributed.elastic.multiprocessing.errors.error_handler import ErrorHandler + + +__all__ = ["get_error_handler"] + + +def get_error_handler() -> ErrorHandler: + return ErrorHandler() diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/redirects.py b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/redirects.py new file mode 100644 index 0000000000000000000000000000000000000000..4553fbebdec8278328d205cccf4ce749abf08926 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/redirects.py @@ -0,0 +1,104 @@ +# mypy: allow-untyped-defs +# !/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Taken and modified from original source: +# https://eli.thegreenplace.net/2015/redirecting-all-kinds-of-stdout-in-python/ +import ctypes +import logging +import os +import sys +from contextlib import contextmanager +from functools import partial + + +IS_WINDOWS = sys.platform == "win32" +IS_MACOS = sys.platform == "darwin" + + +logger = logging.getLogger(__name__) + + +def get_libc(): + if IS_WINDOWS or IS_MACOS: + logger.warning( + "NOTE: Redirects are currently not supported in Windows or MacOs." + ) + return None + else: + return ctypes.CDLL("libc.so.6") + + +libc = get_libc() + + +def _c_std(stream: str): + return ctypes.c_void_p.in_dll(libc, stream) + + +def _python_std(stream: str): + return {"stdout": sys.stdout, "stderr": sys.stderr}[stream] + + +_VALID_STD = {"stdout", "stderr"} + + +@contextmanager +def redirect(std: str, to_file: str): + """ + Redirect ``std`` (one of ``"stdout"`` or ``"stderr"``) to a file in the path specified by ``to_file``. + + This method redirects the underlying std file descriptor (not just python's ``sys.stdout|stderr``). + See usage for details. + + Directory of ``dst_filename`` is assumed to exist and the destination file + is overwritten if it already exists. + + .. note:: Due to buffering cross source writes are not guaranteed to + appear in wall-clock order. For instance in the example below + it is possible for the C-outputs to appear before the python + outputs in the log file. + + Usage: + + :: + + # syntactic-sugar for redirect("stdout", "tmp/stdout.log") + with redirect_stdout("/tmp/stdout.log"): + print("python stdouts are redirected") + libc = ctypes.CDLL("libc.so.6") + libc.printf(b"c stdouts are also redirected" + os.system("echo system stdouts are also redirected") + + print("stdout restored") + + """ + if std not in _VALID_STD: + raise ValueError( + f"unknown standard stream <{std}>, must be one of {_VALID_STD}" + ) + + c_std = _c_std(std) + python_std = _python_std(std) + std_fd = python_std.fileno() + + def _redirect(dst): + libc.fflush(c_std) + python_std.flush() + os.dup2(dst.fileno(), std_fd) + + with os.fdopen(os.dup(std_fd)) as orig_std, open(to_file, mode="w+b") as dst: + _redirect(dst) + try: + yield + finally: + _redirect(orig_std) + + +redirect_stdout = partial(redirect, "stdout") +redirect_stderr = partial(redirect, "stderr") diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__init__.py b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..21322118e77c9dc6891b3ffaae464e092ef5fbe4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__init__.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from torch.distributed.elastic.multiprocessing.subprocess_handler.handlers import ( + get_subprocess_handler, +) +from torch.distributed.elastic.multiprocessing.subprocess_handler.subprocess_handler import ( + SubprocessHandler, +) + + +__all__ = ["SubprocessHandler", "get_subprocess_handler"] diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd5567d171d15ac22e5c255f10caa956d2f03898 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/handlers.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/handlers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b44370217eae3f500c628e86891be7040b30182 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/handlers.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/subprocess_handler.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/subprocess_handler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16cc69ef870b150f29e7baa1a19ee500e1f0abea Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/subprocess_handler.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py new file mode 100644 index 0000000000000000000000000000000000000000..cc510fa32f7dada66305e59c9fbc67303262bec2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py @@ -0,0 +1,30 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torch.distributed.elastic.multiprocessing.subprocess_handler.subprocess_handler import ( + SubprocessHandler, +) + + +__all__ = ["get_subprocess_handler"] + + +def get_subprocess_handler( + entrypoint: str, + args: tuple, + env: dict[str, str], + stdout: str, + stderr: str, + local_rank_id: int, +) -> SubprocessHandler: + return SubprocessHandler( + entrypoint=entrypoint, + args=args, + env=env, + stdout=stdout, + stderr=stderr, + local_rank_id=local_rank_id, + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..ff79ca383e3b7e313fcf002eb99ba2679369fc84 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import os +import signal +import subprocess +import sys +from typing import Any, Optional + + +__all__ = ["SubprocessHandler"] + +IS_WINDOWS = sys.platform == "win32" + + +def _get_default_signal() -> signal.Signals: + """Get the default termination signal. SIGTERM for unix, CTRL_C_EVENT for windows.""" + if IS_WINDOWS: + return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821 + else: + return signal.SIGTERM + + +class SubprocessHandler: + """ + Convenience wrapper around python's ``subprocess.Popen``. Keeps track of + meta-objects associated to the process (e.g. stdout and stderr redirect fds). + """ + + def __init__( + self, + entrypoint: str, + args: tuple, + env: dict[str, str], + stdout: Optional[str], + stderr: Optional[str], + local_rank_id: int, + ): + self._stdout = open(stdout, "w") if stdout else None + self._stderr = open(stderr, "w") if stderr else None + # inherit parent environment vars + env_vars = os.environ.copy() + env_vars.update(env) + + args_str = (entrypoint, *[str(e) for e in args]) + self.local_rank_id = local_rank_id + self.proc: subprocess.Popen = self._popen(args_str, env_vars) + + def _popen(self, args: tuple, env: dict[str, str]) -> subprocess.Popen: + kwargs: dict[str, Any] = {} + if not IS_WINDOWS: + kwargs["start_new_session"] = True + return subprocess.Popen( + # pyre-fixme[6]: Expected `Union[typing.Sequence[Union[_PathLike[bytes], + # _PathLike[str], bytes, str]], bytes, str]` for 1st param but got + # `Tuple[str, *Tuple[Any, ...]]`. + args=args, + env=env, + stdout=self._stdout, + stderr=self._stderr, + **kwargs, + ) + + def close(self, death_sig: Optional[signal.Signals] = None) -> None: + if not death_sig: + death_sig = _get_default_signal() + if IS_WINDOWS: + self.proc.send_signal(death_sig) + else: + os.killpg(self.proc.pid, death_sig) + if self._stdout: + self._stdout.close() + if self._stderr: + self._stderr.close() diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/tail_log.py b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/tail_log.py new file mode 100644 index 0000000000000000000000000000000000000000..f72b8b68340139c728707a350b3790f40d2b09ea --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/multiprocessing/tail_log.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import time +from concurrent.futures.thread import ThreadPoolExecutor +from threading import Event +from typing import Optional, TextIO, TYPE_CHECKING + + +if TYPE_CHECKING: + from concurrent.futures._base import Future + +__all__ = ["tail_logfile", "TailLog"] + +logger = logging.getLogger(__name__) + + +def tail_logfile( + header: str, file: str, dst: TextIO, finished: Event, interval_sec: float +): + while not os.path.exists(file): + if finished.is_set(): + return + time.sleep(interval_sec) + + with open(file, errors="replace") as fp: + while True: + line = fp.readline() + + if line: + dst.write(f"{header}{line}") + else: # reached EOF + if finished.is_set(): + # log line producer is finished + break + else: + # log line producer is still going + # wait for a bit before looping again + time.sleep(interval_sec) + + +class TailLog: + """ + Tail the given log files. + + The log files do not have to exist when the ``start()`` method is called. The tail-er will gracefully wait until + the log files are created by the producer and will tail the contents of the + log files until the ``stop()`` method is called. + + .. warning:: ``TailLog`` will wait indefinitely for the log file to be created! + + Each log file's line will be suffixed with a header of the form: ``[{name}{idx}]:``, + where the ``name`` is user-provided and ``idx`` is the index of the log file + in the ``log_files`` mapping. ``log_line_prefixes`` can be used to override the + header for each log file. + + Usage: + + :: + + log_files = {0: "/tmp/0_stdout.log", 1: "/tmp/1_stdout.log"} + tailer = TailLog("trainer", log_files, sys.stdout).start() + # actually run the trainers to produce 0_stdout.log and 1_stdout.log + run_trainers() + tailer.stop() + + # once run_trainers() start writing the ##_stdout.log files + # the tailer will print to sys.stdout: + # >>> [trainer0]:log_line1 + # >>> [trainer1]:log_line1 + # >>> [trainer0]:log_line2 + # >>> [trainer0]:log_line3 + # >>> [trainer1]:log_line2 + + .. note:: Due to buffering log lines between files may not necessarily + be printed out in order. You should configure your application's + logger to suffix each log line with a proper timestamp. + + """ + + def __init__( + self, + name: str, + log_files: dict[int, str], + dst: TextIO, + log_line_prefixes: Optional[dict[int, str]] = None, + interval_sec: float = 0.1, + ): + n = len(log_files) + self._threadpool = None + if n > 0: + self._threadpool = ThreadPoolExecutor( + max_workers=n, + thread_name_prefix=f"{self.__class__.__qualname__}_{name}", + ) + + self._name = name + self._dst = dst + self._log_files = log_files + self._log_line_prefixes = log_line_prefixes + self._finished_events: dict[int, Event] = { + local_rank: Event() for local_rank in log_files.keys() + } + self._futs: list[Future] = [] + self._interval_sec = interval_sec + self._stopped = False + + def start(self) -> "TailLog": + if not self._threadpool: + return self + + for local_rank, file in self._log_files.items(): + header = f"[{self._name}{local_rank}]:" + if self._log_line_prefixes and local_rank in self._log_line_prefixes: + header = self._log_line_prefixes[local_rank] + self._futs.append( + self._threadpool.submit( + tail_logfile, + header=header, + file=file, + dst=self._dst, + finished=self._finished_events[local_rank], + interval_sec=self._interval_sec, + ) + ) + return self + + def stop(self) -> None: + for finished in self._finished_events.values(): + finished.set() + + for local_rank, f in enumerate(self._futs): + try: + f.result() + except Exception as e: + logger.error( + "error in log tailor for %s%s. %s: %s", + self._name, + local_rank, + e.__class__.__qualname__, + e, + ) + + if self._threadpool: + self._threadpool.shutdown(wait=True) + + self._stopped = True + + def stopped(self) -> bool: + return self._stopped diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/_etcd_stub.py b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/_etcd_stub.py new file mode 100644 index 0000000000000000000000000000000000000000..ccdc7f1127a19e6473bd8e0ab4899a04f89a0668 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/_etcd_stub.py @@ -0,0 +1,75 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Optional + + +""" +This file is not meant to be used directly. It serves as a stub to allow +other files to be safely imported without requiring the installation of +the 'etcd' library. The classes and methods here raise exceptions to +indicate that the real 'etcd' module is needed. +""" + + +class EtcdStubError(ImportError): + """Custom exception to indicate that the real etcd module is required.""" + + def __init__(self) -> None: + super().__init__("The 'etcd' module is required but not installed.") + + +class EtcdAlreadyExist(Exception): + def __init__(self, *args: Any, **kwargs: Any) -> None: + raise EtcdStubError + + +class EtcdCompareFailed(Exception): + def __init__(self, *args: Any, **kwargs: Any) -> None: + raise EtcdStubError + + +class EtcdKeyNotFound(Exception): + def __init__(self, *args: Any, **kwargs: Any) -> None: + raise EtcdStubError + + +class EtcdWatchTimedOut(Exception): + def __init__(self, *args: Any, **kwargs: Any) -> None: + raise EtcdStubError + + +class EtcdEventIndexCleared(Exception): + def __init__(self, *args: Any, **kwargs: Any) -> None: + raise EtcdStubError + + +class EtcdException(Exception): + def __init__(self, *args: Any, **kwargs: Any) -> None: + raise EtcdStubError + + +class EtcdResult: + def __init__(self) -> None: + raise EtcdStubError + + +class Client: + def __init__(self, *args: Any, **kwargs: Any) -> None: + raise EtcdStubError + + def read(self, key: str) -> None: + raise EtcdStubError + + def write( + self, key: str, value: Any, ttl: Optional[int] = None, **kwargs: Any + ) -> None: + raise EtcdStubError + + def test_and_set( + self, key: str, value: Any, prev_value: Any, ttl: Optional[int] = None + ) -> None: + raise EtcdStubError diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/api.py b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/api.py new file mode 100644 index 0000000000000000000000000000000000000000..f36264ea06be6090ac096fb56173bcd645e3cc06 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/api.py @@ -0,0 +1,390 @@ +# mypy: allow-untyped-defs +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import socket +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Callable, ClassVar, Optional + +from torch.distributed import Store +from torch.distributed.elastic.utils.distributed import get_free_port + + +__all__ = [ + "RendezvousClosedError", + "RendezvousConnectionError", + "RendezvousError", + "RendezvousGracefulExitError", + "RendezvousHandler", + "RendezvousHandlerCreator", + "RendezvousHandlerRegistry", + "RendezvousInfo", + "RendezvousParameters", + "RendezvousStateError", + "RendezvousStoreInfo", + "RendezvousTimeoutError", + "rendezvous_handler_registry", +] + + +class RendezvousError(Exception): + """Represents the base type for rendezvous errors.""" + + +class RendezvousClosedError(RendezvousError): + """Raised when a rendezvous is closed.""" + + +class RendezvousTimeoutError(RendezvousError): + """Raised when a rendezvous did not complete on time.""" + + +class RendezvousConnectionError(RendezvousError): + """Raised when the connection to a rendezvous backend has failed.""" + + +class RendezvousStateError(RendezvousError): + """Raised when the state of a rendezvous is corrupt.""" + + +class RendezvousGracefulExitError(RendezvousError): + """Raised when node wasn't not included in rendezvous and gracefully exits. + + Exception is a mechanism to exit the stack, however does not mean a failure. + """ + + +@dataclass +class RendezvousStoreInfo: + """Store address and port that can be used to bootstrap trainer distributed comms""" + + MASTER_ADDR_KEY: ClassVar[str] = "MASTER_ADDR" + MASTER_PORT_KEY: ClassVar[str] = "MASTER_PORT" + master_addr: str + master_port: int + + @staticmethod + def build( + rank: int, + store: Store, + local_addr: Optional[str], + server_port: Optional[int] = None, + ) -> "RendezvousStoreInfo": + """Factory method, finds unused new port on rank0 host and addr/port info with all ranks. + + If master_addr/master_port is knowns (useful when sharing existing tcp store server) use the constructor. + + Args: + rank: rank of the current node + store: store to use for rendezvous + local_addr: address of the current node, if not provided will be resolved from hostname + server_port: port of the TCPStore server, when the TCPStore is shared. + """ + # TODO swap to collectives comms API + if rank == 0: + addr = local_addr or socket.getfqdn() + # When TCPStore is not shared, we fallback to get_free_port. + port = server_port or get_free_port() + store.set( + RendezvousStoreInfo.MASTER_ADDR_KEY, + addr.encode(encoding="UTF-8"), # type: ignore[arg-type] + ) + store.set( + RendezvousStoreInfo.MASTER_PORT_KEY, + str(port).encode(encoding="UTF-8"), # type: ignore[arg-type] + ) + + addr = store.get(RendezvousStoreInfo.MASTER_ADDR_KEY).decode(encoding="UTF-8") + port = int( + store.get(RendezvousStoreInfo.MASTER_PORT_KEY).decode(encoding="UTF-8") + ) + return RendezvousStoreInfo(master_addr=addr, master_port=port) + + +class RendezvousInfo: + """Holds the information about the rendezvous.""" + + def __init__( + self, + store: Store, + rank: int, + world_size: int, + bootstrap_store_info: RendezvousStoreInfo, + ): + self._store = store + self._rank = rank + self._world_size = world_size + self._bootstrap_store_info = bootstrap_store_info + + @property + def store(self) -> Store: + """Store used by torchelastic control plane""" + return self._store + + @property + def rank(self) -> int: + """Rank within a group""" + return self._rank + + @property + def world_size(self) -> int: + """Global group size""" + return self._world_size + + @property + def bootstrap_store_info(self) -> Optional[RendezvousStoreInfo]: + """Store information that can used by trainer code to bootstrap distributed comms.""" + return self._bootstrap_store_info + + +class RendezvousHandler(ABC): + """Main rendezvous interface. + + Note: + Distributed Torch users normally **do not** need to implement their own + ``RendezvousHandler``. An implementation based on C10d Store is already + provided, and is recommended for most users. + """ + + @abstractmethod + def get_backend(self) -> str: + """Return the name of the rendezvous backend.""" + + @property + def use_agent_store(self) -> bool: + """Indicates that store reference returned by :py:meth:`next_rendezvous` can be shared with user + applications and will be available during application lifecycle. + + Rendezvous handler impl will share store details as instance of :py:class:`RendezvousStoreInfo`. + Applications as a convention use `MASTER_ADDR`/`MASTER_PORT` env variables to lookup the store. + """ + return False + + @abstractmethod + def next_rendezvous(self) -> RendezvousInfo: + """Main entry-point into the rendezvous barrier. + + Blocks until the rendezvous is complete and the current process is + included in the formed worker group, or a timeout occurs, or the + rendezvous was marked closed. + + Returns: + Instance of :py:class:`RendezvousInfo`. + + Raises: + RendezvousClosedError: + The rendezvous is closed. + RendezvousConnectionError: + The connection to the rendezvous backend has failed. + RendezvousStateError: + The rendezvous state is corrupt. + RendezvousTimeoutError: + The rendezvous did not complete on time. + """ + + @abstractmethod + def is_closed(self) -> bool: + """Check whether the rendezvous has been closed. + + A closed rendezvous means all future attempts to re-rendezvous within + same job will fail. + + ``is_closed()`` and :py:meth:`set_closed` have semantics of eventual + propagation and should not be used for synchronization. The intention is + that if at least one node decides the job is finished, it will close the + rendezvous, and other nodes will soon observe this and stop running as + well. + """ + + @abstractmethod + def set_closed(self): + """Mark the rendezvous as closed.""" + + @abstractmethod + def num_nodes_waiting(self) -> int: + """Return the number of nodes who arrived late at the rendezvous + barrier, hence were not included in the current worker group. + + Callers should periodically call this method to check whether new + nodes are waiting to join the job and if so admit them by calling + :py:meth:`next_rendezvous()` (re-rendezvous). + """ + + @abstractmethod + def get_run_id(self) -> str: + """Return the run id of the rendezvous. + + The run id is a user-defined id that uniquely identifies an instance of + a distributed application. It typically maps to a job id and is used to + allow nodes to join the correct distributed application. + """ + + @abstractmethod + def shutdown(self) -> bool: + """Close all resources that were open for the rendezvous. + + Example:: + + rdzv_handler = ... + try: + store, rank, world_size = rdzv_handler.next_rendezvous() + finally: + rdzv_handler.shutdown() + """ + + +class RendezvousParameters: + """Hold the parameters to construct a :py:class:`RendezvousHandler`. + + Args: + backend: + The name of the backend to use to handle the rendezvous. + endpoint: + The endpoint of the rendezvous, usually in form [:]. + run_id: + The id of the rendezvous. + min_nodes: + The minimum number of nodes to admit to the rendezvous. + max_nodes: + The maximum number of nodes to admit to the rendezvous. + local_addr: + The address of the local node. + **kwargs: + Additional parameters for the specified backend. + """ + + def __init__( + self, + backend: str, + endpoint: str, + run_id: str, + min_nodes: int, + max_nodes: int, + local_addr: Optional[str] = None, + **kwargs, + ): + if not backend: + raise ValueError("The rendezvous backend name must be a non-empty string.") + + if min_nodes < 1: + raise ValueError( + f"The minimum number of rendezvous nodes ({min_nodes}) must be greater than zero." + ) + if max_nodes < min_nodes: + raise ValueError( + f"The maximum number of rendezvous nodes ({max_nodes}) must be greater than or " + f"equal to the minimum number of rendezvous nodes ({min_nodes})." + ) + + self.backend = backend + self.endpoint = endpoint + self.run_id = run_id + self.min_nodes = min_nodes + self.max_nodes = max_nodes + self.config = kwargs + self.local_addr = local_addr + + def get(self, key: str, default: Any = None) -> Any: + """Return the value for ``key`` if ``key`` exists, else ``default``.""" + return self.config.get(key, default) + + def get_as_bool(self, key: str, default: Optional[bool] = None) -> Optional[bool]: + """Return the value for ``key`` as a ``bool``.""" + value = self.get(key, default) + if value is None or isinstance(value, bool): + return value + if isinstance(value, int): + if value == 1: + return True + if value == 0: + return False + elif isinstance(value, str): + if value.lower() in ["1", "true", "t", "yes", "y"]: + return True + if value.lower() in ["0", "false", "f", "no", "n"]: + return False + raise ValueError( + f"The rendezvous configuration option '{key}' does not represent a valid boolean value." + ) + + def get_as_int(self, key: str, default: Optional[int] = None) -> Optional[int]: + """Return the value for ``key`` as an ``int``.""" + value = self.get(key, default) + if value is None: + return value + try: + return int(value) + except ValueError as e: + raise ValueError( + f"The rendezvous configuration option '{key}' does not represent a valid integer " + "value." + ) from e + + +RendezvousHandlerCreator = Callable[[RendezvousParameters], RendezvousHandler] + + +class RendezvousHandlerRegistry: + """Represent a registry of :py:class:`RendezvousHandler` backends.""" + + _registry: dict[str, RendezvousHandlerCreator] + + def __init__(self) -> None: + self._registry = {} + + def register(self, backend: str, creator: RendezvousHandlerCreator) -> None: + """Register a new rendezvous backend. + + Args: + backend: + The name of the backend. + creator: + The callback to invoke to construct the + :py:class:`RendezvousHandler`. + """ + if not backend: + raise ValueError("The rendezvous backend name must be a non-empty string.") + + current_creator: Optional[RendezvousHandlerCreator] + try: + current_creator = self._registry[backend] + except KeyError: + current_creator = None + + if current_creator is not None and current_creator != creator: + raise ValueError( + f"The rendezvous backend '{backend}' cannot be registered with '{creator}' as it " + f"is already registered with '{current_creator}'." + ) + + self._registry[backend] = creator + + def create_handler(self, params: RendezvousParameters) -> RendezvousHandler: + """Create a new :py:class:`RendezvousHandler`.""" + try: + creator = self._registry[params.backend] + except KeyError as e: + raise ValueError( + f"The rendezvous backend '{params.backend}' is not registered. Did you forget " + f"to call `{self.register.__name__}`?" + ) from e + + handler = creator(params) + + # Do some sanity check. + if handler.get_backend() != params.backend: + raise RuntimeError( + f"The rendezvous backend '{handler.get_backend()}' does not match the requested " + f"backend '{params.backend}'." + ) + + return handler + + +# The default global registry instance used by launcher scripts to instantiate +# rendezvous handlers. +rendezvous_handler_registry = RendezvousHandlerRegistry() diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..ef3562444c23e265a814122bc20b5878d5cce3ba --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py @@ -0,0 +1,273 @@ +# mypy: allow-untyped-defs +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import binascii +import logging +import os +import tempfile +from base64 import b64decode, b64encode +from datetime import timedelta +from typing import Any, cast, Optional + +from torch.distributed import FileStore, Store, TCPStore +from torch.distributed.elastic.events import construct_and_record_rdzv_event, NodeState + +from .api import ( + RendezvousConnectionError, + RendezvousError, + RendezvousParameters, + RendezvousStateError, +) +from .dynamic_rendezvous import RendezvousBackend, Token +from .utils import _matches_machine_hostname, parse_rendezvous_endpoint + + +logger = logging.getLogger(__name__) + +# default port for the TCP store +DEFAULT_PORT = 29400 + + +class C10dRendezvousBackend(RendezvousBackend): + """Represents a C10d-backed rendezvous backend. + + Args: + store: + The :py:class:`torch.distributed.Store` instance to use to + communicate with the C10d store. + run_id: + The run id of the rendezvous. + """ + + # See the explanation in the __init__ method. + _NULL_SENTINEL = "Y2FuaW1hZGFt" + + _store: Store + _key: str + + def __init__(self, store: Store, run_id: str) -> None: + if not run_id: + raise ValueError("The run id must be a non-empty string.") + + self._store = store + + self._key = "torch.rendezvous." + run_id + + # The read operation of a store blocks the caller until the specified + # key becomes available. This behavior makes it tricky to use a store + # as a regular key-value dictionary. + # + # As a workaround we initially set a sentinel value as the rendezvous + # state. Whenever this value gets returned we treat it as a None. + self._call_store("compare_set", self._key, "", self._NULL_SENTINEL) + + @property + def name(self) -> str: + """See base class.""" + return "c10d" + + def get_state(self) -> Optional[tuple[bytes, Token]]: + """See base class.""" + base64_state: bytes = self._call_store("get", self._key) + + return self._decode_state(base64_state) + + def set_state( + self, state: bytes, token: Optional[Token] = None + ) -> Optional[tuple[bytes, Token, bool]]: + """See base class.""" + base64_state_str: str = b64encode(state).decode() + + if token: + # Shortcut if we know for sure that the token is not valid. + if not isinstance(token, bytes): + result = self.get_state() + if result is not None: + tmp = *result, False + # Python 3.6 does not support tuple unpacking in return + # statements. + return tmp + return None + + token = token.decode() + else: + token = self._NULL_SENTINEL + + base64_state: bytes = self._call_store( + "compare_set", self._key, token, base64_state_str + ) + + state_token_pair = self._decode_state(base64_state) + if state_token_pair is None: + return None + + new_state, new_token = state_token_pair + + # C10d Store's compare_set method does not offer an easy way to find out + # whether our write attempt was successful. As a brute-force solution we + # perform a bitwise comparison of our local state and the remote state. + return new_state, new_token, new_state == state + + def _call_store(self, store_op: str, *args, **kwargs) -> Any: + try: + return getattr(self._store, store_op)(*args, **kwargs) + except (ValueError, RuntimeError, TimeoutError) as exc: + raise RendezvousConnectionError( + "The connection to the C10d store has failed. See inner exception for details." + ) from exc + + def _decode_state(self, base64_state: bytes) -> Optional[tuple[bytes, Token]]: + if base64_state == self._NULL_SENTINEL.encode(): + return None + + try: + state = b64decode(base64_state) + except binascii.Error as exc: + raise RendezvousStateError( + "The state object is corrupt. See inner exception for details." + ) from exc + + return state, base64_state + + +def _create_tcp_store(params: RendezvousParameters) -> TCPStore: + host, port = parse_rendezvous_endpoint(params.endpoint, default_port=DEFAULT_PORT) + + cfg_is_host = params.get_as_bool("is_host") + # If the user has explicitly specified whether our process should host the + # the store, respect it. + if cfg_is_host is not None: + is_host = cfg_is_host + # Otherwise try to determine whether we are the host based on our hostname + # and IP address. + else: + is_host = _matches_machine_hostname(host) + + # The timeout + read_timeout = cast(int, params.get_as_int("read_timeout", 60)) + if read_timeout <= 0: + raise ValueError("The read timeout must be a positive integer.") + + # In specific cases we attempt to instantiate the store twice. For details + # see the explanation in the except clause below. + for is_server in [is_host, False]: + try: + store = TCPStore( + host, + port, + is_master=is_server, + multi_tenant=True, + timeout=timedelta(seconds=read_timeout), + ) + + if is_server: + msg = f"Process {os.getpid()} hosts the TCP store for the C10d rendezvous backend." + construct_and_record_rdzv_event( + run_id=params.run_id, message=msg, node_state=NodeState.INIT + ) + logger.info(msg) + + break + except (ValueError, RuntimeError, TimeoutError) as exc: + # If we heuristically inferred the value of is_host as True and our + # first attempt to instantiate the TCP store has failed, try it one + # more time with is_host set to False. As an edge case there can be + # more than one process that is part of the same rendezvous on this + # machine and only one of them will eventually host the store. + + if not is_server or cfg_is_host is not None: + raise RendezvousConnectionError( + "The connection to the C10d store has failed. See inner exception for details." + ) from exc + + return store # type: ignore[possibly-undefined] + + +def _create_file_store(params: RendezvousParameters) -> FileStore: + # If a user specifies an endpoint, we treat it as a path to a file. + if params.endpoint: + path = params.endpoint + else: + try: + # The temporary file is readable and writable only by the user of + # this process. + _, path = tempfile.mkstemp() + except OSError as exc: + raise RendezvousError( + "The file creation for C10d store has failed. See inner exception for details." + ) from exc + + try: + store = FileStore(path) + except (ValueError, RuntimeError) as exc: + raise RendezvousConnectionError( + "The connection to the C10d store has failed. See inner exception for details." + ) from exc + + return store + + +def create_backend(params: RendezvousParameters) -> tuple[C10dRendezvousBackend, Store]: + """Create a new :py:class:`C10dRendezvousBackend` from the specified parameters. + + +--------------+-----------------------------------------------------------+ + | Parameter | Description | + +==============+===========================================================+ + | store_type | The type of the C10d store. The currently supported types | + | | are "tcp" and "file" which correspond to | + | | :py:class:`torch.distributed.TCPStore` and | + | | :py:class:`torch.distributed.FileStore`, respectively. | + | | Defaults to "tcp". | + +--------------+-----------------------------------------------------------+ + | read_timeout | The read timeout, in seconds, for store operations. | + | | Defaults to 60 seconds. | + | | | + | | Note this only applies to | + | | :py:class:`torch.distributed.TCPStore`. It is not relevant| + | | to :py:class:`torch.distributed.FileStore` which does not | + | | take in timeout as a parameter. | + +--------------+-----------------------------------------------------------+ + | is_host | A boolean value indicating whether this backend instance | + | | will host the C10d store. If not specified it will be | + | | inferred heuristically by matching the hostname or the IP | + | | address of this machine against the specified rendezvous | + | | endpoint. Defaults to ``None``. | + | | | + | | Note that this configuration option only applies to | + | | :py:class:`torch.distributed.TCPStore`. In normal | + | | circumstances you can safely skip it; the only time when | + | | it is needed is if its value cannot be correctly | + | | determined (e.g. the rendezvous endpoint has a CNAME as | + | | the hostname or does not match the FQDN of the machine). | + +--------------+-----------------------------------------------------------+ + """ + # As of today we only support TCPStore and FileStore. Other store types do + # not have the required functionality (e.g. compare_set) yet. + store_type = params.get("store_type", "tcp").strip().lower() + store: Store + + try: + if store_type == "file": + store = _create_file_store(params) + elif store_type == "tcp": + store = _create_tcp_store(params) + else: + raise ValueError( + "Invalid store type given. Currently only supports file and tcp." + ) + + backend = C10dRendezvousBackend(store, params.run_id) + + except Exception as e: + construct_and_record_rdzv_event( + message=f"{type(e).__name__}: {str(e)}", + run_id=params.run_id, + node_state=NodeState.FAILED, + ) + raise + + return backend, store diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py new file mode 100644 index 0000000000000000000000000000000000000000..a38af7be2e5cf9ae8fbaf48cedeccaebf17f1e2f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py @@ -0,0 +1,1455 @@ +# mypy: allow-untyped-defs +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import inspect +import logging +import os +import pickle +import socket +import threading +import time +import weakref +from abc import ABC, abstractmethod +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from enum import Enum +from typing import Any, Callable, Optional + +import torch.distributed as dist +from torch.distributed import Store +from torch.distributed.elastic.events import construct_and_record_rdzv_event, NodeState + +from .api import ( + RendezvousClosedError, + RendezvousError, + RendezvousGracefulExitError, + RendezvousHandler, + RendezvousInfo, + RendezvousParameters, + RendezvousStateError, + RendezvousStoreInfo, + RendezvousTimeoutError, +) +from .utils import _delay, _PeriodicTimer + + +__all__ = [ + "RendezvousBackend", + "RendezvousTimeout", + "RendezvousSettings", + "DynamicRendezvousHandler", + "create_handler", +] + +logger = logging.getLogger(__name__) + + +def get_method_name(depth=2): + if len(inspect.stack()) > depth: + return inspect.stack()[depth].function + return "no_method_name" + + +Token = Any +"""Represent an opaque fencing token used by the rendezvous backend.""" + + +class RendezvousBackend(ABC): + """Represent a backend that holds the rendezvous state.""" + + @property + @abstractmethod + def name(self) -> str: + """Get the name of the backend.""" + + @abstractmethod + def get_state(self) -> Optional[tuple[bytes, Token]]: + """Get the rendezvous state. + + Returns: + A tuple of the encoded rendezvous state and its fencing token or + ``None`` if no state is found in the backend. + + Raises: + RendezvousConnectionError: + The connection to the backend has failed. + RendezvousStateError: + The rendezvous state is corrupt. + """ + + @abstractmethod + def set_state( + self, state: bytes, token: Optional[Token] = None + ) -> Optional[tuple[bytes, Token, bool]]: + """Set the rendezvous state. + + The new rendezvous state is set conditionally: + + - If the specified ``token`` matches the fencing token stored in the + backend, the state will be updated. The new state will be returned + to the caller along with its fencing token. + - If the specified ``token`` does not match the fencing token stored + in the backend, the state won't be updated; instead the existing + state along with its fencing token will be returned to the caller. + - If the specified ``token`` is ``None``, the new state will be set + only if there is no existing state in the backend. Either the new + state or the existing state along with its fencing token will be + returned to the caller. + + Args: + state: + The encoded rendezvous state. + token: + An optional fencing token that was retrieved by a previous call + to :py:meth:`get_state` or ``set_state()``. + + Returns: + A tuple of the serialized rendezvous state, its fencing token, and + a boolean value indicating whether our set attempt succeeded. + + Raises: + RendezvousConnectionError: + The connection to the backend has failed. + RendezvousStateError: + The rendezvous state is corrupt. + """ + + +class RendezvousTimeout: + """Hold the timeout configuration of a rendezvous. + + Args: + join: + The time within which the rendezvous is expected to complete. + last_call: + An additional wait amount before completing the rendezvous once the + rendezvous has the minimum number of required participants. + close: + The time within which the rendezvous is expected to close after a + call to :py:meth:`RendezvousHandler.set_closed` or + :py:meth:`RendezvousHandler.shutdown`. + heartbeat: + The time within which a keep-alive heartbeat is expected to + complete. + """ + + _ZERO = timedelta(0) + + _DEFAULT_TIMEOUTS = { + "join": timedelta(seconds=600), + "last_call": timedelta(seconds=30), + "close": timedelta(seconds=30), + "heartbeat": timedelta(seconds=5), + } + + _join: timedelta + _last_call: timedelta + _close: timedelta + _heartbeat: timedelta + + def __init__( + self, + join: Optional[timedelta] = None, + last_call: Optional[timedelta] = None, + close: Optional[timedelta] = None, + heartbeat: Optional[timedelta] = None, + ) -> None: + self._set_timeouts( + join=join, last_call=last_call, close=close, heartbeat=heartbeat + ) + + @property + def join(self) -> timedelta: + """Get the join timeout.""" + return self._join + + @property + def last_call(self) -> timedelta: + """Get the last call timeout.""" + return self._last_call + + @property + def close(self) -> timedelta: + """Get the close timeout.""" + return self._close + + @property + def heartbeat(self) -> timedelta: + """Get the keep-alive heartbeat timeout.""" + return self._heartbeat + + def _set_timeouts(self, **timeouts: Optional[timedelta]): + for name, timeout in timeouts.items(): + if timeout is None: + timeout = self._DEFAULT_TIMEOUTS[name] + if timeout <= self._ZERO: + raise ValueError(f"The {name} timeout ({timeout}) must be positive.") + setattr(self, "_" + name, timeout) + + +@dataclass(repr=False, eq=False, frozen=True) +class RendezvousSettings: + """Hold the settings of the rendezvous. + + Attributes: + run_id: + The run id of the rendezvous. + min_nodes: + The minimum number of nodes to admit to the rendezvous. + max_nodes: + The maximum number of nodes to admit to the rendezvous. + timeout: + The timeout configuration of the rendezvous. + keep_alive_interval: + The amount of time a node waits before sending a heartbeat to keep + it alive in the rendezvous. + keep_alive_max_attempt: + The maximum number of failed heartbeat attempts after which a node + is considered dead. + """ + + run_id: str + min_nodes: int + max_nodes: int + timeout: RendezvousTimeout + keep_alive_interval: timedelta + keep_alive_max_attempt: int + + +@dataclass(eq=True, order=True, frozen=True) +class _NodeDesc: + """Describe a node in the rendezvous. + + Attributes: + addr: + The FQDN of the node or user specified local node address. + pid: + The id of the process in which the rendezvous handler runs. + local_id: + A process-wide unique id. + """ + + addr: str + pid: int + local_id: int + + def __repr__(self) -> str: + return f"{self.addr}_{self.pid}_{self.local_id}" + + +class _NodeDescGenerator: + """Generate node descriptors. + + A node descriptor is a combination of an FQDN, a process id, and an auto- + incremented integer that uniquely identifies a node in the rendezvous. + """ + + _lock: threading.Lock + _local_id: int + + def __init__(self) -> None: + self._lock = threading.Lock() + + # An integer that is incremented with each call to generate(). + self._local_id = 0 + + def generate(self, local_addr: Optional[str] = None) -> _NodeDesc: + # This method can be called by multiple threads concurrently; therefore, + # we must increment the integer atomically. + with self._lock: + local_id = self._local_id + + self._local_id += 1 + + return _NodeDesc(local_addr or socket.getfqdn(), os.getpid(), local_id) + + +class _RendezvousState: + """Hold the state of a rendezvous. + + Attributes: + round: + The current round of the rendezvous. + complete: + A boolean value indicating whether the current round of the + rendezvous is complete. + deadline: + The time at which the current round of the rendezvous will be + considered complete if it is still waiting for nodes to join. + closed: + A boolean value indicating whether the rendezvous is closed. + participants: + A dictionary of the participants and their corresponding ranks. + wait_list: + A set of nodes that are waiting to participate in the next round of + the rendezvous. + redundancy_list: + A set of nodes that are redundant in the current round and can join + the next rendezvous without triggering re-rendezvous. + last_heartbeats: + A dictionary containing each node's last heartbeat time. + """ + + round: int + complete: bool + deadline: Optional[datetime] + closed: bool + participants: dict[_NodeDesc, int] + wait_list: set[_NodeDesc] + redundancy_list: set[_NodeDesc] + last_heartbeats: dict[_NodeDesc, datetime] + + def __init__(self) -> None: + self.round = 0 + self.complete = False + self.deadline = None + self.closed = False + self.participants = {} + self.wait_list = set() + self.redundancy_list = set() + self.last_heartbeats = {} + + +def _remove_participant_epilogue( + state: _RendezvousState, settings: RendezvousSettings +) -> None: + if state.complete: + # If we do not have any participants left, move to the next round. + if not state.participants: + msg = "No participants left in the rendezvous, marking rendezvous as incomplete" + logger.debug(msg) + state.complete = False + + state.round += 1 + else: + if len(state.participants) < settings.min_nodes: + msg = ( + f"Number of participants {len(state.participants)}) less than" + f"min_nodes {settings.min_nodes}, clearning deadline in state" + ) + logger.debug(msg) + state.deadline = None + + +class _RendezvousStateHolder(ABC): + """Hold the shared rendezvous state synced with other nodes.""" + + @property + @abstractmethod + def state(self) -> _RendezvousState: + """Get the local state.""" + + @abstractmethod + def sync(self) -> Optional[bool]: + """Read or writes the latest state. + + Returns: + A boolean value indicating whether the local state, in case marked + as dirty, was successfully synced with other nodes. + """ + + @abstractmethod + def mark_dirty(self) -> None: + """Mark the local state as dirty.""" + + +class _BackendRendezvousStateHolder(_RendezvousStateHolder): + """Hold the rendezvous state synced with other nodes via a backend. + + Args: + backend: + The rendezvous backend to use. + settings: + The rendezvous settings. + cache_duration: + The amount of time, in seconds, to cache the last rendezvous state + before requesting it from the backend again. + """ + + _backend: RendezvousBackend + _state: _RendezvousState + _settings: RendezvousSettings + _cache_duration: int + _token: Token + _dirty: bool + _last_sync_time: float + _dead_nodes: list[_NodeDesc] + + def __init__( + self, + backend: RendezvousBackend, + settings: RendezvousSettings, + cache_duration: int = 1, + ) -> None: + self._backend = backend + self._state = _RendezvousState() + self._settings = settings + self._cache_duration = cache_duration + self._token = None + self._dirty = False + self._last_sync_time = -1 + self._dead_nodes = [] + + def _record(self, message: str, node_state: NodeState = NodeState.RUNNING): + construct_and_record_rdzv_event( + name=f"{self.__class__.__name__}.{get_method_name()}", + run_id=self._settings.run_id, + message=message, + node_state=node_state, + ) + + @property + def state(self) -> _RendezvousState: + """See base class.""" + return self._state + + def sync(self) -> Optional[bool]: + """See base class.""" + state_bits: Optional[bytes] = None + + token = None + + has_set: Optional[bool] + + if self._dirty: + has_set = False + + state_bits = pickle.dumps(self._state) + + set_response = self._backend.set_state(state_bits, self._token) + if set_response is not None: + state_bits, token, has_set = set_response + else: + has_set = None + + if self._cache_duration > 0: + # Avoid overloading the backend if we are asked to retrieve the + # state repeatedly. Try to serve the cached state. + if self._last_sync_time >= max( + time.monotonic() - self._cache_duration, 0 + ): + return None + + get_response = self._backend.get_state() + if get_response is not None: + state_bits, token = get_response + + if state_bits is not None: + try: + self._state = pickle.loads(state_bits) + except pickle.PickleError as exc: + raise RendezvousStateError( + "The rendezvous state is corrupt. See inner exception for details." + ) from exc + else: + self._state = _RendezvousState() + + if has_set and self._dead_nodes and logger.isEnabledFor(logging.DEBUG): + node_list = ", ".join(f"'{dead_node}'" for dead_node in self._dead_nodes) + + msg = ( + f"As part of the sync operation the node(s) {node_list} have been removed from the " + f"rendezvous '{self._settings.run_id}' since they had no heartbeat." + ) + self._record(message=msg) + logger.debug(msg) + + self._token = token + + self._dirty = False + + self._last_sync_time = time.monotonic() + + self._sanitize() + + return has_set + + def _sanitize(self) -> None: + state = self._state + + expire_time = datetime.now(timezone.utc) - ( + self._settings.keep_alive_interval * self._settings.keep_alive_max_attempt + ) + + # Filter out the dead nodes. + self._dead_nodes = [ + node + for node, last_heartbeat in state.last_heartbeats.items() + if last_heartbeat < expire_time + ] + + participant_removed = False + + for dead_node in self._dead_nodes: + msg = f"Detected dead node '{dead_node}', removing it from the rendezvous" + logger.debug(msg) + del state.last_heartbeats[dead_node] + + try: + del state.participants[dead_node] + + participant_removed = True + except KeyError: + pass + + try: + state.wait_list.remove(dead_node) + except KeyError: + pass + + try: + state.redundancy_list.remove(dead_node) + except KeyError: + pass + + if participant_removed: + # Common epilogue shared with the _remove_from_participants() + # function of _DistributedRendezvousOpExecutor. + _remove_participant_epilogue(state, self._settings) + + def mark_dirty(self) -> None: + """See base class. + + If the local rendezvous state is dirty, the next sync call will try to + write the changes back to the backend. However this attempt might fail + if another node, which had the same state, also made changes and wrote + them before us. + """ + self._dirty = True + + +class _Action(Enum): + """Specifies the possible actions based on the state of the rendezvous.""" + + KEEP_ALIVE = 1 + ADD_TO_PARTICIPANTS = 2 + ADD_TO_WAIT_LIST = 3 + ADD_TO_REDUNDANCY_LIST = 4 + REMOVE_FROM_PARTICIPANTS = 5 + REMOVE_FROM_WAIT_LIST = 6 + REMOVE_FROM_REDUNDANCY_LIST = 7 + MARK_RENDEZVOUS_COMPLETE = 8 + MARK_RENDEZVOUS_CLOSED = 9 + SYNC = 10 + ERROR_CLOSED = 11 + ERROR_TIMEOUT = 12 + FINISH = 13 + + +class _RendezvousContext: + """Holds the context of the rendezvous. + + Attributes: + node: + The node descriptor associated with the current rendezvous handler + instance. + state: + The current state of the rendezvous. + settings: + The rendezvous settings. + """ + + node: _NodeDesc + state: _RendezvousState + settings: RendezvousSettings + + def __init__( + self, node: _NodeDesc, state: _RendezvousState, settings: RendezvousSettings + ) -> None: + self.node = node + self.state = state + self.settings = settings + + +class _RendezvousOpExecutor(ABC): + """Execute rendezvous operations.""" + + @abstractmethod + def run( + self, + state_handler: Callable[[_RendezvousContext, float], _Action], + deadline: float, + update_deadline: Optional[Callable[[timedelta], float]] = None, + ) -> None: + """Execute a rendezvous operation. + + An operation is run inside a state machine and is expected to transition + the rendezvous from one state to another. + + Args: + state_handler: + A callable that is expected to return the next state transition + action based on the current state of the rendezvous. + deadline: + The time, in seconds, at which the operation will be considered + timed-out. + update_deadline: + Function to generate a new operation deadline if the current + node may participate in the next rendezvous. + """ + + +class _DistributedRendezvousOpExecutor(_RendezvousOpExecutor): + """Execute rendezvous operations using a shared state. + + Args: + node: + The node descriptor associated with the current rendezvous handler + instance. + state_holder: + The ``RendezvousStateHolder`` to use to sync the rendezvous state + with other nodes. + settings: + The rendezvous settings. + """ + + _node: _NodeDesc + _state: _RendezvousState + _state_holder: _RendezvousStateHolder + _settings: RendezvousSettings + + def __init__( + self, + node: _NodeDesc, + state_holder: _RendezvousStateHolder, + settings: RendezvousSettings, + ) -> None: + self._node = node + self._state_holder = state_holder + self._settings = settings + + def _record(self, message: str, node_state: NodeState = NodeState.RUNNING) -> None: + construct_and_record_rdzv_event( + name=f"{self.__class__.__name__}.{get_method_name()}", + run_id=self._settings.run_id, + message=message, + node_state=node_state, + hostname=self._node.addr, + pid=self._node.pid, + local_id=self._node.local_id, + ) + + def run( + self, + state_handler: Callable[[_RendezvousContext, float], _Action], + deadline: float, + update_deadline: Optional[Callable[[timedelta], float]] = None, + ) -> None: + """See base class.""" + action = None + while action != _Action.FINISH: + # Reads or writes the latest rendezvous state shared by all nodes in + # the rendezvous. Note that our local changes might get overridden + # by another node if that node synced its changes before us. + has_set = self._state_holder.sync() + if has_set is not None: + if has_set: + msg = ( + f"The node '{self._node}' has successfully synced its local changes with " + f"other nodes in the rendezvous '{self._settings.run_id}'." + ) + else: + msg = ( + f"The node '{self._node}' has a stale state and failed to sync its local " + f"changes with other nodes in the rendezvous '{self._settings.run_id}'." + ) + + self._record(message=msg) + logger.debug(msg) + + self._state = self._state_holder.state + + ctx = _RendezvousContext(self._node, self._state, self._settings) + + # Determine the next action to take based on the current state of + # the rendezvous. + action = state_handler(ctx, deadline) + + if action == _Action.FINISH: + continue + + if action == _Action.ERROR_CLOSED: + raise RendezvousClosedError + + if action == _Action.ERROR_TIMEOUT: + raise RendezvousTimeoutError + + if action == _Action.SYNC: + # Delay the execution by one second to avoid overloading the + # backend if we are asked to poll for state changes. + _delay(seconds=1) + else: + if action == _Action.KEEP_ALIVE: + self._keep_alive() + elif action == _Action.ADD_TO_PARTICIPANTS: + self._add_to_participants() + elif action == _Action.ADD_TO_WAIT_LIST: + self._add_to_wait_list() + elif action == _Action.ADD_TO_REDUNDANCY_LIST: + self._add_to_redundancy_list() + elif action == _Action.REMOVE_FROM_PARTICIPANTS: + self._remove_from_participants() + elif action == _Action.REMOVE_FROM_WAIT_LIST: + self._remove_from_wait_list() + elif action == _Action.REMOVE_FROM_REDUNDANCY_LIST: + self._remove_from_redundancy_list() + # update deadline since the node may participate in rendezvous process + if update_deadline: + deadline = update_deadline(self._settings.timeout.join) + elif action == _Action.MARK_RENDEZVOUS_COMPLETE: + self._mark_rendezvous_complete() + elif action == _Action.MARK_RENDEZVOUS_CLOSED: + self._mark_rendezvous_closed() + + # Attempt to sync our changes back to other nodes. + self._state_holder.mark_dirty() + + def _keep_alive(self) -> None: + msg = ( + f"The node '{self._node}' updated its keep-alive heartbeat time for the rendezvous " + f"'{self._settings.run_id}'. Pending sync." + ) + self._record(message=msg) + logger.debug(msg) + + self._state.last_heartbeats[self._node] = datetime.now(timezone.utc) + + def _add_to_participants(self) -> None: + msg = ( + f"The node '{self._node}' added itself to the participants of round " + f"{self._state.round} of the rendezvous '{self._settings.run_id}'. Pending sync." + ) + self._record(message=msg) + logger.debug(msg) + + state = self._state + + try: + state.wait_list.remove(self._node) + except KeyError: + pass + + # The ranks of the participants will be set once the rendezvous is + # complete. + state.participants[self._node] = 0 + + self._keep_alive() + + if len(state.participants) == self._settings.min_nodes: + state.deadline = ( + datetime.now(timezone.utc) + self._settings.timeout.last_call + ) + + if len(state.participants) == self._settings.max_nodes: + self._mark_rendezvous_complete() + + def _add_to_wait_list(self) -> None: + msg = ( + f"The node '{self._node}' added itself to the wait list of round " + f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync." + ) + self._record(message=msg) + logger.debug(msg) + + if self._node in self._state.redundancy_list: + self._state.redundancy_list.remove(self._node) + self._state.wait_list.add(self._node) + + self._keep_alive() + + def _add_to_redundancy_list(self) -> None: + msg = ( + f"The node '{self._node}' added itself to the redundancy list of round " + f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync." + ) + self._record(message=msg) + logger.debug(msg) + + self._state.redundancy_list.add(self._node) + + self._keep_alive() + + def _remove_from_participants(self) -> None: + msg = ( + f"The node '{self._node}' removed itself from the participants of round " + f"{self._state.round} of the rendezvous '{self._settings.run_id}'. Pending sync." + ) + self._record(message=msg) + logger.debug(msg) + + state = self._state + + del state.participants[self._node] + + del state.last_heartbeats[self._node] + + # Common epilogue shared with the sanitizer() function of + # _BackendRendezvousStateHolder. + _remove_participant_epilogue(state, self._settings) + + def _remove_from_wait_list(self) -> None: + msg = ( + f"The node '{self._node}' removed itself from the wait list of round " + f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync." + ) + self._record(message=msg) + logger.debug(msg) + + self._state.wait_list.remove(self._node) + + del self._state.last_heartbeats[self._node] + + def _remove_from_redundancy_list(self) -> None: + msg = ( + f"The node '{self._node}' removed itself from the redundant list of round " + f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync." + ) + self._record(message=msg) + logger.debug(msg) + + self._state.redundancy_list.remove(self._node) + + del self._state.last_heartbeats[self._node] + + def _mark_rendezvous_complete(self) -> None: + msg = ( + f"The node '{self._node}' marked round {self._state.round} of the rendezvous " + f"'{self._settings.run_id}' as complete. Pending sync." + ) + self._record(message=msg, node_state=NodeState.SUCCEEDED) + logger.debug(msg) + + state = self._state + + state.complete = True + state.deadline = None + + # Assign the ranks. + for rank, node in enumerate(sorted(state.participants)): + state.participants[node] = rank + + def _mark_rendezvous_closed(self) -> None: + msg = ( + f"The node '{self._node}' marked the rendezvous '{self._settings.run_id}' as closed. " + "Pending sync." + ) + self._record(message=msg, node_state=NodeState.SUCCEEDED) + logger.debug(msg) + + self._state.closed = True + + +def _should_keep_alive(ctx: _RendezvousContext) -> bool: + """Determine whether a keep-alive heartbeat should be sent.""" + try: + last_heartbeat = ctx.state.last_heartbeats[ctx.node] + except KeyError: + return False + + return ( + last_heartbeat <= datetime.now(timezone.utc) - ctx.settings.keep_alive_interval + ) + + +class _RendezvousExitOp: + """Represent a rendezvous exit operation.""" + + def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action: + if ctx.node in ctx.state.participants: + if time.monotonic() > deadline: + return _Action.ERROR_TIMEOUT + return _Action.REMOVE_FROM_PARTICIPANTS + return _Action.FINISH + + +class _RendezvousJoinOp: + """Represent a rendezvous join operation.""" + + def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action: + state = ctx.state + + # A closed rendezvous means that it no longer accepts new nodes. + if state.closed: + if ctx.node in state.redundancy_list: + msg = f"The rendezvous '{ctx.settings.run_id}' is closed, terminating pending rendezvous." + raise RendezvousGracefulExitError(msg) + return _Action.ERROR_CLOSED + + if ctx.node in state.redundancy_list: + msg = f"The node {ctx.node} is in redundancy list" + logger.debug(msg) + # don't apply the timeout logic here, since we want to allow the node to rejoin + if len(state.participants) == ctx.settings.max_nodes: + if _should_keep_alive(ctx): + return _Action.KEEP_ALIVE + else: + return _Action.SYNC + else: + # transition to waiting state that will respect timeouts. + msg = f"The node {ctx.node} is removed from redundancy list" + logger.debug(msg) + return _Action.REMOVE_FROM_REDUNDANCY_LIST + + is_participant = ctx.node in state.participants + + # If we are part of the rendezvous and it is already complete there is + # no further action to take. + if state.complete and is_participant: + return _Action.FINISH + + now = time.monotonic() + if now > deadline: + rollback_period = 5 # 5 seconds + + # If we still have time to rollback (a short period on top of the + # operation deadline), try to remove ourself from the rendezvous. + # It is okay if we can't though as our keep-alive will eventually + # expire. + if now <= deadline + rollback_period: + # If we are part of the rendezvous, it means we couldn't find + # enough participants to complete it on time. + if is_participant: + return _Action.REMOVE_FROM_PARTICIPANTS + # If we are in the wait list, it means we couldn't wait till the + # next round of the rendezvous. + if ctx.node in state.wait_list: + return _Action.REMOVE_FROM_WAIT_LIST + return _Action.ERROR_TIMEOUT + + if state.complete: + # If we are here, it means we are not part of the rendezvous. In + # case the rendezvous has capacity for additional participants add + # ourself to the wait list for the next round. + if len(state.participants) < ctx.settings.max_nodes: + if ctx.node not in state.wait_list: + return _Action.ADD_TO_WAIT_LIST + elif len(state.participants) >= ctx.settings.max_nodes: + if ( + ctx.node not in state.redundancy_list + and ctx.node not in state.wait_list + ): + return _Action.ADD_TO_REDUNDANCY_LIST + elif is_participant: + # If the rendezvous has enough number of participants including us, + # check whether we have passed the rendezvous deadline. If yes, + # complete it. + if ( + len(state.participants) >= ctx.settings.min_nodes + and len(state.participants) <= ctx.settings.max_nodes + and state.deadline is not None + ): + if state.deadline < datetime.now(timezone.utc): + msg = ( + f"The node '{ctx.node}' marking the rendezvous complete, " + f"quorum established within deadline" + ) + logger.debug(msg) + return _Action.MARK_RENDEZVOUS_COMPLETE + else: + msg = f"The node '{ctx.node}' can't complete rendezvous: deadline reached" + logger.debug(msg) + else: + msg = f"The node '{ctx.node}' can't complete rendezvous: not enough participants" + logger.debug(msg) + else: + # The rendezvous is not complete yet and we are not part of it. Try + # to join. + return _Action.ADD_TO_PARTICIPANTS + + if _should_keep_alive(ctx): + return _Action.KEEP_ALIVE + + # At this point either the rendezvous is not complete, but we are part + # of it, which means we have to wait for other participants to join; or + # the rendezvous is complete, but we are not part of it, which means we + # have to wait for the next round. + return _Action.SYNC + + +class _RendezvousCloseOp: + """Represent a rendezvous close operation.""" + + def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action: + if ctx.state.closed: + return _Action.FINISH + if time.monotonic() > deadline: + return _Action.ERROR_TIMEOUT + return _Action.MARK_RENDEZVOUS_CLOSED + + +class _RendezvousKeepAliveOp: + """Represent a rendezvous keep-alive update operation.""" + + def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action: + if _should_keep_alive(ctx): + if time.monotonic() > deadline: + return _Action.ERROR_TIMEOUT + return _Action.KEEP_ALIVE + return _Action.FINISH + + +class DynamicRendezvousHandler(RendezvousHandler): + """Represent a handler that sets up a rendezvous among a set of nodes.""" + + # Static + _node_desc_generator = _NodeDescGenerator() + + _this_node: _NodeDesc + _settings: RendezvousSettings + _backend_name: str + _store: Store + _state_holder: _RendezvousStateHolder + _op_executor: _RendezvousOpExecutor + _heartbeat_lock: threading.Lock + _keep_alive_timer: Optional[_PeriodicTimer] + + @classmethod + def from_backend( + cls, + run_id: str, + store: Store, + backend: RendezvousBackend, + min_nodes: int, + max_nodes: int, + local_addr: Optional[str] = None, + timeout: Optional[RendezvousTimeout] = None, + keep_alive_interval: int = 5, + keep_alive_max_attempt: int = 3, + ): + """Create a new :py:class:`DynamicRendezvousHandler`. + + Args: + run_id: + The run id of the rendezvous. + store: + The C10d store to return as part of the rendezvous. + backend: + The backend to use to hold the rendezvous state. + min_nodes: + The minimum number of nodes to admit to the rendezvous. + max_nodes: + The maximum number of nodes to admit to the rendezvous. + local_addr: + The local node address. + timeout: + The timeout configuration of the rendezvous. + keep_alive_interval: + The amount of time a node waits before sending a heartbeat to keep + it alive in the rendezvous. + keep_alive_max_attempt: + The maximum number of failed heartbeat attempts after which a node + is considered dead. + """ + # We associate each handler instance with a unique node descriptor. + node = cls._node_desc_generator.generate(local_addr) + + settings = RendezvousSettings( + run_id, + min_nodes, + max_nodes, + timeout or RendezvousTimeout(), + keep_alive_interval=timedelta(seconds=keep_alive_interval), + keep_alive_max_attempt=keep_alive_max_attempt, + ) + + state_holder = _BackendRendezvousStateHolder(backend, settings) + + return cls(node, settings, backend.name, store, state_holder) + + def __init__( + self, + node: _NodeDesc, + settings: RendezvousSettings, + backend_name: str, + store: Store, + state_holder: _RendezvousStateHolder, + ) -> None: + if not settings.run_id: + raise ValueError("The run id must be a non-empty string.") + + if settings.min_nodes < 1: + raise ValueError( + f"The minimum number of nodes ({settings.min_nodes}) must be greater than zero." + ) + + if settings.max_nodes < settings.min_nodes: + raise ValueError( + f"The maximum number of nodes ({settings.max_nodes}) must be greater than or equal " + f"to the minimum number of nodes ({settings.min_nodes})." + ) + + self._this_node = node + + self._settings = settings + + self._backend_name = backend_name + + self._store = store + + self._state_holder = state_holder + + self._op_executor = _DistributedRendezvousOpExecutor( + self._this_node, self._state_holder, self._settings + ) + + self._heartbeat_lock = threading.Lock() + + self._keep_alive_timer = None + + # Cached shared store server reference + self._shared_tcp_store_server: Optional[dist.Store] = None + + self._bootstrap_store_info: Optional[RendezvousStoreInfo] = None + + def _record( + self, + message: str, + node_state: NodeState = NodeState.RUNNING, + rank: Optional[int] = None, + ) -> None: + construct_and_record_rdzv_event( + name=f"{self.__class__.__name__}.{get_method_name()}", + run_id=self._settings.run_id, + message=message, + node_state=node_state, + hostname=self._this_node.addr, + pid=self._this_node.pid, + local_id=self._this_node.local_id, + rank=rank, + ) + + def _create_tcp_store_server(self, master_addr, master_port) -> dist.TCPStore: + return dist.TCPStore( + host_name=master_addr, + port=master_port, + is_master=True, + multi_tenant=True, + ) + + @property + def settings(self) -> RendezvousSettings: + """Get the settings of the rendezvous.""" + return self._settings + + def get_backend(self) -> str: + """See base class.""" + return self._backend_name + + @property + def use_agent_store(self) -> bool: + """See base class.""" + return os.getenv("TORCH_DISABLE_SHARE_RDZV_TCP_STORE", "0") != "1" + + def next_rendezvous(self) -> RendezvousInfo: + """See base class.""" + msg = ( + f"The node '{self._this_node}' attempts to join the next round of the rendezvous " + f"'{self._settings.run_id}'." + ) + self._record(message=msg) + logger.info(msg) + + try: + self._stop_heartbeats() + + # Delay the execution for a small random amount of time if this is our + # first run. This will slightly skew the rendezvous attempts across the + # nodes and reduce the load on the backend. + if self._state_holder.state.round == 0: + _delay(seconds=(0, 0.3)) + + exit_op = _RendezvousExitOp() + join_op = _RendezvousJoinOp() + + deadline = self._get_deadline(self._settings.timeout.join) + self._op_executor.run(exit_op, deadline) + self._op_executor.run(join_op, deadline, self._get_deadline) + + self._start_heartbeats() + + rank, world_size = self._get_world() + store = self._get_store() + + except Exception as e: + self._record( + message=f"{type(e).__name__}: {str(e)}", + node_state=NodeState.FAILED, + ) + raise + + msg = ( + f"The node '{self._this_node}' has joined round {self._state_holder.state.round} of " + f"the rendezvous '{self._settings.run_id}' as rank {rank} in a world of size " + f"{world_size}." + ) + self._record(message=msg, rank=rank) + logger.info(msg) + + # opt-out option of TCPStore sharing + if os.getenv("TORCH_DISABLE_SHARE_RDZV_TCP_STORE", "0") == "1": + bootstrap_store_info = RendezvousStoreInfo.build( + rank, store, local_addr=self._this_node.addr + ) + return RendezvousInfo( + store, + rank, + world_size, + bootstrap_store_info, + ) + + # This will only be hit when TCPStore sharing is enabled. + if self._bootstrap_store_info is None: + # To avoid race in get_free_port because we release the port after the call, + # we want to create a TCPStore server soon afterwards. + server_port = 0 + if rank == 0: + self._shared_tcp_store_server = self._create_tcp_store_server( + self._this_node.addr, server_port + ) + server_port = self._shared_tcp_store_server.port + self._bootstrap_store_info = RendezvousStoreInfo.build( + rank, + store, + local_addr=self._this_node.addr, + server_port=server_port, # For non-0 rank, this is a no-op + ) + + assert self._bootstrap_store_info is not None + if rank == 0: + assert self._shared_tcp_store_server is not None + + return RendezvousInfo( + store, + rank, + world_size, + self._bootstrap_store_info, # type: ignore[assignment] + ) + + def is_closed(self) -> bool: + """See base class.""" + try: + with self._heartbeat_lock: + self._state_holder.sync() + + return self._state_holder.state.closed + + except Exception as e: + self._record( + message=f"{type(e).__name__}: {str(e)}", + node_state=NodeState.FAILED, + ) + raise + + def set_closed(self) -> None: + """See base class.""" + try: + with self._heartbeat_lock: + self._close() + except Exception as e: + self._record( + message=f"{type(e).__name__}: {str(e)}", + node_state=NodeState.FAILED, + ) + raise + + def num_nodes_waiting(self) -> int: + """See base class.""" + try: + with self._heartbeat_lock: + self._state_holder.sync() + + return len(self._state_holder.state.wait_list) + + except Exception as e: + self._record( + message=f"{type(e).__name__}: {str(e)}", + node_state=NodeState.FAILED, + ) + raise + + def get_run_id(self) -> str: + """See base class.""" + return self._settings.run_id + + def shutdown(self) -> bool: + """See base class.""" + self._stop_heartbeats() + + try: + self._close() + + return True + except RendezvousError as ex: + msg = ( + f"The node '{self._this_node}' has failed to shutdown the rendezvous " + f"'{self._settings.run_id}' due to an error of type {type(ex).__name__}." + ) + self._record(message=msg, node_state=NodeState.FAILED) + logger.warning(msg) + + return False + except Exception as e: + self._record( + message=f"{type(e).__name__}: {str(e)}", + node_state=NodeState.FAILED, + ) + raise + + def _close(self) -> None: + op = _RendezvousCloseOp() + + deadline = self._get_deadline(self._settings.timeout.close) + + self._op_executor.run(op, deadline) + + msg = f"The node '{self._this_node}' has closed the rendezvous '{self._settings.run_id}'." + self._record(message=msg, node_state=NodeState.SUCCEEDED) + logger.info(msg) + + @staticmethod + def _keep_alive_weak(weak_self) -> None: + self = weak_self() + if self is not None: + self._keep_alive() + + def _keep_alive(self) -> None: + self._heartbeat_lock.acquire() + + op = _RendezvousKeepAliveOp() + + deadline = self._get_deadline(self._settings.timeout.heartbeat) + + try: + self._op_executor.run(op, deadline) + + msg = ( + f"The node '{self._this_node}' has sent a keep-alive heartbeat to the rendezvous " + f"'{self._settings.run_id}'." + ) + self._record(message=msg) + logger.debug(msg) + except RendezvousError as ex: + msg = ( + f"The node '{self._this_node}' has failed to send a keep-alive heartbeat to the " + f"rendezvous '{self._settings.run_id}' due to an error of type {type(ex).__name__}." + ) + self._record(message=msg, node_state=NodeState.FAILED) + logger.warning(msg) + finally: + self._heartbeat_lock.release() + + def _start_heartbeats(self) -> None: + self._keep_alive_timer = _PeriodicTimer( + self._settings.keep_alive_interval, self._keep_alive_weak, weakref.ref(self) + ) + + self._keep_alive_timer.set_name( + f"RendezvousKeepAliveTimer_{self._this_node.local_id}" + ) + + self._keep_alive_timer.start() + + def _stop_heartbeats(self) -> None: + if self._keep_alive_timer is None: + return + + self._keep_alive_timer.cancel() + + def _get_world(self) -> tuple[int, int]: + state = self._state_holder.state + + return state.participants[self._this_node], len(state.participants) + + def _wrap_store(self, store: Store) -> Store: + key_prefix = ( + f"torch.rendezvous.{self._settings.run_id}.{self._state_holder.state.round}" + ) + + return dist.PrefixStore(key_prefix, store) + + def _get_store(self) -> Store: + return self._wrap_store(self._store) + + def _get_deadline(self, timeout: timedelta) -> float: + return time.monotonic() + timeout.total_seconds() + + +def _get_timeout(params: RendezvousParameters, key: str) -> Optional[timedelta]: + timeout = params.get_as_int(key + "_timeout") + if timeout is None: + return None + return timedelta(seconds=timeout) + + +def create_handler( + store: Store, backend: RendezvousBackend, params: RendezvousParameters +) -> DynamicRendezvousHandler: + """Create a new :py:class:`DynamicRendezvousHandler` from the specified parameters. + + Args: + store: + The C10d store to return as part of the rendezvous. + backend: + The backend to use to hold the rendezvous state. + + +-------------------+------------------------------------------------------+ + | Parameter | Description | + +===================+======================================================+ + | join_timeout | The total time, in seconds, within which the | + | | rendezvous is expected to complete. Defaults to 600 | + | | seconds. | + +-------------------+------------------------------------------------------+ + | last_call_timeout | An additional wait amount, in seconds, before | + | | completing the rendezvous once the minimum number of | + | | nodes has been reached. Defaults to 30 seconds. | + +-------------------+------------------------------------------------------+ + | close_timeout | The time, in seconds, within which the rendezvous is | + | | expected to close after a call to | + | | :py:meth:`RendezvousHandler.set_closed` or | + | | :py:meth:`RendezvousHandler.shutdown`. Defaults to | + | | 30 seconds. | + +-------------------+------------------------------------------------------+ + | heartbeat | The time, in seconds, within which a keep-alive | + | | heartbeat is expected to complete | + +-------------------+------------------------------------------------------+ + """ + try: + timeout = RendezvousTimeout( + _get_timeout(params, "join"), + _get_timeout(params, "last_call"), + _get_timeout(params, "close"), + _get_timeout(params, "heartbeat"), + ) + keep_alive_interval = params.get_as_int("keep_alive_interval", 5) + if keep_alive_interval is None: + raise TypeError( + "You passed 'keep_alive_interval=None' as a rendezvous configuration option" + ) + keep_alive_max_attempt = params.get_as_int("keep_alive_max_attempt", 3) + if keep_alive_max_attempt is None: + raise TypeError( + "You passed 'keep_alive_max_attempt=None' as a rendezvous configuration option" + ) + + return DynamicRendezvousHandler.from_backend( + params.run_id, + store, + backend, + params.min_nodes, + params.max_nodes, + params.local_addr, + timeout, + keep_alive_interval=keep_alive_interval, + keep_alive_max_attempt=keep_alive_max_attempt, + ) + except Exception as e: + construct_and_record_rdzv_event( + message=f"{type(e).__name__}: {str(e)}", + run_id=params.run_id, + node_state=NodeState.FAILED, + ) + raise diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/etcd_rendezvous.py b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/etcd_rendezvous.py new file mode 100644 index 0000000000000000000000000000000000000000..ab4a68ba40dc41710a22768db8db8f2d0d575ebc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/etcd_rendezvous.py @@ -0,0 +1,1081 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +import logging +import sys +import threading +import time +from typing import Optional + + +try: + import etcd # type: ignore[import] +except ModuleNotFoundError: + from . import _etcd_stub as etcd + +from torch.distributed.elastic.rendezvous import ( + RendezvousClosedError, + RendezvousError, + RendezvousHandler, + RendezvousInfo, + RendezvousParameters, + RendezvousStoreInfo, + RendezvousTimeoutError, +) + +from .etcd_store import cas_delay, EtcdStore +from .utils import parse_rendezvous_endpoint + + +__all__ = [ + "EtcdRendezvousRetryableFailure", + "EtcdRendezvousRetryImmediately", + "EtcdRendezvousHandler", + "EtcdRendezvous", + "create_rdzv_handler", +] + +_log_fmt = logging.Formatter("%(levelname)s %(asctime)s %(message)s") +_log_handler = logging.StreamHandler(sys.stderr) +_log_handler.setFormatter(_log_fmt) + +logger = logging.getLogger(__name__) +logger.propagate = False +logger.setLevel(logging.INFO) +logger.addHandler(_log_handler) + + +# Retryable failure exception means the we were too late to make +# a desired state transition (e.g. because of a race condition), +# and should now restart from the beginning. +# A small delay is recommended to avoid spamming Etcd. +class EtcdRendezvousRetryableFailure(Exception): + pass + + +# Similar to retryable failure, but the new state we observed suggests we +# can re-try immediately, i.e. without a need for "safety delay". +class EtcdRendezvousRetryImmediately(Exception): + pass + + +# Default timeout for the rendezvous. +_DEFAULT_TIMEOUT: int = 600 # 10 minutes + +# Additional waiting time after reaching the minimum number of nodes +# in case the rendezvous is elastic (min != max). +_DEFAULT_LAST_CALL_TIMEOUT: int = 30 # 30 seconds + +# Various constants used internally in EtcdRendezvous +CONST_ETCD_SETUP_TTL = 5 +CONST_ETCD_FROZEN_TTL = 10 +CONST_ETCD_JOINABLE_EPHEMERAL_TTL = 10 + +# Ephemeral node TTL for worker's keep-alive key: +CONST_WORKER_KEEPALIVE_TTL = 10 + +# TTL for the ephemeral run_id-specific directory. All rendezvous state data +# for a specific run_id (job instance) is contained within directory. +# Its only role is to clean-up rendezvous data from old runs (for the case when +# etcd server is persistent), and has no affect on correctness, but should be +# larger than any timeouts that a worker process is expected to survive: +CONST_RUNID_SUBROOT_TTL = 7200 # 2 hours + + +class EtcdRendezvousHandler(RendezvousHandler): + """ + Implements a + :py:class:`torch.distributed.elastic.rendezvous.RendezvousHandler` interface + backed by + :py:class:`torch.distributed.elastic.rendezvous.etcd_rendezvous.EtcdRendezvous`. + ``EtcdRendezvousHandler`` uses a URL to configure the type of rendezvous to + use and to pass implementation specific configurations to the rendezvous + module. The basic etcd rendezvous configuration URL looks like the following + :: + + etcd://:/?min_workers=&max_workers= # noqa: W605 + + -- example -- + + etcd://localhost:2379/1234?min_workers=1&max_workers=3 + + The URL above is interpreted as follows: + + 1. Use the rendezvous handler that is registered with the ``etcd`` + scheme + 2. The ``etcd`` endpoint to use is ``localhost:2379`` + 3. ``job_id == 1234`` is used as the prefix in etcd (this allows one to + share a common etcd server for multiple jobs so long as the + ``job_ids`` are guaranteed to be unique). Note that the job id can be + any string (e.g. does not need to be a number) as long as it is + unique. + 4. ``min_workers=1`` and ``max_workers=3`` specifies a range for + membership size - Torch Distributed Elastic starts running the job as + long as the cluster size is greater than or equal to ``min_workers`` + and admits up to ``max_workers`` into the cluster. + + Below are a full list of the parameters that can be passed to etcd + rendezvous: + + +--------------------------------------------+--------------------------+ + | Parameter | Description | + +============================================+==========================+ + | min_workers | minimum number of | + | | workers for the | + | | rendezvous to be valid | + +--------------------------------------------+--------------------------+ + | max_workers | maximum number of | + | | workers to admit | + +--------------------------------------------+--------------------------+ + | timeout | total timeout within | + | | which next_rendezvous is | + | | expected to succeed | + | | (default 600s) | + +--------------------------------------------+--------------------------+ + | last_call_timeout | additional wait amount | + | | ("last call") after min | + | | number of workers has | + | | been reached (defaults | + | | to 30s) | + +--------------------------------------------+--------------------------+ + | etcd_prefix | path prefix (from etcd | + | | root), inside which all | + | | etcd nodes will be | + | | created (defaults to | + | | ``/torchelastic/p2p``) | + +--------------------------------------------+--------------------------+ + """ + + def __init__(self, rdzv_impl: "EtcdRendezvous", local_addr: Optional[str]): + """ + Args: + rdzv_impl: the implementation of the rendezvous + local_addr: the local address of the current node + """ + + self._rdzv_impl = rdzv_impl + self._local_addr = local_addr + + def __del__(self): + # TODO: look into using weakref here instead. + del self._rdzv_impl + + def get_backend(self) -> str: + return "etcd" + + def next_rendezvous(self): + rdzv_version, rank, world_size = self._rdzv_impl.rendezvous_barrier() + + logger.info("Creating EtcdStore as the c10d::Store implementation") + store = self._rdzv_impl.setup_kv_store(rdzv_version) + + bootstrap_store_info = RendezvousStoreInfo.build( + rank, store, local_addr=self._local_addr + ) + return RendezvousInfo(store, rank, world_size, bootstrap_store_info) + + def is_closed(self): + try: + _, state = self._rdzv_impl.get_rdzv_state() + return state["status"] == "closed" + except etcd.EtcdKeyNotFound: + # No rendezvous state, so it cannot be closed. + return False + + def set_closed(self): + self._rdzv_impl.set_closed() + + def num_nodes_waiting(self): + try: + _, state = self._rdzv_impl.get_rdzv_state() + if state["status"] == "final": + return state["num_workers_waiting"] + except etcd.EtcdKeyNotFound: + pass + return 0 + + def get_run_id(self) -> str: + return self._rdzv_impl._run_id + + def shutdown(self) -> bool: + try: + self.set_closed() + return True + except BaseException as e: + logger.warning("Shutdown failed. Error occurred: %s", str(e)) + return False + + +# TODO: we should probably handle a few additional errors, +# like EtcdLeaderElectionInProgress and EtcdWatcherCleared. These are +# only relevant for multi-node Etcd ensemble. A simple retry would work, +# but is verbose to add everywhere. Consider wrapping the client calls +# into auto-retry for these errors? +# +class EtcdRendezvous: + """A rendezvous implementation that uses `etcd `__ as the backend store.""" + + def __init__( + self, + client, + prefix, + run_id, + num_min_workers, + num_max_workers, + timeout, + last_call_timeout, + ): + self.client = client + logger.info("Etcd machines: %s", self.client.machines) + + self._prefix = prefix + self._run_id = run_id + self._num_min_workers = num_min_workers + self._num_max_workers = num_max_workers + self._timeout = timeout + self._last_call_timeout = last_call_timeout + + # For cleaning up TTL refresher threads (for ephemeral keys) + self._lease_run_id_stop = None + self._lease_this_rank_stop = None + + if not self._prefix.endswith("/"): + self._prefix += "/" + + # Setup a permanent prefix dir, if didn't exist + if self._prefix != "/": + self.create_path_if_not_exists(self._prefix) + + # Lease a "sub-root" node specific to this job instance (run_id) + self.create_path_if_not_exists(self.get_path(""), ttl=CONST_RUNID_SUBROOT_TTL) + self._lease_run_id_stop = self.setup_lease_renewal( + self.get_path(""), ttl=CONST_RUNID_SUBROOT_TTL + ) + + # Subdir for all rendezvous work + self.create_path_if_not_exists(self.get_path("/rdzv")) + + # Create a rendezvous version counter, if doesn't exist + try: + self.client.write( + key=self.get_path("/rdzv/version_counter"), value="0", prevExist=False + ) + except etcd.EtcdAlreadyExist: + pass + + def __del__(self): + # TODO: look into using weakref here instead. + if self._lease_run_id_stop is not None: + self._lease_run_id_stop.set() + + if self._lease_this_rank_stop is not None: + self._lease_this_rank_stop.set() + + def rendezvous_barrier(self): + """ + Main entry point for next rendezvous. + + This method is blocking until rendezvous succeeds or a timeout occurs. + + Returns: + ``(rdzv_version, rank, world_size)`` + + Raises: + RendezvousTimeoutError - timeout waiting for rendezvous + RendezvousClosedError - rendezvous is or was closed while waiting + RendezvousError - other persistent errors that + render the rendezvous non-retryable + """ + self._rendezvous_deadline = time.time() + self._timeout + while True: + if time.time() > self._rendezvous_deadline: + raise RendezvousTimeoutError + + logger.info("Attempting to join next rendezvous") + try: + # Dis-own our lease in the previous rendezvous, if exists + if self._lease_this_rank_stop is not None: + self._lease_this_rank_stop.set() + + return self.init_phase() + + except EtcdRendezvousRetryImmediately: + # The type of failure suggests we can retry without delay + pass + + except EtcdRendezvousRetryableFailure: + # In case of retryable failure, wait a small delay + # to avoid spamming etcd + time.sleep(1) + + except RendezvousTimeoutError: + logger.info("Rendezvous timeout occurred in EtcdRendezvousHandler") + raise + + except RendezvousClosedError: + logger.info( + "Rendezvous for run_id=%s was observed to be closed", self._run_id + ) + raise + + except RendezvousError: + raise + + except Exception as e: + # In case of a general exception, wait a small delay + # to avoid spamming etcd + # FIXME: there are a few things that fall under this like + # etcd.EtcdKeyNotFound, etc, which could be handled more explicitly. + logger.info("Rendezvous attempt failed, will retry. Reason: %s", e) + time.sleep(1) + + def init_phase(self): + """ + Initially, the rendezvous state is expected to be one of: + + 1. empty (non-existent) - in this case we try to create a new one. + 2. joinable - we try to join it. + 3. final - we announce ourselves as waiting, and go into monitoring mode + + Any other state is considered transitional, and will be retried after + a short delay. + + Returns: + ``(rdzv_version, rank, world_size)`` + + Raises: + RendezvousClosedError - current rendezvous was/is closed + EtcdRendezvousRetryableFailure - observed some intermediate + state, which is best handled by retrying later + """ + try: + active_version = self.try_create_rendezvous() + state = json.loads(active_version.value) + logger.info("New rendezvous state created: %s", state) + except etcd.EtcdAlreadyExist: + active_version, state = self.get_rdzv_state() + # Note: it is possible for above query to fail (etcd.EtcdKeyNotFound), + # but this is ok for us - just means we'll restart from beginning. + logger.info("Observed existing rendezvous state: %s", state) + + if state["status"] == "closed": + raise RendezvousClosedError + + if state["status"] == "joinable": + return self.join_phase(state["version"]) + + if state["status"] == "final": + self.handle_existing_rendezvous(state["version"]) + raise EtcdRendezvousRetryImmediately + + self.try_wait_for_state_change(etcd_index=active_version.etcd_index + 1) + raise EtcdRendezvousRetryableFailure + + def join_phase(self, expected_version): + """ + We observed a rendezvous state in 'joinable' state, and attempt to join this + particular version, and then wait for all other peers to join. + """ + # Failure to join will propagate an exception, causing a re-entry. + active_version, this_rank = self.join_rendezvous(expected_version) + state = json.loads(active_version.value) + logger.info( + "Joined rendezvous version %s as rank %s. Full state: %s", + state["version"], + this_rank, + state, + ) + + # If this worker was first to reach num_min_workers requirement, + # and rendezvous is still joinable (therefore it is elastic), + # then this worker will be responsible for waiting out the "last call" + # timeout and closing (i.e. transitioning to 'frozen') the rendezvous + # afterwards. + # As a safety against a potential failure of this worker (during the + # last call timeout), the rendezvous state is made ephemeral + # when min_num_workers is reached. + + if this_rank == self._num_min_workers - 1 and state["status"] == "joinable": + logger.info("Rank %s is responsible for join last call.", this_rank) + last_call_deadline = time.time() + self._last_call_timeout + self.handle_join_last_call(expected_version, last_call_deadline) + logger.info("Rank %s finished join last call.", this_rank) + + # Wait for rendezvous state to be frozen, which means a fixed set of peers + logger.info("Waiting for remaining peers.") + active_version = self.wait_for_peers(expected_version) + state = json.loads(active_version.value) + + assert state["version"] == expected_version, ( + "Logic error: failed to observe version mismatch" + ) + + return self.confirm_phase(expected_version, this_rank) + + def confirm_phase(self, expected_version, this_rank): + """ + Once the rendezvous state transitions from 'joinable' to 'frozen', + we have every participant confirm their membership and setup per-member + keep-alive TTL keys, and then wait for all other participants to confirm, + which would then successfully conclude this rendezvous. + """ + logger.info("All peers arrived. Confirming membership.") + self.confirm_membership(expected_version, this_rank) + + logger.info("Waiting for confirmations from all peers.") + active_version = self.wait_for_final(expected_version) + state = json.loads(active_version.value) + + logger.info( + "Rendezvous version %s is complete. Final state: %s", + state["version"], + state, + ) + + # Rendezvous version number; our rank in it; world size + return state["version"], this_rank, len(state["participants"]) + + def handle_existing_rendezvous(self, expected_version): + """ + Handle the case when there's an existing (state 'final) rendezvous already + in place, and we have to announce ourselves waiting, and wait until + the next rendezvous opportunity. + """ + # If state is 'final' -> increment num_workers_waiting + # Then, observe state changes: + # 1. if it's no longer final -> bail out and re-try + # 2. if keep alives are missing, destroy it and bail out. + active_state = self.announce_self_waiting(expected_version) + logger.info( + "Added self to waiting list. Rendezvous full state: %s", active_state.value + ) + + self.wait_for_rendezvous_to_free(expected_version) + logger.info( + "Previously existing rendezvous state changed. Will re-try joining." + ) + + def try_create_rendezvous(self): + """ + Create new rendezvous state or raise an exception that indicates an unexpected state (e.g. already exists). + + Raises: + RendezvousError - on unexpected state + """ + # Initially active_version is ephemeral - this is to handle the + # possibility that might fail to complete the setup transaction, + # i.e. the transition "setup" -> "joinable". + active_version = self.client.write( + key=self.get_path("/rdzv/active_version"), + value=json.dumps({"status": "setup"}), + prevExist=False, + ttl=CONST_ETCD_SETUP_TTL, + ) + + try: + version_counter = self.client.get(self.get_path("/rdzv/version_counter")) + version_counter.value = str(int(version_counter.value) + 1) + self.client.update(version_counter) + except (etcd.EtcdKeyNotFound, etcd.EtcdCompareFailed) as e: + raise RendezvousError( + "Unexpected state of EtcdRendezvousHandler, worker needs to die." + ) from e + + # Any failure below results in declaring a retryable rendezvous failure. + # The ephemeral /rdzv/active_version will expire and someone can then + # re-try the setup process. + + # Create directory node for participant data + self.client.write( + key=self.get_path(f"/rdzv/v_{version_counter.value}"), + value=None, + dir=True, + prevExist=False, + ) + + # Publish rendezvous version and signal it is ready-to-be-joined. + # If rendezvous was set closed just before this, a retry will happen, + # where the closed condition will be handled. + return self.client.test_and_set( + key=self.get_path("/rdzv/active_version"), + value=json.dumps( + { + "status": "joinable", + "version": version_counter.value, + "participants": [], + } + ), + prev_value=active_version.value, + ) + + def join_rendezvous(self, expected_version): + """Helper method for the join phase.""" + # Use compare-and-swap to add self to rendezvous state: + while True: + cas_delay() + active_version, state = self.get_rdzv_state() + + if state["status"] != "joinable": + raise EtcdRendezvousRetryableFailure( + "Rendezvous state became non-joinable before we could join. " + "Must join next one." + ) + + if state["version"] != expected_version: + raise EtcdRendezvousRetryImmediately( + "Rendezvous version changed. Must try join the new one." + ) + + assert len(state["participants"]) < self._num_max_workers, ( + "Logic error: joinable rendezvous should always have space left" + ) + + this_rank = len(state["participants"]) + state["participants"].append(this_rank) + + # When reaching min workers, or changing state to frozen, we'll set + # the active_version node to be ephemeral. + set_ttl: Optional[int] = None + if len(state["participants"]) == self._num_max_workers: + state["status"] = "frozen" + state["keep_alives"] = [] + set_ttl = CONST_ETCD_FROZEN_TTL + elif len(state["participants"]) >= self._num_min_workers: + set_ttl = CONST_ETCD_JOINABLE_EPHEMERAL_TTL + + try: + # Compare-and-swap. + active_version = self.client.test_and_set( + key=self.get_path("/rdzv/active_version"), + value=json.dumps(state), + prev_value=active_version.value, + ttl=set_ttl, + ) + # We succeeded joining. + return active_version, this_rank + + except etcd.EtcdCompareFailed: + logger.info("Join rendezvous CAS unsuccessful, retrying") + + def wait_for_peers(self, expected_version): + """Helper method for the join phase.""" + active_version, state = self.get_rdzv_state() + while True: + if state["status"] == "frozen" and state["version"] == expected_version: + # Success, all peers arrived. + return active_version + + elif state["status"] == "joinable" and state["version"] == expected_version: + # Continue waiting for any interesting events. + active_version, state = self.try_wait_for_state_change( + etcd_index=active_version.etcd_index + 1 + ) + + else: + # No valid transition possible at this point + raise EtcdRendezvousRetryableFailure( + "Rendezvous state transition no longer possible. Must re-enter." + ) + + def confirm_membership(self, expected_version, this_rank): + """Helper method for the confirm phase.""" + # Compare-and-swap loop + while True: + cas_delay() + active_version, state = self.get_rdzv_state() + + if state["status"] != "frozen": + raise EtcdRendezvousRetryImmediately( + "Rendezvous no longer frozen, before we confirmed. " + "Must join next one" + ) + if state["version"] != expected_version: + raise EtcdRendezvousRetryImmediately( + "Rendezvous version changed. Must try join the new one." + ) + + this_lease_key = self.get_path( + f"/rdzv/v_{expected_version}/rank_{this_rank}" + ) + self.client.set(this_lease_key, value=None, ttl=CONST_WORKER_KEEPALIVE_TTL) + + state["keep_alives"].append(this_lease_key) + if len(state["keep_alives"]) == len(state["participants"]): + # Everyone confirmed (this rank is last to do so) + state["status"] = "final" + state["num_workers_waiting"] = 0 + finalize = True + else: + finalize = False + + try: + # Compare-and-swap. If new state is still frozen, keep it ephemeral. + active_version = self.client.test_and_set( + key=self.get_path("/rdzv/active_version"), + value=json.dumps(state), + prev_value=active_version.value, + ttl=None if finalize else CONST_ETCD_FROZEN_TTL, + ) + + self._lease_this_rank_stop = self.setup_lease_renewal( + this_lease_key, ttl=CONST_WORKER_KEEPALIVE_TTL + ) + return active_version + + except etcd.EtcdCompareFailed: + logger.info("Confirm membership CAS unsuccessful, retrying") + + def wait_for_final(self, expected_version): + """Helper method for the confirm phase.""" + active_version, state = self.get_rdzv_state() + while True: + if state["status"] == "final" and state["version"] == expected_version: + # Success. This rendezvous is final, and we accept it. + return active_version + + elif state["status"] == "frozen" and state["version"] == expected_version: + # Continue waiting for any interesting events. + active_version, state = self.try_wait_for_state_change( + etcd_index=active_version.etcd_index + 1 + ) + + else: + # No valid transition possible at this point + raise EtcdRendezvousRetryableFailure( + "Rendezvous state transition no longer possible. Must re-enter." + ) + + def announce_self_waiting(self, expected_version): + """ + Announce this worker is waiting (via num_workers_waiting counter) to join next + rendezvous, but only if state and version match. + """ + while True: + cas_delay() + active_version, state = self.get_rdzv_state() + + if state["status"] != "final" or state["version"] != expected_version: + raise EtcdRendezvousRetryImmediately + + # Increment counter to signal an additional waiting worker. + state["num_workers_waiting"] += 1 + + try: + active_version = self.client.test_and_set( + key=self.get_path("/rdzv/active_version"), + value=json.dumps(state), + prev_value=active_version.value, + ) + return active_version + + except etcd.EtcdCompareFailed: + logger.info("Announce self as waiting CAS unsuccessful, retrying") + + def wait_for_rendezvous_to_free(self, expected_version): + """ + When there's an existing valid rendezvous in state 'final', we have to wait until the next opportunity to join. + + Such opportunity may come from: + + 1. rendezvous state changed by someone else, in which case we unblock and retry. + 2. rendezvous becomes invalid because at least one member failed to renew their + leased keep_alive node. We detect this, and destroy the rendezvous. + """ + active_version, state = self.get_rdzv_state() + while True: + if state["status"] != "final" or state["version"] != expected_version: + return + + # Check if current rendezvous state is valid, in the sense that all + # its members are alive (renewing their lease). + # If not, try destroy this rendezvous, so a new one can be created. + alive_members = self.client.get( + self.get_path(f"/rdzv/v_{expected_version}") + ) + keep_alive_keys = [ch.key for ch in alive_members.children] + + for key in state["keep_alives"]: + if key not in keep_alive_keys: + # This participant didn't renew their lease. We'll declare this + # rendezvous version as dead (but only if it hadn't changed) + logger.info("Keep-alive key %s is not renewed.", key) + logger.info( + "Rendezvous version %s is incomplete. ", expected_version + ) + logger.info("Attempting to destroy it.") + + # Compare-and-delete operation. Throws if compare failed, + # which means rendezvous was already destroyed/re-created/closed, + # and we can try to re-enter the barrier. + self.client.delete( + key=self.get_path("/rdzv/active_version"), + prevValue=active_version.value, + ) + + logger.info( + "Destroyed rendezvous version %s successfully.", + expected_version, + ) + + # We can return (and retry) immediately + return + + # Existing rendezvous seems valid, no reason to destroy it. + # We just have to wait until something changes and re-check. + try: + overall_timeout = ( + max(self._rendezvous_deadline - time.time(), 0.0) + 1.0 + ) + self.client.watch( + key=self.get_path("/rdzv"), + index=active_version.etcd_index + 1, + recursive=True, + timeout=overall_timeout, + ) + except (etcd.EtcdEventIndexCleared, etcd.EtcdWatchTimedOut): + pass + + if time.time() > self._rendezvous_deadline: + raise RendezvousTimeoutError + active_version, state = self.get_rdzv_state() + + def handle_join_last_call(self, expected_version, deadline): + """ + After we reach min number of workers, one particular worker takes on the + responsibility of waiting an additional timeout before closing the join window. + If the worker responsible for this fails, the rendezvous will be destroyed due + to expiring TTL, and the other participants will re-rendezvous. + + Here we expect to see state + Exit gracefully if either: + + 1. state becomes + 2. timeout happens (reaching deadline), in which case + we try the transition to + + Exit with exception otherwise. + """ + active_version, state = self.get_rdzv_state() + while True: + if state["status"] == "frozen" and state["version"] == expected_version: + # Worker set became frozen before last-call timeout. This is possible + # when num_max_workers is reached before the timeout. + return + + if state["status"] != "joinable" or state["version"] != expected_version: + raise EtcdRendezvousRetryableFailure( + "Rendezvous state transition no longer possible. Must re-enter." + ) + + # If timeout occurred, attempt a state transition (joinable -> frozen) + if time.time() >= deadline: + state["status"] = "frozen" + state["keep_alives"] = [] + try: + active_version = self.client.test_and_set( + key=self.get_path("/rdzv/active_version"), + value=json.dumps(state), + prev_value=active_version.value, + ttl=CONST_ETCD_FROZEN_TTL, + ) + # We successfully made this rendezvous frozen. + return + except etcd.EtcdCompareFailed: + logger.info( + "Join last-call transition CAS unsuccessful. Will retry" + ) + cas_delay() + active_version, state = self.get_rdzv_state() + continue + + # Timeout did not occur, so we must refresh TTL, and wait for + # further changes. Note: we only want TTL to be refreshed if + # state is still joinable, hence we use CAS for that here, + # even though we don't change any of the data. + try: + active_version = self.client.test_and_set( + key=self.get_path("/rdzv/active_version"), + value=active_version.value, + prev_value=active_version.value, + ttl=CONST_ETCD_JOINABLE_EPHEMERAL_TTL, + ) + + # Minimize "oversleeping": + timeout = min( + CONST_ETCD_JOINABLE_EPHEMERAL_TTL / 2, + deadline - time.time() + 1.0, # Oversleeping by 1s is ok. + ) + active_version, state = self.try_wait_for_state_change( + etcd_index=active_version.etcd_index + 1, timeout=timeout + ) + except etcd.EtcdCompareFailed: + logger.info("Join last-call TTL refresh CAS unsuccessful, will retry") + cas_delay() + active_version, state = self.get_rdzv_state() + + def set_closed(self): + """ + Mark rendezvous 'closed' for current run_id, which is used to signal other + participants to not attempt to perform (re-)rendezvous. This is useful + when one of the workers decides the job is complete. + """ + while True: + active_version, state = self.get_rdzv_state() + + if state["status"] == "closed": + # Already closed by someone else. + return + + state["status"] = "closed" + try: + self.client.test_and_set( + key=self.get_path("/rdzv/active_version"), + value=json.dumps(state), + prev_value=active_version.value, + ) + return + + except etcd.EtcdCompareFailed: + logger.info("Set closed CAS unsuccessful, retrying") + cas_delay() + + def get_rdzv_state(self): + active_version = self.client.get(key=self.get_path("/rdzv/active_version")) + return active_version, json.loads(active_version.value) + + def try_wait_for_state_change(self, etcd_index, timeout=None): + # Don't sleep past the overall deadline (at least more than by 1s) + overall_timeout = max(self._rendezvous_deadline - time.time(), 0.0) + 1.0 + timeout = overall_timeout if timeout is None else min(timeout, overall_timeout) + + try: + self.client.watch( + self.get_path("/rdzv/active_version"), index=etcd_index, timeout=timeout + ) + except (etcd.EtcdEventIndexCleared, etcd.EtcdWatchTimedOut): + pass + + if time.time() > self._rendezvous_deadline: + raise RendezvousTimeoutError + + # Unfortunately, we have to do another fetch in order to get last etcd_index. + return self.get_rdzv_state() + + def get_path(self, path): + if not path.startswith("/"): + path = "/" + path + + return f"{self._prefix}run_{self._run_id}{path}" + + def create_path_if_not_exists(self, full_path, ttl=None): + try: + self.client.write( + key=full_path, value=None, dir=True, prevExist=False, ttl=ttl + ) + except etcd.EtcdAlreadyExist: + pass + + def setup_lease_renewal(self, full_path, ttl): + # NOTE: For ephemeral key TTL renewal (~lease) to work correctly, + # make sure you don't call any long-blocking methods that do not + # release the Python's GIL! An example of this is calling a pybind11 + # extension function that is blocking / long-running, but is not + # doing a scoped release of the GIL. + def lease_worker(client, path, ttl, stop_event): + while True: + try: + client.refresh(path, ttl=ttl) + except etcd.EtcdKeyNotFound: + break + except ConnectionRefusedError: + # This error usually occurs during test when the server already got terminated but the + # python garbage collector have not yet invoked the __del__ method. + break + + if stop_event.wait(timeout=ttl / 2): + break + + lease_stop_event = threading.Event() + lease_thread = threading.Thread( + target=lease_worker, args=(self.client, full_path, ttl, lease_stop_event) + ) + + lease_thread.daemon = True + lease_thread.start() + + return lease_stop_event + + def store_extra_data(self, rdzv_version, key, value): + node = self.get_path(f"/rdzv/v_{rdzv_version}/extra_data") + try: + # If first time we are storing anything: + extra_data = self.client.write( + key=node, value=json.dumps({key: value}), prevExist=False + ) + return + except etcd.EtcdAlreadyExist: + pass + + # CAS loop, to make sure we don't lose concurrent stores. + while True: + # We never delete extra_data. Failure here should be fatal, no special handling. + extra_data = self.client.get(node) + + new_extra_data_value = json.loads(extra_data.value) + new_extra_data_value[key] = value + + try: + extra_data = self.client.test_and_set( + key=node, + value=json.dumps(new_extra_data_value), + prev_value=extra_data.value, + ) + return + except etcd.EtcdCompareFailed: + logger.info("Store extra_data CAS unsuccessful, retrying") + time.sleep(0.1) + + def load_extra_data(self, rdzv_version, key, timeout=None): + # 'extra_data' node itself, and the directory it is located in: + node = self.get_path(f"/rdzv/v_{rdzv_version}/extra_data") + node_dir = self.get_path(f"/rdzv/v_{rdzv_version}") + + # TODO: implement timeout + # https://github.com/pytorch/elastic/issues/12 + while True: + # Combined wait for the node itself, and the key inside it. + root = self.client.get(node_dir) + + # Find the extra_data node, if it exists + extra_data = [n for n in root.children if n.key == node] + assert len(extra_data) <= 1 + + # Node for extra_data exists, check the desired key inside it. + if len(extra_data) == 1: + extra_data_dict = json.loads(extra_data[0].value) + if key in extra_data_dict: + return extra_data_dict[key] + + # The 'extra_data' node doesn't exist, or they key isn't published yet. + # Wait for interesting events on the extra_data node and retry. + try: + self.client.watch(node, index=root.etcd_index + 1) + except (etcd.EtcdEventIndexCleared, etcd.EtcdWatchTimedOut): + pass + + def setup_kv_store(self, rdzv_version): + store_path = self.get_path(f"/rdzv/v_{rdzv_version}/kv") + self.create_path_if_not_exists(store_path) + return EtcdStore(etcd_client=self.client, etcd_store_prefix=store_path) + + +def _create_etcd_client(params: RendezvousParameters) -> etcd.Client: + """Create a new ``etcd.Client`` from the specified ``RendezvousParameters``.""" + hostname, port = parse_rendezvous_endpoint(params.endpoint, 2379) + + # The communication protocol + protocol = params.config.get("protocol") + if protocol is None: + protocol = "http" + else: + if protocol != "http" and protocol != "https": + raise ValueError("The etcd protocol must be HTTP or HTTPS.") + + # The SSL client certificate + ssl_cert = params.config.get("cert") + if ssl_cert is not None: + cert_key = params.config.get("key") + if cert_key is not None: + # The etcd client expects the certificate key as the second element + # of the `cert` tuple. + ssl_cert = (ssl_cert, cert_key) + + # The root certificate + ca_cert = params.config.get("cacert") + + return etcd.Client( + hostname, + port, + protocol=protocol, + cert=ssl_cert, + ca_cert=ca_cert, + allow_reconnect=True, + ) + + +# Handler for torch.distributed "static" registration +def create_rdzv_handler(params: RendezvousParameters) -> RendezvousHandler: + """ + Usage: + + :: + + rdzv_params = RendezvousParameters( + backend="etcd", + endpoint="192.168.0.42:2379", + run_id="123", + min_nodes=4, + max_nodes=8, + timeout=300, + last_call_timeout=30, + etcd_prefix="custom_prefix", + protocol="https", + cacert="/etc/kubernetes/certs/ca.crt", + cert="/etc/kubernetes/certs/client.crt", + key="/etc/kubernetes/certs/client.key") + # -- or -- + rdzv_params = RendezvousParameters( + backend="etcd", + endpoint="192.168.0.42:2379", + run_id="123", + min_nodes=4, + max_nodes=8) + + etcd_rdzv_handler = create_etcd_rendezvous_handler(rdzv_params) + + + Where: + run_id - unique id for this training job instance, + min_nodes - min number of workers expected to join the rendezvous, + max_nodes - max number of workers allowed to join the rendezvous, + defaults to min_workers is not specified. + timeout - total timeout within which next_rendezvous is expected to + succeed; a RendezvousTimeoutError is raised otherwise; + Defaults is 600 (10 minutes). + last_call_timeout - additional wait amount ("last call") after + min number of workers has been reached. + Defaults to 30 seconds. + etcd_prefix - path prefix (from etcd root), inside which all + etcd nodes will be created. + Default is "/torchelastic/p2p". + protocol - http (default) or https to access etcd. + cacert - CA cert to access etcd, only makes sense with https. + cert - client cert to access etcd, only makes sense with https. + key - client key to access etcd, only makes sense with https. + """ + client = _create_etcd_client(params) + + etcd_prefix = params.get("etcd_prefix", "/torchelastic/p2p") + + rdzv = EtcdRendezvous( + client=client, + prefix=etcd_prefix, + run_id=params.run_id, + num_min_workers=params.min_nodes, + num_max_workers=params.max_nodes, + timeout=params.get_as_int("timeout", _DEFAULT_TIMEOUT), + last_call_timeout=params.get_as_int( + "last_call_timeout", _DEFAULT_LAST_CALL_TIMEOUT + ), + ) + return EtcdRendezvousHandler( + rdzv_impl=rdzv, + local_addr=params.local_addr, + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..c98504279f193bde4c92b9121b9d775cc4f099f4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py @@ -0,0 +1,215 @@ +# mypy: allow-untyped-defs +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import binascii +from base64 import b64decode, b64encode +from typing import cast, Optional + +import urllib3.exceptions # type: ignore[import] + + +try: + import etcd # type: ignore[import] +except ModuleNotFoundError: + from . import _etcd_stub as etcd + +from torch.distributed import Store + +from .api import RendezvousConnectionError, RendezvousParameters, RendezvousStateError +from .dynamic_rendezvous import RendezvousBackend, Token +from .etcd_store import EtcdStore +from .utils import parse_rendezvous_endpoint + + +class EtcdRendezvousBackend(RendezvousBackend): + """Represents an etcd-based rendezvous backend. + + Args: + client: + The ``etcd.Client`` instance to use to communicate with etcd. + run_id: + The run id of the rendezvous. + key_prefix: + The path under which to store the rendezvous state in etcd. + ttl: + The TTL of the rendezvous state. If not specified, defaults to two hours. + """ + + _DEFAULT_TTL = 7200 # 2 hours + + _client: etcd.Client + _key: str + _ttl: int + + def __init__( + self, + client: etcd.Client, + run_id: str, + key_prefix: Optional[str] = None, + ttl: Optional[int] = None, + ) -> None: + if not run_id: + raise ValueError("The run id must be a non-empty string.") + + self._client = client + + if key_prefix: + self._key = key_prefix + "/" + run_id + else: + self._key = run_id + + if ttl and ttl > 0: + self._ttl = ttl + else: + self._ttl = self._DEFAULT_TTL + + @property + def name(self) -> str: + """See base class.""" + return "etcd-v2" + + def get_state(self) -> Optional[tuple[bytes, Token]]: + """See base class.""" + try: + result = self._client.read(self._key) + except etcd.EtcdKeyNotFound: + return None + except (etcd.EtcdException, urllib3.exceptions.TimeoutError) as exc: + raise RendezvousConnectionError( + "The connection to etcd has failed. See inner exception for details." + ) from exc + + return self._decode_state(result) + + def set_state( + self, state: bytes, token: Optional[Token] = None + ) -> Optional[tuple[bytes, Token, bool]]: + """See base class.""" + base64_state = b64encode(state).decode() + + kwargs = {} + + def get_state(): + result = self.get_state() + if result is not None: + tmp = *result, False + # Python 3.6 does not support tuple unpacking in return + # statements. + return tmp + return None + + if token: + try: + token = int(token) + except ValueError: + return get_state() + + if token: + kwargs["prevIndex"] = token + else: + kwargs["prevExist"] = False + + try: + result = self._client.write(self._key, base64_state, self._ttl, **kwargs) + except (etcd.EtcdAlreadyExist, etcd.EtcdCompareFailed): + result = None + except (etcd.EtcdException, urllib3.exceptions.TimeoutError) as exc: + raise RendezvousConnectionError( + "The connection to etcd has failed. See inner exception for details." + ) from exc + + if result is None: + return get_state() + + tmp = *self._decode_state(result), True + return tmp + + def _decode_state(self, result: etcd.EtcdResult) -> tuple[bytes, Token]: + base64_state = result.value.encode() + + try: + state = b64decode(base64_state) + except binascii.Error as exc: + raise RendezvousStateError( + "The state object is corrupt. See inner exception for details." + ) from exc + + return state, result.modifiedIndex + + +def _create_etcd_client(params: RendezvousParameters) -> etcd.Client: + host, port = parse_rendezvous_endpoint(params.endpoint, default_port=2379) + + # The timeout + read_timeout = cast(int, params.get_as_int("read_timeout", 60)) + if read_timeout <= 0: + raise ValueError("The read timeout must be a positive integer.") + + # The communication protocol + protocol = params.get("protocol", "http").strip().lower() + if protocol != "http" and protocol != "https": + raise ValueError("The protocol must be HTTP or HTTPS.") + + # The SSL client certificate + ssl_cert = params.get("ssl_cert") + if ssl_cert: + ssl_cert_key = params.get("ssl_cert_key") + if ssl_cert_key: + # The etcd client expects the certificate key as the second element + # of the `cert` tuple. + ssl_cert = (ssl_cert, ssl_cert_key) + + # The root certificate + ca_cert = params.get("ca_cert") + + try: + return etcd.Client( + host, + port, + read_timeout=read_timeout, + protocol=protocol, + cert=ssl_cert, + ca_cert=ca_cert, + allow_reconnect=True, + ) + except (etcd.EtcdException, urllib3.exceptions.TimeoutError) as exc: + raise RendezvousConnectionError( + "The connection to etcd has failed. See inner exception for details." + ) from exc + + +def create_backend(params: RendezvousParameters) -> tuple[EtcdRendezvousBackend, Store]: + """Create a new :py:class:`EtcdRendezvousBackend` from the specified parameters. + + +--------------+-----------------------------------------------------------+ + | Parameter | Description | + +==============+===========================================================+ + | read_timeout | The read timeout, in seconds, for etcd operations. | + | | Defaults to 60 seconds. | + +--------------+-----------------------------------------------------------+ + | protocol | The protocol to use to communicate with etcd. Valid | + | | values are "http" and "https". Defaults to "http". | + +--------------+-----------------------------------------------------------+ + | ssl_cert | The path to the SSL client certificate to use along with | + | | HTTPS. Defaults to ``None``. | + +--------------+-----------------------------------------------------------+ + | ssl_cert_key | The path to the private key of the SSL client certificate | + | | to use along with HTTPS. Defaults to ``None``. | + +--------------+-----------------------------------------------------------+ + | ca_cert | The path to the rool SSL authority certificate. Defaults | + | | to ``None``. | + +--------------+-----------------------------------------------------------+ + """ + client = _create_etcd_client(params) + + backend = EtcdRendezvousBackend( + client, params.run_id, key_prefix="/torch/elastic/rendezvous" + ) + + store = EtcdStore(client, "/torch/elastic/store") + + return backend, store diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/etcd_server.py b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/etcd_server.py new file mode 100644 index 0000000000000000000000000000000000000000..99623e0bb7bf17e7c1205b37e0d63820248724d0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/etcd_server.py @@ -0,0 +1,248 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import atexit +import logging +import os +import shlex +import shutil +import socket +import subprocess +import tempfile +import time +from typing import Optional, TextIO, Union + + +try: + import etcd # type: ignore[import] +except ModuleNotFoundError: + pass + + +logger = logging.getLogger(__name__) + + +def find_free_port(): + """ + Find a free port and binds a temporary socket to it so that the port can be "reserved" until used. + + .. note:: the returned socket must be closed before using the port, + otherwise a ``address already in use`` error will happen. + The socket should be held and closed as close to the + consumer of the port as possible since otherwise, there + is a greater chance of race-condition where a different + process may see the port as being free and take it. + + Returns: a socket binded to the reserved free port + + Usage:: + + sock = find_free_port() + port = sock.getsockname()[1] + sock.close() + use_port(port) + """ + addrs = socket.getaddrinfo( + host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM + ) + + for addr in addrs: + family, type, proto, _, _ = addr + try: + s = socket.socket(family, type, proto) + s.bind(("localhost", 0)) + s.listen(0) + return s + except OSError as e: + s.close() # type: ignore[possibly-undefined] + print(f"Socket creation attempt failed: {e}") + raise RuntimeError("Failed to create a socket") + + +def stop_etcd(subprocess, data_dir: Optional[str] = None): + if subprocess and subprocess.poll() is None: + logger.info("stopping etcd server") + subprocess.terminate() + subprocess.wait() + + if data_dir: + logger.info("deleting etcd data dir: %s", data_dir) + shutil.rmtree(data_dir, ignore_errors=True) + + +class EtcdServer: + """ + .. note:: tested on etcd server v3.4.3. + + Starts and stops a local standalone etcd server on a random free + port. Useful for single node, multi-worker launches or testing, + where a sidecar etcd server is more convenient than having to + separately setup an etcd server. + + This class registers a termination handler to shutdown the etcd + subprocess on exit. This termination handler is NOT a substitute for + calling the ``stop()`` method. + + The following fallback mechanism is used to find the etcd binary: + + 1. Uses env var TORCHELASTIC_ETCD_BINARY_PATH + 2. Uses ``/bin/etcd`` if one exists + 3. Uses ``etcd`` from ``PATH`` + + Usage + :: + + server = EtcdServer("/usr/bin/etcd", 2379, "/tmp/default.etcd") + server.start() + client = server.get_client() + # use client + server.stop() + + Args: + etcd_binary_path: path of etcd server binary (see above for fallback path) + """ + + def __init__(self, data_dir: Optional[str] = None): + self._port = -1 + self._host = "localhost" + + root = os.path.dirname(__file__) + default_etcd_bin = os.path.join(root, "bin/etcd") + self._etcd_binary_path = os.environ.get( + "TORCHELASTIC_ETCD_BINARY_PATH", default_etcd_bin + ) + if not os.path.isfile(self._etcd_binary_path): + self._etcd_binary_path = "etcd" + + self._base_data_dir = ( + data_dir if data_dir else tempfile.mkdtemp(prefix="torchelastic_etcd_data") + ) + self._etcd_cmd = None + self._etcd_proc: Optional[subprocess.Popen] = None + + def _get_etcd_server_process(self) -> subprocess.Popen: + if not self._etcd_proc: + raise RuntimeError( + "No etcd server process started. Call etcd_server.start() first" + ) + else: + return self._etcd_proc + + def get_port(self) -> int: + """Return the port the server is running on.""" + return self._port + + def get_host(self) -> str: + """Return the host the server is running on.""" + return self._host + + def get_endpoint(self) -> str: + """Return the etcd server endpoint (host:port).""" + return f"{self._host}:{self._port}" + + def start( + self, + timeout: int = 60, + num_retries: int = 3, + stderr: Union[int, TextIO, None] = None, + ) -> None: + """ + Start the server, and waits for it to be ready. When this function returns the sever is ready to take requests. + + Args: + timeout: time (in seconds) to wait for the server to be ready + before giving up. + num_retries: number of retries to start the server. Each retry + will wait for max ``timeout`` before considering it as failed. + stderr: the standard error file handle. Valid values are + `subprocess.PIPE`, `subprocess.DEVNULL`, an existing file + descriptor (a positive integer), an existing file object, and + `None`. + + Raises: + TimeoutError: if the server is not ready within the specified timeout + """ + curr_retries = 0 + while True: + try: + data_dir = os.path.join(self._base_data_dir, str(curr_retries)) + os.makedirs(data_dir, exist_ok=True) + return self._start(data_dir, timeout, stderr) + except Exception as e: + curr_retries += 1 + stop_etcd(self._etcd_proc) + logger.warning( + "Failed to start etcd server, got error: %s, retrying", str(e) + ) + if curr_retries >= num_retries: + shutil.rmtree(self._base_data_dir, ignore_errors=True) + raise + atexit.register(stop_etcd, self._etcd_proc, self._base_data_dir) + + def _start( + self, data_dir: str, timeout: int = 60, stderr: Union[int, TextIO, None] = None + ) -> None: + sock = find_free_port() + sock_peer = find_free_port() + self._port = sock.getsockname()[1] + peer_port = sock_peer.getsockname()[1] + + etcd_cmd = shlex.split( + " ".join( + [ + self._etcd_binary_path, + "--enable-v2", + "--data-dir", + data_dir, + "--listen-client-urls", + f"http://{self._host}:{self._port}", + "--advertise-client-urls", + f"http://{self._host}:{self._port}", + "--listen-peer-urls", + f"http://{self._host}:{peer_port}", + ] + ) + ) + + logger.info("Starting etcd server: [%s]", etcd_cmd) + + sock.close() + sock_peer.close() + self._etcd_proc = subprocess.Popen(etcd_cmd, close_fds=True, stderr=stderr) + self._wait_for_ready(timeout) + + def get_client(self): + """Return an etcd client object that can be used to make requests to this server.""" + return etcd.Client( + host=self._host, port=self._port, version_prefix="/v2", read_timeout=10 + ) + + def _wait_for_ready(self, timeout: int = 60) -> None: + client = etcd.Client( + host=f"{self._host}", port=self._port, version_prefix="/v2", read_timeout=5 + ) + max_time = time.time() + timeout + + while time.time() < max_time: + if self._get_etcd_server_process().poll() is not None: + # etcd server process finished + exitcode = self._get_etcd_server_process().returncode + raise RuntimeError( + f"Etcd server process exited with the code: {exitcode}" + ) + try: + logger.info("etcd server ready. version: %s", client.version) + return + except Exception: + time.sleep(1) + raise TimeoutError("Timed out waiting for etcd server to be ready!") + + def stop(self) -> None: + """Stop the server and cleans up auto generated resources (e.g. data dir).""" + logger.info("EtcdServer stop method called") + stop_etcd(self._etcd_proc, self._base_data_dir) diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/etcd_store.py b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/etcd_store.py new file mode 100644 index 0000000000000000000000000000000000000000..611e284627b254a79dc01cd259bcd316b72442e4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/etcd_store.py @@ -0,0 +1,216 @@ +# mypy: allow-untyped-defs +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import datetime +import random +import time +from base64 import b64decode, b64encode +from typing import Optional + +# pyre-ignore[21]: Could not find name `Store` in `torch.distributed`. +from torch.distributed import Store + + +try: + import etcd # type: ignore[import] +except ModuleNotFoundError: + from . import _etcd_stub as etcd + + +# Delay (sleep) for a small random amount to reduce CAS failures. +# This does not affect correctness, but will reduce requests to etcd server. +def cas_delay(): + time.sleep(random.uniform(0, 0.1)) + + +# pyre-fixme[11]: Annotation `Store` is not defined as a type. +class EtcdStore(Store): + """ + Implement a c10 Store interface by piggybacking on the rendezvous etcd instance. + + This is the store object returned by ``EtcdRendezvous``. + """ + + def __init__( + self, + etcd_client, + etcd_store_prefix, + # Default timeout same as in c10d/Store.hpp + timeout: Optional[datetime.timedelta] = None, + ): + super().__init__() # required for pybind trampoline. + + self.client = etcd_client + self.prefix = etcd_store_prefix + + if timeout is not None: + self.set_timeout(timeout) + + if not self.prefix.endswith("/"): + self.prefix += "/" + + def set(self, key, value): + """ + Write a key/value pair into ``EtcdStore``. + + Both key and value may be either Python ``str`` or ``bytes``. + """ + self.client.set(key=self.prefix + self._encode(key), value=self._encode(value)) + + def get(self, key) -> bytes: + """ + Get a value by key, possibly doing a blocking wait. + + If key is not immediately present, will do a blocking wait + for at most ``timeout`` duration or until the key is published. + + + Returns: + value ``(bytes)`` + + Raises: + LookupError - If key still not published after timeout + """ + b64_key = self.prefix + self._encode(key) + kvs = self._try_wait_get([b64_key]) + + if kvs is None: + raise LookupError(f"Key {key} not found in EtcdStore") + + return self._decode(kvs[b64_key]) + + def add(self, key, num: int) -> int: + """ + Atomically increment a value by an integer amount. + + The integer is represented as a string using base 10. If key is not present, + a default value of ``0`` will be assumed. + + Returns: + the new (incremented) value + + + """ + b64_key = self._encode(key) + # c10d Store assumes value is an integer represented as a decimal string + try: + # Assume default value "0", if this key didn't yet: + node = self.client.write( + key=self.prefix + b64_key, + value=self._encode(str(num)), # i.e. 0 + num + prevExist=False, + ) + return int(self._decode(node.value)) + except etcd.EtcdAlreadyExist: + pass + + while True: + # Note: c10d Store does not have a method to delete keys, so we + # can be sure it's still there. + node = self.client.get(key=self.prefix + b64_key) + new_value = self._encode(str(int(self._decode(node.value)) + num)) + try: + node = self.client.test_and_set( + key=node.key, value=new_value, prev_value=node.value + ) + return int(self._decode(node.value)) + except etcd.EtcdCompareFailed: + cas_delay() + + def wait(self, keys, override_timeout: Optional[datetime.timedelta] = None): + """ + Wait until all of the keys are published, or until timeout. + + Raises: + LookupError - if timeout occurs + """ + b64_keys = [self.prefix + self._encode(key) for key in keys] + kvs = self._try_wait_get(b64_keys, override_timeout) + if kvs is None: + raise LookupError("Timeout while waiting for keys in EtcdStore") + # No return value on success + + def check(self, keys) -> bool: + """Check if all of the keys are immediately present (without waiting).""" + b64_keys = [self.prefix + self._encode(key) for key in keys] + kvs = self._try_wait_get( + b64_keys, + override_timeout=datetime.timedelta(microseconds=1), # as if no wait + ) + return kvs is not None + + # + # Encode key/value data in base64, so we can store arbitrary binary data + # in EtcdStore. Input can be `str` or `bytes`. + # In case of `str`, utf-8 encoding is assumed. + # + def _encode(self, value) -> str: + if type(value) == bytes: + return b64encode(value).decode() + elif type(value) == str: + return b64encode(value.encode()).decode() + raise ValueError("Value must be of type str or bytes") + + # + # Decode a base64 string (of type `str` or `bytes`). + # Return type is `bytes`, which is more convenient with the Store interface. + # + def _decode(self, value) -> bytes: + if type(value) == bytes: + return b64decode(value) + elif type(value) == str: + return b64decode(value.encode()) + raise ValueError("Value must be of type str or bytes") + + # + # Get all of the (base64-encoded) etcd keys at once, or wait until all the keys + # are published or timeout occurs. + # This is a helper method for the public interface methods. + # + # On success, a dictionary of {etcd key -> etcd value} is returned. + # On timeout, None is returned. + # + def _try_wait_get(self, b64_keys, override_timeout=None): + timeout = self.timeout if override_timeout is None else override_timeout # type: ignore[attr-defined] + deadline = time.time() + timeout.total_seconds() + + while True: + # Read whole directory (of keys), filter only the ones waited for + all_nodes = None + try: + all_nodes = self.client.get(key=self.prefix) + req_nodes = { + node.key: node.value + for node in all_nodes.children + if node.key in b64_keys + } + + if len(req_nodes) == len(b64_keys): + # All keys are available + return req_nodes + except etcd.EtcdKeyNotFound: + pass + + watch_timeout = deadline - time.time() + if watch_timeout <= 0: + return None + + try: + index = all_nodes.etcd_index + 1 if all_nodes else 0 + self.client.watch( + key=self.prefix, + recursive=True, + timeout=watch_timeout, + index=index, + ) + except etcd.EtcdWatchTimedOut: + if time.time() >= deadline: + return None + else: + continue + except etcd.EtcdEventIndexCleared: + continue diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/registry.py b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..476141b720a392fc2cf9f5c308b08f283440454a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/registry.py @@ -0,0 +1,100 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import sys + +from .api import ( + rendezvous_handler_registry as handler_registry, + RendezvousHandler, + RendezvousParameters, +) +from .dynamic_rendezvous import create_handler + + +if sys.version_info < (3, 10): + from importlib_metadata import entry_points +else: + from importlib.metadata import entry_points + +log = logging.getLogger(__name__) + +__all__ = ["get_rendezvous_handler"] + + +def _create_static_handler(params: RendezvousParameters) -> RendezvousHandler: + from . import static_tcp_rendezvous + + return static_tcp_rendezvous.create_rdzv_handler(params) + + +def _create_etcd_handler(params: RendezvousParameters) -> RendezvousHandler: + from . import etcd_rendezvous + + return etcd_rendezvous.create_rdzv_handler(params) + + +def _create_etcd_v2_handler(params: RendezvousParameters) -> RendezvousHandler: + from .etcd_rendezvous_backend import create_backend + + backend, store = create_backend(params) + + return create_handler(store, backend, params) + + +def _create_c10d_handler(params: RendezvousParameters) -> RendezvousHandler: + from .c10d_rendezvous_backend import create_backend + + backend, store = create_backend(params) + + return create_handler(store, backend, params) + + +def _register_default_handlers() -> None: + handler_registry.register("etcd", _create_etcd_handler) + handler_registry.register("etcd-v2", _create_etcd_v2_handler) + handler_registry.register("c10d", _create_c10d_handler) + handler_registry.register("static", _create_static_handler) + + +def _register_out_of_tree_handlers() -> None: + discovered_handler_generators = entry_points(group="torchrun.handlers") + + for handler_generator in discovered_handler_generators: + try: + get_handler = discovered_handler_generators[handler_generator.name].load() + handler_registry.register(handler_generator.name, get_handler()) + except Exception: + log.warning( + "Exception while registering out of tree plugin %s: ", + handler_generator.name, + exc_info=True, + ) + + +def get_rendezvous_handler(params: RendezvousParameters) -> RendezvousHandler: + """ + Obtain a reference to a :py:class`RendezvousHandler`. + + Custom rendezvous handlers can be registered by + + :: + + from torch.distributed.elastic.rendezvous import rendezvous_handler_registry + from torch.distributed.elastic.rendezvous.registry import get_rendezvous_handler + + + def create_my_rdzv(params: RendezvousParameters): + return MyCustomRdzv(params) + + + rendezvous_handler_registry.register("my_rdzv_backend_name", create_my_rdzv) + + my_rdzv_handler = get_rendezvous_handler( + "my_rdzv_backend_name", RendezvousParameters + ) + """ + return handler_registry.create_handler(params) diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py new file mode 100644 index 0000000000000000000000000000000000000000..5b87f4e9e73cc325e8e60904edf4c133c88cdf08 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import datetime +import logging +from typing import cast, Optional + +from torch.distributed import PrefixStore, Store, TCPStore +from torch.distributed.elastic.rendezvous import ( + RendezvousHandler, + RendezvousInfo, + RendezvousParameters, + RendezvousStoreInfo, +) +from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint + + +__all__ = ["StaticTCPRendezvous", "create_rdzv_handler"] + +logger = logging.getLogger(__name__) + +_default_timeout_seconds = 600 + + +class StaticTCPRendezvous(RendezvousHandler): + """ + Static rendezvous that is a wrapper around the TCPStore. + + Creates TCPStore based on the input parameters with the + listener on the agent with group_rank=0 + """ + + def __init__( + self, + master_addr: str, + master_port: int, + rank: int, + world_size: int, + run_id: str, + timeout: int, + ): + self.master_addr = master_addr + self.master_port = master_port + self.rank = rank + self.world_size = world_size + self.run_id = run_id + self.timeout = datetime.timedelta(seconds=timeout) + self._store: Optional[Store] = None + + def get_backend(self) -> str: + return "static" + + @property + def use_agent_store(self) -> bool: + return True + + def next_rendezvous(self) -> RendezvousInfo: + logger.info("Creating TCPStore as the c10d::Store implementation") + is_master = self.rank == 0 + if not self._store: + self._store = TCPStore( # type: ignore[call-arg] + self.master_addr, + self.master_port, + self.world_size, + is_master, + self.timeout, + multi_tenant=True, + ) + store = PrefixStore(self.run_id, self._store) + # TCPStore server instance is used by trainer code + bootstrap_store_info = RendezvousStoreInfo(self.master_addr, self.master_port) + return RendezvousInfo( + store, + self.rank, + self.world_size, + bootstrap_store_info, + ) + + def is_closed(self): + return False + + def set_closed(self): + pass + + def num_nodes_waiting(self): + return 0 + + def get_run_id(self) -> str: + return self.run_id + + def shutdown(self) -> bool: + return True + + +def create_rdzv_handler(params: RendezvousParameters) -> RendezvousHandler: + if "rank" not in params.config: + raise ValueError( + "rank is absent in RendezvousParameters." + "Try add --node-rank to the cmd request" + ) + endpoint = params.endpoint.strip() + if not endpoint: + raise ValueError( + "endpoint is absent in RendezvousParameters" + "Try add --master-port and --master-addr to the cmd request" + ) + master_addr, master_port = parse_rendezvous_endpoint(endpoint, -1) + if master_port == -1: + raise ValueError( + f"Port is absent in endpoint: {endpoint}. Try launching with --master-port" + ) + world_size = params.max_nodes + rank = cast(int, params.config.get("rank")) + run_id = params.run_id + if "timeout" in params.config: + timeout = int(params.config["timeout"]) + else: + timeout = _default_timeout_seconds + + return StaticTCPRendezvous( + master_addr, master_port, rank, world_size, run_id, timeout + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/utils.py b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5036457ba81361d377c76a66a5c74ffb710fd725 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/utils.py @@ -0,0 +1,284 @@ +# mypy: allow-untyped-defs +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import ipaddress +import random +import re +import socket +import time +import weakref +from datetime import timedelta +from threading import Event, Thread +from typing import Any, Callable, Optional, Union + + +__all__ = ["parse_rendezvous_endpoint"] + + +def _parse_rendezvous_config(config_str: str) -> dict[str, str]: + """Extract key-value pairs from a rendezvous configuration string. + + Args: + config_str: + A string in format =,...,=. + """ + config: dict[str, str] = {} + + config_str = config_str.strip() + if not config_str: + return config + + key_values = config_str.split(",") + for kv in key_values: + key, *values = kv.split("=", 1) + + key = key.strip() + if not key: + raise ValueError( + "The rendezvous configuration string must be in format " + "=,...,=." + ) + + value: Optional[str] + if values: + value = values[0].strip() + else: + value = None + if not value: + raise ValueError( + f"The rendezvous configuration option '{key}' must have a value specified." + ) + + config[key] = value + return config + + +def _try_parse_port(port_str: str) -> Optional[int]: + """Try to extract the port number from ``port_str``.""" + if port_str and re.match(r"^[0-9]{1,5}$", port_str): + return int(port_str) + return None + + +def parse_rendezvous_endpoint( + endpoint: Optional[str], default_port: int +) -> tuple[str, int]: + """Extract the hostname and the port number from a rendezvous endpoint. + + Args: + endpoint: + A string in format [:]. + default_port: + The port number to use if the endpoint does not include one. + + Returns: + A tuple of hostname and port number. + """ + if endpoint is not None: + endpoint = endpoint.strip() + + if not endpoint: + return ("localhost", default_port) + + # An endpoint that starts and ends with brackets represents an IPv6 address. + if endpoint[0] == "[" and endpoint[-1] == "]": + host, *rest = endpoint, *[] + else: + host, *rest = endpoint.rsplit(":", 1) + + # Sanitize the IPv6 address. + if len(host) > 1 and host[0] == "[" and host[-1] == "]": + host = host[1:-1] + + if len(rest) == 1: + port = _try_parse_port(rest[0]) + if port is None or port >= 2**16: + raise ValueError( + f"The port number of the rendezvous endpoint '{endpoint}' must be an integer " + "between 0 and 65536." + ) + else: + port = default_port + + if not re.match(r"^[\w\.:-]+$", host): + raise ValueError( + f"The hostname of the rendezvous endpoint '{endpoint}' must be a dot-separated list of " + "labels, an IPv4 address, or an IPv6 address." + ) + + return host, port + + +def _matches_machine_hostname(host: str) -> bool: + """Indicate whether ``host`` matches the hostname of this machine. + + This function compares ``host`` to the hostname as well as to the IP + addresses of this machine. Note that it may return a false negative if this + machine has CNAME records beyond its FQDN or IP addresses assigned to + secondary NICs. + """ + if host == "localhost": + return True + + try: + addr = ipaddress.ip_address(host) + except ValueError: + addr = None + + if addr and addr.is_loopback: + return True + + try: + host_addr_list = socket.getaddrinfo( + host, None, proto=socket.IPPROTO_TCP, flags=socket.AI_CANONNAME + ) + except (ValueError, socket.gaierror) as _: + host_addr_list = [] + + host_ip_list = [host_addr_info[4][0] for host_addr_info in host_addr_list] + + this_host = socket.gethostname() + if host == this_host: + return True + + addr_list = socket.getaddrinfo( + this_host, None, proto=socket.IPPROTO_TCP, flags=socket.AI_CANONNAME + ) + for addr_info in addr_list: + # If we have an FQDN in the addr_info, compare it to `host`. + if addr_info[3] and addr_info[3] == host: + return True + + # Otherwise if `host` represents an IP address, compare it to our IP + # address. + if addr and addr_info[4][0] == str(addr): + return True + + # If the IP address matches one of the provided host's IP addresses + if addr_info[4][0] in host_ip_list: + return True + + return False + + +def _delay(seconds: Union[float, tuple[float, float]]) -> None: + """Suspend the current thread for ``seconds``. + + Args: + seconds: + Either the delay, in seconds, or a tuple of a lower and an upper + bound within which a random delay will be picked. + """ + if isinstance(seconds, tuple): + seconds = random.uniform(*seconds) + # Ignore delay requests that are less than 10 milliseconds. + if seconds >= 0.01: + time.sleep(seconds) + + +class _PeriodicTimer: + """Represent a timer that periodically runs a specified function. + + Args: + interval: + The interval, in seconds, between each run. + function: + The function to run. + """ + + # The state of the timer is hold in a separate context object to avoid a + # reference cycle between the timer and the background thread. + class _Context: + interval: float + function: Callable[..., None] + args: tuple[Any, ...] + kwargs: dict[str, Any] + stop_event: Event + + _name: Optional[str] + _thread: Optional[Thread] + _finalizer: Optional[weakref.finalize] + + # The context that is shared between the timer and the background thread. + _ctx: _Context + + def __init__( + self, + interval: timedelta, + function: Callable[..., None], + *args: Any, + **kwargs: Any, + ) -> None: + self._name = None + + self._ctx = self._Context() + self._ctx.interval = interval.total_seconds() + self._ctx.function = function # type: ignore[assignment] + self._ctx.args = args or () + self._ctx.kwargs = kwargs or {} + self._ctx.stop_event = Event() + + self._thread = None + self._finalizer = None + + @property + def name(self) -> Optional[str]: + """Get the name of the timer.""" + return self._name + + def set_name(self, name: str) -> None: + """Set the name of the timer. + + The specified name will be assigned to the background thread and serves + for debugging and troubleshooting purposes. + """ + if self._thread: + raise RuntimeError("The timer has already started.") + + self._name = name + + def start(self) -> None: + """Start the timer.""" + if self._thread: + raise RuntimeError("The timer has already started.") + + self._thread = Thread( + target=self._run, + name=self._name or "PeriodicTimer", + args=(self._ctx,), + daemon=True, + ) + + # We avoid using a regular finalizer (a.k.a. __del__) for stopping the + # timer as joining a daemon thread during the interpreter shutdown can + # cause deadlocks. The weakref.finalize is a superior alternative that + # provides a consistent behavior regardless of the GC implementation. + self._finalizer = weakref.finalize( + self, self._stop_thread, self._thread, self._ctx.stop_event + ) + + # We do not attempt to stop our background thread during the interpreter + # shutdown. At that point we do not even know whether it still exists. + self._finalizer.atexit = False + + self._thread.start() + + def cancel(self) -> None: + """Stop the timer at the next opportunity.""" + if self._finalizer: + self._finalizer() + + @staticmethod + def _run(ctx) -> None: + while not ctx.stop_event.wait(ctx.interval): + ctx.function(*ctx.args, **ctx.kwargs) + + @staticmethod + def _stop_thread(thread, stop_event): + stop_event.set() + + thread.join() diff --git a/phivenv/Lib/site-packages/torch/distributed/launch.py b/phivenv/Lib/site-packages/torch/distributed/launch.py new file mode 100644 index 0000000000000000000000000000000000000000..ca27435c9e8e3a9a40e496cac6c8350bb4137835 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/launch.py @@ -0,0 +1,207 @@ +# mypy: allow-untyped-defs +r""" +Module ``torch.distributed.launch``. + +``torch.distributed.launch`` is a module that spawns up multiple distributed +training processes on each of the training nodes. + +.. warning:: + + This module is going to be deprecated in favor of :ref:`torchrun `. + +The utility can be used for single-node distributed training, in which one or +more processes per node will be spawned. The utility can be used for either +CPU training or GPU training. If the utility is used for GPU training, +each distributed process will be operating on a single GPU. This can achieve +well-improved single-node training performance. It can also be used in +multi-node distributed training, by spawning up multiple processes on each node +for well-improved multi-node distributed training performance as well. +This will especially be beneficial for systems with multiple Infiniband +interfaces that have direct-GPU support, since all of them can be utilized for +aggregated communication bandwidth. + +In both cases of single-node distributed training or multi-node distributed +training, this utility will launch the given number of processes per node +(``--nproc-per-node``). If used for GPU training, this number needs to be less +or equal to the number of GPUs on the current system (``nproc_per_node``), +and each process will be operating on a single GPU from *GPU 0 to +GPU (nproc_per_node - 1)*. + +**How to use this module:** + +1. Single-Node multi-process distributed training + +:: + + python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE + YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other + arguments of your training script) + +2. Multi-Node multi-process distributed training: (e.g. two nodes) + + +Node 1: *(IP: 192.168.1.1, and has a free port: 1234)* + +:: + + python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE + --nnodes=2 --node-rank=0 --master-addr="192.168.1.1" + --master-port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 + and all other arguments of your training script) + +Node 2: + +:: + + python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE + --nnodes=2 --node-rank=1 --master-addr="192.168.1.1" + --master-port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 + and all other arguments of your training script) + +3. To look up what optional arguments this module offers: + +:: + + python -m torch.distributed.launch --help + + +**Important Notices:** + +1. This utility and multi-process distributed (single-node or +multi-node) GPU training currently only achieves the best performance using +the NCCL distributed backend. Thus NCCL backend is the recommended backend to +use for GPU training. + +2. In your training program, you must parse the command-line argument: +``--local-rank=LOCAL_PROCESS_RANK``, which will be provided by this module. +If your training program uses GPUs, you should ensure that your code only +runs on the GPU device of LOCAL_PROCESS_RANK. This can be done by: + +Parsing the local_rank argument + +:: + + >>> # xdoctest: +SKIP + >>> import argparse + >>> parser = argparse.ArgumentParser() + >>> parser.add_argument("--local-rank", "--local_rank", type=int) + >>> args = parser.parse_args() + +Set your device to local rank using either + +:: + + >>> torch.cuda.set_device(args.local_rank) # before your code runs + +or + +:: + + >>> with torch.cuda.device(args.local_rank): + >>> # your code to run + >>> ... + +.. versionchanged:: 2.0.0 + + The launcher will passes the ``--local-rank=`` argument to your script. + From PyTorch 2.0.0 onwards, the dashed ``--local-rank`` is preferred over the + previously used underscored ``--local_rank``. + + For backward compatibility, it may be necessary for users to handle both + cases in their argument parsing code. This means including both ``"--local-rank"`` + and ``"--local_rank"`` in the argument parser. If only ``"--local_rank"`` is + provided, the launcher will trigger an error: "error: unrecognized arguments: + --local-rank=". For training code that only supports PyTorch 2.0.0+, + including ``"--local-rank"`` should be sufficient. + +3. In your training program, you are supposed to call the following function +at the beginning to start the distributed backend. It is strongly recommended +that ``init_method=env://``. Other init methods (e.g. ``tcp://``) may work, +but ``env://`` is the one that is officially supported by this module. + +:: + + >>> torch.distributed.init_process_group(backend='YOUR BACKEND', + >>> init_method='env://') + +4. In your training program, you can either use regular distributed functions +or use :func:`torch.nn.parallel.DistributedDataParallel` module. If your +training program uses GPUs for training and you would like to use +:func:`torch.nn.parallel.DistributedDataParallel` module, +here is how to configure it. + +:: + + >>> model = torch.nn.parallel.DistributedDataParallel(model, + >>> device_ids=[args.local_rank], + >>> output_device=args.local_rank) + +Please ensure that ``device_ids`` argument is set to be the only GPU device id +that your code will be operating on. This is generally the local rank of the +process. In other words, the ``device_ids`` needs to be ``[args.local_rank]``, +and ``output_device`` needs to be ``args.local_rank`` in order to use this +utility + +5. Another way to pass ``local_rank`` to the subprocesses via environment variable +``LOCAL_RANK``. This behavior is enabled when you launch the script with +``--use-env=True``. You must adjust the subprocess example above to replace +``args.local_rank`` with ``os.environ['LOCAL_RANK']``; the launcher +will not pass ``--local-rank`` when you specify this flag. + +.. warning:: + + ``local_rank`` is NOT globally unique: it is only unique per process + on a machine. Thus, don't use it to decide if you should, e.g., + write to a networked filesystem. See + https://github.com/pytorch/pytorch/issues/12042 for an example of + how things can go wrong if you don't do this correctly. + + + +""" + +from typing_extensions import deprecated as _deprecated + +from torch.distributed.run import get_args_parser, run + + +def parse_args(args): + parser = get_args_parser() + parser.add_argument( + "--use-env", + "--use_env", + default=False, + action="store_true", + help="Use environment variable to pass " + "'local rank'. For legacy reasons, the default value is False. " + "If set to True, the script will not pass " + "--local-rank as argument, and will instead set LOCAL_RANK.", + ) + return parser.parse_args(args) + + +def launch(args): + if args.no_python and not args.use_env: + raise ValueError( + "When using the '--no-python' flag, you must also set the '--use-env' flag." + ) + run(args) + + +@_deprecated( + "The module torch.distributed.launch is deprecated\n" + "and will be removed in future. Use torchrun.\n" + "Note that --use-env is set by default in torchrun.\n" + "If your script expects `--local-rank` argument to be set, please\n" + "change it to read from `os.environ['LOCAL_RANK']` instead. See \n" + "https://pytorch.org/docs/stable/distributed.html#launch-utility for \n" + "further instructions\n", + category=FutureWarning, +) +def main(args=None): + args = parse_args(args) + launch(args) + + +if __name__ == "__main__": + main() diff --git a/phivenv/Lib/site-packages/torch/distributed/logging_handlers.py b/phivenv/Lib/site-packages/torch/distributed/logging_handlers.py new file mode 100644 index 0000000000000000000000000000000000000000..8d070b48bcf0c49a95d7e7d83ae54d693b744b5a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/logging_handlers.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging + + +__all__: list[str] = [] + +_log_handlers: dict[str, logging.Handler] = { + "default": logging.NullHandler(), +} diff --git a/phivenv/Lib/site-packages/torch/distributed/remote_device.py b/phivenv/Lib/site-packages/torch/distributed/remote_device.py new file mode 100644 index 0000000000000000000000000000000000000000..f19acdd1335318aa2043e43a490f300a068dc1e8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/remote_device.py @@ -0,0 +1,120 @@ +# mypy: allow-untyped-defs +from typing import Optional, Union + +import torch + + +class _remote_device: + """ + Represents a device on a remote worker. + + Args: + remote_device (str or torch.device): Represents a device on a remote worker. + The string format should be one of the following: + + 1. "/", where the device field can be parsed as torch.device type. + E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0". + In addition, the device field can be optional and the default value is "cpu". + 2. "rank:/", where is the rank of the + process and device can be parsed as torch.device type. + E.g., "rank:0/cpu", "rank:0", "rank:0/cuda:0" + 3. and are optional and formats like "cpu" + and "cuda:1", just represent local devices. + """ + + def __init__(self, remote_device: Union[str, torch.device]): + PARSE_ERROR = ( + f"Could not parse remote_device: {remote_device}. The valid format is " + "'/' or 'rank:/' or ''" + ) + self._worker_name = None + self._rank = None + self._device: Optional[Union[str, int, torch.device]] = None + + if isinstance(remote_device, torch.device): + self._device = remote_device + elif isinstance(remote_device, str): + fields = remote_device.split("/") + if len(fields) == 2: + self._worker_name, self._device = fields + elif len(fields) == 1: + # Check if this is a valid device. + if _remote_device._is_valid_local_device(fields[0]): + self._device = fields[0] + else: + self._worker_name = fields[0] + self._device = "cpu" + else: + raise ValueError(PARSE_ERROR) + else: + raise TypeError(f"Invalid type for remote_device: {type(remote_device)}") + + # Do some basic sanity check (no empty string) + if self._worker_name is not None and not self._worker_name: + raise ValueError(PARSE_ERROR) + + # Validate the device. + self._device = torch.device(self._device) + + # Check for rank based format. + if self._worker_name is not None: + fields = self._worker_name.split(":") + if len(fields) == 2: + # rank:/device format, extract rank + if fields[0] == "rank" and fields[1].isdigit(): + self._rank = int(fields[1]) # type: ignore[assignment] + self._worker_name = None + else: + raise ValueError(PARSE_ERROR) + elif len(fields) > 2: + raise ValueError(PARSE_ERROR) + + @staticmethod + def _is_valid_local_device(device): + # Check for torch.device + try: + torch.device(device) + return True + except Exception: + return False + + def worker_name(self) -> Optional[str]: + """Return the name of remote worker representing the remote device and ``None`` if no worker name is available.""" + return self._worker_name + + def rank(self) -> Optional[int]: + """ + Returns the rank of remote worker representing the remote device. + Returns ``None`` if no rank is available. + """ + return self._rank + + def device(self) -> torch.device: + """Return the local device on the remote worker.""" + return self._device # type: ignore[return-value] + + def __repr__(self): + if self._device is not None: + if self._worker_name is not None: + return f"{self._worker_name}/{self._device}" + elif self._rank is not None: + return f"rank:{self._rank}/{self._device}" + else: + return str(self._device) + else: + if self._worker_name is not None: + return f"{self._worker_name}" + elif self._rank is not None: + return f"{self._rank}" + else: + raise RuntimeError("Invalid state!") + + def __eq__(self, other): + return isinstance(other, _remote_device) and ( + self._worker_name == other._worker_name + and self._device == other._device + and self._rank == other._rank + ) + + def __hash__(self): + return hash(self._worker_name) ^ hash(self._device) ^ hash(self._rank) diff --git a/phivenv/Lib/site-packages/torch/distributed/rendezvous.py b/phivenv/Lib/site-packages/torch/distributed/rendezvous.py new file mode 100644 index 0000000000000000000000000000000000000000..0a38a5bc42f56d8f0fba3cc8f5558f5b807f8096 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/rendezvous.py @@ -0,0 +1,290 @@ +# mypy: allow-untyped-defs +try: + from urllib.parse import urlparse, urlunparse +except ImportError as e: + raise ImportError( + "urllib cannot be found, urlparse from python2 is no longer supported." + ) from e + +import numbers +import os +import sys +from collections.abc import Iterator +from datetime import timedelta +from typing import Callable, Optional + +from torch.distributed import FileStore, Store, TCPStore + +from .constants import default_pg_timeout + + +_rendezvous_handlers: dict[str, Callable[..., Iterator[tuple[Store, int, int]]]] = {} + +__all__ = ["register_rendezvous_handler", "rendezvous"] + + +def register_rendezvous_handler(scheme, handler): + """ + Register a new rendezvous handler. + + Before we can run collective algorithms, participating processes + need to find each other and exchange information to be able to + communicate. We call this process rendezvous. + + The outcome of the rendezvous process is a triplet containing a + shared key/value store, the rank of the process, and the total + number of participating processes. + + If none of the bundled rendezvous methods apply to your execution + environment you can opt to register your own rendezvous handler. + Pick a unique name and use the URL scheme to identify it when + calling the `rendezvous()` function. + + Args: + scheme (str): URL scheme to identify your rendezvous handler. + handler (function): Handler that is invoked when the + `rendezvous()` function is called with a URL that uses + the corresponding scheme. It must be a generator function + that yields the triplet. + """ + global _rendezvous_handlers + if scheme in _rendezvous_handlers: + raise RuntimeError(f"Rendezvous handler for {scheme}:// already registered") + _rendezvous_handlers[scheme] = handler + + +# Query will have format "rank=0&world_size=1" and is +# converted into {"rank": 0, "world_size": 1} +def _query_to_dict(query: str) -> dict[str, str]: + return { + pair[0]: pair[1] + for pair in (pair.split("=") for pair in filter(None, query.split("&"))) + } + + +def _get_use_libuv_from_query_dict(query_dict: dict[str, str]) -> bool: + # libuv is the default backend for TCPStore. To enable the non-libuv backend, + # user can explicitly specify ``use_libuv=0`` in the URL parameter. + if sys.platform == "win32": + # PyTorch is built without libuv support on windows, so default to 0 + return query_dict.get("use_libuv", os.environ.get("USE_LIBUV", "0")) == "1" + return query_dict.get("use_libuv", os.environ.get("USE_LIBUV", "1")) == "1" + + +def _rendezvous_helper(url: str, rank: int, world_size_opt: Optional[int], **kwargs): + result = urlparse(url) + if world_size_opt is None: + world_size = -1 + if result.scheme == "env": + rank = int(os.environ.get("RANK", rank)) + # If the world_size env variable is not present then it is a dynamic group + world_size = int(os.environ.get("WORLD_SIZE", world_size)) + else: + world_size = world_size_opt + if rank != -1 or world_size != -1 or world_size_opt is None: + query_dict = _query_to_dict(result.query) + assert "rank" not in query_dict and "world_size" not in query_dict, ( + f"The url: {url} has node-specific arguments(rank, world_size) already." + ) + if rank != -1: + query_dict["rank"] = str(rank) + if world_size != -1 or world_size_opt is None: + query_dict["world_size"] = str(world_size) + result = result._replace( + query=f"{'&'.join([f'{k}={v}' for k, v in query_dict.items()])}" + ) + url = urlunparse(result) + + if result.scheme not in _rendezvous_handlers: + raise RuntimeError(f"No rendezvous handler for {result.scheme}://") + return _rendezvous_handlers[result.scheme](url, **kwargs) + + +def rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs): + if not isinstance(url, (str, bytes)): + raise RuntimeError(f"`url` must be a string. {type(url)}: {url}") + + if not isinstance(rank, numbers.Integral): + raise RuntimeError(f"`rank` must be an integer. {rank}") + + if not isinstance(world_size, numbers.Integral): + raise RuntimeError(f"`world_size` must be an integer. {world_size}") + + return _rendezvous_helper(url, rank, world_size, **kwargs) + + +def _create_store_from_options(backend_options, rank): + store, _, _ = next(_rendezvous_helper(backend_options.init_method, rank, None)) + return store + + +def _rendezvous_error(msg): + return ValueError("Error initializing torch.distributed using " + msg) + + +def _file_rendezvous_handler(url: str, **kwargs): + def _error(msg): + return _rendezvous_error("file:// rendezvous: " + msg) + + result = urlparse(url) + path = result.path + if sys.platform == "win32": + import urllib.request + + full_path = result.netloc + result.path + path = urllib.request.url2pathname(full_path) + if path: + # Normalizing an empty string produces ".", which is not expected. + path = os.path.normpath(path) + + if not path: + raise _error("path missing") + query_dict = _query_to_dict(result.query) + if "rank" not in query_dict: + raise _error("rank parameter missing") + if "world_size" not in query_dict: + raise _error("world size parameter missing") + + rank = int(query_dict["rank"]) + world_size = int(query_dict["world_size"]) + store = FileStore(path, world_size) + yield (store, rank, world_size) + + # If this configuration is invalidated, there is nothing we can do about it + raise RuntimeError("Unable to perform rerendezvous using file:// method") + + +def _torchelastic_use_agent_store() -> bool: + return os.environ.get("TORCHELASTIC_USE_AGENT_STORE", None) == str(True) + + +def _create_c10d_store( + hostname, port, rank, world_size, timeout, use_libuv=True +) -> Store: + """ + Smartly creates a c10d Store object on ``rank`` based on whether we need to reuse agent store. + + The TCPStore server is assumed to be hosted + on ``hostname:port``. + + By default, the TCPStore server uses the asynchronous implementation + ``LibUVStoreDaemon`` which utilizes libuv. + + If ``torchelastic_use_agent_store()`` is ``True``, then it is assumed that + the agent leader (node rank 0) hosts the TCPStore server (for which the + endpoint is specified by the given ``hostname:port``). Hence + ALL ranks will create and return a TCPStore client (e.g. ``start_daemon=False``). + + If ``torchelastic_use_agent_store()`` is ``False``, then rank 0 will host + the TCPStore (with multi-tenancy) and it is assumed that rank 0's hostname + and port are correctly passed via ``hostname`` and ``port``. All + non-zero ranks will create and return a TCPStore client. + """ + # check if port is uint16_t + if not 0 <= port < 2**16: + raise ValueError(f"port must have value from 0 to 65535 but was {port}.") + + if _torchelastic_use_agent_store(): + # We create a new TCPStore for every retry so no need to add prefix for each attempt. + return TCPStore( + host_name=hostname, + port=port, + world_size=world_size, + is_master=False, + timeout=timeout, + ) + else: + start_daemon = rank == 0 + return TCPStore( + host_name=hostname, + port=port, + world_size=world_size, + is_master=start_daemon, + timeout=timeout, + multi_tenant=True, + use_libuv=use_libuv, + ) + + +def _tcp_rendezvous_handler( + url: str, timeout: timedelta = default_pg_timeout, **kwargs +): + def _error(msg): + return _rendezvous_error("tcp:// rendezvous: " + msg) + + result = urlparse(url) + if result.port is None: + raise _error("port number missing") + query_dict = _query_to_dict(result.query) + if "rank" not in query_dict: + raise _error("rank parameter missing") + if "world_size" not in query_dict: + raise _error("world size parameter missing") + + rank = int(query_dict["rank"]) + world_size = int(query_dict["world_size"]) + use_libuv = _get_use_libuv_from_query_dict(query_dict) + + assert result.hostname is not None + + store = _create_c10d_store( + result.hostname, result.port, rank, world_size, timeout, use_libuv + ) + + yield (store, rank, world_size) + + # If this configuration is invalidated, there is nothing we can do about it + raise RuntimeError("Unable to perform re-rendezvous using tcp:// method") + + +def _env_rendezvous_handler( + url: str, timeout: timedelta = default_pg_timeout, **kwargs +): + def _error(msg): + return _rendezvous_error("env:// rendezvous: " + msg) + + def _env_error(var): + return _error(f"environment variable {var} expected, but not set") + + def _get_env_or_raise(env_var: str) -> str: + env_val = os.environ.get(env_var, None) + if not env_val: + raise _env_error(env_var) + else: + return env_val + + result = urlparse(url) + query_dict = _query_to_dict(result.query) + + rank: int + world_size: int + master_port: int + master_addr: str + + if "rank" in query_dict: + rank = int(query_dict["rank"]) + else: + rank = int(_get_env_or_raise("RANK")) + + if "world_size" in query_dict: + world_size = int(query_dict["world_size"]) + else: + world_size = int(_get_env_or_raise("WORLD_SIZE")) + + master_addr = _get_env_or_raise("MASTER_ADDR") + master_port = int(_get_env_or_raise("MASTER_PORT")) + use_libuv = _get_use_libuv_from_query_dict(query_dict) + + store = _create_c10d_store( + master_addr, master_port, rank, world_size, timeout, use_libuv + ) + + yield (store, rank, world_size) + + # If this configuration is invalidated, there is nothing we can do about it + raise RuntimeError("Unable to perform re-rendezvous using env:// method") + + +register_rendezvous_handler("tcp", _tcp_rendezvous_handler) +register_rendezvous_handler("env", _env_rendezvous_handler) +register_rendezvous_handler("file", _file_rendezvous_handler) diff --git a/phivenv/Lib/site-packages/torch/distributed/run.py b/phivenv/Lib/site-packages/torch/distributed/run.py new file mode 100644 index 0000000000000000000000000000000000000000..c7c9f9599d853d979e0aaf2c4e96c279863d3782 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/run.py @@ -0,0 +1,905 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Module ``torch.distributed.run``. + +``torch.distributed.run`` is a module that spawns up multiple distributed +training processes on each of the training nodes. + +``torchrun`` is a python +`console script `_ +to the main module +`torch.distributed.run `_ +declared in the ``entry_points`` configuration in +`setup.py `_. +It is equivalent to invoking ``python -m torch.distributed.run``. + +``torchrun`` can be used for single-node distributed training, in which one or +more processes per node will be spawned. It can be used for either +CPU training or GPU training. If it is used for GPU training, +each distributed process will be operating on a single GPU. This can achieve +well-improved single-node training performance. ``torchrun`` can also be used in +multi-node distributed training, by spawning up multiple processes on each node +for well-improved multi-node distributed training performance as well. +This will especially be beneficial for systems with multiple Infiniband +interfaces that have direct-GPU support, since all of them can be utilized for +aggregated communication bandwidth. + +In both cases of single-node distributed training or multi-node distributed +training, ``torchrun`` will launch the given number of processes per node +(``--nproc-per-node``). If used for GPU training, this number needs to be less +or equal to the number of GPUs on the current system (``nproc_per_node``), +and each process will be operating on a single GPU from *GPU 0 to +GPU (nproc_per_node - 1)*. + +.. versionchanged:: 2.0.0 + + ``torchrun`` will pass the ``--local-rank=`` argument to your script. + From PyTorch 2.0.0 onwards, the dashed ``--local-rank`` is preferred over the + previously used underscored ``--local_rank``. + + For backward compatibility, it may be necessary for users to handle both + cases in their argument parsing code. This means including both ``"--local-rank"`` + and ``"--local_rank"`` in the argument parser. If only ``"--local_rank"`` is + provided, ``torchrun`` will trigger an error: "error: unrecognized arguments: + --local-rank=". For training code that only supports PyTorch 2.0.0+, + including ``"--local-rank"`` should be sufficient. + + :: + + >>> # xdoctest: +SKIP + >>> import argparse + >>> parser = argparse.ArgumentParser() + >>> parser.add_argument("--local-rank", "--local_rank", type=int) + >>> args = parser.parse_args() + +Usage +----- + +Single-node multi-worker +++++++++++++++++++++++++ + +:: + + torchrun + --standalone + --nnodes=1 + --nproc-per-node=$NUM_TRAINERS + YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...) + +.. note:: ``--nproc-per-node`` may be + ``"gpu"`` (spawn one process per GPU), + ``"cpu"`` (spawn one process per CPU), + ``"auto"`` (equivalent to ``"gpu"`` if CUDA is available, + else equivalent to ``"cpu"``), + or an integer specifying the number of processes. + See `torch.distributed.run.determine_local_world_size + `_ + for more details. + +Stacked single-node multi-worker +++++++++++++++++++++++++++++++++ + +To run multiple instances (separate jobs) of single-node, multi-worker on the +same host, we need to make sure that each instance (job) is +setup on different ports to avoid port conflicts (or worse, two jobs being merged +as a single job). To do this you have to run with ``--rdzv-backend=c10d`` +and specify a different port by setting ``--rdzv-endpoint=localhost:$PORT_k``. +For ``--nodes=1``, its often convenient to let ``torchrun`` pick a free random +port automatically instead of manually assigning different ports for each run. + +:: + + torchrun + --rdzv-backend=c10d + --rdzv-endpoint=localhost:0 + --nnodes=1 + --nproc-per-node=$NUM_TRAINERS + YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...) + + +Fault tolerant (fixed sized number of workers, no elasticity, tolerates 3 failures) ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +:: + + torchrun + --nnodes=$NUM_NODES + --nproc-per-node=$NUM_TRAINERS + --max-restarts=3 + --rdzv-id=$JOB_ID + --rdzv-backend=c10d + --rdzv-endpoint=$HOST_NODE_ADDR + YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...) + +``HOST_NODE_ADDR``, in form [:] (e.g. node1.example.com:29400), specifies the node and +the port on which the C10d rendezvous backend should be instantiated and hosted. It can be any +node in your training cluster, but ideally you should pick a node that has a high bandwidth. + +.. note:: + If no port number is specified ``HOST_NODE_ADDR`` defaults to 29400. + +Elastic (``min=1``, ``max=4``, tolerates up to 3 membership changes or failures) +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +:: + + torchrun + --nnodes=1:4 + --nproc-per-node=$NUM_TRAINERS + --max-restarts=3 + --rdzv-id=$JOB_ID + --rdzv-backend=c10d + --rdzv-endpoint=$HOST_NODE_ADDR + YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...) + +``HOST_NODE_ADDR``, in form [:] (e.g. node1.example.com:29400), specifies the node and +the port on which the C10d rendezvous backend should be instantiated and hosted. It can be any +node in your training cluster, but ideally you should pick a node that has a high bandwidth. + +.. note:: + If no port number is specified ``HOST_NODE_ADDR`` defaults to 29400. + +Note on rendezvous backend +-------------------------- + +For multi-node training you need to specify: + +1. ``--rdzv-id``: A unique job id (shared by all nodes participating in the job) +2. ``--rdzv-backend``: An implementation of + :py:class:`torch.distributed.elastic.rendezvous.RendezvousHandler` +3. ``--rdzv-endpoint``: The endpoint where the rendezvous backend is running; usually in form + ``host:port``. + +Currently ``c10d`` (recommended), ``etcd-v2``, and ``etcd`` (legacy) rendezvous backends are +supported out of the box. To use ``etcd-v2`` or ``etcd``, setup an etcd server with the ``v2`` api +enabled (e.g. ``--enable-v2``). + +.. warning:: + ``etcd-v2`` and ``etcd`` rendezvous use etcd API v2. You MUST enable the v2 API on the etcd + server. Our tests use etcd v3.4.3. + +.. warning:: + For etcd-based rendezvous we recommend using ``etcd-v2`` over ``etcd`` which is functionally + equivalent, but uses a revised implementation. ``etcd`` is in maintenance mode and will be + removed in a future version. + +Definitions +----------- + +1. ``Node`` - A physical instance or a container; maps to the unit that the job manager works with. + +2. ``Worker`` - A worker in the context of distributed training. + +3. ``WorkerGroup`` - The set of workers that execute the same function (e.g. trainers). + +4. ``LocalWorkerGroup`` - A subset of the workers in the worker group running on the same node. + +5. ``RANK`` - The rank of the worker within a worker group. + +6. ``WORLD_SIZE`` - The total number of workers in a worker group. + +7. ``LOCAL_RANK`` - The rank of the worker within a local worker group. + +8. ``LOCAL_WORLD_SIZE`` - The size of the local worker group. + +9. ``rdzv_id`` - A user-defined id that uniquely identifies the worker group for a job. This id is + used by each node to join as a member of a particular worker group. + +9. ``rdzv_backend`` - The backend of the rendezvous (e.g. ``c10d``). This is typically a strongly + consistent key-value store. + +10. ``rdzv_endpoint`` - The rendezvous backend endpoint; usually in form ``:``. + +A ``Node`` runs ``LOCAL_WORLD_SIZE`` workers which comprise a ``LocalWorkerGroup``. The union of +all ``LocalWorkerGroups`` in the nodes in the job comprise the ``WorkerGroup``. + +Environment Variables +--------------------- + +The following environment variables are made available to you in your script: + +1. ``LOCAL_RANK`` - The local rank. + +2. ``RANK`` - The global rank. + +3. ``GROUP_RANK`` - The rank of the worker group. A number between 0 and ``max_nnodes``. When + running a single worker group per node, this is the rank of the node. + +4. ``ROLE_RANK`` - The rank of the worker across all the workers that have the same role. The role + of the worker is specified in the ``WorkerSpec``. + +5. ``LOCAL_WORLD_SIZE`` - The local world size (e.g. number of workers running locally); equals to + ``--nproc-per-node`` specified on ``torchrun``. + +6. ``WORLD_SIZE`` - The world size (total number of workers in the job). + +7. ``ROLE_WORLD_SIZE`` - The total number of workers that was launched with the same role specified + in ``WorkerSpec``. + +8. ``MASTER_ADDR`` - The FQDN of the host that is running worker with rank 0; used to initialize + the Torch Distributed backend. + +9. ``MASTER_PORT`` - The port on the ``MASTER_ADDR`` that can be used to host the C10d TCP store. + +10. ``TORCHELASTIC_RESTART_COUNT`` - The number of worker group restarts so far. + +11. ``TORCHELASTIC_MAX_RESTARTS`` - The configured maximum number of restarts. + +12. ``TORCHELASTIC_RUN_ID`` - Equal to the rendezvous ``run_id`` (e.g. unique job id). + +13. ``PYTHON_EXEC`` - System executable override. If provided, the python user script will + use the value of ``PYTHON_EXEC`` as executable. The `sys.executable` is used by default. + +Deployment +---------- + +1. (Not needed for the C10d backend) Start the rendezvous backend server and get the endpoint (to be + passed as ``--rdzv-endpoint`` to ``torchrun``) + +2. Single-node multi-worker: Start ``torchrun`` on the host to start the agent process which + creates and monitors a local worker group. + +3. Multi-node multi-worker: Start ``torchrun`` with the same arguments on all the nodes + participating in training. + +When using a job/cluster manager, the entry point command to the multi-node job should be ``torchrun``. + +Failure Modes +------------- + +1. Worker failure: For a training job with ``n`` workers, if ``k<=n`` workers fail all workers + are stopped and restarted up to ``max_restarts``. + +2. Agent failure: An agent failure results in a local worker group failure. It is up to the job + manager to fail the entire job (gang semantics) or attempt to replace the node. Both behaviors + are supported by the agent. + +3. Node failure: Same as agent failure. + +Membership Changes +------------------ + +1. Node departure (scale-down): The agent is notified of the departure, all existing workers are + stopped, a new ``WorkerGroup`` is formed, and all workers are started with a new ``RANK`` and + ``WORLD_SIZE``. + +2. Node arrival (scale-up): The new node is admitted to the job, all existing workers are stopped, + a new ``WorkerGroup`` is formed, and all workers are started with a new ``RANK`` and + ``WORLD_SIZE``. + +Important Notices +----------------- + +1. This utility and multi-process distributed (single-node or + multi-node) GPU training currently only achieves the best performance using + the NCCL distributed backend. Thus NCCL backend is the recommended backend to + use for GPU training. + +2. The environment variables necessary to initialize a Torch process group are provided to you by + this module, no need for you to pass ``RANK`` manually. To initialize a process group in your + training script, simply run: + +:: + + >>> # xdoctest: +SKIP("stub") + >>> import torch.distributed as dist + >>> dist.init_process_group(backend="gloo|nccl") + +3. In your training program, you can either use regular distributed functions + or use :func:`torch.nn.parallel.DistributedDataParallel` module. If your + training program uses GPUs for training and you would like to use + :func:`torch.nn.parallel.DistributedDataParallel` module, + here is how to configure it. + +:: + + local_rank = int(os.environ["LOCAL_RANK"]) + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[local_rank], output_device=local_rank + ) + +Please ensure that ``device_ids`` argument is set to be the only GPU device id +that your code will be operating on. This is generally the local rank of the +process. In other words, the ``device_ids`` needs to be ``[int(os.environ("LOCAL_RANK"))]``, +and ``output_device`` needs to be ``int(os.environ("LOCAL_RANK"))`` in order to use this +utility + + +4. On failures or membership changes ALL surviving workers are killed immediately. Make sure to + checkpoint your progress. The frequency of checkpoints should depend on your job's tolerance + for lost work. + +5. This module only supports homogeneous ``LOCAL_WORLD_SIZE``. That is, it is assumed that all + nodes run the same number of local workers (per role). + +6. ``RANK`` is NOT stable. Between restarts, the local workers on a node can be assigned a + different range of ranks than before. NEVER hard code any assumptions about the stable-ness of + ranks or some correlation between ``RANK`` and ``LOCAL_RANK``. + +7. When using elasticity (``min_size!=max_size``) DO NOT hard code assumptions about + ``WORLD_SIZE`` as the world size can change as nodes are allowed to leave and join. + +8. It is recommended for your script to have the following structure: + +:: + + def main(): + load_checkpoint(checkpoint_path) + initialize() + train() + + + def train(): + for batch in iter(dataset): + train_step(batch) + + if should_checkpoint: + save_checkpoint(checkpoint_path) + +9. (Recommended) On worker errors, this tool will summarize the details of the error + (e.g. time, rank, host, pid, traceback, etc). On each node, the first error (by timestamp) + is heuristically reported as the "Root Cause" error. To get tracebacks as part of this + error summary print out, you must decorate your main entrypoint function in your + training script as shown in the example below. If not decorated, then the summary + will not include the traceback of the exception and will only contain the exitcode. + For details on torchelastic error handling see: https://pytorch.org/docs/stable/elastic/errors.html + +:: + + from torch.distributed.elastic.multiprocessing.errors import record + + + @record + def main(): + # do train + pass + + + if __name__ == "__main__": + main() +""" # noqa: E501 + +import os +import sys +import uuid +from argparse import ArgumentParser, REMAINDER +from importlib import metadata +from typing import Callable, Optional, Union + +import torch +from torch.distributed.argparse_util import check_env, env +from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, LogsSpecs, Std +from torch.distributed.elastic.multiprocessing.errors import record +from torch.distributed.elastic.rendezvous.utils import _parse_rendezvous_config +from torch.distributed.elastic.utils import macros +from torch.distributed.elastic.utils.logging import get_logger +from torch.distributed.launcher.api import elastic_launch, LaunchConfig +from torch.utils.backend_registration import _get_custom_mod_func + + +logger = get_logger(__name__) + + +def get_args_parser() -> ArgumentParser: + """Parse the command line options.""" + parser = ArgumentParser(description="Torch Distributed Elastic Training Launcher") + + # + # Worker/node size related arguments. + # + + parser.add_argument( + "--nnodes", + action=env, + type=str, + default="1:1", + help="Number of nodes, or the range of nodes in form :.", + ) + parser.add_argument( + "--nproc-per-node", + "--nproc_per_node", + action=env, + type=str, + default="1", + help="Number of workers per node; supported values: [auto, cpu, gpu, int].", + ) + + # + # Rendezvous related arguments + # + + parser.add_argument( + "--rdzv-backend", + "--rdzv_backend", + action=env, + type=str, + default="static", + help="Rendezvous backend.", + ) + parser.add_argument( + "--rdzv-endpoint", + "--rdzv_endpoint", + action=env, + type=str, + default="", + help="Rendezvous backend endpoint; usually in form :.", + ) + parser.add_argument( + "--rdzv-id", + "--rdzv_id", + action=env, + type=str, + default="none", + help="User-defined group id.", + ) + parser.add_argument( + "--rdzv-conf", + "--rdzv_conf", + action=env, + type=str, + default="", + help="Additional rendezvous configuration (=,=,...).", + ) + parser.add_argument( + "--standalone", + action=check_env, + help="Start a local standalone rendezvous backend that is represented by a C10d TCP store " + "on a free port. Useful when launching single-node, multi-worker job. If specified " + "--rdzv-backend, --rdzv-endpoint, --rdzv-id are auto-assigned and any explicitly set values " + "are ignored.", + ) + + # + # User-code launch related arguments. + # + + parser.add_argument( + "--max-restarts", + "--max_restarts", + action=env, + type=int, + default=0, + help="Maximum number of worker group restarts before failing.", + ) + parser.add_argument( + "--monitor-interval", + "--monitor_interval", + action=env, + type=float, + default=0.1, + help="Interval, in seconds, to monitor the state of workers.", + ) + parser.add_argument( + "--start-method", + "--start_method", + action=env, + type=str, + default="spawn", + choices=["spawn", "fork", "forkserver"], + help="Multiprocessing start method to use when creating workers.", + ) + parser.add_argument( + "--event-log-handler", + "--event_log_handler", + action=env, + type=str, + default="null", + help="name of a registered event logging handler (see: https://docs.pytorch.org/docs/stable/elastic/events.html)", + ) + parser.add_argument( + "--role", + action=env, + type=str, + default="default", + help="User-defined role for the workers.", + ) + parser.add_argument( + "-m", + "--module", + action=check_env, + help="Change each process to interpret the launch script as a Python module, executing " + "with the same behavior as 'python -m'.", + ) + parser.add_argument( + "--no-python", + "--no_python", + action=check_env, + help="Skip prepending the training script with 'python' - just execute it directly. Useful " + "when the script is not a Python script.", + ) + + parser.add_argument( + "--run-path", + "--run_path", + action=check_env, + help="Run the training script with runpy.run_path in the same interpreter." + " Script must be provided as an abs path (e.g. /abs/path/script.py)." + " Takes precedence over --no-python.", + ) + parser.add_argument( + "--log-dir", + "--log_dir", + action=env, + type=str, + default=None, + help="Base directory to use for log files (e.g. /var/log/torch/elastic). The same " + "directory is reused for multiple runs (a unique job-level sub-directory is created with " + "rdzv_id as the prefix).", + ) + parser.add_argument( + "-r", + "--redirects", + action=env, + type=str, + default="0", + help="Redirect std streams into a log file in the log directory (e.g. [-r 3] redirects " + "both stdout+stderr for all workers, [-r 0:1,1:2] redirects stdout for local rank 0 and " + "stderr for local rank 1).", + ) + parser.add_argument( + "-t", + "--tee", + action=env, + type=str, + default="0", + help="Tee std streams into a log file and also to console (see --redirects for format).", + ) + + parser.add_argument( + "--local-ranks-filter", + "--local_ranks_filter", + action=env, + type=str, + default="", + help="Only show logs from specified ranks in console (e.g. [--local_ranks_filter=0,1,2] will " + "only show logs from rank 0, 1 and 2). This will only apply to stdout and stderr, not to" + "log files saved via --redirect or --tee", + ) + + # + # Backwards compatible parameters with caffe2.distributed.launch. + # + + parser.add_argument( + "--node-rank", + "--node_rank", + type=int, + action=env, + default=0, + help="Rank of the node for multi-node distributed training.", + ) + parser.add_argument( + "--master-addr", + "--master_addr", + default="127.0.0.1", + type=str, + action=env, + help="Address of the master node (rank 0) that only used for static rendezvous. It should " + "be either the IP address or the hostname of rank 0. For single node multi-proc training " + "the --master-addr can simply be 127.0.0.1; IPv6 should have the pattern " + "`[0:0:0:0:0:0:0:1]`.", + ) + parser.add_argument( + "--master-port", + "--master_port", + default=29500, + type=int, + action=env, + help="Port on the master node (rank 0) to be used for communication during distributed " + "training. It is only used for static rendezvous.", + ) + parser.add_argument( + "--local-addr", + "--local_addr", + default=None, + type=str, + action=env, + help="Address of the local node. If specified, will use the given address for connection. " + "Else, will look up the local node address instead. Else, it will be default to local " + "machine's FQDN.", + ) + + parser.add_argument( + "--logs-specs", + "--logs_specs", + default=None, + type=str, + help="torchrun.logs_specs group entrypoint name, value must be type of LogsSpecs. " + "Can be used to override custom logging behavior.", + ) + + # + # Positional arguments. + # + + parser.add_argument( + "training_script", + type=str, + help="Full path to the (single GPU) training program/script to be launched in parallel, " + "followed by all the arguments for the training script.", + ) + + # Rest from the training program. + parser.add_argument("training_script_args", nargs=REMAINDER) + + return parser + + +def parse_args(args): + parser = get_args_parser() + return parser.parse_args(args) + + +def parse_min_max_nnodes(nnodes: str): + arr = nnodes.split(":") + + if len(arr) == 1: + min_nodes = max_nodes = int(arr[0]) + elif len(arr) == 2: + min_nodes = int(arr[0]) + max_nodes = int(arr[1]) + else: + raise RuntimeError(f'nnodes={nnodes} is not in "MIN:MAX" format') # noqa: E231 + + return min_nodes, max_nodes + + +def determine_local_world_size(nproc_per_node: str): + try: + logger.info("Using nproc_per_node=%s.", nproc_per_node) + return int(nproc_per_node) + except ValueError as e: + if nproc_per_node == "cpu": + num_proc = os.cpu_count() + device_type = "cpu" + elif nproc_per_node == "gpu": + if not torch.cuda.is_available(): + raise ValueError("Cuda is not available.") from e + device_type = "gpu" + num_proc = torch.cuda.device_count() + elif nproc_per_node == torch._C._get_privateuse1_backend_name(): + if not _get_custom_mod_func("is_available")(): + raise ValueError(f"{nproc_per_node} is not available.") from e + device_type = nproc_per_node + num_proc = _get_custom_mod_func("device_count")() + elif nproc_per_node == "auto": + if torch.cuda.is_available(): + num_proc = torch.cuda.device_count() + device_type = "gpu" + elif ( + hasattr(torch, torch._C._get_privateuse1_backend_name()) + and _get_custom_mod_func("is_available")() + ): + num_proc = _get_custom_mod_func("device_count")() + device_type = torch._C._get_privateuse1_backend_name() + else: + num_proc = os.cpu_count() + device_type = "cpu" + else: + raise ValueError( + f"Unsupported nproc_per_node value: {nproc_per_node}" + ) from e + + logger.info( + "Using nproc_per_node=%s, setting nproc_per_node to %s since the instance has %s %s", + nproc_per_node, + num_proc, + num_proc, + device_type, + ) + return num_proc + + +def get_rdzv_endpoint(args): + if args.rdzv_backend == "static" and not args.rdzv_endpoint: + return f"{args.master_addr}:{args.master_port}" # noqa: E231 + return args.rdzv_endpoint + + +def get_use_env(args) -> bool: + """ + Retrieve ``use_env`` from the args. + + ``use_env`` is a legacy argument, if ``use_env`` is False, the + ``--node-rank`` argument will be transferred to all worker processes. + ``use_env`` is only used by the ``torch.distributed.launch`` and will + be deprecated in future releases. + """ + if not hasattr(args, "use_env"): + return True + return args.use_env + + +def _get_logs_specs_class(logs_specs_name: Optional[str]) -> type[LogsSpecs]: + """ + Attempts to load `torchrun.logs_spec` entrypoint with key of `logs_specs_name` param. + Provides plugin mechanism to provide custom implementation of LogsSpecs. + + Returns `DefaultLogsSpecs` when logs_spec_name is None. + Raises ValueError when entrypoint for `logs_spec_name` can't be found in entrypoints. + """ + logs_specs_cls = None + if logs_specs_name is not None: + eps = metadata.entry_points() + if hasattr(eps, "select"): # >= 3.10 + group = eps.select(group="torchrun.logs_specs") + if group.select(name=logs_specs_name): + logs_specs_cls = group[logs_specs_name].load() + + elif specs := eps.get("torchrun.logs_specs"): # < 3.10 + if entrypoint_list := [ep for ep in specs if ep.name == logs_specs_name]: + logs_specs_cls = entrypoint_list[0].load() + + if logs_specs_cls is None: + raise ValueError( + f"Could not find entrypoint under 'torchrun.logs_specs[{logs_specs_name}]' key" + ) + + logger.info( + "Using logs_spec '%s' mapped to %s", logs_specs_name, str(logs_specs_cls) + ) + else: + logs_specs_cls = DefaultLogsSpecs + + return logs_specs_cls + + +def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], list[str]]: + # If ``args`` not passed, defaults to ``sys.argv[:1]`` + min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes) + assert 0 < min_nodes <= max_nodes + assert args.max_restarts >= 0 + + if ( + hasattr(args, "master_addr") + and args.rdzv_backend != "static" + and not args.rdzv_endpoint + ): + logger.warning( + "master_addr is only used for static rdzv_backend and when rdzv_endpoint " + "is not specified." + ) + + nproc_per_node = determine_local_world_size(args.nproc_per_node) + if "OMP_NUM_THREADS" not in os.environ and nproc_per_node > 1: + omp_num_threads = 1 + logger.warning( + "\n*****************************************\n" + "Setting OMP_NUM_THREADS environment variable for each process to be " + "%s in default, to avoid your system being overloaded, " + "please further tune the variable for optimal performance in " + "your application as needed. \n" + "*****************************************", + omp_num_threads, + ) + # This env variable will be passed down to the subprocesses + os.environ["OMP_NUM_THREADS"] = str(omp_num_threads) + + log_line_prefix_template = os.getenv("TORCHELASTIC_LOG_LINE_PREFIX_TEMPLATE") + + rdzv_configs = _parse_rendezvous_config(args.rdzv_conf) + + if args.rdzv_backend == "static": + rdzv_configs["rank"] = args.node_rank + + rdzv_endpoint = get_rdzv_endpoint(args) + + ranks: Optional[set[int]] = None + if args.local_ranks_filter: + try: + ranks = set(map(int, args.local_ranks_filter.split(","))) + assert ranks + except Exception as e: + raise ValueError( + "--local_ranks_filter must be a comma-separated list of integers e.g. --local_ranks_filter=0,1,2" + ) from e + + logs_specs_cls: type[LogsSpecs] = _get_logs_specs_class(args.logs_specs) + logs_specs = logs_specs_cls( + log_dir=args.log_dir, + redirects=Std.from_str(args.redirects), + tee=Std.from_str(args.tee), + local_ranks_filter=ranks, + ) + + config = LaunchConfig( + min_nodes=min_nodes, + max_nodes=max_nodes, + nproc_per_node=nproc_per_node, + run_id=args.rdzv_id, + role=args.role, + rdzv_endpoint=rdzv_endpoint, + rdzv_backend=args.rdzv_backend, + rdzv_configs=rdzv_configs, + max_restarts=args.max_restarts, + monitor_interval=args.monitor_interval, + start_method=args.start_method, + log_line_prefix_template=log_line_prefix_template, + local_addr=args.local_addr, + logs_specs=logs_specs, + event_log_handler=args.event_log_handler, + ) + + with_python = not args.no_python + cmd: Union[Callable, str] + cmd_args = [] + use_env = get_use_env(args) + if args.run_path: + cmd = run_script_path + cmd_args.append(args.training_script) + else: + if with_python: + cmd = os.getenv("PYTHON_EXEC", sys.executable) + cmd_args.append("-u") + if args.module: + cmd_args.append("-m") + cmd_args.append(args.training_script) + else: + if args.module: + raise ValueError( + "Don't use both the '--no-python' flag" + " and the '--module' flag at the same time." + ) + cmd = args.training_script + if not use_env: + cmd_args.append(f"--local-rank={macros.local_rank}") + cmd_args.extend(args.training_script_args) + + return config, cmd, cmd_args + + +def run_script_path(training_script: str, *training_script_args: str): + """ + Run the provided `training_script` from within this interpreter. + + Usage: `script_as_function("/abs/path/to/script.py", "--arg1", "val1")` + """ + import runpy + import sys + + sys.argv = [training_script] + [*training_script_args] + runpy.run_path(sys.argv[0], run_name="__main__") + + +def run(args): + torch.multiprocessing._set_thread_name("pt_elastic") + + if args.standalone: + args.rdzv_backend = "c10d" + args.rdzv_endpoint = "localhost:0" + args.rdzv_id = str(uuid.uuid4()) + logger.info( + "\n**************************************\n" + "Rendezvous info:\n" + "--rdzv-backend=%s " + "--rdzv-endpoint=%s " + "--rdzv-id=%s\n" + "**************************************\n", + args.rdzv_backend, + args.rdzv_endpoint, + args.rdzv_id, + ) + + config, cmd, cmd_args = config_from_args(args) + elastic_launch( + config=config, + entrypoint=cmd, + )(*cmd_args) + + +@record +def main(args=None): + args = parse_args(args) + run(args) + + +if __name__ == "__main__": + main() diff --git a/phivenv/Lib/site-packages/torch/distributed/utils.py b/phivenv/Lib/site-packages/torch/distributed/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cade28f32691ca4d6d8c774c15eec38d0c273823 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/utils.py @@ -0,0 +1,376 @@ +# mypy: allow-untyped-defs +import dataclasses +import traceback +from collections import OrderedDict +from collections.abc import Container +from typing import Any, Callable, Optional, overload, TypeVar + +import torch +import torch.distributed as dist +from torch import nn +from torch.nn.utils.rnn import PackedSequence + + +__all__ = [] # type: ignore[var-annotated] + + +def _pack_kwargs(*args: Any, **kwargs: Any) -> tuple[tuple[Any, ...], tuple[str, ...]]: + """ + Turn argument list into separate key list and value list (unpack_kwargs does the opposite). + + Inspiration: https://github.com/facebookresearch/fairscale/blob/eeb6684/fairscale/internal/containers.py#L70 + Usage:: + + kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4) + assert kwarg_keys == ("a", "b") + assert flat_args == (1, 2, 3, 4) + args, kwargs = unpack_kwargs(kwarg_keys, flat_args) + assert args == (1, 2) + assert kwargs == {"a": 3, "b": 4} + Returns: + Tuple[Tuple[Any, ...], Tuple[str, ...]]: The first tuple element gives + gives both positional args and kwarg values, where the positional args + proceed kwarg values and kwarg values are ordered consistently with the + kwarg keys. The second tuple element gives the kwarg keys. + The second tuple element's length is at most the first tuple element's length. + """ + kwarg_keys: list[str] = [] + flat_args: list[Any] = list(args) + for k, v in kwargs.items(): + kwarg_keys.append(k) + flat_args.append(v) + + return tuple(flat_args), tuple(kwarg_keys) + + +def _cast_forward_inputs( + dtype: Optional[torch.dtype], + *args: Any, + **kwargs: Any, +) -> tuple[Any, Any]: + """ + Cast floating point tensors in ``args`` and ``kwargs`` to ``input_dtype``. + + This respects the existing ``requires_grad`` on the tensors. + """ + if dtype is None: + return args, kwargs + + def cast_fn(x: torch.Tensor) -> torch.Tensor: + if not torch.is_floating_point(x) or x.dtype == dtype: + return x + return x.to(dtype) + + return (_apply_to_tensors(cast_fn, args), _apply_to_tensors(cast_fn, kwargs)) + + +def _unpack_kwargs( + flat_args: tuple[Any, ...], kwarg_keys: tuple[str, ...] +) -> tuple[tuple[Any, ...], dict[str, Any]]: + """See _pack_kwargs.""" + assert len(kwarg_keys) <= len(flat_args), ( + f"too many keys {len(kwarg_keys)} vs. {len(flat_args)}" + ) + if len(kwarg_keys) == 0: + return flat_args, {} + args = flat_args[: -len(kwarg_keys)] + kwargs = dict(zip(kwarg_keys, flat_args[-len(kwarg_keys) :])) + return args, kwargs + + +S = TypeVar("S", dict, list, tuple) +T = TypeVar("T", torch.Tensor, PackedSequence) + + +@overload +def _recursive_to( + inputs: S, target_device: torch.device, use_side_stream_for_tensor_copies: bool +) -> list[S]: ... + + +@overload +def _recursive_to( + inputs: T, target_device: torch.device, use_side_stream_for_tensor_copies: bool +) -> tuple[T]: ... + + +def _recursive_to(inputs, target_device, use_side_stream_for_tensor_copies): + r"""Recursively moves input to the target_device.""" + + def to_map(obj): + if isinstance(obj, (torch.Tensor, PackedSequence)): + device = obj.data.device if isinstance(obj, PackedSequence) else obj.device + if device == target_device: + return (obj,) + if not use_side_stream_for_tensor_copies: + return (obj.to(target_device),) + else: + # If the custom module is not registered to torch, stream is not used for acceleration + if device.type == "cpu": + return (obj.to(target_device),) + + from torch.nn.parallel._functions import _get_stream + + # Perform CPU -> target_device copies in a background stream. This code is + # motivated from similar logic in torch/nn/parallel/_functions.py + stream = _get_stream(target_device) + with stream: + output = obj.to(target_device) + # synchronize with the copy stream + with torch.accelerator.device_index(target_device.index): + current_stream = torch.accelerator.current_stream() + # Sync the current stream with the copy stream + current_stream.wait_stream(stream) + # Ensure tensor memory is not reused until work on + # main stream is complete + if isinstance(obj, PackedSequence): + output.data.record_stream(current_stream) # type: ignore[arg-type] + else: + assert isinstance(output, torch.Tensor) + output.record_stream(current_stream) # type: ignore[arg-type] + return (output,) + + from torch.nn.parallel.scatter_gather import _is_namedtuple + + if _is_namedtuple(obj): + return [type(obj)(*args) for args in zip(*map(to_map, obj))] + if isinstance(obj, tuple) and len(obj) > 0: + return list(zip(*map(to_map, obj))) + if isinstance(obj, list) and len(obj) > 0: + return [list(i) for i in zip(*map(to_map, obj))] + if isinstance(obj, dict) and len(obj) > 0: + return [type(obj)(i) for i in zip(*map(to_map, obj.items()))] + return [obj] + + # Avoid reference cycle + try: + res = to_map(inputs) + finally: + to_map = None # type: ignore[assignment] + return res + + +def _p_assert(cond: Any, s: str, raise_assertion_error: bool = True) -> None: + """Alternate to ``assert`` when in the backward context to print the error message ``s`` since otherwise, it is swallowed.""" + if not cond: + print(s) + traceback.print_stack() + if raise_assertion_error: + raise AssertionError(s) + + +def _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> None: + """ + Allocate storage for ``tensor`` with the given size. + + Returns: + bool: ``True`` if this method allocated storage and ``False`` if the + storage was already allocated. + """ + with torch.no_grad(): + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): + already_allocated = tensor._typed_storage()._size() == size.numel() + if not already_allocated: + tensor_storage_size = tensor._typed_storage()._size() + _p_assert( + tensor_storage_size == 0, + "Tensor storage should have been resized to be 0 but got PLACEHOLDEr", + ) + tensor._typed_storage()._resize_(size.numel()) + + +def _free_storage(tensor: torch.Tensor): + """ + Frees the underlying storage of ``tensor``. + + Returns: + bool: ``True`` if the method freed the storage and ``False`` if the + storage was already freed. + """ + with torch.no_grad(): + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): + already_freed = tensor._typed_storage()._size() == 0 + if not already_freed: + _p_assert( + tensor.storage_offset() == 0, + "Freeing a tensor's storage is unsafe when it is not the sole occupant\n" + f"storage offset: {tensor.storage_offset()}\n" + f"storage size: {tensor._typed_storage()._size()}\n" + f"tensor shape: {tensor.shape}", + ) + tensor._typed_storage()._resize_(0) + + +Q = TypeVar("Q") +R = TypeVar("R", dict, list, tuple, set, OrderedDict, PackedSequence, Any) + + +@overload +def _apply_to_tensors( + fn: Callable[[torch.Tensor], Q], container: torch.Tensor +) -> Q: ... + + +@overload +def _apply_to_tensors(fn: Callable[[torch.Tensor], Any], container: R) -> R: ... + + +def _apply_to_tensors(fn, container): + """Recursively apply to all tensor in different kinds of container types.""" + + def apply(x): + from torch.nn.parallel.scatter_gather import _is_namedtuple + + if isinstance(x, torch.Tensor): + return fn(x) + elif hasattr(x, "__dataclass_fields__"): + dc = dataclasses.replace(x) + changes = { + f.name: apply(getattr(dc, f.name)) for f in dataclasses.fields(dc) + } + return dataclasses.replace(dc, **changes) + elif isinstance(x, OrderedDict): + od = x.__class__() + for key, value in x.items(): + od[key] = apply(value) + return od + elif isinstance(x, PackedSequence): + apply(x.data) + return x + elif isinstance(x, dict): + return {key: apply(value) for key, value in x.items()} + elif _is_namedtuple(x): + res = (apply(el) for el in x) + return type(x)(*res) + elif isinstance(x, (list, tuple, set)): + return type(x)(apply(el) for el in x) + else: + return x + + return apply(container) + + +def _to_kwargs( + inputs: tuple[Any, ...], + kwargs: Optional[dict[str, Any]], + target_device: torch.device, + use_side_stream_for_tensor_copies: bool, +) -> tuple[tuple[Any, ...], tuple[dict[str, Any], ...]]: + moved_inputs = ( + _recursive_to(inputs, target_device, use_side_stream_for_tensor_copies) + if inputs + else [] + ) + moved_kwargs = ( + _recursive_to(kwargs, target_device, use_side_stream_for_tensor_copies) + if kwargs + else [] + ) + if len(moved_inputs) < len(moved_kwargs): + moved_inputs.extend([() for _ in range(len(moved_kwargs) - len(inputs))]) + elif len(moved_kwargs) < len(moved_inputs): + moved_kwargs.extend([{} for _ in range(len(moved_inputs) - len(moved_kwargs))]) + return tuple(moved_inputs), tuple(moved_kwargs) + + +def _verify_param_shape_across_processes( + process_group: dist.ProcessGroup, + tensors: list[torch.Tensor], + logger: Optional["dist.Logger"] = None, +): + return dist._verify_params_across_processes(process_group, tensors, logger) + + +def _sync_module_states( + module: nn.Module, + process_group: dist.ProcessGroup, + broadcast_bucket_size: int, + src: int, + params_and_buffers_to_ignore: Container[str], + broadcast_buffers: bool = True, +) -> None: + """ + Sync ``module``'s parameters and buffers state. + + Syncs ``module``'s parameters and buffers state so that all ranks contain + the same module state across all ranks. Note that this API assumes that all + parameter shapes are consistent before running the synchronization. This can + be checked with ``_verify_param_shape_across_processes``. + """ + module_states: list[torch.Tensor] = [] + for name, param in module.named_parameters(): + if name not in params_and_buffers_to_ignore: + module_states.append(param.detach()) + + if broadcast_buffers: + for name, buffer in module.named_buffers(): + if name not in params_and_buffers_to_ignore: + module_states.append(buffer.detach()) + + _sync_params_and_buffers(process_group, module_states, broadcast_bucket_size, src) + + +def _sync_params_and_buffers( + process_group: dist.ProcessGroup, + module_states: list[torch.Tensor], + broadcast_bucket_size: int, + src: int, +) -> None: + """Synchronize ``module_states`` (list of tensors) across all processes by broadcasting them from rank 0.""" + if len(module_states) > 0: + dist._broadcast_coalesced( + process_group, module_states, broadcast_bucket_size, src + ) + + +def _replace_by_prefix( + state_dict: dict[str, Any], + old_prefix: str, + new_prefix: str, +) -> None: + """ + Replace all keys that match a given old_prefix with a new_prefix (in-place). + + Usage:: + + state_dict = {"layer.xyz": torch.tensor(1)} + replace_by_prefix_(state_dict, "layer.", "module.layer.") + assert state_dict == {"module.layer.xyz": torch.tensor(1)} + """ + if old_prefix == new_prefix: + raise ValueError("old_prefix and new_prefix must be distinct") + for key in list(state_dict.keys()): + if not key.startswith(old_prefix): + continue + new_key = new_prefix + key[len(old_prefix) :] + state_dict[new_key] = state_dict[key] + del state_dict[key] + + +def _data_ptr_allocated(tensor: torch.Tensor) -> bool: + return tensor.untyped_storage().data_ptr() > 0 + + +def _get_root_modules(modules: list[nn.Module]) -> list[nn.Module]: + """ + Returns the modules in ``modules`` that are root modules (i.e. + parent-less) with respect to the set ``modules``. In other words, these + are the modules in ``modules`` that are the not child of any other + module in ``modules``. + """ + root_modules: list[nn.Module] = [] + module_to_modules: dict[nn.Module, set[nn.Module]] = { + module: set(module.modules()) for module in modules + } + for candidate_module in modules: + is_root_module = True + for module, _modules in module_to_modules.items(): + is_child_module = ( + candidate_module is not module and candidate_module in _modules + ) + if is_child_module: + is_root_module = False + break + if is_root_module: + root_modules.append(candidate_module) + return root_modules diff --git a/phivenv/Lib/site-packages/torch/functional.py b/phivenv/Lib/site-packages/torch/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..371f036bce06b82c7ef4df5f03247ecd47029819 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/functional.py @@ -0,0 +1,2237 @@ +# mypy: allow-untyped-defs +import itertools +import operator +from collections.abc import Sequence +from typing import Any, Optional, TYPE_CHECKING, Union + +import torch +import torch.nn.functional as F +from torch import _VF, Tensor +from torch._C import _add_docstr +from torch._jit_internal import _overload as overload, boolean_dispatch +from torch._lowrank import pca_lowrank, svd_lowrank +from torch.overrides import ( + handle_torch_function, + has_torch_function, + has_torch_function_unary, + has_torch_function_variadic, +) + + +__all__ = [ + "atleast_1d", + "atleast_2d", + "atleast_3d", + "align_tensors", + "broadcast_shapes", + "broadcast_tensors", + "cartesian_prod", + "block_diag", + "cdist", + "chain_matmul", + "einsum", + "istft", + "lu", + "norm", + "meshgrid", + "pca_lowrank", + "split", + "stft", + "svd_lowrank", + "tensordot", + "unique", + "unique_consecutive", + "unravel_index", +] + + +def broadcast_tensors(*tensors): + r"""broadcast_tensors(*tensors) -> List of Tensors + + Broadcasts the given tensors according to :ref:`broadcasting-semantics`. + + Args: + *tensors: any number of tensors of the same type + + .. warning:: + + More than one element of a broadcasted tensor may refer to a single + memory location. As a result, in-place operations (especially ones that + are vectorized) may result in incorrect behavior. If you need to write + to the tensors, please clone them first. + + Example:: + + >>> x = torch.arange(3).view(1, 3) + >>> y = torch.arange(2).view(2, 1) + >>> a, b = torch.broadcast_tensors(x, y) + >>> a.size() + torch.Size([2, 3]) + >>> a + tensor([[0, 1, 2], + [0, 1, 2]]) + """ + # This wrapper exists to support variadic args. + if has_torch_function(tensors): + return handle_torch_function(broadcast_tensors, tensors, *tensors) + return _VF.broadcast_tensors(tensors) # type: ignore[attr-defined] + + +def broadcast_shapes(*shapes): + r"""broadcast_shapes(*shapes) -> Size + + Similar to :func:`broadcast_tensors` but for shapes. + + This is equivalent to + ``torch.broadcast_tensors(*map(torch.empty, shapes))[0].shape`` + but avoids the need create to intermediate tensors. This is useful for + broadcasting tensors of common batch shape but different rightmost shape, + e.g. to broadcast mean vectors with covariance matrices. + + Example:: + + >>> torch.broadcast_shapes((2,), (3, 1), (1, 1, 1)) + torch.Size([1, 3, 2]) + + Args: + \*shapes (torch.Size): Shapes of tensors. + + Returns: + shape (torch.Size): A shape compatible with all input shapes. + + Raises: + RuntimeError: If shapes are incompatible. + """ + # This wrapper exists to support variadic args. + # TODO Move this to C++ once the jit has better support for torch.Size. + if not torch.jit.is_tracing(): + max_len = 0 + for shape in shapes: + if isinstance(shape, (int, torch.SymInt)): + if max_len < 1: + max_len = 1 + elif isinstance(shape, (tuple, list)): + s = len(shape) + if max_len < s: + max_len = s + result = [1] * max_len + + from torch.fx.experimental.symbolic_shapes import ( + guard_size_oblivious, + is_nested_int, + ) + + for shape in shapes: + if isinstance(shape, (int, torch.SymInt)): + shape = (shape,) + if isinstance(shape, (tuple, list)): + for i in range(-1, -1 - len(shape), -1): + if shape[i] < 0: + raise RuntimeError( + f"Trying to create tensor with negative dimension ({shape[i]}): ({shape[i]})" + ) + + # NB: handle nested ints specially to avoid invalid guarding on Ne(j0, 1). + if is_nested_int(shape[i]): + # Broadcasting is allowed for (j0, 1) or (j0, j0); + # not (j0, j1), (j0, 5), etc. + if is_nested_int(result[i]) and guard_size_oblivious( + shape[i] == result[i] + ): + continue + else: + # NB: result is initialized to 1 so this is effectively an + # equals one test + if guard_size_oblivious(shape[i] == 1) or guard_size_oblivious( + shape[i] == result[i] + ): + continue + + if result[i] != 1: + raise RuntimeError( + "Shape mismatch: objects cannot be broadcast to a single shape" + ) + result[i] = shape[i] + else: + raise RuntimeError( + "Input shapes should be of type ints, a tuple of ints, or a list of ints, got ", + shape, + ) + return torch.Size(result) + else: + # with implementation above, torch.jit.trace hardcodes the sizes which makes subsequent replays fail + with torch.no_grad(): + scalar = torch.zeros((), device="cpu") + tensors = [scalar.expand(shape) for shape in shapes] + tensors = broadcast_tensors(*tensors) + return tensors[0].shape + + +def split( + tensor: Tensor, + split_size_or_sections: Union[int, list[int]], + dim: int = 0, +) -> tuple[Tensor, ...]: + r"""Splits the tensor into chunks. Each chunk is a view of the original tensor. + + If :attr:`split_size_or_sections` is an integer type, then :attr:`tensor` will + be split into equally sized chunks (if possible). Last chunk will be smaller if + the tensor size along the given dimension :attr:`dim` is not divisible by + :attr:`split_size`. + + If :attr:`split_size_or_sections` is a list, then :attr:`tensor` will be split + into ``len(split_size_or_sections)`` chunks with sizes in :attr:`dim` according + to :attr:`split_size_or_sections`. + + Args: + tensor (Tensor): tensor to split. + split_size_or_sections (int) or (list(int)): size of a single chunk or + list of sizes for each chunk + dim (int): dimension along which to split the tensor. + + Example:: + + >>> a = torch.arange(10).reshape(5, 2) + >>> a + tensor([[0, 1], + [2, 3], + [4, 5], + [6, 7], + [8, 9]]) + >>> torch.split(a, 2) + (tensor([[0, 1], + [2, 3]]), + tensor([[4, 5], + [6, 7]]), + tensor([[8, 9]])) + >>> torch.split(a, [1, 4]) + (tensor([[0, 1]]), + tensor([[2, 3], + [4, 5], + [6, 7], + [8, 9]])) + """ + if has_torch_function_unary(tensor): + return handle_torch_function( + split, (tensor,), tensor, split_size_or_sections, dim=dim + ) + # Overwriting reason: + # This dispatches to two ATen functions depending on the type of + # split_size_or_sections. The branching code is in _tensor.py, which we + # call here. + return tensor.split(split_size_or_sections, dim) + + +def einsum(*args: Any) -> Tensor: + r"""einsum(equation, *operands) -> Tensor + + Sums the product of the elements of the input :attr:`operands` along dimensions specified using a notation + based on the Einstein summation convention. + + Einsum allows computing many common multi-dimensional linear algebraic array operations by representing them + in a short-hand format based on the Einstein summation convention, given by :attr:`equation`. The details of + this format are described below, but the general idea is to label every dimension of the input :attr:`operands` + with some subscript and define which subscripts are part of the output. The output is then computed by summing + the product of the elements of the :attr:`operands` along the dimensions whose subscripts are not part of the + output. For example, matrix multiplication can be computed using einsum as `torch.einsum("ij,jk->ik", A, B)`. + Here, j is the summation subscript and i and k the output subscripts (see section below for more details on why). + + Equation: + + The :attr:`equation` string specifies the subscripts (letters in `[a-zA-Z]`) for each dimension of + the input :attr:`operands` in the same order as the dimensions, separating subscripts for each operand by a + comma (','), e.g. `'ij,jk'` specify subscripts for two 2D operands. The dimensions labeled with the same subscript + must be broadcastable, that is, their size must either match or be `1`. The exception is if a subscript is + repeated for the same input operand, in which case the dimensions labeled with this subscript for this operand + must match in size and the operand will be replaced by its diagonal along these dimensions. The subscripts that + appear exactly once in the :attr:`equation` will be part of the output, sorted in increasing alphabetical order. + The output is computed by multiplying the input :attr:`operands` element-wise, with their dimensions aligned based + on the subscripts, and then summing out the dimensions whose subscripts are not part of the output. + + Optionally, the output subscripts can be explicitly defined by adding an arrow ('->') at the end of the equation + followed by the subscripts for the output. For instance, the following equation computes the transpose of a + matrix multiplication: 'ij,jk->ki'. The output subscripts must appear at least once for some input operand and + at most once for the output. + + Ellipsis ('...') can be used in place of subscripts to broadcast the dimensions covered by the ellipsis. + Each input operand may contain at most one ellipsis which will cover the dimensions not covered by subscripts, + e.g. for an input operand with 5 dimensions, the ellipsis in the equation `'ab...c'` cover the third and fourth + dimensions. The ellipsis does not need to cover the same number of dimensions across the :attr:`operands` but the + 'shape' of the ellipsis (the size of the dimensions covered by them) must broadcast together. If the output is not + explicitly defined with the arrow ('->') notation, the ellipsis will come first in the output (left-most dimensions), + before the subscript labels that appear exactly once for the input operands. e.g. the following equation implements + batch matrix multiplication `'...ij,...jk'`. + + A few final notes: the equation may contain whitespaces between the different elements (subscripts, ellipsis, + arrow and comma) but something like `'. . .'` is not valid. An empty string `''` is valid for scalar operands. + + .. note:: + + ``torch.einsum`` handles ellipsis ('...') differently from NumPy in that it allows dimensions + covered by the ellipsis to be summed over, that is, ellipsis are not required to be part of the output. + + .. note:: + + Please install opt-einsum (https://optimized-einsum.readthedocs.io/en/stable/) in order to enroll into a more + performant einsum. You can install when installing torch like so: `pip install torch[opt-einsum]` or by itself + with `pip install opt-einsum`. + + If opt-einsum is available, this function will automatically speed up computation and/or consume less memory + by optimizing contraction order through our opt_einsum backend :mod:`torch.backends.opt_einsum` (The _ vs - is + confusing, I know). This optimization occurs when there are at least three inputs, since the order does not matter + otherwise. Note that finding `the` optimal path is an NP-hard problem, thus, opt-einsum relies on different + heuristics to achieve near-optimal results. If opt-einsum is not available, the default order is to contract + from left to right. + + To bypass this default behavior, add the following to disable opt_einsum and skip path calculation: + ``torch.backends.opt_einsum.enabled = False`` + + To specify which strategy you'd like for opt_einsum to compute the contraction path, add the following line: + ``torch.backends.opt_einsum.strategy = 'auto'``. The default strategy is 'auto', and we also support 'greedy' and + 'optimal'. Disclaimer that the runtime of 'optimal' is factorial in the number of inputs! See more details in + the opt_einsum documentation (https://optimized-einsum.readthedocs.io/en/stable/path_finding.html). + + .. note:: + + As of PyTorch 1.10 :func:`torch.einsum` also supports the sublist format (see examples below). In this format, + subscripts for each operand are specified by sublists, list of integers in the range [0, 52). These sublists + follow their operands, and an extra sublist can appear at the end of the input to specify the output's + subscripts., e.g. `torch.einsum(op1, sublist1, op2, sublist2, ..., [subslist_out])`. Python's `Ellipsis` object + may be provided in a sublist to enable broadcasting as described in the Equation section above. + + Args: + equation (str): The subscripts for the Einstein summation. + operands (List[Tensor]): The tensors to compute the Einstein summation of. + + Examples:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> # trace + >>> torch.einsum('ii', torch.randn(4, 4)) + tensor(-1.2104) + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> # diagonal + >>> torch.einsum('ii->i', torch.randn(4, 4)) + tensor([-0.1034, 0.7952, -0.2433, 0.4545]) + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> # outer product + >>> x = torch.randn(5) + >>> y = torch.randn(4) + >>> torch.einsum('i,j->ij', x, y) + tensor([[ 0.1156, -0.2897, -0.3918, 0.4963], + [-0.3744, 0.9381, 1.2685, -1.6070], + [ 0.7208, -1.8058, -2.4419, 3.0936], + [ 0.1713, -0.4291, -0.5802, 0.7350], + [ 0.5704, -1.4290, -1.9323, 2.4480]]) + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> # batch matrix multiplication + >>> As = torch.randn(3, 2, 5) + >>> Bs = torch.randn(3, 5, 4) + >>> torch.einsum('bij,bjk->bik', As, Bs) + tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], + [-1.6706, -0.8097, -0.8025, -2.1183]], + + [[ 4.2239, 0.3107, -0.5756, -0.2354], + [-1.4558, -0.3460, 1.5087, -0.8530]], + + [[ 2.8153, 1.8787, -4.3839, -1.2112], + [ 0.3728, -2.1131, 0.0921, 0.8305]]]) + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> # with sublist format and ellipsis + >>> torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2]) + tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], + [-1.6706, -0.8097, -0.8025, -2.1183]], + + [[ 4.2239, 0.3107, -0.5756, -0.2354], + [-1.4558, -0.3460, 1.5087, -0.8530]], + + [[ 2.8153, 1.8787, -4.3839, -1.2112], + [ 0.3728, -2.1131, 0.0921, 0.8305]]]) + + >>> # batch permute + >>> A = torch.randn(2, 3, 4, 5) + >>> torch.einsum('...ij->...ji', A).shape + torch.Size([2, 3, 5, 4]) + + >>> # equivalent to torch.nn.functional.bilinear + >>> A = torch.randn(3, 5, 4) + >>> l = torch.randn(2, 5) + >>> r = torch.randn(2, 4) + >>> torch.einsum('bn,anm,bm->ba', l, A, r) + tensor([[-0.3430, -5.2405, 0.4494], + [ 0.3311, 5.5201, -3.0356]]) + """ + import torch.backends.opt_einsum as opt_einsum + + # This wrapper exists to support variadic args. + if len(args) < 2: + raise ValueError( + "einsum(): must specify the equation string and at least one operand, " + "or at least one operand and its subscripts list" + ) + + equation = None + operands = None + + if isinstance(args[0], torch.Tensor): + # Convert the subscript list format which is an interleaving of operand and its subscripts + # list with an optional output subscripts list at the end (see documentation for more details on this) + # to the equation string format by creating the equation string from the subscripts list and grouping the + # input operands into a tensorlist (List[Tensor]). + def parse_subscript(n: int) -> str: + if n == Ellipsis: + return "..." + if n >= 0 and n < 26: + return chr(ord("A") + n) + if n >= 26 and n < 52: + return chr(ord("a") + n - 26) + raise ValueError( + "einsum(): subscript in subscript list is not within the valid range [0, 52)" + ) + + # Parse subscripts for input operands + equation = ",".join("".join(parse_subscript(s) for s in l) for l in args[1::2]) + + # Parse optional output subscripts (provided when the number of arguments is odd) + if len(args) % 2 == 1: + equation += "->" + "".join(parse_subscript(s) for s in args[-1]) + operands = args[:-1:2] + else: + operands = args[::2] + else: + equation = args[0] + operands = args[1:] + + if has_torch_function(operands): + return handle_torch_function(einsum, operands, equation, *operands) + + if len(operands) == 1 and isinstance(operands[0], (list, tuple)): + # the old interface of passing the operands as one list argument + _operands = operands[0] + # recurse incase operands contains value that has torch function + # in the original implementation this line is omitted + return einsum(equation, *_operands) + + if len(operands) <= 2 or not opt_einsum.enabled: + # the path for contracting 0 or 1 time(s) is already optimized + # or the user has disabled using opt_einsum + return _VF.einsum(equation, operands) # type: ignore[attr-defined] + + path = None + if opt_einsum.is_available(): + _opt_einsum = opt_einsum.get_opt_einsum() + tupled_path = _opt_einsum.contract_path( + equation, *operands, optimize=opt_einsum.strategy + )[0] + # flatten path for dispatching to C++ + path = [*itertools.chain.from_iterable(tupled_path)] + return _VF.einsum(equation, operands, path=path) # type: ignore[attr-defined] + + +# This wrapper exists to support variadic args. +if TYPE_CHECKING: + # The JIT doesn't understand Union, so only add type annotation for mypy + def meshgrid( + *tensors: Union[Tensor, list[Tensor]], indexing: Optional[str] = None + ) -> tuple[Tensor, ...]: + return _meshgrid(*tensors, indexing=indexing) + +else: + + def meshgrid(*tensors, indexing: Optional[str] = None) -> tuple[Tensor, ...]: + r"""Creates grids of coordinates specified by the 1D inputs in `attr`:tensors. + + This is helpful when you want to visualize data over some + range of inputs. See below for a plotting example. + + Given :math:`N` 1D tensors :math:`T_0 \ldots T_{N-1}` as + inputs with corresponding sizes :math:`S_0 \ldots S_{N-1}`, + this creates :math:`N` N-dimensional tensors :math:`G_0 \ldots + G_{N-1}`, each with shape :math:`(S_0, ..., S_{N-1})` where + the output :math:`G_i` is constructed by expanding :math:`T_i` + to the result shape. + + .. note:: + 0D inputs are treated equivalently to 1D inputs of a + single element. + + .. warning:: + `torch.meshgrid(*tensors)` currently has the same behavior + as calling `numpy.meshgrid(*arrays, indexing='ij')`. + + In the future `torch.meshgrid` will transition to + `indexing='xy'` as the default. + + https://github.com/pytorch/pytorch/issues/50276 tracks + this issue with the goal of migrating to NumPy's behavior. + + .. seealso:: + + :func:`torch.cartesian_prod` has the same effect but it + collects the data in a tensor of vectors. + + Args: + tensors (list of Tensor): list of scalars or 1 dimensional tensors. Scalars will be + treated as tensors of size :math:`(1,)` automatically + + indexing: (str, optional): the indexing mode, either "xy" + or "ij", defaults to "ij". See warning for future changes. + + If "xy" is selected, the first dimension corresponds + to the cardinality of the second input and the second + dimension corresponds to the cardinality of the first + input. + + If "ij" is selected, the dimensions are in the same + order as the cardinality of the inputs. + + Returns: + seq (sequence of Tensors): If the input has :math:`N` + tensors of size :math:`S_0 \ldots S_{N-1}``, then the + output will also have :math:`N` tensors, where each tensor + is of shape :math:`(S_0, ..., S_{N-1})`. + + Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> y = torch.tensor([4, 5, 6]) + + Observe the element-wise pairings across the grid, (1, 4), + (1, 5), ..., (3, 6). This is the same thing as the + cartesian product. + >>> grid_x, grid_y = torch.meshgrid(x, y, indexing='ij') + >>> grid_x + tensor([[1, 1, 1], + [2, 2, 2], + [3, 3, 3]]) + >>> grid_y + tensor([[4, 5, 6], + [4, 5, 6], + [4, 5, 6]]) + + This correspondence can be seen when these grids are + stacked properly. + >>> torch.equal(torch.cat(tuple(torch.dstack([grid_x, grid_y]))), + ... torch.cartesian_prod(x, y)) + True + + `torch.meshgrid` is commonly used to produce a grid for + plotting. + >>> # xdoctest: +REQUIRES(module:matplotlib) + >>> # xdoctest: +REQUIRES(env:DOCTEST_SHOW) + >>> import matplotlib.pyplot as plt + >>> xs = torch.linspace(-5, 5, steps=100) + >>> ys = torch.linspace(-5, 5, steps=100) + >>> x, y = torch.meshgrid(xs, ys, indexing='xy') + >>> z = torch.sin(torch.sqrt(x * x + y * y)) + >>> ax = plt.axes(projection='3d') + >>> ax.plot_surface(x.numpy(), y.numpy(), z.numpy()) + >>> plt.show() + + .. image:: ../_static/img/meshgrid.png + :width: 512 + + """ + return _meshgrid(*tensors, indexing=indexing) + + +def _meshgrid(*tensors, indexing: Optional[str]): + if has_torch_function(tensors): + return handle_torch_function(meshgrid, tensors, *tensors, indexing=indexing) + if len(tensors) == 1 and isinstance(tensors[0], (list, tuple)): + # the old interface of passing the operands as one list argument + tensors = tensors[0] # type: ignore[assignment] + + # Continue allowing call of old method that takes no indexing + # kwarg for forward compatibility reasons. + # + # Remove this two weeks after landing. + kwargs = {} if indexing is None else {"indexing": indexing} + return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined] + + +def stft( + input: Tensor, + n_fft: int, + hop_length: Optional[int] = None, + win_length: Optional[int] = None, + window: Optional[Tensor] = None, + center: bool = True, + pad_mode: str = "reflect", + normalized: bool = False, + onesided: Optional[bool] = None, + return_complex: Optional[bool] = None, + align_to_window: Optional[bool] = None, +) -> Tensor: + r"""Short-time Fourier transform (STFT). + + .. warning:: + From version 1.8.0, :attr:`return_complex` must always be given + explicitly for real inputs and `return_complex=False` has been + deprecated. Strongly prefer `return_complex=True` as in a future + pytorch release, this function will only return complex tensors. + + Note that :func:`torch.view_as_real` can be used to recover a real + tensor with an extra last dimension for real and imaginary components. + + .. warning:: + From version 2.1, a warning will be provided if a :attr:`window` is + not specified. In a future release, this attribute will be required. + Not providing a window currently defaults to using a rectangular window, + which may result in undesirable artifacts. Consider using tapered windows, + such as :func:`torch.hann_window`. + + The STFT computes the Fourier transform of short overlapping windows of the + input. This giving frequency components of the signal as they change over + time. The interface of this function is modeled after (but *not* a drop-in + replacement for) librosa_ stft function. + + .. _librosa: https://librosa.org/doc/latest/generated/librosa.stft.html + + Ignoring the optional batch dimension, this method computes the following + expression: + + .. math:: + X[\omega, m] = \sum_{k = 0}^{\text{win\_length-1}}% + \text{window}[k]\ \text{input}[m \times \text{hop\_length} + k]\ % + \exp\left(- j \frac{2 \pi \cdot \omega k}{\text{n\_fft}}\right), + + where :math:`m` is the index of the sliding window, and :math:`\omega` is + the frequency :math:`0 \leq \omega < \text{n\_fft}` for ``onesided=False``, + or :math:`0 \leq \omega < \lfloor \text{n\_fft} / 2 \rfloor + 1` for ``onesided=True``. + + * :attr:`input` must be either a 1-D time sequence or a 2-D batch of time + sequences. + + * If :attr:`hop_length` is ``None`` (default), it is treated as equal to + ``floor(n_fft / 4)``. + + * If :attr:`win_length` is ``None`` (default), it is treated as equal to + :attr:`n_fft`. + + * :attr:`window` can be a 1-D tensor of size :attr:`win_length`, e.g., from + :meth:`torch.hann_window`. If :attr:`window` is ``None`` (default), it is + treated as if having :math:`1` everywhere in the window. If + :math:`\text{win\_length} < \text{n\_fft}`, :attr:`window` will be padded on + both sides to length :attr:`n_fft` before being applied. + + * If :attr:`center` is ``True`` (default), :attr:`input` will be padded on + both sides so that the :math:`t`-th frame is centered at time + :math:`t \times \text{hop\_length}`. Otherwise, the :math:`t`-th frame + begins at time :math:`t \times \text{hop\_length}`. + + * :attr:`pad_mode` determines the padding method used on :attr:`input` when + :attr:`center` is ``True``. See :meth:`torch.nn.functional.pad` for + all available options. Default is ``"reflect"``. + + * If :attr:`onesided` is ``True`` (default for real input), only values for + :math:`\omega` in :math:`\left[0, 1, 2, \dots, \left\lfloor + \frac{\text{n\_fft}}{2} \right\rfloor + 1\right]` are returned because + the real-to-complex Fourier transform satisfies the conjugate symmetry, + i.e., :math:`X[m, \omega] = X[m, \text{n\_fft} - \omega]^*`. + Note if the input or window tensors are complex, then :attr:`onesided` + output is not possible. + + * If :attr:`normalized` is ``True`` (default is ``False``), the function + returns the normalized STFT results, i.e., multiplied by :math:`(\text{frame\_length})^{-0.5}`. + + * If :attr:`return_complex` is ``True`` (default if input is complex), the + return is a ``input.dim() + 1`` dimensional complex tensor. If ``False``, + the output is a ``input.dim() + 2`` dimensional real tensor where the last + dimension represents the real and imaginary components. + + Returns either a complex tensor of size :math:`(* \times N \times T)` if + :attr:`return_complex` is true, or a real tensor of size :math:`(* \times N + \times T \times 2)`. Where :math:`*` is the optional batch size of + :attr:`input`, :math:`N` is the number of frequencies where STFT is applied + and :math:`T` is the total number of frames used. + + .. warning:: + This function changed signature at version 0.4.1. Calling with the + previous signature may cause error or return incorrect result. + + Args: + input (Tensor): the input tensor of shape `(B?, L)` where `B?` is an optional + batch dimension + n_fft (int): size of Fourier transform + hop_length (int, optional): the distance between neighboring sliding window + frames. Default: ``None`` (treated as equal to ``floor(n_fft / 4)``) + win_length (int, optional): the size of window frame and STFT filter. + Default: ``None`` (treated as equal to :attr:`n_fft`) + window (Tensor, optional): the optional window function. + Shape must be 1d and `<= n_fft` + Default: ``None`` (treated as window of all :math:`1` s) + center (bool, optional): whether to pad :attr:`input` on both sides so + that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`. + Default: ``True`` + pad_mode (str, optional): controls the padding method used when + :attr:`center` is ``True``. Default: ``"reflect"`` + normalized (bool, optional): controls whether to return the normalized STFT results + Default: ``False`` + onesided (bool, optional): controls whether to return half of results to + avoid redundancy for real inputs. + Default: ``True`` for real :attr:`input` and :attr:`window`, ``False`` otherwise. + return_complex (bool, optional): whether to return a complex tensor, or + a real tensor with an extra last dimension for the real and + imaginary components. + + .. versionchanged:: 2.0 + ``return_complex`` is now a required argument for real inputs, + as the default is being transitioned to ``True``. + + .. deprecated:: 2.0 + ``return_complex=False`` is deprecated, instead use ``return_complex=True`` + Note that calling :func:`torch.view_as_real` on the output will + recover the deprecated output format. + + Returns: + Tensor: A tensor containing the STFT result with shape `(B?, N, T, C?)` where + - `B?` is an optional batch dimension from the input. + - `N` is the number of frequency samples, `(n_fft // 2) + 1` for + `onesided=True`, or otherwise `n_fft`. + - `T` is the number of frames, `1 + L // hop_length` + for `center=True`, or `1 + (L - n_fft) // hop_length` otherwise. + - `C?` is an optional length-2 dimension of real and imaginary + components, present when `return_complex=False`. + + """ + if has_torch_function_unary(input): + return handle_torch_function( + stft, + (input,), + input, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + center=center, + pad_mode=pad_mode, + normalized=normalized, + onesided=onesided, + return_complex=return_complex, + align_to_window=align_to_window, + ) + if center and align_to_window is not None: + raise RuntimeError( + "stft align_to_window should only be set when center = false" + ) + # NOTE: Do not edit. This code will be removed once the forward-compatibility + # period is over for PR #73432 + if center: + signal_dim = input.dim() + extended_shape = [1] * (3 - signal_dim) + list(input.size()) + pad = int(n_fft // 2) + input = F.pad(input.view(extended_shape), [pad, pad], pad_mode) + input = input.view(input.shape[-signal_dim:]) + return _VF.stft( # type: ignore[attr-defined] + input, + n_fft, + hop_length, + win_length, + window, + normalized, + onesided, + return_complex, + align_to_window, + ) + + +istft = _add_docstr( + torch.istft, + "istft(input, n_fft, hop_length=None, win_length=None, window=None, center=True, " + "normalized=False, onesided=None, length=None, return_complex=False) -> Tensor:\n" + r""" +Inverse short time Fourier Transform. This is expected to be the inverse of :func:`~torch.stft`. + +.. warning:: + From version 2.1, a warning will be provided if a :attr:`window` is + not specified. In a future release, this attribute will be required. + Please provide the same window used in the stft call. + +It has the same parameters (+ additional optional parameter of :attr:`length`) and it should return the +least squares estimation of the original signal. The algorithm will check using the NOLA condition ( +nonzero overlap). + +Important consideration in the parameters :attr:`window` and :attr:`center` so that the envelope +created by the summation of all the windows is never zero at certain point in time. Specifically, +:math:`\sum_{t=-\infty}^{\infty} |w|^2[n-t\times hop\_length] \cancel{=} 0`. + +Since :func:`~torch.stft` discards elements at the end of the signal if they do not fit in a frame, +``istft`` may return a shorter signal than the original signal (can occur if :attr:`center` is False +since the signal isn't padded). If `length` is given in the arguments and is longer than expected, +``istft`` will pad zeros to the end of the returned signal. + +If :attr:`center` is ``True``, then there will be padding e.g. ``'constant'``, ``'reflect'``, etc. +Left padding can be trimmed off exactly because they can be calculated but right padding cannot be +calculated without additional information. + +Example: Suppose the last window is: +``[17, 18, 0, 0, 0]`` vs ``[18, 0, 0, 0, 0]`` + +The :attr:`n_fft`, :attr:`hop_length`, :attr:`win_length` are all the same which prevents the calculation +of right padding. These additional values could be zeros or a reflection of the signal so providing +:attr:`length` could be useful. If :attr:`length` is ``None`` then padding will be aggressively removed +(some loss of signal). + +[1] D. W. Griffin and J. S. Lim, "Signal estimation from modified short-time Fourier transform," +IEEE Trans. ASSP, vol.32, no.2, pp.236-243, Apr. 1984. + +Args: + input (Tensor): The input tensor. Expected to be in the format of :func:`~torch.stft`, + output. That is a complex tensor of shape `(B?, N, T)` where + + - `B?` is an optional batch dimension + - `N` is the number of frequency samples, `(n_fft // 2) + 1` + for onesided input, or otherwise `n_fft`. + - `T` is the number of frames, `1 + length // hop_length` for centered stft, + or `1 + (length - n_fft) // hop_length` otherwise. + + .. versionchanged:: 2.0 + Real datatype inputs are no longer supported. Input must now have a + complex datatype, as returned by ``stft(..., return_complex=True)``. + n_fft (int): Size of Fourier transform + hop_length (Optional[int]): The distance between neighboring sliding window frames. + (Default: ``n_fft // 4``) + win_length (Optional[int]): The size of window frame and STFT filter. (Default: ``n_fft``) + window (Optional[torch.Tensor]): The optional window function. + Shape must be 1d and `<= n_fft` + (Default: ``torch.ones(win_length)``) + center (bool): Whether :attr:`input` was padded on both sides so that the :math:`t`-th frame is + centered at time :math:`t \times \text{hop\_length}`. + (Default: ``True``) + normalized (bool): Whether the STFT was normalized. (Default: ``False``) + onesided (Optional[bool]): Whether the STFT was onesided. + (Default: ``True`` if `n_fft != fft_size` in the input size) + length (Optional[int]): The amount to trim the signal by (i.e. the + original signal length). Defaults to `(T - 1) * hop_length` for + centered stft, or `n_fft + (T - 1) * hop_length` otherwise, where `T` + is the number of input frames. + return_complex (Optional[bool]): + Whether the output should be complex, or if the input should be + assumed to derive from a real signal and window. + Note that this is incompatible with ``onesided=True``. + (Default: ``False``) + +Returns: + Tensor: Least squares estimation of the original signal of shape `(B?, length)` where + `B?` is an optional batch dimension from the input tensor. +""", +) + + +if TYPE_CHECKING: + # These _impl functions return a variable number of tensors as output with + # __torch_function__; tuple unpacking is done already rather than being + # done by the caller of the _impl function + _unique_impl_out = Any +else: + _unique_impl_out = tuple[Tensor, Tensor, Tensor] + + +def _unique_impl( + input: Tensor, + sorted: bool = True, + return_inverse: bool = False, + return_counts: bool = False, + dim: Optional[int] = None, +) -> _unique_impl_out: + r"""unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None) -> tuple[Tensor, Tensor, Tensor] + + Returns the unique elements of the input tensor. + + .. note:: This function is different from :func:`torch.unique_consecutive` in the sense that + this function also eliminates non-consecutive duplicate values. + + .. note:: Currently in the CUDA implementation and the CPU implementation, + `torch.unique` always sort the tensor at the beginning regardless of the `sort` argument. + Sorting could be slow, so if your input tensor is already sorted, it is recommended to use + :func:`torch.unique_consecutive` which avoids the sorting. + + Args: + input (Tensor): the input tensor + sorted (bool): Whether to sort the unique elements in ascending order + before returning as output. + return_inverse (bool): Whether to also return the indices for where + elements in the original input ended up in the returned unique list. + return_counts (bool): Whether to also return the counts for each unique + element. + dim (int, optional): the dimension to operate upon. If ``None``, the + unique of the flattened input is returned. Otherwise, each of the + tensors indexed by the given dimension is treated as one of the + elements to apply the unique operation upon. See examples for more + details. Default: ``None`` + + Returns: + (Tensor, Tensor (optional), Tensor (optional)): A tensor or a tuple of tensors containing + + - **output** (*Tensor*): the output list of unique scalar elements. + - **inverse_indices** (*Tensor*): (optional) if + :attr:`return_inverse` is True, there will be an additional + returned tensor (same shape as input) representing the indices + for where elements in the original input map to in the output; + otherwise, this function will only return a single tensor. + - **counts** (*Tensor*): (optional) if + :attr:`return_counts` is True, there will be an additional + returned tensor (same shape as output or output.size(dim), + if dim was specified) representing the number of occurrences + for each unique value or tensor. + + Example:: + + >>> output = torch.unique(torch.tensor([1, 3, 2, 3], dtype=torch.long)) + >>> output + tensor([1, 2, 3]) + + >>> output, inverse_indices = torch.unique( + ... torch.tensor([1, 3, 2, 3], dtype=torch.long), sorted=True, return_inverse=True) + >>> output + tensor([1, 2, 3]) + >>> inverse_indices + tensor([0, 2, 1, 2]) + + >>> output, inverse_indices = torch.unique( + ... torch.tensor([[1, 3], [2, 3]], dtype=torch.long), sorted=True, return_inverse=True) + >>> output + tensor([1, 2, 3]) + >>> inverse_indices + tensor([[0, 2], + [1, 2]]) + + >>> a = torch.tensor([ + ... [ + ... [1, 1, 0, 0], + ... [1, 1, 0, 0], + ... [0, 0, 1, 1], + ... ], + ... [ + ... [0, 0, 1, 1], + ... [0, 0, 1, 1], + ... [1, 1, 1, 1], + ... ], + ... [ + ... [1, 1, 0, 0], + ... [1, 1, 0, 0], + ... [0, 0, 1, 1], + ... ], + ... ]) + + >>> # If we call `torch.unique(a, dim=0)`, each of the tensors `a[idx, :, :]` + >>> # will be compared. We can see that `a[0, :, :]` and `a[2, :, :]` match + >>> # each other, so one of them will be removed. + >>> (a[0, :, :] == a[2, :, :]).all() + tensor(True) + >>> a_unique_dim0 = torch.unique(a, dim=0) + >>> a_unique_dim0 + tensor([[[0, 0, 1, 1], + [0, 0, 1, 1], + [1, 1, 1, 1]], + [[1, 1, 0, 0], + [1, 1, 0, 0], + [0, 0, 1, 1]]]) + + >>> # Notice which sub-tensors from `a` match with the sub-tensors from + >>> # `a_unique_dim0`: + >>> (a_unique_dim0[0, :, :] == a[1, :, :]).all() + tensor(True) + >>> (a_unique_dim0[1, :, :] == a[0, :, :]).all() + tensor(True) + + >>> # For `torch.unique(a, dim=1)`, each of the tensors `a[:, idx, :]` are + >>> # compared. `a[:, 0, :]` and `a[:, 1, :]` match each other, so one of + >>> # them will be removed. + >>> (a[:, 0, :] == a[:, 1, :]).all() + tensor(True) + >>> torch.unique(a, dim=1) + tensor([[[0, 0, 1, 1], + [1, 1, 0, 0]], + [[1, 1, 1, 1], + [0, 0, 1, 1]], + [[0, 0, 1, 1], + [1, 1, 0, 0]]]) + + >>> # For `torch.unique(a, dim=2)`, the tensors `a[:, :, idx]` are compared. + >>> # `a[:, :, 0]` and `a[:, :, 1]` match each other. Also, `a[:, :, 2]` and + >>> # `a[:, :, 3]` match each other as well. So in this case, two of the + >>> # sub-tensors will be removed. + >>> (a[:, :, 0] == a[:, :, 1]).all() + tensor(True) + >>> (a[:, :, 2] == a[:, :, 3]).all() + tensor(True) + >>> torch.unique(a, dim=2) + tensor([[[0, 1], + [0, 1], + [1, 0]], + [[1, 0], + [1, 0], + [1, 1]], + [[0, 1], + [0, 1], + [1, 0]]]) + """ + if has_torch_function_unary(input): + return handle_torch_function( + unique, + (input,), + input, + sorted=sorted, + return_inverse=return_inverse, + return_counts=return_counts, + dim=dim, + ) + + if dim is not None: + output, inverse_indices, counts = _VF.unique_dim( + input, + dim, + sorted=sorted, + return_inverse=return_inverse, + return_counts=return_counts, + ) + else: + output, inverse_indices, counts = torch._unique2( + input, + sorted=sorted, + return_inverse=return_inverse, + return_counts=return_counts, + ) + return output, inverse_indices, counts + + +def _unique_consecutive_impl( + input: Tensor, + return_inverse: bool = False, + return_counts: bool = False, + dim: Optional[int] = None, +) -> _unique_impl_out: + r"""Eliminates all but the first element from every consecutive group of equivalent elements. + + .. note:: This function is different from :func:`torch.unique` in the sense that this function + only eliminates consecutive duplicate values. This semantics is similar to `std::unique` + in C++. + + Args: + input (Tensor): the input tensor + return_inverse (bool): Whether to also return the indices for where + elements in the original input ended up in the returned unique list. + return_counts (bool): Whether to also return the counts for each unique + element. + dim (int): the dimension to apply unique. If ``None``, the unique of the + flattened input is returned. default: ``None`` + + Returns: + (Tensor, Tensor (optional), Tensor (optional)): A tensor or a tuple of tensors containing + + - **output** (*Tensor*): the output list of unique scalar elements. + - **inverse_indices** (*Tensor*): (optional) if + :attr:`return_inverse` is True, there will be an additional + returned tensor (same shape as input) representing the indices + for where elements in the original input map to in the output; + otherwise, this function will only return a single tensor. + - **counts** (*Tensor*): (optional) if + :attr:`return_counts` is True, there will be an additional + returned tensor (same shape as output or output.size(dim), + if dim was specified) representing the number of occurrences + for each unique value or tensor. + + Example:: + + >>> x = torch.tensor([1, 1, 2, 2, 3, 1, 1, 2]) + >>> output = torch.unique_consecutive(x) + >>> output + tensor([1, 2, 3, 1, 2]) + + >>> output, inverse_indices = torch.unique_consecutive(x, return_inverse=True) + >>> output + tensor([1, 2, 3, 1, 2]) + >>> inverse_indices + tensor([0, 0, 1, 1, 2, 3, 3, 4]) + + >>> output, counts = torch.unique_consecutive(x, return_counts=True) + >>> output + tensor([1, 2, 3, 1, 2]) + >>> counts + tensor([2, 2, 1, 2, 1]) + """ + if has_torch_function_unary(input): + return handle_torch_function( + unique_consecutive, + (input,), + input, + return_inverse=return_inverse, + return_counts=return_counts, + dim=dim, + ) + output, inverse_indices, counts = _VF.unique_consecutive( # type: ignore[attr-defined] + input, return_inverse=return_inverse, return_counts=return_counts, dim=dim + ) + return output, inverse_indices, counts + + +def _return_counts( + input, + sorted=True, + return_inverse=False, + return_counts=False, + dim=None, +): + # type: (Tensor, bool, bool, bool, Optional[int]) -> tuple[Tensor, Tensor] + + if has_torch_function_unary(input): + return _unique_impl(input, sorted, return_inverse, return_counts, dim) + + output, _, counts = _unique_impl(input, sorted, return_inverse, return_counts, dim) + return output, counts + + +def _return_output( + input, + sorted=True, + return_inverse=False, + return_counts=False, + dim=None, +): + # type: (Tensor, bool, bool, bool, Optional[int]) -> Tensor + + if has_torch_function_unary(input): + return _unique_impl(input, sorted, return_inverse, return_counts, dim) + + output, _, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim) + return output + + +def _return_inverse( + input, + sorted=True, + return_inverse=False, + return_counts=False, + dim=None, +): + # type: (Tensor, bool, bool, bool, Optional[int]) -> tuple[Tensor, Tensor] + + if has_torch_function_unary(input): + return _unique_impl(input, sorted, return_inverse, return_counts, dim) + + output, inverse_indices, _ = _unique_impl( + input, sorted, return_inverse, return_counts, dim + ) + return output, inverse_indices + + +_return_inverse_false = boolean_dispatch( + arg_name="return_counts", + arg_index=3, + default=False, + if_true=_return_counts, + if_false=_return_output, + module_name=__name__, + func_name="unique", +) + +_return_inverse_true = boolean_dispatch( + arg_name="return_counts", + arg_index=3, + default=False, + if_true=_unique_impl, + if_false=_return_inverse, + module_name=__name__, + func_name="unique", +) + +# The return type of unique depends on `return_inverse`, and `return_counts` so in order to +# resolve the output type in TorchScript we need to statically know the value of both parameters + +unique = boolean_dispatch( + arg_name="return_inverse", + arg_index=2, + default=False, + if_true=_return_inverse_true, + if_false=_return_inverse_false, + module_name=__name__, + func_name="unique", +) +unique.__doc__ = _unique_impl.__doc__ + + +def _consecutive_return_counts( + input, + return_inverse=False, + return_counts=False, + dim=None, +): + # type: (Tensor, bool, bool, Optional[int]) -> tuple[Tensor, Tensor] + + if has_torch_function_unary(input): + return _unique_consecutive_impl(input, return_inverse, return_counts, dim) + + output, _, counts = _unique_consecutive_impl( + input, return_inverse, return_counts, dim + ) + return output, counts + + +def _consecutive_return_output( + input, + return_inverse=False, + return_counts=False, + dim=None, +): + # type: (Tensor, bool, bool, Optional[int]) -> Tensor + + if has_torch_function_unary(input): + return _unique_consecutive_impl(input, return_inverse, return_counts, dim) + + output, _, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim) + return output + + +def _consecutive_return_inverse( + input, + return_inverse=False, + return_counts=False, + dim=None, +): + # type: (Tensor, bool, bool, Optional[int]) -> tuple[Tensor, Tensor] + + if has_torch_function_unary(input): + return _unique_consecutive_impl(input, return_inverse, return_counts, dim) + + output, inverse_indices, _ = _unique_consecutive_impl( + input, return_inverse, return_counts, dim + ) + return output, inverse_indices + + +_consecutive_return_inverse_false = boolean_dispatch( + arg_name="return_counts", + arg_index=1, + default=False, + if_true=_consecutive_return_counts, + if_false=_consecutive_return_output, + module_name=__name__, + func_name="unique_consecutive", +) + +_consecutive_return_inverse_true = boolean_dispatch( + arg_name="return_counts", + arg_index=1, + default=False, + if_true=_unique_consecutive_impl, + if_false=_consecutive_return_inverse, + module_name=__name__, + func_name="unique_consecutive", +) + +# The return type of unique depends on `return_inverse`, and `return_counts` so in order to +# resolve the output type in TorchScript we need to statically know the value of both parameters + +unique_consecutive = boolean_dispatch( + arg_name="return_inverse", + arg_index=2, + default=False, + if_true=_consecutive_return_inverse_true, + if_false=_consecutive_return_inverse_false, + module_name=__name__, + func_name="unique_consecutive", +) +unique_consecutive.__doc__ = _unique_consecutive_impl.__doc__ + +if TYPE_CHECKING: + pass + # There's no good way to use this type annotation without breaking JIT + # overloads. So leave untyped for mypy for now. +else: + + @overload + def tensordot( + a, + b, + dims: int = 2, + out: Optional[torch.Tensor] = None, + ): + pass + + @overload + def tensordot( # noqa: F811 + a, + b, + dims: tuple[list[int], list[int]], + out: Optional[torch.Tensor] = None, + ): + pass + + @overload + def tensordot( # noqa: F811 + a, + b, + dims: list[list[int]], + out: Optional[torch.Tensor] = None, + ): + pass + + @overload + def tensordot( # noqa: F811 + a, + b, + dims: torch.Tensor, + out: Optional[torch.Tensor] = None, + ): + pass + + +def tensordot( # noqa: F811 + a, + b, + dims=2, + out: Optional[torch.Tensor] = None, +): + r"""Returns a contraction of a and b over multiple dimensions. + + :attr:`tensordot` implements a generalized matrix product. + + Args: + a (Tensor): Left tensor to contract + b (Tensor): Right tensor to contract + dims (int or Tuple[List[int], List[int]] or List[List[int]] containing two lists or Tensor): number of dimensions to + contract or explicit lists of dimensions for :attr:`a` and + :attr:`b` respectively + + When called with a non-negative integer argument :attr:`dims` = :math:`d`, and + the number of dimensions of :attr:`a` and :attr:`b` is :math:`m` and :math:`n`, + respectively, :func:`~torch.tensordot` computes + + .. math:: + r_{i_0,...,i_{m-d}, i_d,...,i_n} + = \sum_{k_0,...,k_{d-1}} a_{i_0,...,i_{m-d},k_0,...,k_{d-1}} \times b_{k_0,...,k_{d-1}, i_d,...,i_n}. + + When called with :attr:`dims` of the list form, the given dimensions will be contracted + in place of the last :math:`d` of :attr:`a` and the first :math:`d` of :math:`b`. The sizes + in these dimensions must match, but :func:`~torch.tensordot` will deal with broadcasted + dimensions. + + Examples:: + + >>> a = torch.arange(60.).reshape(3, 4, 5) + >>> b = torch.arange(24.).reshape(4, 3, 2) + >>> torch.tensordot(a, b, dims=([1, 0], [0, 1])) + tensor([[4400., 4730.], + [4532., 4874.], + [4664., 5018.], + [4796., 5162.], + [4928., 5306.]]) + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> a = torch.randn(3, 4, 5, device='cuda') + >>> b = torch.randn(4, 5, 6, device='cuda') + >>> c = torch.tensordot(a, b, dims=2).cpu() + tensor([[ 8.3504, -2.5436, 6.2922, 2.7556, -1.0732, 3.2741], + [ 3.3161, 0.0704, 5.0187, -0.4079, -4.3126, 4.8744], + [ 0.8223, 3.9445, 3.2168, -0.2400, 3.4117, 1.7780]]) + + >>> a = torch.randn(3, 5, 4, 6) + >>> b = torch.randn(6, 4, 5, 3) + >>> torch.tensordot(a, b, dims=([2, 1, 3], [1, 2, 0])) + tensor([[ 7.7193, -2.4867, -10.3204], + [ 1.5513, -14.4737, -6.5113], + [ -0.2850, 4.2573, -3.5997]]) + """ + if has_torch_function_variadic(a, b): + return handle_torch_function(tensordot, (a, b), a, b, dims=dims, out=out) + + if not isinstance(dims, (tuple, list, torch.Tensor, int, torch.SymInt)): + raise RuntimeError( + "tensordot expects dims to be int or " + + "tuple[list[int], list[int]] or " + + "list[list[int]] containing two lists, but got " + + f"dims={dims}" + ) + + dims_a: list[int] = [] + dims_b: list[int] = [] + + if isinstance(dims, (tuple, list)): + dims_a, dims_b = dims + + if isinstance(dims, torch.Tensor): + num_elements = dims.numel() + if num_elements > 1: + assert dims.size()[0] == 2 + dims_a = torch.jit.annotate(list[int], dims[0].tolist()) + dims_b = torch.jit.annotate(list[int], dims[1].tolist()) + else: + dims_val = int(dims.item()) + if dims_val < 0: + raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}") + dims_a = list(range(-dims_val, 0)) + dims_b = list(range(dims_val)) + + if isinstance(dims, (int, torch.SymInt)): + if dims < 0: + raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}") + if dims > min(a.dim(), b.dim()): + raise RuntimeError( + f"tensordot expects dims < ndim_a or ndim_b, but got dims={dims}" + ) + dims_a = list(range(-dims, 0)) + dims_b = list(range(dims)) + + if out is None: + return _VF.tensordot(a, b, dims_a, dims_b) # type: ignore[attr-defined] + else: + return _VF.tensordot(a, b, dims_a, dims_b, out=out) # type: ignore[attr-defined] + + +def cartesian_prod(*tensors: Tensor) -> Tensor: + """Do cartesian product of the given sequence of tensors. The behavior is similar to + python's `itertools.product`. + + Args: + *tensors: any number of 1 dimensional tensors. + + Returns: + Tensor: A tensor equivalent to converting all the input tensors into lists, + do `itertools.product` on these lists, and finally convert the resulting list + into tensor. + + Example:: + + >>> import itertools + >>> a = [1, 2, 3] + >>> b = [4, 5] + >>> list(itertools.product(a, b)) + [(1, 4), (1, 5), (2, 4), (2, 5), (3, 4), (3, 5)] + >>> tensor_a = torch.tensor(a) + >>> tensor_b = torch.tensor(b) + >>> torch.cartesian_prod(tensor_a, tensor_b) + tensor([[1, 4], + [1, 5], + [2, 4], + [2, 5], + [3, 4], + [3, 5]]) + """ + # This wrapper exists to support variadic args. + if has_torch_function(tensors): + return handle_torch_function(cartesian_prod, tensors, *tensors) + return _VF.cartesian_prod(tensors) # type: ignore[attr-defined] + + +def block_diag(*tensors): + """Create a block diagonal matrix from provided tensors. + + Args: + *tensors: One or more tensors with 0, 1, or 2 dimensions. + + Returns: + Tensor: A 2 dimensional tensor with all the input tensors arranged in + order such that their upper left and lower right corners are + diagonally adjacent. All other elements are set to 0. + + Example:: + + >>> import torch + >>> A = torch.tensor([[0, 1], [1, 0]]) + >>> B = torch.tensor([[3, 4, 5], [6, 7, 8]]) + >>> C = torch.tensor(7) + >>> D = torch.tensor([1, 2, 3]) + >>> E = torch.tensor([[4], [5], [6]]) + >>> torch.block_diag(A, B, C, D, E) + tensor([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 3, 4, 5, 0, 0, 0, 0, 0], + [0, 0, 6, 7, 8, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 7, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 2, 3, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 4], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 5], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 6]]) + """ + # This wrapper exists to support variadic args. + if has_torch_function(tensors): + return handle_torch_function(block_diag, tensors, *tensors) + return torch._C._VariableFunctions.block_diag(tensors) # type: ignore[attr-defined] + + +def cdist(x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary"): + # type: (Tensor, Tensor, float, str) -> (Tensor) + r"""Computes batched the p-norm distance between each pair of the two collections of row vectors. + + Args: + x1 (Tensor): input tensor where the last two dimensions represent the points and the feature dimension respectively. + The shape can be :math:`D_1 \times D_2 \times \cdots \times D_n \times P \times M`, + where :math:`P` is the number of points and :math:`M` is the feature dimension. + x2 (Tensor): input tensor where the last two dimensions also represent the points and the feature dimension respectively. + The shape can be :math:`D_1' \times D_2' \times \cdots \times D_m' \times R \times M`, + where :math:`R` is the number of points and :math:`M` is the feature dimension, + which should match the feature dimension of `x1`. + p: p value for the p-norm distance to calculate between each vector pair + :math:`\in [0, \infty]`. + compute_mode: + 'use_mm_for_euclid_dist_if_necessary' - will use matrix multiplication approach to calculate + euclidean distance (p = 2) if P > 25 or R > 25 + 'use_mm_for_euclid_dist' - will always use matrix multiplication approach to calculate + euclidean distance (p = 2) + 'donot_use_mm_for_euclid_dist' - will never use matrix multiplication approach to calculate + euclidean distance (p = 2) + Default: use_mm_for_euclid_dist_if_necessary. + + If x1 has shape :math:`B \times P \times M` and x2 has shape :math:`B \times R \times M` then the + output will have shape :math:`B \times P \times R`. + + This function is equivalent to `scipy.spatial.distance.cdist(input,'minkowski', p=p)` + if :math:`p \in (0, \infty)`. When :math:`p = 0` it is equivalent to + `scipy.spatial.distance.cdist(input, 'hamming') * M`. When :math:`p = \infty`, the closest + scipy function is `scipy.spatial.distance.cdist(xn, lambda x, y: np.abs(x - y).max())`. + + Example: + + >>> a = torch.tensor([[0.9041, 0.0196], [-0.3108, -2.4423], [-0.4821, 1.059]]) + >>> a + tensor([[ 0.9041, 0.0196], + [-0.3108, -2.4423], + [-0.4821, 1.0590]]) + >>> b = torch.tensor([[-2.1763, -0.4713], [-0.6986, 1.3702]]) + >>> b + tensor([[-2.1763, -0.4713], + [-0.6986, 1.3702]]) + >>> torch.cdist(a, b, p=2) + tensor([[3.1193, 2.0959], + [2.7138, 3.8322], + [2.2830, 0.3791]]) + """ + if has_torch_function_variadic(x1, x2): + return handle_torch_function( + cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode + ) + if compute_mode == "use_mm_for_euclid_dist_if_necessary": + return _VF.cdist(x1, x2, p, None) # type: ignore[attr-defined] + elif compute_mode == "use_mm_for_euclid_dist": + return _VF.cdist(x1, x2, p, 1) # type: ignore[attr-defined] + elif compute_mode == "donot_use_mm_for_euclid_dist": + return _VF.cdist(x1, x2, p, 2) # type: ignore[attr-defined] + else: + raise ValueError(f"{compute_mode} is not a valid value for compute_mode") + + +def atleast_1d(*tensors): + r""" + Returns a 1-dimensional view of each input tensor with zero dimensions. + Input tensors with one or more dimensions are returned as-is. + + Args: + input (Tensor or list of Tensors) + + Returns: + output (Tensor or tuple of Tensors) + + Example:: + + >>> x = torch.arange(2) + >>> x + tensor([0, 1]) + >>> torch.atleast_1d(x) + tensor([0, 1]) + >>> x = torch.tensor(1.) + >>> x + tensor(1.) + >>> torch.atleast_1d(x) + tensor([1.]) + >>> x = torch.tensor(0.5) + >>> y = torch.tensor(1.) + >>> torch.atleast_1d((x, y)) + (tensor([0.5000]), tensor([1.])) + """ + # This wrapper exists to support variadic args. + if has_torch_function(tensors): + return handle_torch_function(atleast_1d, tensors, *tensors) + if len(tensors) == 1: + tensors = tensors[0] + return _VF.atleast_1d(tensors) # type: ignore[attr-defined] + + +def atleast_2d(*tensors): + r""" + Returns a 2-dimensional view of each input tensor with zero dimensions. + Input tensors with two or more dimensions are returned as-is. + + Args: + input (Tensor or list of Tensors) + + Returns: + output (Tensor or tuple of Tensors) + + Example:: + + >>> x = torch.tensor(1.) + >>> x + tensor(1.) + >>> torch.atleast_2d(x) + tensor([[1.]]) + >>> x = torch.arange(4).view(2, 2) + >>> x + tensor([[0, 1], + [2, 3]]) + >>> torch.atleast_2d(x) + tensor([[0, 1], + [2, 3]]) + >>> x = torch.tensor(0.5) + >>> y = torch.tensor(1.) + >>> torch.atleast_2d((x, y)) + (tensor([[0.5000]]), tensor([[1.]])) + """ + # This wrapper exists to support variadic args. + if has_torch_function(tensors): + return handle_torch_function(atleast_2d, tensors, *tensors) + if len(tensors) == 1: + tensors = tensors[0] + return _VF.atleast_2d(tensors) # type: ignore[attr-defined] + + +def atleast_3d(*tensors): + r""" + Returns a 3-dimensional view of each input tensor with zero dimensions. + Input tensors with three or more dimensions are returned as-is. + + Args: + input (Tensor or list of Tensors) + + Returns: + output (Tensor or tuple of Tensors) + + Example: + + >>> x = torch.tensor(0.5) + >>> x + tensor(0.5000) + >>> torch.atleast_3d(x) + tensor([[[0.5000]]]) + >>> y = torch.arange(4).view(2, 2) + >>> y + tensor([[0, 1], + [2, 3]]) + >>> torch.atleast_3d(y) + tensor([[[0], + [1]], + + [[2], + [3]]]) + >>> x = torch.tensor(1).view(1, 1, 1) + >>> x + tensor([[[1]]]) + >>> torch.atleast_3d(x) + tensor([[[1]]]) + >>> x = torch.tensor(0.5) + >>> y = torch.tensor(1.0) + >>> torch.atleast_3d((x, y)) + (tensor([[[0.5000]]]), tensor([[[1.]]])) + """ + # This wrapper exists to support variadic args. + if has_torch_function(tensors): + return handle_torch_function(atleast_3d, tensors, *tensors) + if len(tensors) == 1: + tensors = tensors[0] + return _VF.atleast_3d(tensors) # type: ignore[attr-defined] + + +if TYPE_CHECKING: + pass + # There's no good way to use this type annotation; cannot rename norm() to + # _norm_impl() in a way that doesn't break JIT overloads. So leave untyped + # for mypy for now. + # def norm(input: Tensor, + # p: Optional[Union[str, Number]] = "fro", + # dim: Optional[Union[int, List[int]]] = None, + # keepdim: bool = False, + # out: Optional[Tensor] = None, + # dtype: _dtype = None) -> Tensor: + # return _norm_impl(input, p, dim, keepdim, out, dtype) +else: + # TODO: type dim as BroadcastingList when + # https://github.com/pytorch/pytorch/issues/33782 is fixed + @overload + def norm( + input, + p="fro", + dim=None, + keepdim=False, + out=None, + dtype=None, + ): + # type: (Tensor, str, Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor + pass + + @overload + def norm( # noqa: F811 + input, + p="fro", + dim=None, + keepdim=False, + out=None, + dtype=None, + ): + # type: (Tensor, Optional[number], Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor + pass + + @overload + def norm( # noqa: F811 + input, + p="fro", + dim=None, + keepdim=False, + out=None, + dtype=None, + ): + # type: (Tensor, Optional[number], Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor + pass + + @overload + def norm( # noqa: F811 + input, + p="fro", + dim=None, + keepdim=False, + out=None, + dtype=None, + ): + # type: (Tensor, str, Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor + pass + + +def norm( # noqa: F811 + input, + p: Optional[Union[float, str]] = "fro", + dim=None, + keepdim=False, + out=None, + dtype=None, +): + r"""Returns the matrix norm or vector norm of a given tensor. + + .. warning:: + + torch.norm is deprecated and may be removed in a future PyTorch release. + Its documentation and behavior may be incorrect, and it is no longer + actively maintained. + + Use :func:`torch.linalg.vector_norm` when computing vector norms and + :func:`torch.linalg.matrix_norm` when computing matrix norms. + For a function with a similar behavior as this one see :func:`torch.linalg.norm`. + Note, however, the signature for these functions is slightly different than the + signature for ``torch.norm``. + + Args: + input (Tensor): The input tensor. Its data type must be either a floating + point or complex type. For complex inputs, the norm is calculated using the + absolute value of each element. If the input is complex and neither + :attr:`dtype` nor :attr:`out` is specified, the result's data type will + be the corresponding floating point type (e.g. float if :attr:`input` is + complexfloat). + + p (int, float, inf, -inf, 'fro', 'nuc', optional): the order of norm. Default: ``'fro'`` + The following norms can be calculated: + + ====== ============== ========================== + ord matrix norm vector norm + ====== ============== ========================== + 'fro' Frobenius norm -- + 'nuc' nuclear norm -- + Number -- sum(abs(x)**ord)**(1./ord) + ====== ============== ========================== + + The vector norm can be calculated across any number of dimensions. + The corresponding dimensions of :attr:`input` are flattened into + one dimension, and the norm is calculated on the flattened + dimension. + + Frobenius norm produces the same result as ``p=2`` in all cases + except when :attr:`dim` is a list of three or more dims, in which + case Frobenius norm throws an error. + + Nuclear norm can only be calculated across exactly two dimensions. + + dim (int, tuple of ints, list of ints, optional): + Specifies which dimension or dimensions of :attr:`input` to + calculate the norm across. If :attr:`dim` is ``None``, the norm will + be calculated across all dimensions of :attr:`input`. If the norm + type indicated by :attr:`p` does not support the specified number of + dimensions, an error will occur. + keepdim (bool, optional): whether the output tensors have :attr:`dim` + retained or not. Ignored if :attr:`dim` = ``None`` and + :attr:`out` = ``None``. Default: ``False`` + out (Tensor, optional): the output tensor. Ignored if + :attr:`dim` = ``None`` and :attr:`out` = ``None``. + dtype (:class:`torch.dtype`, optional): the desired data type of + returned tensor. If specified, the input tensor is casted to + :attr:`dtype` while performing the operation. Default: None. + + .. note:: + Even though ``p='fro'`` supports any number of dimensions, the true + mathematical definition of Frobenius norm only applies to tensors with + exactly two dimensions. :func:`torch.linalg.matrix_norm` with ``ord='fro'`` + aligns with the mathematical definition, since it can only be applied across + exactly two dimensions. + + Example:: + + >>> import torch + >>> a = torch.arange(9, dtype= torch.float) - 4 + >>> b = a.reshape((3, 3)) + >>> torch.norm(a) + tensor(7.7460) + >>> torch.norm(b) + tensor(7.7460) + >>> torch.norm(a, float('inf')) + tensor(4.) + >>> torch.norm(b, float('inf')) + tensor(4.) + >>> c = torch.tensor([[ 1, 2, 3], [-1, 1, 4]] , dtype=torch.float) + >>> torch.norm(c, dim=0) + tensor([1.4142, 2.2361, 5.0000]) + >>> torch.norm(c, dim=1) + tensor([3.7417, 4.2426]) + >>> torch.norm(c, p=1, dim=1) + tensor([6., 6.]) + >>> d = torch.arange(8, dtype=torch.float).reshape(2, 2, 2) + >>> torch.norm(d, dim=(1, 2)) + tensor([ 3.7417, 11.2250]) + >>> torch.norm(d[0, :, :]), torch.norm(d[1, :, :]) + (tensor(3.7417), tensor(11.2250)) + """ + + if has_torch_function_unary(input): + return handle_torch_function( + norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype + ) + + # NB. All the repeated code and weird python is to please TorchScript. + # For a more compact implementation see the relevant function in `_refs/__init__.py` + + # We don't do this for MPS or sparse tensors + if input.layout == torch.strided and input.device.type in ( + "cpu", + "cuda", + "xpu", + "meta", + torch.utils.backend_registration._privateuse1_backend_name, + ): + if dim is not None: + if isinstance(dim, (int, torch.SymInt)): + _dim = [dim] + else: + _dim = dim + else: + _dim = None # type: ignore[assignment] + + if isinstance(p, str): + if p == "fro" and ( + dim is None or isinstance(dim, (int, torch.SymInt)) or len(dim) <= 2 + ): + if out is None: + return torch.linalg.vector_norm( + input, 2, _dim, keepdim, dtype=dtype + ) + else: + return torch.linalg.vector_norm( + input, 2, _dim, keepdim, dtype=dtype, out=out + ) + + # Here we either call the nuclear norm, or we call matrix_norm with some arguments + # that will throw an error + if _dim is None: + _dim = list(range(input.ndim)) + if out is None: + return torch.linalg.matrix_norm(input, p, _dim, keepdim, dtype=dtype) + else: + return torch.linalg.matrix_norm( + input, p, _dim, keepdim, dtype=dtype, out=out + ) + else: + # NB. p should be Union[str, number], not Optional! + _p = 2.0 if p is None else p + if out is None: + return torch.linalg.vector_norm(input, _p, _dim, keepdim, dtype=dtype) + else: + return torch.linalg.vector_norm( + input, _p, _dim, keepdim, dtype=dtype, out=out + ) + + ndim = input.dim() + + # catch default case + if dim is None and out is None and dtype is None and p is not None: + if isinstance(p, str): + if p == "fro": + return _VF.frobenius_norm(input, dim=(), keepdim=keepdim) + if not isinstance(p, str): + _dim = list(range(ndim)) + return _VF.norm(input, p, dim=_dim, keepdim=keepdim) # type: ignore[attr-defined] + + # TODO: when https://github.com/pytorch/pytorch/issues/33782 is fixed + # remove the overloads where dim is an int and replace with BraodcastingList1 + # and remove next four lines, replace _dim with dim + if dim is not None: + if isinstance(dim, (int, torch.SymInt)): + _dim = [dim] + else: + _dim = dim + else: + _dim = None # type: ignore[assignment] + + if isinstance(p, str): + if p == "fro": + if dtype is not None: + raise ValueError("dtype argument is not supported in frobenius norm") + + if _dim is None: + _dim = list(range(ndim)) + if out is None: + return _VF.frobenius_norm(input, _dim, keepdim=keepdim) # type: ignore[arg-type] + else: + return _VF.frobenius_norm(input, _dim, keepdim=keepdim, out=out) # type: ignore[arg-type] + elif p == "nuc": + if dtype is not None: + raise ValueError("dtype argument is not supported in nuclear norm") + if _dim is None: + if out is None: + return _VF.nuclear_norm(input, keepdim=keepdim) # type: ignore[arg-type] + else: + return _VF.nuclear_norm(input, keepdim=keepdim, out=out) # type: ignore[arg-type] + else: + if out is None: + return _VF.nuclear_norm(input, _dim, keepdim=keepdim) # type: ignore[arg-type] + else: + return _VF.nuclear_norm(input, _dim, keepdim=keepdim, out=out) # type: ignore[arg-type] + raise RuntimeError(f"only valid string values are 'fro' and 'nuc', found {p}") + else: + if _dim is None: + _dim = list(range(ndim)) + + if out is None: + if dtype is None: + return _VF.norm(input, p, _dim, keepdim=keepdim) # type: ignore[attr-defined] + else: + return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype) # type: ignore[attr-defined] + else: + if dtype is None: + return _VF.norm(input, p, _dim, keepdim=keepdim, out=out) # type: ignore[attr-defined] + else: + return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype, out=out) # type: ignore[attr-defined] + + +def unravel_index( + indices: Tensor, + shape: Union[int, Sequence[int], torch.Size], +) -> tuple[Tensor, ...]: + r"""Converts a tensor of flat indices into a tuple of coordinate tensors that + index into an arbitrary tensor of the specified shape. + + Args: + indices (Tensor): An integer tensor containing indices into the + flattened version of an arbitrary tensor of shape :attr:`shape`. + All elements must be in the range ``[0, prod(shape) - 1]``. + + shape (int, sequence of ints, or torch.Size): The shape of the arbitrary + tensor. All elements must be non-negative. + + Returns: + tuple of Tensors: Each ``i``-th tensor in the output corresponds with + dimension ``i`` of :attr:`shape`. Each tensor has the same shape as + ``indices`` and contains one index into dimension ``i`` for each of the + flat indices given by ``indices``. + + Example:: + + >>> import torch + >>> torch.unravel_index(torch.tensor(4), (3, 2)) + (tensor(2), + tensor(0)) + + >>> torch.unravel_index(torch.tensor([4, 1]), (3, 2)) + (tensor([2, 0]), + tensor([0, 1])) + + >>> torch.unravel_index(torch.tensor([0, 1, 2, 3, 4, 5]), (3, 2)) + (tensor([0, 0, 1, 1, 2, 2]), + tensor([0, 1, 0, 1, 0, 1])) + + >>> torch.unravel_index(torch.tensor([1234, 5678]), (10, 10, 10, 10)) + (tensor([1, 5]), + tensor([2, 6]), + tensor([3, 7]), + tensor([4, 8])) + + >>> torch.unravel_index(torch.tensor([[1234], [5678]]), (10, 10, 10, 10)) + (tensor([[1], [5]]), + tensor([[2], [6]]), + tensor([[3], [7]]), + tensor([[4], [8]])) + + >>> torch.unravel_index(torch.tensor([[1234], [5678]]), (100, 100)) + (tensor([[12], [56]]), + tensor([[34], [78]])) + """ + if has_torch_function_unary(indices): + return handle_torch_function(unravel_index, (indices,), indices, shape=shape) + res_tensor = _unravel_index(indices, shape) + return res_tensor.unbind(-1) + + +def _unravel_index(indices: Tensor, shape: Union[int, Sequence[int]]) -> Tensor: + torch._check_type( + not indices.is_complex() + and not indices.is_floating_point() + and not indices.dtype == torch.bool, + lambda: f"expected 'indices' to be integer dtype, but got {indices.dtype}", + ) + + torch._check_type( + isinstance(shape, (int, torch.SymInt, Sequence)), + lambda: f"expected 'shape' to be int or sequence of ints, but got {type(shape)}", + ) + + if isinstance(shape, (int, torch.SymInt)): + shape = torch.Size([shape]) + else: + for dim in shape: + torch._check_type( + isinstance(dim, (int, torch.SymInt)), + lambda: f"expected 'shape' sequence to only contain ints, but got {type(dim)}", + ) + shape = torch.Size(shape) + + torch._check_value( + all(dim >= 0 for dim in shape), + lambda: f"'shape' cannot have negative values, but got {tuple(shape)}", + ) + + coefs = list( + reversed( + list( + itertools.accumulate( + reversed(shape[1:] + torch.Size([1])), func=operator.mul + ) + ) + ) + ) + return indices.unsqueeze(-1).floor_divide( + torch.tensor(coefs, device=indices.device, dtype=torch.int64) + ) % torch.tensor(shape, device=indices.device, dtype=torch.int64) + + +def chain_matmul(*matrices, out=None): + r"""Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed + using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms + of arithmetic operations (`[CLRS]`_). Note that since this is a function to compute the product, :math:`N` + needs to be greater than or equal to 2; if equal to 2 then a trivial matrix-matrix product is returned. + If :math:`N` is 1, then this is a no-op - the original matrix is returned as is. + + .. warning:: + + :func:`torch.chain_matmul` is deprecated and will be removed in a future PyTorch release. + Use :func:`torch.linalg.multi_dot` instead, which accepts a list of two or more tensors + rather than multiple arguments. + + Args: + matrices (Tensors...): a sequence of 2 or more 2-D tensors whose product is to be determined. + out (Tensor, optional): the output tensor. Ignored if :attr:`out` = ``None``. + + Returns: + Tensor: if the :math:`i^{th}` tensor was of dimensions :math:`p_{i} \times p_{i + 1}`, then the product + would be of dimensions :math:`p_{1} \times p_{N + 1}`. + + Example:: + + >>> # xdoctest: +SKIP + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> a = torch.randn(3, 4) + >>> b = torch.randn(4, 5) + >>> c = torch.randn(5, 6) + >>> d = torch.randn(6, 7) + >>> # will raise a deprecation warning + >>> torch.chain_matmul(a, b, c, d) + tensor([[ -2.3375, -3.9790, -4.1119, -6.6577, 9.5609, -11.5095, -3.2614], + [ 21.4038, 3.3378, -8.4982, -5.2457, -10.2561, -2.4684, 2.7163], + [ -0.9647, -5.8917, -2.3213, -5.2284, 12.8615, -12.2816, -2.5095]]) + + .. _`[CLRS]`: https://mitpress.mit.edu/books/introduction-algorithms-third-edition + """ + # This wrapper exists to support variadic args. + if has_torch_function(matrices): + return handle_torch_function(chain_matmul, matrices, *matrices) + + if out is None: + return _VF.chain_matmul(matrices) # type: ignore[attr-defined] + else: + return _VF.chain_matmul(matrices, out=out) # type: ignore[attr-defined] + + +def _lu_impl(A, pivot=True, get_infos=False, out=None): + # type: (Tensor, bool, bool, Any) -> tuple[Tensor, Tensor, Tensor] + r"""Computes the LU factorization of a matrix or batches of matrices + :attr:`A`. Returns a tuple containing the LU factorization and + pivots of :attr:`A`. Pivoting is done if :attr:`pivot` is set to + ``True``. + + .. warning:: + + :func:`torch.lu` is deprecated in favor of :func:`torch.linalg.lu_factor` + and :func:`torch.linalg.lu_factor_ex`. :func:`torch.lu` will be removed in a + future PyTorch release. + ``LU, pivots, info = torch.lu(A, compute_pivots)`` should be replaced with + + .. code:: python + + LU, pivots = torch.linalg.lu_factor(A, compute_pivots) + + ``LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True)`` should be replaced with + + .. code:: python + + LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots) + + .. note:: + * The returned permutation matrix for every matrix in the batch is + represented by a 1-indexed vector of size ``min(A.shape[-2], A.shape[-1])``. + ``pivots[i] == j`` represents that in the ``i``-th step of the algorithm, + the ``i``-th row was permuted with the ``j-1``-th row. + * LU factorization with :attr:`pivot` = ``False`` is not available + for CPU, and attempting to do so will throw an error. However, + LU factorization with :attr:`pivot` = ``False`` is available for + CUDA. + * This function does not check if the factorization was successful + or not if :attr:`get_infos` is ``True`` since the status of the + factorization is present in the third element of the return tuple. + * In the case of batches of square matrices with size less or equal + to 32 on a CUDA device, the LU factorization is repeated for + singular matrices due to the bug in the MAGMA library + (see magma issue 13). + * ``L``, ``U``, and ``P`` can be derived using :func:`torch.lu_unpack`. + + .. warning:: + The gradients of this function will only be finite when :attr:`A` is full rank. + This is because the LU decomposition is just differentiable at full rank matrices. + Furthermore, if :attr:`A` is close to not being full rank, + the gradient will be numerically unstable as it depends on the computation of :math:`L^{-1}` and :math:`U^{-1}`. + + Args: + A (Tensor): the tensor to factor of size :math:`(*, m, n)` + pivot (bool, optional): controls whether pivoting is done. Default: ``True`` + get_infos (bool, optional): if set to ``True``, returns an info IntTensor. + Default: ``False`` + out (tuple, optional): optional output tuple. If :attr:`get_infos` is ``True``, + then the elements in the tuple are Tensor, IntTensor, + and IntTensor. If :attr:`get_infos` is ``False``, then the + elements in the tuple are Tensor, IntTensor. Default: ``None`` + + Returns: + (Tensor, IntTensor, IntTensor (optional)): A tuple of tensors containing + + - **factorization** (*Tensor*): the factorization of size :math:`(*, m, n)` + + - **pivots** (*IntTensor*): the pivots of size :math:`(*, \text{min}(m, n))`. + ``pivots`` stores all the intermediate transpositions of rows. + The final permutation ``perm`` could be reconstructed by + applying ``swap(perm[i], perm[pivots[i] - 1])`` for ``i = 0, ..., pivots.size(-1) - 1``, + where ``perm`` is initially the identity permutation of :math:`m` elements + (essentially this is what :func:`torch.lu_unpack` is doing). + + - **infos** (*IntTensor*, *optional*): if :attr:`get_infos` is ``True``, this is a tensor of + size :math:`(*)` where non-zero values indicate whether factorization for the matrix or + each minibatch has succeeded or failed + + Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> A = torch.randn(2, 3, 3) + >>> A_LU, pivots = torch.lu(A) + >>> A_LU + tensor([[[ 1.3506, 2.5558, -0.0816], + [ 0.1684, 1.1551, 0.1940], + [ 0.1193, 0.6189, -0.5497]], + + [[ 0.4526, 1.2526, -0.3285], + [-0.7988, 0.7175, -0.9701], + [ 0.2634, -0.9255, -0.3459]]]) + >>> pivots + tensor([[ 3, 3, 3], + [ 3, 3, 3]], dtype=torch.int32) + >>> A_LU, pivots, info = torch.lu(A, get_infos=True) + >>> if info.nonzero().size(0) == 0: + ... print('LU factorization succeeded for all samples!') + LU factorization succeeded for all samples! + """ + # If get_infos is True, then we don't need to check for errors and vice versa + return torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos)) + + +if TYPE_CHECKING: + _ListOrSeq = Sequence[Tensor] +else: + _ListOrSeq = list[Tensor] + + +def _check_list_size(out_len: int, get_infos: bool, out: _ListOrSeq) -> None: + get_infos_int = 1 if get_infos else 0 + if out_len - get_infos_int != 2: + raise TypeError( + f"expected tuple of {2 + int(get_infos)} elements but got {out_len}" + ) + if not isinstance(out, (tuple, list)): + raise TypeError( + f"argument 'out' must be tuple of Tensors, not {type(out).__name__}" + ) + + +def _lu_with_infos(A, pivot=True, get_infos=False, out=None): + # type: (Tensor, bool, bool, Optional[tuple[Tensor, Tensor, Tensor]]) -> tuple[Tensor, Tensor, Tensor] + if has_torch_function_unary(A): + return handle_torch_function( + lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out + ) + result = _lu_impl(A, pivot, get_infos, out) + if out is not None: + _check_list_size(len(out), get_infos, out) + for i in range(len(out)): + out[i].resize_as_(result[i]).copy_(result[i]) + return out + else: + return result # A_LU, pivots, infos + + +def _lu_no_infos(A, pivot=True, get_infos=False, out=None): + # type: (Tensor, bool, bool, Optional[tuple[Tensor, Tensor]]) -> tuple[Tensor, Tensor] + # need to check for torch_function here so that we exit if + if has_torch_function_unary(A): + return handle_torch_function( + lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out + ) + result = _lu_impl(A, pivot, get_infos, out) + if out is not None: + _check_list_size(len(out), get_infos, out) + for i in range(len(out)): + out[i].resize_as_(result[i]).copy_(result[i]) + return out + else: + return result[0], result[1] # A_LU, pivots + + +# The return type of lu depends on `get_infos`, so in order to resolve the output type +# of lu in TorchScript we need to statically know the value of `get_infos` +lu = boolean_dispatch( + arg_name="get_infos", + arg_index=2, + default=False, + if_true=_lu_with_infos, + if_false=_lu_no_infos, + module_name=__name__, + func_name="lu", +) +lu.__doc__ = _lu_impl.__doc__ + + +def align_tensors(*tensors): + raise RuntimeError("`align_tensors` not yet implemented.") diff --git a/phivenv/Lib/site-packages/torch/hub.py b/phivenv/Lib/site-packages/torch/hub.py new file mode 100644 index 0000000000000000000000000000000000000000..ffad31b525f30ab8f4e95829df05ca3ce1d2d51c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/hub.py @@ -0,0 +1,875 @@ +# mypy: allow-untyped-defs +import contextlib +import errno +import hashlib +import json +import os +import re +import shutil +import sys +import tempfile +import uuid +import warnings +import zipfile +from pathlib import Path +from typing import Any, Optional, Union +from typing_extensions import deprecated +from urllib.error import HTTPError, URLError +from urllib.parse import urlparse # noqa: F401 +from urllib.request import Request, urlopen + +import torch +from torch.serialization import MAP_LOCATION + + +class _Faketqdm: # type: ignore[no-redef] + def __init__(self, total=None, disable=False, unit=None, *args, **kwargs): + self.total = total + self.disable = disable + self.n = 0 + # Ignore all extra *args and **kwargs lest you want to reinvent tqdm + + def update(self, n): + if self.disable: + return + + self.n += n + if self.total is None: + sys.stderr.write(f"\r{self.n:.1f} bytes") + else: + sys.stderr.write(f"\r{100 * self.n / float(self.total):.1f}%") + sys.stderr.flush() + + # Don't bother implementing; use real tqdm if you want + def set_description(self, *args, **kwargs): + pass + + def write(self, s): + sys.stderr.write(f"{s}\n") + + def close(self): + self.disable = True + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.disable: + return + + sys.stderr.write("\n") + + +try: + from tqdm import tqdm # If tqdm is installed use it, otherwise use the fake wrapper +except ImportError: + tqdm = _Faketqdm + +__all__ = [ + "download_url_to_file", + "get_dir", + "help", + "list", + "load", + "load_state_dict_from_url", + "set_dir", +] + +# matches bfd8deac from resnet18-bfd8deac.pth +HASH_REGEX = re.compile(r"-([a-f0-9]*)\.") + +_TRUSTED_REPO_OWNERS = ( + "facebookresearch", + "facebookincubator", + "pytorch", + "fairinternal", +) +ENV_GITHUB_TOKEN = "GITHUB_TOKEN" +ENV_TORCH_HOME = "TORCH_HOME" +ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME" +DEFAULT_CACHE_DIR = "~/.cache" +VAR_DEPENDENCY = "dependencies" +MODULE_HUBCONF = "hubconf.py" +READ_DATA_CHUNK = 128 * 1024 +_hub_dir: Optional[str] = None + + +@contextlib.contextmanager +def _add_to_sys_path(path): + sys.path.insert(0, path) + try: + yield + finally: + sys.path.remove(path) + + +# Copied from tools/shared/module_loader to be included in torch package +def _import_module(name, path): + import importlib.util + from importlib.abc import Loader + + spec = importlib.util.spec_from_file_location(name, path) + assert spec is not None + module = importlib.util.module_from_spec(spec) + assert isinstance(spec.loader, Loader) + spec.loader.exec_module(module) + return module + + +def _remove_if_exists(path): + if os.path.exists(path): + if os.path.isfile(path): + os.remove(path) + else: + shutil.rmtree(path) + + +def _git_archive_link(repo_owner, repo_name, ref): + # See https://docs.github.com/en/rest/reference/repos#download-a-repository-archive-zip + return f"https://github.com/{repo_owner}/{repo_name}/zipball/{ref}" + + +def _load_attr_from_module(module, func_name): + # Check if callable is defined in the module + if func_name not in dir(module): + return None + return getattr(module, func_name) + + +def _get_torch_home(): + torch_home = os.path.expanduser( + os.getenv( + ENV_TORCH_HOME, + os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "torch"), + ) + ) + return torch_home + + +def _parse_repo_info(github): + if ":" in github: + repo_info, ref = github.split(":") + else: + repo_info, ref = github, None + repo_owner, repo_name = repo_info.split("/") + + if ref is None: + # The ref wasn't specified by the user, so we need to figure out the + # default branch: main or master. Our assumption is that if main exists + # then it's the default branch, otherwise it's master. + try: + with urlopen(f"https://github.com/{repo_owner}/{repo_name}/tree/main/"): + ref = "main" + except HTTPError as e: + if e.code == 404: + ref = "master" + else: + raise + except URLError as e: + # No internet connection, need to check for cache as last resort + for possible_ref in ("main", "master"): + if os.path.exists( + f"{get_dir()}/{repo_owner}_{repo_name}_{possible_ref}" + ): + ref = possible_ref + break + if ref is None: + raise RuntimeError( + "It looks like there is no internet connection and the " + f"repo could not be found in the cache ({get_dir()})" + ) from e + return repo_owner, repo_name, ref + + +def _read_url(url): + with urlopen(url) as r: + return r.read().decode(r.headers.get_content_charset("utf-8")) + + +def _validate_not_a_forked_repo(repo_owner, repo_name, ref): + # Use urlopen to avoid depending on local git. + headers = {"Accept": "application/vnd.github.v3+json"} + token = os.environ.get(ENV_GITHUB_TOKEN) + if token is not None: + headers["Authorization"] = f"token {token}" + for url_prefix in ( + f"https://api.github.com/repos/{repo_owner}/{repo_name}/branches", + f"https://api.github.com/repos/{repo_owner}/{repo_name}/tags", + ): + page = 0 + while True: + page += 1 + url = f"{url_prefix}?per_page=100&page={page}" + response = json.loads(_read_url(Request(url, headers=headers))) + # Empty response means no more data to process + if not response: + break + for br in response: + if br["name"] == ref or br["commit"]["sha"].startswith(ref): + return + + raise ValueError( + f"Cannot find {ref} in https://github.com/{repo_owner}/{repo_name}. " + "If it's a commit from a forked repo, please call hub.load() with forked repo directly." + ) + + +def _get_cache_or_reload( + github, + force_reload, + trust_repo, + calling_fn, + verbose=True, + skip_validation=False, +): + # Setup hub_dir to save downloaded files + hub_dir = get_dir() + os.makedirs(hub_dir, exist_ok=True) + # Parse github repo information + repo_owner, repo_name, ref = _parse_repo_info(github) + # Github allows branch name with slash '/', + # this causes confusion with path on both Linux and Windows. + # Backslash is not allowed in Github branch name so no need to + # to worry about it. + normalized_br = ref.replace("/", "_") + # Github renames folder repo-v1.x.x to repo-1.x.x + # We don't know the repo name before downloading the zip file + # and inspect name from it. + # To check if cached repo exists, we need to normalize folder names. + owner_name_branch = "_".join([repo_owner, repo_name, normalized_br]) + repo_dir = os.path.join(hub_dir, owner_name_branch) + # Check that the repo is in the trusted list + _check_repo_is_trusted( + repo_owner, + repo_name, + owner_name_branch, + trust_repo=trust_repo, + calling_fn=calling_fn, + ) + + use_cache = (not force_reload) and os.path.exists(repo_dir) + + if use_cache: + if verbose: + sys.stderr.write(f"Using cache found in {repo_dir}\n") + else: + # Validate the tag/branch is from the original repo instead of a forked repo + if not skip_validation: + _validate_not_a_forked_repo(repo_owner, repo_name, ref) + + cached_file = os.path.join(hub_dir, normalized_br + ".zip") + _remove_if_exists(cached_file) + + try: + url = _git_archive_link(repo_owner, repo_name, ref) + sys.stdout.write(f'Downloading: "{url}" to {cached_file}\n') + download_url_to_file(url, cached_file, progress=False) + except HTTPError as err: + if err.code == 300: + # Getting a 300 Multiple Choices error likely means that the ref is both a tag and a branch + # in the repo. This can be disambiguated by explicitely using refs/heads/ or refs/tags + # See https://git-scm.com/book/en/v2/Git-Internals-Git-References + # Here, we do the same as git: we throw a warning, and assume the user wanted the branch + warnings.warn( + f"The ref {ref} is ambiguous. Perhaps it is both a tag and a branch in the repo? " + "Torchhub will now assume that it's a branch. " + "You can disambiguate tags and branches by explicitly passing refs/heads/branch_name or " + "refs/tags/tag_name as the ref. That might require using skip_validation=True." + ) + disambiguated_branch_ref = f"refs/heads/{ref}" + url = _git_archive_link( + repo_owner, repo_name, ref=disambiguated_branch_ref + ) + download_url_to_file(url, cached_file, progress=False) + else: + raise + + with zipfile.ZipFile(cached_file) as cached_zipfile: + extraced_repo_name = cached_zipfile.infolist()[0].filename + extracted_repo = os.path.join(hub_dir, extraced_repo_name) + _remove_if_exists(extracted_repo) + # Unzip the code and rename the base folder + cached_zipfile.extractall(hub_dir) + + _remove_if_exists(cached_file) + _remove_if_exists(repo_dir) + shutil.move(extracted_repo, repo_dir) # rename the repo + + return repo_dir + + +def _check_repo_is_trusted( + repo_owner, + repo_name, + owner_name_branch, + trust_repo, + calling_fn="load", +): + hub_dir = get_dir() + filepath = os.path.join(hub_dir, "trusted_list") + + if not os.path.exists(filepath): + Path(filepath).touch() + with open(filepath) as file: + trusted_repos = tuple(line.strip() for line in file) + + # To minimize friction of introducing the new trust_repo mechanism, we consider that + # if a repo was already downloaded by torchhub, then it is already trusted (even if it's not in the allowlist) + trusted_repos_legacy = next(os.walk(hub_dir))[1] + + owner_name = "_".join([repo_owner, repo_name]) + is_trusted = ( + owner_name in trusted_repos + or owner_name_branch in trusted_repos_legacy + or repo_owner in _TRUSTED_REPO_OWNERS + ) + + # TODO: Remove `None` option in 2.0 and change the default to "check" + if trust_repo is None: + if not is_trusted: + warnings.warn( + "You are about to download and run code from an untrusted repository. In a future release, this won't " + "be allowed. To add the repository to your trusted list, change the command to {calling_fn}(..., " + "trust_repo=False) and a command prompt will appear asking for an explicit confirmation of trust, " + f"or {calling_fn}(..., trust_repo=True), which will assume that the prompt is to be answered with " + f"'yes'. You can also use {calling_fn}(..., trust_repo='check') which will only prompt for " + f"confirmation if the repo is not already trusted. This will eventually be the default behaviour" + ) + return + + if (trust_repo is False) or (trust_repo == "check" and not is_trusted): + response = input( + f"The repository {owner_name} does not belong to the list of trusted repositories and as such cannot be downloaded. " + "Do you trust this repository and wish to add it to the trusted list of repositories (y/N)?" + ) + if response.lower() in ("y", "yes"): + if is_trusted: + print("The repository is already trusted.") + elif response.lower() in ("n", "no", ""): + raise Exception("Untrusted repository.") # noqa: TRY002 + else: + raise ValueError(f"Unrecognized response {response}.") + + # At this point we're sure that the user trusts the repo (or wants to trust it) + if not is_trusted: + with open(filepath, "a") as file: + file.write(owner_name + "\n") + + +def _check_module_exists(name): + import importlib.util + + return importlib.util.find_spec(name) is not None + + +def _check_dependencies(m): + dependencies = _load_attr_from_module(m, VAR_DEPENDENCY) + + if dependencies is not None: + missing_deps = [pkg for pkg in dependencies if not _check_module_exists(pkg)] + if len(missing_deps): + raise RuntimeError(f"Missing dependencies: {', '.join(missing_deps)}") + + +def _load_entry_from_hubconf(m, model): + if not isinstance(model, str): + raise ValueError("Invalid input: model should be a string of function name") + + # Note that if a missing dependency is imported at top level of hubconf, it will + # throw before this function. It's a chicken and egg situation where we have to + # load hubconf to know what're the dependencies, but to import hubconf it requires + # a missing package. This is fine, Python will throw proper error message for users. + _check_dependencies(m) + + func = _load_attr_from_module(m, model) + + if func is None or not callable(func): + raise RuntimeError(f"Cannot find callable {model} in hubconf") + + return func + + +def get_dir() -> str: + r""" + Get the Torch Hub cache directory used for storing downloaded models & weights. + + If :func:`~torch.hub.set_dir` is not called, default path is ``$TORCH_HOME/hub`` where + environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``. + ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux + filesystem layout, with a default value ``~/.cache`` if the environment + variable is not set. + """ + # Issue warning to move data if old env is set + if os.getenv("TORCH_HUB"): + warnings.warn("TORCH_HUB is deprecated, please use env TORCH_HOME instead") + + if _hub_dir is not None: + return _hub_dir + return os.path.join(_get_torch_home(), "hub") + + +def set_dir(d: Union[str, os.PathLike]) -> None: + r""" + Optionally set the Torch Hub directory used to save downloaded models & weights. + + Args: + d (str): path to a local folder to save downloaded models & weights. + """ + global _hub_dir + _hub_dir = os.path.expanduser(d) + + +def list( + github, + force_reload=False, + skip_validation=False, + trust_repo=None, + verbose=True, +): + r""" + List all callable entrypoints available in the repo specified by ``github``. + + Args: + github (str): a string with format "repo_owner/repo_name[:ref]" with an optional + ref (tag or branch). If ``ref`` is not specified, the default branch is assumed to be ``main`` if + it exists, and otherwise ``master``. + Example: 'pytorch/vision:0.10' + force_reload (bool, optional): whether to discard the existing cache and force a fresh download. + Default is ``False``. + skip_validation (bool, optional): if ``False``, torchhub will check that the branch or commit + specified by the ``github`` argument properly belongs to the repo owner. This will make + requests to the GitHub API; you can specify a non-default GitHub token by setting the + ``GITHUB_TOKEN`` environment variable. Default is ``False``. + trust_repo (bool, str or None): ``"check"``, ``True``, ``False`` or ``None``. + This parameter was introduced in v1.12 and helps ensuring that users + only run code from repos that they trust. + + - If ``False``, a prompt will ask the user whether the repo should + be trusted. + - If ``True``, the repo will be added to the trusted list and loaded + without requiring explicit confirmation. + - If ``"check"``, the repo will be checked against the list of + trusted repos in the cache. If it is not present in that list, the + behaviour will fall back onto the ``trust_repo=False`` option. + - If ``None``: this will raise a warning, inviting the user to set + ``trust_repo`` to either ``False``, ``True`` or ``"check"``. This + is only present for backward compatibility and will be removed in + v2.0. + + Default is ``None`` and will eventually change to ``"check"`` in v2.0. + verbose (bool, optional): If ``False``, mute messages about hitting + local caches. Note that the message about first download cannot be + muted. Default is ``True``. + + Returns: + list: The available callables entrypoint + + Example: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB) + >>> entrypoints = torch.hub.list("pytorch/vision", force_reload=True) + """ + repo_dir = _get_cache_or_reload( + github, + force_reload, + trust_repo, + "list", + verbose=verbose, + skip_validation=skip_validation, + ) + + with _add_to_sys_path(repo_dir): + hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF) + hub_module = _import_module(MODULE_HUBCONF, hubconf_path) + + # We take functions starts with '_' as internal helper functions + entrypoints = [ + f + for f in dir(hub_module) + if callable(getattr(hub_module, f)) and not f.startswith("_") + ] + + return entrypoints + + +def help(github, model, force_reload=False, skip_validation=False, trust_repo=None): + r""" + Show the docstring of entrypoint ``model``. + + Args: + github (str): a string with format with an optional + ref (a tag or a branch). If ``ref`` is not specified, the default branch is assumed + to be ``main`` if it exists, and otherwise ``master``. + Example: 'pytorch/vision:0.10' + model (str): a string of entrypoint name defined in repo's ``hubconf.py`` + force_reload (bool, optional): whether to discard the existing cache and force a fresh download. + Default is ``False``. + skip_validation (bool, optional): if ``False``, torchhub will check that the ref + specified by the ``github`` argument properly belongs to the repo owner. This will make + requests to the GitHub API; you can specify a non-default GitHub token by setting the + ``GITHUB_TOKEN`` environment variable. Default is ``False``. + trust_repo (bool, str or None): ``"check"``, ``True``, ``False`` or ``None``. + This parameter was introduced in v1.12 and helps ensuring that users + only run code from repos that they trust. + + - If ``False``, a prompt will ask the user whether the repo should + be trusted. + - If ``True``, the repo will be added to the trusted list and loaded + without requiring explicit confirmation. + - If ``"check"``, the repo will be checked against the list of + trusted repos in the cache. If it is not present in that list, the + behaviour will fall back onto the ``trust_repo=False`` option. + - If ``None``: this will raise a warning, inviting the user to set + ``trust_repo`` to either ``False``, ``True`` or ``"check"``. This + is only present for backward compatibility and will be removed in + v2.0. + + Default is ``None`` and will eventually change to ``"check"`` in v2.0. + Example: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB) + >>> print(torch.hub.help("pytorch/vision", "resnet18", force_reload=True)) + """ + repo_dir = _get_cache_or_reload( + github, + force_reload, + trust_repo, + "help", + verbose=True, + skip_validation=skip_validation, + ) + + with _add_to_sys_path(repo_dir): + hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF) + hub_module = _import_module(MODULE_HUBCONF, hubconf_path) + + entry = _load_entry_from_hubconf(hub_module, model) + + return entry.__doc__ + + +def load( + repo_or_dir, + model, + *args, + source="github", + trust_repo=None, + force_reload=False, + verbose=True, + skip_validation=False, + **kwargs, +): + r""" + Load a model from a github repo or a local directory. + + Note: Loading a model is the typical use case, but this can also be used to + for loading other objects such as tokenizers, loss functions, etc. + + If ``source`` is 'github', ``repo_or_dir`` is expected to be + of the form ``repo_owner/repo_name[:ref]`` with an optional + ref (a tag or a branch). + + If ``source`` is 'local', ``repo_or_dir`` is expected to be a + path to a local directory. + + Args: + repo_or_dir (str): If ``source`` is 'github', + this should correspond to a github repo with format ``repo_owner/repo_name[:ref]`` with + an optional ref (tag or branch), for example 'pytorch/vision:0.10'. If ``ref`` is not specified, + the default branch is assumed to be ``main`` if it exists, and otherwise ``master``. + If ``source`` is 'local' then it should be a path to a local directory. + model (str): the name of a callable (entrypoint) defined in the + repo/dir's ``hubconf.py``. + *args (optional): the corresponding args for callable ``model``. + source (str, optional): 'github' or 'local'. Specifies how + ``repo_or_dir`` is to be interpreted. Default is 'github'. + trust_repo (bool, str or None): ``"check"``, ``True``, ``False`` or ``None``. + This parameter was introduced in v1.12 and helps ensuring that users + only run code from repos that they trust. + + - If ``False``, a prompt will ask the user whether the repo should + be trusted. + - If ``True``, the repo will be added to the trusted list and loaded + without requiring explicit confirmation. + - If ``"check"``, the repo will be checked against the list of + trusted repos in the cache. If it is not present in that list, the + behaviour will fall back onto the ``trust_repo=False`` option. + - If ``None``: this will raise a warning, inviting the user to set + ``trust_repo`` to either ``False``, ``True`` or ``"check"``. This + is only present for backward compatibility and will be removed in + v2.0. + + Default is ``None`` and will eventually change to ``"check"`` in v2.0. + force_reload (bool, optional): whether to force a fresh download of + the github repo unconditionally. Does not have any effect if + ``source = 'local'``. Default is ``False``. + verbose (bool, optional): If ``False``, mute messages about hitting + local caches. Note that the message about first download cannot be + muted. Does not have any effect if ``source = 'local'``. + Default is ``True``. + skip_validation (bool, optional): if ``False``, torchhub will check that the branch or commit + specified by the ``github`` argument properly belongs to the repo owner. This will make + requests to the GitHub API; you can specify a non-default GitHub token by setting the + ``GITHUB_TOKEN`` environment variable. Default is ``False``. + **kwargs (optional): the corresponding kwargs for callable ``model``. + + Returns: + The output of the ``model`` callable when called with the given + ``*args`` and ``**kwargs``. + + Example: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB) + >>> # from a github repo + >>> repo = "pytorch/vision" + >>> model = torch.hub.load( + ... repo, "resnet50", weights="ResNet50_Weights.IMAGENET1K_V1" + ... ) + >>> # from a local directory + >>> path = "/some/local/path/pytorch/vision" + >>> # xdoctest: +SKIP + >>> model = torch.hub.load(path, "resnet50", weights="ResNet50_Weights.DEFAULT") + """ + source = source.lower() + + if source not in ("github", "local"): + raise ValueError( + f'Unknown source: "{source}". Allowed values: "github" | "local".' + ) + + if source == "github": + repo_or_dir = _get_cache_or_reload( + repo_or_dir, + force_reload, + trust_repo, + "load", + verbose=verbose, + skip_validation=skip_validation, + ) + + model = _load_local(repo_or_dir, model, *args, **kwargs) + return model + + +def _load_local(hubconf_dir, model, *args, **kwargs): + r""" + Load a model from a local directory with a ``hubconf.py``. + + Args: + hubconf_dir (str): path to a local directory that contains a + ``hubconf.py``. + model (str): name of an entrypoint defined in the directory's + ``hubconf.py``. + *args (optional): the corresponding args for callable ``model``. + **kwargs (optional): the corresponding kwargs for callable ``model``. + + Returns: + a single model with corresponding pretrained weights. + + Example: + >>> # xdoctest: +SKIP("stub local path") + >>> path = "/some/local/path/pytorch/vision" + >>> model = _load_local( + ... path, + ... "resnet50", + ... weights="ResNet50_Weights.IMAGENET1K_V1", + ... ) + """ + with _add_to_sys_path(hubconf_dir): + hubconf_path = os.path.join(hubconf_dir, MODULE_HUBCONF) + hub_module = _import_module(MODULE_HUBCONF, hubconf_path) + + entry = _load_entry_from_hubconf(hub_module, model) + model = entry(*args, **kwargs) + + return model + + +def download_url_to_file( + url: str, + dst: str, + hash_prefix: Optional[str] = None, + progress: bool = True, +) -> None: + r"""Download object at the given URL to a local path. + + Args: + url (str): URL of the object to download + dst (str): Full path where object will be saved, e.g. ``/tmp/temporary_file`` + hash_prefix (str, optional): If not None, the SHA256 downloaded file should start with ``hash_prefix``. + Default: None + progress (bool, optional): whether or not to display a progress bar to stderr + Default: True + + Example: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB) + >>> # xdoctest: +REQUIRES(POSIX) + >>> torch.hub.download_url_to_file( + ... "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth", + ... "/tmp/temporary_file", + ... ) + + """ + file_size = None + req = Request(url, headers={"User-Agent": "torch.hub"}) + u = urlopen(req) + meta = u.info() + if hasattr(meta, "getheaders"): + content_length = meta.getheaders("Content-Length") + else: + content_length = meta.get_all("Content-Length") + if content_length is not None and len(content_length) > 0: + file_size = int(content_length[0]) + + # We deliberately save it in a temp file and move it after + # download is complete. This prevents a local working checkpoint + # being overridden by a broken download. + # We deliberately do not use NamedTemporaryFile to avoid restrictive + # file permissions being applied to the downloaded file. + dst = os.path.expanduser(dst) + for _ in range(tempfile.TMP_MAX): + tmp_dst = dst + "." + uuid.uuid4().hex + ".partial" + try: + f = open(tmp_dst, "w+b") + except FileExistsError: + continue + break + else: + raise FileExistsError(errno.EEXIST, "No usable temporary file name found") + + try: + if hash_prefix is not None: + sha256 = hashlib.sha256() + with tqdm( + total=file_size, + disable=not progress, + unit="B", + unit_scale=True, + unit_divisor=1024, + ) as pbar: + while True: + buffer = u.read(READ_DATA_CHUNK) + if len(buffer) == 0: + break + f.write(buffer) # type: ignore[possibly-undefined] + if hash_prefix is not None: + sha256.update(buffer) # type: ignore[possibly-undefined] + pbar.update(len(buffer)) + + f.close() + if hash_prefix is not None: + digest = sha256.hexdigest() # type: ignore[possibly-undefined] + if digest[: len(hash_prefix)] != hash_prefix: + raise RuntimeError( + f'invalid hash value (expected "{hash_prefix}", got "{digest}")' + ) + shutil.move(f.name, dst) + finally: + f.close() + if os.path.exists(f.name): + os.remove(f.name) + + +# Hub used to support automatically extracts from zipfile manually compressed by users. +# The legacy zip format expects only one file from torch.save() < 1.6 in the zip. +# We should remove this support since zipfile is now default zipfile format for torch.save(). +def _is_legacy_zip_format(filename: str) -> bool: + if zipfile.is_zipfile(filename): + infolist = zipfile.ZipFile(filename).infolist() + return len(infolist) == 1 and not infolist[0].is_dir() + return False + + +@deprecated( + "Falling back to the old format < 1.6. This support will be " + "deprecated in favor of default zipfile format introduced in 1.6. " + "Please redo torch.save() to save it in the new zipfile format.", + category=FutureWarning, +) +def _legacy_zip_load( + filename: str, + model_dir: str, + map_location: MAP_LOCATION, + weights_only: bool, +) -> dict[str, Any]: + # Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand. + # We deliberately don't handle tarfile here since our legacy serialization format was in tar. + # E.g. resnet18-5c106cde.pth which is widely used. + with zipfile.ZipFile(filename) as f: + members = f.infolist() + if len(members) != 1: + raise RuntimeError("Only one file(not dir) is allowed in the zipfile") + f.extractall(model_dir) + extraced_name = members[0].filename + extracted_file = os.path.join(model_dir, extraced_name) + return torch.load( + extracted_file, map_location=map_location, weights_only=weights_only + ) + + +def load_state_dict_from_url( + url: str, + model_dir: Optional[str] = None, + map_location: MAP_LOCATION = None, + progress: bool = True, + check_hash: bool = False, + file_name: Optional[str] = None, + weights_only: bool = False, +) -> dict[str, Any]: + r"""Loads the Torch serialized object at the given URL. + + If downloaded file is a zip file, it will be automatically + decompressed. + + If the object is already present in `model_dir`, it's deserialized and + returned. + The default value of ``model_dir`` is ``/checkpoints`` where + ``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`. + + Args: + url (str): URL of the object to download + model_dir (str, optional): directory in which to save the object + map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load) + progress (bool, optional): whether or not to display a progress bar to stderr. + Default: True + check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention + ``filename-.ext`` where ```` is the first eight or more + digits of the SHA256 hash of the contents of the file. The hash is used to + ensure unique names and to verify the contents of the file. + Default: False + file_name (str, optional): name for the downloaded file. Filename from ``url`` will be used if not set. + weights_only(bool, optional): If True, only weights will be loaded and no complex pickled objects. + Recommended for untrusted sources. See :func:`~torch.load` for more details. + + Example: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB) + >>> state_dict = torch.hub.load_state_dict_from_url( + ... "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth" + ... ) + + """ + # Issue warning to move data if old env is set + if os.getenv("TORCH_MODEL_ZOO"): + warnings.warn( + "TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead" + ) + + if model_dir is None: + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, "checkpoints") + + os.makedirs(model_dir, exist_ok=True) + + parts = urlparse(url) + filename = os.path.basename(parts.path) + if file_name is not None: + filename = file_name + cached_file = os.path.join(model_dir, filename) + if not os.path.exists(cached_file): + sys.stdout.write(f'Downloading: "{url}" to {cached_file}\n') + hash_prefix = None + if check_hash: + r = HASH_REGEX.search(filename) # r is Optional[Match[str]] + hash_prefix = r.group(1) if r else None + download_url_to_file(url, cached_file, hash_prefix, progress=progress) + + if _is_legacy_zip_format(cached_file): + return _legacy_zip_load(cached_file, model_dir, map_location, weights_only) + return torch.load(cached_file, map_location=map_location, weights_only=weights_only) diff --git a/phivenv/Lib/site-packages/torch/library.py b/phivenv/Lib/site-packages/torch/library.py new file mode 100644 index 0000000000000000000000000000000000000000..914406e94afc9bddb411a430e7200017b38f9dd3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/library.py @@ -0,0 +1,1620 @@ +# mypy: allow-untyped-defs +import contextlib +import functools +import inspect +import re +import sys +import traceback +import weakref +from collections.abc import Sequence +from typing import ( + Any, + Callable, + Literal, + Optional, + overload, + TYPE_CHECKING, + TypeVar, + Union, +) +from typing_extensions import deprecated, ParamSpec + +import torch +import torch._library as _library +from torch._library.custom_ops import ( + _cast, + _maybe_get_opdef, + custom_op, + CustomOpDef, + device_types_t, +) +from torch._library.infer_schema import infer_schema # noqa: F401 +from torch._library.triton import triton_op, wrap_triton +from torch._ops import OpOverload +from torch.types import _dtype + + +__all__ = [ + "Library", + "impl", + "define", + "fallthrough_kernel", + "impl_abstract", + "register_autocast", + "register_fake", + "register_torch_dispatch", + "register_vmap", + "get_ctx", + "custom_op", + "triton_op", + "wrap_triton", + "infer_schema", +] + +_T = TypeVar("_T") +_P = ParamSpec("_P") + +# Set containing the combination of (namespace, operator, DispatchKey) for which a new kernel has been registered +# The keys in the set are of the form `namespace + "/" + op_name + "/" + dispatch_key`. +# This set is maintained to ensure that two libraries don't try to override the exact same functionality to avoid +# libraries calling into kernels not intended to be called. +_impls: set[str] = set() +_defs: set[str] = set() + +# prim is reserved by TorchScript interpreter +_reserved_namespaces = ["prim"] + + +def fallthrough_kernel(): + """ + A dummy function to pass to ``Library.impl`` in order to register a fallthrough. + """ + raise NotImplementedError("fallthrough_kernel() should never be called.") + + +class Library: + """ + A class to create libraries that can be used to register new operators or + override operators in existing libraries from Python. + A user can optionally pass in a dispatch keyname if they only want to register + kernels corresponding to only one specific dispatch key. + + To create a library to override operators in an existing library (with name ns), set the kind to "IMPL". + To create a new library (with name ns) to register new operators, set the kind to "DEF". + To create a fragment of a possibly existing library to register operators (and bypass + the limitation that there is only one library for a given namespace), set the kind to + "FRAGMENT". + + Args: + ns: library name + kind: "DEF", "IMPL", "FRAGMENT" + dispatch_key: PyTorch dispatch key (default: "") + """ + + def __init__(self, ns, kind, dispatch_key=""): + from torch.fx.operator_schemas import _SCHEMA_TO_SIGNATURE_CACHE + + if kind not in ("IMPL", "DEF", "FRAGMENT"): + raise ValueError("Unsupported kind: ", kind) + + if ns in _reserved_namespaces and (kind == "DEF" or kind == "FRAGMENT"): + raise ValueError( + ns, + " is a reserved namespace. Please try creating a library with another name.", + ) + if torch._running_with_deploy(): + _library.utils.warn_deploy() + return + + frame = traceback.extract_stack(limit=3)[0] + filename, lineno = frame.filename, frame.lineno + self.m: Optional[Any] = torch._C._dispatch_library( + kind, ns, dispatch_key, filename, lineno + ) + self.ns = ns + self._op_defs: set[str] = set() + self._op_impls: set[str] = set() + self._registration_handles: list[torch._library.utils.RegistrationHandle] = [] + self.kind = kind + self.dispatch_key = dispatch_key + # Use a finalizer to setup the "destructor" instead of __del__. + # Python __del__ can lead to weird things (globals and locals may already + # be gone when __del__ actually gets called!). finalizers help the + # situation because it lets us capture references and keeps them alive + weakref.finalize( + self, + _del_library, + _impls, + self._op_impls, + _defs, + self._op_defs, + self._registration_handles, + self.m, + _SCHEMA_TO_SIGNATURE_CACHE, + ) + + def __repr__(self): + return f"Library(kind={self.kind}, ns={self.ns}, dispatch_key={self.dispatch_key})>" + + def define(self, schema, alias_analysis="", *, tags=()): + r"""Defines a new operator and its semantics in the ns namespace. + + Args: + schema: function schema to define a new operator. + alias_analysis (optional): Indicates if the aliasing properties of the operator arguments can be + inferred from the schema (default behavior) or not ("CONSERVATIVE"). + tags (Tag | Sequence[Tag]): one or more torch.Tag to apply to this + operator. Tagging an operator changes the operator's behavior + under various PyTorch subsystems; please read the docs for the + torch.Tag carefully before applying it. + + Returns: + name of the operator as inferred from the schema. + + Example:: + + >>> my_lib = Library("mylib", "DEF") + >>> my_lib.define("sum(Tensor self) -> Tensor") + """ + if torch._running_with_deploy(): + _library.utils.warn_deploy() + return + + # This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid + # AliasAnalysis type in C++ + if alias_analysis not in ["", "FROM_SCHEMA", "CONSERVATIVE"]: + raise RuntimeError(f"Invalid alias_analysis type {alias_analysis}") + assert self.m is not None + if isinstance(tags, torch.Tag): + tags = (tags,) + + name = schema.split("(")[0] + packet_name = name.split(".")[0] if "." in name else name + has_preexisting_packet = hasattr(torch.ops, self.ns) and hasattr( + getattr(torch.ops, self.ns), packet_name + ) + + result = self.m.define(schema, alias_analysis, tuple(tags)) + name = schema.split("(")[0] + qualname = self.ns + "::" + name + + # If the OpOverloadPacket exists already, then this means we're adding a + # new OpOverload for it. Refresh the packet to include the new OpOverload. + if has_preexisting_packet: + ns = getattr(torch.ops, self.ns) + packet = getattr(ns, packet_name) + torch._ops._refresh_packet(packet) + + self._op_defs.add(qualname) + _defs.add(qualname) + return result + + def _register_fake(self, op_name, fn, _stacklevel=1, *, allow_override=False): + r"""Registers the fake impl for an operator defined in the library.""" + if torch._running_with_deploy(): + _library.utils.warn_deploy() + return + + source = torch._library.utils.get_source(_stacklevel + 1) + frame = sys._getframe(_stacklevel) + caller_module = inspect.getmodule(frame) + # Can be none if you call register_fake from somewhere there isn't a module + # (e.g. __main__) + caller_module_name = None if caller_module is None else caller_module.__name__ + + # TODO(rzou): We're gonna need to stage this change with torchvision, + # since torchvision is github first. + if caller_module_name is not None and caller_module_name.startswith( + "torchvision." + ): + caller_module_name = None + + qualname = f"{self.ns}::{op_name}" + entry = torch._library.simple_registry.singleton.find(qualname) + if caller_module_name is not None: + func_to_register = _check_pystubs_once(fn, qualname, caller_module_name) + else: + func_to_register = fn + + handle = entry.fake_impl.register( + func_to_register, source, lib=self, allow_override=allow_override + ) + self._registration_handles.append(handle) + + def _register_torch_dispatch_rule(self, op_name, torch_dispatch_class, fn): + r"""Registers a torch_dispatch rule for the given operator and torch_dispatch_class. + + This allows for open registration to specify the behavior between the operator + and the torch_dispatch_class without needing to modify the torch_dispatch_class + or the operator directly. + + The torch_dispatch_class is either a Tensor subclass with `__torch_dispatch__` or a + TorchDispatchMode. + + If it is a Tensor subclass, we expect fn to have the following signature: + (cls, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any + + If it is a TorchDispatchMode, we expect fn to have the following signature: + (mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any + """ + if torch._running_with_deploy(): + _library.utils.warn_deploy() + return + + qualname = f"{self.ns}::{op_name}" + entry = torch._library.simple_registry.singleton.find(qualname) + handle = entry.torch_dispatch_rules.register(torch_dispatch_class, fn) + self._registration_handles.append(handle) + + def _impl_with_aoti_compile(self, op_name, dispatch_key=""): + r"""Register the operator to use the AOTI-compiled implementation. + + Args: + op_name: operator name (along with the overload) or OpOverload object. + dispatch_key: dispatch key that the input function should be registered for. By default, it uses + the dispatch key that the library was created with. + + Example:: + + >>> my_lib = Library("aten", "IMPL") + >>> my_lib._impl_with_aoti_compile("div.Tensor", "CPU") + """ + if torch._running_with_deploy(): + _library.utils.warn_deploy() + return + + if dispatch_key == "": + dispatch_key = self.dispatch_key + assert torch.DispatchKeySet(dispatch_key).has(torch._C.DispatchKey.Dense) + + if isinstance(op_name, str): + name = op_name + elif isinstance(op_name, OpOverload): + name = op_name._schema.name + overload_name = op_name._schema.overload_name + if overload_name != "": + name = name + "." + overload_name + else: + raise RuntimeError( + "_impl_with_aoti_compile should be passed either a name or an OpOverload object " + "as the first argument" + ) + + key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key + if key in _impls: + # TODO: in future, add more info about where the existing function is registered (this info is + # today already returned by the C++ warning when _impl_with_aoti_compile is called but we error out before that) + raise RuntimeError( + "This is not allowed since there's already a kernel registered from python overriding {}" + "'s behavior for {} dispatch key and {} namespace.".format( + name.split("::")[-1], dispatch_key, self.ns + ) + ) + + assert self.m is not None + impl_fn: Callable = self.m.impl_with_aoti_compile + impl_fn(self.ns, name.split("::")[-1], dispatch_key) + + _impls.add(key) + self._op_impls.add(key) + + def impl( + self, op_name, fn, dispatch_key="", *, with_keyset=False, allow_override=False + ): + r"""Registers the function implementation for an operator defined in the library. + + Args: + op_name: operator name (along with the overload) or OpOverload object. + fn: function that's the operator implementation for the input dispatch key or :func:`~fallthrough_kernel` + to register a fallthrough. + dispatch_key: dispatch key that the input function should be registered for. By default, it uses + the dispatch key that the library was created with. + with_keyset: flag controlling if the current dispatcher call keyset should be passed as the first argument + to :attr:`fn` when calling. This should be used to create the appropriate keyset for redispatch calls. + allow_override: Flag controlling if we want to override an + existing registered kernel implementation. This is by + default off, and will error you're trying to register a + kernel to a dispatch key with a kernel already + registered. + + Example:: + + >>> my_lib = Library("aten", "IMPL") + >>> def div_cpu(self, other): + >>> return self * (1 / other) + >>> my_lib.impl("div.Tensor", div_cpu, "CPU") + """ + if torch._running_with_deploy(): + _library.utils.warn_deploy() + return + + if not callable(fn): + raise TypeError( + f"Input function is required to be a callable but found type {type(fn)}" + ) + if dispatch_key == "": + dispatch_key = self.dispatch_key + + if isinstance(op_name, str): + name = op_name + elif isinstance(op_name, OpOverload): + name = op_name._schema.name + overload_name = op_name._schema.overload_name + if overload_name != "": + name = name + "." + overload_name + else: + raise RuntimeError( + "impl should be passed either a name or an OpOverload object as the first argument" + ) + + key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key + if (not allow_override) and key in _impls: + # TODO: in future, add more info about where the existing function is registered (this info is + # today already returned by the C++ warning when impl is called but we error out before that) + raise RuntimeError( + "This is not allowed since there's already a kernel registered from python overriding {}" + "'s behavior for {} dispatch key and {} namespace.".format( + name.split("::")[-1], dispatch_key, self.ns + ) + ) + + if dispatch_key == "Meta": + dispatcher_op_name = name + if "::" not in dispatcher_op_name: + dispatcher_op_name = f"{self.ns}::{dispatcher_op_name}" + + # Internally, we shouldn't be registering meta kernels for any operators that + # have CompositeImplicitAutograd kernels. + # Instead, we should be letting those decompositions run, and writing meta kernels + # only for the base operators. + if torch._C._dispatch_has_kernel_for_dispatch_key( + dispatcher_op_name, "CompositeImplicitAutograd" + ): + raise RuntimeError( + f"We should not register a meta kernel directly to the operator '{name}'," + " because it has a CompositeImplicitAutograd kernel in core." + " Instead we should let the operator decompose, and ensure that we have meta kernels" + " for the base ops that it decomposes into." + ) + + assert self.m is not None + self.m.impl( + name, + dispatch_key if dispatch_key != "" else "CompositeImplicitAutograd", + fn, + with_keyset, + ) + + _impls.add(key) + self._op_impls.add(key) + + def fallback(self, fn, dispatch_key="", *, with_keyset=False): + r"""Registers the function implementation as the fallback for the given key. + + This function only works for a library with global namespace ("_"). + + Args: + fn: function used as fallback for the given dispatch key or :func:`~fallthrough_kernel` + to register a fallthrough. + dispatch_key: dispatch key that the input function should be registered for. By default, it uses + the dispatch key that the library was created with. + with_keyset: flag controlling if the current dispatcher call keyset should be passed as the first argument + to :attr:`fn` when calling. This should be used to create the appropriate keyset for redispatch calls. + + Example:: + + >>> my_lib = Library("_", "IMPL") + >>> def fallback_kernel(op, *args, **kwargs): + >>> # Handle all autocast ops generically + >>> # ... + >>> my_lib.fallback(fallback_kernel, "Autocast") + """ + if torch._running_with_deploy(): + _library.utils.warn_deploy() + return + + if dispatch_key == "": + dispatch_key = self.dispatch_key + + if self.ns != "_": + raise RuntimeError( + f"""Fallback can only be registered using libary fragment on the global namespace "_" but it is {self.ns}""" + ) + + assert dispatch_key != "" + assert self.m is not None + + self.m.fallback(dispatch_key, fn, with_keyset) + + def _destroy(self): + if self.m is not None: + self.m.reset() + self.m = None + for handle in self._registration_handles: + handle.destroy() + self._registration_handles.clear() + global _impls + _impls -= self._op_impls + for name in self._op_defs: + # Delete the cached torch.ops.ns.foo if it was registered. + # Otherwise, accessing it leads to a segfault. + # It's possible that we only registered an overload in this Library + # and another library owns an alive overload. + # That's OK - the next time torch.ops.ns.foo gets called, it'll be + # recomputed to point at the right collection of overloads. + ns, name_with_overload = name.split("::") + name = name_with_overload.split(".")[0] + if not hasattr(torch.ops, ns): + continue + namespace = getattr(torch.ops, ns) + if not hasattr(namespace, name): + continue + delattr(namespace, name) + namespace._dir.remove(name) + + +def _del_library( + captured_impls, + op_impls, + captured_defs, + op_defs, + registration_handles, + m, + schema_to_signature_cache, +): + for op_def in op_defs: + name = op_def + overload_name = "" + if "." in op_def: + name, overload_name = op_def.split(".") + if ( + name, + overload_name, + ) in schema_to_signature_cache: + del schema_to_signature_cache[(name, overload_name)] + + captured_impls -= op_impls + captured_defs -= op_defs + for handle in registration_handles: + handle.destroy() + + if m is not None: + m.reset() + + +@contextlib.contextmanager +def _scoped_library(*args, **kwargs): + try: + lib = Library(*args, **kwargs) + yield lib + finally: + lib._destroy() + + +_keep_alive: list[Library] = [] + + +NAMELESS_SCHEMA = re.compile(r"\(.*\) -> .*") + + +@functools.singledispatch +def define(qualname, schema, *, lib=None, tags=()): + r"""Defines a new operator. + + In PyTorch, defining an op (short for "operator") is a two step-process: + - we need to define the op (by providing an operator name and schema) + - we need to implement behavior for how the operator interacts with + various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc. + + This entrypoint defines the custom operator (the first step) + you must then perform the second step by calling various + ``impl_*`` APIs, like :func:`torch.library.impl` or + :func:`torch.library.register_fake`. + + Args: + qualname (str): The qualified name for the operator. Should be + a string that looks like "namespace::name", e.g. "aten::sin". + Operators in PyTorch need a namespace to + avoid name collisions; a given operator may only be created once. + If you are writing a Python library, we recommend the namespace to + be the name of your top-level module. + schema (str): The schema of the operator. E.g. "(Tensor x) -> Tensor" + for an op that accepts one Tensor and returns one Tensor. It does + not contain the operator name (that is passed in ``qualname``). + lib (Optional[Library]): If provided, the lifetime of this operator + will be tied to the lifetime of the Library object. + tags (Tag | Sequence[Tag]): one or more torch.Tag to apply to this + operator. Tagging an operator changes the operator's behavior + under various PyTorch subsystems; please read the docs for the + torch.Tag carefully before applying it. + + Example:: + >>> import torch + >>> import numpy as np + >>> + >>> # Define the operator + >>> torch.library.define("mylib::sin", "(Tensor x) -> Tensor") + >>> + >>> # Add implementations for the operator + >>> @torch.library.impl("mylib::sin", "cpu") + >>> def f(x): + >>> return torch.from_numpy(np.sin(x.numpy())) + >>> + >>> # Call the new operator from torch.ops. + >>> x = torch.randn(3) + >>> y = torch.ops.mylib.sin(x) + >>> assert torch.allclose(y, x.sin()) + + """ + if not isinstance(qualname, str): + raise ValueError( + f"define(qualname, schema): expected qualname " + f"to be instance of str, got {type(qualname)}" + ) + namespace, name = torch._library.utils.parse_namespace(qualname) + if lib is None: + lib = Library(namespace, "FRAGMENT") + _keep_alive.append(lib) + if not NAMELESS_SCHEMA.fullmatch(schema): + raise ValueError( + f"define(qualname, schema, ...): expected schema " + f'to look like e.g. "(Tensor x) -> Tensor" but ' + f'got "{schema}"' + ) + lib.define(name + schema, alias_analysis="", tags=tags) + + +@define.register +def _(lib: Library, schema, alias_analysis=""): + """The old torch.library.define. + We're keeping this around for BC reasons + """ + + def wrap(f): + name = lib.define(schema, alias_analysis) + lib.impl(name, f) + return f + + return wrap + + +@overload +def impl( + qualname: str, + types: Union[str, Sequence[str]], + func: Literal[None] = None, + *, + lib: Optional[Library] = None, +) -> Callable[[Callable[..., object]], None]: ... + + +@overload +def impl( + qualname: str, + types: Union[str, Sequence[str]], + func: Callable[..., object], + *, + lib: Optional[Library] = None, +) -> None: ... + + +# Deprecated BC API +@overload +def impl( + lib: Library, + name: str, + dispatch_key: str = "", +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: ... + + +@functools.singledispatch +def impl( + qualname: str, + types: Union[str, Sequence[str]], + func: Optional[Callable[_P, _T]] = None, + *, + lib: Optional[Library] = None, +) -> object: + """Register an implementation for a device type for this operator. + + You may pass "default" for ``types`` to register this implementation as the + default implementation for ALL device types. + Please only use this if the implementation truly supports all device types; + for example, this is true if it is a composition of built-in PyTorch operators. + + This API may be used as a decorator. You can use nested decorators + with this API provided they return a function and are placed inside + this API (see Example 2). + + Some valid types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu". + + Args: + qualname (str): Should be a string that looks like "namespace::operator_name". + types (str | Sequence[str]): The device types to register an impl to. + lib (Optional[Library]): If provided, the lifetime of this registration + will be tied to the lifetime of the Library object. + + Examples: + >>> import torch + >>> import numpy as np + >>> # Example 1: Register function. + >>> # Define the operator + >>> torch.library.define("mylib::mysin", "(Tensor x) -> Tensor") + >>> + >>> # Add implementations for the cpu device + >>> @torch.library.impl("mylib::mysin", "cpu") + >>> def f(x): + >>> return torch.from_numpy(np.sin(x.numpy())) + >>> + >>> x = torch.randn(3) + >>> y = torch.ops.mylib.mysin(x) + >>> assert torch.allclose(y, x.sin()) + >>> + >>> # Example 2: Register function with decorator. + >>> def custom_decorator(func): + >>> def wrapper(*args, **kwargs): + >>> return func(*args, **kwargs) + 1 + >>> return wrapper + >>> + >>> # Define the operator + >>> torch.library.define("mylib::sin_plus_one", "(Tensor x) -> Tensor") + >>> + >>> # Add implementations for the operator + >>> @torch.library.impl("mylib::sin_plus_one", "cpu") + >>> @custom_decorator + >>> def f(x): + >>> return torch.from_numpy(np.sin(x.numpy())) + >>> + >>> # Call the new operator from torch.ops. + >>> x = torch.randn(3) + >>> + >>> y1 = torch.ops.mylib.sin_plus_one(x) + >>> y2 = torch.sin(x) + 1 + >>> assert torch.allclose(y1, y2) + """ + return _impl(qualname, types, func, lib=lib, disable_dynamo=False) + + +if not TYPE_CHECKING: + + @impl.register + def _( + lib: Library, name: str, dispatch_key: str = "" + ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + """Legacy torch.library.impl API. Kept around for BC""" + + def wrap(f: Callable[_P, _T]) -> Callable[_P, _T]: + lib.impl(name, f, dispatch_key) + return f + + return wrap + + +@overload +def _impl( + qualname: str, + types: Union[str, Sequence[str]], + func: Literal[None] = None, + *, + lib: Optional[Library] = None, + disable_dynamo: bool = False, +) -> Callable[[Callable[..., object]], None]: ... + + +@overload +def _impl( + qualname: str, + types: Union[str, Sequence[str]], + func: Callable[..., object], + *, + lib: Optional[Library] = None, + disable_dynamo: bool = False, +) -> None: ... + + +def _impl( + qualname: str, + types: Union[str, Sequence[str]], + func: Optional[Callable[..., object]] = None, + *, + lib: Optional[Library] = None, + disable_dynamo: bool = False, +) -> Optional[Callable[[Callable[..., object]], None]]: + # See impl() + if isinstance(types, str): + types = (types,) + keys = set({}) + for typ in types: + is_dispatch_key = torch._C._parse_dispatch_key(typ) + if is_dispatch_key: + # We also support passing a DispatchKey to impl. Please prefer using + # the higher-level torch.library APIs and only pass DispatchKey to + # torch.library.impl with caution (or even better, don't use this + # option and file an issue on GitHub for what you need). + # We don't advertise this to users because + # it is very easy to shoot yourself in the foot. + keys.add(typ) + else: + keys.add(_device_type_to_key(typ)) + + def register_(func: Callable[..., object]) -> None: + namespace, _ = torch._library.utils.parse_namespace(qualname) + + if lib is None: + use_lib = Library(namespace, "FRAGMENT") + _keep_alive.append(use_lib) + else: + use_lib = lib + if disable_dynamo: + + @torch._disable_dynamo + def func_no_dynamo(*args, **kwargs): + return func(*args, **kwargs) + + for key in keys: + use_lib.impl(qualname, func_no_dynamo, key) + else: + for key in keys: + use_lib.impl(qualname, func, key) + + if func is None: + return register_ + else: + register_(func) + return None + + +def _device_type_to_key(device_type: str) -> str: + if device_type == "default": + # This is technically not correct, because although all device_type + # DispatchKeys are included in CompositeExplicitAutograd, + # not everything in CompositeExplicitAutograd is associated with a + # device_type. I don't really care that much about the difference. + return "CompositeExplicitAutograd" + return torch._C._dispatch_key_for_device(device_type) + + +@deprecated( + "`torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that " + "instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.", + category=FutureWarning, +) +def impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1): + r"""This API was renamed to :func:`torch.library.register_fake` in PyTorch 2.4. + Please use that instead. + """ + if func is not None: + _stacklevel = _stacklevel + 1 + return register_fake(qualname, func, lib=lib, _stacklevel=_stacklevel) + + +_op_identifier = Union[ + str, "torch._ops.OpOverload", "torch._library.custom_ops.CustomOpDef" +] + + +def register_kernel( + op: _op_identifier, + device_types: device_types_t, + func: Optional[Callable] = None, + /, + *, + lib: Optional[Library] = None, +): + """Register an implementation for a device type for this operator. + + Some valid device_types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu". + This API may be used as a decorator. + + Args: + op (str | OpOverload): The operator to register an impl to. + device_types (None | str | Sequence[str]): The device_types to register an impl to. + If None, we will register to all device types -- please only use + this option if your implementation is truly device-type-agnostic. + func (Callable): The function to register as the implementation for + the given device types. + lib (Optional[Library]): If provided, the lifetime of this registration + + Examples:: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> import torch + >>> from torch import Tensor + >>> from torch.library import custom_op + >>> import numpy as np + >>> + >>> # Create a custom op that works on cpu + >>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu") + >>> def numpy_sin(x: Tensor) -> Tensor: + >>> x_np = x.numpy() + >>> y_np = np.sin(x_np) + >>> return torch.from_numpy(y_np) + >>> + >>> # Add implementations for the cuda device + >>> @torch.library.register_kernel("mylib::numpy_sin", "cuda") + >>> def _(x): + >>> x_np = x.cpu().numpy() + >>> y_np = np.sin(x_np) + >>> return torch.from_numpy(y_np).to(device=x.device) + >>> + >>> x_cpu = torch.randn(3) + >>> x_cuda = x_cpu.cuda() + >>> assert torch.allclose(numpy_sin(x_cpu), x_cpu.sin()) + >>> assert torch.allclose(numpy_sin(x_cuda), x_cuda.sin()) + + """ + + if not isinstance( + op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef) + ): + raise ValueError( + f"register_kernel({op}): got unexpected type for op: {type(op)}" + ) + if isinstance(op, torch._ops.OpOverload): + op = op._name + opdef = _maybe_get_opdef(op) + if opdef is not None: + return opdef.register_kernel(device_types, func) + assert isinstance(op, str) + if device_types is None: + device_types = "CompositeExplicitAutograd" + + return _impl(op, device_types, func, lib=lib, disable_dynamo=True) + + +def register_autocast( + op: _op_identifier, + device_type: str, + cast_inputs: _dtype, + /, + *, + lib: Optional[Library] = None, +): + r"""Register an autocast dispatch rule for this custom op. + + Valid `device_type` include: "cpu" and "cuda". + + Args: + op (str | OpOverload): The operator to register an autocast dispatch rule to. + device_type(str): Device type to use. 'cuda' or 'cpu'. + The type is the same as the `type` attribute of a :class:`torch.device`. + Thus, you may obtain the device type of a tensor using `Tensor.device.type`. + cast_inputs (:class:`torch.dtype`): When custom op runs in an autocast-enabled region, + casts incoming floating-point Tensors to the target dtype (non-floating-point Tensors + are not affected), then executes custom op with autocast disabled. + lib (Optional[Library]): If provided, the lifetime of this registration + + Examples:: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> import torch + >>> from torch import Tensor + >>> from torch.library import custom_op + >>> + >>> # Create a custom op that works on cuda + >>> @torch.library.custom_op("mylib::my_sin", mutates_args=()) + >>> def my_sin(x: Tensor) -> Tensor: + >>> return torch.sin(x) + >>> + >>> # Register autocast dispatch rule for the cuda device + >>> torch.library.register_autocast("mylib::my_sin", "cuda", torch.float16) + >>> + >>> x = torch.randn(3, dtype=torch.float32, device="cuda") + >>> with torch.autocast("cuda", dtype=torch.float16): + >>> y = torch.ops.mylib.my_sin(x) + >>> assert y.dtype == torch.float16 + + """ + if not isinstance( + op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef) + ): + raise ValueError( + f"register_autocast({op}): got unexpected type for op: {type(op)}" + ) + if device_type not in ["cpu", "cuda"]: + raise ValueError(f"Unknown device type: {device_type}") + + if isinstance(op, torch._ops.OpOverload): + op = op._name + opdef = _maybe_get_opdef(op) + if opdef is not None: + return opdef.register_autocast(device_type, cast_inputs) + + assert isinstance(op, str) + qualname = op + _op = torch._library.utils.lookup_op(qualname) + + namespace, opname = torch._library.utils.parse_namespace(qualname) + if lib is None: + lib = Library(namespace, "FRAGMENT") + _keep_alive.append(lib) + + def kernel(_, *args, **kwargs): + assert len(kwargs) == 0, "Custom ops do not support kwargs yet." + autocast_keyset = torch._C.DispatchKeySet( + torch._C.DispatchKey.AutocastCPU + ) | torch._C.DispatchKeySet(torch._C.DispatchKey.AutocastCUDA) + with torch._C._ExcludeDispatchKeyGuard(autocast_keyset): + return _op(*_cast(args, device_type, cast_inputs)) + + if device_type == "cuda": + return lib.impl(opname, kernel, "AutocastCUDA", with_keyset=True) + else: + # device_type is "cpu" + return lib.impl(opname, kernel, "AutocastCPU", with_keyset=True) + + +def register_fake( + op: _op_identifier, + func: Optional[Callable] = None, + /, + *, + lib: Optional[Library] = None, + _stacklevel: int = 1, + allow_override: bool = False, +): + r"""Register a FakeTensor implementation ("fake impl") for this operator. + + Also sometimes known as a "meta kernel", "abstract impl". + + An "FakeTensor implementation" specifies the behavior of this operator on + Tensors that carry no data ("FakeTensor"). Given some input Tensors with + certain properties (sizes/strides/storage_offset/device), it specifies + what the properties of the output Tensors are. + + The FakeTensor implementation has the same signature as the operator. + It is run for both FakeTensors and meta tensors. To write a FakeTensor + implementation, assume that all Tensor inputs to the operator are + regular CPU/CUDA/Meta tensors, but they do not have storage, and + you are trying to return regular CPU/CUDA/Meta tensor(s) as output. + The FakeTensor implementation must consist of only PyTorch operations + (and may not directly access the storage or data of any input or + intermediate Tensors). + + This API may be used as a decorator (see examples). + + For a detailed guide on custom ops, please see + https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html + + Args: + op_name: Operator name (along with the overload) or OpOverload object. + func: Fake tensor implementation. + lib (Optional[Library]): Library to register the fake tensor to. + allow_override: Flag controlling if we want to override an + existing registered fake impl. This is by default off, + and will error you're trying to register a fake impl to + an operator that already has a fake impl. This also only + applies if the custom operator was not created via + torch.library.custom_op, as overriding and existing fake + impl is already allowed. + + Examples: + >>> import torch + >>> import numpy as np + >>> from torch import Tensor + >>> + >>> # Example 1: an operator without data-dependent output shape + >>> @torch.library.custom_op("mylib::custom_linear", mutates_args=()) + >>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor: + >>> raise NotImplementedError("Implementation goes here") + >>> + >>> @torch.library.register_fake("mylib::custom_linear") + >>> def _(x, weight, bias): + >>> assert x.dim() == 2 + >>> assert weight.dim() == 2 + >>> assert bias.dim() == 1 + >>> assert x.shape[1] == weight.shape[1] + >>> assert weight.shape[0] == bias.shape[0] + >>> assert x.device == weight.device + >>> + >>> return (x @ weight.t()) + bias + >>> + >>> with torch._subclasses.fake_tensor.FakeTensorMode(): + >>> x = torch.randn(2, 3) + >>> w = torch.randn(3, 3) + >>> b = torch.randn(3) + >>> y = torch.ops.mylib.custom_linear(x, w, b) + >>> + >>> assert y.shape == (2, 3) + >>> + >>> # Example 2: an operator with data-dependent output shape + >>> @torch.library.custom_op("mylib::custom_nonzero", mutates_args=()) + >>> def custom_nonzero(x: Tensor) -> Tensor: + >>> x_np = x.numpy(force=True) + >>> res = np.stack(np.nonzero(x_np), axis=1) + >>> return torch.tensor(res, device=x.device) + >>> + >>> @torch.library.register_fake("mylib::custom_nonzero") + >>> def _(x): + >>> # Number of nonzero-elements is data-dependent. + >>> # Since we cannot peek at the data in an fake impl, + >>> # we use the ctx object to construct a new symint that + >>> # represents the data-dependent size. + >>> ctx = torch.library.get_ctx() + >>> nnz = ctx.new_dynamic_size() + >>> shape = [nnz, x.dim()] + >>> result = x.new_empty(shape, dtype=torch.int64) + >>> return result + >>> + >>> from torch.fx.experimental.proxy_tensor import make_fx + >>> + >>> x = torch.tensor([0, 1, 2, 3, 4, 0]) + >>> trace = make_fx(torch.ops.mylib.custom_nonzero, tracing_mode="symbolic")(x) + >>> trace.print_readable() + >>> + >>> assert torch.allclose(trace(x), torch.ops.mylib.custom_nonzero(x)) + + """ + if not isinstance( + op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef) + ): + raise ValueError(f"register_fake({op}): got unexpected type for op: {type(op)}") + if isinstance(op, torch._ops.OpOverload): + op = op._name + opdef = _maybe_get_opdef(op) + if opdef is not None: + if func is None: + return opdef.register_fake + else: + return opdef.register_fake(func) + assert isinstance(op, str) + + stacklevel = _stacklevel + + def register(func): + namespace, op_name = torch._library.utils.parse_namespace(op) + if lib is None: + use_lib = Library(namespace, "FRAGMENT") + _keep_alive.append(use_lib) + else: + use_lib = lib + use_lib._register_fake( + op_name, func, _stacklevel=stacklevel + 1, allow_override=allow_override + ) + return func + + if func is None: + return register + else: + stacklevel += 1 + return register(func) + + +def register_autograd( + op: _op_identifier, + backward: Callable, + /, + *, + setup_context: Optional[Callable] = None, + lib=None, +) -> None: + r"""Register a backward formula for this custom op. + + In order for an operator to work with autograd, you need to register + a backward formula: + 1. You must tell us how to compute gradients during the backward pass + by providing us a "backward" function. + 2. If you need any values from the forward to compute gradients, you can + use `setup_context` to save values for backward. + + ``backward`` runs during the backward pass. It accepts ``(ctx, *grads)``: + - ``grads`` is one or more gradients. The number of gradients matches + the number of outputs of the operator. + The ``ctx`` object is `the same ctx object `_ used by + :class:`torch.autograd.Function`. The semantics of ``backward_fn`` are the + same as :meth:`torch.autograd.Function.backward`. + + ``setup_context(ctx, inputs, output)`` runs during the forward pass. + Please save quantities needed for backward onto the ``ctx`` object via + either :meth:`torch.autograd.function.FunctionCtx.save_for_backward` + or assigning them as attributes of ``ctx``. If your custom op has + kwarg-only arguments, we expect the signature of ``setup_context`` + to be ``setup_context(ctx, inputs, keyword_only_inputs, output)``. + + Both ``setup_context_fn`` and ``backward_fn`` must be traceable. That is, + they may not directly access :meth:`torch.Tensor.data_ptr` and they must + not depend on or mutate global state. If you need a non-traceable backward, + you can make it a separate custom_op that you call inside ``backward_fn``. + + If you need different autograd behavior on different devices, then we + recommend creating two different custom operators, one for each device + that needs different behavior, and switching between them at runtime. + + Examples: + >>> import torch + >>> import numpy as np + >>> from torch import Tensor + >>> + >>> @torch.library.custom_op("mylib::numpy_sin", mutates_args=()) + >>> def numpy_sin(x: Tensor) -> Tensor: + >>> x_np = x.cpu().numpy() + >>> y_np = np.sin(x_np) + >>> return torch.from_numpy(y_np).to(device=x.device) + >>> + >>> def setup_context(ctx, inputs, output) -> Tensor: + >>> x, = inputs + >>> ctx.save_for_backward(x) + >>> + >>> def backward(ctx, grad): + >>> x, = ctx.saved_tensors + >>> return grad * x.cos() + >>> + >>> torch.library.register_autograd( + ... "mylib::numpy_sin", backward, setup_context=setup_context + ... ) + >>> + >>> x = torch.randn(3, requires_grad=True) + >>> y = numpy_sin(x) + >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y)) + >>> assert torch.allclose(grad_x, x.cos()) + >>> + >>> # Example with a keyword-only arg + >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) + >>> def numpy_mul(x: Tensor, *, val: float) -> Tensor: + >>> x_np = x.cpu().numpy() + >>> y_np = x_np * val + >>> return torch.from_numpy(y_np).to(device=x.device) + >>> + >>> def setup_context(ctx, inputs, keyword_only_inputs, output) -> Tensor: + >>> ctx.val = keyword_only_inputs["val"] + >>> + >>> def backward(ctx, grad): + >>> return grad * ctx.val + >>> + >>> torch.library.register_autograd( + ... "mylib::numpy_mul", backward, setup_context=setup_context + ... ) + >>> + >>> x = torch.randn(3, requires_grad=True) + >>> y = numpy_mul(x, val=3.14) + >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y)) + >>> assert torch.allclose(grad_x, torch.full_like(x, 3.14)) + + """ + if not isinstance( + op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef) + ): + raise ValueError( + f"register_autograd({op}): got unexpected type for op: {type(op)}" + ) + if isinstance(op, torch._ops.OpOverload): + op = op._name + opdef = _maybe_get_opdef(op) + if opdef is not None: + opdef.register_autograd(backward, setup_context=setup_context) + return + + assert isinstance(op, str) + qualname = op + op = torch._library.utils.lookup_op(qualname) + schema = op._schema + if not _library.utils.is_functional_schema(schema): + raise RuntimeError( + f"Cannot register autograd formula for non-functional operator " + f"{op} with schema {schema}. Please create " + f"a functional operator and register an autograd formula for that." + ) + if _library.utils.has_kwarg_only_tensors(schema): + raise NotImplementedError( + f"register_autograd with kwarg-only Tensor args. In the original " + f"definition of the op, please make your tensors not kwarg-only. " + f"Got: {schema}" + ) + + info = _library.autograd.Info(backward, setup_context) + autograd_kernel = _library.autograd.make_autograd_impl(op, info) + namespace, opname = torch._library.utils.parse_namespace(qualname) + if lib is None: + lib = Library(namespace, "FRAGMENT") + _keep_alive.append(lib) + lib.impl(opname, autograd_kernel, "Autograd", with_keyset=True) + + +def register_torch_dispatch( + op: _op_identifier, + torch_dispatch_class: Any, + func: Optional[Callable] = None, + /, + *, + lib: Optional[Library] = None, +): + r"""Registers a torch_dispatch rule for the given operator and ``torch_dispatch_class``. + + This allows for open registration to specify the behavior between the operator + and the ``torch_dispatch_class`` without needing to modify the ``torch_dispatch_class`` + or the operator directly. + + The ``torch_dispatch_class`` is either a Tensor subclass with ``__torch_dispatch__`` or a + TorchDispatchMode. + + If it is a Tensor subclass, we expect ``func`` to have the following signature: + ``(cls, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any`` + + If it is a TorchDispatchMode, we expect ``func`` to have the following signature: + ``(mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any`` + + ``args`` and ``kwargs`` will have been normalized the same way they are + in ``__torch_dispatch__`` (see :ref:`torch-dispatch-calling-convention`). + + Examples: + + >>> import torch + >>> + >>> @torch.library.custom_op("mylib::foo", mutates_args={}) + >>> def foo(x: torch.Tensor) -> torch.Tensor: + >>> return x.clone() + >>> + >>> class MyMode(torch.utils._python_dispatch.TorchDispatchMode): + >>> def __torch_dispatch__(self, func, types, args=(), kwargs=None): + >>> return func(*args, **kwargs) + >>> + >>> @torch.library.register_torch_dispatch("mylib::foo", MyMode) + >>> def _(mode, func, types, args, kwargs): + >>> x, = args + >>> return x + 1 + >>> + >>> x = torch.randn(3) + >>> y = foo(x) + >>> assert torch.allclose(y, x) + >>> + >>> with MyMode(): + >>> y = foo(x) + >>> assert torch.allclose(y, x + 1) + + """ + if not isinstance( + op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef) + ): + raise ValueError( + f"register_torch_dispatch({op}): got unexpected type for op: {type(op)}" + ) + if isinstance(op, torch._ops.OpOverload): + op = op._name + opdef = _maybe_get_opdef(op) + if opdef is not None: + return opdef.register_torch_dispatch(torch_dispatch_class, func) + assert isinstance(op, str) + + def register(func): + namespace, op_name = torch._library.utils.parse_namespace(op) + if lib is None: + use_lib = Library(namespace, "FRAGMENT") + _keep_alive.append(use_lib) + else: + use_lib = lib + use_lib._register_torch_dispatch_rule(op_name, torch_dispatch_class, func) + return func + + if func is None: + return register + else: + return register(func) + + +def register_vmap( + op: _op_identifier, + func: Optional[Callable] = None, + /, + *, + lib=None, +): + r"""Register a vmap implementation to support :func:`torch.vmap` for this custom op. + + This API may be used as a decorator (see examples). + + In order for an operator to work with :func:`torch.vmap`, you may need to register a + vmap implementation in the following signature: + + ``vmap_func(info, in_dims: Tuple[Optional[int]], *args, **kwargs)``, + + where ``*args`` and ``**kwargs`` are the arguments and kwargs for ``op``. + We do not support kwarg-only Tensor args. + + It specifies how do we compute the batched version of ``op`` given inputs with an additional + dimension (specified by ``in_dims``). + + For each arg in ``args``, ``in_dims`` has a corresponding ``Optional[int]``. It is ``None`` + if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer + specifying what dimension of the Tensor is being vmapped over. + + ``info`` is a collection of additional metadata that may be helpful: + ``info.batch_size`` specifies the size of the dimension being vmapped over, while + ``info.randomness`` is the ``randomness`` option that was passed to :func:`torch.vmap`. + + The return of the function ``func`` is a tuple of ``(output, out_dims)``. Similar to ``in_dims``, + ``out_dims`` should be of the same structure as ``output`` and contain one ``out_dim`` + per output that specifies if the output has the vmapped dimension and what index it is in. + + Examples: + >>> import torch + >>> import numpy as np + >>> from torch import Tensor + >>> from typing import Tuple + >>> + >>> def to_numpy(tensor): + >>> return tensor.cpu().numpy() + >>> + >>> lib = torch.library.Library("mylib", "FRAGMENT") + >>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=()) + >>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]: + >>> x_np = to_numpy(x) + >>> dx = torch.tensor(3 * x_np ** 2, device=x.device) + >>> return torch.tensor(x_np ** 3, device=x.device), dx + >>> + >>> def numpy_cube_vmap(info, in_dims, x): + >>> result = numpy_cube(x) + >>> return result, (in_dims[0], in_dims[0]) + >>> + >>> torch.library.register_vmap(numpy_cube, numpy_cube_vmap) + >>> + >>> x = torch.randn(3) + >>> torch.vmap(numpy_cube)(x) + >>> + >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) + >>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor: + >>> return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) + >>> + >>> @torch.library.register_vmap("mylib::numpy_mul") + >>> def numpy_mul_vmap(info, in_dims, x, y): + >>> x_bdim, y_bdim = in_dims + >>> x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) + >>> y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) + >>> result = x * y + >>> result = result.movedim(-1, 0) + >>> return result, 0 + >>> + >>> + >>> x = torch.randn(3) + >>> y = torch.randn(3) + >>> torch.vmap(numpy_mul)(x, y) + + .. note:: + The vmap function should aim to preserve the semantics of the entire custom operator. + That is, ``grad(vmap(op))`` should be replaceable with a ``grad(map(op))``. + + If your custom operator has any custom behavior in the backward pass, please + keep this in mind. + + """ + if not isinstance( + op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef) + ): + raise ValueError(f"register_vmap({op}): got unexpected type for op: {type(op)}") + if isinstance(op, torch._ops.OpOverload): + op = op._name + opdef = _maybe_get_opdef(op) + if opdef is not None: + return opdef.register_vmap(func) + assert isinstance(op, str) + qualname = op + op = torch._library.utils.lookup_op(qualname) + schema = op._schema + if _library.utils.has_kwarg_only_tensors(schema): + raise NotImplementedError( + f"register_vmap with kwarg-only Tensor args. In the original " + f"definition of the op, please make your tensors not kwarg-only. " + f"Got: {schema}" + ) + + def register(func): + nonlocal op, lib + + namespace, opname = torch._library.utils.parse_namespace(qualname) + if lib is None: + lib = Library(namespace, "FRAGMENT") + _keep_alive.append(lib) + + from torch._functorch.autograd_function import custom_function_call_vmap_helper + from torch._functorch.pyfunctorch import retrieve_current_functorch_interpreter + + def wrapped_func(keyset, *args, **kwargs): + interpreter = retrieve_current_functorch_interpreter() + return custom_function_call_vmap_helper( + interpreter, func, op, *args, **kwargs + ) + + lib.impl(opname, wrapped_func, "FuncTorchBatched", with_keyset=True) + + if func is None: + return register + else: + return register(func) + + +# If the op was defined in C++, then we want to make sure there was an +# m.set_python_module(module, ...) call and that the module is the +# same as the module that called torch.library.register_fake. +def _check_pystubs_once(func, qualname, actual_module_name): + checked = False + + def inner(*args, **kwargs): + nonlocal checked + if checked: + return func(*args, **kwargs) + + op = torch._library.utils.lookup_op(qualname) + if op._defined_in_python: + checked = True + return func(*args, **kwargs) + + maybe_pystub = torch._C._dispatch_pystub( + op._schema.name, op._schema.overload_name + ) + if maybe_pystub is None: + if torch._library.utils.requires_set_python_module(): + namespace = op.namespace + cpp_filename = op._handle.debug() + raise RuntimeError( + f"Operator '{qualname}' was defined in C++ and has a Python " + f"fake impl. In this situation, we require there to also be a " + f'companion C++ `m.set_python_module("{actual_module_name}")` ' + f"call, but we could not find one. Please add that to " + f"to the top of the C++ TORCH_LIBRARY({namespace}, ...) block the " + f"operator was registered in ({cpp_filename})" + ) + else: + pystub_module = maybe_pystub[0] + if actual_module_name != pystub_module: + cpp_filename = op._handle.debug() + raise RuntimeError( + f"Operator '{qualname}' specified that its python fake impl " + f"is in the Python module '{pystub_module}' but it was actually found " + f"in '{actual_module_name}'. Please either move the fake impl " + f"or correct the m.set_python_module call ({cpp_filename})" + ) + checked = True + return func(*args, **kwargs) + + return inner + + +# NOTE [ctx inside the fake implementation] +# If a user has an operator with data-dependent output shape, then when writing +# a fake implementation they must query the current ctx and use methods on the +# ctx to construct a new unbacked symint. +# +# This is done via us setting the global_ctx_getter function every time a fake +# implementation is invoked. +def get_ctx() -> "torch._library.fake_impl.FakeImplCtx": + """get_ctx() returns the current AbstractImplCtx object. + + Calling ``get_ctx()`` is only valid inside of an fake impl + (see :func:`torch.library.register_fake` for more usage details. + """ + return torch._library.fake_impl.global_ctx_getter() + + +_OPCHECK_DEFAULT_UTILS = ( + "test_schema", + "test_autograd_registration", + "test_faketensor", + "test_aot_dispatch_dynamic", +) + + +def opcheck( + op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, CustomOpDef], + args: tuple[Any, ...], + kwargs: Optional[dict[str, Any]] = None, + *, + test_utils: Union[str, Sequence[str]] = _OPCHECK_DEFAULT_UTILS, + raise_exception: bool = True, + atol=None, + rtol=None, +) -> dict[str, str]: + """Given an operator and some sample arguments, tests if the operator is + registered correctly. + + That is, when you use the torch.library/TORCH_LIBRARY APIs to create a + custom op, you specified metadata (e.g. mutability info) about the custom op + and these APIs require that the functions you pass them satisfy certain + properties (e.g. no data pointer access in the fake/meta/abstract kernel) + ``opcheck`` tests these metadata and properties. + + Concretely, we test the following: + + - test_schema: If the schema matches the implementation of + the operator. For example: if the schema specifies a Tensor is mutated, + then we check the implementation mutates the Tensor. If the schema + specifies that we return a new Tensor, then we check that the + implementation returns a new Tensor (instead of an existing one or + a view of an existing one). + - test_autograd_registration: If the operator supports training + (autograd): we check that its autograd formula is registered via + torch.library.register_autograd or a manual registration to one + or more DispatchKey::Autograd keys. Any other DispatchKey-based + registrations may lead to undefined behavior. + - test_faketensor: If the operator has a FakeTensor kernel + (and if it is correct). The FakeTensor kernel is necessary ( + but not sufficient) for the operator to work with PyTorch compilation + APIs (torch.compile/export/FX). We check that a FakeTensor kernel + (also sometimes known as a meta kernel) was registered for the + operator and that it is correct. This test takes the result of + running the operator on real tensors and the result of running + the operator on FakeTensors and checks that they have the same + Tensor metadata (sizes/strides/dtype/device/etc). + - test_aot_dispatch_dynamic: If the operator has correct behavior + with PyTorch compilation APIs (torch.compile/export/FX). + This checks that the outputs (and gradients, if applicable) are the + same under eager-mode PyTorch and torch.compile. + This test is a superset of ``test_faketensor`` and is an e2e test; + other things it tests are that the operator supports + functionalization and that the backward pass (if it exists) also + supports FakeTensor and functionalization. + + For best results, please call ``opcheck`` multiple times with a + representative set of inputs. If your operator supports + autograd, please use ``opcheck`` with inputs with ``requires_grad = True``; + if your operator supports multiple devices (e.g. CPU and CUDA), please + use ``opcheck`` with inputs on all supported devices. + + Args: + op: The operator. Must either be a function decorated with + :func:`torch.library.custom_op` or an OpOverload/OpOverloadPacket + found in torch.ops.* (e.g. torch.ops.aten.sin, torch.ops.mylib.foo) + args: The args to the operator + kwargs: The kwargs to the operator + test_utils: Tests that we should run. Default: all of them. + Example: ("test_schema", "test_faketensor") + raise_exception: If we should raise an exception on the first + error. If False, we will return a dict with information + on if each test passed or not. + rtol (Optional[float]): Relative tolerance for floating point comparisons. + If specified ``atol`` must also be specified. + If omitted, default values based on the ``dtype`` are selected + (see the table in :func:`torch.testing.assert_close`). + atol (Optional[float]): Absolute tolerance for floating point comparisons. + If specified ``rtol`` must also be specified. + If omitted, default values based on the ``dtype`` are selected + (see the table in :func:`torch.testing.assert_close`). + + .. warning:: + + opcheck and :func:`torch.autograd.gradcheck` test different things; + opcheck tests if your usage of torch.library APIs is correct while + :func:`torch.autograd.gradcheck` tests if your autograd formula is + mathematically correct. Use both to test custom ops that support + gradient computation. + + Example: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) + >>> def numpy_mul(x: Tensor, y: float) -> Tensor: + >>> x_np = x.numpy(force=True) + >>> z_np = x_np * y + >>> return torch.from_numpy(z_np).to(x.device) + >>> + >>> @numpy_mul.register_fake + >>> def _(x, y): + >>> return torch.empty_like(x) + >>> + >>> def setup_context(ctx, inputs, output): + >>> y, = inputs + >>> ctx.y = y + >>> + >>> def backward(ctx, grad): + >>> return grad * ctx.y, None + >>> + >>> numpy_mul.register_autograd(backward, setup_context=setup_context) + >>> + >>> sample_inputs = [ + >>> (torch.randn(3), 3.14), + >>> (torch.randn(2, 3, device='cuda'), 2.718), + >>> (torch.randn(1, 10, requires_grad=True), 1.234), + >>> (torch.randn(64, 64, device='cuda', requires_grad=True), 90.18), + >>> ] + >>> + >>> for args in sample_inputs: + >>> torch.library.opcheck(numpy_mul, args) + + """ + import torch.testing._internal.optests as optests + + return optests.opcheck( + op, + args, + kwargs, + test_utils=test_utils, + raise_exception=raise_exception, + rtol=rtol, + atol=atol, + ) diff --git a/phivenv/Lib/site-packages/torch/overrides.py b/phivenv/Lib/site-packages/torch/overrides.py new file mode 100644 index 0000000000000000000000000000000000000000..6150f52d96af5e656cd19d7669e632f4f7b3eb9c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/overrides.py @@ -0,0 +1,2117 @@ +""" +Python implementation of ``__torch_function__`` + +While most of the torch API and handling for ``__torch_function__`` happens +at the C++ level, some of the torch API is written in Python so we need +python-level handling for ``__torch_function__`` overrides as well. The main +developer-facing functionality in this file are handle_torch_function and +has_torch_function. See torch/functional.py and test/test_overrides.py +for usage examples. + +Note +---- +heavily inspired by NumPy's ``__array_function__`` (see: +https://github.com/pytorch/pytorch/issues/24015 and +https://www.numpy.org/neps/nep-0018-array-function-protocol.html +) + +If changing this file in a way that can affect ``__torch_function__`` overhead, +please report the benchmarks in ``benchmarks/overrides_benchmark``. See the +instructions in the ``README.md`` in that directory. +""" + +import __future__ # noqa: F404 + +import collections +import contextlib +import functools +import types +import warnings +from collections.abc import Iterable +from functools import wraps +from typing import Any, Callable, Optional, TypeVar +from typing_extensions import ParamSpec + +import torch +from torch._C import ( + _add_docstr, + _get_function_stack_at, + _has_torch_function, + _has_torch_function_unary, + _has_torch_function_variadic, + _is_torch_function_mode_enabled, + _len_torch_function_stack, + _pop_torch_function_stack, + _push_on_torch_function_stack, +) + + +__all__ = [ + "get_ignored_functions", + "get_overridable_functions", + "get_testing_overrides", + "handle_torch_function", + "has_torch_function", + "resolve_name", + "is_tensor_like", + "is_tensor_method_or_property", + "wrap_torch_function", + "enable_reentrant_dispatch", +] + +_P = ParamSpec("_P") +_R = TypeVar("_R") + + +def _disable_user_warnings( + func: Callable[_P, _R], + regex: str = ".*is deprecated, please use.*", + module: str = "torch", +) -> Callable[_P, _R]: + """ + Decorator that temporarily disables ``UserWarning``s for the given ``module`` if the warning message matches the + given ``regex`` pattern. + + Arguments + --------- + func : function + Function to disable the warnings for. + regex : str + A regex pattern compilable by ``re.compile``. This is used to match the ``UserWarning`` message. + module : str + The python module to which the filtering should be restricted. + + Returns + ------- + function + The wrapped function. + """ + + @wraps(func) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", category=UserWarning, message=regex, module=module + ) + return func(*args, **kwargs) + + return wrapper + + +@functools.cache +@_disable_user_warnings +def get_ignored_functions() -> set[Callable]: + """ + Return public functions that cannot be overridden by ``__torch_function__``. + + Returns + ------- + set[Callable] + A tuple of functions that are publicly available in the torch API but cannot + be overridden with ``__torch_function__``. Mostly this is because none of the + arguments of these functions are tensors or tensor-likes. + + Examples + -------- + >>> torch.Tensor.as_subclass in torch.overrides.get_ignored_functions() + True + >>> torch.add in torch.overrides.get_ignored_functions() + False + """ + Tensor = torch.Tensor + return { + torch.typename, + torch.is_tensor, + torch.is_storage, + torch.set_default_tensor_type, + torch.set_default_device, + torch.get_default_device, + torch.set_rng_state, + torch.get_rng_state, + torch.manual_seed, + torch.initial_seed, + torch.seed, + torch.save, + torch.load, + torch.set_printoptions, + torch.fork, + torch.get_default_dtype, + torch.get_num_interop_threads, + torch.get_num_threads, + torch.init_num_threads, + torch.import_ir_module, + torch.import_ir_module_from_buffer, + torch.is_anomaly_enabled, + torch.is_anomaly_check_nan_enabled, + torch.is_grad_enabled, + torch.merge_type_from_type_comment, + torch.parse_ir, + torch.parse_schema, + torch.parse_type_comment, + torch.set_anomaly_enabled, + torch.set_flush_denormal, + torch.set_num_interop_threads, + torch.set_num_threads, + torch.wait, + torch.as_tensor, + torch.from_numpy, + torch.tensor, + torch.default_generator, + torch.has_cuda, + torch.has_cudnn, + torch.has_lapack, + torch.device, + torch.dtype, + torch.finfo, + torch.has_mkl, + torch.has_mps, + torch.has_mkldnn, + torch.has_openmp, + torch.iinfo, + torch.memory_format, + torch.qscheme, + torch.set_grad_enabled, + torch.no_grad, + torch.enable_grad, + torch.inference_mode, + torch.is_inference_mode_enabled, + torch.layout, + torch.align_tensors, + torch.arange, + torch.as_strided, + torch.bartlett_window, + torch.blackman_window, + torch.broadcast_shapes, + torch.can_cast, + torch.compile, + torch.cudnn_affine_grid_generator, + torch.cudnn_batch_norm, + torch.cudnn_convolution, + torch.cudnn_convolution_transpose, + torch.cudnn_convolution_relu, + torch.cudnn_convolution_add_relu, + torch.cudnn_grid_sampler, + torch.cudnn_is_acceptable, + torch.empty, + torch.empty_permuted, + torch.empty_strided, + torch.empty_quantized, + torch.export.export, + torch.export.load, + torch.export.register_dataclass, + torch.export.save, + torch.eye, + torch.fft.fftfreq, + torch.fft.rfftfreq, + torch.from_file, + torch.full, + torch.fill, + torch.hamming_window, + torch.hann_window, + torch.kaiser_window, + torch.linspace, + torch.logspace, + torch.mkldnn_adaptive_avg_pool2d, + torch.mkldnn_convolution, + torch.mkldnn_max_pool2d, + torch.mkldnn_max_pool3d, + torch.mkldnn_linear_backward_weights, + torch.mkldnn_rnn_layer, + torch.normal, + torch.ones, + torch.promote_types, + torch.rand, + torch.randn, + torch.randint, + torch.randperm, + torch.range, + torch.result_type, + torch.scalar_tensor, + torch.sparse_coo_tensor, + torch.sparse_compressed_tensor, + torch.sparse_csr_tensor, + torch.sparse_csc_tensor, + torch.sparse_bsr_tensor, + torch.sparse_bsc_tensor, + torch.sym_constrain_range, + torch.sym_constrain_range_for_size, + torch.sym_fresh_size, + torch.tril_indices, + torch.triu_indices, + torch.vander, + torch.zeros, + torch._jit_internal.boolean_dispatch, + torch.nn.functional.assert_int_or_pair, + torch.nn.functional.upsample, + torch.nn.functional.upsample_bilinear, + torch.nn.functional.upsample_nearest, + torch.nn.functional.has_torch_function, + torch.nn.functional.has_torch_function_unary, + torch.nn.functional.has_torch_function_variadic, + torch.nn.functional.handle_torch_function, + torch.nn.functional.sigmoid, + torch.nn.functional.hardsigmoid, + torch.nn.functional.tanh, + torch.nn.functional._canonical_mask, + torch.nn.functional._none_or_dtype, + # Doesn't actually take or return tensor arguments + torch.nn.init.calculate_gain, + # These are deprecated; don't test them + torch.nn.init.uniform, + torch.nn.init.normal, + torch.nn.init.constant, + torch.nn.init.eye, + torch.nn.init.dirac, + torch.nn.init.xavier_uniform, + torch.nn.init.xavier_normal, + torch.nn.init.kaiming_uniform, + torch.nn.init.kaiming_normal, + torch.nn.init.orthogonal, + torch.nn.init.sparse, + torch.nested.to_padded_tensor, + has_torch_function, + handle_torch_function, + torch.set_autocast_enabled, + torch.is_autocast_enabled, + torch.set_autocast_dtype, + torch.get_autocast_dtype, + torch.clear_autocast_cache, + torch.set_autocast_cpu_enabled, + torch.is_autocast_cpu_enabled, + torch.set_autocast_xla_enabled, + torch.is_autocast_xla_enabled, + torch.set_autocast_ipu_enabled, + torch.is_autocast_ipu_enabled, + torch.set_autocast_cpu_dtype, + torch.get_autocast_cpu_dtype, + torch.set_autocast_ipu_dtype, + torch.get_autocast_ipu_dtype, + torch.get_autocast_gpu_dtype, + torch.set_autocast_gpu_dtype, + torch.get_autocast_xla_dtype, + torch.set_autocast_xla_dtype, + torch.autocast_increment_nesting, + torch.autocast_decrement_nesting, + torch.is_autocast_cache_enabled, + torch.set_autocast_cache_enabled, + torch.nn.functional.hardswish, + torch.is_vulkan_available, + torch.are_deterministic_algorithms_enabled, + torch.use_deterministic_algorithms, + torch.is_deterministic_algorithms_warn_only_enabled, + torch.set_deterministic_debug_mode, + torch.get_device_module, + torch.get_deterministic_debug_mode, + torch.set_float32_matmul_precision, + torch.get_float32_matmul_precision, + torch.unify_type_list, + torch.is_warn_always_enabled, + torch.set_warn_always, + torch.vitals_enabled, + torch.set_vital, + torch.read_vitals, + torch.vmap, + torch.cond, + torch.frombuffer, + torch.asarray, + torch._functional_sym_constrain_range, + torch._make_dep_token, + Tensor.__delitem__, + Tensor.__dir__, + Tensor.__getattribute__, + Tensor.__init__, + Tensor.__iter__, + Tensor.__init_subclass__, + Tensor.__delattr__, + Tensor.__setattr__, + Tensor.__torch_function__, + Tensor.__torch_dispatch__, + Tensor.__new__, + Tensor.__class__, + Tensor.__subclasshook__, + Tensor.__hash__, + Tensor.as_subclass, + Tensor.eig, + Tensor.lstsq, + Tensor.reinforce, + Tensor.new, + Tensor.new_tensor, + Tensor.new_empty, + Tensor.new_empty_strided, + Tensor.new_zeros, + Tensor.new_ones, + Tensor.new_full, + Tensor._make_subclass, + Tensor.solve, + Tensor.symeig, + Tensor.stride, + Tensor.unflatten, + Tensor.to_sparse_coo, + Tensor.to_sparse_csr, + Tensor.to_sparse_csc, + Tensor.to_sparse_bsr, + Tensor.to_sparse_bsc, + Tensor._to_sparse, + Tensor._to_sparse_csr, + Tensor._to_sparse_csc, + Tensor._to_sparse_bsr, + Tensor._to_sparse_bsc, + Tensor._typed_storage, + Tensor._reduce_ex_internal, + Tensor._fix_weakref, + Tensor._view_func, + Tensor._view_func_unsafe, + Tensor._rev_view_func_unsafe, + Tensor._make_wrapper_subclass, + Tensor._python_dispatch.__get__, + Tensor._has_symbolic_sizes_strides.__get__, + Tensor._conj, + Tensor._conj_physical, + Tensor._lazy_clone, + Tensor._neg_view, + Tensor._is_zerotensor, + Tensor._is_all_true, + Tensor._is_any_true, + Tensor._addmm_activation, + Tensor.to_padded_tensor, + Tensor._use_count, + } + + +@functools.cache +def get_default_nowrap_functions() -> set[Callable]: + """ + Return public functions that do not wrap in a subclass when invoked by + the default ``Tensor.__torch_function__`` that preserves subclasses. Typically, + these functions represent field accesses (i.e., retrieving a Tensor that + is stored somewhere on the Tensor) as opposed to computation. Users of + these functions expect object identity to be preserved over multiple accesses + (e.g., ``a.grad is a.grad``) which cannot be upheld if we're wrapping on + the fly every time (furthermore, the tensor stored here might already be + the subclass, in which case wrapping really ought not to happen). + + Not ALL property accessors have this property; for example ``Tensor.T`` actually + just creates a new transposed tensor on the fly, and so we SHOULD interpose on + these calls (you need to check the implementation of the function to see if + this is the case or not). Additionally, if a property accessor doesn't return a Tensor, + it doesn't have to be on this list (though it is harmless if it is). + """ + Tensor = torch.Tensor + return { + Tensor._base.__get__, + Tensor.grad.__get__, + Tensor._grad.__get__, + } + + +@functools.cache +@_disable_user_warnings +def get_testing_overrides() -> dict[Callable, Callable]: + """Return a dict containing dummy overrides for all overridable functions + + Returns + ------- + Dict[Callable, Callable] + A dictionary that maps overridable functions in the PyTorch API to + lambda functions that have the same signature as the real function + and unconditionally return -1. These lambda functions are useful + for testing API coverage for a type that defines ``__torch_function__``. + + Examples + -------- + >>> import inspect + >>> my_add = torch.overrides.get_testing_overrides()[torch.add] + >>> inspect.signature(my_add) + + """ + # Every function in the PyTorchAPI that can be overridden needs an entry + # in this dict. + # + # Optimally we would use inspect to get the function signature and define + # the lambda function procedurally but that is blocked by generating + # function signatures for native kernels that can be consumed by inspect. + # See Issue #28233. + Tensor = torch.Tensor + ret: dict[Callable, Callable] = { + torch.abs: lambda input, out=None: -1, + torch.absolute: lambda input, out=None: -1, + torch.adaptive_avg_pool1d: lambda input, output_size: -1, + torch.adaptive_max_pool1d: lambda inputs, output_size: -1, + torch.acos: lambda input, out=None: -1, + torch.adjoint: lambda input: -1, + torch.arccos: lambda input, out=None: -1, + torch.acosh: lambda input, out=None: -1, + torch.arccosh: lambda input, out=None: -1, + torch.add: lambda input, other, out=None: -1, + torch.addbmm: lambda input, batch1, batch2, alpha=1, beta=1, out=None: -1, + torch.addcdiv: lambda input, tensor1, tensor2, value=1, out=None: -1, + torch.addcmul: lambda input, tensor1, tensor2, value=1, out=None: -1, + torch.addmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1, + torch.addmv: lambda input, mat, vec, beta=1, alpha=1, out=None: -1, + torch.addr: lambda input, vec1, vec2, beta=1, alpha=1, out=None: -1, + torch.affine_grid_generator: lambda theta, size, align_corners: -1, + torch.all: lambda input, dim=None: -1, + torch.allclose: lambda input, other, trol=1e-05, atol=1e-08, equal_nan=False: -1, + torch.alpha_dropout: lambda input, p, train, inplace=False: -1, + torch.amax: lambda input, dim=None: -1, + torch.amin: lambda input, dim=None: -1, + torch.aminmax: lambda input, dim=None, keepdim=False, out=None: -1, + torch.angle: lambda input, out=None: -1, + torch.any: lambda input, dim=None, keepdim=False, out=None: -1, + torch.argmax: lambda input: -1, + torch.argmin: lambda input: -1, + torch.argsort: lambda input, dim=None: -1, + torch.asin: lambda input, out=None: -1, + torch._assert_async: lambda input, msg: -1, + torch.arcsin: lambda input, out=None: -1, + torch.asinh: lambda input, out=None: -1, + torch.arcsinh: lambda input, out=None: -1, + torch.atan: lambda input, out=None: -1, + torch.arctan: lambda input, out=None: -1, + torch.atan2: lambda input, other, out=None: -1, + torch.arctan2: lambda input, other, out=None: -1, + torch.atanh: lambda input, out=None: -1, + torch.arctanh: lambda input, out=None: -1, + torch.atleast_1d: lambda *tensors: -1, + torch.atleast_2d: lambda *tensors: -1, + torch.atleast_3d: lambda *tensors: -1, + torch.avg_pool1d: lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True: -1, + torch.baddbmm: lambda input, batch1, batch2, alpha=1, beta=1, out=None: -1, + torch.batch_norm: lambda input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled: -1, + torch.batch_norm_backward_elemt: lambda grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count_tensor: -1, + torch.batch_norm_backward_reduce: lambda grad_out, input, mean, invstd, weight, input_g, weight_g, bias_g: -1, + torch.batch_norm_elemt: lambda input, weight, bias, mean, invstd, eps: -1, + torch.batch_norm_gather_stats: lambda input, mean, invstd, running_mean, running_var, momentum, eps, count: -1, + torch.batch_norm_gather_stats_with_counts: lambda input, mean, invstd, running_mean, running_var, momentum, eps, count: -1, + torch.batch_norm_stats: lambda input, eps: -1, + torch.batch_norm_update_stats: lambda input, running_mean, running_var, momentum: -1, + torch.bernoulli: lambda input, generator=None, out=None: -1, + torch.bilinear: lambda input1, input2, weight, bias: -1, + torch.binary_cross_entropy_with_logits: ( + lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean", pos_weight=None: -1 + ), + torch.bincount: lambda input, weights=None, minlength=0: -1, + torch.binomial: lambda count, prob, generator=None: -1, + torch.bitwise_and: lambda input, other, out=None: -1, + torch.bitwise_not: lambda input, out=None: -1, + torch.bitwise_or: lambda input, other, out=None: -1, + torch.bitwise_xor: lambda input, other, out=None: -1, + torch.bitwise_left_shift: lambda input, other, out=None: -1, + torch.bitwise_right_shift: lambda input, other, out=None: -1, + torch.block_diag: lambda *tensors: -1, + torch.bmm: lambda input, mat2, out_dtype=None, out=None: -1, + torch.broadcast_tensors: lambda *tensors: -1, + torch.broadcast_to: lambda self, size: -1, + torch.bucketize: lambda input, boundaries, out_int32=False, right=False, out=None: -1, + torch.cartesian_prod: lambda *tensors: -1, + torch.cat: lambda tensors, dim=0, out=None: -1, + torch.concat: lambda tensors, dim=0, out=None: -1, # alias for torch.cat + torch.concatenate: lambda tensors, dim=0, out=None: -1, # alias for torch.concatenate + torch.cdist: lambda x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary": -1, + torch.ceil: lambda input, out=None: -1, + torch.celu: lambda input, alpha=1.0, inplace=False: -1, + torch.chain_matmul: lambda *matrices, out=None: -1, + torch.channel_shuffle: lambda input, groups: -1, + torch.cholesky: lambda input, upper=False, out=None: -1, + torch.linalg.cholesky: lambda input, out=None: -1, + torch.linalg.cholesky_ex: lambda input, check_errors=False, out=None: -1, + torch.cholesky_inverse: lambda input, upper=False, out=None: -1, + torch.cholesky_solve: lambda input1, input2, upper=False, out=None: -1, + torch.choose_qparams_optimized: lambda input, numel, n_bins, ratio, bit_width: -1, + torch.chunk: lambda input, chunks, dim=0: -1, + torch.clamp: lambda input, min=None, max=None, out=None: -1, + torch.clip: lambda input, min=None, max=None, out=None: -1, + torch.clamp_min: lambda input, min, out=None: -1, + torch.clamp_max: lambda input, max, out=None: -1, + torch.column_stack: lambda tensors, out=None: -1, + torch.cov: lambda input, correction=1, fweights=None, aweights=None: -1, + torch.clone: lambda input: -1, + torch.combinations: lambda input, r=2, with_replacement=False: -1, + torch.complex: lambda real, imag: -1, + torch.copysign: lambda input, other, out=None: -1, + torch.polar: lambda abs, ang: -1, + torch.linalg.cond: lambda input, ord=None: -1, + torch.conj: lambda input, out=None: -1, + torch.conj_physical: lambda input, out=None: -1, + torch.resolve_conj: lambda input, out=None: -1, + torch.resolve_neg: lambda input, out=None: -1, + torch.constant_pad_nd: lambda input, pad, value=0: -1, + torch.conv1d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1, + torch.conv2d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1, + torch.conv3d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1, + torch.convolution: lambda input, weight, bias, stride, padding, dilation, transposed, output_adding, groups: -1, + torch.conv_tbc: lambda input, weight, bias, pad=0: -1, + torch.conv_transpose1d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1, + torch.conv_transpose2d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1, + torch.conv_transpose3d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1, + torch.corrcoef: lambda input: -1, + torch.cos: lambda input, out=None: -1, + torch.cosine_embedding_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1, + torch.cosh: lambda input, out=None: -1, + torch.cosine_similarity: lambda x1, x2, dim=1, eps=1e-8: -1, + torch.count_nonzero: lambda input: -1, + torch.cross: lambda input, other, dim=None, out=None: -1, + torch.linalg.cross: lambda input, other, dim=-1, out=None: -1, + torch.ctc_loss: ( + lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction="mean", zero_infinity=False: -1 + ), + torch.cummax: lambda input, dim, out=None: -1, + torch.cummin: lambda input, dim, out=None: -1, + torch.cumprod: lambda input, dim, out=None, dtype=None: -1, + torch.cumsum: lambda input, dim, out=None, dtype=None: -1, + torch.cumulative_trapezoid: lambda y, x=None, dim=-1: -1, + torch.logcumsumexp: lambda input, dim, out=None: -1, + torch.deg2rad: lambda input, out=None: -1, + torch.dequantize: lambda input: -1, + torch.det: lambda input: -1, + torch.linalg.det: lambda input: -1, # alias for torch.det # type: ignore[attr-defined] + torch.detach: lambda input: -1, + torch.diag: lambda input, diagonal=0, out=None: -1, + torch.diag_embed: lambda input, diagonal=0, out=None: -1, + torch.diagflat: lambda input, offset=0: -1, + torch.diff: lambda input, n=1, dim=-1, prepend=None, append=None, out=None: -1, + torch.diagonal: lambda input, offset=0, dim1=0, dim2=1: -1, + torch.linalg.diagonal: lambda input, offset=0, dim1=-2, dim2=-1: -1, + torch.diagonal_scatter: lambda input, src, offset=0, dim1=0, dim2=1: -1, + torch.as_strided_scatter: lambda self, src, size, stride, storage_offset=None: -1, + torch.digamma: lambda input, out=None: -1, + torch.dist: lambda input, other, p=2: -1, + torch.div: lambda input, other, rounding_mode=None, out=None: -1, + torch.divide: lambda input, other, rounding_mode=None, out=None: -1, + torch.dot: lambda input, other, out=None: -1, + torch.dropout: lambda input, p, train, inplace=False: -1, + torch.dsmm: lambda input, mat2, out_dtype=None: -1, + torch.hsmm: lambda mat1, mat2: -1, + torch.dsplit: lambda input, indices_or_sections: -1, + torch.dstack: lambda tensors, out=None: -1, + torch.linalg.eig: lambda input, out=None: -1, + torch.linalg.eigvals: lambda input, out=None: -1, + torch.linalg.eigh: lambda input, UPLO="L", out=None: -1, + torch.linalg.eigvalsh: lambda input, UPLO="L", out=None: -1, + torch.einsum: lambda equation, *operands: -1, + torch.embedding: ( + lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False: -1 # noqa: B950 + ), + torch.embedding_bag: ( + lambda input, weight, offsets, max_norm=None, norm_type=2, scale_grad_by_freq=False, mode="mean", sparse=False, per_sample_weights=None, padding_idx=None: -1 # noqa: B950 + ), + torch.empty_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1, + torch.eq: lambda input, other, out=None: -1, + torch.equal: lambda input, other: -1, + torch.erf: lambda input, out=None: -1, + torch.erfc: lambda input, out=None: -1, + torch.erfinv: lambda input, out=None: -1, + torch.exp: lambda input, out=None: -1, + torch.exp2: lambda input, out=None: -1, + torch.expm1: lambda input, out=None: -1, + torch.fake_quantize_per_channel_affine: lambda input, scale, zero_point, axis, quant_min, quant_max: -1, + torch.fake_quantize_per_tensor_affine: lambda input, scale, zero_point, quant_min, quant_max: -1, + torch.fused_moving_avg_obs_fake_quant: ( + lambda x, observer_on, fake_quant_on, averaging_const, running_min, running_max, scale, zero_point, quant_min, quant_max, ch_axis, per_row_fake_quant=False, symmetric_quant=False: -1 # noqa: B950 + ), + torch.fbgemm_linear_fp16_weight: lambda input, packed_weight, bias: -1, + torch.fbgemm_linear_fp16_weight_fp32_activation: lambda input, packed_weight, bias: -1, + torch.fbgemm_linear_int8_weight: lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1, # noqa: B950 + torch.fbgemm_linear_int8_weight_fp32_activation: ( + lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1 + ), + torch.fbgemm_linear_quantize_weight: lambda input: -1, + torch.fbgemm_pack_gemm_matrix_fp16: lambda input: -1, + torch.fbgemm_pack_quantized_matrix: lambda input, a, b: -1, + torch.feature_alpha_dropout: lambda input, p, train: -1, + torch.feature_dropout: lambda input, p, train: -1, + torch.fft.ifft: lambda input, n=None, dim=-1, norm=None: -1, + torch.fft.rfft: lambda input, n=None, dim=-1, norm=None: -1, + torch.fft.irfft: lambda input, n=None, dim=-1, norm=None: -1, + torch.fft.hfft: lambda input, n=None, dim=-1, norm=None: -1, + torch.fft.ihfft: lambda input, n=None, dim=-1, norm=None: -1, + torch.fft.hfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1, + torch.fft.ihfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1, + torch.fft.hfftn: lambda input, s=None, dim=-1, norm=None: -1, + torch.fft.ihfftn: lambda input, s=None, dim=-1, norm=None: -1, + torch.fft.fftn: lambda input, s=None, dim=None, norm=None: -1, + torch.fft.ifftn: lambda input, s=None, dim=None, norm=None: -1, + torch.fft.rfftn: lambda input, s=None, dim=None, norm=None: -1, + torch.fft.irfftn: lambda input, s=None, dim=None, norm=None: -1, + torch.fft.fft2: lambda input, s=None, dim=(-2, -1), norm=None: -1, + torch.fft.ifft2: lambda input, s=None, dim=(-2, -1), norm=None: -1, + torch.fft.rfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1, + torch.fft.irfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1, + torch.fft.fftshift: lambda input, dim=None: -1, + torch.fft.ifftshift: lambda input, dim=None: -1, + torch.fft.fft: lambda input, n=None, dim=-1, norm=None: -1, + torch.fix: lambda input, out=None: -1, + torch.flatten: lambda input, start_dim=0, end_dim=-1: -1, + torch.flip: lambda input, dims: -1, + torch.fliplr: lambda input: -1, + torch.flipud: lambda input: -1, + torch.frobenius_norm: lambda input, dim=None, keepdim=False, out=None: -1, + torch.floor: lambda input, out=None: -1, + torch.floor_divide: lambda input, other: -1, + torch.float_power: lambda input, exponent, out=None: -1, + torch.fmod: lambda input, other, out=None: -1, + torch.frac: lambda input, out=None: -1, + torch.frexp: lambda input, out=None: -1, + torch.full_like: lambda input, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1, # noqa: B950 + torch._functional_assert_async: lambda input, msg, dep_token: -1, + torch.lu_unpack: lambda LU_data, LU_pivots, unpack_data=True, unpack_pivots=True: -1, + torch.gather: lambda input, dim, index, out=None, sparse_grad=False: -1, + torch.gcd: lambda input, other, out=None: -1, + torch.ge: lambda input, other, out=None: -1, + torch.get_device: lambda input: -1, + torch.greater_equal: lambda input, other, out=None: -1, + torch.geqrf: lambda input, out=None: -1, + torch.i0: lambda input, out=None: -1, + torch.inner: lambda input, other, out=None: -1, + torch.outer: lambda input, vec2, out=None: -1, + torch.ger: lambda input, vec2, out=None: -1, # alias for torch.outer + torch.gradient: lambda input, spacing=None, dim=None, edge_order=1: -1, + torch.grid_sampler: lambda input, grid, interpolation_mode, padding_mode, align_corners: -1, + torch.grid_sampler_2d: lambda input, grid, interpolation_mode, padding_mode, align_corners: -1, + torch.grid_sampler_3d: lambda input, grid, interpolation_mode, padding_mode, align_corners: -1, + torch.group_norm: lambda input, num_groups, weight=None, bias=None, eps=1e-05, cudnn_enabled=True: -1, + torch.gru: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1, + torch.gru_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1, + torch.gt: lambda input, other, out=None: -1, + torch.greater: lambda input, other, out=None: -1, + torch.hardshrink: lambda input, lambd=0.5: -1, + torch.heaviside: lambda input, values, out=None: -1, + torch.hinge_embedding_loss: lambda input, target, margin=1.0, size_average=None, reduce=None, reduction="mean": -1, # noqa: B950 + torch.histc: lambda input, bins=100, min=0, max=0, out=None: -1, + torch.histogram: lambda input, bins=100, min=None, max=None, weight=None, density=False, out=None: -1, + torch.histogramdd: lambda input, bins, range=None, weight=None, density=False: -1, + torch.linalg.householder_product: lambda input, tau: -1, + torch.hspmm: lambda mat1, mat2, out=None: -1, + torch.hsplit: lambda input, indices_or_sections: -1, + torch.hstack: lambda tensors, out=None: -1, + torch.hypot: lambda input, other, out=None: -1, + torch.igamma: lambda input, other, out=None: -1, + torch.igammac: lambda input, other, out=None: -1, + torch.imag: lambda input, out=None: -1, + torch.index_add: lambda input, dim, index, source: -1, + torch.index_copy: lambda input, dim, index, source: -1, + torch.index_put: lambda input, indices, values, accumulate=False: -1, + torch.index_select: lambda input, dim, index, out=None: -1, + torch.index_fill: lambda input, dim, index, value: -1, + torch.index_reduce: lambda input, dim, index, source, reduce, include_input=True: -1, + torch.isfinite: lambda tensor: -1, + torch.isin: lambda e, te, assume_unique=False, invert=False: -1, + torch.isinf: lambda tensor: -1, + torch.isreal: lambda tensor: -1, + torch.isposinf: lambda input, out=None: -1, + torch.isneginf: lambda input, out=None: -1, + torch.instance_norm: ( + lambda input, running_mean, running_var, weight, bias, use_input_stats, momentum, eps, cudnn_enabled: -1 + ), + torch.int_repr: lambda input: -1, + torch.inverse: lambda input, out=None: -1, + torch.linalg.inv: lambda input, out=None: -1, + torch.linalg.inv_ex: lambda input, check_errors=False, out=None: -1, + torch.is_complex: lambda input: -1, + torch.is_conj: lambda input: -1, + torch.is_neg: lambda input: -1, + torch.is_distributed: lambda input: -1, + torch.is_inference: lambda input: -1, + torch.is_floating_point: lambda input: -1, + torch.is_nonzero: lambda input: -1, + torch.is_same_size: lambda input, other: -1, + torch.is_signed: lambda input: -1, + torch.isclose: lambda input, other, rtol=1e-05, atol=1e-08, equal_nan=False: -1, + torch.isnan: lambda input: -1, + torch.istft: ( + lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, normalized=False, onesided=None, length=None, return_complex=False: -1 # noqa: B950 + ), + torch.kl_div: lambda input, target, size_average=None, reduce=None, reduction="mean", log_target=False: -1, + torch.kron: lambda input, other: -1, + torch.kthvalue: lambda input, k, dim=None, keepdim=False, out=None: -1, + torch.linalg.ldl_factor_ex: lambda input, hermitian=False, check_errors=False, out=None: -1, + torch.linalg.ldl_factor: lambda input, hermitian=False, out=None: -1, + torch.linalg.ldl_solve: lambda LD, pivots, B, hermitian=False, out=None: -1, + torch.layer_norm: lambda input, normalized_shape, weight=None, bias=None, esp=1e-05, cudnn_enabled=True: -1, + torch.lcm: lambda input, other, out=None: -1, + torch.ldexp: lambda input, other, out=None: -1, + torch.le: lambda input, other, out=None: -1, + torch.less_equal: lambda input, other, out=None: -1, + torch.lerp: lambda input, end, weight, out=None: -1, + torch.lgamma: lambda input, out=None: -1, + torch.lobpcg: lambda input, k=None, B=None, X=None, n=None, iK=None, niter=None, tol=None, largest=None, method=None, tracker=None, ortho_iparams=None, ortho_fparams=None, ortho_bparams=None: -1, # noqa: B950 + torch.log: lambda input, out=None: -1, + torch.log_softmax: lambda input, dim, dtype=None: -1, + torch.log10: lambda input, out=None: -1, + torch.log1p: lambda input, out=None: -1, + torch.log2: lambda input, out=None: -1, + torch.logaddexp: lambda input, other, out=None: -1, + torch.logaddexp2: lambda input, other, out=None: -1, + torch.logdet: lambda input: -1, + torch.xlogy: lambda x, y, out=None: -1, + torch.logical_and: lambda input, other, out=None: -1, + torch.logical_not: lambda input, out=None: -1, + torch.logical_or: lambda input, other, out=None: -1, + torch.logical_xor: lambda input, other, out=None: -1, + torch.logit: lambda input, eps=None: -1, + torch.logsumexp: lambda input, names, keepdim=False, out=None: -1, + torch.lstm: lambda data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional: -1, + torch.lstm_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1, + torch.lt: lambda input, other, out=None: -1, + torch.less: lambda input, other, out=None: -1, + torch.lu: lambda A, pivot=True, get_infos=False, out=None: -1, + torch.lu_solve: lambda b, LU_data, LU_pivots, out=None: -1, + torch.margin_ranking_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1, # type: ignore[attr-defined] # noqa: B950 + torch.masked_fill: lambda input, mask, value: -1, + torch.masked_scatter: lambda input, mask, source: -1, + torch.masked_select: lambda input, mask, out=None: -1, + torch.matmul: lambda input, other, out=None: -1, + torch.linalg.lu: lambda input, pivot=True, out=None: -1, + torch.linalg.lu_factor: lambda input, pivot=True, out=None: -1, + torch.linalg.lu_factor_ex: lambda input, pivot=True, check_errors=False, out=None: -1, + torch.linalg.lu_solve: lambda LU, pivots, B, left=True, adjoint=False, out=None: -1, + torch.linalg.matmul: lambda input, other, out=None: -1, # alias for torch.matmul + torch.matrix_power: lambda input, n: -1, + torch.linalg.matrix_power: lambda input, n, out=None: -1, + torch.linalg.matrix_rank: lambda input, tol=None, hermitian=False: -1, + torch.linalg.multi_dot: lambda tensors, out=None: -1, + torch.matrix_exp: lambda input: -1, + torch.linalg.matrix_exp: lambda input: -1, + torch.max: lambda input, out=None: -1, + torch.maximum: lambda input, other, out=None: -1, + torch.fmax: lambda input, other, out=None: -1, + torch.max_pool1d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1, + torch.max_pool2d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1, + torch.max_pool3d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1, + torch.max_pool1d_with_indices: ( + lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1 + ), + torch.mean: lambda input, dim=None: -1, + torch.nanmean: lambda input, dim=None, keepdim=False, dtype=None, out=None: -1, + torch.median: lambda input, dim=None: -1, + torch.nanmedian: lambda input, dim=None: -1, + torch.meshgrid: lambda *tensors, **kwargs: -1, + torch.min: lambda input, out=None: -1, + torch.minimum: lambda input, other, out=None: -1, + torch.fmin: lambda input, other, out=None: -1, + torch.miopen_batch_norm: ( + lambda input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon: -1 + ), + torch.miopen_convolution: lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1, # noqa: B950 + torch.miopen_convolution_add_relu: lambda input, weight, z, alpha, bias, stride, padding, dilation, groups: -1, + torch.miopen_convolution_relu: lambda input, weight, bias, stride, padding, dilation, groups: -1, + torch.miopen_convolution_transpose: ( + lambda input, weight, bias, padding, output_padding, stride, dilation, groups, benchmark, deterministic: -1 + ), + torch.miopen_depthwise_convolution: ( + lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1 + ), + torch.miopen_rnn: ( + lambda input, weight, weight_stride0, hx, cx, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state: -1 # noqa: B950 + ), + torch.mm: lambda input, mat2, out_dtype=None, out=None: -1, + torch.mode: lambda input, dim=-1, keepdim=False, out=None: -1, + torch.movedim: lambda input, source, destination: -1, + torch.moveaxis: lambda input, source, destination: -1, + torch.msort: lambda input, descending=False, out=None: -1, + torch.mul: lambda input, other, out=None: -1, + torch.multiply: lambda input, other, out=None: -1, + torch.multinomial: lambda input, num_samples, replacement=False, out=None: -1, + torch.mv: lambda input, vec, out=None: -1, + torch.mvlgamma: lambda input, p: -1, + torch.narrow: lambda input, dim, start, length: -1, + torch.nan_to_num: lambda input, nan=0.0, posinf=None, neginf=None, out=None: -1, + torch.native_batch_norm: lambda input, weight, bias, running_mean, running_var, training, momentum, eps: -1, + torch._native_batch_norm_legit: lambda input, weight, bias, training, momentum, eps: -1, + torch.native_dropout: lambda input, p, train: -1, + torch.native_layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1, + torch.native_group_norm: lambda input, weight, bias, N, C, HxW, group, eps: -1, + torch.native_norm: lambda input, p=2, dim=None, keepdim=False, dtype=None: -1, + torch.native_channel_shuffle: lambda input, groups: -1, + torch.ne: lambda input, other, out=None: -1, + torch.not_equal: lambda input, other, out=None: -1, + torch.neg: lambda input, out=None: -1, + torch.negative: lambda input, out=None: -1, + torch.nextafter: lambda input, other, out=None: -1, + torch.nn.functional.adaptive_avg_pool2d: lambda input, output_size: -1, + torch.nn.functional.adaptive_avg_pool3d: lambda input, output_size: -1, + torch.nn.functional.adaptive_max_pool1d: lambda input, output_size, return_indices=False: -1, + torch.nn.functional.adaptive_max_pool1d_with_indices: lambda input, output_size, return_indices=False: -1, + torch.nn.functional.adaptive_max_pool2d: lambda input, output_size, return_indices=False: -1, + torch.nn.functional.adaptive_max_pool2d_with_indices: lambda input, output_size, return_indices=False: -1, + torch.nn.functional.adaptive_max_pool3d: lambda input, output_size, return_indices=False: -1, + torch.nn.functional.adaptive_max_pool3d_with_indices: lambda input, output_size, return_indices=False: -1, + torch.nn.functional.affine_grid: lambda theta, size, align_corners=None: -1, + torch.nn.functional.alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1, + torch.nn.functional.avg_pool2d: ( + lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None: -1 # noqa: B950 + ), + torch.nn.functional.avg_pool3d: ( + lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None: -1 # noqa: B950 + ), + torch.nn.functional.batch_norm: ( + lambda input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-05: -1 + ), + torch.nn.functional.bilinear: lambda input1, input2, weight, bias=None: -1, + torch.nn.functional.binary_cross_entropy: ( + lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean": -1 + ), + torch.nn.functional.binary_cross_entropy_with_logits: ( + lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean", pos_weight=None: -1 + ), + torch.nn.functional.celu: lambda input, alpha=1.0, inplace=False: -1, + torch.nn.functional.cosine_embedding_loss: ( + lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1 + ), + torch.nn.functional.cross_entropy: ( + lambda input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean", label_smoothing=0.0: -1 # noqa: B950 + ), + torch.nn.functional.ctc_loss: ( + lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction="mean", zero_infinity=False: -1 + ), + torch.nn.functional.dropout: lambda input, p=0.5, training=True, inplace=False: -1, + torch.nn.functional.dropout1d: lambda input, p=0.5, training=True, inplace=False: -1, + torch.nn.functional.dropout2d: lambda input, p=0.5, training=True, inplace=False: -1, + torch.nn.functional.dropout3d: lambda input, p=0.5, training=True, inplace=False: -1, + torch.nn.functional.elu: lambda input, alpha=1.0, inplace=False: -1, + torch.nn.functional.embedding: ( + lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False: -1 # noqa: B950 + ), + torch.nn.functional.embedding_bag: ( + lambda input, weight, offsets=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, mode="mean", sparse=False, per_sample_weights=None, include_last_offset=False, padding_idx=None: -1 # noqa: B950 + ), + torch.nn.functional.feature_alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1, + torch.nn.functional.fold: lambda input, output_size, kernel_size, dilation=1, padding=0, stride=1: -1, + torch.nn.functional.fractional_max_pool2d: ( + lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950 + ), + torch.nn.functional.fractional_max_pool2d_with_indices: ( + lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950 + ), + torch.nn.functional.fractional_max_pool3d: ( + lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950 + ), + torch.nn.functional.fractional_max_pool3d_with_indices: ( + lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950 + ), + torch.nn.functional.gaussian_nll_loss: lambda input, target, var, full=False, eps=1e-06, reduction="mean": -1, + torch.nn.functional.gelu: lambda input, approximate="none": -1, + torch.nn.functional.glu: lambda input, dim=-1: -1, + torch.nn.functional.grid_sample: lambda input, grid, mode="bilinear", padding_mode="zeros", align_corners=None: -1, # noqa: B950 + torch.nn.functional.group_norm: lambda input, num_groups, weight=None, bias=None, eps=1e-05: -1, + torch.nn.functional.gumbel_softmax: lambda logits, tau=1, hard=False, eps=1e-10, dim=-1: -1, + torch.nn.functional.hardshrink: lambda input, lambd=0.5: -1, + torch.nn.functional.hardtanh: lambda input, min_val=-1.0, max_val=1.0, inplace=False: -1, + torch.nn.functional.hinge_embedding_loss: ( + lambda input, target, margin=1.0, size_average=None, reduce=None, reduction="mean": -1 + ), + torch.nn.functional.instance_norm: ( + lambda input, running_mean=None, running_var=None, weight=None, bias=None, use_input_stats=True, momentum=0.1, eps=1e-05: -1 # noqa: B950 + ), + torch.nn.functional.interpolate: ( + lambda input, size=None, scale_factor=None, mode="nearest", align_corners=None, recompute_scale_factor=None, antialias=False: -1 # noqa: B950 + ), + torch.nn.functional.kl_div: lambda input, target, size_average=None, reduce=None, reduction="mean", log_target=False: -1, # noqa: B950 + torch.nn.functional.l1_loss: lambda input, target, size_average=None, reduce=None, reduction="mean", weight=None: -1, + torch.nn.functional.layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1, + torch.nn.functional.leaky_relu: lambda input, negative_slope=0.01, inplace=False: -1, + torch.nn.functional.linear: lambda input, weight, bias=None: -1, + torch.nn.functional.local_response_norm: lambda input, size, alpha=0.0001, beta=0.75, k=1.0: -1, + torch.nn.functional.log_softmax: lambda input, dim=None, _stacklevel=3, dtype=None: -1, + torch.nn.functional.logsigmoid: lambda input: -1, + torch.nn.functional.lp_pool1d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1, + torch.nn.functional.lp_pool2d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1, + torch.nn.functional.lp_pool3d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1, + torch.nn.functional.margin_ranking_loss: ( + lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1 + ), + torch.nn.functional.max_pool1d: ( + lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False: -1 + ), + torch.nn.functional.max_pool1d_with_indices: ( + lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1 + ), + torch.nn.functional.max_pool2d: ( + lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False: -1 + ), + torch.nn.functional.max_pool2d_with_indices: ( + lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1 + ), + torch.nn.functional.max_pool3d: ( + lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1 + ), + torch.nn.functional.max_pool3d_with_indices: ( + lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1 + ), + torch.nn.functional.max_unpool1d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, # noqa: B950 + torch.nn.functional.max_unpool2d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, # noqa: B950 + torch.nn.functional.max_unpool3d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, # noqa: B950 + torch.nn.functional.mse_loss: lambda input, target, size_average=None, reduce=None, reduction="mean", weight=None: -1, + torch.nn.functional.multi_head_attention_forward: ( + lambda query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training=True, key_padding_mask=None, need_weights=True, attn_mask=None, use_separate_proj_weight=False, q_proj_weight=None, k_proj_weight=None, v_proj_weight=None, static_k=None, static_v=None, average_attn_weights=None, is_causal=False: -1 # noqa: B950 + ), + torch.nn.functional.multi_margin_loss: ( + lambda input, target, p=1, margin=1.0, weight=None, size_average=None, reduce=None, reduction="mean": -1 + ), + torch.nn.functional.multilabel_margin_loss: ( + lambda input, target, size_average=None, reduce=None, reduction="mean": -1 + ), + torch.nn.functional.multilabel_soft_margin_loss: ( + lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean": -1 + ), + torch.nn.functional.nll_loss: ( + lambda input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean": -1 + ), + torch.nn.functional.normalize: lambda input, p=2, dim=1, eps=1e-12, out=None: -1, + torch.nn.functional.one_hot: lambda tensor, num_classes=-1: -1, + torch.nn.functional.pad: lambda input, pad, mode="constant", value=0: -1, + torch.nn.functional.pairwise_distance: lambda x1, x2, p=2.0, eps=1e-06, keepdim=False: -1, + torch.nn.functional.poisson_nll_loss: ( + lambda input, target, log_input=True, full=False, size_average=None, eps=1e-08, reduce=None, reduction="mean": -1 # noqa: B950 + ), + torch.nn.functional.prelu: lambda input, weight: -1, + torch.nn.functional.relu: lambda input, inplace=False: -1, + torch.nn.functional.relu6: lambda input, inplace=False: -1, + torch.nn.functional.rms_norm: lambda input, normalized_shape, weight=None, eps=1e-6: -1, + torch.nn.functional.rrelu: lambda input, lower=0.125, upper=0.3333333333333333, training=False, inplace=False: -1, # noqa: B950 + torch.nn.functional.selu: lambda input, inplace=False: -1, + torch.nn.functional.silu: lambda input, inplace=False: -1, + torch.nn.functional.mish: lambda input, inplace=False: -1, + torch.nn.functional.scaled_dot_product_attention: lambda query, key, value, attn_mask=None, dropout_p=0.0: -1, + torch.nn.functional.smooth_l1_loss: lambda input, target, size_average=None, reduce=None, reduction="mean", beta=1.0: -1, # noqa: B950 + torch.nn.functional.huber_loss: lambda input, target, reduction="mean", delta=1.0, weight=None: -1, + torch.nn.functional.soft_margin_loss: lambda input, target, size_average=None, reduce=None, reduction="mean": -1, # noqa: B950 + torch.nn.functional.softmax: lambda input, dim=None, _stacklevel=3, dtype=None: -1, + torch.nn.functional.softmin: lambda input, dim=None, _stacklevel=3, dtype=None: -1, + torch.nn.functional.softplus: lambda input, beta=1, threshold=20: -1, + torch.nn.functional.softshrink: lambda input, lambd=0.5: -1, + torch.nn.functional.softsign: lambda input: -1, + torch.nn.functional.tanhshrink: lambda input: -1, + torch.nn.functional.threshold: lambda input, threshold, value, inplace=False: -1, + torch.nn.functional.triplet_margin_loss: ( + lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False, size_average=None, reduce=None, reduction="mean": -1 # noqa: B950 + ), + torch.nn.functional.triplet_margin_with_distance_loss: ( + lambda anchor, positive, negative, *, distance_function=None, margin=1.0, swap=False, reduction="mean": -1 + ), + torch.nn.functional.unfold: lambda input, kernel_size, dilation=1, padding=0, stride=1: -1, + torch.nn.init.uniform_: lambda tensor, a=0.0, b=1.0, generator=None: -1, + torch.nn.init.normal_: lambda tensor, mean=0.0, std=1.0, generator=None: -1, + torch.nn.init.constant_: lambda tensor, val: -1, + torch.nn.init.kaiming_uniform_: lambda tensor, a=0, mode="fan_in", nonlinearity="leaky_relu", generator=None: -1, # noqa: B950 + torch.nonzero: lambda input, as_tuple=False: -1, + torch.nonzero_static: lambda input, *, size, fill_value=-1: -1, + torch.argwhere: lambda input: -1, + torch.norm: lambda input, p="fro", dim=None, keepdim=False, out=None, dtype=None: -1, + torch.linalg.norm: lambda input, ord=None, dim=None, keepdim=False, out=None, dtype=None: -1, + torch.linalg.vector_norm: lambda input, ord=2, dim=None, keepdim=False, out=None, dtype=None: -1, + torch.linalg.matrix_norm: lambda input, ord="fro", dim=( + -2, + -1, + ), keepdim=False, out=None, dtype=None: -1, + torch.norm_except_dim: lambda v, pow=2, dim=0: -1, + torch.nuclear_norm: lambda input, p="fro", dim=None, keepdim=False, out=None, dtype=None: -1, + torch.numel: lambda input: -1, + torch.orgqr: lambda input, tau: -1, + torch.ormqr: lambda input, input2, input3, left=True, transpose=False: -1, + torch.pairwise_distance: lambda x1, x2, p=2.0, eps=1e-06, keepdim=False: -1, + torch.permute: lambda self, dim: -1, + torch.pca_lowrank: lambda input, q=None, center=True, niter=2: -1, + torch.pdist: lambda input, p=2: -1, + torch.pinverse: lambda input, rcond=1e-15: -1, + torch.linalg.pinv: lambda input, rcond=1e-15, hermitian=False: -1, + torch.pixel_shuffle: lambda input, upscale_factor: -1, + torch.pixel_unshuffle: lambda input, downscale_factor: -1, + torch.poisson: lambda input, generator=None: -1, + torch.poisson_nll_loss: lambda input, target, log_input, full, eps, reduction: -1, + torch.polygamma: lambda input, n, out=None: -1, + torch.positive: lambda input, out=None: -1, + torch.prelu: lambda input, weight: -1, + torch.ones_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1, + torch.pow: lambda input, exponent, out=None: -1, + torch.prod: lambda input, dtype=None: -1, + torch.put: lambda input, index, source, accumulate=False: -1, + torch.q_per_channel_axis: lambda input: -1, + torch.q_per_channel_scales: lambda input: -1, + torch.q_per_channel_zero_points: lambda input: -1, + torch.q_scale: lambda input: -1, + torch.q_zero_point: lambda input: -1, + torch.qr: lambda input, some=True, out=None: -1, + torch.linalg.qr: lambda input, mode="reduced", out=None: -1, + torch.quantile: lambda input, q, dim=None, keepdim=False, interpolation="linear", out=None: -1, + torch.nanquantile: lambda input, q, dim=None, keepdim=False, interpolation="linear", out=None: -1, + torch.quantize_per_channel: lambda input, scales, zero_points, axis, dtype: -1, + torch.quantize_per_tensor: lambda input, scale, zero_point, dtype: -1, + torch.quantize_per_tensor_dynamic: lambda input, dtype, reduce_range: -1, + torch.quantized_batch_norm: lambda input, weight, bias, mean, var, eps, output_scale, output_zero_point: -1, + torch.quantized_gru_cell: ( + lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950 + ), + torch.quantized_lstm_cell: ( + lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950 + ), + torch.quantized_max_pool1d: ( + lambda input, kernel_size, stride=(), padding=(0,), dilation=( + 1, + ), ceil_mode=False: -1 + ), + torch.quantized_max_pool2d: ( + lambda input, kernel_size, stride=(), padding=(0, 0), dilation=( + 1, + 1, + ), ceil_mode=False: -1 + ), + torch.quantized_max_pool3d: ( + lambda input, kernel_size, stride=(), padding=(0, 0, 0), dilation=( + 1, + 1, + 1, + ), ceil_mode=False: -1 + ), + torch.quantized_rnn_relu_cell: ( + lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950 + ), + torch.quantized_rnn_tanh_cell: ( + lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950 + ), + torch.rad2deg: lambda input, out=None: -1, + torch.rand_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1, + torch.randint_like: lambda input, high, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1, + torch.randn_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1, + torch.ravel: lambda input: -1, + torch.real: lambda input, out=None: -1, + torch.vdot: lambda input, other, out=None: -1, + torch.linalg.vecdot: lambda input, other, dim=-1, out=None: -1, + torch.view_as_real: lambda input: -1, + torch.view_as_complex: lambda input: -1, + torch.reciprocal: lambda input, out=None: -1, + torch.relu: lambda input, inplace=False: -1, + torch.remainder: lambda input, other, out=None: -1, + torch.renorm: lambda input, p, dim, maxnorm, out=None: -1, + torch.repeat_interleave: lambda input, dim=None: -1, + torch.reshape: lambda input, shape: -1, + torch.rms_norm: lambda input, normalized_shape, weight=None, eps=1e-6: -1, + torch.rnn_relu: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1, # noqa: B950 + torch.rnn_relu_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1, + torch.rnn_tanh: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1, # noqa: B950 + torch.rnn_tanh_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1, + torch.roll: lambda input, shifts, dims=None: -1, + torch.rot90: lambda input, k=1, dims=(0, 1): -1, + torch.round: lambda input, out=None: -1, + torch.row_stack: lambda tensors, out=None: -1, # alias for torch.vstack + torch._rowwise_prune: (lambda weight, mask, compressed_indices_dtype: -1), + torch.rrelu: lambda input, lower=1.0 / 8, upper=1.0 / 3, training=False, inplace=False: -1, + torch.rsqrt: lambda input, out=None: -1, + torch.rsub: lambda input, other, alpha=1: -1, + torch.saddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1, + torch.scatter: lambda input, dim, index, src: -1, + torch.scatter_add: lambda input, dim, index, src: -1, + torch.scatter_reduce: lambda input, dim, index, src, reduce, include_self=True: -1, + torch.searchsorted: lambda sorted_sequence, input, out_int32=False, right=False, out=None: -1, + torch._segment_reduce: lambda data, reduce="max", lengths=None, indices=None, offsets=None, axis=0, unsafe=False: -1, # noqa: B950 + torch.select: lambda input, dim, index: -1, + torch.select_scatter: lambda input, src, dim, index: -1, + torch.slice_inverse: lambda input, src, dim=0, start=None, end=None, step=1: -1, + torch.slice_scatter: lambda input, src, dim=0, start=None, end=None, step=1: -1, + torch.selu: lambda input, inplace=False: -1, + torch.sigmoid: lambda input, out=None: -1, + torch.sign: lambda input, out=None: -1, + torch.signbit: lambda input, out=None: -1, + torch.sgn: lambda input, out=None: -1, + torch.sin: lambda input, out=None: -1, + torch.sinc: lambda input, out=None: -1, + torch.sinh: lambda input, out=None: -1, + torch.slogdet: lambda input: -1, + torch.linalg.slogdet: lambda input: -1, + torch.smm: lambda input, mat2, out_dtype=None: -1, + torch.spmm: lambda input, mat2, out_dtype=None: -1, + torch.softmax: lambda input, dim, dtype=None: -1, + torch.linalg.solve: lambda A, B, left=True, out=None: -1, + torch.linalg.solve_ex: lambda A, B, left=True, check_errors=False, out=None: -1, + torch.sort: lambda input, dim=-1, descending=False, *, stable=False, out=None: -1, + torch.split: lambda tensor, split_size_or_sections, dim=0: -1, + torch.split_with_sizes: lambda tensor, split_size_or_sections, dim=0: -1, + torch.sqrt: lambda input, out=None: -1, + torch.square: lambda input, out=None: -1, + torch.squeeze: lambda input, dim=None, out=None: -1, + torch.sspaddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1, + torch.stack: lambda tensors, dim=0, out=None: -1, + torch.std: lambda input, dim=None: -1, + torch.std_mean: lambda input, dim=None: -1, + torch.stft: ( + lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, pad_mode="reflect", normalized=False, onesided=True, return_complex=None, align_to_window=None: -1 # noqa: B950 + ), + torch.sub: lambda input, other, out=None: -1, + torch.subtract: lambda input, other, out=None: -1, + torch.sum: lambda input, dim=None: -1, + torch.sym_float: lambda input: -1, + torch.sym_int: lambda input: -1, + torch.sym_max: lambda a, b: -1, + torch.sym_min: lambda a, b: -1, + torch.sym_not: lambda input: -1, + torch.sym_ite: lambda a, b, c: -1, + torch.sym_sum: lambda args: -1, + torch._sym_sqrt: lambda input: -1, + torch._sym_cos: lambda input: -1, + torch._sym_cosh: lambda input: -1, + torch._sym_sin: lambda input: -1, + torch._sym_sinh: lambda input: -1, + torch._sym_tan: lambda input: -1, + torch._sym_tanh: lambda input: -1, + torch._sym_asin: lambda input: -1, + torch._sym_acos: lambda input: -1, + torch._sym_atan: lambda input: -1, + torch.nansum: lambda input, dim=None: -1, + torch.svd: lambda input, some=True, compute_uv=True, out=None: -1, + torch.svd_lowrank: lambda input, q=6, niter=2, M=None: -1, + torch.linalg.svd: lambda input, full_matrices=True, out=None: -1, + torch.linalg.svdvals: lambda input, out=None: -1, + torch.swapaxes: lambda input, dim0, dim1: -1, + torch.swapdims: lambda input, axis0, axis1: -1, + torch.special.airy_ai: lambda input: -1, + torch.special.bessel_j0: lambda input: -1, + torch.special.bessel_j1: lambda input: -1, + torch.special.bessel_y0: lambda input: -1, + torch.special.bessel_y1: lambda input: -1, + torch.special.chebyshev_polynomial_t: lambda input, n, out=None: -1, + torch.special.chebyshev_polynomial_u: lambda input, n, out=None: -1, + torch.special.chebyshev_polynomial_v: lambda input, n, out=None: -1, + torch.special.chebyshev_polynomial_w: lambda input, n, out=None: -1, + torch.special.digamma: lambda input: -1, + torch.special.entr: lambda input: -1, + torch.special.erf: lambda input: -1, + torch.special.erfc: lambda input: -1, + torch.special.erfcx: lambda input: -1, + torch.special.erfinv: lambda input: -1, + torch.special.exp2: lambda input: -1, + torch.special.expit: lambda input: -1, + torch.special.expm1: lambda input: -1, + torch.special.gammainc: lambda input, other, out=None: -1, + torch.special.gammaincc: lambda input, other, out=None: -1, + torch.special.gammaln: lambda input: -1, + torch.special.hermite_polynomial_h: lambda input, n, out=None: -1, + torch.special.hermite_polynomial_he: lambda input, n, out=None: -1, + torch.special.i0: lambda input: -1, + torch.special.i0e: lambda input: -1, + torch.special.i1: lambda input: -1, + torch.special.i1e: lambda input: -1, + torch.special.laguerre_polynomial_l: lambda input, n, out=None: -1, + torch.special.legendre_polynomial_p: lambda input, n, out=None: -1, + torch.special.log1p: lambda input: -1, + torch.special.log_ndtr: lambda input: -1, + torch.special.log_softmax: lambda input, dim, dtype=None: -1, + torch.special.logit: lambda input: -1, + torch.special.logsumexp: lambda input, dim, keepdim=False, out=None: -1, + torch.special.modified_bessel_i0: lambda input: -1, + torch.special.modified_bessel_i1: lambda input: -1, + torch.special.modified_bessel_k0: lambda input: -1, + torch.special.modified_bessel_k1: lambda input: -1, + torch.special.multigammaln: lambda input, p: -1, + torch.special.ndtr: lambda input: -1, + torch.special.ndtri: lambda input: -1, + torch.special.polygamma: lambda input, n, out=None: -1, + torch.special.psi: lambda input: -1, + torch.special.round: lambda input: -1, + torch.special.scaled_modified_bessel_k0: lambda input: -1, + torch.special.scaled_modified_bessel_k1: lambda input: -1, + torch.special.shifted_chebyshev_polynomial_t: lambda input, n, out=None: -1, + torch.special.shifted_chebyshev_polynomial_u: lambda input, n, out=None: -1, + torch.special.shifted_chebyshev_polynomial_v: lambda input, n, out=None: -1, + torch.special.shifted_chebyshev_polynomial_w: lambda input, n, out=None: -1, + torch.special.sinc: lambda input: -1, + torch.special.softmax: lambda input, dim, dtype=None: -1, + torch.special.spherical_bessel_j0: lambda input: -1, + torch.special.xlog1py: lambda input, other, out=None: -1, + torch.special.xlogy: lambda input, other, out=None: -1, + torch.special.zeta: lambda self, other, out=None: -1, + torch.t: lambda input: -1, + torch.take: lambda input, index: -1, + torch.take_along_dim: lambda input, indices, dim=None, out=None: -1, + torch.tan: lambda input, out=None: -1, + torch.tanh: lambda input, out=None: -1, + torch.linalg.tensorinv: lambda a, ind=2: -1, + torch.linalg.tensorsolve: lambda a, b, dims=None: -1, + torch.tensordot: lambda a, b, dims=2, out=None: -1, + torch.tensor_split: lambda input, indices_or_sections, dim=0: -1, + torch.threshold: lambda input, threshold, value, inplace=False: -1, + torch.tile: lambda input, dims: -1, + torch.topk: lambda input, k, dim=-1, descending=False, out=None: -1, + torch.trace: lambda input: -1, + torch.transpose: lambda input, dim0, dim1: -1, + torch.trapz: lambda y, x=None, dim=-1: -1, + torch.trapezoid: lambda y, x=None, dim=-1: -1, + torch.triangular_solve: lambda input, A, upper=True, transpose=False, unitriangular=False: -1, + torch.linalg.solve_triangular: lambda input, B, upper, left=True, unitriangular=False: -1, + torch.tril: lambda input, diagonal=0, out=None: -1, + torch.triplet_margin_loss: ( + lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False, size_average=None, reduce=None, reduction="mean": -1 # noqa: B950 + ), + torch.triu: lambda input, diagonal=0, out=None: -1, + torch.true_divide: lambda input, other: -1, + torch.trunc: lambda input, out=None: -1, + torch.unbind: lambda input, dim=0: -1, + torch.unflatten: lambda input, dim, sizes, names: -1, + torch.unique: lambda input, sorted=True, return_inverse=False, return_counts=False, dim=None: -1, + torch.unique_consecutive: lambda input, return_inverse=False, return_counts=False, dim=None: -1, + torch.unravel_index: lambda indices, shape: -1, + torch.unsafe_chunk: lambda input, chunks, dim=0: -1, + torch.unsafe_split: lambda tensor, split_size_or_sections, dim=0: -1, + torch.unsafe_split_with_sizes: lambda tensor, split_size_or_sections, dim=0: -1, + torch.unsqueeze: lambda input, dim, out=None: -1, + torch.linalg.vander: lambda x, N=None: -1, + torch.var: lambda input, dim=None: -1, + torch.var_mean: lambda input, dim=None: -1, + torch.vsplit: lambda input, indices_or_sections: -1, + torch.vstack: lambda tensors, out=None: -1, + torch.where: lambda condition, x=None, y=None: -1, + torch._wrapped_linear_prepack: lambda weight, weight_scale, weight_zero_point, bias : -1, + torch._wrapped_quantized_linear_prepacked: ( + lambda input, input_scale, input_zero_point, prepacked, out_scale, out_zero_point, out_channel : -1 # noqa: B950 + ), + torch.zeros_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1, + torch._fw_primal_copy: lambda self, level: -1, + torch._make_dual_copy: lambda primal, tangent, level: -1, + torch.view_as_real_copy: lambda self: -1, + torch.view_as_complex_copy: lambda self: -1, + torch._conj_copy: lambda self: -1, + torch._neg_view_copy: lambda self: -1, + torch.as_strided_copy: lambda self, size, stride, storage_offset=None: -1, + torch._sparse_broadcast_to_copy: lambda self, size: -1, + torch.diagonal_copy: lambda self, offset=0, dim1=0, dim2=1: -1, + torch.expand_copy: lambda self, size, *, implicit=False: -1, + torch.narrow_copy: lambda self, dim, start, length: -1, + torch.permute_copy: lambda self, dims: -1, + torch._reshape_alias_copy: lambda self, size, stride: -1, + torch.select_copy: lambda self, dim, index: -1, + torch.detach_copy: lambda self: -1, + torch.slice_copy: lambda self, dim=0, start=None, end=None, step=1: -1, + torch.split_copy: lambda self, split_size, dim=0: -1, + torch.split_with_sizes_copy: lambda self, split_sizes, dim=0: -1, + torch.squeeze_copy: lambda self, dim: -1, + torch.t_copy: lambda self: -1, + torch.transpose_copy: lambda self, dim0, dim1: -1, + torch.unsqueeze_copy: lambda self, dim: -1, + torch._indices_copy: lambda self: -1, + torch._values_copy: lambda self: -1, + torch.indices_copy: lambda self: -1, + torch.values_copy: lambda self: -1, + torch.crow_indices_copy: lambda self: -1, + torch.col_indices_copy: lambda self: -1, + torch.ccol_indices_copy: lambda self: -1, + torch.row_indices_copy: lambda self: -1, + torch.unbind_copy: lambda self, dim=0: -1, + torch.view_copy: lambda self, dtype: -1, + torch.unfold_copy: lambda self, dimension, size, step: -1, + torch.alias_copy: lambda self: -1, + Tensor.__floordiv__: lambda self, other: -1, + Tensor.__rfloordiv__: lambda self, other: -1, + Tensor.__ifloordiv__: lambda self, other: -1, + Tensor.__truediv__: lambda self, other: -1, + Tensor.__rtruediv__: lambda self, other: -1, + Tensor.__itruediv__: lambda self, other: -1, + Tensor.__lshift__: lambda self, other: -1, + Tensor.__rlshift__: lambda self, other: -1, + Tensor.__ilshift__: lambda self, other: -1, + Tensor.__rshift__: lambda self, other: -1, + Tensor.__rrshift__: lambda self, other: -1, + Tensor.__irshift__: lambda self, other: -1, + Tensor.__and__: lambda self, other: -1, + Tensor.__or__: lambda self, other: -1, + Tensor.__xor__: lambda self, other: -1, + Tensor.__float__: lambda self: -1, + Tensor.__complex__: lambda self: -1, + Tensor.__array__: lambda self, dtype: -1, + Tensor.__bool__: lambda self: -1, + Tensor.__contains__: lambda self, other: -1, + Tensor.__neg__: lambda self: -1, + Tensor.__invert__: lambda self: -1, + Tensor.__mod__: lambda self, other: -1, + Tensor.__rmod__: lambda self, other: -1, + Tensor.__imod__: lambda self, other: -1, + Tensor.__array_wrap__: lambda self, array: -1, + Tensor.__getitem__: lambda self, idx: -1, + Tensor.__deepcopy__: lambda self, memo: -1, + Tensor.__int__: lambda self: -1, + Tensor.__long__: lambda self: -1, + Tensor.__index__: lambda self: -1, + Tensor.__len__: lambda self: -1, + Tensor.__format__: lambda self, format_spec: -1, + Tensor.__reduce_ex__: lambda self, proto: -1, + Tensor.__reversed__: lambda self: -1, + Tensor.__repr__: lambda self, *, tensor_contents=None: -1, + Tensor.__setitem__: lambda self, k, v: -1, + Tensor.__setstate__: lambda self, d: -1, + Tensor.T.__get__: lambda self: -1, + Tensor.H.__get__: lambda self: -1, + Tensor.mT.__get__: lambda self: -1, + Tensor.mH.__get__: lambda self: -1, + Tensor._backward_hooks.__get__: lambda self: -1, + Tensor._post_accumulate_grad_hooks.__get__: lambda self: -1, + Tensor._base.__get__: lambda self: -1, + Tensor._cdata.__get__: lambda self: -1, + Tensor.grad.__get__: lambda self: -1, + Tensor._grad.__get__: lambda self: -1, + Tensor._grad_fn.__get__: lambda self: -1, + Tensor.grad_fn.__get__: lambda self: -1, + Tensor._version.__get__: lambda self: -1, + Tensor._autocast_to_reduced_precision: lambda self, cuda_enabled, cpu_enabled, cuda_dtype, cpu_dtype: -1, + Tensor._autocast_to_full_precision: lambda self, cuda_enabled, cpu_enabled: -1, + Tensor._clear_non_serializable_cached_data: lambda self: -1, + Tensor.data.__get__: lambda self: -1, + Tensor.device.__get__: lambda self: -1, + Tensor.dtype.__get__: lambda self: -1, + Tensor.is_cuda.__get__: lambda self: -1, + Tensor.is_cpu.__get__: lambda self: -1, + Tensor.is_xla.__get__: lambda self: -1, + Tensor.is_xpu.__get__: lambda self: -1, + Tensor.is_ipu.__get__: lambda self: -1, + Tensor.is_leaf.__get__: lambda self: -1, + Tensor.retains_grad.__get__: lambda self: -1, + Tensor.is_meta.__get__: lambda self: -1, + Tensor.is_mps.__get__: lambda self: -1, + Tensor.is_mtia.__get__: lambda self: -1, + Tensor.is_nested.__get__: lambda self: -1, + Tensor.is_maia.__get__: lambda self: -1, + Tensor.is_mkldnn.__get__: lambda self: -1, + Tensor.is_quantized.__get__: lambda self: -1, + Tensor.is_sparse.__get__: lambda self: -1, + Tensor.is_sparse_csr.__get__: lambda self: -1, + Tensor.is_vulkan.__get__: lambda self: -1, + Tensor.itemsize.__get__: lambda self: -1, + Tensor.layout.__get__: lambda self: -1, + Tensor.name.__get__: lambda self: -1, + Tensor.names.__get__: lambda self: -1, + Tensor.nbytes.__get__: lambda self: -1, + Tensor.ndim.__get__: lambda self: -1, + Tensor.output_nr.__get__: lambda self: -1, + Tensor.requires_grad.__get__: lambda self: -1, + Tensor.shape.__get__: lambda self: -1, + Tensor.volatile.__get__: lambda self: -1, + Tensor.real.__get__: lambda self: -1, + Tensor.imag.__get__: lambda self: -1, + Tensor.__cuda_array_interface__.__get__: lambda self: -1, + Tensor.type: lambda self, dtype=None, non_blocking=False, **kwargs: -1, + Tensor._dimI: lambda self: -1, + Tensor._dimV: lambda self: -1, + Tensor._indices: lambda self: -1, + Tensor._is_view: lambda self: -1, + Tensor._nnz: lambda self: -1, + Tensor.crow_indices: lambda self: -1, + Tensor.col_indices: lambda self: -1, + Tensor.ccol_indices: lambda self: -1, + Tensor.row_indices: lambda self: -1, + Tensor._update_names: lambda self, names, inplace: -1, + Tensor._values: lambda self: -1, + Tensor.adjoint: lambda self: -1, + Tensor.align_as: lambda self, other: -1, + Tensor.align_to: lambda self, order, ellipsis_idx: -1, + Tensor.apply_: lambda self, callable: -1, + Tensor.as_strided: lambda self, size, stride: -1, + Tensor.as_strided_: lambda self, size, stride: -1, + Tensor.backward: lambda self, gradient=None, retain_graph=None, create_graph=False, inputs=None: -1, + Tensor.bfloat16: lambda self, memory_format=torch.preserve_format: -1, + Tensor.bool: lambda self, memory_format=torch.preserve_format: -1, + Tensor.byte: lambda self, memory_format=torch.preserve_format: -1, + Tensor.char: lambda self, memory_format=torch.preserve_format: -1, + Tensor.cauchy_: lambda self, median=0, sigma=1, *, generator=None: -1, + Tensor.coalesce: lambda self: -1, + Tensor._coalesced_: lambda self, coalesced: -1, + Tensor.contiguous: lambda self, memory_format=torch.contiguous_format: -1, + Tensor.copy_: lambda self, src, non_blocking=False: -1, + Tensor.cpu: lambda self, memory_format=torch.preserve_format: -1, + Tensor.cuda: lambda self, memory_format=torch.preserve_format: -1, + Tensor.mtia: lambda self, memory_format=torch.preserve_format: -1, + Tensor.xpu: lambda self, memory_format=torch.preserve_format: -1, + Tensor.ipu: lambda self, memory_format=torch.preserve_format: -1, + Tensor.data_ptr: lambda self: -1, + Tensor.dense_dim: lambda self: -1, + Tensor.diagonal_scatter: lambda self, src, offset=0, dim1=0, dim2=1: -1, + Tensor.dim: lambda self: -1, + Tensor.dim_order: lambda self, ambiguity_check=False: -1, + Tensor.double: lambda self, memory_format=torch.preserve_format: -1, + Tensor.cdouble: lambda self, memory_format=torch.preserve_format: -1, + Tensor.element_size: lambda self: -1, + Tensor.expand: lambda self, size: -1, + Tensor.expand_as: lambda self, other: -1, + Tensor.exponential_: lambda self, lambd=1, *, generator=None: -1, + Tensor.fill_: lambda self, value: -1, + Tensor.fill_diagonal_: lambda self, value: -1, + Tensor.float: lambda self, memory_format=torch.preserve_format: -1, + Tensor.cfloat: lambda self, memory_format=torch.preserve_format: -1, + Tensor.geometric_: lambda self, p, *, generator=None: -1, + Tensor.get_device: lambda self: -1, + Tensor.half: lambda self, memory_format=torch.preserve_format: -1, + Tensor.chalf: lambda self, memory_format=torch.preserve_format: -1, + Tensor.has_names: lambda self: -1, + Tensor.indices: lambda self: -1, + Tensor.int: lambda self, memory_format=torch.preserve_format: -1, + Tensor.is_coalesced: lambda self: -1, + Tensor.is_contiguous: lambda self: -1, + Tensor.is_inference: lambda self: -1, + Tensor.is_pinned: lambda self: -1, + Tensor.is_set_to: lambda self, tensor: -1, + Tensor.is_shared: lambda self: -1, + Tensor.item: lambda self: -1, + Tensor.log_normal_: lambda self, mean=1, std=2, *, generator=None: -1, + Tensor.log_softmax: lambda self, dim: -1, + Tensor.long: lambda self, memory_format=torch.preserve_format: -1, + Tensor.map_: lambda self, tensor, callable: -1, + Tensor.map2_: lambda self, x, y, callable: -1, + Tensor.mm: lambda self, mat2, out_dtype=None: -1, + Tensor.module_load: lambda self, other, assign=False: -1, + Tensor.narrow_copy: lambda self, dimension, start, length: -1, + Tensor.ndimension: lambda self: -1, + Tensor.nelement: lambda self: -1, + Tensor._nested_tensor_size: lambda self: -1, + Tensor._nested_tensor_storage_offsets: lambda self: -1, + Tensor._nested_tensor_strides: lambda self: -1, + Tensor.normal_: lambda self: -1, + Tensor.numpy: lambda self: -1, + Tensor.permute: lambda self, dim: -1, + Tensor.pin_memory: lambda self: -1, + Tensor.put_: lambda self, indices, tensor, accumulate=False: -1, + Tensor.qscheme: lambda self: -1, + Tensor.random_: lambda self, from_=0, to=None, *, generator=None: -1, + Tensor.record_stream: lambda self, stream: -1, + Tensor.refine_names: lambda self, names: -1, + Tensor.register_hook: lambda self, hook: -1, + Tensor.register_post_accumulate_grad_hook: lambda self, hook: -1, + Tensor.rename: lambda self, name: -1, + Tensor.repeat: lambda self, *size: -1, + Tensor.requires_grad_: lambda self, requires_grad=True: -1, + Tensor.reshape_as: lambda self, other: -1, + Tensor.resize: lambda self, *size: -1, + Tensor.resize_: lambda self, size: -1, + Tensor.resize_as: lambda self, other: -1, + Tensor.resize_as_sparse_: lambda self, other: -1, + Tensor.retain_grad: lambda self: -1, + Tensor.set_: lambda self, source=None, storage_offset=0, size=None, stride=None: -1, + Tensor.select_scatter: lambda self, src, dim, index: -1, + Tensor.share_memory_: lambda self: -1, + Tensor.short: lambda self, memory_format=torch.preserve_format: -1, + Tensor.size: lambda self: -1, + Tensor.slice_scatter: lambda self, src, dim=0, start=None, end=None, step=1: -1, + Tensor.sparse_dim: lambda self: -1, + Tensor.sparse_mask: lambda self, mask: -1, + Tensor._sparse_mask_projection: lambda self, mask, accumulate_matches=False: -1, + Tensor.sparse_resize_: lambda self, size1, size2, dense_dim: -1, + Tensor.sparse_resize_and_clear_: lambda self, size1, size2, dense_dim: -1, + Tensor.sspaddmm: lambda self, mat1, mat2, beta=1, alpha=1, out=None: -1, + Tensor.storage: lambda self: -1, + Tensor.untyped_storage: lambda self: -1, + Tensor.storage_offset: lambda self: -1, + Tensor.storage_type: lambda self: -1, + Tensor.sum_to_size: lambda self, size: -1, + Tensor.tile: lambda self, *reps: -1, + Tensor.to: lambda self, dtype, non_blocking=False, copy=False, memory_format=torch.preserve_format: -1, + Tensor.to_dense: lambda self, dtype=None, *, masked_grad=None: -1, + Tensor._to_dense: lambda self, dtype=None, masked_grad=None: -1, + Tensor.to_sparse: lambda self: -1, + Tensor.tolist: lambda self: -1, + Tensor.to_mkldnn: lambda self: -1, + Tensor.type_as: lambda self, other: -1, + Tensor.unfold: lambda self, dimension, size, step: -1, + Tensor.uniform_: lambda self, from_=0, to=1: -1, + Tensor.values: lambda self: -1, + Tensor.view: lambda self, shape: -1, + Tensor.view_as: lambda self, other: -1, + Tensor.zero_: lambda self: -1, + Tensor.__dlpack__: lambda self, stream=None: -1, + Tensor.__dlpack_device__: lambda self: -1, + torch.linalg.lstsq: lambda self, b, cond=None, driver=None: -1, + } # fmt: skip + + privateuse1_backend_name = ( + torch.utils.backend_registration._privateuse1_backend_name + ) + if hasattr(Tensor, privateuse1_backend_name): + ret[getattr(Tensor, privateuse1_backend_name)] = ( + lambda self, device=None, non_blocking=False, **kwargs: -1 + ) + ret[getattr(Tensor, f"is_{privateuse1_backend_name}").__get__] = lambda self: -1 + + ret2 = {} + ignored = get_ignored_functions() + + for k, v in ret.items(): + # Generate methods like __add__ and add_ by default from add + names = [ + k.__name__, # Default method + k.__name__ + "_", # Inplace variant + "__" + k.__name__ + "__", # Dunder method + "__i" + k.__name__ + "__", # Inplace dunder method + "__r" + k.__name__ + "__", # Reverse dunder method + ] + + if k.__name__.startswith("bitwise_"): + # bitwise_ have dunder methods of the form ____ + # And so on. + subname = k.__name__[len("bitwise_") :] + names.extend( + ["__" + subname + "__", "__i" + subname + "__", "__r" + subname + "__"] + ) + + for name in names: + func = getattr(Tensor, name, None) + if callable(func) and func not in ret and func not in ignored: + ret2[func] = v + + ret.update(ret2) + return ret + + +def wrap_torch_function(dispatcher: Callable): + """Wraps a given function with ``__torch_function__`` -related functionality. + + Parameters + ---------- + dispatcher: Callable + A callable that returns an iterable of Tensor-likes passed into the function. + + Note + ---- + This decorator may reduce the performance of your code. Generally, it's enough to express + your code as a series of functions that, themselves, support __torch_function__. If you + find yourself in the rare situation where this is not the case, e.g. if you're wrapping a + low-level library and you also need it to work for Tensor-likes, then this function is available. + + Examples + -------- + >>> def dispatcher(a): # Must have the same signature as func + ... return (a,) + >>> @torch.overrides.wrap_torch_function(dispatcher) + >>> def func(a): # This will make func dispatchable by __torch_function__ + ... return a + 0 + """ + + def inner(func): + @functools.wraps(func) + def wrapped(*args, **kwargs): + relevant_args = dispatcher(*args, **kwargs) + if has_torch_function(relevant_args): + return handle_torch_function(wrapped, relevant_args, *args, **kwargs) + + return func(*args, **kwargs) + + return wrapped + + return inner + + +def _get_overloaded_args( + relevant_args: Iterable[Any], + get_type_fn: Optional[Callable[[Any], type]] = None, +) -> list[Any]: + """Returns a list of arguments on which to call __torch_function__. + + Checks arguments in relevant_args for __torch_function__ implementations, + storing references to the arguments and their types in overloaded_args and + overloaded_types in order of calling precedence. Only distinct types are + considered. If a type is a subclass of another type it will have higher + precedence, otherwise the precedence order is the same as the order of + arguments in relevant_args, that is, from left-to-right in the argument list. + + The precedence-determining algorithm implemented in this function is + described in `NEP-0018`_. + + See torch::append_overloaded_arg for the equivalent function in the C++ + implementation. + + Parameters + ---------- + relevant_args : iterable of array-like + Iterable of array-like arguments to check for __torch_function__ + methods. + + get_type_fn : callable, optional + Function to call on each argument in relevant_args to get its type. + + Returns + ------- + overloaded_args : list + Arguments from relevant_args on which to call __torch_function__ + methods, in the order in which they should be called. + + .. _NEP-0018: + https://numpy.org/neps/nep-0018-array-function-protocol.html + """ + if get_type_fn is None: + get_type_fn = type + + # If torch function is not enabled, there are no overloaded types + if not torch._C._is_torch_function_enabled(): + return [] + # Runtime is O(num_arguments * num_unique_types) + overloaded_types: set[type] = set() + overloaded_args: list[Any] = [] + for arg in relevant_args: + arg_type = get_type_fn(arg) + # We only collect arguments if they have a unique type, which ensures + # reasonable performance even with a long list of possibly overloaded + # arguments. + # + # NB: Important to exclude _disabled_torch_function_impl, otherwise + # https://github.com/pytorch/pytorch/issues/64687 + if ( + arg_type not in overloaded_types + and hasattr(arg_type, "__torch_function__") + and arg_type.__torch_function__ != torch._C._disabled_torch_function_impl + ): + # Create lists explicitly for the first type (usually the only one + # done) to avoid setting up the iterator for overloaded_args. + if overloaded_types: + overloaded_types.add(arg_type) + # By default, insert argument at the end, but if it is + # subclass of another argument, insert it before that argument. + # This ensures "subclasses before superclasses". + index = len(overloaded_args) + for i, old_arg in enumerate(overloaded_args): + if issubclass(arg_type, get_type_fn(old_arg)): + index = i + break + overloaded_args.insert(index, arg) + else: + overloaded_types = {arg_type} + overloaded_args = [arg] + return overloaded_args + + +def handle_torch_function( + public_api: Callable, + relevant_args: Iterable[Any], + *args, + **kwargs, +) -> Any: + """Implement a function with checks for ``__torch_function__`` overrides. + + See torch::autograd::handle_torch_function for the equivalent of this + function in the C++ implementation. + + Arguments + --------- + public_api : function + Function exposed by the public torch API originally called like + ``public_api(*args, **kwargs)`` on which arguments are now being + checked. + relevant_args : iterable + Iterable of arguments to check for __torch_function__ methods. + args : tuple + Arbitrary positional arguments originally passed into ``public_api``. + kwargs : tuple + Arbitrary keyword arguments originally passed into ``public_api``. + + Returns + ------- + object + Result from calling ``implementation`` or an ``__torch_function__`` + method, as appropriate. + + Raises + ------ + TypeError : if no implementation is found. + + Example + ------- + >>> def func(a): + ... if has_torch_function_unary(a): + ... return handle_torch_function(func, (a,), a) + ... return a + 0 + """ + # Check for __torch_function__ methods. + overloaded_args = _get_overloaded_args(relevant_args) + # overloaded_args already have unique types. + types = tuple(map(type, overloaded_args)) + + # Check for __torch_function__ mode. + if _is_torch_function_mode_enabled(): + # if we're here, the mode must be set to a TorchFunctionStackMode + # this unsets it and calls directly into TorchFunctionStackMode's torch function + with _pop_mode_temporarily() as mode: + result = mode.__torch_function__(public_api, types, args, kwargs) + if result is not NotImplemented: + return result + + # Call overrides + for overloaded_arg in overloaded_args: + # This call needs to become a classmethod call in the future. + # See https://github.com/pytorch/pytorch/issues/63767 + torch_func_method = overloaded_arg.__torch_function__ + if ( + hasattr(torch_func_method, "__self__") + and torch_func_method.__self__ is overloaded_arg + and torch_func_method is not torch._C._disabled_torch_function_impl + ): + warnings.warn( + "Defining your `__torch_function__ as a plain method is deprecated and " + "will be an error in future, please define it as a classmethod.", + DeprecationWarning, + ) + + # Use `public_api` instead of `implementation` so __torch_function__ + # implementations can do equality/identity comparisons. + result = torch_func_method(public_api, types, args, kwargs) + + if result is not NotImplemented: + return result + + func_name = f"{public_api.__module__}.{public_api.__name__}" + msg = ( + f"no implementation found for '{func_name}' on types that implement " + f"__torch_function__: {[type(arg) for arg in overloaded_args]}" + ) + if _is_torch_function_mode_enabled(): + msg += f" nor in mode {_get_current_function_mode()}" + raise TypeError(msg) + + +has_torch_function = _add_docstr( + _has_torch_function, + r"""Check for __torch_function__ implementations in the elements of an iterable + or if a __torch_function__ mode is enabled. Considers exact ``Tensor`` s + and ``Parameter`` s non-dispatchable. Use this to guard a call to + :func:`handle_torch_function`; don't use it to test if something + is Tensor-like, use :func:`is_tensor_like` instead. + Arguments + --------- + relevant_args : iterable + Iterable or arguments to check for __torch_function__ methods. + Returns + ------- + bool + True if any of the elements of relevant_args have __torch_function__ + implementations, False otherwise. + See Also + ________ + torch.is_tensor_like + Checks if something is a Tensor-like, including an exact ``Tensor``. + """, +) + +has_torch_function_unary = _add_docstr( + _has_torch_function_unary, + r"""Special case of `has_torch_function` for single inputs. + Instead of: + `has_torch_function((t,))` + call: + `has_torch_function_unary(t)` + which skips unnecessary packing and unpacking work. + """, +) + +has_torch_function_variadic = _add_docstr( + _has_torch_function_variadic, + r"""Special case of `has_torch_function` that skips tuple creation. + + This uses the METH_FASTCALL protocol introduced in Python 3.7 + + Instead of: + `has_torch_function((a, b))` + call: + `has_torch_function_variadic(a, b)` + which skips unnecessary packing and unpacking work. + """, +) + + +@functools.cache +def _get_overridable_functions() -> tuple[ + dict[Any, list[Callable]], dict[Callable, str] +]: + overridable_funcs = collections.defaultdict(list) + index = {} + tested_namespaces = [ + ("torch", torch, torch.__all__), + ("torch.functional", torch.functional, torch.functional.__all__), + ("torch.nn.functional", torch.nn.functional, dir(torch.nn.functional)), + ("torch.nn.init", torch.nn.init, dir(torch.nn.init)), + ("torch.Tensor", torch.Tensor, dir(torch.Tensor)), + ("torch.linalg", torch.linalg, dir(torch.linalg)), + ("torch.fft", torch.fft, dir(torch.fft)), + ("torch.special", torch.special, dir(torch.special)), + ] + for namespace_str, namespace, ns_funcs in tested_namespaces: + for func_name in ns_funcs: + ignore = False + # ignore private functions or functions that are deleted in torch.__init__ + if namespace is not torch.Tensor: + if func_name.startswith("__"): + continue + elif func_name.startswith("_"): + ignore = True + elif func_name.endswith("_"): + ignore = True + elif not func_name[0].islower(): + ignore = True + elif func_name == "unique_dim": + continue + else: + func = getattr(namespace, func_name) + if getattr(object, func_name, None) == func: + continue + if func_name == "__weakref__": + continue + func = getattr(namespace, func_name) + if namespace is torch.Tensor and getattr(object, func_name, None) == func: + continue + # ignore re-exported modules + if isinstance(func, types.ModuleType): + continue + # ignore __future__ imports + if isinstance(func, __future__._Feature): + continue + + if not callable(func) and hasattr(func, "__get__"): + index[func.__get__] = f"{namespace_str}.{func_name}.__get__" + index[func.__set__] = f"{namespace_str}.{func_name}.__set__" + if ignore: + continue + if func.__get__ in get_ignored_functions(): + msg = ( + "{}.{} is in the tuple returned by torch._overrides.get_ignored_functions " + "but still has an explicit override" + ) + assert func.__get__ not in get_testing_overrides(), msg.format( + namespace, func.__name__ + ) + continue + else: + overridable_funcs[func].append(func.__get__) + continue + + if not callable(func): + continue + + index[func] = f"{namespace_str}.{func_name}" + + if ignore: + continue + + # cannot be overridden by __torch_function__ + if func in get_ignored_functions(): + msg = ( + "{}.{} is in the tuple returned by torch._overrides.get_ignored_functions " + "but still has an explicit override" + ) + assert func not in get_testing_overrides(), msg.format( + namespace, func.__name__ + ) + continue + overridable_funcs[namespace].append(func) + return overridable_funcs, index + + +@_disable_user_warnings +def get_overridable_functions() -> dict[Any, list[Callable]]: + """List functions that are overridable via __torch_function__ + + Returns + ------- + Dict[Any, List[Callable]] + A dictionary that maps namespaces that contain overridable functions + to functions in that namespace that can be overridden. + """ + return _get_overridable_functions()[0] + + +@_disable_user_warnings +def resolve_name(f): + """Get a human readable string name for a function passed to + __torch_function__ + + Arguments + --------- + f : Callable + Function to resolve the name of. + + Returns + ------- + str + Name of the function; if eval'ed it should give back the input + function. + """ + if isinstance(f, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)): + return str(f) + return _get_overridable_functions()[1].get(f) + + +@functools.cache +def _get_tensor_methods() -> set[Callable]: + """Returns a set of the overridable methods on ``torch.Tensor``""" + overridable_funcs = get_overridable_functions() + methods = set(overridable_funcs[torch.Tensor]) + return methods + + +@_disable_user_warnings +def is_tensor_method_or_property(func: Callable) -> bool: + """ + Returns True if the function passed in is a handler for a + method or property belonging to ``torch.Tensor``, as passed + into ``__torch_function__``. + + .. note:: + For properties, their ``__get__`` method must be passed in. + + This may be needed, in particular, for the following reasons: + + 1. Methods/properties sometimes don't contain a `__module__` slot. + 2. They require that the first passed-in argument is an instance + of ``torch.Tensor``. + + Examples + -------- + >>> is_tensor_method_or_property(torch.Tensor.add) + True + >>> is_tensor_method_or_property(torch.add) + False + """ + return func in _get_tensor_methods() or func.__name__ == "__get__" + + +def is_tensor_like(inp): + """ + Returns ``True`` if the passed-in input is a Tensor-like. + + Currently, this occurs whenever there's a ``__torch_function__`` + attribute on the type of the input. + + Examples + -------- + A subclass of tensor is generally a Tensor-like. + + >>> class SubTensor(torch.Tensor): ... + >>> is_tensor_like(SubTensor([0])) + True + + Built-in or user types aren't usually Tensor-like. + + >>> is_tensor_like(6) + False + >>> is_tensor_like(None) + False + >>> class NotATensor: ... + >>> is_tensor_like(NotATensor()) + False + + But, they can be made Tensor-like by implementing __torch_function__. + + >>> class TensorLike: + ... @classmethod + ... def __torch_function__(cls, func, types, args, kwargs): + ... return -1 + >>> is_tensor_like(TensorLike()) + True + """ + return type(inp) is torch.Tensor or hasattr(inp, "__torch_function__") + + +class TorchFunctionMode: + """ + A ``TorchFunctionMode`` allows you to override the meaning of all + ``__torch_function__`` overrideable functions within a dynamic scope, + without having to actually create a tensor subclass or manually + monkey-patch functions in the PyTorch API. Some common situations + where you should use a mode: + + * You want to override the meaning of factory functions, or other + functions that do not otherwise take a tensor as an argument + (these cannot be overridden with tensor subclasses). + + * You want to override the behavior of all functions without needing + to wrap your inputs in tensor subclasses; e.g., if you are just + interested in logging intermediate computations. + + * You want to control the order of execution of various tensor + subclasses explicitly, rather than implicitly via the return of + ``NotImplemented``. + + Independent subclasses of :class:`TorchFunctionMode` are compositional: + modes can be pushed onto a stack using ``with MyMode():``. + When you call functions in the PyTorch API inside your + ``__torch_function__`` implementation, by default, they will forward on to + the next mode on the mode stack. If you want recursively call back into + your current ``__torch_function__`` implementation, either explicitly + invoke ``self.__torch_function__(...)``, or use the context manager + ``enable_torch_function_mode(self, replace=self.inner)`` to make PyTorch + API self-referential (beware of infinite loops, in this case!) + """ + + inner: "TorchFunctionMode" + + # Force metaclass to generate constructor at the base of the hierarchy + def __init__(self) -> None: + pass + + def __torch_function__(self, func, types, args=(), kwargs=None): + raise NotImplementedError + + def __enter__(self): + _push_mode(self) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + _pop_mode() + + @classmethod + def push(cls, *args, **kwargs): + warnings.warn( + "`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`" + ) + instance = cls(*args, **kwargs) + return instance + + +def _get_current_function_mode(): + stack_len = _len_torch_function_stack() + return _get_function_stack_at(stack_len - 1) if stack_len > 0 else None + + +def _get_current_function_mode_stack(): + stack_len = _len_torch_function_stack() + return [_get_function_stack_at(i) for i in range(stack_len)] + + +def _push_mode(mode): + _push_on_torch_function_stack(mode) + + +def _pop_mode(): + old = _pop_torch_function_stack() + return old + + +@contextlib.contextmanager +def _pop_mode_temporarily(): + old = _pop_mode() + try: + yield old + finally: + _push_mode(old) + + +class BaseTorchFunctionMode(TorchFunctionMode): + def __torch_function__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + return func(*args, **kwargs) + + +@contextlib.contextmanager +def _enable_torch_function(): + old_state = torch._C._get_torch_function_state() + try: + torch._C._set_torch_function_state(torch._C._TorchFunctionState.ENABLED) + yield + finally: + torch._C._set_torch_function_state(old_state) + + +@contextlib.contextmanager +def enable_reentrant_dispatch(): + # NB: this can't simply be + # `enable_reentrant_dispatch = torch._C._RestorePythonTLSSnapshot` + # because: + # 1. torch._C._RestorePythonTLSSnapshot is unavailable when this file + # initially gets imported. Probably an import order thing. + # 2. enable_reentrant_dispatch is technically public API; assigning + # it the object would change the __module__ to look private. + with torch._C._RestorePythonTLSSnapshot(): + try: + yield + finally: + pass diff --git a/phivenv/Lib/site-packages/torch/py.typed b/phivenv/Lib/site-packages/torch/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/quasirandom.py b/phivenv/Lib/site-packages/torch/quasirandom.py new file mode 100644 index 0000000000000000000000000000000000000000..6cf051b6dd500ff19c3666905f006d862969d14b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/quasirandom.py @@ -0,0 +1,217 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch + + +class SobolEngine: + r""" + The :class:`torch.quasirandom.SobolEngine` is an engine for generating + (scrambled) Sobol sequences. Sobol sequences are an example of low + discrepancy quasi-random sequences. + + This implementation of an engine for Sobol sequences is capable of + sampling sequences up to a maximum dimension of 21201. It uses direction + numbers from https://web.maths.unsw.edu.au/~fkuo/sobol/ obtained using the + search criterion D(6) up to the dimension 21201. This is the recommended + choice by the authors. + + References: + - Art B. Owen. Scrambling Sobol and Niederreiter-Xing points. + Journal of Complexity, 14(4):466-489, December 1998. + + - I. M. Sobol. The distribution of points in a cube and the accurate + evaluation of integrals. + Zh. Vychisl. Mat. i Mat. Phys., 7:784-802, 1967. + + Args: + dimension (Int): The dimensionality of the sequence to be drawn + scramble (bool, optional): Setting this to ``True`` will produce + scrambled Sobol sequences. Scrambling is + capable of producing better Sobol + sequences. Default: ``False``. + seed (Int, optional): This is the seed for the scrambling. The seed + of the random number generator is set to this, + if specified. Otherwise, it uses a random seed. + Default: ``None`` + + Examples:: + + >>> # xdoctest: +SKIP("unseeded random state") + >>> soboleng = torch.quasirandom.SobolEngine(dimension=5) + >>> soboleng.draw(3) + tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.5000, 0.5000, 0.5000, 0.5000, 0.5000], + [0.7500, 0.2500, 0.2500, 0.2500, 0.7500]]) + """ + + MAXBIT = 30 + MAXDIM = 21201 + + def __init__(self, dimension, scramble=False, seed=None): + if dimension > self.MAXDIM or dimension < 1: + raise ValueError( + "Supported range of dimensionality " + f"for SobolEngine is [1, {self.MAXDIM}]" + ) + + self.seed = seed + self.scramble = scramble + self.dimension = dimension + + cpu = torch.device("cpu") + + self.sobolstate = torch.zeros( + dimension, self.MAXBIT, device=cpu, dtype=torch.long + ) + torch._sobol_engine_initialize_state_(self.sobolstate, self.dimension) + + if not self.scramble: + self.shift = torch.zeros(self.dimension, device=cpu, dtype=torch.long) + else: + self._scramble() + + self.quasi = self.shift.clone(memory_format=torch.contiguous_format) + self._first_point = (self.quasi / 2**self.MAXBIT).reshape(1, -1) + self.num_generated = 0 + + def draw( + self, + n: int = 1, + out: Optional[torch.Tensor] = None, + dtype: Optional[torch.dtype] = None, + ) -> torch.Tensor: + r""" + Function to draw a sequence of :attr:`n` points from a Sobol sequence. + Note that the samples are dependent on the previous samples. The size + of the result is :math:`(n, dimension)`. + + Args: + n (Int, optional): The length of sequence of points to draw. + Default: 1 + out (Tensor, optional): The output tensor + dtype (:class:`torch.dtype`, optional): the desired data type of the + returned tensor. + Default: ``None`` + """ + if dtype is None: + dtype = torch.get_default_dtype() + + if self.num_generated == 0: + if n == 1: + result = self._first_point.to(dtype) + else: + result, self.quasi = torch._sobol_engine_draw( + self.quasi, + n - 1, + self.sobolstate, + self.dimension, + self.num_generated, + dtype=dtype, + ) + result = torch.cat((self._first_point.to(dtype), result), dim=-2) + else: + result, self.quasi = torch._sobol_engine_draw( + self.quasi, + n, + self.sobolstate, + self.dimension, + self.num_generated - 1, + dtype=dtype, + ) + + self.num_generated += n + + if out is not None: + out.resize_as_(result).copy_(result) + return out + + return result + + def draw_base2( + self, + m: int, + out: Optional[torch.Tensor] = None, + dtype: Optional[torch.dtype] = None, + ) -> torch.Tensor: + r""" + Function to draw a sequence of :attr:`2**m` points from a Sobol sequence. + Note that the samples are dependent on the previous samples. The size + of the result is :math:`(2**m, dimension)`. + + Args: + m (Int): The (base2) exponent of the number of points to draw. + out (Tensor, optional): The output tensor + dtype (:class:`torch.dtype`, optional): the desired data type of the + returned tensor. + Default: ``None`` + """ + n = 2**m + total_n = self.num_generated + n + if not (total_n & (total_n - 1) == 0): + raise ValueError( + "The balance properties of Sobol' points require " + f"n to be a power of 2. {self.num_generated} points have been " + f"previously generated, then: n={self.num_generated}+2**{m}={total_n}. " + "If you still want to do this, please use " + "'SobolEngine.draw()' instead." + ) + return self.draw(n=n, out=out, dtype=dtype) + + def reset(self): + r""" + Function to reset the ``SobolEngine`` to base state. + """ + self.quasi.copy_(self.shift) + self.num_generated = 0 + return self + + def fast_forward(self, n): + r""" + Function to fast-forward the state of the ``SobolEngine`` by + :attr:`n` steps. This is equivalent to drawing :attr:`n` samples + without using the samples. + + Args: + n (Int): The number of steps to fast-forward by. + """ + if self.num_generated == 0: + torch._sobol_engine_ff_( + self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated + ) + else: + torch._sobol_engine_ff_( + self.quasi, n, self.sobolstate, self.dimension, self.num_generated - 1 + ) + self.num_generated += n + return self + + def _scramble(self): + g: Optional[torch.Generator] = None + if self.seed is not None: + g = torch.Generator() + g.manual_seed(self.seed) + + cpu = torch.device("cpu") + + # Generate shift vector + shift_ints = torch.randint( + 2, (self.dimension, self.MAXBIT), device=cpu, generator=g + ) + self.shift = torch.mv( + shift_ints, torch.pow(2, torch.arange(0, self.MAXBIT, device=cpu)) + ) + + # Generate lower triangular matrices (stacked across dimensions) + ltm_dims = (self.dimension, self.MAXBIT, self.MAXBIT) + ltm = torch.randint(2, ltm_dims, device=cpu, generator=g).tril() + + torch._sobol_engine_scramble_(self.sobolstate, ltm, self.dimension) + + def __repr__(self): + fmt_string = [f"dimension={self.dimension}"] + if self.scramble: + fmt_string += ["scramble=True"] + if self.seed is not None: + fmt_string += [f"seed={self.seed}"] + return self.__class__.__name__ + "(" + ", ".join(fmt_string) + ")" diff --git a/phivenv/Lib/site-packages/torch/random.py b/phivenv/Lib/site-packages/torch/random.py new file mode 100644 index 0000000000000000000000000000000000000000..adc2b33fe506f2981ca7cc3f24873f3ff897c973 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/random.py @@ -0,0 +1,203 @@ +# mypy: allow-untyped-defs +import contextlib +import warnings +from collections.abc import Generator + +import torch +from torch._C import default_generator + + +def set_rng_state(new_state: torch.Tensor) -> None: + r"""Sets the random number generator state. + + .. note:: This function only works for CPU. For CUDA, please use + :func:`torch.manual_seed`, which works for both CPU and CUDA. + + Args: + new_state (torch.ByteTensor): The desired state + """ + default_generator.set_state(new_state) + + +def get_rng_state() -> torch.Tensor: + r"""Returns the random number generator state as a `torch.ByteTensor`. + + .. note:: The returned state is for the default generator on CPU only. + + See also: :func:`torch.random.fork_rng`. + """ + return default_generator.get_state() + + +def manual_seed(seed) -> torch._C.Generator: + r"""Sets the seed for generating random numbers on all devices. Returns a + `torch.Generator` object. + + Args: + seed (int): The desired seed. Value must be within the inclusive range + `[-0x8000_0000_0000_0000, 0xffff_ffff_ffff_ffff]`. Otherwise, a RuntimeError + is raised. Negative inputs are remapped to positive values with the formula + `0xffff_ffff_ffff_ffff + seed`. + """ + seed = int(seed) + import torch.cuda + + if not torch.cuda._is_in_bad_fork(): + torch.cuda.manual_seed_all(seed) + + import torch.mps + + if not torch.mps._is_in_bad_fork(): + torch.mps.manual_seed(seed) + + import torch.xpu + + if not torch.xpu._is_in_bad_fork(): + torch.xpu.manual_seed_all(seed) + + _seed_custom_device(seed) + + return default_generator.manual_seed(seed) + + +def seed() -> int: + r"""Sets the seed for generating random numbers to a non-deterministic + random number on all devices. Returns a 64 bit number used to seed the RNG. + """ + seed = default_generator.seed() + import torch.cuda + + if not torch.cuda._is_in_bad_fork(): + torch.cuda.manual_seed_all(seed) + + import torch.mps + + if not torch.mps._is_in_bad_fork(): + torch.mps.manual_seed(seed) + + import torch.xpu + + if not torch.xpu._is_in_bad_fork(): + torch.xpu.manual_seed_all(seed) + + _seed_custom_device(seed) + + return seed + + +def _seed_custom_device(seed) -> None: + r"""Sets the seed to generate random numbers for custom device. + + Args: + seed (int): The desired seed. + + See [Note: support the custom device with privateuse1] + """ + seed = int(seed) + custom_backend_name = torch._C._get_privateuse1_backend_name() + if hasattr(torch, custom_backend_name): + custom_device_mod = getattr(torch, custom_backend_name) + _bad_fork_name = "_is_in_bad_fork" + _seed_all_name = "manual_seed_all" + if hasattr(custom_device_mod, _bad_fork_name) and hasattr( + custom_device_mod, _seed_all_name + ): + if not getattr(custom_device_mod, _bad_fork_name)(): + getattr(custom_device_mod, _seed_all_name)(seed) + else: + message = f"Set seed for `{custom_backend_name}` device does not take effect, please add API's " + message += f"`{_bad_fork_name}` and `{_seed_all_name}` to `{custom_backend_name}` device module." + warnings.warn(message, UserWarning, stacklevel=3) + + +def initial_seed() -> int: + r"""Returns the initial seed for generating random numbers as a + Python `long`. + + .. note:: The returned seed is for the default generator on CPU only. + """ + return default_generator.initial_seed() + + +_fork_rng_warned_already = False + + +@contextlib.contextmanager +def fork_rng( + devices=None, + enabled=True, + _caller="fork_rng", + _devices_kw="devices", + device_type="cuda", +) -> Generator: + """ + Forks the RNG, so that when you return, the RNG is reset + to the state that it was previously in. + + Args: + devices (iterable of Device IDs): devices for which to fork + the RNG. CPU RNG state is always forked. By default, :meth:`fork_rng` operates + on all devices, but will emit a warning if your machine has a lot + of devices, since this function will run very slowly in that case. + If you explicitly specify devices, this warning will be suppressed + enabled (bool): if ``False``, the RNG is not forked. This is a convenience + argument for easily disabling the context manager without having + to delete it and unindent your Python code under it. + device_type (str): device type str, default is `cuda`. As for supported device, + see details in :ref:`accelerator` + """ + + if device_type == "meta": + yield + return + + device_type = torch.device(device_type).type + device_mod = getattr(torch, device_type, None) + if device_mod is None: + raise RuntimeError( + f"torch has no module of `{device_type}`, you should register " + + "a module by `torch._register_device_module`." + ) + global _fork_rng_warned_already + + # Internal arguments: + # _caller: the function which called fork_rng, which the user used + # _devices_kw: the devices keyword of _caller + + if not enabled: + yield + return + + if devices is None: + num_devices = device_mod.device_count() + if num_devices > 1 and not _fork_rng_warned_already: + message = ( + f"{device_type.upper()} reports that you have {num_devices} available devices, and " + f"you have used {_caller} without explicitly specifying which devices are being used. " + f"For safety, we initialize *every* {device_type.upper()} device by default, which can " + f"be quite slow if you have a lot of {device_type.upper()}s. If you know that you are only" + f" making use of a few {device_type.upper()} devices, set the environment variable " + f"{device_type.upper()}_VISIBLE_DEVICES or the '{_devices_kw}' keyword argument of {_caller} " + "with the set of devices you are actually using. For example, if you are using CPU only, " + "set device.upper()_VISIBLE_DEVICES= or devices=[]; if you are using device 0 only, " + f"set {device_type.upper()}_VISIBLE_DEVICES=0 or devices=[0]. To initialize all devices " + f"and suppress this warning, set the '{_devices_kw}' keyword argument to " + f"`range(torch.{device_type}.device_count())`." + ) + warnings.warn(message) + _fork_rng_warned_already = True + devices = list(range(num_devices)) + else: + # Protect against user passing us a generator; we need to traverse this + # multiple times but a generator will be exhausted upon first traversal + devices = list(devices) + + cpu_rng_state = torch.get_rng_state() + device_rng_states = [device_mod.get_rng_state(device) for device in devices] + + try: + yield + finally: + torch.set_rng_state(cpu_rng_state) + for device, device_rng_state in zip(devices, device_rng_states): + device_mod.set_rng_state(device_rng_state, device) diff --git a/phivenv/Lib/site-packages/torch/return_types.py b/phivenv/Lib/site-packages/torch/return_types.py new file mode 100644 index 0000000000000000000000000000000000000000..950406023fa161d0896fdd12823fd451f1dba7a0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/return_types.py @@ -0,0 +1,51 @@ +import inspect + +import torch +from torch.utils._pytree import register_pytree_node, SequenceKey + + +__all__ = ["pytree_register_structseq", "all_return_types"] + +all_return_types = [] + +# error: Module has no attribute "_return_types" +return_types = torch._C._return_types # type: ignore[attr-defined] + + +def pytree_register_structseq(cls): + def structseq_flatten(structseq): + return list(structseq), None + + def structseq_flatten_with_keys(structseq): + values, context = structseq_flatten(structseq) + return [(SequenceKey(i), v) for i, v in enumerate(values)], context + + def structseq_unflatten(values, context): + return cls(values) + + register_pytree_node( + cls, + structseq_flatten, + structseq_unflatten, + flatten_with_keys_fn=structseq_flatten_with_keys, + ) + + +for name in dir(return_types): + if name.startswith("__"): + continue + + _attr = getattr(return_types, name) + globals()[name] = _attr + + if not name.startswith("_"): + __all__.append(name) + all_return_types.append(_attr) + + # Today everything in torch.return_types is a structseq, aka a "namedtuple"-like + # thing defined by the Python C-API. We're going to need to modify this when that + # is no longer the case. + # NB: I don't know how to check that something is a "structseq" so we do a fuzzy + # check for tuple + if inspect.isclass(_attr) and issubclass(_attr, tuple): + pytree_register_structseq(_attr) diff --git a/phivenv/Lib/site-packages/torch/return_types.pyi b/phivenv/Lib/site-packages/torch/return_types.pyi new file mode 100644 index 0000000000000000000000000000000000000000..1b4b9814d0fc59e9660c2dcb2a42362489aeab00 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/return_types.pyi @@ -0,0 +1,605 @@ +# @generated by tools/pyi/gen_pyi.py from torch/_C/return_types.pyi.in +# mypy: allow-untyped-defs + +from typing import Final, NoReturn +from typing_extensions import Self + +from torch import SymInt, Tensor +from torch.types import ( # noqa: F401 + _bool, + _device, + _dtype, + _float, + _int, + _layout, + _qscheme, + _size, + Number, +) + +__all__ = [ + "pytree_register_structseq", + "all_return_types", + "_fake_quantize_per_tensor_affine_cachemask_tensor_qparams", + "_fused_moving_avg_obs_fq_helper", + "_linalg_det", + "_linalg_eigh", + "_linalg_slogdet", + "_linalg_solve_ex", + "_linalg_svd", + "_lu_with_info", + "_scaled_dot_product_cudnn_attention", + "_scaled_dot_product_efficient_attention", + "_scaled_dot_product_flash_attention", + "_scaled_dot_product_flash_attention_for_cpu", + "_unpack_dual", + "aminmax", + "cummax", + "cummin", + "frexp", + "geqrf", + "histogram", + "histogramdd", + "kthvalue", + "lu_unpack", + "max", + "median", + "min", + "mode", + "nanmedian", + "qr", + "slogdet", + "sort", + "svd", + "topk", + "triangular_solve", +] + +def pytree_register_structseq(cls: type) -> None: ... + +class _fake_quantize_per_tensor_affine_cachemask_tensor_qparams(tuple[Tensor, Tensor]): # fmt: skip + @property + def output(self) -> Tensor: ... + @property + def mask(self) -> Tensor: ... + def __new__( + cls, + sequence: tuple[Tensor, Tensor], + ) -> Self: # fmt: skip + ... + n_fields: Final[_int] = 2 + n_sequence_fields: Final[_int] = 2 + n_unnamed_fields: Final[_int] = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class _fused_moving_avg_obs_fq_helper(tuple[Tensor, Tensor]): # fmt: skip + @property + def output(self) -> Tensor: ... + @property + def mask(self) -> Tensor: ... + def __new__( + cls, + sequence: tuple[Tensor, Tensor], + ) -> Self: # fmt: skip + ... + n_fields: Final[_int] = 2 + n_sequence_fields: Final[_int] = 2 + n_unnamed_fields: Final[_int] = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class _linalg_det(tuple[Tensor, Tensor, Tensor]): # fmt: skip + @property + def result(self) -> Tensor: ... + @property + def LU(self) -> Tensor: ... + @property + def pivots(self) -> Tensor: ... + def __new__( + cls, + sequence: tuple[Tensor, Tensor, Tensor], + ) -> Self: # fmt: skip + ... + n_fields: Final[_int] = 3 + n_sequence_fields: Final[_int] = 3 + n_unnamed_fields: Final[_int] = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class _linalg_eigh(tuple[Tensor, Tensor]): # fmt: skip + @property + def eigenvalues(self) -> Tensor: ... + @property + def eigenvectors(self) -> Tensor: ... + def __new__( + cls, + sequence: tuple[Tensor, Tensor], + ) -> Self: # fmt: skip + ... + n_fields: Final[_int] = 2 + n_sequence_fields: Final[_int] = 2 + n_unnamed_fields: Final[_int] = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class _linalg_slogdet(tuple[Tensor, Tensor, Tensor, Tensor]): # fmt: skip + @property + def sign(self) -> Tensor: ... + @property + def logabsdet(self) -> Tensor: ... + @property + def LU(self) -> Tensor: ... + @property + def pivots(self) -> Tensor: ... + def __new__( + cls, + sequence: tuple[Tensor, Tensor, Tensor, Tensor], + ) -> Self: # fmt: skip + ... + n_fields: Final[_int] = 4 + n_sequence_fields: Final[_int] = 4 + n_unnamed_fields: Final[_int] = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class _linalg_solve_ex(tuple[Tensor, Tensor, Tensor, Tensor]): # fmt: skip + @property + def result(self) -> Tensor: ... + @property + def LU(self) -> Tensor: ... + @property + def pivots(self) -> Tensor: ... + @property + def info(self) -> Tensor: ... + def __new__( + cls, + sequence: tuple[Tensor, Tensor, Tensor, Tensor], + ) -> Self: # fmt: skip + ... + n_fields: Final[_int] = 4 + n_sequence_fields: Final[_int] = 4 + n_unnamed_fields: Final[_int] = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class _linalg_svd(tuple[Tensor, Tensor, Tensor]): # fmt: skip + @property + def U(self) -> Tensor: ... + @property + def S(self) -> Tensor: ... + @property + def Vh(self) -> Tensor: ... + def __new__( + cls, + sequence: tuple[Tensor, Tensor, Tensor], + ) -> Self: # fmt: skip + ... + n_fields: Final[_int] = 3 + n_sequence_fields: Final[_int] = 3 + n_unnamed_fields: Final[_int] = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class _lu_with_info(tuple[Tensor, Tensor, Tensor]): # fmt: skip + @property + def LU(self) -> Tensor: ... + @property + def pivots(self) -> Tensor: ... + @property + def info(self) -> Tensor: ... + def __new__( + cls, + sequence: tuple[Tensor, Tensor, Tensor], + ) -> Self: # fmt: skip + ... + n_fields: Final[_int] = 3 + n_sequence_fields: Final[_int] = 3 + n_unnamed_fields: Final[_int] = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class _scaled_dot_product_cudnn_attention(tuple[Tensor, Tensor, Tensor, Tensor, _int | SymInt, _int | SymInt, Tensor, Tensor, Tensor]): # fmt: skip + @property + def output(self) -> Tensor: ... + @property + def logsumexp(self) -> Tensor: ... + @property + def cum_seq_q(self) -> Tensor: ... + @property + def cum_seq_k(self) -> Tensor: ... + @property + def max_q(self) -> _int | SymInt: ... + @property + def max_k(self) -> _int | SymInt: ... + @property + def philox_seed(self) -> Tensor: ... + @property + def philox_offset(self) -> Tensor: ... + @property + def debug_attn_mask(self) -> Tensor: ... + def __new__( + cls, + sequence: tuple[Tensor, Tensor, Tensor, Tensor, _int | SymInt, _int | SymInt, Tensor, Tensor, Tensor], + ) -> Self: # fmt: skip + ... + n_fields: Final[_int] = 9 + n_sequence_fields: Final[_int] = 9 + n_unnamed_fields: Final[_int] = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class _scaled_dot_product_efficient_attention(tuple[Tensor, Tensor, Tensor, Tensor]): # fmt: skip + @property + def output(self) -> Tensor: ... + @property + def log_sumexp(self) -> Tensor: ... + @property + def philox_seed(self) -> Tensor: ... + @property + def philox_offset(self) -> Tensor: ... + def __new__( + cls, + sequence: tuple[Tensor, Tensor, Tensor, Tensor], + ) -> Self: # fmt: skip + ... + n_fields: Final[_int] = 4 + n_sequence_fields: Final[_int] = 4 + n_unnamed_fields: Final[_int] = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class _scaled_dot_product_flash_attention(tuple[Tensor, Tensor, Tensor, Tensor, _int | SymInt, _int | SymInt, Tensor, Tensor, Tensor]): # fmt: skip + @property + def output(self) -> Tensor: ... + @property + def logsumexp(self) -> Tensor: ... + @property + def cum_seq_q(self) -> Tensor: ... + @property + def cum_seq_k(self) -> Tensor: ... + @property + def max_q(self) -> _int | SymInt: ... + @property + def max_k(self) -> _int | SymInt: ... + @property + def rng_state(self) -> Tensor: ... + @property + def unused(self) -> Tensor: ... + @property + def debug_attn_mask(self) -> Tensor: ... + def __new__( + cls, + sequence: tuple[Tensor, Tensor, Tensor, Tensor, _int | SymInt, _int | SymInt, Tensor, Tensor, Tensor], + ) -> Self: # fmt: skip + ... + n_fields: Final[_int] = 9 + n_sequence_fields: Final[_int] = 9 + n_unnamed_fields: Final[_int] = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class _scaled_dot_product_flash_attention_for_cpu(tuple[Tensor, Tensor]): # fmt: skip + @property + def output(self) -> Tensor: ... + @property + def logsumexp(self) -> Tensor: ... + def __new__( + cls, + sequence: tuple[Tensor, Tensor], + ) -> Self: # fmt: skip + ... + n_fields: Final[_int] = 2 + n_sequence_fields: Final[_int] = 2 + n_unnamed_fields: Final[_int] = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class _unpack_dual(tuple[Tensor, Tensor]): # fmt: skip + @property + def primal(self) -> Tensor: ... + @property + def tangent(self) -> Tensor: ... + def __new__( + cls, + sequence: tuple[Tensor, Tensor], + ) -> Self: # fmt: skip + ... + n_fields: Final[_int] = 2 + n_sequence_fields: Final[_int] = 2 + n_unnamed_fields: Final[_int] = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class aminmax(tuple[Tensor, Tensor]): # fmt: skip + @property + def min(self) -> Tensor: ... + @property + def max(self) -> Tensor: ... + def __new__( + cls, + sequence: tuple[Tensor, Tensor], + ) -> Self: # fmt: skip + ... + n_fields: Final[_int] = 2 + n_sequence_fields: Final[_int] = 2 + n_unnamed_fields: Final[_int] = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class cummax(tuple[Tensor, Tensor]): # fmt: skip + @property + def values(self) -> Tensor: ... + @property + def indices(self) -> Tensor: ... + def __new__( + cls, + sequence: tuple[Tensor, Tensor], + ) -> Self: # fmt: skip + ... + n_fields: Final[_int] = 2 + n_sequence_fields: Final[_int] = 2 + n_unnamed_fields: Final[_int] = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class cummin(tuple[Tensor, Tensor]): # fmt: skip + @property + def values(self) -> Tensor: ... + @property + def indices(self) -> Tensor: ... + def __new__( + cls, + sequence: tuple[Tensor, Tensor], + ) -> Self: # fmt: skip + ... + n_fields: Final[_int] = 2 + n_sequence_fields: Final[_int] = 2 + n_unnamed_fields: Final[_int] = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class frexp(tuple[Tensor, Tensor]): # fmt: skip + @property + def mantissa(self) -> Tensor: ... + @property + def exponent(self) -> Tensor: ... + def __new__( + cls, + sequence: tuple[Tensor, Tensor], + ) -> Self: # fmt: skip + ... + n_fields: Final[_int] = 2 + n_sequence_fields: Final[_int] = 2 + n_unnamed_fields: Final[_int] = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class geqrf(tuple[Tensor, Tensor]): # fmt: skip + @property + def a(self) -> Tensor: ... + @property + def tau(self) -> Tensor: ... + def __new__( + cls, + sequence: tuple[Tensor, Tensor], + ) -> Self: # fmt: skip + ... + n_fields: Final[_int] = 2 + n_sequence_fields: Final[_int] = 2 + n_unnamed_fields: Final[_int] = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class histogram(tuple[Tensor, Tensor]): # fmt: skip + @property + def hist(self) -> Tensor: ... + @property + def bin_edges(self) -> Tensor: ... + def __new__( + cls, + sequence: tuple[Tensor, Tensor], + ) -> Self: # fmt: skip + ... + n_fields: Final[_int] = 2 + n_sequence_fields: Final[_int] = 2 + n_unnamed_fields: Final[_int] = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class histogramdd(tuple[Tensor, tuple[Tensor, ...]]): # fmt: skip + @property + def hist(self) -> Tensor: ... + @property + def bin_edges(self) -> tuple[Tensor, ...]: ... + def __new__( + cls, + sequence: tuple[Tensor, tuple[Tensor, ...]], + ) -> Self: # fmt: skip + ... + n_fields: Final[_int] = 2 + n_sequence_fields: Final[_int] = 2 + n_unnamed_fields: Final[_int] = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class kthvalue(tuple[Tensor, Tensor]): # fmt: skip + @property + def values(self) -> Tensor: ... + @property + def indices(self) -> Tensor: ... + def __new__( + cls, + sequence: tuple[Tensor, Tensor], + ) -> Self: # fmt: skip + ... + n_fields: Final[_int] = 2 + n_sequence_fields: Final[_int] = 2 + n_unnamed_fields: Final[_int] = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class lu_unpack(tuple[Tensor, Tensor, Tensor]): # fmt: skip + @property + def P(self) -> Tensor: ... + @property + def L(self) -> Tensor: ... + @property + def U(self) -> Tensor: ... + def __new__( + cls, + sequence: tuple[Tensor, Tensor, Tensor], + ) -> Self: # fmt: skip + ... + n_fields: Final[_int] = 3 + n_sequence_fields: Final[_int] = 3 + n_unnamed_fields: Final[_int] = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class max(tuple[Tensor, Tensor]): # fmt: skip + @property + def values(self) -> Tensor: ... + @property + def indices(self) -> Tensor: ... + def __new__( + cls, + sequence: tuple[Tensor, Tensor], + ) -> Self: # fmt: skip + ... + n_fields: Final[_int] = 2 + n_sequence_fields: Final[_int] = 2 + n_unnamed_fields: Final[_int] = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class median(tuple[Tensor, Tensor]): # fmt: skip + @property + def values(self) -> Tensor: ... + @property + def indices(self) -> Tensor: ... + def __new__( + cls, + sequence: tuple[Tensor, Tensor], + ) -> Self: # fmt: skip + ... + n_fields: Final[_int] = 2 + n_sequence_fields: Final[_int] = 2 + n_unnamed_fields: Final[_int] = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class min(tuple[Tensor, Tensor]): # fmt: skip + @property + def values(self) -> Tensor: ... + @property + def indices(self) -> Tensor: ... + def __new__( + cls, + sequence: tuple[Tensor, Tensor], + ) -> Self: # fmt: skip + ... + n_fields: Final[_int] = 2 + n_sequence_fields: Final[_int] = 2 + n_unnamed_fields: Final[_int] = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class mode(tuple[Tensor, Tensor]): # fmt: skip + @property + def values(self) -> Tensor: ... + @property + def indices(self) -> Tensor: ... + def __new__( + cls, + sequence: tuple[Tensor, Tensor], + ) -> Self: # fmt: skip + ... + n_fields: Final[_int] = 2 + n_sequence_fields: Final[_int] = 2 + n_unnamed_fields: Final[_int] = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class nanmedian(tuple[Tensor, Tensor]): # fmt: skip + @property + def values(self) -> Tensor: ... + @property + def indices(self) -> Tensor: ... + def __new__( + cls, + sequence: tuple[Tensor, Tensor], + ) -> Self: # fmt: skip + ... + n_fields: Final[_int] = 2 + n_sequence_fields: Final[_int] = 2 + n_unnamed_fields: Final[_int] = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class qr(tuple[Tensor, Tensor]): # fmt: skip + @property + def Q(self) -> Tensor: ... + @property + def R(self) -> Tensor: ... + def __new__( + cls, + sequence: tuple[Tensor, Tensor], + ) -> Self: # fmt: skip + ... + n_fields: Final[_int] = 2 + n_sequence_fields: Final[_int] = 2 + n_unnamed_fields: Final[_int] = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class slogdet(tuple[Tensor, Tensor]): # fmt: skip + @property + def sign(self) -> Tensor: ... + @property + def logabsdet(self) -> Tensor: ... + def __new__( + cls, + sequence: tuple[Tensor, Tensor], + ) -> Self: # fmt: skip + ... + n_fields: Final[_int] = 2 + n_sequence_fields: Final[_int] = 2 + n_unnamed_fields: Final[_int] = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class sort(tuple[Tensor, Tensor]): # fmt: skip + @property + def values(self) -> Tensor: ... + @property + def indices(self) -> Tensor: ... + def __new__( + cls, + sequence: tuple[Tensor, Tensor], + ) -> Self: # fmt: skip + ... + n_fields: Final[_int] = 2 + n_sequence_fields: Final[_int] = 2 + n_unnamed_fields: Final[_int] = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class svd(tuple[Tensor, Tensor, Tensor]): # fmt: skip + @property + def U(self) -> Tensor: ... + @property + def S(self) -> Tensor: ... + @property + def V(self) -> Tensor: ... + def __new__( + cls, + sequence: tuple[Tensor, Tensor, Tensor], + ) -> Self: # fmt: skip + ... + n_fields: Final[_int] = 3 + n_sequence_fields: Final[_int] = 3 + n_unnamed_fields: Final[_int] = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class topk(tuple[Tensor, Tensor]): # fmt: skip + @property + def values(self) -> Tensor: ... + @property + def indices(self) -> Tensor: ... + def __new__( + cls, + sequence: tuple[Tensor, Tensor], + ) -> Self: # fmt: skip + ... + n_fields: Final[_int] = 2 + n_sequence_fields: Final[_int] = 2 + n_unnamed_fields: Final[_int] = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class triangular_solve(tuple[Tensor, Tensor]): # fmt: skip + @property + def solution(self) -> Tensor: ... + @property + def cloned_coefficient(self) -> Tensor: ... + def __new__( + cls, + sequence: tuple[Tensor, Tensor], + ) -> Self: # fmt: skip + ... + n_fields: Final[_int] = 2 + n_sequence_fields: Final[_int] = 2 + n_unnamed_fields: Final[_int] = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +all_return_types: list[type] = ... diff --git a/phivenv/Lib/site-packages/torch/serialization.py b/phivenv/Lib/site-packages/torch/serialization.py new file mode 100644 index 0000000000000000000000000000000000000000..8874d7f7bec12456a8736a6388c3af1193bbd86b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/serialization.py @@ -0,0 +1,2130 @@ +# mypy: allow-untyped-defs +import copyreg +import difflib +import functools +import io +import os +import pickle +import re +import shutil +import struct +import sys +import tarfile +import tempfile +import threading +import warnings +from contextlib import closing, contextmanager +from enum import Enum +from typing import Any, Callable, cast, Generic, IO, Optional, TypeVar, Union +from typing_extensions import TypeAlias, TypeIs + +import torch +import torch._weights_only_unpickler as _weights_only_unpickler +from torch._sources import get_source_lines_and_file +from torch._utils import _import_dotted_name +from torch.storage import _get_dtype_from_pickle_storage_type +from torch.types import FileLike, Storage + + +__all__ = [ + "SourceChangeWarning", + "mkdtemp", + "register_package", + "check_module_version_greater_or_equal", + "validate_cuda_device", + "validate_hpu_device", + "location_tag", + "default_restore_location", + "normalize_storage_type", + "storage_to_tensor_type", + "save", + "load", + "StorageType", + "LoadEndianness", + "get_crc32_options", + "set_crc32_options", + "get_default_load_endianness", + "set_default_load_endianness", + "get_default_mmap_options", + "set_default_mmap_options", + "clear_safe_globals", + "get_safe_globals", + "add_safe_globals", + "safe_globals", + "get_unsafe_globals_in_checkpoint", + "skip_data", +] + +DEFAULT_PROTOCOL = 2 + +LONG_SIZE = struct.Struct("=l").size +INT_SIZE = struct.Struct("=i").size +SHORT_SIZE = struct.Struct("=h").size + +MAGIC_NUMBER = 0x1950A86A20F9469CFC6C +PROTOCOL_VERSION = 1001 +STORAGE_KEY_SEPARATOR = "," + +MAP_LOCATION: TypeAlias = Optional[ + Union[Callable[[Storage, str], Storage], torch.device, str, dict[str, str]] +] +STORAGE: TypeAlias = Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage] + +IS_WINDOWS = sys.platform == "win32" + +UNSAFE_MESSAGE = ( + "In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` " + "from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, " + "but it can result in arbitrary code execution. Do it only if you got the file from a " + "trusted source." +) + +if not IS_WINDOWS: + from mmap import MAP_PRIVATE, MAP_SHARED +else: + MAP_SHARED, MAP_PRIVATE = None, None # type: ignore[assignment] + + +def _default_to_weights_only(pickle_module): + is_fbcode = not hasattr(torch.version, "git_version") + return pickle_module is None and not is_fbcode + + +# _serialization_tls is used to store thread local state specific to serialization +# that needs to be propagated to other files, in particular we use this for +# (1) map_location (needed for wrapper subclasses/third party devices to torch._utils) +# (2) skip_data (needed for torch.Tensor.__reduce_ex__ for skip_data ctx) +# (3) materialize_fake_tensors (needed for torch.Tensor.__reduce_ex__ for skip_data ctx) +class _SerializationLocal(threading.local): + def __init__(self): + super().__init__() + self.map_location: Optional[MAP_LOCATION] = None + self.skip_data: bool = False + self.materialize_fake_tensors: bool = False + + +_serialization_tls = _SerializationLocal() + + +class SourceChangeWarning(Warning): + pass + + +@contextmanager +def mkdtemp(): + path = tempfile.mkdtemp() + try: + yield path + finally: + shutil.rmtree(path) + + +_package_registry: list[ + tuple[ + int, + Callable[[STORAGE], Optional[str]], + Callable[[STORAGE, str], Optional[STORAGE]], + ] +] = [] + + +class LoadEndianness(Enum): + NATIVE = 1 + LITTLE = 2 + BIG = 3 + + +def get_default_load_endianness() -> Optional[LoadEndianness]: + """ + Get fallback byte order for loading files + + If byteorder mark is not present in saved checkpoint, + this byte order is used as fallback. + By default, it's "native" byte order. + + Returns: + default_load_endian: Optional[LoadEndianness] + """ + from torch.utils.serialization import config + + return config.load.endianness + + +def set_default_load_endianness(endianness): + """ + Set fallback byte order for loading files + + If byteorder mark is not present in saved checkpoint, + this byte order is used as fallback. + By default, it's "native" byte order. + + Args: + endianness: the new fallback byte order + """ + if not isinstance(endianness, LoadEndianness) and endianness is not None: + raise TypeError("Invalid argument type in function set_default_load_endianness") + from torch.utils.serialization import config + + config.load.endianness = endianness + + +def get_crc32_options() -> bool: + """ + Get whether :func:`torch.save` computes and writes crc32 for each record. + + Defaults to ``True``. + """ + from torch.utils.serialization import config + + return config.save.compute_crc32 + + +def set_crc32_options(compute_crc32: bool): + """ + Set whether :func:`torch.save` computes and writes crc32 for each record. + + .. note:: + Setting this to ``False`` may make unzipping of the ``torch.save`` output + fail or warn due to corrupted CRC32. However ``torch.load`` will be + able to load the file. + + Args: + compute_crc32 (bool): set crc32 compuation flag + """ + from torch.utils.serialization import config + + config.save.compute_crc32 = compute_crc32 + + +def get_default_mmap_options() -> Optional[int]: + """ + Get default mmap options for :func:`torch.load` with ``mmap=True``. + + Defaults to ``mmap.MAP_PRIVATE``. + + + Returns: + default_mmap_options: int + """ + from torch.utils.serialization import config + + return config.load.mmap_flags + + +def _get_storage_alignment() -> int: + """ + Gets alignment for storages in torch.save files/ + + Defaults to 64. + + Returns: + storage_alginment: int + """ + from torch.utils.serialization import config + + return config.save.storage_alignment + + +class set_default_mmap_options: + """ + Context manager or function to set default mmap options for :func:`torch.load` with ``mmap=True`` to flags. + + For now, only either ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED`` are supported. + Please open an issue if you need any other option to be added here. + + .. note:: + This feature is currently not supported for Windows. + + Args: + flags: ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED`` + """ + + def __init__(self, flags: int) -> None: + if IS_WINDOWS: + raise RuntimeError( + "Changing the default mmap options is currently not supported for Windows" + ) + if flags != MAP_PRIVATE and flags != MAP_SHARED: + raise ValueError( + "Invalid argument in function set_default_mmap_options, " + f"expected mmap.MAP_PRIVATE or mmap.MAP_SHARED, but got {flags}" + ) + # global config + from torch.utils.serialization import config + + self.prev = config.load.mmap_flags + config.load.mmap_flags = flags + + def __enter__(self) -> None: + pass + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + from torch.utils.serialization import config + + config.load.mmap_flags = self.prev + + +def clear_safe_globals() -> None: + """ + Clears the list of globals that are safe for ``weights_only`` load. + """ + _weights_only_unpickler._clear_safe_globals() + + +def get_safe_globals() -> list[Union[Callable, tuple[Callable, str]]]: + """ + Returns the list of user-added globals that are safe for ``weights_only`` load. + """ + return _weights_only_unpickler._get_safe_globals() + + +def add_safe_globals(safe_globals: list[Union[Callable, tuple[Callable, str]]]) -> None: + """ + Marks the given globals as safe for ``weights_only`` load. For example, functions + added to this list can be called during unpickling, classes could be instantiated + and have state set. + + Each item in the list can either be a function/class or a tuple of the form + (function/class, string) where string is the full path of the function/class. + + Within the serialized format, each function is identified with its full + path as ``{__module__}.{__qualname__}``. When calling this API, you can provide this + full path that should match the one in the checkpoint otherwise the default + ``{fn.__module__}.{fn.__qualname__}`` will be used. + + Args: + safe_globals (List[Union[Callable, Tuple[Callable, str]]]): list of globals to mark as safe + + Example: + >>> # xdoctest: +SKIP("Can't torch.save(t, ...) as doctest thinks MyTensor is defined on torch.serialization") + >>> import tempfile + >>> class MyTensor(torch.Tensor): + ... pass + >>> t = MyTensor(torch.randn(2, 3)) + >>> with tempfile.NamedTemporaryFile() as f: + ... torch.save(t, f.name) + # Running `torch.load(f.name, weights_only=True)` will fail with + # Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default. + # Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint. + ... torch.serialization.add_safe_globals([MyTensor]) + ... torch.load(f.name, weights_only=True) + # MyTensor([[-0.5024, -1.8152, -0.5455], + # [-0.8234, 2.0500, -0.3657]]) + """ + _weights_only_unpickler._add_safe_globals(safe_globals) + + +class safe_globals(_weights_only_unpickler._safe_globals): + r"""Context-manager that adds certain globals as safe for ``weights_only`` load. + + Args: + safe_globals: List of globals for weights_only load. + + Example: + >>> # xdoctest: +SKIP("Can't torch.save(t, ...) as doctest thinks MyTensor is defined on torch.serialization") + >>> import tempfile + >>> class MyTensor(torch.Tensor): + ... pass + >>> t = MyTensor(torch.randn(2, 3)) + >>> with tempfile.NamedTemporaryFile() as f: + ... torch.save(t, f.name) + # Running `torch.load(f.name, weights_only=True)` will fail with + # Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default. + # Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint. + ... with torch.serialization.safe_globals([MyTensor]): + ... torch.load(f.name, weights_only=True) + # MyTensor([[-0.5024, -1.8152, -0.5455], + # [-0.8234, 2.0500, -0.3657]]) + >>> assert torch.serialization.get_safe_globals() == [] + """ + + +def get_unsafe_globals_in_checkpoint(f: FileLike) -> list[str]: + """Returns a list of strings of functions/classes in a ``torch.save`` object that are not safe for ``weights_only``. + + For a given function or class ``f``, the corresponding string will be of the form + ``{f.__module__}.{f.__name__}``. + + This function will return any GLOBALs in the checkpoint that are not in the set marked safe + for ``weights_only`` (either via :func:`add_safe_globals` or :class:`safe_globals` context or + allowlisted by ``torch`` by default). + + .. note:: + This function will statically disassemble the pickle file in the checkpoint. + The implication is any classes dynamically pushed onto the stack during unpickling + will not be included in the output. + + Args: + f: File-like object or string containing the checkpoint object saved via ``torch.save`` + + Returns: + A list of strings of pickle GLOBALs in the checkpoint that are not allowlisted for ``weights_only``. + """ + default_safe_globals_strings = set( + _weights_only_unpickler._get_allowed_globals().keys() + ) + user_safe_global_strings = set( + _weights_only_unpickler._get_user_allowed_globals().keys() + ) + safe_global_strings = default_safe_globals_strings.union(user_safe_global_strings) + + with _open_file_like(f, "rb") as opened_file: + if not _is_zipfile(opened_file): + raise ValueError("Expected input to be a checkpoint returned by torch.save") + with _open_zipfile_reader(opened_file) as zip_file: + if _is_torchscript_zip(zip_file): + raise ValueError( + "Expected input to be a checkpoint returned by torch.save but got a torchscript checkpoint" + ) + data_file = io.BytesIO(zip_file.get_record("data.pkl")) + all_globals = _weights_only_unpickler.get_globals_in_pkl(data_file) + return list(all_globals.difference(safe_global_strings)) + + +class skip_data: + """ + Context-manager that skips writing/reading storage bytes for ``torch.save`` / ``torch.load`` calls. + + For the save path, storages will still be saved, but the space that their bytes would usually be written to + will be empty space. The storage bytes can then be populated in a separate pass. + + For the load path, tensors will be loaded per the checkpoint but their storages will not be populated with data. + + .. warning:: + The ``skip_data`` context manager is an early prototype and is subject to change. + + Args: + materialize_fake_tensors: Whether to materialize FakeTensors during save. This is a no-op for the load path. + + Example: + >>> # xdoctest: +SKIP("NamedTemporaryFile on Windows") + >>> import tempfile + >>> t = torch.randn(2, 3) + >>> with tempfile.NamedTemporaryFile() as f: + ... with torch.serialization.skip_data(): + ... torch.save(t, f.name) + ... torch.load(f.name, weights_only=True) + tensor([[0., 0., 0.], + [0., 0., 0.]]) + """ + + def __init__(self, materialize_fake_tensors: bool = False): + self.materialize_fake_tensors = materialize_fake_tensors + + def __enter__(self): + global _serialization_tls + self._old_skip_data = _serialization_tls.skip_data + self._old_materialize_fake_tensors = _serialization_tls.materialize_fake_tensors + _serialization_tls.skip_data = True + _serialization_tls.materialize_fake_tensors = self.materialize_fake_tensors + + def __exit__(self, type, value, tb): + global _serialization_tls + _serialization_tls.skip_data = self._old_skip_data + _serialization_tls.materialize_fake_tensors = self._old_materialize_fake_tensors + + +def _is_zipfile(f) -> bool: + # This is a stricter implementation than zipfile.is_zipfile(). + # zipfile.is_zipfile() is True if the magic number appears anywhere in the + # binary. Since we expect the files here to be generated by torch.save or + # torch.jit.save, it's safe to only check the start bytes and avoid + # collisions and assume the zip has only 1 file. + # See bugs.python.org/issue28494. + + start = f.tell() + # Read the first few bytes and match against the ZIP file signature + local_header_magic_number = b"PK\x03\x04" + read_bytes = f.read(len(local_header_magic_number)) + f.seek(start) + return read_bytes == local_header_magic_number + + +def register_package( + priority: int, + tagger: Callable[[STORAGE], Optional[str]], + deserializer: Callable[[STORAGE, str], Optional[STORAGE]], +): + """ + Registers callables for tagging and deserializing storage objects with an associated priority. + Tagging associates a device with a storage object at save time while deserializing moves a + storage object to an appropriate device at load time. :attr:`tagger` and :attr:`deserializer` + are run in the order given by their :attr:`priority` until a tagger/deserializer returns a + value that is not `None`. + + To override the deserialization behavior for a device in the global registry, one can register a + tagger with a higher priority than the existing tagger. + + This function can also be used to register a tagger and deserializer for new devices. + + Args: + priority: Indicates the priority associated with the tagger and deserializer, where a lower + value indicates higher priority. + tagger: Callable that takes in a storage object and returns its tagged device as a string + or None. + deserializer: Callable that takes in storage object and a device string and returns a storage + object on the appropriate device or None. + + Returns: + `None` + + Example: + >>> def ipu_tag(obj): + >>> if obj.device.type == 'ipu': + >>> return 'ipu' + >>> def ipu_deserialize(obj, location): + >>> if location.startswith('ipu'): + >>> ipu = getattr(torch, "ipu", None) + >>> assert ipu is not None, "IPU device module is not loaded" + >>> assert torch.ipu.is_available(), "ipu is not available" + >>> return obj.ipu(location) + >>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize) + """ + queue_elem = (priority, tagger, deserializer) + _package_registry.append(queue_elem) + _package_registry.sort() + + +def check_module_version_greater_or_equal( + module, + req_version_tuple, + error_if_malformed=True, +): + """ + Check if a module's version satisfies requirements + + Usually, a module's version string will be like 'x.y.z', which would be represented + as a tuple (x, y, z), but sometimes it could be an unexpected format. If the version + string does not match the given tuple's format up to the length of the tuple, then + error and exit or emit a warning. + + Args: + module: the module to check the version of + req_version_tuple: tuple (usually of ints) representing the required version + error_if_malformed: whether we should exit if module version string is malformed + + Returns: + requirement_is_met: bool + """ + try: + version_strs = module.__version__.split(".") + # Cast module version fields to match the types of the required version + module_version = tuple( + type(req_field)(version_strs[idx]) + for idx, req_field in enumerate(req_version_tuple) + ) + requirement_is_met = module_version >= req_version_tuple + + except Exception as e: + message = ( + f"'{module.__name__}' module version string is malformed '{module.__version__}' and cannot be compared" + f" with tuple {str(req_version_tuple)}" + ) + if error_if_malformed: + raise RuntimeError(message) from e + else: + warnings.warn(message + ", but continuing assuming that requirement is met") + requirement_is_met = True + + return requirement_is_met + + +def _cpu_tag(obj): + if obj.device.type == "cpu": + return "cpu" + + +def _mps_tag(obj): + if obj.device.type == "mps": + return "mps" + + +def _meta_tag(obj): + if obj.device.type == "meta": + return "meta" + + +def _backend_tag(backend_name, obj): + if backend_name == "privateuse1": + backend_name = torch._C._get_privateuse1_backend_name() + if obj.device.type == backend_name: + if obj.device.index is None: + return backend_name + else: + return backend_name + ":" + str(obj.device.index) + + +def _cpu_deserialize(obj, location): + if location == "cpu": + return obj + + +def _mps_deserialize(obj, location): + if location.startswith("mps"): + return obj.mps() + + +def _meta_deserialize(obj, location): + if location == "meta": + return torch.UntypedStorage(obj.nbytes(), device="meta") + + +def _validate_device(location, backend_name): + """ + Check whether the device index of specified backend is valid + + In case of privateuse1 backend, your must first register a device_module for + privateuse1 using torch._register_device_module. Implement the following + methods in device_module like cuda: device_module._utils._get_device_index(location, True), + device_module.device_count(). + + Args: + location: string of device + backend_name: the backend name or the name of privateuse1, which can be renamed + + Returns: + device_index: int + """ + if not hasattr(torch, backend_name): + raise RuntimeError( + f"The {backend_name.upper()} device module is not registered. " + "If you are running on a CPU-only machine, " + "please use torch.load with map_location=torch.device('cpu') " + "to map your storages to the CPU." + ) + device_module = getattr(torch, backend_name) + if hasattr(device_module, "_utils") and hasattr( + device_module._utils, "_get_device_index" + ): + device_index = device_module._utils._get_device_index(location, True) + device = torch.device(backend_name, device_index) + else: + device = torch.device(location) + device_index = device.index if device.index else 0 + if hasattr(device_module, "is_available") and not device_module.is_available(): + raise RuntimeError( + f"Attempting to deserialize object on a {backend_name.upper()} " + f"device but torch.{backend_name}.is_available() is False. " + "If you are running on a CPU-only machine, " + "please use torch.load with map_location=torch.device('cpu') " + "to map your storages to the CPU." + ) + if hasattr(device_module, "device_count"): + device_count = device_module.device_count() + if device_index >= device_count: + raise RuntimeError( + f"Attempting to deserialize object on {backend_name.upper()} device " + f"{device_index} but torch.{backend_name}.device_count() is {device_count}. " + "Please use torch.load with map_location to map your storages " + "to an existing device." + ) + return device + + +def validate_cuda_device(location): + return _validate_device(location, "cuda").index + + +def validate_hpu_device(location): + return _validate_device(location, "hpu").index + + +def _deserialize(backend_name, obj, location): + if backend_name == "privateuse1": + backend_name = torch._C._get_privateuse1_backend_name() + if location.startswith(backend_name): + device = _validate_device(location, backend_name) + return obj.to(device=device) + + +register_package(10, _cpu_tag, _cpu_deserialize) +register_package( + 20, + functools.partial(_backend_tag, "cuda"), + functools.partial(_deserialize, "cuda"), +) +register_package(21, _mps_tag, _mps_deserialize) +register_package(22, _meta_tag, _meta_deserialize) +register_package( + 23, + functools.partial(_backend_tag, "privateuse1"), + functools.partial(_deserialize, "privateuse1"), +) +register_package( + 24, + functools.partial(_backend_tag, "hpu"), + functools.partial(_deserialize, "hpu"), +) +register_package( + 25, + functools.partial(_backend_tag, "xpu"), + functools.partial(_deserialize, "xpu"), +) + + +def location_tag( + storage: Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage], +): + for _, tagger, _ in _package_registry: + location = tagger(storage) + if location: + return location + raise RuntimeError( + "don't know how to determine data location of " + torch.typename(storage) + ) + + +def default_restore_location(storage, location): + """ + Restores `storage` using a deserializer function registered for the `location`. + + This function looks in the registry for deserializer functions that match the `location`. + If found, it attempts to use them, in priority order, to restore `storage` until one + returns a not `None` result. If no deserializer can be found in the registry, or all found fail + to bear a result, it raises a `RuntimeError`. + + Args: + storage (STORAGE): the storage object to restore + location (str): the location tag associated with the storage object + + Returns: + storage: Optional[STORAGE] + + Raises: + RuntimeError: If no deserializer matching `location` is found in the registry or if + all matching ones return `None`. + """ + for _, _, fn in _package_registry: + result = fn(storage, location) + if result is not None: + return result + raise RuntimeError( + "don't know how to restore data location of " + + torch.typename(storage) + + " (tagged with " + + location + + ")" + ) + + +def normalize_storage_type(storage_type): + return getattr(torch, storage_type.__name__) + + +def storage_to_tensor_type(storage): + storage_type = type(storage) + module = _import_dotted_name(storage_type.__module__) + return getattr(module, storage_type.__name__.replace("Storage", "Tensor")) + + +def _is_path(name_or_buffer: object) -> TypeIs[Union[str, os.PathLike]]: + return isinstance(name_or_buffer, (str, os.PathLike)) + + +T = TypeVar("T") + + +class _opener(Generic[T]): + def __init__(self, file_like: T) -> None: + self.file_like: T = file_like + + def __enter__(self): + return self.file_like + + def __exit__(self, *args): + pass + + +class _open_file(_opener[IO[bytes]]): + def __init__(self, name: Union[str, os.PathLike[str]], mode: str) -> None: + super().__init__(open(name, mode)) + + def __exit__(self, *args): + self.file_like.close() + + +class _open_buffer_reader(_opener[IO[bytes]]): + def __init__(self, buffer: IO[bytes]) -> None: + super().__init__(buffer) + _check_seekable(buffer) + + +class _open_buffer_writer(_opener[IO[bytes]]): + def __exit__(self, *args): + self.file_like.flush() + + +def _open_file_like(name_or_buffer: FileLike, mode: str) -> _opener[IO[bytes]]: + if _is_path(name_or_buffer): + return _open_file(name_or_buffer, mode) + else: + if "w" in mode: + return _open_buffer_writer(name_or_buffer) + elif "r" in mode: + return _open_buffer_reader(name_or_buffer) + else: + raise RuntimeError(f"Expected 'r' or 'w' in mode but got {mode}") + + +class _open_zipfile_reader(_opener[torch._C.PyTorchFileReader]): + def __init__(self, name_or_buffer: Union[str, IO[bytes]]) -> None: + super().__init__(torch._C.PyTorchFileReader(name_or_buffer)) + + +class _open_zipfile_writer_file(_opener[torch._C.PyTorchFileWriter]): + def __init__(self, name: str) -> None: + self.file_stream = None + self.name = name + try: + self.name.encode("ascii") + except UnicodeEncodeError: + # PyTorchFileWriter only supports ascii filename. + # For filenames with non-ascii characters, we rely on Python + # for writing out the file. + self.file_stream = io.FileIO(self.name, mode="w") + super().__init__( + torch._C.PyTorchFileWriter( + self.file_stream, get_crc32_options(), _get_storage_alignment() + ) + ) + else: + super().__init__( + torch._C.PyTorchFileWriter( + self.name, get_crc32_options(), _get_storage_alignment() + ) + ) + + def __exit__(self, *args) -> None: + self.file_like.write_end_of_file() + if self.file_stream is not None: + self.file_stream.close() + + +class _open_zipfile_writer_buffer(_opener[torch._C.PyTorchFileWriter]): + def __init__(self, buffer: IO[bytes]) -> None: + if not callable(getattr(buffer, "write", None)): + msg = f"Buffer of {str(type(buffer)).strip('<>')} has no callable attribute 'write'" + if not hasattr(buffer, "write"): + raise AttributeError(msg) + raise TypeError(msg) + self.buffer = buffer + super().__init__( + torch._C.PyTorchFileWriter( + buffer, get_crc32_options(), _get_storage_alignment() + ) + ) + + def __exit__(self, *args) -> None: + self.file_like.write_end_of_file() + self.buffer.flush() + + +def _open_zipfile_writer(name_or_buffer: Union[str, IO[bytes]]) -> _opener: + container: type[_opener] + if _is_path(name_or_buffer): + container = _open_zipfile_writer_file + else: + container = _open_zipfile_writer_buffer + return container(name_or_buffer) # type: ignore[arg-type] + + +def _is_compressed_file(f) -> bool: + compress_modules = ["gzip"] + try: + return f.__module__ in compress_modules + except AttributeError: + return False + + +def _should_read_directly(f): + """ + Checks if f is a file that should be read directly. It should be read + directly if it is backed by a real file (has a fileno) and is not a + a compressed file (e.g. gzip) + """ + if _is_compressed_file(f): + return False + try: + return f.fileno() >= 0 + except io.UnsupportedOperation: + return False + except AttributeError: + return False + + +def _check_seekable(f) -> bool: + def raise_err_msg(patterns, e): + for p in patterns: + if p in str(e): + msg = ( + str(e) + + ". You can only torch.load from a file that is seekable." + + " Please pre-load the data into a buffer like io.BytesIO and" + + " try to load from it instead." + ) + raise type(e)(msg) + raise e + + try: + f.seek(f.tell()) + return True + except (io.UnsupportedOperation, AttributeError) as e: + raise_err_msg(["seek", "tell"], e) + return False + + +def _check_dill_version(pickle_module) -> None: + """Checks if using dill as the pickle module, and if so, checks if it is the correct version. + If dill version is lower than 0.3.1, a ValueError is raised. + + Args: + pickle_module: module used for pickling metadata and objects + + """ + if pickle_module is not None and pickle_module.__name__ == "dill": + required_dill_version = (0, 3, 1) + if not check_module_version_greater_or_equal( + pickle_module, required_dill_version, False + ): + raise ValueError( + ( + "'torch' supports dill >= {}, but you have dill {}." + " Please upgrade dill or switch to 'pickle'" + ).format( + ".".join([str(num) for num in required_dill_version]), + pickle_module.__version__, + ) + ) + + +def _check_save_filelike(f): + if not _is_path(f) and not hasattr(f, "write"): + raise AttributeError( + "expected 'f' to be string, path, or a file-like object with " + "a 'write' attribute" + ) + + +def save( + obj: object, + f: FileLike, + pickle_module: Any = pickle, + pickle_protocol: int = DEFAULT_PROTOCOL, + _use_new_zipfile_serialization: bool = True, + _disable_byteorder_record: bool = False, +) -> None: + # Reference: https://github.com/pytorch/pytorch/issues/54354 + # The first line of this docstring overrides the one Sphinx generates for the + # documentation. We need it so that Sphinx doesn't leak `pickle`s path from + # the build environment (e.g. `>> # xdoctest: +SKIP("makes cwd dirty") + >>> # Save to file + >>> x = torch.tensor([0, 1, 2, 3, 4]) + >>> torch.save(x, "tensor.pt") + >>> # Save to io.BytesIO buffer + >>> buffer = io.BytesIO() + >>> torch.save(x, buffer) + """ + torch._C._log_api_usage_once("torch.save") + _check_dill_version(pickle_module) + _check_save_filelike(f) + + if isinstance(f, (str, os.PathLike)): + f = os.fspath(f) + + if _use_new_zipfile_serialization: + with _open_zipfile_writer(f) as opened_zipfile: + _save( + obj, + opened_zipfile, + pickle_module, + pickle_protocol, + _disable_byteorder_record, + ) + return + else: + global _serialization_tls + if _serialization_tls.skip_data: + raise RuntimeError( + "Cannot use skip_data=True with _use_new_zipfile_serialization=False" + ) + with _open_file_like(f, "wb") as opened_file: + _legacy_save(obj, opened_file, pickle_module, pickle_protocol) + + +def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None: + import torch.nn as nn + + serialized_container_types = {} + serialized_storages: dict[str, tuple[torch.UntypedStorage, torch.dtype]] = {} + + # Since loading storages that view the same data with different dtypes is + # not supported, we need to keep track of the dtype associated with each + # storage data_ptr and throw an error if the dtype is ever different. + # TODO: This feature could be added in the future + storage_dtypes: dict[int, torch.dtype] = {} + + def persistent_id(obj: Any) -> Optional[tuple]: + # FIXME: the docs say that persistent_id should only return a string + # but torch store returns tuples. This works only in the binary protocol + # see + # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects + # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 + if isinstance(obj, type) and issubclass(obj, nn.Module): + if obj in serialized_container_types: + return None + serialized_container_types[obj] = True + source_file = source = None + try: + source_lines, _, source_file = get_source_lines_and_file(obj) + source = "".join(source_lines) + except ( + Exception + ): # saving the source is optional, so we can ignore any errors + warnings.warn( + "Couldn't retrieve source code for container of " + "type " + obj.__name__ + ". It won't be checked " + "for correctness upon loading." + ) + return ("module", obj, source_file, source) + + if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj): + storage: torch.UntypedStorage + + if isinstance(obj, torch.storage.TypedStorage): + # TODO: Once we decide to break serialization FC, this case + # can be deleted + storage = obj._untyped_storage + storage_dtype = obj.dtype + storage_type_str = obj._pickle_storage_type() + storage_type = getattr(torch, storage_type_str) + dtype = obj.dtype + storage_numel = obj._size() + + elif isinstance(obj, torch.UntypedStorage): + storage = obj + storage_dtype = torch.uint8 + storage_type = normalize_storage_type(type(obj)) + dtype = torch.uint8 + storage_numel = storage.nbytes() + else: + raise TypeError(f"type not recognized: {type(obj)}") + + # If storage is allocated, ensure that any other saved storages + # pointing to the same data all have the same dtype. If storage is + # not allocated, don't perform this check + if storage.data_ptr() != 0: + if storage.data_ptr() in storage_dtypes: + if storage_dtype != storage_dtypes[storage.data_ptr()]: + raise RuntimeError( + "Cannot save multiple tensors or storages that " + "view the same data as different types" + ) + else: + storage_dtypes[storage.data_ptr()] = storage_dtype + + view_metadata: Optional[tuple[str, int, int]] + + # Offset is always 0, but we keep it for backwards compatibility + # with the old serialization format (which supported storage views) + offset = 0 + storage_key = str(storage._cdata) + location = location_tag(storage) + + # TODO: There's an issue here with FC. It might be impossible to + # solve, but it's worth noting. Imagine we save a list `[storage, + # tensor]`, where `tensor.storage()` is the same as `storage`, and + # `tensor.element_size() > 1`. Let's say that `tensor.dtype == + # torch.float`. The storage will be serialized with element size + # of 1, since we're choosing to serialize the first occurance of + # a duplicate storage. Since this legacy serialization format saves + # the numel of the storage, rather than nbytes directly, we'll be + # effectively saving nbytes in this case. We'll be able to load it + # and the tensor back up with no problems in _this_ and future + # versions of pytorch, but in older versions, here's the problem: + # the storage will be loaded up as a UntypedStorage, and then the + # FloatTensor will loaded and the UntypedStorage will be assigned to + # it. Since the storage dtype does not match the tensor dtype, this + # will cause an error. If we reverse the list, like `[tensor, + # storage]`, then we will save the `tensor.storage()` as a faked + # `FloatStorage`, and the saved size will be the correct + # dtype-specific numel count that old versions expect. `tensor` + # will be able to load up properly in old versions, pointing to + # a FloatStorage. However, `storage` is still being translated to + # a UntypedStorage, and it will try to resolve to the same + # FloatStorage that `tensor` contains. This will also cause an + # error. It doesn't seem like there's any way around this. + # Probably, we just cannot maintain FC for the legacy format if the + # saved list contains both a tensor and a storage that point to the + # same data. We should still be able to maintain FC for lists of + # just tensors, as long as all views share the same dtype as the + # tensor they are viewing. + + if storage_key not in serialized_storages: + serialized_storages[storage_key] = (storage, dtype) + is_view = storage._cdata != storage._cdata + if is_view: + view_metadata = (str(storage._cdata), offset, storage.nbytes()) + else: + view_metadata = None + + res = ( + "storage", + storage_type, + storage_key, + location, + storage_numel, + view_metadata, + ) + return res + return None + + sys_info = dict( + protocol_version=PROTOCOL_VERSION, + little_endian=sys.byteorder == "little", + type_sizes=dict( + short=SHORT_SIZE, + int=INT_SIZE, + long=LONG_SIZE, + ), + ) + + pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol) + pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol) + pickle_module.dump(sys_info, f, protocol=pickle_protocol) + + class PyTorchLegacyPickler(pickle_module.Pickler): + def persistent_id(self, obj): + return persistent_id(obj) + + pickler = PyTorchLegacyPickler(f, protocol=pickle_protocol) + pickler.dump(obj) + + serialized_storage_keys = sorted(serialized_storages.keys()) + pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol) + f.flush() + for key in serialized_storage_keys: + storage, dtype = serialized_storages[key] + storage._write_file( + f, _should_read_directly(f), True, torch._utils._element_size(dtype) + ) + + +def _save( + obj, + zip_file, + pickle_module, + pickle_protocol, + _disable_byteorder_record, +): + serialized_storages = {} + id_map: dict[int, str] = {} + + # Since loading storages that view the same data with different dtypes is + # not supported, we need to keep track of the dtype associated with each + # storage data_ptr and throw an error if the dtype is ever different. + # TODO: This feature could be added in the future + storage_dtypes: dict[int, torch.dtype] = {} + + def persistent_id(obj): + # FIXME: the docs say that persistent_id should only return a string + # but torch store returns tuples. This works only in the binary protocol + # see + # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects + # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 + if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj): + if isinstance(obj, torch.storage.TypedStorage): + # TODO: Once we decide to break serialization FC, this case + # can be deleted + storage = obj._untyped_storage + storage_dtype = obj.dtype + storage_type_str = obj._pickle_storage_type() + storage_type = getattr(torch, storage_type_str) + storage_numel = obj._size() + + else: + storage = obj + storage_dtype = torch.uint8 + storage_type = normalize_storage_type(type(obj)) + storage_numel = storage.nbytes() + + # If storage is allocated, ensure that any other saved storages + # pointing to the same data all have the same dtype. If storage is + # not allocated, don't perform this check + if str(storage.device) != "meta" and storage.data_ptr() != 0: + if storage.data_ptr() in storage_dtypes: + if storage_dtype != storage_dtypes[storage.data_ptr()]: + raise RuntimeError( + "Cannot save multiple tensors or storages that " + "view the same data as different types" + ) + else: + storage_dtypes[storage.data_ptr()] = storage_dtype + + storage_key = id_map.setdefault(storage._cdata, str(len(id_map))) + if hasattr(obj, "_fake_device") and obj._fake_device is not None: + location = str(obj._fake_device) + else: + location = location_tag(storage) + serialized_storages[storage_key] = storage + + return ("storage", storage_type, storage_key, location, storage_numel) + + return None + + # Write the pickle data for `obj` + data_buf = io.BytesIO() + + class PyTorchPickler(pickle_module.Pickler): # type: ignore[name-defined] + def persistent_id(self, obj): + return persistent_id(obj) + + pickler = PyTorchPickler(data_buf, protocol=pickle_protocol) + pickler.dump(obj) + data_value = data_buf.getvalue() + zip_file.write_record("data.pkl", data_value, len(data_value)) + # .format_version is used to track + # 1. version 1 represents the order of storages being changed from + # lexicographical based on keys to numerically ordered based on keys + # 2. version 2 represents including storage_alignment as a record + # within the zipfile + zip_file.write_record(".format_version", "1", len("1")) + storage_alignment = str(_get_storage_alignment()) + zip_file.write_record( + ".storage_alignment", storage_alignment, len(storage_alignment) + ) + + # Write byte order marker + if not _disable_byteorder_record: + if sys.byteorder not in ["little", "big"]: + raise ValueError("Unknown endianness type: " + sys.byteorder) + + zip_file.write_record("byteorder", sys.byteorder, len(sys.byteorder)) + + # Write each tensor to a file named tensor/the_tensor_key in the zip archive + for key in serialized_storages.keys(): + name = f"data/{key}" + storage = serialized_storages[key] + num_bytes = storage.nbytes() + global _serialization_tls + if _serialization_tls.skip_data: + zip_file.write_record_metadata(name, num_bytes) + else: + # given that we copy things around anyway, we might use storage.cpu() + # this means to that to get tensors serialized, you need to implement + # .cpu() on the underlying Storage + if storage.device.type != "cpu": + from torch.utils.serialization import config + + if ( + config.save.use_pinned_memory_for_d2h + and ( + acc := torch.accelerator.current_accelerator( + check_available=True + ) + ) + is not None + and acc.type == storage.device.type + ): + new_storage = torch.empty( + num_bytes, dtype=torch.uint8, device="cpu", pin_memory=True + ).untyped_storage() + new_storage.copy_(storage) + torch.accelerator.current_stream(storage.device.index).synchronize() + storage = new_storage + else: + storage = storage.cpu() + # Now that it is on the CPU we can directly copy it into the zip file + zip_file.write_record(name, storage, num_bytes) + + +def load( + f: FileLike, + map_location: MAP_LOCATION = None, + pickle_module: Any = None, + *, + weights_only: Optional[bool] = None, + mmap: Optional[bool] = None, + **pickle_load_args: Any, +) -> Any: + # Reference: https://github.com/pytorch/pytorch/issues/54354 + # The first line of this docstring overrides the one Sphinx generates for the + # documentation. We need it so that Sphinx doesn't leak `pickle`s path from + # the build environment (e.g. `>> # xdoctest: +SKIP("undefined filepaths") + >>> torch.load("tensors.pt", weights_only=True) + # Load all tensors onto the CPU + >>> torch.load( + ... "tensors.pt", + ... map_location=torch.device("cpu"), + ... weights_only=True, + ... ) + # Load all tensors onto the CPU, using a function + >>> torch.load( + ... "tensors.pt", + ... map_location=lambda storage, loc: storage, + ... weights_only=True, + ... ) + # Load all tensors onto GPU 1 + >>> torch.load( + ... "tensors.pt", + ... map_location=lambda storage, loc: storage.cuda(1), + ... weights_only=True, + ... ) # type: ignore[attr-defined] + # Map tensors from GPU 1 to GPU 0 + >>> torch.load( + ... "tensors.pt", + ... map_location={"cuda:1": "cuda:0"}, + ... weights_only=True, + ... ) + # Load tensor from io.BytesIO object + # Loading from a buffer setting weights_only=False, warning this can be unsafe + >>> with open("tensor.pt", "rb") as f: + ... buffer = io.BytesIO(f.read()) + >>> torch.load(buffer, weights_only=False) + # Load a module with 'ascii' encoding for unpickling + # Loading from a module setting weights_only=False, warning this can be unsafe + >>> torch.load("module.pt", encoding="ascii", weights_only=False) + """ + torch._C._log_api_usage_once("torch.load") + DOCS_MESSAGE = ( + "\n\nCheck the documentation of torch.load to learn more about types accepted by default with " + "weights_only https://pytorch.org/docs/stable/generated/torch.load.html." + ) + + def _get_wo_message(message: str) -> str: + unsafe_global_pattern = r"GLOBAL (\S+) was not an allowed global by default." + has_unsafe_global = re.search(unsafe_global_pattern, message) is not None + blocklist_pattern = r"whose module (\S+) is blocked" + has_blocklist = re.search(blocklist_pattern, message) is not None + import_pattern = r"(\S+) must be (\S+) to load" + has_import = re.search(import_pattern, message) is not None + if has_unsafe_global: + updated_message = ( + "Weights only load failed. This file can still be loaded, to do so you have two options, " + "\033[1mdo those steps only if you trust the source of the checkpoint\033[0m. " + f"\n\t(1) {UNSAFE_MESSAGE}\n\t(2) Alternatively, to load with `weights_only=True` please check " + "the recommended steps in the following error message.\n\tWeightsUnpickler error: " + + message + ) + else: + if has_import: + return f"Weights only load failed. {message}\n {UNSAFE_MESSAGE}\n" + else: + updated_message = f"Weights only load failed. {UNSAFE_MESSAGE}\n" + if not has_blocklist: + updated_message += ( + "Please file an issue with the following so that we can make " + "`weights_only=True` compatible with your use case: WeightsUnpickler error: " + ) + updated_message += message + return updated_message + DOCS_MESSAGE + + weights_only_not_set = weights_only is None + + if weights_only_not_set: + weights_only = _default_to_weights_only(pickle_module) + + true_values = ["1", "y", "yes", "true"] + # Add ability to force safe only or non-safe weight loads via environment variables + force_weights_only_load = ( + os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0") in true_values + ) + force_no_weights_only_load = ( + os.getenv("TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD", "0") in true_values + ) + + if force_weights_only_load and force_no_weights_only_load: + raise RuntimeError( + "Only one of `TORCH_FORCE_WEIGHTS_ONLY_LOAD` or `TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD` " + "should be set, but both were set." + ) + elif force_weights_only_load: + weights_only = True + elif force_no_weights_only_load: + # TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD can only override if callsite did not explicitly set weights_only + if weights_only_not_set: + warnings.warn( + "Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected, since the" + "`weights_only` argument was not explicitly passed to `torch.load`, forcing weights_only=False.", + UserWarning, + stacklevel=2, + ) + weights_only = False + + if weights_only: + if pickle_module is not None: + raise RuntimeError( + "Can not safely load weights when explicit pickle_module is specified" + ) + else: + if pickle_module is None: + pickle_module = pickle + + # make flipping default BC-compatible + if mmap is None: + from torch.utils.serialization import config + + mmap = config.load.mmap + + _check_dill_version(pickle_module) + + if "encoding" not in pickle_load_args.keys(): + pickle_load_args["encoding"] = "utf-8" + + with _open_file_like(f, "rb") as opened_file: + if _is_zipfile(opened_file): + # The zipfile reader is going to advance the current file position. + # If we want to actually tail call to torch.jit.load, we need to + # reset back to the original position. + orig_position = opened_file.tell() + overall_storage = None + with _open_zipfile_reader(opened_file) as opened_zipfile: + if _is_torchscript_zip(opened_zipfile): + warnings.warn( + "'torch.load' received a zip file that looks like a TorchScript archive" + " dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to" + " silence this warning)", + UserWarning, + ) + if weights_only: + raise RuntimeError( + "Cannot use ``weights_only=True`` with TorchScript archives passed to " + "``torch.load``. " + UNSAFE_MESSAGE + ) + opened_file.seek(orig_position) + return torch.jit.load(opened_file, map_location=map_location) + if mmap: + if not _is_path(f): + raise ValueError( + "f must be a file path in order to use the mmap argument" + ) + size = os.path.getsize(f) + if not IS_WINDOWS: + shared = get_default_mmap_options() == MAP_SHARED + else: + shared = False + overall_storage = torch.UntypedStorage.from_file( + os.fspath(f), shared, size + ) + if weights_only: + try: + return _load( + opened_zipfile, + map_location, + _weights_only_unpickler, + overall_storage=overall_storage, + **pickle_load_args, + ) + except pickle.UnpicklingError as e: + raise pickle.UnpicklingError(_get_wo_message(str(e))) from None + return _load( + opened_zipfile, + map_location, + pickle_module, + overall_storage=overall_storage, + **pickle_load_args, + ) + if mmap: + f_name = "" if not isinstance(f, str) else f"{f}, " + raise RuntimeError( + "mmap can only be used with files saved with " + f"`torch.save({f_name}_use_new_zipfile_serialization=True), " + "please torch.save your checkpoint with this option in order to use mmap." + ) + if weights_only: + try: + return _legacy_load( + opened_file, + map_location, + _weights_only_unpickler, + **pickle_load_args, + ) + except pickle.UnpicklingError as e: + raise pickle.UnpicklingError(_get_wo_message(str(e))) from None + return _legacy_load( + opened_file, map_location, pickle_module, **pickle_load_args + ) + + +# Register pickling support for layout instances such as +# torch.sparse_coo, etc +def _get_layout(name): + """Get layout extension object from its string representation.""" + cache = _get_layout.cache # type: ignore[attr-defined] + if not cache: + for v in torch.__dict__.values(): + if isinstance(v, torch.layout): + cache[str(v)] = v + return cache[name] + + +# There are yet not good way to type annotate function attributes https://github.com/python/mypy/issues/2087 +_get_layout.cache = {} # type: ignore[attr-defined] +copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),))) + + +def _legacy_load(f, map_location, pickle_module, **pickle_load_args): + deserialized_objects: dict[int, Any] = {} + + restore_location = _get_restore_location(map_location) + + class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined] + def find_class(self, mod_name, name): + if type(name) is str and "Storage" in name: + try: + return StorageType(name) + except KeyError: + pass + return super().find_class(mod_name, name) + + def _check_container_source(container_type, source_file, original_source): + try: + current_source = "".join(get_source_lines_and_file(container_type)[0]) + except Exception: # saving the source is optional, so we can ignore any errors + warnings.warn( + "Couldn't retrieve source code for container of " + "type " + container_type.__name__ + ". It won't be checked " + "for correctness upon loading." + ) + return + if original_source != current_source: + if container_type.dump_patches: + file_name = container_type.__name__ + ".patch" + diff = difflib.unified_diff( + current_source.split("\n"), + original_source.split("\n"), + source_file, + source_file, + lineterm="", + ) + lines = "\n".join(diff) + try: + with open(file_name, "a+") as f: + file_size = f.seek(0, 2) + f.seek(0) + if file_size == 0: + f.write(lines) + elif file_size != len(lines) or f.read() != lines: + raise OSError + msg = ( + "Saved a reverse patch to " + file_name + ". " + "Run `patch -p0 < " + file_name + "` to revert your " + "changes." + ) + except OSError: + msg = ( + "Tried to save a patch, but couldn't create a " + "writable file " + file_name + ". Make sure it " + "doesn't exist and your working directory is " + "writable." + ) + else: + msg = ( + "you can retrieve the original source code by " + "accessing the object's source attribute or set " + "`torch.nn.Module.dump_patches = True` and use the " + "patch tool to revert the changes." + ) + msg = f"source code of class '{torch.typename(container_type)}' has changed. {msg}" + warnings.warn(msg, SourceChangeWarning) + + def legacy_load(f): + deserialized_objects: dict[int, Any] = {} + + def persistent_load(saved_id): + if isinstance(saved_id, tuple): + # Ignore containers that don't have any sources saved + if all(saved_id[1:]): + _check_container_source(*saved_id) + return saved_id[0] + return deserialized_objects[int(saved_id)] + + with ( + closing( + tarfile.open(fileobj=f, mode="r:", format=tarfile.PAX_FORMAT) + ) as tar, + mkdtemp() as tmpdir, + ): + if pickle_module is _weights_only_unpickler: + raise RuntimeError( + "Cannot use ``weights_only=True`` with files saved in the " + "legacy .tar format. " + UNSAFE_MESSAGE + ) + tar.extract("storages", path=tmpdir) + with open(os.path.join(tmpdir, "storages"), "rb", 0) as f: + num_storages = pickle_module.load(f, **pickle_load_args) + for _ in range(num_storages): + args = pickle_module.load(f, **pickle_load_args) + key, location, storage_type = args + dtype = storage_type._dtype + obj = cast(Storage, torch.UntypedStorage)._new_with_file( + f, torch._utils._element_size(dtype) + ) + obj = restore_location(obj, location) + # TODO: Once we decide to break serialization FC, we can + # stop wrapping with TypedStorage + deserialized_objects[key] = torch.storage.TypedStorage( + wrap_storage=obj, dtype=dtype, _internal=True + ) + + storage_views = pickle_module.load(f, **pickle_load_args) + for target_cdata, root_cdata, offset, numel in storage_views: + root = deserialized_objects[root_cdata] + element_size = torch._utils._element_size(root.dtype) + offset_bytes = offset * element_size + # TODO: Once we decide to break serialization FC, we can + # stop wrapping with TypedStorage + deserialized_objects[target_cdata] = torch.storage.TypedStorage( + wrap_storage=root._untyped_storage[ + offset_bytes : offset_bytes + numel * element_size + ], + dtype=root.dtype, + _internal=True, + ) + + tar.extract("tensors", path=tmpdir) + with open(os.path.join(tmpdir, "tensors"), "rb", 0) as f: + num_tensors = pickle_module.load(f, **pickle_load_args) + for _ in range(num_tensors): + args = pickle_module.load(f, **pickle_load_args) + key, storage_id, _original_tensor_type = args + storage = deserialized_objects[storage_id] + (ndim,) = struct.unpack(" str: + # When using encoding='bytes' in Py3, some **internal** keys stored as + # strings in Py2 are loaded as bytes. This function decodes them with + # ascii encoding, one that Py3 uses by default. + # + # NOTE: This should only be used on internal keys (e.g., `typename` and + # `location` in `persistent_load` below! + if isinstance(bytes_str, bytes): + return bytes_str.decode("ascii") + return bytes_str + + +def _get_restore_location(map_location): + if map_location is None: + restore_location = default_restore_location + elif isinstance(map_location, dict): + + def restore_location(storage, location): + location = map_location.get(location, location) + return default_restore_location(storage, location) + + elif isinstance(map_location, (str, bytes)): + + def restore_location(storage, location): + return default_restore_location(storage, map_location) + + elif isinstance(map_location, torch.device): + + def restore_location(storage, location): + return default_restore_location(storage, str(map_location)) + + else: + + def restore_location(storage, location): + result = map_location(storage, location) + if result is None: + result = default_restore_location(storage, location) + return result + + return restore_location + + +class StorageType: + def __init__(self, name): + self._dtype = _get_dtype_from_pickle_storage_type(name) + + @property + def dtype(self): + return self._dtype + + def __str__(self): + return f"StorageType(dtype={self.dtype})" + + +def _load( + zip_file, + map_location, + pickle_module, + pickle_file="data.pkl", + overall_storage=None, + **pickle_load_args, +): + restore_location = _get_restore_location(map_location) + + loaded_storages = {} + + can_calculate_storage_offsets = False + if zip_file.has_record(".format_version"): + version = zip_file.get_record(".format_version") + can_calculate_storage_offsets = version >= b"1" + + # check if byteswapping is needed + byteordername = "byteorder" + byteorderdata = None + if zip_file.has_record(byteordername): + byteorderdata = zip_file.get_record(byteordername) + if byteorderdata not in [b"little", b"big"]: + raise ValueError("Unknown endianness type: " + byteorderdata.decode()) + elif ( + get_default_load_endianness() == LoadEndianness.LITTLE + or get_default_load_endianness() is None + ): + byteorderdata = b"little" + elif get_default_load_endianness() == LoadEndianness.BIG: + byteorderdata = b"big" + elif get_default_load_endianness() == LoadEndianness.NATIVE: + pass + else: + raise ValueError("Invalid load endianness type") + + storage_alignment = 64 + if zip_file.has_record(".storage_alignment"): + storage_alignment = int(zip_file.get_record(".storage_alignment")) + + if ( + not zip_file.has_record(byteordername) + and get_default_load_endianness() is None + and sys.byteorder == "big" + ): + # Default behaviour was changed + # See https://github.com/pytorch/pytorch/issues/101688 + warnings.warn( + "The default load endianness for checkpoints without a byteorder mark " + "on big endian machines was changed from 'native' to 'little' endian, " + "to avoid this behavior please use " + "torch.serialization.set_default_load_endianness to set " + "the desired default load endianness", + UserWarning, + ) + + from torch.utils.serialization import config + + calculate_storage_offsets = config.load.calculate_storage_offsets + run_debug_asserts = os.environ.get("TORCH_SERIALIZATION_DEBUG", "0") == "1" + current_offset = None + # constants from miniz.h/miniz.c + data_descripter_size64 = 24 + data_descripter_size32 = 16 + mz_uint32_max = 0xFFFFFFFF + offsets: dict[str, int] = dict() + + def _get_offset(key, name, numel): + """ + Return the offset of the storage associated with key with record name `name` and size numel. + It is expected that the zipfile header of this storage starts at current_offset. + + WARNING: This function relies on the behavior of the zipwriter in miniz.c. In particular, + the behavior of `mz_zip_writer_add_mem_ex_v2`. The behavior of this function must be kept + in sync with that of miniz! + + After reading a storage of size numel that starts at storage_offset + if it is the first time that storage was read, update nonlocal variable + current_offset to the start of the next zipfile header by incrementing + it by numel and the data descriptor size. + """ + nonlocal current_offset, offsets + if name in offsets: + storage_offset = offsets[name] + return storage_offset + + if current_offset is None: + assert key == "0" + current_offset = zip_file.get_record_offset(name) + local_header_offset = zip_file.get_record_header_offset(name) + storage_offset = current_offset + else: + storage_offset = zip_file.get_record_offset_no_read( + current_offset, name, numel, storage_alignment + ) + local_header_offset = current_offset + + # This is only actually needed for storages that have typed_storage._data_ptr() == 0 + # after being read. Otherwise persistent_load would never "re-call" load_tensor + # for a given key. + offsets[name] = storage_offset + + # Increment current_offset of offset where next zipfile header starts + current_offset = storage_offset + numel + # add size of data descriptor after payload + if numel > 0: + if local_header_offset >= mz_uint32_max or numel >= mz_uint32_max: + current_offset += data_descripter_size64 + else: + current_offset += data_descripter_size32 + + return storage_offset + + def load_tensor(dtype, numel, key, location): + name = f"data/{key}" + if torch._guards.detect_fake_mode(None) is not None: + nbytes = numel * torch._utils._element_size(dtype) + storage = torch.UntypedStorage(nbytes, device="meta") + storage._checkpoint_offset = zip_file.get_record_offset(name) + elif _serialization_tls.skip_data: + nbytes = numel * torch._utils._element_size(dtype) + storage = torch.UntypedStorage(nbytes) + elif overall_storage is not None: + if can_calculate_storage_offsets and calculate_storage_offsets: + storage_offset = _get_offset(key, name, numel) + if run_debug_asserts: + if storage_offset != zip_file.get_record_offset(name): + raise RuntimeError( + "This is a debug assert that was run as the `TORCH_SERIALIZATION_DEBUG` environment " + f"variable was set: Incorrect offset for {name}, got {storage_offset} expected " + f"{zip_file.get_record_offset(name)}" + ) + else: + storage_offset = zip_file.get_record_offset(name) + storage = overall_storage[storage_offset : storage_offset + numel] + else: + if can_calculate_storage_offsets and run_debug_asserts: + # This is debug code that we use to test the validity of + # torch.utils.serialization.config.load.calculate_storage_offsets throughout CI + storage_offset = _get_offset(key, name, numel) + if storage_offset != zip_file.get_record_offset(name): + raise RuntimeError( + "This is a debug assert that was run as the `TORCH_SERIALIZATION_DEBUG` environment " + f"variable was set: Incorrect offset for {name}, got {storage_offset} expected " + f"{zip_file.get_record_offset(name)}" + ) + storage = ( + zip_file.get_storage_from_record(name, numel, torch.UntypedStorage) + ._typed_storage() + ._untyped_storage + ) + # swap here if byteswapping is needed + if byteorderdata is not None: + if byteorderdata.decode() != sys.byteorder: + storage.byteswap(dtype) + + # TODO: Once we decide to break serialization FC, we can + # stop wrapping with TypedStorage + + if torch._guards.detect_fake_mode(None) is None: + wrap_storage = restore_location(storage, location) + else: + storage._fake_device = location + wrap_storage = storage + + typed_storage = torch.storage.TypedStorage( + wrap_storage=wrap_storage, + dtype=dtype, + _internal=True, + ) + + if typed_storage._data_ptr() != 0: + loaded_storages[key] = typed_storage + + return typed_storage + + def persistent_load(saved_id): + assert isinstance(saved_id, tuple) + typename = _maybe_decode_ascii(saved_id[0]) + data = saved_id[1:] + + assert typename == "storage", ( + f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'" + ) + storage_type, key, location, numel = data + if storage_type is torch.UntypedStorage: + dtype = torch.uint8 + else: + dtype = storage_type.dtype + + if key in loaded_storages: + typed_storage = loaded_storages[key] + else: + nbytes = numel * torch._utils._element_size(dtype) + typed_storage = load_tensor( + dtype, nbytes, key, _maybe_decode_ascii(location) + ) + + return typed_storage + + load_module_mapping: dict[str, str] = { + # See https://github.com/pytorch/pytorch/pull/51633 + "torch.tensor": "torch._tensor" + } + + # Need to subclass Unpickler instead of directly monkey-patching the find_class method + # because it's marked readonly in pickle. + # The type: ignore is because mypy can't statically determine the type of this class. + class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined] + # from https://stackoverflow.com/questions/13398462/unpickling-python-objects-with-a-changed-module-path/13405732 + # Lets us override the imports that pickle uses when unpickling an object. + # This is useful for maintaining BC if we change a module path that tensor instantiation relies on. + def find_class(self, mod_name, name): + if type(name) is str and "Storage" in name: + try: + return StorageType(name) + except KeyError: + pass + mod_name = load_module_mapping.get(mod_name, mod_name) + return super().find_class(mod_name, name) + + # Load the data (which may in turn use `persistent_load` to load tensors) + data_file = io.BytesIO(zip_file.get_record(pickle_file)) + + unpickler = UnpicklerWrapper(data_file, **pickle_load_args) + unpickler.persistent_load = persistent_load + # Needed for tensors where storage device and rebuild tensor device are + # not connected (wrapper subclasses and tensors rebuilt using numpy) + global _serialization_tls + _serialization_tls.map_location = map_location + result = unpickler.load() + _serialization_tls.map_location = None + + torch._utils._validate_loaded_sparse_tensors() + torch._C._log_api_usage_metadata( + "torch.load.metadata", {"serialization_id": zip_file.serialization_id()} + ) + return result + + +def _is_torchscript_zip(zip_file): + return "constants.pkl" in zip_file.get_all_records() diff --git a/phivenv/Lib/site-packages/torch/storage.py b/phivenv/Lib/site-packages/torch/storage.py new file mode 100644 index 0000000000000000000000000000000000000000..d49e3dcce0114397aef2de85569b5a4197152ba1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/storage.py @@ -0,0 +1,1551 @@ +# mypy: allow-untyped-defs + +from __future__ import annotations + +import collections +import copy +import functools +import io +import threading +import warnings +from typing import Any, cast, Optional as _Optional, TYPE_CHECKING, TypeVar, Union +from typing_extensions import Self + +import torch +from torch._utils import _to, _type +from torch.types import _bool, _int, Storage + + +if TYPE_CHECKING: + from torch._prims_common import DeviceLikeType + + +__all__ = ["TypedStorage", "UntypedStorage"] + + +try: + import numpy as np + + HAS_NUMPY = True +except ModuleNotFoundError: + HAS_NUMPY = False + np = None # type: ignore[assignment] + + +_share_memory_lock = threading.Lock() +_share_memory_map: dict[int, threading.RLock] = {} + +T = TypeVar("T", bound="Union[_StorageBase, TypedStorage]") + + +class _StorageBase: + _cdata: Any + is_sparse: _bool = False + is_sparse_csr: _bool = False + device: torch.device + # Used when + # (1) stashing FakeTensor device onto storage in torch.serialization.skip_data + # (2) stashing device onto storage to propagate to FakeTensor when torch.load under FakeTensorMode + _fake_device: _Optional[torch.device] = None + # Used when loading with FakeTensorMode to give information about offset of storage in torch.saved-file + _checkpoint_offset: _Optional[int] = None + + def __init__(self, *args, **kwargs): + pass + + def __len__(self) -> _int: + raise NotImplementedError + + def __getitem__(self, idx): + raise NotImplementedError + + def __setitem__(self, *args, **kwargs): + raise NotImplementedError + + def copy_(self, source: T, non_blocking: _Optional[_bool] = None) -> T: + raise NotImplementedError + + def new(self) -> Union[_StorageBase, TypedStorage]: + raise NotImplementedError + + def nbytes(self) -> _int: + raise NotImplementedError + + def size(self) -> _int: + return self.nbytes() + + def type( + self, dtype: _Optional[str] = None, non_blocking: _bool = False + ) -> Union[_StorageBase, TypedStorage]: + return _type(self, dtype, non_blocking) + + def cuda( + self, device=None, non_blocking=False + ) -> Union[_StorageBase, TypedStorage]: + """Returns a copy of this object in CUDA memory. + + If this object is already in CUDA memory and on the correct device, then + no copy is performed and the original object is returned. + + Args: + device (int): The destination GPU id. Defaults to the current device. + non_blocking (bool): If ``True`` and the source is in pinned memory, + the copy will be asynchronous with respect to the host. Otherwise, + the argument has no effect. + """ + device2 = torch.device("cuda", device) if device else torch.device("cuda") + return self.to(device=device2, non_blocking=non_blocking) + + def hpu(self, device=None, non_blocking=False) -> Union[_StorageBase, TypedStorage]: + """Returns a copy of this object in HPU memory. + + If this object is already in HPU memory and on the correct device, then + no copy is performed and the original object is returned. + + Args: + device (int): The destination HPU id. Defaults to the current device. + non_blocking (bool): If ``True`` and the source is in pinned memory, + the copy will be asynchronous with respect to the host. Otherwise, + the argument has no effect. + """ + device2 = torch.device("hpu", device) if device else torch.device("hpu") + return self.to(device=device2, non_blocking=non_blocking) + + def element_size(self) -> _int: + raise NotImplementedError + + def get_device(self) -> _int: + return self.device.index + + def data_ptr(self) -> _int: + raise NotImplementedError + + def resizable(self) -> _bool: + raise NotImplementedError + + # Defined in torch/csrc/generic/StorageSharing.cpp + def _share_filename_cpu_(self, *args, **kwargs): + raise NotImplementedError + + def _share_fd_cpu_(self, *args, **kwargs): + raise NotImplementedError + + @classmethod + def _new_using_filename_cpu(cls, size: _int) -> Self: + raise NotImplementedError + + @classmethod + def _new_using_fd_cpu(cls, size: _int) -> Self: + raise NotImplementedError + + @classmethod + def from_buffer(cls, *args, **kwargs) -> Self: + raise NotImplementedError + + @classmethod + def _new_shared_filename_cpu( + cls, + manager, + obj, + size, + *, + device=None, + dtype=None, + ) -> Self: + raise NotImplementedError + + @classmethod + def _release_ipc_counter(cls, *args, device=None, **kwargs): + return cls._release_ipc_counter_cuda(*args, **kwargs) + + @classmethod + def _release_ipc_counter_cuda(cls, *args, **kwargs) -> Self: + raise NotImplementedError + + @classmethod + def _new_with_weak_ptr(cls, *args, **kwargs) -> Self: + raise NotImplementedError + + def _shared_decref(self) -> Union[_StorageBase, TypedStorage]: + raise NotImplementedError + + def _write_file(self, *args, **kwargs): + raise NotImplementedError + + def resize_(self, size: _int): + raise NotImplementedError + + def _weak_ref(self, *args, **kwargs) -> Union[_StorageBase, TypedStorage]: + raise NotImplementedError + + def _set_from_file(self, *args, **kwargs): + raise NotImplementedError + + def _set_cdata(self, *args, **kwargs): + raise NotImplementedError + + def _share_cuda_(self, *args, **kwargs): + raise NotImplementedError + + def is_shared(self) -> _bool: + raise NotImplementedError + + @classmethod + def _new_shared_cuda(cls, *args, **kwargs) -> Self: + raise NotImplementedError + + def _shared_incref(self, *args, **kwargs): + raise NotImplementedError + + @classmethod + def _free_weak_ref(cls, *args, **kwargs): + raise NotImplementedError + + @property + def is_cuda(self): + raise NotImplementedError + + @property + def is_hpu(self): + raise NotImplementedError + + @classmethod + def from_file(cls, filename, shared, nbytes) -> Union[_StorageBase, TypedStorage]: + raise NotImplementedError + + @classmethod + def _expired(cls, *args, **kwargs) -> Union[_StorageBase, TypedStorage]: + raise NotImplementedError + + def _byteswap(self, *args, **kwargs): + raise NotImplementedError + + def _get_filename(self, *args, **kwargs) -> _Optional[str]: + raise NotImplementedError + + def __repr__(self): + info_str = f"[{torch.typename(self)}(device={self.device}) of size {len(self)}]" + if self.device.type == "meta": + return "...\n" + info_str + data_str = " " + "\n ".join(str(self[i]) for i in range(self.size())) + return data_str + "\n" + info_str + + def __iter__(self): + return iter(self[i] for i in range(self.size())) + + def __copy__(self): + return self.clone() + + def __deepcopy__(self, memo): + memo = memo.setdefault("torch", {}) + if self._cdata in memo: + return memo[self._cdata] + new_storage = self.clone() + memo[self._cdata] = new_storage + return new_storage + + def __reduce__(self): + b = io.BytesIO() + torch.save(self, b, _use_new_zipfile_serialization=False) + return (_load_from_bytes, (b.getvalue(),)) + + def __sizeof__(self): + return super().__sizeof__() + self.size() + + def clone(self): + """Return a copy of this storage.""" + return type(self)(self.nbytes(), device=self.device).copy_(self) + + def tolist(self): + """Return a list containing the elements of this storage.""" + return list(self) + + def cpu(self): + """Return a CPU copy of this storage if it's not already on the CPU.""" + if self.device.type != "cpu": + return torch.UntypedStorage(self.size()).copy_(self, False) + return self + + def mps(self): + """Return a MPS copy of this storage if it's not already on the MPS.""" + if self.device.type != "mps": + return torch.UntypedStorage(self.size(), device="mps").copy_(self, False) + return self + + def _to(self, dtype): + if not isinstance(dtype, torch.dtype): + raise TypeError(f"Argument 'dtype' must be torch.dtype, not {type(dtype)}") + storage = ( + torch.tensor([], dtype=torch.uint8, device=self.device) + .set_(cast(Storage, self)) + .to(dtype) + ._typed_storage() + ) + if storage.data_ptr() == self.data_ptr(): + storage = storage.clone() + return storage + + def to(self, *, device: DeviceLikeType, non_blocking: _bool = False): + if not isinstance(device, torch.device): + device = torch.device(device) + return _to(self, device, non_blocking) + + def double(self): + """Casts this storage to double type.""" + return self._to(torch.double) + + def float(self): + """Casts this storage to float type.""" + return self._to(torch.float) + + def half(self): + """Casts this storage to half type.""" + return self._to(torch.half) + + def long(self): + """Casts this storage to long type.""" + return self._to(torch.long) + + def int(self): + """Casts this storage to int type.""" + return self._to(torch.int) + + def short(self): + """Casts this storage to short type.""" + return self._to(torch.short) + + def char(self): + """Casts this storage to char type.""" + return self._to(torch.int8) + + def byte(self): + """Casts this storage to byte type.""" + return self._to(torch.uint8) + + def bool(self): + """Casts this storage to bool type.""" + return self._to(torch.bool) + + def bfloat16(self): + """Casts this storage to bfloat16 type.""" + return self._to(torch.bfloat16) + + def complex_double(self): + """Casts this storage to complex double type.""" + return self._to(torch.cdouble) + + def complex_float(self): + """Casts this storage to complex float type.""" + return self._to(torch.cfloat) + + def float8_e5m2(self): + """Casts this storage to float8_e5m2 type""" + return self._to(torch.float8_e5m2) + + def float8_e4m3fn(self): + """Casts this storage to float8_e4m3fn type""" + return self._to(torch.float8_e4m3fn) + + def float8_e5m2fnuz(self): + """Casts this storage to float8_e5m2fnuz type""" + return self._to(torch.float8_e5m2fnuz) + + def float8_e4m3fnuz(self): + """Casts this storage to float8_e4m3fnuz type""" + return self._to(torch.float8_e4m3fnuz) + + def is_pinned(self, device: Union[str, torch.device] = "cuda"): + r"""Determine whether the CPU storage is already pinned on device. + + Args: + device (str or torch.device): The device to pin memory on (default: ``'cuda'``). + This argument is discouraged and subject to deprecated. + + Returns: + A boolean variable. + """ + return ( + torch.tensor([], dtype=torch.uint8, device=self.device) + .set_(cast(Storage, self)) + .is_pinned(device) + ) + + def pin_memory(self, device: Union[str, torch.device] = "cuda"): + r"""Copy the CPU storage to pinned memory, if it's not already pinned. + + Args: + device (str or torch.device): The device to pin memory on (default: ``'cuda'``). + This argument is discouraged and subject to deprecated. + + Returns: + A pinned CPU storage. + """ + if self.device.type != "cpu": + raise TypeError(f"cannot pin '{self.type()}' only CPU memory can be pinned") + + pinned_tensor = ( + torch.tensor([], dtype=torch.uint8, device=self.device) + .set_(cast(Storage, self)) + .pin_memory(device) + ) + return pinned_tensor.untyped_storage() + + def share_memory_(self): + """See :meth:`torch.UntypedStorage.share_memory_`""" + from torch.multiprocessing import get_sharing_strategy + + if self.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]: + pass # CUDA or PrivateUse1 doesn't use POSIX shared memory + elif get_sharing_strategy() == "file_system": + self._share_filename_cpu_() + else: + self._share_fd_cpu_() + return self + + @classmethod + def _new_shared(cls, size, *, device="cpu"): + """Create a new storage in shared memory with the same data type.""" + from torch.multiprocessing import get_sharing_strategy + + device = torch.device(device) + if device.type in ["cuda", torch._C._get_privateuse1_backend_name(), "hpu"]: + return cls(size, device=device) + elif get_sharing_strategy() == "file_system": + return cls._new_using_filename_cpu(size) + else: + return cls._new_using_fd_cpu(size) + + def untyped(self): + return self + + def byteswap(self, dtype): + """Swap bytes in underlying data.""" + elem_size = torch._utils._element_size(dtype) + # for complex types, don't swap first and second numbers + if dtype.is_complex: + elem_size = max(int(elem_size / 2), 1) + self._byteswap(elem_size) + + +def _share_memory_lock_protected(fn): + @functools.wraps(fn) + def wrapper(self, *args, **kwargs): + to_free = None + to_wait = None + with _share_memory_lock: + key = self._cdata + if key in _share_memory_map: + to_wait = _share_memory_map[key] + else: + _share_memory_map[key] = threading.RLock() + _share_memory_map[key].acquire() + to_free = key + + # If we're already in the process of sharing the storage, wait + # for it to be done. + if to_wait is not None: + with to_wait: + pass + + try: + return fn(self, *args, **kwargs) + finally: + # If we acquired the storage lock here and we're done working on it + # we can now release it and free the entry. + if to_free is not None: + # Ensure that the cdata from the storage didn't change and only + # the data_ptr did. + assert self._cdata == to_free + with _share_memory_lock: + _share_memory_map[to_free].release() + del _share_memory_map[to_free] + + return wrapper + + +class UntypedStorage(torch._C.StorageBase, _StorageBase): + def __getitem__(self, *args, **kwargs): + if self.device.type == "meta": + raise NotImplementedError("Not available for 'meta' device type") + return super().__getitem__(*args, **kwargs) + + @property + def is_cuda(self): + return self.device.type == "cuda" + + @property + def is_hpu(self): + return self.device.type == "hpu" + + @property + def filename(self) -> _Optional[str]: + """Returns the file name associated with this storage. + + The file name will be a string if the storage is on CPU and was created via + :meth:`~torch.from_file()` with ``shared`` as ``True``. This attribute is ``None`` otherwise. + """ + return self._get_filename() + + @_share_memory_lock_protected + def share_memory_(self, *args, **kwargs): + """ + Moves the storage to shared memory. + + This is a no-op for storages already in shared memory and for CUDA + storages, which do not need to be moved for sharing across processes. + Storages in shared memory cannot be resized. + + Note that to mitigate issues like `this `_ + it is thread safe to call this function from multiple threads on the same object. + It is NOT thread safe though to call any other function on self without proper + synchronization. Please see :doc:`/notes/multiprocessing` for more details. + + .. note:: + When all references to a storage in shared memory are deleted, the associated shared memory + object will also be deleted. PyTorch has a special cleanup process to ensure that this happens + even if the current process exits unexpectedly. + + It is worth noting the difference between :meth:`share_memory_` and :meth:`from_file` with ``shared = True`` + + #. ``share_memory_`` uses `shm_open(3) `_ to create a + POSIX shared memory object while :meth:`from_file` uses + `open(2) `_ to open the filename passed by the user. + #. Both use an `mmap(2) call `_ with ``MAP_SHARED`` + to map the file/object into the current virtual address space + #. ``share_memory_`` will call ``shm_unlink(3)`` on the object after mapping it to make sure the shared memory + object is freed when no process has the object open. ``torch.from_file(shared=True)`` does not unlink the + file. This file is persistent and will remain until it is deleted by the user. + + Returns: + ``self`` + """ + return super().share_memory_(*args, **kwargs) + + @_share_memory_lock_protected + def _share_fd_cpu_(self, *args, **kwargs): + return super()._share_fd_cpu_(*args, **kwargs) + + @_share_memory_lock_protected + def _share_filename_cpu_(self, *args, **kwargs): + return super()._share_filename_cpu_(*args, **kwargs) + + +def _load_from_bytes(b): + return torch.load(io.BytesIO(b), weights_only=False) + + +@functools.cache +def _new_dtypes(): + # These are dtypes serialized as UntypedStorage unlike those in + # _dtype_to_storage_type_map + return { + torch.float8_e5m2, + torch.float8_e4m3fn, + torch.float8_e5m2fnuz, + torch.float8_e4m3fnuz, + torch.float8_e8m0fnu, + torch.float4_e2m1fn_x2, + torch.bits8, + torch.bits16, + torch.bits1x8, + torch.bits2x4, + torch.bits4x2, + torch.complex32, + torch.uint16, + torch.uint32, + torch.uint64, + } + + +@functools.cache +def _dtype_to_storage_type_map(): + # NOTE: We should no longer add dtypes to this map. This map + # is only used for BC/FC with older PyTorch versions. Going forward, + # new dtypes of TypedStorage should not translate to a legacy + # Storage class. Instead, new dtypes of TypedStorage should + # be serialized as an UntypedStorage paired with a torch.dtype + return { + torch.double: "DoubleStorage", + torch.float: "FloatStorage", + torch.half: "HalfStorage", + torch.long: "LongStorage", + torch.int: "IntStorage", + torch.int16: "ShortStorage", + torch.int8: "CharStorage", + torch.uint8: "ByteStorage", + torch.bool: "BoolStorage", + torch.bfloat16: "BFloat16Storage", + torch.cdouble: "ComplexDoubleStorage", + torch.cfloat: "ComplexFloatStorage", + torch.qint8: "QInt8Storage", + torch.qint32: "QInt32Storage", + torch.quint8: "QUInt8Storage", + torch.quint4x2: "QUInt4x2Storage", + torch.quint2x4: "QUInt2x4Storage", + } + + +@functools.cache +def _storage_type_to_dtype_map(): + dtype_map = {val: key for key, val in _dtype_to_storage_type_map().items()} + return dtype_map + + +def _get_storage_from_sequence(sequence, dtype, device): + if dtype in [ + torch.quint8, + torch.quint4x2, + torch.quint2x4, + torch.qint32, + torch.qint8, + ]: + interpret_dtypes = { + torch.quint8: torch.uint8, + torch.quint4x2: torch.uint8, + torch.quint2x4: torch.uint8, + torch.qint32: torch.int32, + torch.qint8: torch.int8, + } + tmp_tensor = torch.tensor( + sequence, dtype=interpret_dtypes[dtype], device=device + ) + + else: + tmp_tensor = torch.tensor(sequence, dtype=dtype, device=device) + + return tmp_tensor._typed_storage()._untyped_storage + + +def _isint(x): + if HAS_NUMPY: + return isinstance(x, (int, np.integer)) + else: + return isinstance(x, int) + + +_always_warn_typed_storage_removal = False + + +def _get_always_warn_typed_storage_removal(): + return _always_warn_typed_storage_removal + + +def _set_always_warn_typed_storage_removal(always_warn): + global _always_warn_typed_storage_removal + assert isinstance(always_warn, bool) + _always_warn_typed_storage_removal = always_warn + + +def _warn_typed_storage_removal(stacklevel=2): + global _always_warn_typed_storage_removal + + def is_first_time(): + if not hasattr(_warn_typed_storage_removal, "has_warned"): + return True + else: + return not _warn_typed_storage_removal.__dict__["has_warned"] + + if _get_always_warn_typed_storage_removal() or is_first_time(): + message = ( + "TypedStorage is deprecated. It will be removed in the future and " + "UntypedStorage will be the only storage class. This should only matter " + "to you if you are using storages directly. To access UntypedStorage " + "directly, use tensor.untyped_storage() instead of tensor.storage()" + ) + warnings.warn(message, UserWarning, stacklevel=stacklevel + 1) + _warn_typed_storage_removal.__dict__["has_warned"] = True + + +def _reset_warn_typed_storage_removal(): + _warn_typed_storage_removal.__dict__["has_warned"] = False + + +def _get_device_from_module(module: str): + last_part = module.rsplit(".", 1)[-1] + if last_part in ["cuda", torch._C._get_privateuse1_backend_name(), "hpu"]: + return last_part + else: + return "cpu" + + +class TypedStorage: + is_sparse: _bool = False + # Used when stashing FakeTensor device onto storage in torch.save(metadata_only=True) + _fake_device: _Optional[torch.device] = None + + dtype: torch.dtype + + @property + def _dtype(self): + return self.dtype + + @property + def filename(self) -> _Optional[str]: + """Returns the file name associated with this storage if the storage was memory mapped from a file. + or ``None`` if the storage was not created by memory mapping a file.""" + return self._untyped_storage.filename + + def fill_(self, value): + _warn_typed_storage_removal() + self._setitem(slice(0, self._size()), value) + return self + + def __new__( + cls, + *args, + wrap_storage=None, + dtype=None, + device=None, + _internal=False, + ): + if not _internal: + _warn_typed_storage_removal() + + if cls == torch.storage._LegacyStorage: + raise RuntimeError( + "Only child classes of _LegacyStorage can be instantiated" + ) + + if cls == TypedStorage: + return super().__new__(cls) + + else: + arg_error_msg = ( + f"{cls}.__new__ received an invalid combination " + f"of arguments. Expected one of:\n" + " * no arguments\n" + " * (int size)\n" + " * (Sequence data)\n" + " * (*, UntypedStorage wrap_storage)" + ) + + if device is not None: + raise RuntimeError( + arg_error_msg + "\nKeyword argument 'device' cannot be specified" + ) + + if dtype is not None: + raise RuntimeError( + arg_error_msg + "\nKeyword argument 'dtype' cannot be specified" + ) + + if wrap_storage is None: + if len(args) > 1: + raise RuntimeError( + arg_error_msg + "\nToo many positional arguments" + ) + + if ( + len(args) == 1 + and not _isint(args[0]) + and not isinstance(args[0], collections.abc.Sequence) + ): + raise TypeError( + arg_error_msg + + f"\nArgument type not recognized: {type(args[0])}" + ) + + return TypedStorage( + *args, + dtype=cls._dtype, + device=_get_device_from_module(cls.__module__), + _internal=True, + ) + + else: + if len(args) != 0: + raise RuntimeError( + arg_error_msg + + "\nNo positional arguments should be given when using " + "'wrap_storage'" + ) + + if not isinstance(wrap_storage, torch.UntypedStorage): + raise TypeError( + arg_error_msg + + f"\nArgument 'wrap_storage' must be UntypedStorage, but got {type(wrap_storage)}" + ) + + cls_device = _get_device_from_module(cls.__module__) + + if wrap_storage.device.type != cls_device: + raise RuntimeError( + arg_error_msg + + f"\nDevice of 'wrap_storage' must be {cls_device}" + f", but got {wrap_storage.device.type}" + ) + + return TypedStorage( + *args, + wrap_storage=wrap_storage, + dtype=cls.dtype, + _internal=True, + ) + + def __init__( + self, + *args, + device=None, + dtype=None, + wrap_storage=None, + _internal=False, + ): + if not _internal: + _warn_typed_storage_removal() + arg_error_msg = ( + "TypedStorage.__init__ received an invalid combination " + "of arguments. Expected one of:\n" + " * (*, torch.device device, torch.dtype dtype)\n" + " * (int size, *, torch.device device, torch.dtype dtype)\n" + " * (Sequence data, *, torch.device device, torch.dtype dtype)\n" + " * (*, UntypedStorage wrap_storage, torch.dtype dtype)" + ) + + if wrap_storage is not None: + if len(args) != 0: + raise RuntimeError( + arg_error_msg + + "\nNo positional arguments should be given when using " + "'wrap_storage'" + ) + + if dtype is None: + raise RuntimeError( + arg_error_msg + "\nArgument 'dtype' must be specified" + ) + + if not isinstance(dtype, torch.dtype): + raise TypeError( + arg_error_msg + + f"\nArgument 'dtype' must be torch.dtype, not {type(dtype)}" + ) + + if device is not None: + raise RuntimeError( + arg_error_msg + + "\nArgument 'device' should not be specified when 'wrap_storage' is given" + ) + + self.dtype = dtype + + if not isinstance(wrap_storage, torch.UntypedStorage): + raise TypeError( + arg_error_msg + + f"\nArgument 'wrap_storage' must be UntypedStorage, but got {type(wrap_storage)}" + ) + + self._untyped_storage = wrap_storage + + else: + self.dtype = torch.get_default_dtype() if dtype is None else dtype + device = torch.device("cpu" if device is None else device) + + if self.dtype in [ + torch.quint8, + torch.quint4x2, + torch.quint2x4, + torch.qint32, + torch.qint8, + ]: + if device.type == "cuda": + raise RuntimeError( + "Cannot create CUDA storage with quantized dtype" + ) + + if len(args) == 0: + self._untyped_storage = torch.UntypedStorage(device=device) + + elif len(args) == 1: + if _isint(args[0]): + self._untyped_storage = torch.UntypedStorage( + int(args[0]) * self._element_size(), device=device + ) + elif isinstance(args[0], collections.abc.Sequence): + self._untyped_storage = _get_storage_from_sequence( + args[0], self.dtype, device + ) + else: + raise TypeError( + arg_error_msg + + f"\nArgument type not recognized: {type(args[0])}" + ) + + else: + raise RuntimeError(arg_error_msg + "\nToo many positional arguments") + + @property + def is_cuda(self): + _warn_typed_storage_removal() + return self._untyped_storage.device.type == "cuda" + + @property + def is_hpu(self): + _warn_typed_storage_removal() + return self._untyped_storage.device.type == "hpu" + + def untyped(self): + """Return the internal :class:`torch.UntypedStorage`.""" + _warn_typed_storage_removal() + return self._untyped_storage + + def _new_wrapped_storage(self, untyped_storage) -> Self: + assert type(untyped_storage) == torch.UntypedStorage + + if type(self) == TypedStorage: + return cast( + Self, + TypedStorage( + wrap_storage=untyped_storage, dtype=self.dtype, _internal=True + ), + ) + else: + return type(self)(wrap_storage=untyped_storage) + + def __len__(self): + _warn_typed_storage_removal() + return self._size() + + def _maybe_wrap_index(self, idx, is_stop=False): + if idx is None: + if is_stop: + return self._size() + else: + return 0 + + else: + if type(idx) != int: + raise TypeError(f"can't index a {type(self)} with {type(idx)}") + if is_stop: + if (idx > self._size()) or (idx < -self._size()): + raise IndexError( + f"index {idx} out of range for storage of size {self.size()}" + ) + if idx > 0: + return idx + else: + return idx % self._size() + else: + if (idx >= self._size()) or (idx < -self._size()): + raise IndexError( + f"index {idx} out of range for storage of size {self.size()}" + ) + return idx % self._size() + + def __setitem__(self, idx, value): + _warn_typed_storage_removal() + return self._setitem(idx, value) + + def _setitem(self, idx, value): + if not isinstance(idx, (int, slice)): + raise RuntimeError(f"can't index a {type(self)} with {type(idx)}") + if torch.is_storage(value): + raise RuntimeError(f"cannot set item with value type {type(value)}") + if self.dtype in [ + torch.quint8, + torch.quint4x2, + torch.quint2x4, + torch.qint32, + torch.qint8, + ]: + interpret_dtypes = { + torch.quint8: torch.uint8, + torch.quint4x2: torch.uint8, + torch.quint2x4: torch.uint8, + torch.qint32: torch.int32, + torch.qint8: torch.int8, + } + tmp_dtype = interpret_dtypes[self.dtype] + tmp_tensor = torch.tensor( + [], dtype=tmp_dtype, device=self._untyped_storage.device + ) + tmp_tensor.set_( + TypedStorage( + wrap_storage=self._untyped_storage, dtype=tmp_dtype, _internal=True + ) + ) + else: + tmp_tensor = torch.tensor( + [], dtype=self.dtype, device=self._untyped_storage.device + ).set_(self) + + tmp_tensor[idx] = value + + def __getitem__(self, idx): + _warn_typed_storage_removal() + return self._getitem(idx) + + def _getitem(self, idx): + if self._untyped_storage.device.type == "meta": + raise NotImplementedError("Not available for 'meta' device type") + + # NOTE: Before TypedStorage existed, indexing with a slice used to be + # possible for Storage objects. However, it would return + # a storage view, which would be a hassle to implement in TypedStorage, + # so it was disabled + if isinstance(idx, slice): + raise RuntimeError( + "slices are only supported in UntypedStorage.__getitem__" + ) + elif not isinstance(idx, int): + raise RuntimeError(f"can't index a {type(self)} with {type(idx)}") + + if self.dtype in [ + torch.quint8, + torch.quint4x2, + torch.quint2x4, + torch.qint32, + torch.qint8, + ]: + interpret_dtypes = { + torch.quint8: torch.uint8, + torch.quint4x2: torch.uint8, + torch.quint2x4: torch.uint8, + torch.qint32: torch.int32, + torch.qint8: torch.int8, + } + return TypedStorage( + wrap_storage=self._untyped_storage, + dtype=interpret_dtypes[self.dtype], + _internal=True, + )._getitem(idx) + + idx_wrapped = self._maybe_wrap_index(idx) + from torch._subclasses.fake_tensor import unset_fake_temporarily + + with unset_fake_temporarily(): + tmp_tensor = torch.tensor( + [], dtype=self.dtype, device=self._untyped_storage.device + ).set_(self) + return tmp_tensor[idx_wrapped].item() + + def copy_(self, source: T, non_blocking: _Optional[bool] = None): + _warn_typed_storage_removal() + if isinstance(source, TypedStorage): + self._untyped_storage.copy_(source._untyped_storage, non_blocking) + else: + self._untyped_storage.copy_(source, non_blocking) + return self + + def nbytes(self): + _warn_typed_storage_removal() + return self._nbytes() + + # For internal use only, to avoid deprecation warning + def _nbytes(self): + return self._untyped_storage.nbytes() + + def type( + self, + dtype: _Optional[str] = None, + non_blocking: bool = False, + ) -> Union[_StorageBase, TypedStorage, str]: + _warn_typed_storage_removal() + if dtype is None: + legacy_class = self._get_legacy_storage_class() + + if legacy_class is not None: + return legacy_class.__module__ + "." + legacy_class.__name__ + + return ".".join([self.__module__, type(self).__name__]) + + else: + return self._untyped_storage.type(dtype, non_blocking) + + def cuda(self, device=None, non_blocking=False) -> Self: + _warn_typed_storage_removal() + if self.dtype in [ + torch.quint8, + torch.quint4x2, + torch.quint2x4, + torch.qint32, + torch.qint8, + ]: + raise RuntimeError("Cannot create CUDA storage with quantized dtype") + cuda_storage = self._untyped_storage.cuda(device, non_blocking) + return self._new_wrapped_storage(cuda_storage) + + def hpu(self, device=None, non_blocking=False) -> Self: + _warn_typed_storage_removal() + if self.dtype in [ + torch.quint8, + torch.quint4x2, + torch.quint2x4, + torch.qint32, + torch.qint8, + ]: + raise RuntimeError("Cannot create HPU storage with quantized dtype") + hpu_storage = self._untyped_storage.hpu(device, non_blocking) + return self._new_wrapped_storage(hpu_storage) + + def to(self, *, device: DeviceLikeType, non_blocking: bool = False) -> Self: + _warn_typed_storage_removal() + if not isinstance(device, torch.device): + device = torch.device(device) + if self.dtype in [ + torch.quint8, + torch.quint4x2, + torch.quint2x4, + torch.qint32, + torch.qint8, + ]: + raise RuntimeError( + f"Cannot create {device.type.upper()} storage with quantized dtype" + ) + to_storage = self._untyped_storage.to(device=device, non_blocking=non_blocking) + return self._new_wrapped_storage(to_storage) + + def element_size(self): + _warn_typed_storage_removal() + return self._element_size() + + # For internal use only, to avoid deprecation warning + def _element_size(self): + return torch._utils._element_size(self.dtype) + + def get_device(self) -> _int: + _warn_typed_storage_removal() + return self._untyped_storage.get_device() + + def __str__(self): + _warn_typed_storage_removal() + info_str = ( + f"[{torch.typename(self)}(dtype={self.dtype}, " + f"device={self.device}) of size {len(self)}]" + ) + if self.device.type == "meta": + return "...\n" + info_str + else: + data_str = " " + "\n ".join(str(self[i]) for i in range(self.size())) + return data_str + "\n" + info_str + + def __repr__(self): + _warn_typed_storage_removal() + return str(self) + + def __iter__(self): + _warn_typed_storage_removal() + return iter(self[i] for i in range(self.size())) + + def __copy__(self): + _warn_typed_storage_removal() + return self._new_wrapped_storage(copy.copy(self._untyped_storage)) + + def __deepcopy__(self, memo): + _warn_typed_storage_removal() + return self._deepcopy(memo) + + # For internal use only, to avoid deprecation warning + def _deepcopy(self, memo): + return self._new_wrapped_storage(copy.deepcopy(self._untyped_storage, memo)) + + def __sizeof__(self): + _warn_typed_storage_removal() + return super().__sizeof__() + self.nbytes() + + def clone(self): + """Return a copy of this storage.""" + _warn_typed_storage_removal() + return self._new_wrapped_storage(self._untyped_storage.clone()) + + def tolist(self): + """Return a list containing the elements of this storage.""" + _warn_typed_storage_removal() + return list(self) + + def cpu(self): + """Return a CPU copy of this storage if it's not already on the CPU.""" + _warn_typed_storage_removal() + return self._new_wrapped_storage(self._untyped_storage.cpu()) + + def is_pinned(self, device: Union[str, torch.device] = "cuda"): + r"""Determine whether the CPU TypedStorage is already pinned on device. + + Args: + device (str or torch.device): The device to pin memory on (default: ``'cuda'``). + This argument is discouraged and subject to deprecated. + + Returns: + A boolean variable. + """ + _warn_typed_storage_removal() + return self._untyped_storage.is_pinned(device) + + def pin_memory(self, device: Union[str, torch.device] = "cuda"): + r"""Copy the CPU TypedStorage to pinned memory, if it's not already pinned. + + Args: + device (str or torch.device): The device to pin memory on (default: ``'cuda'``). + This argument is discouraged and subject to deprecated. + + Returns: + A pinned CPU storage. + """ + _warn_typed_storage_removal() + return self._new_wrapped_storage( + self._untyped_storage.pin_memory(device=device) + ) + + def share_memory_(self): + """See :meth:`torch.UntypedStorage.share_memory_`""" + _warn_typed_storage_removal() + return self._share_memory_() + + # For internal use only, to avoid deprecation warning + def _share_memory_(self): + self._untyped_storage.share_memory_() + return self + + def _new_shared(self, size, *, device=None): + """Create a new storage in shared memory with the same data type.""" + if device is None: + device = "cpu" + device = torch.device(device) + untyped_storage = torch.UntypedStorage._new_shared( + size * self._element_size(), device=device + ) + return TypedStorage( + wrap_storage=untyped_storage, dtype=self.dtype, _internal=True + ) + + @property + def _cdata(self): + return self._untyped_storage._cdata + + @property + def device(self): + _warn_typed_storage_removal() + return self._untyped_storage.device + + def size(self): + _warn_typed_storage_removal() + return self._size() + + # For internal use only, to avoid deprecation warning + def _size(self): + # NB: don't indirect through __len__, as that requires + # an int to be returned + return self._untyped_storage.nbytes() // self._element_size() + + def pickle_storage_type(self): + _warn_typed_storage_removal() + return self._pickle_storage_type() + + # For internal use only, to avoid deprecation warning + def _pickle_storage_type(self): + try: + return _dtype_to_storage_type_map()[self.dtype] + except KeyError as e: + raise KeyError(f"dtype {self.dtype} is not recognized") from e + + def __reduce__(self): + b = io.BytesIO() + torch.save(self, b, _use_new_zipfile_serialization=False) + return (_load_from_bytes, (b.getvalue(),)) + + def data_ptr(self): + _warn_typed_storage_removal() + return self._data_ptr() + + # For internal use only, to avoid deprecation warning + def _data_ptr(self): + return self._untyped_storage.data_ptr() + + def resizable(self): + _warn_typed_storage_removal() + return self._untyped_storage.resizable() + + def resize_(self, size): + _warn_typed_storage_removal() + self._resize_(size) + + # For internal use only, to avoid deprecation warning + def _resize_(self, size): + self._untyped_storage.resize_(size * self._element_size()) + + @classmethod + def _free_weak_ref(cls, *args, **kwargs): + return UntypedStorage._free_weak_ref(*args, **kwargs) + + def _weak_ref(self, *args, **kwargs): + return self._untyped_storage._weak_ref(*args, **kwargs) + + @classmethod + def from_buffer(cls, *args, **kwargs): + _warn_typed_storage_removal() + return cls._from_buffer(*args, **kwargs) + + @classmethod + def _from_buffer(cls, *args, dtype=None, device=None, **kwargs): + if cls == TypedStorage: + dtype = torch.get_default_dtype() if dtype is None else dtype + device = torch.device("cpu" if device is None else device) + if device.type != "cpu": + raise RuntimeError( + f"TypedStorage.from_buffer: Not available for device {device.type}" + ) + untyped_storage: torch.UntypedStorage = torch.UntypedStorage.from_buffer( + *args, dtype=dtype, **kwargs + ) + + else: + if dtype is not None or len(args) == 5: + raise RuntimeError( + "from_buffer: 'dtype' can only be specified in " + "UntypedStorage.from_buffer and TypedStorage.from_buffer" + ) + if device is not None: + raise RuntimeError( + "from_buffer: 'device' can only be specified in " + "UntypedStorage.from_buffer and TypedStorage.from_buffer" + ) + + dtype = cls._dtype + untyped_storage = torch.UntypedStorage.from_buffer( + *args, dtype=dtype, **kwargs + ) + + return TypedStorage(wrap_storage=untyped_storage, dtype=dtype, _internal=True) + + def _to(self, dtype): + if not isinstance(dtype, torch.dtype): + raise TypeError(f"Argument 'dtype' must be torch.dtype, not {type(dtype)}") + storage = ( + torch.tensor([], dtype=self.dtype, device=self.device) + .set_(self) + .to(dtype) + ._typed_storage() + ) + if storage.data_ptr() == self.data_ptr(): + storage = storage.clone() + return storage + + def double(self): + """Casts this storage to double type.""" + _warn_typed_storage_removal() + return self._to(torch.double) + + def float(self): + """Casts this storage to float type.""" + _warn_typed_storage_removal() + return self._to(torch.float) + + def half(self): + """Casts this storage to half type.""" + _warn_typed_storage_removal() + return self._to(torch.half) + + def long(self): + """Casts this storage to long type.""" + _warn_typed_storage_removal() + return self._to(torch.long) + + def int(self): + """Casts this storage to int type.""" + _warn_typed_storage_removal() + return self._to(torch.int) + + def short(self): + """Casts this storage to short type.""" + _warn_typed_storage_removal() + return self._to(torch.short) + + def char(self): + """Casts this storage to char type.""" + _warn_typed_storage_removal() + return self._to(torch.int8) + + def byte(self): + """Casts this storage to byte type.""" + _warn_typed_storage_removal() + return self._to(torch.uint8) + + def bool(self): + """Casts this storage to bool type.""" + _warn_typed_storage_removal() + return self._to(torch.bool) + + def bfloat16(self): + """Casts this storage to bfloat16 type.""" + _warn_typed_storage_removal() + return self._to(torch.bfloat16) + + def complex_double(self): + """Casts this storage to complex double type.""" + _warn_typed_storage_removal() + return self._to(torch.cdouble) + + def complex_float(self): + """Casts this storage to complex float type.""" + _warn_typed_storage_removal() + return self._to(torch.cfloat) + + def float8_e5m2(self): + """Casts this storage to float8_e5m2 type""" + _warn_typed_storage_removal() + return self._to(torch.float8_e5m2) + + def float8_e4m3fn(self): + """Casts this storage to float8_e4m3fn type""" + _warn_typed_storage_removal() + return self._to(torch.float8_e4m3fn) + + def float8_e5m2fnuz(self): + """Casts this storage to float8_e5m2fnuz type""" + _warn_typed_storage_removal() + return self._to(torch.float8_e5m2fnuz) + + def float8_e4m3fnuz(self): + """Casts this storage to float8_e4m3fnuz type""" + _warn_typed_storage_removal() + return self._to(torch.float8_e4m3fnuz) + + @classmethod + def from_file(cls, filename, shared, size): + """from_file(filename, shared=False, size=0) -> Storage + + Creates a CPU storage backed by a memory-mapped file. + + If ``shared`` is ``True``, then memory is shared between all processes. + All changes are written to the file. If ``shared`` is ``False``, then the changes on + the storage do not affect the file. + + ``size`` is the number of elements in the storage. If ``shared`` is ``False``, + then the file must contain at least ``size * sizeof(Type)`` bytes + (``Type`` is the type of storage). If ``shared`` is ``True`` the file will be created if needed. + + Args: + filename (str): file name to map + shared (bool): whether to share memory (whether ``MAP_SHARED`` or ``MAP_PRIVATE`` is passed to the + underlying `mmap(2) call `_) + size (int): number of elements in the storage + """ + _warn_typed_storage_removal() + if cls == TypedStorage: + raise RuntimeError("from_file can only be called on derived classes") + untyped_storage = UntypedStorage.from_file( + filename, shared, size * torch._utils._element_size(cls.dtype) + ) + storage = cls(wrap_storage=untyped_storage) + return storage + + @classmethod + def _expired(cls, *args, **kwargs): + return UntypedStorage._expired(*args, **kwargs) + + def _write_file(self, *args, **kwargs): + return self._untyped_storage._write_file(*args, **kwargs) + + def _set_from_file(self, *args, **kwargs): + return self._untyped_storage._set_from_file(*args, **kwargs) + + def _set_cdata(self, *args, **kwargs): + return self._untyped_storage._set_cdata(*args, **kwargs) + + def _share_cuda_(self, *args, **kwargs): + return self._untyped_storage._share_cuda_(*args, **kwargs) + + def is_shared(self): + _warn_typed_storage_removal() + return self._is_shared() + + # For internal use only, to avoid deprecation warning + def _is_shared(self): + return self._untyped_storage.is_shared() + + @classmethod + def _new_shared_cuda(cls, *args, **kwargs): + return torch.UntypedStorage._new_shared_cuda(*args, **kwargs) + + def _share_filename_cpu_(self, *args, **kwargs): + ( + manager_handle, + storage_handle, + size, + ) = self._untyped_storage._share_filename_cpu_(*args, **kwargs) + return manager_handle, storage_handle, size // self._element_size() + + def _shared_decref(self): + self._untyped_storage._shared_decref() + return self + + @classmethod + def _release_ipc_counter(cls, *args, device=None, **kwargs): + return torch.UntypedStorage._release_ipc_counter_cuda(*args, **kwargs) + + def _shared_incref(self, *args, **kwargs): + return self._untyped_storage._shared_incref(*args, **kwargs) + + def _share_fd_cpu_(self, *args, **kwargs): + fd, size = self._untyped_storage._share_fd_cpu_(*args, **kwargs) + return fd, size // self._element_size() + + def _get_legacy_storage_class(self): + if self.dtype not in _dtype_to_storage_type_map(): + return None + + storage_name = _dtype_to_storage_type_map()[self.dtype] + + if self.device.type not in [ + "cpu", + "cuda", + "hpu", + torch._C._get_privateuse1_backend_name(), + ]: + return None + + module = ( + torch if self.device.type == "cpu" else getattr(torch, self.device.type) + ) + + try: + return getattr(module, storage_name) + except AttributeError: + return None + + +TypedStorage.type.__doc__ = _type.__doc__ +TypedStorage.cuda.__doc__ = _StorageBase.cuda.__doc__ +TypedStorage.hpu.__doc__ = _StorageBase.hpu.__doc__ +TypedStorage.to.__doc__ = _to.__doc__ + + +class _LegacyStorageMeta(type): + dtype: torch.dtype + + def __instancecheck__(cls, instance): + if type(instance) == TypedStorage: + cls_device = _get_device_from_module(cls.__module__) + return (cls_device == instance.device.type) and ( + cls.dtype == instance.dtype + ) + return False + + +class _LegacyStorage(TypedStorage, metaclass=_LegacyStorageMeta): + @classmethod + def _new_shared(cls, size): # type: ignore[override] + """Create a new storage in shared memory with the same data type.""" + untyped_storage = torch.UntypedStorage._new_shared(size * cls()._element_size()) + return cls(wrap_storage=untyped_storage) + + @classmethod + def _release_ipc_counter(cls, *args, **kwargs): + return torch.UntypedStorage._release_ipc_counter_cuda(*args, **kwargs) + + @classmethod + def _new_shared_filename(cls, manager, obj, size): + bytes_size = size * torch._utils._element_size(cls.dtype) + return cls( + wrap_storage=torch.UntypedStorage._new_shared_filename_cpu( + manager, obj, bytes_size + ) + ) + + +def _get_dtype_from_pickle_storage_type(pickle_storage_type: str): + try: + return _storage_type_to_dtype_map()[pickle_storage_type] + except KeyError as e: + raise KeyError( + f'pickle storage type "{pickle_storage_type}" is not recognized' + ) from e diff --git a/phivenv/Lib/site-packages/torch/torch_version.py b/phivenv/Lib/site-packages/torch/torch_version.py new file mode 100644 index 0000000000000000000000000000000000000000..5cbfec5cc5855b21af23de61c272bffeac3af6b1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/torch_version.py @@ -0,0 +1,66 @@ +from collections.abc import Iterable +from typing import Any + +from torch._vendor.packaging.version import InvalidVersion, Version +from torch.version import __version__ as internal_version + + +__all__ = ["TorchVersion"] + + +class TorchVersion(str): + """A string with magic powers to compare to both Version and iterables! + Prior to 1.10.0 torch.__version__ was stored as a str and so many did + comparisons against torch.__version__ as if it were a str. In order to not + break them we have TorchVersion which masquerades as a str while also + having the ability to compare against both packaging.version.Version as + well as tuples of values, eg. (1, 2, 1) + Examples: + Comparing a TorchVersion object to a Version object + TorchVersion('1.10.0a') > Version('1.10.0a') + Comparing a TorchVersion object to a Tuple object + TorchVersion('1.10.0a') > (1, 2) # 1.2 + TorchVersion('1.10.0a') > (1, 2, 1) # 1.2.1 + Comparing a TorchVersion object against a string + TorchVersion('1.10.0a') > '1.2' + TorchVersion('1.10.0a') > '1.2.1' + """ + + __slots__ = () + + # fully qualified type names here to appease mypy + def _convert_to_version(self, inp: Any) -> Any: + if isinstance(inp, Version): + return inp + elif isinstance(inp, str): + return Version(inp) + elif isinstance(inp, Iterable): + # Ideally this should work for most cases by attempting to group + # the version tuple, assuming the tuple looks (MAJOR, MINOR, ?PATCH) + # Examples: + # * (1) -> Version("1") + # * (1, 20) -> Version("1.20") + # * (1, 20, 1) -> Version("1.20.1") + return Version(".".join(str(item) for item in inp)) + else: + raise InvalidVersion(inp) + + def _cmp_wrapper(self, cmp: Any, method: str) -> bool: + try: + return getattr(Version(self), method)(self._convert_to_version(cmp)) + except BaseException as e: + if not isinstance(e, InvalidVersion): + raise + # Fall back to regular string comparison if dealing with an invalid + # version like 'parrot' + return getattr(super(), method)(cmp) + + +for cmp_method in ["__gt__", "__lt__", "__eq__", "__ge__", "__le__"]: + setattr( + TorchVersion, + cmp_method, + lambda x, y, method=cmp_method: x._cmp_wrapper(y, method), + ) + +__version__ = TorchVersion(internal_version) diff --git a/phivenv/Lib/site-packages/torch/types.py b/phivenv/Lib/site-packages/torch/types.py new file mode 100644 index 0000000000000000000000000000000000000000..7a7df1bd827622f23d3091b946c1db7481d96daf --- /dev/null +++ b/phivenv/Lib/site-packages/torch/types.py @@ -0,0 +1,130 @@ +# In some cases, these basic types are shadowed by corresponding +# top-level values. The underscore variants let us refer to these +# types. See https://github.com/python/mypy/issues/4146 for why these +# workarounds is necessary +import os +from builtins import ( # noqa: F401 + bool as _bool, + bytes as _bytes, + complex as _complex, + float as _float, + int as _int, + str as _str, +) +from collections.abc import Sequence +from typing import Any, IO, TYPE_CHECKING, Union +from typing_extensions import Self, TypeAlias + +# `as` imports have better static analysis support than assignment `ExposedType: TypeAlias = HiddenType` +from torch import ( # noqa: F401 + device as _device, + DispatchKey as DispatchKey, + dtype as _dtype, + layout as _layout, + qscheme as _qscheme, + Size as Size, + SymBool as SymBool, + SymFloat as SymFloat, + SymInt as SymInt, + Tensor as Tensor, +) + + +if TYPE_CHECKING: + from torch.autograd.graph import GradientEdge + + +__all__ = ["Number", "Device", "FileLike", "Storage"] + +# Convenience aliases for common composite types that we need +# to talk about in PyTorch +_TensorOrTensors: TypeAlias = Union[Tensor, Sequence[Tensor]] # noqa: PYI047 +_TensorOrTensorsOrGradEdge: TypeAlias = Union[ # noqa: PYI047 + Tensor, + Sequence[Tensor], + "GradientEdge", + Sequence["GradientEdge"], +] + +_size: TypeAlias = Union[Size, list[int], tuple[int, ...]] # noqa: PYI042,PYI047 +_symsize: TypeAlias = Union[Size, Sequence[Union[int, SymInt]]] # noqa: PYI042,PYI047 +_dispatchkey: TypeAlias = Union[str, DispatchKey] # noqa: PYI042,PYI047 + +# int or SymInt +IntLikeType: TypeAlias = Union[int, SymInt] +# float or SymFloat +FloatLikeType: TypeAlias = Union[float, SymFloat] +# bool or SymBool +BoolLikeType: TypeAlias = Union[bool, SymBool] + +py_sym_types = (SymInt, SymFloat, SymBool) # left un-annotated intentionally +PySymType: TypeAlias = Union[SymInt, SymFloat, SymBool] + +# Meta-type for "numeric" things; matches our docs +Number: TypeAlias = Union[int, float, bool] +# tuple for isinstance(x, Number) checks. +# FIXME: refactor once python 3.9 support is dropped. +_Number = (int, float, bool) + +FileLike: TypeAlias = Union[str, os.PathLike[str], IO[bytes]] + +# Meta-type for "device-like" things. Not to be confused with 'device' (a +# literal device object). This nomenclature is consistent with PythonArgParser. +# None means use the default device (typically CPU) +Device: TypeAlias = Union[_device, str, int, None] + + +# Storage protocol implemented by ${Type}StorageBase classes +class Storage: + _cdata: int + device: _device + dtype: _dtype + _torch_load_uninitialized: bool + + def __deepcopy__(self, memo: dict[int, Any]) -> Self: + raise NotImplementedError + + def _new_shared(self, size: int) -> Self: + raise NotImplementedError + + def _write_file( + self, + f: Any, + is_real_file: bool, + save_size: bool, + element_size: int, + ) -> None: + raise NotImplementedError + + def element_size(self) -> int: + raise NotImplementedError + + def is_shared(self) -> bool: + raise NotImplementedError + + def share_memory_(self) -> Self: + raise NotImplementedError + + def nbytes(self) -> int: + raise NotImplementedError + + def cpu(self) -> Self: + raise NotImplementedError + + def data_ptr(self) -> int: + raise NotImplementedError + + def from_file( + self, + filename: str, + shared: bool = False, + nbytes: int = 0, + ) -> Self: + raise NotImplementedError + + def _new_with_file( + self, + f: Any, + element_size: int, + ) -> Self: + raise NotImplementedError diff --git a/phivenv/Lib/site-packages/torch/version.py b/phivenv/Lib/site-packages/torch/version.py new file mode 100644 index 0000000000000000000000000000000000000000..1f52ae86621094acc78a8b5cd52f8d49e41c4b33 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/version.py @@ -0,0 +1,9 @@ +from typing import Optional + +__all__ = ['__version__', 'debug', 'cuda', 'git_version', 'hip', 'xpu'] +__version__ = '2.8.0+cpu' +debug = False +cuda: Optional[str] = None +git_version = 'a1cb3cc05d46d198467bebbb6e8fba50a325d4e7' +hip: Optional[str] = None +xpu: Optional[str] = None